Source code for braintrace._etrace_algorithms.otpe

# Copyright 2026 BrainX Ecosystem Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

"""OTPE — Online Training with Postsynaptic Estimates (Summe et al., 2023).

OTPE replaces RTRL's full Jacobian with a leaky-additive per-parameter
accumulator :math:`\\hat R \\leftarrow \\lambda\\,\\hat R + \\partial s/\\partial
\\theta_\\text{local}` that estimates how a parameter's influence persists in the
postsynaptic membrane across several time steps — temporal structure that the
single-step approximations OTTT and OSTL drop. Cross-layer coupling is handled
inside ``_solve_weight_gradients`` without relaxing the compiler's "no W→W→h"
invariant.

See :class:`OTPE` for the mathematical formulation, references, and an example.
"""

import warnings
from typing import Dict, Optional

import brainstate
import jax.numpy as jnp

from braintrace._etrace_op import is_batched_primitive
from braintrace._misc import etrace_df_key
from ._common import _route_grads_by_path, _update_dict
from .vjp_base import ETraceVjpAlgorithm

__all__ = ['OTPE']


class OTPE(ETraceVjpAlgorithm):
    r"""Online Training with Postsynaptic Estimates for spiking networks.

    OTPE maintains a leaky, additive estimate :math:`\hat R^t` of each
    parameter's accumulated influence on the postsynaptic state, then contracts
    it with the learning signal :math:`L^t` to obtain the weight gradient:

    .. math::

        \hat R^t = \lambda\,\hat R^{t-1}
                   + \frac{\partial s^t}{\partial \theta}
                 = \lambda\,\hat R^{t-1}
                   + x^t \otimes \operatorname{diag}(D_f^t) ,
        \qquad
        \nabla_{W}\mathcal{L}^t = L^t \cdot \hat R^t ,

    where :math:`x^t` is the presynaptic input, :math:`D_f^t` the
    state-to-output Jacobian (surrogate gradient of the spike), :math:`\lambda
    \in (0, 1)` the membrane leak, and :math:`L^t = \partial \mathcal{L}^t /
    \partial s^t` the learning signal. The contraction runs over the output
    dimension, leaving a gradient with the weight's shape.

    In the low-rank ``'approx'`` mode (**F-OTPE**) the estimate is factorized as
    an outer product, reducing memory from :math:`O(I\cdot O)` to :math:`O(I+O)`
    per layer:

    .. math::

        \hat R^t \approx \hat z_\text{in}^t \otimes \bar g_\text{out}^t ,
        \quad
        \hat z_\text{in}^t = \lambda\,\hat z_\text{in}^{t-1} + x^t ,
        \quad
        \bar g_\text{out}^t = \lambda\,\bar g_\text{out}^{t-1}
                              + \operatorname{diag}(D_f^t) ,

    with gradient :math:`\nabla_{W}\mathcal{L}^t
    = \hat z_\text{in}^t \otimes (L^t \cdot \bar g_\text{out}^t)`.

    **How it works.** Unlike OTTT/OSTL, which assign temporal credit only within
    the current layer's output, OTPE keeps a per-parameter trace that decays
    with the membrane leak, approximating the *entire* temporal effect of a
    weight on downstream activity while staying local to each layer. This
    improves gradient alignment with BPTT in deep feed-forward SNNs at modest
    extra cost.

    Parameters
    ----------
    model : brainstate.nn.Module
        The SNN whose weights are trained online.
    mode : {'full', 'approx'}, default 'full'
        ``'full'`` keeps the full ``(batch, I, O)`` estimate :math:`\hat R` per
        layer. ``'approx'`` (F-OTPE) factorizes it as an outer product for
        :math:`O(I+O)` memory; emits a :class:`UserWarning` when the network has
        more than one HiddenGroup, because the factorization bias compounds with
        depth.
    leak : float
        Decay factor :math:`\lambda \in (0, 1)`. **Required** — it must be
        supplied explicitly and is never inferred from the model. :math:`\lambda`
        is the membrane leak of the *postsynaptic* neuron whose influence is
        being accumulated; auto-inferring it from ``model.states()`` silently
        picks an arbitrary (often wrong) value on heterogeneous or
        multi-population models, so the framework will not guess it.
    trace_clip_abs : float, optional
        Elementwise clip applied to :math:`\hat R` each step (full mode only).
        ``None`` disables clipping.
    name : str, optional
        Name of the algorithm instance.
    vjp_method : str, optional
        Forwarded to the base algorithm. Only ``'single-step'`` is supported by
        OTPE v1; multi-step inputs raise :class:`NotImplementedError`.

    Limitations
    -----------
    OTPE's published derivation is **narrower than OTTT's**, and this
    implementation is a *general operator* that will happily run far outside
    that proven regime. The estimate :math:`\hat R` is built on the assumption
    that the only temporal coupling of the postsynaptic state is the scalar
    membrane leak, :math:`\partial U^t / \partial U^{t-1} = \lambda` — exactly
    the leaky integrate-and-fire (LIF) recurrence. On top of that scalar-leak
    assumption (inherited from OTTT), OTPE adds three further restrictions:

    1. **A single global time constant.** One scalar :math:`\lambda` is shared by
       every traced connection. Heterogeneous leaks across neurons or layers
       break the estimate; ``leak`` is therefore a user-supplied global constant
       and is never inferred from the model (see the ``leak`` parameter).
    2. **Feed-forward only.** The trace omits the hidden-to-hidden Jacobian, so
       it is the *postsynaptic estimate* for feed-forward SNNs. Applying it to a
       recurrent network silently drops the recurrent temporal credit.
    3. **Single-hidden-layer exactness.** The estimate is gradient-exact for one
       hidden layer; with depth the per-layer factorization accumulates bias.

    The low-rank ``'approx'`` mode (**F-OTPE**) layers an additional
    outer-product approximation on top, which is itself justified only under the
    same linear-leak assumption; its bias compounds with network depth (hence
    the :class:`UserWarning` for multi-group networks).

    Concretely, ``braintrace`` exposes OTPE as a generic ETP operator: it accepts
    arbitrary ETP weights and hidden states, multi-layer stacks, recurrent
    connectivity, and even non-spiking cells (e.g. a ``tanh`` RNN). All of these
    *run* mechanically, but **the moment the model deviates from a feed-forward
    LIF network with a single global scalar leak, the computed gradient leaves
    the regime in which OTPE is proven correct** and should be treated as a
    heuristic approximation rather than a faithful gradient estimate. The one
    structural case that is rejected outright is a multi-state hidden group
    (``num_state > 1``, e.g. ALIF with an adaptation variable): the leaky scalar
    estimate cannot assign per-state credit, so :meth:`compile_graph` raises
    rather than silently summing across states.

    Raises
    ------
    ValueError
        If ``mode`` is not ``'full'`` or ``'approx'``, if ``leak`` is not in
        :math:`(0, 1)`, if a weight-to-hidden relation reaches more than one
        HiddenGroup (OTPE v1 requires one-hop per-layer relations), or (at
        :meth:`compile_graph`) if a trained connection projects into a hidden
        group with ``num_state > 1``.

    Examples
    --------
    .. code-block:: python

        >>> import brainstate
        >>> import braintrace
        >>>
        >>> class Net(brainstate.nn.Module):
        ...     def __init__(self):
        ...         super().__init__()
        ...         self.cell = braintrace.nn.ValinaRNNCell(1, 20, activation='tanh')
        ...         self.out = braintrace.nn.Linear(20, 1)
        ...     def update(self, x):
        ...         return x >> self.cell >> self.out
        >>>
        >>> model = Net()
        >>> _ = brainstate.nn.init_all_states(model)
        >>> # ``leak`` is the postsynaptic membrane leak and must be passed
        >>> # explicitly; it is never inferred from the model.
        >>> learner = braintrace.OTPE(model, mode='full', leak=0.9)
        >>> x0 = brainstate.random.randn(1)
        >>> learner.compile_graph(x0)
        >>> y = learner(x0)

    References
    ----------
    .. [1] Summe, T. M., Schaefer, C. J. S., & Joshi, S. (2023). "Estimating
       Post-Synaptic Effects for Online Training of Feed-Forward SNNs."
       *arXiv preprint* arXiv:2311.16151. https://arxiv.org/abs/2311.16151
    """

    __module__ = 'braintrace'

    def __init__(
        self,
        model: brainstate.nn.Module,
        mode: str = 'full',
        *,
        leak: float,
        name: Optional[str] = None,
        vjp_method: str = 'single-step',
        trace_clip_abs: Optional[float] = None,
        **kwargs,
    ):
        if mode not in ('full', 'approx'):
            raise ValueError(f"mode must be 'full' or 'approx'; got {mode!r}")
        if not (0.0 < float(leak) < 1.0):
            raise ValueError(f'leak must be in (0, 1); got {leak}')
        super().__init__(model, name=name, vjp_method=vjp_method)
        self.mode = mode
        self.leak = float(leak)
        self.trace_clip_abs = trace_clip_abs
        self._R_hat: Dict[int, brainstate.ShortTermState] = {}
        self._R_hat_x: Dict[int, brainstate.ShortTermState] = {}
        self._R_hat_g: Dict[int, brainstate.ShortTermState] = {}

