pp_prop#
- class braintrace.pp_prop(model, decay_or_rank, name=None, vjp_method='single-step', fast_solve=True, **kwargs)#
Online gradient algorithm with diagonal approximation and input-output-dimension complexity.
pp_propis the canonical name for the input-output-dimension eligibility trace algorithm implemented byIODimVjpAlgorithm. It computes the gradients of the weights with the diagonal approximation and the input-output dimensional complexity.This subclass inherits all behavior from
IODimVjpAlgorithmwithout modification; it exists to provide the canonicalpp_propname. SeeIODimVjpAlgorithmfor the full parameter list.See also
IODimVjpAlgorithmThe implementing class with the full parameter list.
Notes
The learning rule is
\[\begin{split}\begin{aligned} & \boldsymbol{\epsilon}^t \approx \boldsymbol{\epsilon}_{\mathbf{f}}^t \otimes \boldsymbol{\epsilon}_{\mathbf{x}}^t \\ & \boldsymbol{\epsilon}_{\mathbf{x}}^t=\alpha \boldsymbol{\epsilon}_{\mathbf{x}}^{t-1}+\mathbf{x}^t \\ & \boldsymbol{\epsilon}_{\mathbf{f}}^t=\alpha \operatorname{diag}\left(\mathbf{D}^t\right) \circ \boldsymbol{\epsilon}_{\mathbf{f}}^{t-1}+(1-\alpha) \operatorname{diag}\left(\mathbf{D}_f^t\right) \\ & \nabla_{\boldsymbol{\theta}} \mathcal{L}=\sum_{t^{\prime} \in \mathcal{T}} \frac{\partial \mathcal{L}^{t^{\prime}}}{\partial \mathbf{h}^{t^{\prime}}} \circ \boldsymbol{\epsilon}^{t^{\prime}} \end{aligned}\end{split}\]For more details, please see the ES-D-RTRL algorithm presented in our manuscript.
Examples
>>> import brainstate >>> import braintrace >>> >>> class RNN(brainstate.nn.Module): ... def __init__(self): ... super().__init__() ... self.cell = braintrace.nn.ValinaRNNCell(1, 20, activation='tanh') ... self.out = braintrace.nn.Linear(20, 1) ... def update(self, x): ... return x >> self.cell >> self.out >>> >>> model = RNN() >>> _ = brainstate.nn.init_all_states(model) >>> learner = braintrace.pp_prop(model, decay_or_rank=0.9) # or rank: decay_or_rank=19 >>> x0 = brainstate.random.randn(1) >>> learner.compile_graph(x0) # trace the graph once >>> y = learner(x0) # forward pass + eligibility-trace update
References