GradExpon

GradExpon#

class braintrace.GradExpon(grad_shape, tau_or_decay)[source]#

Accumulate gradients with an exponential (leaky) running sum.

Maintains a decaying accumulator over a pytree of gradients, useful for smoothing online-learning gradient signals across time steps.

Parameters:
  • grad_shape (PyTree) – A pytree whose leaves give the shape and dtype of the gradients to accumulate. The accumulator is initialised to zeros matching each leaf.

  • tau_or_decay (Quantity | float) – Either a decay time constant (as a Quantity), from which the decay factor is computed as \(\exp(-1 / (\tau / \mathrm{dt}))\), or the decay factor itself (a float in the open interval \((0, 1)\)).

Notes

The update rule is

\[g_{t+1} = \mathrm{decay} \cdot g_t + \mathrm{grads},\]

where \(g_t\) is the accumulated gradient at time \(t\), \(\mathrm{grads}\) is the new gradient at time \(t\), and \(\mathrm{decay}\) is the decay factor.

Examples

>>> import brainstate
>>> import jax.numpy as jnp
>>> import braintrace
>>> acc = braintrace.GradExpon(jnp.zeros((3,)), 0.9)
>>> acc.update(jnp.ones((3,)))
>>> acc.update(jnp.ones((3,)))
>>> print(acc.gradients.value)
[1.9 1.9 1.9]
update(grads)[source]#

Update the accumulated gradients with the exponential decay rule.

Applies \(g_{t+1} = \mathrm{decay} \cdot g_t + \mathrm{grads}\), where \(g_t\) is the accumulated gradient, grads the new gradient, and \(\mathrm{decay}\) the decay factor. The accumulator stored in self.gradients is updated in place.

Parameters:

grads (PyTree) – The new gradients to incorporate into the accumulated gradients. Must match the pytree structure of the accumulator.

Returns:

The self.gradients attribute is updated in place.

Return type:

None