ParamDimVjpAlgorithm#
- class braintrace.ParamDimVjpAlgorithm(model, name=None, vjp_method='single-step', fast_solve=True, normalize_matrix_spectrum=False, trace_dtype=None, **kwargs)[source]#
Online gradient algorithm with diagonal approximation and parameter-dimension complexity.
This algorithm computes the gradients of the weights with the diagonal approximation and the parameter-dimension complexity. It is based on the RTRL algorithm (Real-Time Recurrent Learning).
- Parameters:
model (
Module) – The model function, which receives the input arguments and returns the model output.vjp_method (
str) –The method for computing the VJP. It should be either
"single-step"or"multi-step"."single-step": the VJP is computed at the current time step, i.e., \(\partial L^t/\partial h^t\)."multi-step": the VJP is computed at multiple time steps, i.e., \(\partial L^t/\partial h^{t-k}\), where \(k\) is determined by the data input.
mode (braintrace.mixin.Mode, optional) – The computing mode, indicating the batching behavior.
Notes
The learning rule is
\[\begin{split}\begin{aligned} &\boldsymbol{\epsilon}^t \approx \mathbf{D}^t \boldsymbol{\epsilon}^{t-1}+\operatorname{diag}\left(\mathbf{D}_f^t\right) \otimes \mathbf{x}^t \\ & \nabla_{\boldsymbol{\theta}} \mathcal{L}=\sum_{t^{\prime} \in \mathcal{T}} \frac{\partial \mathcal{L}^{t^{\prime}}}{\partial \mathbf{h}^{t^{\prime}}} \circ \boldsymbol{\epsilon}^{t^{\prime}} \end{aligned}\end{split}\]where \(\boldsymbol{\epsilon}^t\) is the per-parameter eligibility trace, \(\mathbf{D}^t\) the hidden-to-hidden Jacobian, \(\mathbf{D}_f^t\) the state-to-output Jacobian, \(\mathbf{x}^t\) the presynaptic input, and \(\partial \mathcal{L}^{t'}/\partial \mathbf{h}^{t'}\) the learning signal back-propagated from the loss at each step.
Real-Time Recurrent Learning (RTRL) propagates the full sensitivity \(\partial \mathbf{h}^t/\partial \boldsymbol{\theta}\) forward in time, which costs \(O(|\theta| \cdot H)\) memory. D-RTRL keeps only the diagonal of the hidden-to-hidden Jacobian, collapsing the trace to one value per parameter. The trace is then contracted with the instantaneous learning signal at each step to accumulate the gradient — no backward pass through time and memory linear in the parameter count.
ParamDimVjpAlgorithmis a subclass ofbrainstate.nn.Moduleand is sensitive to the context/mode of the computation. In particular, it is sensitive tobrainstate.mixin.Batchingbehavior.This algorithm has \(O(B\theta)\) memory complexity, where \(\theta\) is the number of parameters and \(B\) the batch size. For a convolutional layer, the weight gradients are computed with \(O(B\theta)\) memory complexity, where \(\theta\) is the dimension of the convolutional kernel. For a linear transformation layer, the weight gradients are computed with \(O(BIO)\) computational complexity, where \(I\) and \(O\) are the number of input and output dimensions.
For more details, please see the D-RTRL algorithm presented in our manuscript.
Examples
>>> import brainstate >>> import braintrace >>> >>> class RNN(brainstate.nn.Module): ... def __init__(self): ... super().__init__() ... self.cell = braintrace.nn.ValinaRNNCell(1, 20, activation='tanh') ... self.out = braintrace.nn.Linear(20, 1) ... def update(self, x): ... return x >> self.cell >> self.out >>> >>> model = RNN() >>> _ = brainstate.nn.init_all_states(model) >>> learner = braintrace.D_RTRL(model) # alias of ParamDimVjpAlgorithm >>> x0 = brainstate.random.randn(1) >>> learner.compile_graph(x0) # trace the graph once >>> y = learner(x0) # forward pass + eligibility-trace update
References
- get_etrace_of(weight)[source]#
Get the eligibility trace of the given weight.
- Parameters:
weight (
Union[ParamState,Tuple[str,...]]) – The weight whose eligibility trace is requested, given either as abrainstate.ParamStateinstance or as its path in the model.- Returns:
A dictionary mapping
(y_var id, hidden-group index)keys to the eligibility-trace values associated with the given weight.- Return type:
- Raises:
ValueError – If no eligibility trace is found for the given weight.