OTPE#
- class braintrace.OTPE(model, mode='full', *, leak, name=None, vjp_method='single-step', trace_clip_abs=None, **kwargs)#
Online Training with Postsynaptic Estimates for spiking networks.
OTPE maintains a leaky, additive estimate \(\hat R^t\) of each parameter’s accumulated influence on the postsynaptic state, then contracts it with the learning signal \(L^t\) to obtain the weight gradient:
\[\hat R^t = \lambda\,\hat R^{t-1} + \frac{\partial s^t}{\partial \theta} = \lambda\,\hat R^{t-1} + x^t \otimes \operatorname{diag}(D_f^t) , \qquad \nabla_{W}\mathcal{L}^t = L^t \cdot \hat R^t ,\]where \(x^t\) is the presynaptic input, \(D_f^t\) the state-to-output Jacobian (surrogate gradient of the spike), \(\lambda \in (0, 1)\) the membrane leak, and \(L^t = \partial \mathcal{L}^t / \partial s^t\) the learning signal. The contraction runs over the output dimension, leaving a gradient with the weight’s shape.
In the low-rank
'approx'mode (F-OTPE) the estimate is factorized as an outer product, reducing memory from \(O(I\cdot O)\) to \(O(I+O)\) per layer:\[\hat R^t \approx \hat z_\text{in}^t \otimes \bar g_\text{out}^t , \quad \hat z_\text{in}^t = \lambda\,\hat z_\text{in}^{t-1} + x^t , \quad \bar g_\text{out}^t = \lambda\,\bar g_\text{out}^{t-1} + \operatorname{diag}(D_f^t) ,\]with gradient \(\nabla_{W}\mathcal{L}^t = \hat z_\text{in}^t \otimes (L^t \cdot \bar g_\text{out}^t)\).
How it works. Unlike OTTT/OSTL, which assign temporal credit only within the current layer’s output, OTPE keeps a per-parameter trace that decays with the membrane leak, approximating the entire temporal effect of a weight on downstream activity while staying local to each layer. This improves gradient alignment with BPTT in deep feed-forward SNNs at modest extra cost.
- Parameters:
model (
Module) – The SNN whose weights are trained online.mode (
str) –'full'keeps the full(batch, I, O)estimate \(\hat R\) per layer.'approx'(F-OTPE) factorizes it as an outer product for \(O(I+O)\) memory; emits aUserWarningwhen the network has more than one HiddenGroup, because the factorization bias compounds with depth.leak (
float) – Decay factor \(\lambda \in (0, 1)\). Required — it must be supplied explicitly and is never inferred from the model. \(\lambda\) is the membrane leak of the postsynaptic neuron whose influence is being accumulated; auto-inferring it frommodel.states()silently picks an arbitrary (often wrong) value on heterogeneous or multi-population models, so the framework will not guess it.trace_clip_abs (
Optional[float]) – Elementwise clip applied to \(\hat R\) each step (full mode only).Nonedisables clipping.vjp_method (
str) – Forwarded to the base algorithm. Only'single-step'is supported by OTPE v1; multi-step inputs raiseNotImplementedError.Limitations
-----------
OTTT's (OTPE's published derivation is narrower than)
this (and)
outside (implementation is a general operator that will happily run far)
:param that proven regime. The estimate \(\hat R\) is built on the assumption: :param that the only temporal coupling of the postsynaptic state is the scalar: :param membrane leak: :param \(\partial U^t / \partial U^{t-1} = \lambda\) — exactly: :param the leaky integrate-and-fire (LIF) recurrence. On top of that scalar-leak: :param assumption (inherited from OTTT): :param OTPE adds three further restrictions: :param 1. A single global time constant. One scalar \(\lambda\) is shared by: every traced connection. Heterogeneous leaks across neurons or layers
break the estimate;
leakis therefore a user-supplied global constant and is never inferred from the model (see theleakparameter).- Parameters:
Jacobian (2. Feed-forward only. The trace omits the hidden-to-hidden) – it is the postsynaptic estimate for feed-forward SNNs. Applying it to a recurrent network silently drops the recurrent temporal credit.
so – it is the postsynaptic estimate for feed-forward SNNs. Applying it to a recurrent network silently drops the recurrent temporal credit.
one (3. Single-hidden-layer exactness. The estimate is gradient-exact for) – hidden layer; with depth the per-layer factorization accumulates bias.
additional (The low-rank 'approx' mode (F-OTPE) layers an)
top (outer-product approximation on)
the (which is itself justified only under)
(hence (same linear-leak assumption; its bias compounds with network depth)
:param the
UserWarningfor multi-group networks).: :param Concretely: :type Concretely: it accepts :parambraintraceexposes OTPE as a generic ETP operator: :typebraintraceexposes OTPE as a generic ETP operator: it accepts :param arbitrary ETP weights and hidden states: :param multi-layer stacks: :param recurrent: :param connectivity: :param and even non-spiking cells (e.g. atanhRNN). All of these: :param *run* mechanically: :param but the moment the model deviates from a feed-forward: :param LIF network with a single global scalar leak: :param the computed gradient leaves: :param the regime in which OTPE is proven correct and should be treated as a: :param heuristic approximation rather than a faithful gradient estimate. The one: :param structural case that is rejected outright is a multi-state hidden group: :param (num_state > 1: :type (num_state > 1: the leaky scalar :param e.g. ALIF with an adaptation variable): :type e.g. ALIF with an adaptation variable): the leaky scalar :param estimate cannot assign per-state credit: :param socompile_graph()raises: :param rather than silently summing across states.:- Raises:
ValueError – If
modeis not'full'or'approx', ifleakis not in \((0, 1)\), if a weight-to-hidden relation reaches more than one HiddenGroup (OTPE v1 requires one-hop per-layer relations), or (atcompile_graph()) if a trained connection projects into a hidden group withnum_state > 1.
Examples
>>> import brainstate >>> import braintrace >>> >>> class Net(brainstate.nn.Module): ... def __init__(self): ... super().__init__() ... self.cell = braintrace.nn.ValinaRNNCell(1, 20, activation='tanh') ... self.out = braintrace.nn.Linear(20, 1) ... def update(self, x): ... return x >> self.cell >> self.out >>> >>> model = Net() >>> _ = brainstate.nn.init_all_states(model) >>> # ``leak`` is the postsynaptic membrane leak and must be passed >>> # explicitly; it is never inferred from the model. >>> learner = braintrace.OTPE(model, mode='full', leak=0.9) >>> x0 = brainstate.random.randn(1) >>> learner.compile_graph(x0) >>> y = learner(x0)
References
- compile_graph(*args)[source]#
Compile the eligibility trace graph of the relationship between etrace weights, states and operators.
The compilation process includes:
building the etrace graph
separating the states
initializing the etrace states
- Parameters:
*args – The input arguments.
- Return type:
- init_etrace_state(*args, **kwargs)[source]#
Initialize the eligibility trace states of the etrace algorithm.
This method is needed after compiling the etrace graph. See .compile_graph() for the details.
- Parameters:
*args – The positional arguments.
**kwargs – The keyword arguments.
- Raises:
NotImplementedError – This method must be implemented by subclasses.