OSTTP#

class braintrace.OSTTP(model, B_list, target_timing='per-step', name=None, vjp_method='single-step', fast_solve=True, **kwargs)#

Online Spatio-Temporal Learning with Target Projection.

OSTTP reuses the OSTL / D-RTRL per-parameter eligibility trace but replaces the back-propagated learning signal with a direct random target projection (DRTP):

\[\boldsymbol{\epsilon}^t \approx \mathbf{D}^t\,\boldsymbol{\epsilon}^{t-1} + \operatorname{diag}(\mathbf{D}_f^t)\otimes \mathbf{x}^t , \qquad L_l^t = y^{*\,t}\, B_l , \qquad \nabla_{W}\mathcal{L} = \sum_t L^t \circ \boldsymbol{\epsilon}^t ,\]

where \(y^{*\,t}\) is the task target at time \(t\), \(B_l \in \mathbb{R}^{n_\text{target}\times n_l}\) is a fixed random feedback matrix for HiddenGroup \(l\) (frozen via stop_gradient), \(\mathbf{D}^t\) is the hidden-to-hidden Jacobian, \(\mathbf{D}_f^t\) the state-to-output Jacobian, and \(\mathbf{x}^t\) the presynaptic input.

How it works. The eligibility trace carries the temporal credit exactly as in OSTLRecurrent (‘with-H’), but the spatial credit normally obtained by back-propagating \(\partial \mathcal{L}/\partial h\) is replaced by a frozen random projection of the target. Because the projection matrices \(B_l\) are fixed, there is no weight transport and no backward pass — the rule is fully forward and update-unlocked in both space and time.

Parameters:
  • model (Module) – The SNN whose weights are trained online.

  • B_list (Sequence[Array]) – One feedback matrix per HiddenGroup, each of shape (n_target, n_l). Frozen via stop_gradient at construction; the count and trailing dimension are validated against the compiled graph.

  • target_timing (str) – 'per-step' requires y_target at every update() call. 'sequence-end' zeros the learning signal on intermediate steps (the trace still accumulates) and applies the projection only when y_target is supplied.

  • name (Optional[str]) – Name of the algorithm instance.

  • vjp_method (str) – Forwarded verbatim to ParamDimVjpAlgorithm.

  • fast_solve (bool) – Forwarded verbatim to ParamDimVjpAlgorithm.

Raises:

ValueError – If target_timing is invalid; if len(B_list) differs from the number of HiddenGroups; if a matrix’s trailing dimension does not match its HiddenGroup width; or if target_timing='per-step' and y_target is omitted from an update() call.

Examples

>>> import brainstate
>>> import jax
>>> import braintrace
>>>
>>> class Net(brainstate.nn.Module):
...     def __init__(self):
...         super().__init__()
...         self.cell = braintrace.nn.ValinaRNNCell(1, 20, activation='tanh')
...         self.out = braintrace.nn.Linear(20, 1)
...     def update(self, x):
...         return x >> self.cell >> self.out
>>>
>>> model = Net()
>>> _ = brainstate.nn.init_all_states(model)
>>> # one (n_target, n_l) feedback matrix per HiddenGroup (here n_l = 20)
>>> B = jax.random.normal(jax.random.PRNGKey(0), (1, 20))
>>> learner = braintrace.OSTTP(model, B_list=[B])
>>> x0 = brainstate.random.randn(1)
>>> learner.compile_graph(x0)
>>> y = learner.update(x0, y_target=brainstate.random.randn(1))

References

compile_graph(*args)[source]#

Compile the eligibility trace graph of the relationship between etrace weights, states and operators.

The compilation process includes:

  • building the etrace graph

  • separating the states

  • initializing the etrace states

Parameters:

*args – The input arguments.

Return type:

None

update(x, y_target=None)[source]#

Call super().update(x) after stashing y_target for the hook.