EProp#

class braintrace.EProp(model, feedback='symmetric', kappa_filter_decay=0.0, random_feedback_key=None, name=None, vjp_method='single-step', fast_solve=True, normalize_matrix_spectrum=False, **kwargs)#

Eligibility Propagation (e-prop) for recurrent spiking networks.

E-prop approximates the gradient of a loss \(\mathcal{L}\) with respect to a recurrent weight \(W_{ji}\) by the product of a local eligibility trace and a global learning signal, dropping the temporally non-local terms of BPTT:

\[\frac{d\mathcal{L}}{dW_{ji}} = \sum_t L_j^t \, \bar{e}_{ji}^t ,\]

where

\[e_{ji}^t = \frac{\partial h_j^t}{\partial W_{ji}} \approx D_j^t \, e_{ji}^{t-1} + \big[\operatorname{diag}(D_{f,j}^t)\big]\, x_i^t , \qquad \bar{e}_{ji}^t = \kappa\,\bar{e}_{ji}^{t-1} + e_{ji}^t .\]

Here \(h_j^t\) is the hidden state of neuron \(j\) at time \(t\), \(x_i^t\) the presynaptic input, \(D_j^t\) the hidden-to-hidden (recurrent) Jacobian diagonal, \(D_{f,j}^t\) the state-to-output Jacobian, and \(\kappa \in [0, 1)\) the readout-side low-pass factor. The learning signal is

\[\begin{split}L_j^t = \begin{cases} \dfrac{\partial \mathcal{L}}{\partial h_j^t} & \text{(symmetric feedback, standard backprop through readout)} \\[2ex] \big(B\,e^t\big)_j & \text{(random feedback: a fixed random projection } B\text{)} . \end{cases}\end{split}\]

How it works. The eligibility trace \(e_{ji}^t\) is exactly the per-parameter trace maintained by D_RTRL; it depends only on quantities local to the synapse and is updated forward in time. The learning signal \(L_j^t\) is broadcast from the readout. E-prop is therefore online (no backward pass through time) and uses memory linear in the number of parameters. With kappa_filter_decay > 0 the learning signal is additionally low-pass filtered; with feedback='random' the symmetric readout gradient is replaced by a frozen random matrix, removing the biologically implausible weight-transport requirement.

Parameters:
  • model (Module) – The recurrent SNN whose weights are trained online.

  • feedback (str) – 'symmetric' uses reverse-AD’s \(\partial \mathcal{L}/\partial h\) (standard backprop through the readout). 'random' replaces the readout gradient with a frozen random projection (requires random_feedback_key).

  • kappa_filter_decay (float) – Readout-side low-pass factor \(\kappa\). If > 0, each HiddenGroup’s learning signal is filtered each step (\(\bar L^t = (1-\kappa)L^t + \kappa\bar L^{t-1}\)). 0 disables filtering.

  • random_feedback_key (Optional[Array]) – Seed for the random-feedback matrices. Required when feedback='random'; ignored otherwise.

  • name (Optional[str]) – Name of the algorithm instance.

  • vjp_method (str) – Forwarded verbatim to D_RTRL.

  • fast_solve (bool) – Forwarded verbatim to D_RTRL.

  • normalize_matrix_spectrum (bool) – Forwarded verbatim to D_RTRL.

Raises:

ValueError – If feedback is not one of {'symmetric', 'random'}, or if feedback='random' is given without random_feedback_key.

Examples

>>> import brainstate
>>> import braintrace
>>>
>>> class RSNN(brainstate.nn.Module):
...     def __init__(self):
...         super().__init__()
...         self.cell = braintrace.nn.ValinaRNNCell(1, 20, activation='tanh')
...         self.out = braintrace.nn.Linear(20, 1)
...     def update(self, x):
...         return x >> self.cell >> self.out
>>>
>>> model = RSNN()
>>> _ = brainstate.nn.init_all_states(model)
>>> learner = braintrace.EProp(model, kappa_filter_decay=0.9)
>>> x0 = brainstate.random.randn(1)
>>> learner.compile_graph(x0)   # trace the graph once
>>> y = learner(x0)             # forward pass + eligibility-trace update

References

init_etrace_state(*args, **kwargs)[source]#

Initialize the eligibility trace states of the etrace algorithm.

This method is needed after compiling the etrace graph. See compile_graph() for the details.

reset_state(batch_size=None, **kwargs)[source]#

Reset the eligibility trace states.

Parameters:

batch_size (Optional[int]) – The batch size used to reshape the reset trace states. Default None.