Source code for braintrace._etrace_algorithms._common

# Copyright 2026 BrainX Ecosystem Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

"""Shared helpers for the SNN online-learning algorithms."""

from functools import partial
from typing import Any, Dict, Optional

import brainstate
import jax
import jax.numpy as jnp
import saiunit as u

__all__ = [
    'PresynapticTrace',
    'KappaFilter',
    'FixedRandomFeedback',
    'extract_y_target',
]


class _ZeroResetState(brainstate.ShortTermState):
    """``ShortTermState`` that records its init shape/dtype and resets to zeros.

    Both :class:`PresynapticTrace` and :class:`KappaFilter` are leaky scalar-rate
    accumulators that share the same reset semantics: on ``reset_state`` they are
    re-zeroed at the original shape, optionally with the leading dimension swapped
    for ``batch_size`` (or prepended for scalar-shaped states).
    """

    def __init__(self, init_value):
        super().__init__(init_value)
        self._init_shape = jnp.shape(init_value)
        self._init_dtype = init_value.dtype

    def reset_state(self, batch_size: Optional[int] = None, **kwargs):
        """Re-zero the state at its original shape.

        Parameters
        ----------
        batch_size : int or None, optional
            If given, the state is re-zeroed with ``batch_size`` as the leading
            dimension (prepended when the original state was scalar-shaped).
            When ``None`` (default) the original unbatched shape is restored.
        **kwargs
            Ignored; accepted for compatibility with the brainstate state-reset
            protocol.
        """
        if batch_size is None:
            shape = self._init_shape
        elif len(self._init_shape) == 0:
            shape = (batch_size,)
        else:
            shape = (batch_size, *self._init_shape[1:])
        self.value = jnp.zeros(shape, dtype=self._init_dtype)


class PresynapticTrace(_ZeroResetState):
    r"""Leaky presynaptic accumulator used by OTTT and OTPE-Approx.

    The trace accumulates the presynaptic input with a multiplicative decay,
    following :math:`\hat{a} \leftarrow \lambda \cdot \hat{a} + x_t`.

    Parameters
    ----------
    init_value : jax.Array
        Initial value; also dictates the shape and dtype of the trace.
    leak : float
        Decay factor :math:`\lambda` in ``(0, 1)``. Pulled from the neuron's
        membrane leak in SNN usage.

    Raises
    ------
    ValueError
        If ``leak`` is not strictly inside the open interval ``(0, 1)``.

    Examples
    --------
    .. code-block:: python

        >>> import jax.numpy as jnp
        >>> import braintrace
        >>> trace = braintrace.PresynapticTrace(jnp.zeros(3), leak=0.5)
        >>> out = trace.update(jnp.ones(3))
        >>> print(out)
        [1. 1. 1.]
        >>> out = trace.update(jnp.ones(3))
        >>> print(out)
        [1.5 1.5 1.5]
    """

    __module__ = 'braintrace'

    def __init__(self, init_value, leak: float):
        super().__init__(init_value)
        if not (0.0 < leak < 1.0):
            raise ValueError(f'leak must be in (0, 1); got {leak}')
        self.leak = float(leak)

