Source code for braintrace._etrace_op.sparse

# Copyright 2026 BrainX Ecosystem Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

r"""Sparse-matmul ETP primitives.

``etp_sp_mm_p`` (batched) and ``etp_sp_mv_p`` (unbatched). The sparse
structure is supplied as a static parameter (``sparse_mat``); only the
non-zero values flow through the primitive as the ``weight_data`` invar.
The structure object must implement ``with_data`` (substitute new data
into the structure) and ``yw_to_w_transposed`` (apply the transposed
sparse pattern to a trace).

**Forward operation**

Let :math:`W = \mathrm{sparse}(w_{\text{data}})` denote the dense matrix
obtained by placing the vector :math:`w_{\text{data}} \in \mathbb{R}^{nnz}`
into the fixed sparse pattern stored in ``sparse_mat``. The forward op
is just dense matmul over the materialised representation:

.. math::

    y = x\, W \;(+ b), \qquad
    W = \mathrm{sparse}(w_{\text{data}}).

Only the nnz non-zero entries are trainable; the structural zeros are
frozen.

**Role of each ETP rule**

* ``xy_to_dw(x, hidden_dim, weights)`` — pullback of :math:`y = x\,\mathrm{sparse}(w) + b`
  by :math:`\jax.vjp`. Sparse-aware: the VJP natively restricts the
  Jacobian to the nnz-entries, returning :math:`\partial h/\partial w_{\text{data}} \in \mathbb{R}^{nnz}`.
  This is the instantaneous
  :math:`\operatorname{diag}(\mathbf{D}_f^t)\otimes \mathbf{x}^t` term
  for D-RTRL, projected onto the sparse support.

* ``yw_to_w(hidden_dim, trace)`` — propagation of
  :math:`\mathbf{D}^t \boldsymbol{\epsilon}^{t-1}`. For the weight data,
  delegates to ``sparse_mat.yw_to_w_transposed``: this contracts
  ``hidden_dim`` along ``out`` and restricts to the sparse pattern in a
  single kernel call — equivalent to computing the dense
  :math:`(\partial h/\partial y) \cdot \mathrm{scatter}^{\top}` but only
  touching the nnz entries.

* ``init_drtrl`` — nnz-dimensional trace
  :math:`\boldsymbol{\epsilon}_w \in \mathbb{R}^{nnz \times n_{\text{state}}}`
  (plus bias trace) instead of :math:`I \times O`; this is the whole
  point of ``etp_sp_*`` — ETP memory scales with ``nnz`` not
  :math:`I \cdot O`.

* ``init_pp`` — output-shaped df trace, identical to the dense case
  (pp-prop factorises :math:`\boldsymbol{\epsilon} \approx \boldsymbol{\epsilon}_f \otimes \boldsymbol{\epsilon}_x`
  and the :math:`\boldsymbol{\epsilon}_f` side is output-shaped
  regardless of how :math:`W` is stored).

**Dict rule API (N-trainable-input refactor)**

Both primitives declare ``trainable_invars_fn``, which returns
``{'weight': 1}`` when ``has_bias=False`` and ``{'weight': 1, 'bias': 2}``
when ``has_bias=True``. The four ETP rules accept / return
``Dict[str, Array]`` instead of bare arrays so the executor can route
gradients to *both* weight and bias ``ParamState`` objects in one pass.

When ``has_bias=False`` the ``'bias'`` key is simply absent from every
dict, so the legacy (no-bias) code path is unchanged in behaviour.
"""

import jax
import jax.numpy as jnp
import saiunit as u

from ._primitive import register_primitive

__all__ = [
    'etp_sp_mm_p',
    'etp_sp_mv_p',
    'sparse_matmul',
]


def _etp_sp_matmul_impl(*args, sparse_mat=None, has_bias=False):
    x, weight_data = args[0], args[1]
    w = sparse_mat.with_data(weight_data)
    y = x @ w
    if has_bias:
        y = y + args[2]
    return y


# ---------------------------------------------------------------------------
# trainable_invars_fn — shared by both mm and mv
# ---------------------------------------------------------------------------

def _sp_trainable_invars(params):
    """Return ``{key: invar_index}`` depending on ``has_bias``."""
    base = {'weight': 1}
    if params.get('has_bias', False):
        base['bias'] = 2
    return base


# ---------------------------------------------------------------------------
# etp_sp_mm_p — batched
# ---------------------------------------------------------------------------

