OTPE#

class braintrace.OTPE(model, mode='full', *, leak, name=None, vjp_method='single-step', trace_clip_abs=None, **kwargs)#

Online Training with Postsynaptic Estimates for spiking networks.

OTPE maintains a leaky, additive estimate \(\hat R^t\) of each parameter’s accumulated influence on the postsynaptic state, then contracts it with the learning signal \(L^t\) to obtain the weight gradient:

\[\hat R^t = \lambda\,\hat R^{t-1} + \frac{\partial s^t}{\partial \theta} = \lambda\,\hat R^{t-1} + x^t \otimes \operatorname{diag}(D_f^t) , \qquad \nabla_{W}\mathcal{L}^t = L^t \cdot \hat R^t ,\]

where \(x^t\) is the presynaptic input, \(D_f^t\) the state-to-output Jacobian (surrogate gradient of the spike), \(\lambda \in (0, 1)\) the membrane leak, and \(L^t = \partial \mathcal{L}^t / \partial s^t\) the learning signal. The contraction runs over the output dimension, leaving a gradient with the weight’s shape.

In the low-rank 'approx' mode (F-OTPE) the estimate is factorized as an outer product, reducing memory from \(O(I\cdot O)\) to \(O(I+O)\) per layer:

\[\hat R^t \approx \hat z_\text{in}^t \otimes \bar g_\text{out}^t , \quad \hat z_\text{in}^t = \lambda\,\hat z_\text{in}^{t-1} + x^t , \quad \bar g_\text{out}^t = \lambda\,\bar g_\text{out}^{t-1} + \operatorname{diag}(D_f^t) ,\]

with gradient \(\nabla_{W}\mathcal{L}^t = \hat z_\text{in}^t \otimes (L^t \cdot \bar g_\text{out}^t)\).

How it works. Unlike OTTT/OSTL, which assign temporal credit only within the current layer’s output, OTPE keeps a per-parameter trace that decays with the membrane leak, approximating the entire temporal effect of a weight on downstream activity while staying local to each layer. This improves gradient alignment with BPTT in deep feed-forward SNNs at modest extra cost.

Parameters:
  • model (Module) – The SNN whose weights are trained online.

  • mode (str) – 'full' keeps the full (batch, I, O) estimate \(\hat R\) per layer. 'approx' (F-OTPE) factorizes it as an outer product for \(O(I+O)\) memory; emits a UserWarning when the network has more than one HiddenGroup, because the factorization bias compounds with depth.

  • leak (float) – Decay factor \(\lambda \in (0, 1)\). Required — it must be supplied explicitly and is never inferred from the model. \(\lambda\) is the membrane leak of the postsynaptic neuron whose influence is being accumulated; auto-inferring it from model.states() silently picks an arbitrary (often wrong) value on heterogeneous or multi-population models, so the framework will not guess it.

  • trace_clip_abs (Optional[float]) – Elementwise clip applied to \(\hat R\) each step (full mode only). None disables clipping.

  • name (Optional[str]) – Name of the algorithm instance.

  • vjp_method (str) – Forwarded to the base algorithm. Only 'single-step' is supported by OTPE v1; multi-step inputs raise NotImplementedError.

  • Limitations

  • -----------

  • OTTT's (OTPE's published derivation is narrower than)

  • this (and)

  • outside (implementation is a general operator that will happily run far)

:param that proven regime. The estimate \(\hat R\) is built on the assumption: :param that the only temporal coupling of the postsynaptic state is the scalar: :param membrane leak: :param \(\partial U^t / \partial U^{t-1} = \lambda\) — exactly: :param the leaky integrate-and-fire (LIF) recurrence. On top of that scalar-leak: :param assumption (inherited from OTTT): :param OTPE adds three further restrictions: :param 1. A single global time constant. One scalar \(\lambda\) is shared by: every traced connection. Heterogeneous leaks across neurons or layers

break the estimate; leak is therefore a user-supplied global constant and is never inferred from the model (see the leak parameter).

Parameters:
  • Jacobian (2. Feed-forward only. The trace omits the hidden-to-hidden) – it is the postsynaptic estimate for feed-forward SNNs. Applying it to a recurrent network silently drops the recurrent temporal credit.

  • so – it is the postsynaptic estimate for feed-forward SNNs. Applying it to a recurrent network silently drops the recurrent temporal credit.

  • one (3. Single-hidden-layer exactness. The estimate is gradient-exact for) – hidden layer; with depth the per-layer factorization accumulates bias.

  • additional (The low-rank 'approx' mode (F-OTPE) layers an)

  • top (outer-product approximation on)

  • the (which is itself justified only under)

  • (hence (same linear-leak assumption; its bias compounds with network depth)

:param the UserWarning for multi-group networks).: :param Concretely: :type Concretely: it accepts :param braintrace exposes OTPE as a generic ETP operator: :type braintrace exposes OTPE as a generic ETP operator: it accepts :param arbitrary ETP weights and hidden states: :param multi-layer stacks: :param recurrent: :param connectivity: :param and even non-spiking cells (e.g. a tanh RNN). All of these: :param *run* mechanically: :param but the moment the model deviates from a feed-forward: :param LIF network with a single global scalar leak: :param the computed gradient leaves: :param the regime in which OTPE is proven correct and should be treated as a: :param heuristic approximation rather than a faithful gradient estimate. The one: :param structural case that is rejected outright is a multi-state hidden group: :param (num_state > 1: :type (num_state > 1: the leaky scalar :param e.g. ALIF with an adaptation variable): :type e.g. ALIF with an adaptation variable): the leaky scalar :param estimate cannot assign per-state credit: :param so compile_graph() raises: :param rather than silently summing across states.:

Raises:

ValueError – If mode is not 'full' or 'approx', if leak is not in \((0, 1)\), if a weight-to-hidden relation reaches more than one HiddenGroup (OTPE v1 requires one-hop per-layer relations), or (at compile_graph()) if a trained connection projects into a hidden group with num_state > 1.

Examples

>>> import brainstate
>>> import braintrace
>>>
>>> class Net(brainstate.nn.Module):
...     def __init__(self):
...         super().__init__()
...         self.cell = braintrace.nn.ValinaRNNCell(1, 20, activation='tanh')
...         self.out = braintrace.nn.Linear(20, 1)
...     def update(self, x):
...         return x >> self.cell >> self.out
>>>
>>> model = Net()
>>> _ = brainstate.nn.init_all_states(model)
>>> # ``leak`` is the postsynaptic membrane leak and must be passed
>>> # explicitly; it is never inferred from the model.
>>> learner = braintrace.OTPE(model, mode='full', leak=0.9)
>>> x0 = brainstate.random.randn(1)
>>> learner.compile_graph(x0)
>>> y = learner(x0)

References

compile_graph(*args)[source]#

Compile the eligibility trace graph of the relationship between etrace weights, states and operators.

The compilation process includes:

  • building the etrace graph

  • separating the states

  • initializing the etrace states

Parameters:

*args – The input arguments.

Return type:

None

init_etrace_state(*args, **kwargs)[source]#

Initialize the eligibility trace states of the etrace algorithm.

This method is needed after compiling the etrace graph. See .compile_graph() for the details.

Parameters:
  • *args – The positional arguments.

  • **kwargs – The keyword arguments.

Raises:

NotImplementedError – This method must be implemented by subclasses.

reset_state(batch_size=None, **kwargs)[source]#

State resetting function.