IODimVjpAlgorithm#
- class braintrace.IODimVjpAlgorithm(model, decay_or_rank, name=None, vjp_method='single-step', fast_solve=True, **kwargs)[source]#
Online gradient algorithm with diagonal approximation and input-output-dimension complexity.
This algorithm computes the gradients of the weights with the diagonal approximation and the input-output dimensional complexity. It is based on the RTRL algorithm (Real-Time Recurrent Learning).
- Parameters:
model (
Module) – The model function, which receives the input arguments and returns the model output.decay_or_rank (
float|int) – The exponential smoothing factor for the eligibility trace. If a float, it is the decay factor and should be in the range \((0, 1)\). If an integer, it is the number of approximation ranks for the algorithm and should be greater than 0.vjp_method (
str) –The method for computing the VJP. It should be either
"single-step"or"multi-step"."single-step": the VJP is computed at the current time step, i.e., \(\partial L^t/\partial h^t\)."multi-step": the VJP is computed at multiple time steps, i.e., \(\partial L^t/\partial h^{t-k}\), where \(k\) is determined by the data input.
mode (braintrace.mixin.Mode, optional) – The computing mode, indicating the batching information.
Notes
The learning rule is
\[\begin{split}\begin{aligned} & \boldsymbol{\epsilon}^t \approx \boldsymbol{\epsilon}_{\mathbf{f}}^t \otimes \boldsymbol{\epsilon}_{\mathbf{x}}^t \\ & \boldsymbol{\epsilon}_{\mathbf{x}}^t=\alpha \boldsymbol{\epsilon}_{\mathbf{x}}^{t-1}+\mathbf{x}^t \\ & \boldsymbol{\epsilon}_{\mathbf{f}}^t=\alpha \operatorname{diag}\left(\mathbf{D}^t\right) \circ \boldsymbol{\epsilon}_{\mathbf{f}}^{t-1}+(1-\alpha) \operatorname{diag}\left(\mathbf{D}_f^t\right) \\ & \nabla_{\boldsymbol{\theta}} \mathcal{L}=\sum_{t^{\prime} \in \mathcal{T}} \frac{\partial \mathcal{L}^{t^{\prime}}}{\partial \mathbf{h}^{t^{\prime}}} \circ \boldsymbol{\epsilon}^{t^{\prime}} \end{aligned}\end{split}\]where \(\boldsymbol{\epsilon}_{\mathbf{x}}^t\) is the input-side trace, \(\boldsymbol{\epsilon}_{\mathbf{f}}^t\) the output-side trace, \(\alpha\) the exponential-smoothing factor, \(\mathbf{D}^t\) the hidden-to-hidden Jacobian, \(\mathbf{D}_f^t\) the state-to-output Jacobian, and \(\mathbf{x}^t\) the presynaptic input.
The full per-parameter D-RTRL trace \(\boldsymbol{\epsilon}^t \in \mathbb{R}^{I\times O}\) is approximated by the outer product of two exponentially-smoothed vectors — one over the input dimension and one over the output dimension. Storing the two factors instead of the full matrix drops the memory from \(O(I\cdot O)\) to \(O(I+O)\) per layer. The decay \(\alpha\) (equivalently an approximation rank) controls how much temporal history the factored trace retains; the bias of the exponential estimator is corrected at solve time.
This algorithm has \(O(BI+BO)\) memory complexity and \(O(BIO)\) computational complexity, where \(I\) and \(O\) are the number of input and output dimensions, and \(B\) the batch size. In particular, for a linear transformation layer, the weight gradients are computed with \(O(Bn)\) memory complexity and \(O(Bn^2)\) computational complexity, where \(n\) is the number of hidden dimensions.
For more details, please see the ES-D-RTRL algorithm presented in our manuscript.
Examples
>>> import brainstate >>> import braintrace >>> >>> class RNN(brainstate.nn.Module): ... def __init__(self): ... super().__init__() ... self.cell = braintrace.nn.ValinaRNNCell(1, 20, activation='tanh') ... self.out = braintrace.nn.Linear(20, 1) ... def update(self, x): ... return x >> self.cell >> self.out >>> >>> model = RNN() >>> _ = brainstate.nn.init_all_states(model) >>> learner = braintrace.pp_prop(model, decay_or_rank=0.9) # or rank: decay_or_rank=19 >>> x0 = brainstate.random.randn(1) >>> learner.compile_graph(x0) # trace the graph once >>> y = learner(x0) # forward pass + eligibility-trace update
References
- get_etrace_of(weight)[source]#
Get the eligibility trace of the given weight.
- Parameters:
weight (
Union[ParamState,Tuple[str,...]]) – The weight whose eligibility trace is requested, given either as abrainstate.ParamStateinstance or as its path in the model.- Return type:
- Returns:
etrace_xs (dict) – The input-side eligibility traces keyed by the weight-input variable.
etrace_dfs (dict) – The output-side eligibility traces keyed by
(y_var, hidden-group index).
- Raises:
ValueError – If no eligibility trace is found for the given weight.