Source code for braintrace._etrace_algorithms.ottt

# Copyright 2026 BrainX Ecosystem Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

"""OTTT — Online Training Through Time (Xiao et al., 2022).

OTTT is derived from BPTT but discards the hidden-to-hidden recurrent Jacobian,
so it keeps only a leaky presynaptic eligibility trace
:math:`\\hat a^t \\leftarrow \\lambda\\,\\hat a^{t-1} + x^t` and forms the weight
gradient at each step as the outer product of that trace with the (instantaneous)
learning signal, :math:`\\Delta W = \\hat a^t \\otimes (L \\cdot \\sigma'(u))`.
This yields constant training memory, independent of the number of time steps.

See :class:`OTTT` for the mathematical formulation, references, and an example.
"""

from typing import Dict, Optional

import brainstate
import jax.numpy as jnp

from braintrace._etrace_op import is_batched_primitive
from ._common import PresynapticTrace, _route_grads_by_path, _update_dict
from .vjp_base import ETraceVjpAlgorithm

__all__ = ['OTTT']


class OTTT(ETraceVjpAlgorithm):
    r"""Online Training Through Time for spiking neural networks.

    OTTT tracks only a leaky **presynaptic trace** and forms the weight gradient
    each step as an outer product with the local learning signal:

    .. math::

        \hat{a}^t =
        \begin{cases}
          \lambda\,\hat{a}^{t-1} + x^t & \text{(mode='A', accumulated)} \\
          x^t                           & \text{(mode='O', instantaneous)}
        \end{cases}

    .. math::

        \nabla_{W}\mathcal{L}^t
        = \hat{a}^t \otimes
          \Big( \frac{\partial \mathcal{L}^t}{\partial s^t}\,\sigma'(u^t) \Big)
        \;=\; \hat{a}^t \otimes L^t ,

    where :math:`x^t` is the presynaptic input, :math:`u^t` the membrane
    potential, :math:`s^t = \sigma(u^t)` the (surrogate) spike, :math:`\sigma'`
    the surrogate-gradient function, :math:`\lambda \in (0, 1)` the membrane
    leak, and :math:`L^t` the learning signal already propagated through the
    spike nonlinearity.

    **How it works.** Starting from BPTT, OTTT keeps the spatial credit
    assignment but **drops the hidden-to-hidden recurrent Jacobian**. The only
    state it carries forward in time is the rank-1 presynaptic trace
    :math:`\hat{a}^t`, so the per-step gradient is the outer product of that
    trace with the instantaneous learning signal. Training memory is therefore
    :math:`O(B \cdot I)` per layer and **independent of the sequence length** —
    the cheapest of the algorithms here, at the cost of ignoring longer-range
    temporal credit.

    Parameters
    ----------
    model : brainstate.nn.Module
        The SNN whose weights are trained online.
    mode : {'A', 'O'}, default 'A'
        ``'A'`` accumulates the presynaptic trace over time
        (:math:`\hat a \leftarrow \lambda\,\hat a + x`). ``'O'`` uses the
        instantaneous presynaptic spike only (:math:`\hat a := x^t`).
    leak : float
        Presynaptic leak :math:`\lambda \in (0, 1)`. **Required** — it must be
        supplied explicitly and is never inferred from the model (see
        *Limitations*). Mathematically :math:`\lambda` is the membrane leak of
        the *postsynaptic* neuron whose trace is being accumulated.
    name : str, optional
        Name of the algorithm instance.
    vjp_method : str, optional
        Forwarded to the base algorithm. Only ``'single-step'`` is supported by
        OTTT v1; multi-step inputs raise :class:`NotImplementedError`.

    Limitations
    -----------
    - **The leak must be supplied by the user.** OTTT does *not* try to read
      :math:`\lambda` off the model's neuron states. A previous version walked
      ``model.states()`` and took the first state exposing a ``leak`` attribute,
      but on heterogeneous or multi-population models that silently picks an
      arbitrary (often wrong) value — e.g. the leak of the *presynaptic* layer,
      a readout filter, or whichever population happens to be enumerated first.
      Since :math:`\lambda` is, by the derivation, the membrane leak of the
      postsynaptic neuron of each trained connection, the framework cannot
      guess it safely. A single network with different leaks per layer therefore
      cannot be trained correctly with one global ``leak`` and is unsupported.
    - **Single-state hidden groups only.** Each trained connection must project
      into a :class:`HiddenGroup` with ``num_state == 1``. The weight gradient
      contracts the learning signal ``L`` (shape ``(*varshape, num_state)``)
      down to ``(*varshape,)``; collapsing a ``num_state > 1`` tail (e.g. an
      ALIF neuron carrying both membrane potential and an adaptation variable)
      has no theoretical justification — the trace is a single leaky scalar and
      cannot disentangle per-state credit — so OTTT raises at compile time
      instead of silently summing across states.
    - **Single-step inputs only** (OTTT v1); multi-step inputs raise
      :class:`NotImplementedError`.

    Raises
    ------
    ValueError
        If ``mode`` is not ``'A'`` or ``'O'``, if ``leak`` is not in
        :math:`(0, 1)`, or (at :meth:`compile_graph`) if a trained connection
        projects into a hidden group with ``num_state > 1``.

    Examples
    --------
    .. code-block:: python

        >>> import brainstate
        >>> import braintrace
        >>>
        >>> class Net(brainstate.nn.Module):
        ...     def __init__(self):
        ...         super().__init__()
        ...         self.cell = braintrace.nn.ValinaRNNCell(1, 20, activation='tanh')
        ...         self.out = braintrace.nn.Linear(20, 1)
        ...     def update(self, x):
        ...         return x >> self.cell >> self.out
        >>>
        >>> model = Net()
        >>> _ = brainstate.nn.init_all_states(model)
        >>> # ``leak`` is the postsynaptic membrane leak and must be passed
        >>> # explicitly; it is never inferred from the model.
        >>> learner = braintrace.OTTT(model, mode='A', leak=0.9)
        >>> x0 = brainstate.random.randn(1)
        >>> learner.compile_graph(x0)
        >>> y = learner(x0)

    References
    ----------
    .. [1] Xiao, M., Meng, Q., Zhang, Z., He, D., & Lin, Z. (2022). "Online
       Training Through Time for Spiking Neural Networks." *Advances in Neural
       Information Processing Systems (NeurIPS)* 35.
       https://arxiv.org/abs/2210.04195
    """

    __module__ = 'braintrace'

    def __init__(
        self,
        model: brainstate.nn.Module,
        mode: str = 'A',
        *,
        leak: float,
        name: Optional[str] = None,
        vjp_method: str = 'single-step',
        **kwargs,
    ):
        if mode not in ('A', 'O'):
            raise ValueError(f"mode must be 'A' or 'O'; got {mode!r}")
        if not (0.0 < float(leak) < 1.0):
            raise ValueError(f'leak must be in (0, 1); got {leak}')
        super().__init__(model, name=name, vjp_method=vjp_method)
        self.mode = mode
        self.leak = float(leak)
        self._pre_traces: Dict[int, PresynapticTrace] = {}

