Source code for braintrace._etrace_op.conv

# Copyright 2026 BrainX Ecosystem Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

r"""Convolution ETP primitive (``etp_conv_p``).

Always expects a batch dimension on the input. The full keyword surface
of ``jax.lax.conv_general_dilated`` is preserved; the wrapper splits and
recombines saiunit quantities for the input and kernel.

**Forward operation**

.. math::

    y_{b, \mathbf{s}, k}
      = \sum_{\mathbf{u}, c} x_{b, \mathbf{s}+\mathbf{u}, c}\,
                              K_{\mathbf{u}, c, k}
        \;+\; b_k

where :math:`\mathbf{s}` runs over spatial output positions,
:math:`\mathbf{u}` over the kernel spatial window, :math:`c` is the input
channel, :math:`k` the output channel, and the kernel layout follows the
``dimension_numbers`` supplied at bind-time. Because the bias is a single
value per output channel shared across every spatial position, its
Jacobian is fundamentally different from the kernel's.

**Role of each ETP rule**

Let :math:`\mathbf{D}_f^t = \partial h / \partial y` (one cotangent per
output element). The conv primitive implements:

* ``xy_to_dw`` — for the *kernel*, uses the conv VJP to produce the full
  weight Jacobian :math:`\partial h / \partial K` (requires :math:`x`).
  For the *bias*, stores the per-position cotangent
  :math:`\partial h / \partial b_k = (\partial h / \partial y)_{b,\mathbf{s},k}`
  **without** summing over spatial positions. The spatial summation is
  deferred to ``yw_to_w`` during trace propagation — because the bias is
  spatially shared, the "true" bias gradient requires summing the
  cotangent along spatial axes, but doing the sum *inside* the trace
  rather than at the instantaneous step keeps the linear algebra
  consistent with the D-RTRL recurrence (the trace must accumulate
  :math:`\mathbf{D}^t \boldsymbol{\epsilon}^{t-1}` with the *same*
  spatial shape the executor feeds in).

* ``yw_to_w`` — applies :math:`\partial h / \partial y` to the trace.
  For the kernel: reduces ``hidden_dim`` over spatial axes
  (the kernel is spatially shared) and broadcasts the result across the
  kernel's spatial dims. For the bias: elementwise multiply with
  ``hidden_dim`` then sum over spatial axes — this is the deferred
  spatial reduction that implements :math:`\sum_{\mathbf{s}} \partial h/\partial b_k`.

* ``init_drtrl`` — allocates weight-shaped :math:`\boldsymbol{\epsilon}_K`
  plus per-position :math:`\boldsymbol{\epsilon}_b` with output-shape
  (kept spatial so the trace can accumulate before the deferred sum).

* ``init_pp`` — allocates the pp-prop output-shaped df trace; the
  :math:`\boldsymbol{\epsilon}_x` factor in ES-D-RTRL is the full
  batched input tensor held by the executor.
"""

from typing import Any, Optional, Sequence

import jax
import jax.numpy as jnp
import saiunit as u

from ._primitive import register_primitive

__all__ = [
    'etp_conv_p',
    'conv',
]


def _etp_conv_impl(
    *args,
    has_bias=False,
    strides=(1,),
    padding='SAME',
    lhs_dilation=None,
    rhs_dilation=None,
    feature_group_count=1,
    batch_group_count=1,
    dimension_numbers=None,
):
    x, kernel = args[0], args[1]
    y = jax.lax.conv_general_dilated(
        lhs=x,
        rhs=kernel,
        window_strides=strides,
        padding=padding,
        lhs_dilation=lhs_dilation,
        rhs_dilation=rhs_dilation,
        feature_group_count=feature_group_count,
        batch_group_count=batch_group_count,
        dimension_numbers=dimension_numbers,
    )
    if has_bias:
        y = y + args[2]
    return y


def _conv_trainable_invars(params):
    """Return ``{key: invar_index}`` depending on ``has_bias``."""
    base = {'weight': 1}
    if params.get('has_bias', False):
        base['bias'] = 2
    return base