def _sp_mm_yw_to_w(hidden_dim, trace, *, sparse_mat=None, has_bias=False):
    r"""Batched sparse ``yw_to_w`` — propagate :math:`\partial h / \partial y`
    through the nnz-shaped D-RTRL trace.

    **Role in D-RTRL.** Implements the :math:`y \to w_{\text{data}}` chain
    factor inside :math:`\mathbf{D}^t \boldsymbol{\epsilon}^{t-1}`. For
    the dense-equivalent :math:`y_j = \sum_i x_i W_{ij}` we would write
    :math:`\partial y_j / \partial W_{ik} = \delta_{jk} x_i`. Restricted
    to the sparse support, only positions with
    :math:`(i, j) \in \mathrm{pattern}` are kept; ``yw_to_w_transposed``
    performs the contraction and scatter-restrict in one sparse kernel:

    .. math::

        \epsilon^{t}_{w, b, p} \;=\;
          \sum_j (\partial h / \partial y)_{b, j}\,
                 \epsilon^{t-1}_{w, b, p}\,
                 \mathbb{1}[\mathrm{col}(p) = j],

    for each nnz index :math:`p`.

    **Bias**: :math:`y_j = \dots + b_j` ⇒
    :math:`\partial y_j / \partial b_k = \delta_{jk}`, so the bias-trace
    propagation is the familiar elementwise product — just like dense
    matmul.

    **Shapes.**
        scan context: ``hidden_dim : (batch, out)``,
                      ``trace['weight'] : (batch, nnz)``,
                      ``trace['bias']   : (batch, out)``.
        solve context: batch axis dropped by the outer vmap.
    """
    out = {'weight': sparse_mat.yw_to_w_transposed(hidden_dim, trace['weight'])}
    if has_bias:
        out['bias'] = trace['bias'] * hidden_dim
    return out


def _sp_xy_to_dw(x, hidden_dim, weights, *, sparse_mat=None, has_bias=False):
    r"""Sparse instantaneous Jacobian :math:`\partial h / \partial w_{\text{data}}`,
    and :math:`\partial h / \partial b`.

    **Role.** Gives the
    :math:`\operatorname{diag}(\mathbf{D}_f^t)\otimes \mathbf{x}^t` term
    of D-RTRL (and the solve-time factor of ES-D-RTRL) restricted to the
    sparse support. The chain rule gives

    .. math::

        \frac{\partial h}{\partial w_p}
          \;=\; x_{\mathrm{row}(p)} \cdot
                \Bigl(\frac{\partial h}{\partial y}\Bigr)_{\mathrm{col}(p)},

    for each nnz index :math:`p`. :func:`jax.vjp` of
    ``sparse_mat.with_data`` returns exactly this nnz-shaped gradient —
    the zeros outside the pattern are never materialised.

    Bias gradient: identical to dense,
    :math:`\partial h / \partial b = \partial h / \partial y`.

    Both weight and bias pullbacks are fused into one ``jax.vjp`` over a
    dict-valued forward function.
    """

    def _fwd(w_dict):
        y = x @ sparse_mat.with_data(w_dict['weight'])
        if has_bias:
            y = y + w_dict['bias']
        return u.get_mantissa(y)

    _, vjp_fn = jax.vjp(_fwd, weights)
    return jax.tree.map(u.get_mantissa, vjp_fn(hidden_dim)[0])


def _sp_mm_init_drtrl(x_var, y_var, weight_vars, num_hidden_state):
    r"""Initialise batched sparse D-RTRL trace.

    The memory advantage of sparse vs dense lives here:

    .. math::

        \boldsymbol{\epsilon}_w \in \mathbb{R}^{B \times nnz \times n_{\text{state}}}, \qquad
        \boldsymbol{\epsilon}_b \in \mathbb{R}^{B \times O \times n_{\text{state}}}.

    ``nnz`` can be orders of magnitude smaller than :math:`I \cdot O`
    for typical connectivity matrices. Zero-initialised.
    """
    batch = x_var.aval.shape[0]
    nnz = weight_vars['weight'].aval.shape[0]
    out = {'weight': jnp.zeros((batch, nnz, num_hidden_state))}
    if 'bias' in weight_vars:
        out['bias'] = jnp.zeros(
            (batch, *weight_vars['bias'].aval.shape, num_hidden_state)
        )
    return out


def _sp_mm_init_pp(x_var, y_var, weight_vars, num_hidden_state):
    r"""Initialise batched sparse pp-prop / ES-D-RTRL df trace.

    .. math::

        \boldsymbol{\epsilon}_f \in \mathbb{R}^{B \times O \times n_{\text{state}}}.

    Output-shaped — same as the dense case. The :math:`\boldsymbol{\epsilon}_x`
    factor is the raw dense input :math:`x`, held by the executor.
    """
    return jnp.zeros((*y_var.aval.shape, num_hidden_state), dtype=y_var.aval.dtype)


# ---------------------------------------------------------------------------
# etp_sp_mv_p — unbatched
# ---------------------------------------------------------------------------

