Source code for braintrace._etrace_algorithms.e_prop

# Copyright 2026 BrainX Ecosystem Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

"""E-Prop — Eligibility Propagation (Bellec et al., 2020).

E-prop factorizes the BPTT gradient of a recurrent SNN into a *local*
eligibility trace and a *learning signal* broadcast from the readout. This
module builds on ``D_RTRL``'s per-parameter trace and adds the two ingredients
that make the rule biologically plausible:

- An optional κ-filter on each HiddenGroup's learning signal
  (:math:`\\bar L = F_\\kappa(L)`), matching the paper's readout-side low-pass.
- An optional random-feedback variant (``feedback='random'``) that replaces the
  readout's symmetric gradient with a fixed random projection, removing the
  weight-transport requirement.

See :class:`EProp` for the mathematical formulation, references, and an example.
"""

from typing import Dict, Optional

import brainstate
import jax
import jax.numpy as jnp

from ._common import FixedRandomFeedback, KappaFilter
from .param_dim_vjp import ParamDimVjpAlgorithm

__all__ = ['EProp']


class EProp(ParamDimVjpAlgorithm):
    r"""Eligibility Propagation (e-prop) for recurrent spiking networks.

    E-prop approximates the gradient of a loss :math:`\mathcal{L}` with respect
    to a recurrent weight :math:`W_{ji}` by the product of a *local* eligibility
    trace and a *global* learning signal, dropping the temporally non-local
    terms of BPTT:

    .. math::

        \frac{d\mathcal{L}}{dW_{ji}}
        = \sum_t L_j^t \, \bar{e}_{ji}^t ,

    where

    .. math::

        e_{ji}^t = \frac{\partial h_j^t}{\partial W_{ji}}
                 \approx D_j^t \, e_{ji}^{t-1}
                 + \big[\operatorname{diag}(D_{f,j}^t)\big]\, x_i^t ,
        \qquad
        \bar{e}_{ji}^t = \kappa\,\bar{e}_{ji}^{t-1} + e_{ji}^t .

    Here :math:`h_j^t` is the hidden state of neuron :math:`j` at time
    :math:`t`, :math:`x_i^t` the presynaptic input, :math:`D_j^t` the
    hidden-to-hidden (recurrent) Jacobian diagonal, :math:`D_{f,j}^t` the
    state-to-output Jacobian, and :math:`\kappa \in [0, 1)` the readout-side
    low-pass factor. The learning signal is

    .. math::

        L_j^t =
        \begin{cases}
          \dfrac{\partial \mathcal{L}}{\partial h_j^t}
            & \text{(symmetric feedback, standard backprop through readout)} \\[2ex]
          \big(B\,e^t\big)_j
            & \text{(random feedback: a fixed random projection } B\text{)} .
        \end{cases}

    **How it works.** The eligibility trace :math:`e_{ji}^t` is exactly the
    per-parameter trace maintained by :class:`~braintrace.D_RTRL`; it depends
    only on quantities local to the synapse and is updated forward in time. The
    learning signal :math:`L_j^t` is broadcast from the readout. E-prop is
    therefore *online* (no backward pass through time) and uses memory linear in
    the number of parameters. With ``kappa_filter_decay > 0`` the learning
    signal is additionally low-pass filtered; with ``feedback='random'`` the
    symmetric readout gradient is replaced by a frozen random matrix, removing
    the biologically implausible weight-transport requirement.

    Parameters
    ----------
    model : brainstate.nn.Module
        The recurrent SNN whose weights are trained online.
    feedback : {'symmetric', 'random'}, default 'symmetric'
        ``'symmetric'`` uses reverse-AD's :math:`\partial \mathcal{L}/\partial h`
        (standard backprop through the readout). ``'random'`` replaces the
        readout gradient with a frozen random projection (requires
        ``random_feedback_key``).
    kappa_filter_decay : float in [0, 1), default 0.0
        Readout-side low-pass factor :math:`\kappa`. If ``> 0``, each
        HiddenGroup's learning signal is filtered each step
        (:math:`\bar L^t = (1-\kappa)L^t + \kappa\bar L^{t-1}`). ``0`` disables
        filtering.
    random_feedback_key : jax.random.PRNGKey, optional
        Seed for the random-feedback matrices. Required when
        ``feedback='random'``; ignored otherwise.
    name : str, optional
        Name of the algorithm instance.
    vjp_method, fast_solve, normalize_matrix_spectrum
        Forwarded verbatim to :class:`~braintrace.D_RTRL`.

    Raises
    ------
    ValueError
        If ``feedback`` is not one of ``{'symmetric', 'random'}``, or if
        ``feedback='random'`` is given without ``random_feedback_key``.

    Examples
    --------
    .. code-block:: python

        >>> import brainstate
        >>> import braintrace
        >>>
        >>> class RSNN(brainstate.nn.Module):
        ...     def __init__(self):
        ...         super().__init__()
        ...         self.cell = braintrace.nn.ValinaRNNCell(1, 20, activation='tanh')
        ...         self.out = braintrace.nn.Linear(20, 1)
        ...     def update(self, x):
        ...         return x >> self.cell >> self.out
        >>>
        >>> model = RSNN()
        >>> _ = brainstate.nn.init_all_states(model)
        >>> learner = braintrace.EProp(model, kappa_filter_decay=0.9)
        >>> x0 = brainstate.random.randn(1)
        >>> learner.compile_graph(x0)   # trace the graph once
        >>> y = learner(x0)             # forward pass + eligibility-trace update

    References
    ----------
    .. [1] Bellec, G., Scherr, F., Subramoney, A., Hajek, E., Salaj, D.,
       Legenstein, R., & Maass, W. (2020). "A solution to the learning dilemma
       for recurrent networks of spiking neurons." *Nature Communications*,
       11, 3625. https://doi.org/10.1038/s41467-020-17236-y
    """

    __module__ = 'braintrace'

    def __init__(
        self,
        model: brainstate.nn.Module,
        feedback: str = 'symmetric',
        kappa_filter_decay: float = 0.0,
        random_feedback_key: Optional[jax.Array] = None,
        name: Optional[str] = None,
        vjp_method: str = 'single-step',
        fast_solve: bool = True,
        normalize_matrix_spectrum: bool = False,
        **kwargs,
    ):
        if feedback not in ('symmetric', 'random'):
            raise ValueError(
                f"feedback must be 'symmetric' or 'random'; got {feedback!r}"
            )
        if feedback == 'random' and random_feedback_key is None:
            raise ValueError(
                "feedback='random' requires random_feedback_key=<PRNGKey>"
            )
        super().__init__(
            model,
            name=name,
            vjp_method=vjp_method,
            fast_solve=fast_solve,
            normalize_matrix_spectrum=normalize_matrix_spectrum,
            **kwargs,
        )
        self.feedback = feedback
        self.kappa_filter_decay = float(kappa_filter_decay)
        self._random_feedback_key = random_feedback_key
        self._kappa_filters: Dict[int, KappaFilter] = {}
        self._random_feedback: Dict[int, FixedRandomFeedback] = {}