def _conv_layout(params):
    """Return ``(n_spatial, channel_axis, batch_axis, kernel_out_axis)``.

    ``n_spatial``:       spatial rank (1, 2, or 3).
    ``channel_axis``:    position of the output-channel axis in the OUTPUT
                         tensor (``y`` / ``hidden_dim`` in batched form).
    ``batch_axis``:      position of the batch axis in the OUTPUT tensor.
    ``kernel_out_axis``: position of the out-channel dimension in the KERNEL
                         tensor (shape of ``w_trace`` *without* any batch
                         prefix, i.e. as stored in the weight array).

    Sources used (in priority order):

    1. ``params['dimension_numbers']`` — when a ``ConvDimensionNumbers``
       namedtuple is present, ``out_spec[0]`` is the batch position and
       ``out_spec[1]`` is the channel position in the output; ``rhs_spec[0]``
       is the out-channel position in the kernel.
    2. ``params['strides']`` — ``len(strides)`` gives ``n_spatial``.
    3. When ``dimension_numbers`` is ``None`` JAX defaults to ``iota``
       (``(0,1,2,...)``) which maps to NCHW / NCH for the output (batch=0,
       channel=1) and OIHW / OIH for the kernel (out-channel=0).

    Notes on ``ConvDimensionNumbers``::

        ConvDimensionNumbers(lhs_spec, rhs_spec, out_spec)

    ``out_spec[0]``  → position of N (batch)   in the output
    ``out_spec[1]``  → position of C (channel) in the output
    ``rhs_spec[0]``  → position of out-channel  in the kernel (logical order)

    Example: ``('NHWC', 'HWIO', 'NHWC')`` gives ``batch_axis=0``,
    ``channel_axis=3``, ``kernel_out_axis=2`` (index of 'O' in 'HWIO').
    """
    n_spatial = len(params.get('strides', (1,)))
    dn = params.get('dimension_numbers', None)
    if dn is None:
        # JAX default: iota = (0,1,2,...) → NCHW/NCH output, OIHW/OIH kernel.
        batch_axis = 0
        channel_axis = 1
        kernel_out_axis = 0  # out-channel at axis 0 of kernel (OIHW-style)
    elif isinstance(dn, tuple) and len(dn) == 3 and isinstance(dn[2], str):
        # String-tuple form e.g. ('NHWC', 'HWIO', 'NHWC').
        out_spec_str = dn[2]
        batch_axis = out_spec_str.index('N')
        channel_axis = out_spec_str.index('C')
        rhs_spec_str = dn[1]
        kernel_out_axis = rhs_spec_str.index('O')
    else:
        # ConvDimensionNumbers namedtuple.
        out_spec = dn.out_spec
        batch_axis = out_spec[0]
        channel_axis = out_spec[1]
        rhs_spec = dn.rhs_spec
        kernel_out_axis = rhs_spec[0]  # logical out-channel position in kernel
    return n_spatial, channel_axis, batch_axis, kernel_out_axis


