GradExpon#
- class braintrace.GradExpon(grad_shape, tau_or_decay)[source]#
Accumulate gradients with an exponential (leaky) running sum.
Maintains a decaying accumulator over a pytree of gradients, useful for smoothing online-learning gradient signals across time steps.
- Parameters:
grad_shape (
PyTree) – A pytree whose leaves give the shape and dtype of the gradients to accumulate. The accumulator is initialised to zeros matching each leaf.tau_or_decay (
Quantity|float) – Either a decay time constant (as aQuantity), from which the decay factor is computed as \(\exp(-1 / (\tau / \mathrm{dt}))\), or the decay factor itself (afloatin the open interval \((0, 1)\)).
Notes
The update rule is
\[g_{t+1} = \mathrm{decay} \cdot g_t + \mathrm{grads},\]where \(g_t\) is the accumulated gradient at time \(t\), \(\mathrm{grads}\) is the new gradient at time \(t\), and \(\mathrm{decay}\) is the decay factor.
Examples
>>> import brainstate >>> import jax.numpy as jnp >>> import braintrace >>> acc = braintrace.GradExpon(jnp.zeros((3,)), 0.9) >>> acc.update(jnp.ones((3,))) >>> acc.update(jnp.ones((3,))) >>> print(acc.gradients.value) [1.9 1.9 1.9]
- update(grads)[source]#
Update the accumulated gradients with the exponential decay rule.
Applies \(g_{t+1} = \mathrm{decay} \cdot g_t + \mathrm{grads}\), where \(g_t\) is the accumulated gradient,
gradsthe new gradient, and \(\mathrm{decay}\) the decay factor. The accumulator stored inself.gradientsis updated in place.- Parameters:
grads (
PyTree) – The new gradients to incorporate into the accumulated gradients. Must match the pytree structure of the accumulator.- Returns:
The
self.gradientsattribute is updated in place.- Return type:
None