[docs] def init_etrace_state(self, *args, **kwargs): super().init_etrace_state(*args, **kwargs) self._kappa_filters = {} self._random_feedback = {} if self.kappa_filter_decay > 0.0: for rel in self.graph.hidden_param_op_relations: for group in rel.hidden_groups: gid = group.index if gid in self._kappa_filters: continue zeros = jnp.zeros(group.varshape, dtype=jnp.float32) self._kappa_filters[gid] = KappaFilter( zeros, self.kappa_filter_decay ) if self.feedback == 'random': key = self._random_feedback_key for rel in self.graph.hidden_param_op_relations: for group in rel.hidden_groups: gid = group.index if gid in self._random_feedback: continue key, sub = jax.random.split(key) n_layer = int(group.varshape[-1]) # n_target == n_layer — square projection over reverse-AD signal. self._random_feedback[gid] = FixedRandomFeedback( n_target=n_layer, n_layer=n_layer, key=sub, init_scale=0.1 )
[docs] def reset_state(self, batch_size: Optional[int] = None, **kwargs): super().reset_state(batch_size=batch_size, **kwargs) for flt in self._kappa_filters.values(): flt.reset_state(batch_size=batch_size)
def _compute_learning_signal(self, dl_autodiff, args): signals = list(dl_autodiff) if self.feedback == 'random' and self._random_feedback: # dl_autodiff[g].shape == (*varshape, num_state). Project over the # trailing n_layer axis (-2), preserving num_state on the tail. def _project(B, s): return jnp.einsum('...lj,lk->...kj', s, B) signals = [ _project(self._random_feedback[gid].B, s) if gid in self._random_feedback else s for gid, s in enumerate(signals) ] if self._kappa_filters: # KappaFilter state carries varshape, but signal has an extra # trailing num_state axis. Collapse num_state for filter purposes; # broadcast the filtered value back. def _filter(flt, s): # collapse num_state tail: sum over last axis produces shape (*varshape,) collapsed = s.sum(axis=-1) filtered = flt.update(collapsed) return jnp.expand_dims(filtered, axis=-1).astype(s.dtype) signals = [ _filter(self._kappa_filters[gid], s) if gid in self._kappa_filters else s for gid, s in enumerate(signals) ] return signals