def _conv_yw_to_w(hidden_dim, trace, **params):
    r"""Propagate :math:`\partial h / \partial y` through the conv trace.

    **Role in D-RTRL.** Implements the :math:`y \to (K, b)` chain factor
    in the :math:`\mathbf{D}^t \boldsymbol{\epsilon}^{t-1}` term. Unlike
    dense matmul, the kernel is *spatially shared*: every output position
    reads the same :math:`K`. Differentiating the forward equation gives

    .. math::

        \frac{\partial y_{b, \mathbf{s}, k}}{\partial K_{\mathbf{u}, c, k'}}
          \;=\; \delta_{k k'}\, x_{b,\, \mathbf{s}+\mathbf{u},\, c}, \qquad
        \frac{\partial y_{b, \mathbf{s}, k}}{\partial b_{k'}}
          \;=\; \delta_{k k'}.

    Pulling back a hidden cotangent :math:`g = \partial h / \partial y`
    therefore contracts :math:`\mathbf{s}` for the kernel and the bias
    (they are shared along spatial axes):

    .. math::

        \frac{\partial h}{\partial K_{\mathbf{u}, c, k}}
          \;=\; \sum_{\mathbf{s}} g_{b, \mathbf{s}, k}\, x_{b, \mathbf{s}+\mathbf{u}, c}, \qquad
        \frac{\partial h}{\partial b_k}
          \;=\; \sum_{\mathbf{s}} g_{b, \mathbf{s}, k}.

    Applied to a trace :math:`\boldsymbol{\epsilon}^{t-1}`, only the
    out-channel axis :math:`k` survives on the hidden-dim side of the
    product; spatial axes are summed. The kernel's spatial axes remain
    free (the trace stores one slot per kernel position), so after
    reducing :math:`g` over :math:`\mathbf{s}` we broadcast the resulting
    per-out-channel vector over the kernel's spatial and in-channel axes.

    **Two execution contexts (detected from ``hidden_dim.ndim``):**

    1. trace-update path (batch retained):
       ``hidden_dim   : (batch, *spatial_out, out_ch)`` (or permuted),
       ``trace['weight'] : (batch, *kernel_dims)`` (batch prefix present),
       ``trace['bias']   : (batch, *spatial_out, out_ch)``.

    2. gradient-solve path (outer batch-vmap strips batch):
       ``hidden_dim   : (*spatial_out, out_ch)`` (batch-free),
       ``trace['weight'] : (*kernel_dims)``,
       ``trace['bias']   : (*spatial_out, out_ch)``.

    **Bias gradient is deferred.** ``xy_to_dw`` stores the per-position
    cotangent as ``trace['bias']`` (no spatial sum). Here we complete the
    bias Jacobian by multiplying by :math:`g` and summing spatial axes —
    that deferral keeps the D-RTRL trace recurrence consistent
    (each :math:`\boldsymbol{\epsilon}^{t-1}` leaf has the same shape as
    :math:`(\partial h/\partial y)^{t-1}` before the reduction).

    **Layout awareness.** Spatial axes and the kernel out-channel axis
    are derived from ``dimension_numbers`` / ``strides``, handling NHWC /
    HWIO / NCHW / OIHW etc. The 0-D ``hidden_dim`` degenerate case is
    handled directly by elementwise multiply.
    """
    has_bias = params.get('has_bias', False)
    w_trace = trace['weight']

    if hidden_dim.ndim == 0:
        # Scalar (degenerate) case: multiply all trace entries elementwise.
        out = {'weight': w_trace * hidden_dim}
        if has_bias:
            out['bias'] = jnp.sum(trace['bias'] * hidden_dim)
        return out

    # ── Determine layout from params ──────────────────────────────────────────
    # Detect which call context we are in from hidden_dim rank:
    #   scan context:  hidden_dim.ndim == n_spatial + 2  (batch + spatial + ch)
    #   grad context:  hidden_dim.ndim == n_spatial + 1  (spatial + ch only)
    n_spatial, channel_axis_batched, batch_axis_batched, kernel_out_axis = _conv_layout(params)
    has_batch_prefix = (hidden_dim.ndim == n_spatial + 2)

    # Compute spatial axes in hidden_dim (same permutation as y output).
    if has_batch_prefix:
        # Axes in full output: {batch_axis_batched, channel_axis_batched} excluded.
        spatial_axes_hd = tuple(
            sorted(set(range(hidden_dim.ndim)) - {batch_axis_batched, channel_axis_batched})
        )
        # channel_axis in hidden_dim (same as in y).
        ch_axis_hd = channel_axis_batched
    else:
        # Batch axis is stripped.  Remaining axes: spatial + channel.
        # The original channel_axis_batched and batch_axis_batched are for the
        # batched layout.  After stripping the batch axis, remaining axes are
        # renumbered: remove batch_axis_batched from the set.
        all_axes = set(range(n_spatial + 2))
        remaining = sorted(all_axes - {batch_axis_batched})
        # remaining[i] is the original axis index; map to new (shifted) index.
        ch_axis_hd = remaining.index(channel_axis_batched)
        spatial_axes_hd = tuple(i for i in range(len(remaining)) if i != ch_axis_hd)

    # ── Weight: reduce hidden_dim over spatial axes, then broadcast ───────────
    # Target shape after reduction: only batch (if present) and ch_axis_hd survive.
    hd_reduced = jnp.sum(hidden_dim, axis=spatial_axes_hd) if spatial_axes_hd else hidden_dim
    # hd_reduced shape: (*batch_prefix, out_ch)  [only batch and ch axes survive]

    # Determine where out_ch sits in w_trace:
    #   scan context: w_trace has batch prefix at axis 0, so kernel_out_axis shifts by 1.
    #   grad context: w_trace has no batch prefix, use kernel_out_axis directly.
    w_out_axis = kernel_out_axis + 1 if has_batch_prefix else kernel_out_axis

    # Build broadcast shape: all-ones except batch (axis 0 when present) and out_ch axis.
    target_shape = [1] * w_trace.ndim
    if has_batch_prefix:
        target_shape[0] = w_trace.shape[0]  # batch size
    target_shape[w_out_axis] = w_trace.shape[w_out_axis]  # out_ch size
    hd_for_weight = jnp.reshape(hd_reduced, target_shape)

    out = {'weight': w_trace * hd_for_weight}

    # ── Bias: trace['bias'] has y-output shape; sum over spatial axes ─────────
    if has_bias:
        b_trace = trace['bias']
        # b_trace has same ndim as hidden_dim (same layout).
        b_sum_axes = spatial_axes_hd
        out['bias'] = jnp.sum(b_trace * hidden_dim, axis=b_sum_axes) if b_sum_axes else b_trace * hidden_dim
    return out