[docs] def init_etrace_state(self, *args, **kwargs): self._pre_traces = {} for rel in self.graph.hidden_param_op_relations: for group in rel.hidden_groups: if group.num_state > 1: raise ValueError( f'OTTT only supports hidden groups with num_state == 1, ' f'but a trained connection projects into a group with ' f'num_state == {group.num_state}. Collapsing the learning ' f'signal across multiple hidden states (e.g. an ALIF ' f'neuron with membrane potential plus an adaptation ' f'variable) has no theoretical basis for OTTT; the leaky ' f'scalar presynaptic trace cannot assign per-state credit.' ) rid = id(rel.x_var) if rid in self._pre_traces: continue shape = rel.x_var.aval.shape dtype = rel.x_var.aval.dtype self._pre_traces[rid] = PresynapticTrace( jnp.zeros(shape, dtype=dtype), leak=self.leak )
[docs] def reset_state(self, batch_size: Optional[int] = None, **kwargs): self.running_index.value = 0 for t in self._pre_traces.values(): t.reset_state(batch_size=batch_size)
def _get_etrace_data(self): return {rid: t.value for rid, t in self._pre_traces.items()} def _assign_etrace_data(self, vals): for rid, v in vals.items(): self._pre_traces[rid].value = v def _update_etrace_data( self, running_index, hist_vals, hid2weight_jac, hid2hid_jac, weight_vals, input_is_multi_step, ): """``â ← λ·â + x_t`` (mode='A') or ``â := x_t`` (mode='O'). Ignores ``hid2hid_jac`` — OTTT's core approximation. """ if input_is_multi_step: raise NotImplementedError('OTTT v1 supports single-step only') xs_at_t = hid2weight_jac[0] new_vals = {} for rid, old in hist_vals.items(): x_t = xs_at_t[rid] if self.mode == 'A': new_vals[rid] = self.leak * old + x_t else: new_vals[rid] = x_t return new_vals def _solve_weight_gradients( self, running_index, etrace_at_t, dl_to_hidden_groups, weight_vals, dl_to_nonetws_at_t, dl_to_etws_at_t, ): """``ΔW = outer(â, L)`` where ``L`` is the (already σ'-propagated) signal.""" dG = {path: None for path in self.param_states} for rel in self.graph.hidden_param_op_relations: a_hat = etrace_at_t[id(rel.x_var)] for group in rel.hidden_groups: L = dl_to_hidden_groups[group.index] # L shape = (*varshape, num_state); num_state == 1 is enforced at # compile time (see init_etrace_state), so this drops the singleton # tail rather than summing across genuinely distinct hidden states. L_proj = L.sum(axis=-1) if is_batched_primitive(rel.primitive): # a_hat: (batch, in), L_proj: (batch, out). ΔW: (in, out) dw = jnp.einsum('bi,bo->io', a_hat, L_proj) else: dw = jnp.einsum('i,o->io', a_hat, L_proj) _route_grads_by_path(rel, {'weight': dw}, weight_vals, dG) for path, dg in dl_to_nonetws_at_t.items(): _update_dict(dG, path, dg) if dl_to_etws_at_t is not None: for path, dg in dl_to_etws_at_t.items(): _update_dict(dG, path, dg, error_when_no_key=True) return dG