def _sp_mv_yw_to_w(hidden_dim, trace, *, sparse_mat=None, has_bias=False):
    r"""Unbatched sparse ``yw_to_w`` — identical algebra to the batched case
    with no batch axis.

    Propagates :math:`\partial h / \partial y` through the sparse pattern:

    .. math::

        \epsilon^t_{w, p} \;=\;
          \sum_j (\partial h / \partial y)_j\,
                 \epsilon^{t-1}_{w, p}\,
                 \mathbb{1}[\mathrm{col}(p) = j], \qquad
        \epsilon^t_{b, k} \;=\; (\partial h / \partial y)_k\, \epsilon^{t-1}_{b, k}.

    Shapes:  ``hidden_dim : (out,)``,
             ``trace['weight'] : (nnz,)``,
             ``trace['bias']   : (out,)``.
    """
    out = {'weight': sparse_mat.yw_to_w_transposed(hidden_dim, trace['weight'])}
    if has_bias:
        out['bias'] = trace['bias'] * hidden_dim
    return out


def _sp_mv_init_drtrl(x_var, y_var, weight_vars, num_hidden_state):
    r"""Initialise unbatched sparse D-RTRL trace.

    .. math::

        \boldsymbol{\epsilon}_w \in \mathbb{R}^{nnz \times n_{\text{state}}}, \qquad
        \boldsymbol{\epsilon}_b \in \mathbb{R}^{O \times n_{\text{state}}}.

    Zero-initialised.
    """
    nnz = weight_vars['weight'].aval.shape[0]
    out = {'weight': jnp.zeros((nnz, num_hidden_state))}
    if 'bias' in weight_vars:
        out['bias'] = jnp.zeros(
            (*weight_vars['bias'].aval.shape, num_hidden_state)
        )
    return out


def _sp_mv_init_pp(x_var, y_var, weight_vars, num_hidden_state):
    r"""Initialise unbatched sparse pp-prop / ES-D-RTRL df trace.

    .. math::

        \boldsymbol{\epsilon}_f \in \mathbb{R}^{O \times n_{\text{state}}}.
    """
    return jnp.zeros((*y_var.aval.shape, num_hidden_state), dtype=y_var.aval.dtype)


# ---------------------------------------------------------------------------
# Primitive registration
# ---------------------------------------------------------------------------

etp_sp_mm_p = register_primitive(
    'etp_sp_mm',
    _etp_sp_matmul_impl,
    batched=True,
    trainable_invars_fn=_sp_trainable_invars,
    x_invar_index=0,
)
etp_sp_mm_p.register_etp_rules(
    yw_to_w=_sp_mm_yw_to_w,
    xy_to_dw=_sp_xy_to_dw,
    init_drtrl=_sp_mm_init_drtrl,
    init_pp=_sp_mm_init_pp,
)

etp_sp_mv_p = register_primitive(
    'etp_sp_mv',
    _etp_sp_matmul_impl,
    batched=False,
    trainable_invars_fn=_sp_trainable_invars,
    x_invar_index=0,
)
etp_sp_mv_p.register_etp_rules(
    yw_to_w=_sp_mv_yw_to_w,
    xy_to_dw=_sp_xy_to_dw,
    init_drtrl=_sp_mv_init_drtrl,
    init_pp=_sp_mv_init_pp,
)


[docs] def sparse_matmul(x, weight_data, *, sparse_mat, bias=None): r"""ETP-aware sparse matrix multiplication. Computes :math:`y = x \mathbin{@} \mathrm{sparse}(w) \; (+ b)`, where only the non-zero entries (``weight_data``) of the fixed sparse pattern are trainable and participate in eligibility-trace computation. Auto-dispatches batched/unbatched based on ``x.ndim``. Parameters ---------- x : ArrayLike Input array. weight_data : ArrayLike Sparse-matrix data, i.e. the non-zero values, shape ``(nnz,)``. sparse_mat : object Sparse-matrix structure (e.g. a ``saiunit.sparse`` matrix object). Must expose ``with_data`` (substitute new data into the structure) and ``yw_to_w_transposed`` (apply the transposed sparse pattern to a trace). bias : ArrayLike or None, optional Bias vector. Default ``None``. Returns ------- ArrayLike Output array. """ p = etp_sp_mm_p if x.ndim >= 2 else etp_sp_mv_p x_v, x_u = u.split_mantissa_unit(x) w_v, w_u = u.split_mantissa_unit(weight_data) unit = x_u * w_u if bias is not None: bias_v = u.Quantity(bias).to_decimal(unit) r = p.bind(x_v, w_v, bias_v, sparse_mat=sparse_mat, has_bias=True) else: r = p.bind(x_v, w_v, sparse_mat=sparse_mat, has_bias=False) return u.maybe_decimal(r * x_u * w_u)