EProp#
- class braintrace.EProp(model, feedback='symmetric', kappa_filter_decay=0.0, random_feedback_key=None, name=None, vjp_method='single-step', fast_solve=True, normalize_matrix_spectrum=False, **kwargs)#
Eligibility Propagation (e-prop) for recurrent spiking networks.
E-prop approximates the gradient of a loss \(\mathcal{L}\) with respect to a recurrent weight \(W_{ji}\) by the product of a local eligibility trace and a global learning signal, dropping the temporally non-local terms of BPTT:
\[\frac{d\mathcal{L}}{dW_{ji}} = \sum_t L_j^t \, \bar{e}_{ji}^t ,\]where
\[e_{ji}^t = \frac{\partial h_j^t}{\partial W_{ji}} \approx D_j^t \, e_{ji}^{t-1} + \big[\operatorname{diag}(D_{f,j}^t)\big]\, x_i^t , \qquad \bar{e}_{ji}^t = \kappa\,\bar{e}_{ji}^{t-1} + e_{ji}^t .\]Here \(h_j^t\) is the hidden state of neuron \(j\) at time \(t\), \(x_i^t\) the presynaptic input, \(D_j^t\) the hidden-to-hidden (recurrent) Jacobian diagonal, \(D_{f,j}^t\) the state-to-output Jacobian, and \(\kappa \in [0, 1)\) the readout-side low-pass factor. The learning signal is
\[\begin{split}L_j^t = \begin{cases} \dfrac{\partial \mathcal{L}}{\partial h_j^t} & \text{(symmetric feedback, standard backprop through readout)} \\[2ex] \big(B\,e^t\big)_j & \text{(random feedback: a fixed random projection } B\text{)} . \end{cases}\end{split}\]How it works. The eligibility trace \(e_{ji}^t\) is exactly the per-parameter trace maintained by
D_RTRL; it depends only on quantities local to the synapse and is updated forward in time. The learning signal \(L_j^t\) is broadcast from the readout. E-prop is therefore online (no backward pass through time) and uses memory linear in the number of parameters. Withkappa_filter_decay > 0the learning signal is additionally low-pass filtered; withfeedback='random'the symmetric readout gradient is replaced by a frozen random matrix, removing the biologically implausible weight-transport requirement.- Parameters:
model (
Module) – The recurrent SNN whose weights are trained online.feedback (
str) –'symmetric'uses reverse-AD’s \(\partial \mathcal{L}/\partial h\) (standard backprop through the readout).'random'replaces the readout gradient with a frozen random projection (requiresrandom_feedback_key).kappa_filter_decay (
float) – Readout-side low-pass factor \(\kappa\). If> 0, each HiddenGroup’s learning signal is filtered each step (\(\bar L^t = (1-\kappa)L^t + \kappa\bar L^{t-1}\)).0disables filtering.random_feedback_key (
Optional[Array]) – Seed for the random-feedback matrices. Required whenfeedback='random'; ignored otherwise.normalize_matrix_spectrum (
bool) – Forwarded verbatim toD_RTRL.
- Raises:
ValueError – If
feedbackis not one of{'symmetric', 'random'}, or iffeedback='random'is given withoutrandom_feedback_key.
Examples
>>> import brainstate >>> import braintrace >>> >>> class RSNN(brainstate.nn.Module): ... def __init__(self): ... super().__init__() ... self.cell = braintrace.nn.ValinaRNNCell(1, 20, activation='tanh') ... self.out = braintrace.nn.Linear(20, 1) ... def update(self, x): ... return x >> self.cell >> self.out >>> >>> model = RSNN() >>> _ = brainstate.nn.init_all_states(model) >>> learner = braintrace.EProp(model, kappa_filter_decay=0.9) >>> x0 = brainstate.random.randn(1) >>> learner.compile_graph(x0) # trace the graph once >>> y = learner(x0) # forward pass + eligibility-trace update
References