OSTLFeedforward

Contents

OSTLFeedforward#

class braintrace.OSTLFeedforward(model, decay_or_rank=1e-06, name=None, **kwargs)#

OSTL ‘without-H’ regime — feedforward / no recurrent Jacobian.

The ‘without-H’ regime drops the hidden-to-hidden Jacobian \(\mathbf{D}^t\), so the temporal term of the eligibility trace vanishes and only the instantaneous (spatial) contribution survives:

\[\boldsymbol{\epsilon}^t \approx \operatorname{diag}(\mathbf{D}_f^t) \otimes \mathbf{x}^t , \qquad \nabla_{\boldsymbol{\theta}}\mathcal{L} = \sum_t \frac{\partial \mathcal{L}^t}{\partial \mathbf{h}^t} \circ \boldsymbol{\epsilon}^t .\]

This is the appropriate (and exact) approximation for feed-forward SNNs. It is realized by delegating to pp_prop (the input-output factorized trace) with a negligible decay, so the trace does not accumulate across time.

Parameters:
  • model (Module) – The SNN whose weights are trained online.

  • decay_or_rank (float | int) – Exponential-smoothing factor of the IO-dim trace. The tiny default makes the temporal contribution negligible, matching the ‘without-H’ regime. A float must lie in (0, 1); an int is read as an approximation rank.

  • name (Optional[str]) – Forwarded verbatim to pp_prop.

  • vjp_method (optional) – Forwarded verbatim to pp_prop.

  • fast_solve (optional) – Forwarded verbatim to pp_prop.

Examples

>>> import brainstate
>>> import braintrace
>>>
>>> class Net(brainstate.nn.Module):
...     def __init__(self):
...         super().__init__()
...         self.cell = braintrace.nn.ValinaRNNCell(1, 20, activation='tanh')
...         self.out = braintrace.nn.Linear(20, 1)
...     def update(self, x):
...         return x >> self.cell >> self.out
>>>
>>> model = Net()
>>> _ = brainstate.nn.init_all_states(model)
>>> learner = braintrace.OSTLFeedforward(model)
>>> x0 = brainstate.random.randn(1)
>>> learner.compile_graph(x0)
>>> y = learner(x0)

References