OTTT#
- class braintrace.OTTT(model, mode='A', *, leak, name=None, vjp_method='single-step', **kwargs)#
Online Training Through Time for spiking neural networks.
OTTT tracks only a leaky presynaptic trace and forms the weight gradient each step as an outer product with the local learning signal:
\[\begin{split}\hat{a}^t = \begin{cases} \lambda\,\hat{a}^{t-1} + x^t & \text{(mode='A', accumulated)} \\ x^t & \text{(mode='O', instantaneous)} \end{cases}\end{split}\]\[\nabla_{W}\mathcal{L}^t = \hat{a}^t \otimes \Big( \frac{\partial \mathcal{L}^t}{\partial s^t}\,\sigma'(u^t) \Big) \;=\; \hat{a}^t \otimes L^t ,\]where \(x^t\) is the presynaptic input, \(u^t\) the membrane potential, \(s^t = \sigma(u^t)\) the (surrogate) spike, \(\sigma'\) the surrogate-gradient function, \(\lambda \in (0, 1)\) the membrane leak, and \(L^t\) the learning signal already propagated through the spike nonlinearity.
How it works. Starting from BPTT, OTTT keeps the spatial credit assignment but drops the hidden-to-hidden recurrent Jacobian. The only state it carries forward in time is the rank-1 presynaptic trace \(\hat{a}^t\), so the per-step gradient is the outer product of that trace with the instantaneous learning signal. Training memory is therefore \(O(B \cdot I)\) per layer and independent of the sequence length — the cheapest of the algorithms here, at the cost of ignoring longer-range temporal credit.
- Parameters:
model (
Module) – The SNN whose weights are trained online.mode (
str) –'A'accumulates the presynaptic trace over time (\(\hat a \leftarrow \lambda\,\hat a + x\)).'O'uses the instantaneous presynaptic spike only (\(\hat a := x^t\)).leak (
float) – Presynaptic leak \(\lambda \in (0, 1)\). Required — it must be supplied explicitly and is never inferred from the model (see Limitations). Mathematically \(\lambda\) is the membrane leak of the postsynaptic neuron whose trace is being accumulated.vjp_method (
str) – Forwarded to the base algorithm. Only'single-step'is supported by OTTT v1; multi-step inputs raiseNotImplementedError.Limitations
-----------
read (- The leak must be supplied by the user. OTTT does not try to) – \(\lambda\) off the model’s neuron states. A previous version walked
model.states()and took the first state exposing aleakattribute, but on heterogeneous or multi-population models that silently picks an arbitrary (often wrong) value — e.g. the leak of the presynaptic layer, a readout filter, or whichever population happens to be enumerated first. Since \(\lambda\) is, by the derivation, the membrane leak of the postsynaptic neuron of each trained connection, the framework cannot guess it safely. A single network with different leaks per layer therefore cannot be trained correctly with one globalleakand is unsupported.project (- Single-state hidden groups only. Each trained connection must) – into a
HiddenGroupwithnum_state == 1. The weight gradient contracts the learning signalL(shape(*varshape, num_state)) down to(*varshape,); collapsing anum_state > 1tail (e.g. an ALIF neuron carrying both membrane potential and an adaptation variable) has no theoretical justification — the trace is a single leaky scalar and cannot disentangle per-state credit — so OTTT raises at compile time instead of silently summing across states.raise (- Single-step inputs only (OTTT v1); multi-step inputs) –
NotImplementedError.
- Raises:
ValueError – If
modeis not'A'or'O', ifleakis not in \((0, 1)\), or (atcompile_graph()) if a trained connection projects into a hidden group withnum_state > 1.
Examples
>>> import brainstate >>> import braintrace >>> >>> class Net(brainstate.nn.Module): ... def __init__(self): ... super().__init__() ... self.cell = braintrace.nn.ValinaRNNCell(1, 20, activation='tanh') ... self.out = braintrace.nn.Linear(20, 1) ... def update(self, x): ... return x >> self.cell >> self.out >>> >>> model = Net() >>> _ = brainstate.nn.init_all_states(model) >>> # ``leak`` is the postsynaptic membrane leak and must be passed >>> # explicitly; it is never inferred from the model. >>> learner = braintrace.OTTT(model, mode='A', leak=0.9) >>> x0 = brainstate.random.randn(1) >>> learner.compile_graph(x0) >>> y = learner(x0)
References
- init_etrace_state(*args, **kwargs)[source]#
Initialize the eligibility trace states of the etrace algorithm.
This method is needed after compiling the etrace graph. See .compile_graph() for the details.
- Parameters:
*args – The positional arguments.
**kwargs – The keyword arguments.
- Raises:
NotImplementedError – This method must be implemented by subclasses.