[docs] def compile_graph(self, *args) -> None: super().compile_graph(*args) if self.mode == 'approx': n_groups = len(self.graph.hidden_groups) if n_groups > 1: warnings.warn( "OTPE(mode='approx') bias compounds with network depth; " "consider F-OTPE or mode='full'.", UserWarning, ) # Invariant: each relation maps to exactly one HiddenGroup in OTPE v1. for rel in self.graph.hidden_param_op_relations: if len(rel.hidden_groups) != 1: raise ValueError( f'OTPE requires per-layer one-hop weight-to-hidden relations; ' f'found relation reaching {len(rel.hidden_groups)} groups.' ) # OTPE's derivation assumes a single scalar membrane state per neuron # (the LIF case). A hidden group bundling several states (e.g. ALIF's # membrane potential plus adaptation variable) cannot be handled: the # leaky scalar estimate cannot assign per-state credit, and collapsing # the num_state axis with a sum has no theoretical basis (see the # *Limitations* section of the class docstring). group = rel.hidden_groups[0] if group.num_state > 1: raise ValueError( f'OTPE only supports hidden groups with num_state == 1 ' f'(single-state LIF-like neurons), but a trained connection ' f'projects into a group with num_state == {group.num_state}. ' f'Multi-state neurons (e.g. ALIF with an adaptation variable) ' f'are outside the regime where OTPE is derived.' )
[docs] def init_etrace_state(self, *args, **kwargs): self._R_hat = {} self._R_hat_x = {} self._R_hat_g = {} for rel in self.graph.hidden_param_op_relations: rid = id(rel.y_var) in_shape = rel.x_var.aval.shape out_shape = rel.y_var.aval.shape if self.mode == 'full': weight_key = next(iter(rel.trainable_vars)) weight_var = rel.trainable_vars[weight_key] weight_shape = weight_var.aval.shape if is_batched_primitive(rel.primitive): shape = (in_shape[0], *weight_shape) else: shape = weight_shape self._R_hat[rid] = brainstate.ShortTermState( jnp.zeros(shape, dtype=jnp.float32) ) else: self._R_hat_x[rid] = brainstate.ShortTermState( jnp.zeros(in_shape, dtype=jnp.float32) ) self._R_hat_g[rid] = brainstate.ShortTermState( jnp.zeros(out_shape, dtype=jnp.float32) )
[docs] def reset_state(self, batch_size: Optional[int] = None, **kwargs): self.running_index.value = 0 def _rezero(state): shape = state.value.shape new_shape = (batch_size, *shape[1:]) if batch_size is not None else shape state.value = jnp.zeros(new_shape, dtype=state.value.dtype) for store in (self._R_hat, self._R_hat_x, self._R_hat_g): for r in store.values(): _rezero(r)
def _get_etrace_data(self): if self.mode == 'full': return {rid: r.value for rid, r in self._R_hat.items()} return ( {rid: r.value for rid, r in self._R_hat_x.items()}, {rid: r.value for rid, r in self._R_hat_g.items()}, ) def _assign_etrace_data(self, vals): if self.mode == 'full': for rid, v in vals.items(): self._R_hat[rid].value = v else: vals_x, vals_g = vals for rid, v in vals_x.items(): self._R_hat_x[rid].value = v for rid, v in vals_g.items(): self._R_hat_g[rid].value = v def _update_etrace_data( self, running_index, hist_vals, hid2weight_jac, hid2hid_jac, weight_vals, input_is_multi_step, ): """``R_hat ← λ·R_hat + ∂s/∂θ_local``. Ignores ``hid2hid_jac``.""" if input_is_multi_step: raise NotImplementedError('OTPE v1 supports single-step only') xs = hid2weight_jac[0] dfs = hid2weight_jac[1] if self.mode == 'full': new_R = {} for rel in self.graph.hidden_param_op_relations: rid = id(rel.y_var) group = rel.hidden_groups[0] x = xs[id(rel.x_var)] df = dfs[etrace_df_key(rel.y_var, group.index)] df_proj = df.sum(axis=-1) if is_batched_primitive(rel.primitive): local = jnp.einsum('bi,bo->bio', x, df_proj) else: local = jnp.einsum('i,o->io', x, df_proj) updated = self.leak * hist_vals[rid] + local if self.trace_clip_abs is not None: updated = jnp.clip( updated, -self.trace_clip_abs, self.trace_clip_abs ) new_R[rid] = updated return new_R else: new_Rx = {} new_Rg = {} hist_x, hist_g = hist_vals for rel in self.graph.hidden_param_op_relations: rid = id(rel.y_var) group = rel.hidden_groups[0] x = xs[id(rel.x_var)] df = dfs[etrace_df_key(rel.y_var, group.index)].sum(axis=-1) new_Rx[rid] = self.leak * hist_x[rid] + x new_Rg[rid] = self.leak * hist_g[rid] + df return (new_Rx, new_Rg) def _solve_weight_gradients( self, running_index, etrace_at_t, dl_to_hidden_groups, weight_vals, dl_to_nonetws_at_t, dl_to_etws_at_t, ): dG = {path: None for path in self.param_states} if self.mode == 'full': for rel in self.graph.hidden_param_op_relations: rid = id(rel.y_var) R = etrace_at_t[rid] group = rel.hidden_groups[0] L = dl_to_hidden_groups[group.index].sum(axis=-1) if is_batched_primitive(rel.primitive): dw = jnp.einsum('bo,bio->io', L, R) else: dw = jnp.einsum('o,io->io', L, R) _route_grads_by_path(rel, {'weight': dw}, weight_vals, dG) else: Rx_map, Rg_map = etrace_at_t for rel in self.graph.hidden_param_op_relations: rid = id(rel.y_var) group = rel.hidden_groups[0] L = dl_to_hidden_groups[group.index].sum(axis=-1) Rx = Rx_map[rid] Rg = Rg_map[rid] if is_batched_primitive(rel.primitive): dw = jnp.einsum('bi,bo->io', Rx, L * Rg) else: dw = jnp.einsum('i,o->io', Rx, L * Rg) _route_grads_by_path(rel, {'weight': dw}, weight_vals, dG) for path, dg in dl_to_nonetws_at_t.items(): _update_dict(dG, path, dg) if dl_to_etws_at_t is not None: for path, dg in dl_to_etws_at_t.items(): _update_dict(dG, path, dg, error_when_no_key=True) return dG