Source code for braintrace._grad_exponential

# Copyright 2025 BrainX Ecosystem Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================


import brainstate
import jax.tree
import saiunit as u

__all__ = [
    'GradExpon',
]


[docs] class GradExpon(brainstate.nn.Module): r"""Accumulate gradients with an exponential (leaky) running sum. Maintains a decaying accumulator over a pytree of gradients, useful for smoothing online-learning gradient signals across time steps. Parameters ---------- grad_shape : brainstate.typing.PyTree A pytree whose leaves give the shape and dtype of the gradients to accumulate. The accumulator is initialised to zeros matching each leaf. tau_or_decay : saiunit.Quantity or float Either a decay time constant (as a :class:`~saiunit.Quantity`), from which the decay factor is computed as :math:`\exp(-1 / (\tau / \mathrm{dt}))`, or the decay factor itself (a ``float`` in the open interval :math:`(0, 1)`). Notes ----- The update rule is .. math:: g_{t+1} = \mathrm{decay} \cdot g_t + \mathrm{grads}, where :math:`g_t` is the accumulated gradient at time :math:`t`, :math:`\mathrm{grads}` is the new gradient at time :math:`t`, and :math:`\mathrm{decay}` is the decay factor. Examples -------- .. code-block:: python >>> import brainstate >>> import jax.numpy as jnp >>> import braintrace >>> acc = braintrace.GradExpon(jnp.zeros((3,)), 0.9) >>> acc.update(jnp.ones((3,))) >>> acc.update(jnp.ones((3,))) >>> print(acc.gradients.value) [1.9 1.9 1.9] """ def __init__( self, grad_shape: brainstate.typing.PyTree, tau_or_decay: u.Quantity | float, ): super().__init__() # gradients (stored as LongTermState for proper JAX transform tracking) self.gradients = brainstate.LongTermState( jax.tree.map(lambda x: jax.numpy.zeros_like(x), grad_shape) ) # decay time constant if isinstance(tau_or_decay, u.Quantity): tau = u.maybe_decimal(tau_or_decay / brainstate.environ.get_dt()) decay = u.math.exp(-1.0 / tau) elif isinstance(tau_or_decay, float): assert 0.0 < tau_or_decay < 1.0, f"Decay must be between 0 and 1, but got {tau_or_decay}" decay = tau_or_decay else: raise TypeError(f"tau_or_decay must be a Quantity or a float, but got {tau_or_decay}") self.decay = decay
[docs] def update(self, grads: brainstate.typing.PyTree): r"""Update the accumulated gradients with the exponential decay rule. Applies :math:`g_{t+1} = \mathrm{decay} \cdot g_t + \mathrm{grads}`, where :math:`g_t` is the accumulated gradient, ``grads`` the new gradient, and :math:`\mathrm{decay}` the decay factor. The accumulator stored in ``self.gradients`` is updated in place. Parameters ---------- grads : brainstate.typing.PyTree The new gradients to incorporate into the accumulated gradients. Must match the pytree structure of the accumulator. Returns ------- None The ``self.gradients`` attribute is updated in place. """ self.gradients.value = jax.tree.map( lambda x, y: x * self.decay + y, self.gradients.value, grads, is_leaf=u.math.is_quantity )