[docs] def update(self, x): r"""Apply one accumulation step :math:`\hat{a} \leftarrow \lambda \cdot \hat{a} + x`. Parameters ---------- x : jax.Array The new presynaptic input added to the decayed trace. Returns ------- jax.Array The updated trace value. """ self.value = self.leak * self.value + x return self.value
class KappaFilter(_ZeroResetState): r"""Low-pass output-side filter used by EProp. The filter smooths the output-side signal following :math:`x_{\mathrm{filt}} \leftarrow (1-\kappa) \cdot x + \kappa \cdot x_{\mathrm{filt}}`. Parameters ---------- init_value : jax.Array Initial value; also dictates the shape and dtype of the filtered state. kappa : float Decay factor :math:`\kappa` in ``[0, 1)``. A value of ``0`` disables filtering. Raises ------ ValueError If ``kappa`` is not inside the half-open interval ``[0, 1)``. Examples -------- .. code-block:: python >>> import jax.numpy as jnp >>> import braintrace >>> filt = braintrace.KappaFilter(jnp.zeros(3), kappa=0.5) >>> out = filt.update(jnp.ones(3)) >>> print(out) [0.5 0.5 0.5] >>> out = filt.update(jnp.ones(3)) >>> print(out) [0.75 0.75 0.75] """ __module__ = 'braintrace' def __init__(self, init_value, kappa: float): super().__init__(init_value) if not (0.0 <= kappa < 1.0): raise ValueError(f'kappa must be in [0, 1); got {kappa}') self.kappa = float(kappa)
[docs] def update(self, x): r"""Apply one low-pass step :math:`x_{\mathrm{filt}} \leftarrow (1-\kappa) x + \kappa\, x_{\mathrm{filt}}`. Parameters ---------- x : jax.Array The new input mixed into the filtered state. Returns ------- jax.Array The updated, filtered value. """ new = (1.0 - self.kappa) * x + self.kappa * self.value self.value = new return new
class FixedRandomFeedback: r"""Frozen random feedback matrix with a stop-gradient guard. The feedback matrix :math:`B \in \mathbb{R}^{n_{\mathrm{target}} \times n_{\mathrm{layer}}}` is sampled once at construction and frozen via :func:`jax.lax.stop_gradient`. It is used by OSTTP (per-HiddenGroup target projection) and EProp-random-feedback. Parameters ---------- n_target : int Number of target dimensions (the row count of ``B``). n_layer : int Number of layer dimensions (the column count of ``B``). key : jax.Array A JAX PRNG key used to sample the feedback matrix. init_scale : float, optional Standard-deviation scaling applied to the sampled normal entries. Default is ``0.1``. Attributes ---------- B : jax.Array The frozen feedback matrix of shape ``(n_target, n_layer)``. n_target : int Number of target dimensions. n_layer : int Number of layer dimensions. Examples -------- .. code-block:: python >>> import jax >>> import braintrace >>> fb = braintrace.FixedRandomFeedback(2, 3, jax.random.PRNGKey(0)) >>> print(fb.B.shape) (2, 3) >>> y = jax.numpy.ones(2) >>> print(fb.project(y).shape) (3,) """ __module__ = 'braintrace' def __init__(self, n_target: int, n_layer: int, key, init_scale: float = 0.1): self.B = jax.lax.stop_gradient( init_scale * jax.random.normal(key, (n_target, n_layer)) ) self.n_target = int(n_target) self.n_layer = int(n_layer)
[docs] def project(self, y_target): """Project the target onto the frozen feedback matrix. Parameters ---------- y_target : jax.Array The target tensor to project. Both batched and unbatched layouts are handled. Returns ------- jax.Array The projection ``y_target @ B`` with ``B`` frozen. """ return y_target @ self.B
def extract_y_target(args: tuple, *, index: int = -1) -> Optional[jax.Array]: """Fetch the target tensor from a positional-args tuple. Returns ``None`` if ``args`` is empty. ``index`` defaults to the last position (OSTTP's convention: ``algo.update(x, y_target)``). """ if not args: return None return args[index] def _reset_state_in_a_dict( state_dict: Dict[Any, brainstate.State], batch_size: Optional[int], ): """ Reset the values in a dictionary of states to zero. This function iterates over a dictionary of states and resets each state's value to a zero array. The shape of the zero array is determined by the original shape of the state's value and the specified batch size. Args: state_dict (Dict[Any, brainstate.State]): A dictionary where keys are any type and values are brainstate.State objects. Each state's value will be reset to a zero array. batch_size (Optional[int]): The size of the batch. If provided, the zero array will include a batch dimension; otherwise, it will not. Returns: None: The function modifies the state_dict in place, resetting each state's value to a zero array. """ for k, v in state_dict.items(): state_dict[k].value = jax.tree.map(partial(_zeros_like_batch_or_not, batch_size), v.value) def _zeros_like_batch_or_not( batch_size: Optional[int], x: jax.Array ): """ Create a zeros array with the same shape and type as the input array, optionally including a batch dimension. This function generates a zeros array that matches the shape and data type of the input array `x`. If a batch size is provided, the zeros array will include an additional batch dimension at the beginning. Args: batch_size (Optional[int]): The size of the batch. If provided, the zeros array will include a batch dimension. If None, the zeros array will have the same shape as `x`. x (jax.Array): The input array whose shape and data type are used as a reference for creating the zeros array. Returns: jax.Array: A zeros array with the same shape and data type as the input array, optionally including a batch dimension if `batch_size` is provided. """ if batch_size is not None: assert isinstance(batch_size, int), 'The batch size should be an integer. ' return u.math.zeros((batch_size,) + x.shape[1:], x.dtype) else: return u.math.zeros_like(x) def _batched_zeros_like( batch_size: Optional[int], num_state: int, # the number of hidden states x: jax.Array # the input array ): """ Create a batched zeros array with the same shape as the input array, extended by the number of hidden states. This function generates a zeros array that matches the shape of the input array `x`, with an additional dimension for the number of hidden states. If a batch size is provided, the zeros array will also include a batch dimension. Args: batch_size (Optional[int]): The size of the batch. If None, the batch dimension is not included. num_state (int): The number of hidden states, which determines the size of the additional dimension in the zeros array. x (jax.Array): The input array whose shape is used as a reference for creating the zeros array. Returns: jax.Array: A zeros array with the same shape as the input array, extended by the number of hidden states, and optionally including a batch dimension. """ if batch_size is None: return u.math.zeros((*x.shape, num_state), x.dtype) else: return u.math.zeros((batch_size, *x.shape, num_state), x.dtype) def _sum_dim(xs: jax.Array, axis: int = -1): """ Sums the elements along the last dimension of each array in a PyTree. This function applies a sum operation along the last dimension of each array within a PyTree structure. It is useful for reducing the dimensionality of arrays by aggregating values along the specified axis. Args: xs (jax.Array): A PyTree of arrays where each array will have its last dimension summed. Returns: jax.Array: A PyTree with the same structure as the input, where each array has been reduced by summing over its last dimension. """ return jax.tree.map(lambda x: u.math.sum(x, axis=axis), xs) def _unit_safe_add(a, b): """Add two leaves, stripping units only when one side has units and the other does not. Gradient contributions for the same weight may come from paths that preserve physical units (e.g. VJP through the original jaxpr) and paths that already strip them (e.g. ETP ``xy_to_dw`` rules). When the two sides disagree on unit representation, both are reduced to plain arrays before adding; otherwise units are preserved. """ a_is_q = isinstance(a, u.Quantity) b_is_q = isinstance(b, u.Quantity) if a_is_q != b_is_q: a = u.get_mantissa(a) if a_is_q else a b = u.get_mantissa(b) if b_is_q else b return u.math.add(a, b) def _extract_leaf(pytree_val: brainstate.typing.PyTree, leaf_idx: int): """Return the leaf at ``leaf_idx`` in ``jax.tree.leaves(pytree_val)``. Bare arrays (treedef with a single leaf) return the array unchanged. Raises ``IndexError`` if ``leaf_idx`` is outside ``len(leaves)``. """ leaves = jax.tree.leaves(pytree_val) if not leaves: return pytree_val if leaf_idx < 0 or leaf_idx >= len(leaves): raise IndexError( f'leaf_idx {leaf_idx} out of range for pytree with {len(leaves)} leaves' ) return leaves[leaf_idx] def _wrap_leaves_as_pytree( reference_pytree: brainstate.typing.PyTree, leaf_grads: Dict[int, jax.Array], ): """Build a pytree matching ``reference_pytree`` with ``leaf_grads`` inserted at the given leaf indices; any other leaf is zero-filled. When the reference is a bare array, ``leaf_grads`` must contain at most one entry at index 0 and that value is returned directly (no wrapping). Raises ``IndexError`` if any supplied index is outside ``len(jax.tree.leaves(reference_pytree))``. """ ref_treedef = jax.tree.structure(reference_pytree) # Bare-array fast path. # jax's PyTreeDef stubs omit num_leaves and __eq__; both are valid at runtime. if ref_treedef.num_leaves <= 1 and ref_treedef == jax.tree.structure(0): # type: ignore[attr-defined, operator] if 0 in leaf_grads: return leaf_grads[0] return u.math.zeros_like(reference_pytree) leaves = jax.tree.leaves(reference_pytree) n = len(leaves) for idx in leaf_grads: if idx < 0 or idx >= n: raise IndexError( f'leaf_idx {idx} out of range for pytree with {n} leaves' ) new_leaves = [ leaf_grads[i] if i in leaf_grads else u.math.zeros_like(leaf) for i, leaf in enumerate(leaves) ] return jax.tree.unflatten(ref_treedef, new_leaves) def _route_grads_by_path( relation, per_key_grads: Dict[str, jax.Array], weight_vals: Dict[Any, brainstate.typing.PyTree], target_dict: Dict[Any, brainstate.typing.PyTree], ) -> None: """Route per-key gradients from a dict-API rule into per-path pytrees. Both D-RTRL and ES-D-RTRL share this bookkeeping: for each key in ``per_key_grads`` (returned by ``xy_to_dw`` or ``yw_to_w``), look up the owning ``ParamState`` path and the leaf index from ``relation``, accumulate into ``per_path``, then wrap with ``_wrap_leaves_as_pytree`` and merge into ``target_dict`` via ``_update_dict``. Args: relation: HiddenParamOpRelation — provides ``trainable_paths`` and ``trainable_leaf_indices``. per_key_grads: Dict[str, Array] — gradient contributions keyed by trainable invar name (e.g. ``'weight'``, ``'lora_b'``). weight_vals: Dict[Path, PyTree] — current ParamState pytree values; used as the structure template for ``_wrap_leaves_as_pytree``. target_dict: Dict[Path, PyTree] — accumulation target, modified in place. """ per_path: Dict[Any, Dict[int, jax.Array]] = {} for key, grad in per_key_grads.items(): path = relation.trainable_paths[key] leaf_idx = relation.trainable_leaf_indices[key] per_path.setdefault(path, {})[leaf_idx] = grad for path, leaf_to_grad in per_path.items(): wrapped = _wrap_leaves_as_pytree(weight_vals[path], leaf_to_grad) _update_dict(target_dict, path, wrapped) def _update_dict( the_dict: Dict, key: Any, value: brainstate.typing.PyTree, error_when_no_key: Optional[bool] = False ): """Update the dictionary. If the key exists, then add the value to the existing value. Otherwise, create a new key-value pair. Args: the_dict: The dictionary. key: The key. value: The value. error_when_no_key: bool, whether to raise an error when the key does not exist. """ if key not in the_dict: if error_when_no_key: raise ValueError(f'The key {key} does not exist in the dictionary. ') the_dict[key] = value else: old_value = the_dict[key] if old_value is None: the_dict[key] = value else: the_dict[key] = jax.tree.map( _unit_safe_add, old_value, value, is_leaf=lambda x: isinstance(x, u.Quantity) )