def _conv_xy_to_dw(x, hidden_dim, weights, **params):
    r"""Instantaneous conv Jacobian :math:`\partial h / \partial (K, b)`.

    **Role in D-RTRL.** Produces the
    :math:`\operatorname{diag}(\mathbf{D}_f^t) \otimes \mathbf{x}^t` term
    for the conv primitive. The derivative pieces are

    .. math::

        \frac{\partial y_{b, \mathbf{s}, k}}{\partial K_{\mathbf{u}, c, k'}}
          = \delta_{k k'}\, x_{b, \mathbf{s}+\mathbf{u}, c}, \qquad
        \frac{\partial y_{b, \mathbf{s}, k}}{\partial b_{k'}}
          = \delta_{k k'},

    so pulling back a hidden cotangent ``hidden_dim`` gives

    .. math::

        \left.\frac{\partial h}{\partial K}\right|_t
          \;=\; \text{VJP}_K\bigl(\mathrm{conv}(x, K)\bigr)(\partial h/\partial y), \qquad
        \left.\frac{\partial h}{\partial b_k}\right|_{t, \mathbf{s}}
          \;=\; (\partial h/\partial y)_{b, \mathbf{s}, k}.

    **Kernel path.** Uses ``jax.vjp`` of ``jax.lax.conv_general_dilated``
    — the kernel genuinely depends on :math:`x`, so the full conv VJP is
    required. The remap ``strides → window_strides`` matches the
    low-level API.

    **Bias path.** The bias appears additively with no spatial coupling,
    so its instantaneous Jacobian is the cotangent itself at each
    spatial position. We store this per-position (no spatial sum here);
    the sum is performed inside :func:`_conv_yw_to_w` during trace
    propagation — this keeps the bias-trace shape identical to
    :math:`\partial h / \partial y` so the :math:`\mathbf{D}^t \boldsymbol{\epsilon}^{t-1}`
    recurrence can be applied element-by-element.

    **Unbatched detection.** The D-RTRL executor vmaps over the batch
    axis, so ``x`` may arrive as ``ndim == n_spatial + 1`` (no batch).
    We prepend a leading axis before calling ``conv_general_dilated``,
    then strip it on the way out implicitly via the VJP.
    """
    has_bias = params.get('has_bias', False)
    # Build conv_general_dilated kwargs; remap 'strides' -> 'window_strides'.
    conv_kw = {}
    for k, v in params.items():
        if k == 'has_bias':
            continue
        if k == 'strides':
            conv_kw['window_strides'] = v
        else:
            conv_kw[k] = v

    # The batched D-RTRL executor vmaps over the batch dimension, so x may
    # arrive here without a leading batch axis.
    # Unbatched detection: a batched input has ndim == n_spatial + 2
    # (batch + spatial + channel / or the permuted equivalent), while an
    # unbatched input has ndim == n_spatial + 1.
    n_spatial = len(params.get('strides', (1,)))
    unbatched = (x.ndim == n_spatial + 1)
    if unbatched:
        x_in = x[None]
        hd_in = hidden_dim[None]
    else:
        x_in = x
        hd_in = hidden_dim

    # Kernel gradient via VJP (needs x).
    def _fwd_w(w):
        return u.get_mantissa(
            jax.lax.conv_general_dilated(x_in, w, **conv_kw)
        )

    _, vjp_fn = jax.vjp(_fwd_w, weights['weight'])
    dw = u.get_mantissa(vjp_fn(hd_in)[0])
    out = {'weight': dw}

    if has_bias:
        # Bias gradient = hidden_dim (cotangent at each output position).
        # No spatial summation — the trace stores per-position ∂h/∂b.
        out['bias'] = hidden_dim

    return out


def _conv_init_drtrl(x_var, y_var, weight_vars, num_hidden_state):
    r"""Initialise conv D-RTRL weight-shaped trace.

    .. math::

        \boldsymbol{\epsilon}_K \in
          \mathbb{R}^{B \times \text{(kernel dims)} \times n_{\text{state}}}, \qquad
        \boldsymbol{\epsilon}_b \in
          \mathbb{R}^{B \times \text{(spatial out)} \times O \times n_{\text{state}}}.

    The bias trace intentionally keeps spatial dims so that
    :func:`_conv_yw_to_w` can apply the :math:`\partial h / \partial y`
    cotangent elementwise before collapsing them. The spatial sum
    implementing :math:`\sum_{\mathbf{s}} \partial h/\partial b_k` is
    performed on the trace-update side, not at the ``xy_to_dw`` step.

    Zero-initialised (matches :math:`\boldsymbol{\epsilon}^0 = \mathbf{0}`).
    """
    batch = x_var.aval.shape[0]
    out = {
        'weight': jnp.zeros(
            (batch, *weight_vars['weight'].aval.shape, num_hidden_state)
        )
    }
    if 'bias' in weight_vars:
        # y_var.aval.shape = (batch, *spatial, out_ch); strip the batch dim.
        out['bias'] = jnp.zeros(
            (batch, *y_var.aval.shape[1:], num_hidden_state)
        )
    return out


