pp_prop

Contents

pp_prop#

class braintrace.pp_prop(model, decay_or_rank, name=None, vjp_method='single-step', fast_solve=True, **kwargs)#

Online gradient algorithm with diagonal approximation and input-output-dimension complexity.

pp_prop is the canonical name for the input-output-dimension eligibility trace algorithm implemented by IODimVjpAlgorithm. It computes the gradients of the weights with the diagonal approximation and the input-output dimensional complexity.

This subclass inherits all behavior from IODimVjpAlgorithm without modification; it exists to provide the canonical pp_prop name. See IODimVjpAlgorithm for the full parameter list.

See also

IODimVjpAlgorithm

The implementing class with the full parameter list.

Notes

The learning rule is

\[\begin{split}\begin{aligned} & \boldsymbol{\epsilon}^t \approx \boldsymbol{\epsilon}_{\mathbf{f}}^t \otimes \boldsymbol{\epsilon}_{\mathbf{x}}^t \\ & \boldsymbol{\epsilon}_{\mathbf{x}}^t=\alpha \boldsymbol{\epsilon}_{\mathbf{x}}^{t-1}+\mathbf{x}^t \\ & \boldsymbol{\epsilon}_{\mathbf{f}}^t=\alpha \operatorname{diag}\left(\mathbf{D}^t\right) \circ \boldsymbol{\epsilon}_{\mathbf{f}}^{t-1}+(1-\alpha) \operatorname{diag}\left(\mathbf{D}_f^t\right) \\ & \nabla_{\boldsymbol{\theta}} \mathcal{L}=\sum_{t^{\prime} \in \mathcal{T}} \frac{\partial \mathcal{L}^{t^{\prime}}}{\partial \mathbf{h}^{t^{\prime}}} \circ \boldsymbol{\epsilon}^{t^{\prime}} \end{aligned}\end{split}\]

For more details, please see the ES-D-RTRL algorithm presented in our manuscript.

Examples

>>> import brainstate
>>> import braintrace
>>>
>>> class RNN(brainstate.nn.Module):
...     def __init__(self):
...         super().__init__()
...         self.cell = braintrace.nn.ValinaRNNCell(1, 20, activation='tanh')
...         self.out = braintrace.nn.Linear(20, 1)
...     def update(self, x):
...         return x >> self.cell >> self.out
>>>
>>> model = RNN()
>>> _ = brainstate.nn.init_all_states(model)
>>> learner = braintrace.pp_prop(model, decay_or_rank=0.9)  # or rank: decay_or_rank=19
>>> x0 = brainstate.random.randn(1)
>>> learner.compile_graph(x0)   # trace the graph once
>>> y = learner(x0)             # forward pass + eligibility-trace update

References