IODimVjpAlgorithm#

class braintrace.IODimVjpAlgorithm(model, decay_or_rank, name=None, vjp_method='single-step', fast_solve=True, **kwargs)[source]#

Online gradient algorithm with diagonal approximation and input-output-dimension complexity.

This algorithm computes the gradients of the weights with the diagonal approximation and the input-output dimensional complexity. It is based on the RTRL algorithm (Real-Time Recurrent Learning).

Parameters:
  • model (Module) – The model function, which receives the input arguments and returns the model output.

  • decay_or_rank (float | int) – The exponential smoothing factor for the eligibility trace. If a float, it is the decay factor and should be in the range \((0, 1)\). If an integer, it is the number of approximation ranks for the algorithm and should be greater than 0.

  • vjp_method (str) –

    The method for computing the VJP. It should be either "single-step" or "multi-step".

    • "single-step": the VJP is computed at the current time step, i.e., \(\partial L^t/\partial h^t\).

    • "multi-step": the VJP is computed at multiple time steps, i.e., \(\partial L^t/\partial h^{t-k}\), where \(k\) is determined by the data input.

  • name (Optional[str]) – The name of the etrace algorithm.

  • mode (braintrace.mixin.Mode, optional) – The computing mode, indicating the batching information.

Notes

The learning rule is

\[\begin{split}\begin{aligned} & \boldsymbol{\epsilon}^t \approx \boldsymbol{\epsilon}_{\mathbf{f}}^t \otimes \boldsymbol{\epsilon}_{\mathbf{x}}^t \\ & \boldsymbol{\epsilon}_{\mathbf{x}}^t=\alpha \boldsymbol{\epsilon}_{\mathbf{x}}^{t-1}+\mathbf{x}^t \\ & \boldsymbol{\epsilon}_{\mathbf{f}}^t=\alpha \operatorname{diag}\left(\mathbf{D}^t\right) \circ \boldsymbol{\epsilon}_{\mathbf{f}}^{t-1}+(1-\alpha) \operatorname{diag}\left(\mathbf{D}_f^t\right) \\ & \nabla_{\boldsymbol{\theta}} \mathcal{L}=\sum_{t^{\prime} \in \mathcal{T}} \frac{\partial \mathcal{L}^{t^{\prime}}}{\partial \mathbf{h}^{t^{\prime}}} \circ \boldsymbol{\epsilon}^{t^{\prime}} \end{aligned}\end{split}\]

where \(\boldsymbol{\epsilon}_{\mathbf{x}}^t\) is the input-side trace, \(\boldsymbol{\epsilon}_{\mathbf{f}}^t\) the output-side trace, \(\alpha\) the exponential-smoothing factor, \(\mathbf{D}^t\) the hidden-to-hidden Jacobian, \(\mathbf{D}_f^t\) the state-to-output Jacobian, and \(\mathbf{x}^t\) the presynaptic input.

The full per-parameter D-RTRL trace \(\boldsymbol{\epsilon}^t \in \mathbb{R}^{I\times O}\) is approximated by the outer product of two exponentially-smoothed vectors — one over the input dimension and one over the output dimension. Storing the two factors instead of the full matrix drops the memory from \(O(I\cdot O)\) to \(O(I+O)\) per layer. The decay \(\alpha\) (equivalently an approximation rank) controls how much temporal history the factored trace retains; the bias of the exponential estimator is corrected at solve time.

This algorithm has \(O(BI+BO)\) memory complexity and \(O(BIO)\) computational complexity, where \(I\) and \(O\) are the number of input and output dimensions, and \(B\) the batch size. In particular, for a linear transformation layer, the weight gradients are computed with \(O(Bn)\) memory complexity and \(O(Bn^2)\) computational complexity, where \(n\) is the number of hidden dimensions.

For more details, please see the ES-D-RTRL algorithm presented in our manuscript.

Examples

>>> import brainstate
>>> import braintrace
>>>
>>> class RNN(brainstate.nn.Module):
...     def __init__(self):
...         super().__init__()
...         self.cell = braintrace.nn.ValinaRNNCell(1, 20, activation='tanh')
...         self.out = braintrace.nn.Linear(20, 1)
...     def update(self, x):
...         return x >> self.cell >> self.out
>>>
>>> model = RNN()
>>> _ = brainstate.nn.init_all_states(model)
>>> learner = braintrace.pp_prop(model, decay_or_rank=0.9)  # or rank: decay_or_rank=19
>>> x0 = brainstate.random.randn(1)
>>> learner.compile_graph(x0)   # trace the graph once
>>> y = learner(x0)             # forward pass + eligibility-trace update

References

get_etrace_of(weight)[source]#

Get the eligibility trace of the given weight.

Parameters:

weight (Union[ParamState, Tuple[str, ...]]) – The weight whose eligibility trace is requested, given either as a brainstate.ParamState instance or as its path in the model.

Return type:

Tuple[Dict, Dict]

Returns:

  • etrace_xs (dict) – The input-side eligibility traces keyed by the weight-input variable.

  • etrace_dfs (dict) – The output-side eligibility traces keyed by (y_var, hidden-group index).

Raises:

ValueError – If no eligibility trace is found for the given weight.

init_etrace_state(*args, **kwargs)[source]#

Initialize the eligibility trace states of the etrace algorithm.

This method is needed after compiling the etrace graph. See compile_graph() for the details.

reset_state(batch_size=None, **kwargs)[source]#

Reset the eligibility trace states.

Parameters:

batch_size (int) – The batch size used to reshape the reset trace states. Default None.