def _conv_init_pp(x_var, y_var, weight_vars, num_hidden_state):
    r"""Initialise conv pp-prop / ES-D-RTRL df trace.

    .. math::

        \boldsymbol{\epsilon}_f \in \mathbb{R}^{B \times \text{(spatial)} \times O \times n_{\text{state}}}.

    Output-shaped like :math:`y`. The matching :math:`\boldsymbol{\epsilon}_x`
    in ES-D-RTRL is the raw batched input tensor held by the executor's
    x-trace; :func:`_conv_xy_to_dw` combines the two via conv VJP at
    solve-time.
    """
    return jnp.zeros((*y_var.aval.shape, num_hidden_state), dtype=y_var.aval.dtype)


etp_conv_p = register_primitive(
    'etp_conv',
    _etp_conv_impl,
    batched=True,
    trainable_invars_fn=_conv_trainable_invars,
    x_invar_index=0,
)
etp_conv_p.register_etp_rules(
    yw_to_w=_conv_yw_to_w,
    xy_to_dw=_conv_xy_to_dw,
    init_drtrl=_conv_init_drtrl,
    init_pp=_conv_init_pp,
)


[docs] def conv( x, kernel, bias=None, *, strides: Sequence[int] = (1,), padding: str = 'SAME', lhs_dilation: Optional[Sequence[int]] = None, rhs_dilation: Optional[Sequence[int]] = None, feature_group_count: int = 1, batch_group_count: int = 1, dimension_numbers: Any = None, ): r"""ETP-aware convolution. Computes :math:`y = \mathrm{conv}(x, kernel) \; (+ b)` by routing the kernel (and optional bias) through an ETP primitive so they participate in eligibility-trace computation. The full keyword surface of :func:`jax.lax.conv_general_dilated` is preserved. Always expects a batch dimension on ``x``. Parameters ---------- x : ArrayLike Input tensor with a leading batch dimension. kernel : ArrayLike Convolution kernel, with layout governed by ``dimension_numbers``. bias : ArrayLike or None, optional Per-output-channel bias. Default ``None``. strides : Sequence[int], optional Window strides. Default ``(1,)``. padding : str, optional Padding mode (e.g. ``'SAME'`` or ``'VALID'``). Default ``'SAME'``. lhs_dilation : Sequence[int] or None, optional Left-hand-side (input) dilation factors. Default ``None``. rhs_dilation : Sequence[int] or None, optional Right-hand-side (kernel) dilation factors. Default ``None``. feature_group_count : int, optional Number of feature groups. Default ``1``. batch_group_count : int, optional Number of batch groups. Default ``1``. dimension_numbers : Any, optional Convolution dimension numbers (e.g. ``('NHWC', 'HWIO', 'NHWC')``). Default ``None``, which uses the JAX default layout. Returns ------- ArrayLike Convolution output tensor. Examples -------- .. code-block:: python >>> import brainstate >>> import braintrace >>> >>> brainstate.environ.set(precision=64) >>> # 1-D conv, NCH input and OIH kernel (JAX defaults) >>> x = brainstate.random.randn(8, 3, 16) >>> kernel = brainstate.random.randn(4, 3, 5) >>> y = braintrace.conv(x, kernel, strides=(1,), padding='SAME') >>> print(y.shape) (8, 4, 16) """ conv_kwargs = dict( strides=tuple(strides), padding=padding, lhs_dilation=tuple(lhs_dilation) if lhs_dilation is not None else None, rhs_dilation=tuple(rhs_dilation) if rhs_dilation is not None else None, feature_group_count=feature_group_count, batch_group_count=batch_group_count, dimension_numbers=dimension_numbers, ) x_v, x_u = u.split_mantissa_unit(x) kernel_v, kernel_u = u.split_mantissa_unit(kernel) unit = x_u * kernel_u if bias is not None: bias_v = u.Quantity(bias).to_decimal(unit) r = etp_conv_p.bind(x_v, kernel_v, bias_v, has_bias=True, **conv_kwargs) else: r = etp_conv_p.bind(x_v, kernel_v, has_bias=False, **conv_kwargs) return u.maybe_decimal(r * x_u * kernel_u)