ParamDimVjpAlgorithm#

class braintrace.ParamDimVjpAlgorithm(model, name=None, vjp_method='single-step', fast_solve=True, normalize_matrix_spectrum=False, trace_dtype=None, **kwargs)[source]#

Online gradient algorithm with diagonal approximation and parameter-dimension complexity.

This algorithm computes the gradients of the weights with the diagonal approximation and the parameter-dimension complexity. It is based on the RTRL algorithm (Real-Time Recurrent Learning).

Parameters:
  • model (Module) – The model function, which receives the input arguments and returns the model output.

  • vjp_method (str) –

    The method for computing the VJP. It should be either "single-step" or "multi-step".

    • "single-step": the VJP is computed at the current time step, i.e., \(\partial L^t/\partial h^t\).

    • "multi-step": the VJP is computed at multiple time steps, i.e., \(\partial L^t/\partial h^{t-k}\), where \(k\) is determined by the data input.

  • name (Optional[str]) – The name of the etrace algorithm.

  • mode (braintrace.mixin.Mode, optional) – The computing mode, indicating the batching behavior.

Notes

The learning rule is

\[\begin{split}\begin{aligned} &\boldsymbol{\epsilon}^t \approx \mathbf{D}^t \boldsymbol{\epsilon}^{t-1}+\operatorname{diag}\left(\mathbf{D}_f^t\right) \otimes \mathbf{x}^t \\ & \nabla_{\boldsymbol{\theta}} \mathcal{L}=\sum_{t^{\prime} \in \mathcal{T}} \frac{\partial \mathcal{L}^{t^{\prime}}}{\partial \mathbf{h}^{t^{\prime}}} \circ \boldsymbol{\epsilon}^{t^{\prime}} \end{aligned}\end{split}\]

where \(\boldsymbol{\epsilon}^t\) is the per-parameter eligibility trace, \(\mathbf{D}^t\) the hidden-to-hidden Jacobian, \(\mathbf{D}_f^t\) the state-to-output Jacobian, \(\mathbf{x}^t\) the presynaptic input, and \(\partial \mathcal{L}^{t'}/\partial \mathbf{h}^{t'}\) the learning signal back-propagated from the loss at each step.

Real-Time Recurrent Learning (RTRL) propagates the full sensitivity \(\partial \mathbf{h}^t/\partial \boldsymbol{\theta}\) forward in time, which costs \(O(|\theta| \cdot H)\) memory. D-RTRL keeps only the diagonal of the hidden-to-hidden Jacobian, collapsing the trace to one value per parameter. The trace is then contracted with the instantaneous learning signal at each step to accumulate the gradient — no backward pass through time and memory linear in the parameter count.

ParamDimVjpAlgorithm is a subclass of brainstate.nn.Module and is sensitive to the context/mode of the computation. In particular, it is sensitive to brainstate.mixin.Batching behavior.

This algorithm has \(O(B\theta)\) memory complexity, where \(\theta\) is the number of parameters and \(B\) the batch size. For a convolutional layer, the weight gradients are computed with \(O(B\theta)\) memory complexity, where \(\theta\) is the dimension of the convolutional kernel. For a linear transformation layer, the weight gradients are computed with \(O(BIO)\) computational complexity, where \(I\) and \(O\) are the number of input and output dimensions.

For more details, please see the D-RTRL algorithm presented in our manuscript.

Examples

>>> import brainstate
>>> import braintrace
>>>
>>> class RNN(brainstate.nn.Module):
...     def __init__(self):
...         super().__init__()
...         self.cell = braintrace.nn.ValinaRNNCell(1, 20, activation='tanh')
...         self.out = braintrace.nn.Linear(20, 1)
...     def update(self, x):
...         return x >> self.cell >> self.out
>>>
>>> model = RNN()
>>> _ = brainstate.nn.init_all_states(model)
>>> learner = braintrace.D_RTRL(model)  # alias of ParamDimVjpAlgorithm
>>> x0 = brainstate.random.randn(1)
>>> learner.compile_graph(x0)   # trace the graph once
>>> y = learner(x0)             # forward pass + eligibility-trace update

References

get_etrace_of(weight)[source]#

Get the eligibility trace of the given weight.

Parameters:

weight (Union[ParamState, Tuple[str, ...]]) – The weight whose eligibility trace is requested, given either as a brainstate.ParamState instance or as its path in the model.

Returns:

A dictionary mapping (y_var id, hidden-group index) keys to the eligibility-trace values associated with the given weight.

Return type:

Dict

Raises:

ValueError – If no eligibility trace is found for the given weight.

init_etrace_state(*args, **kwargs)[source]#

Initialize the eligibility trace states of the etrace algorithm.

This method is needed after compiling the etrace graph. See compile_graph() for the details.

reset_state(batch_size=None, **kwargs)[source]#

Reset the eligibility trace states.

Parameters:

batch_size (int) – The batch size used to reshape the reset trace states. Default None.