Source code for braintrace._etrace_op._primitive

# Copyright 2026 BrainX Ecosystem Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

r""":class:`ETPPrimitive` and :func:`register_primitive`.

Each ETP primitive is a JAX :class:`~jax.core.Primitive` subclass with
four ETP-specific rule slots (``yw_to_w``, ``xy_to_dw``, ``init_drtrl``,
``init_pp``). All standard JAX rules — ``impl``, ``abstract_eval``,
MLIR lowering, JVP, transpose, batching — are auto-derived from a single
implementation function via :func:`register_primitive`.
"""

from functools import partial
from typing import Callable

import jax
import jax.numpy as jnp
from jax.core import ShapedArray
from jax.interpreters import ad, batching, mlir

from braintrace._compatible_imports import Primitive
from ._registries import (
    BATCHED_PRIMITIVES,
    ETP_PRIMITIVES,
    ETP_RULES_INIT_DRTRL,
    ETP_RULES_INIT_PP,
    ETP_RULES_XY_TO_DW,
    ETP_RULES_YW_TO_W,
    ETP_TRAINABLE_INVARS_FNS,
    ETP_X_INVAR_INDICES,
    ETP_Y_OUTVAR_INDICES,
    GRADIENT_ENABLED_PRIMITIVES,
)

__all__ = [
    'ETPPrimitive',
    'register_primitive',
]


[docs] class ETPPrimitive(Primitive): """A JAX ``Primitive`` with ETP rule registration helpers. Returned by :func:`register_primitive`. Supports every standard JAX primitive operation (``bind``, ``def_impl``, ...) and adds five convenience methods for installing ETP-specific rules into the global registries. See Also -------- register_primitive : Factory that creates and returns an instance. Examples -------- .. code-block:: python >>> import jax.numpy as jnp >>> import braintrace >>> >>> # Register a primitive whose forward delegates to a standard op. >>> def my_impl(x, w): ... return x @ w >>> my_p = braintrace.register_primitive('etp_demo_mm', my_impl, batched=True) >>> y = my_p.bind(jnp.ones((2, 3)), jnp.ones((3, 4))) >>> print(y.shape) (2, 4) """
[docs] def register_yw_to_w(self, fn: Callable): """Install a D-RTRL trace propagation rule. Parameters ---------- fn : Callable Rule with signature ``(hidden_dim, trace, **params) -> trace``. """ ETP_RULES_YW_TO_W[self] = fn
[docs] def register_xy_to_dw(self, fn: Callable): """Install a weight-gradient rule. Parameters ---------- fn : Callable Rule with signature ``(x, hidden_dim, w, **params) -> dw``. """ ETP_RULES_XY_TO_DW[self] = fn
[docs] def register_init_drtrl(self, fn: Callable): """Install a D-RTRL trace initialiser. Parameters ---------- fn : Callable Rule with signature ``(x_var, y_var, weight_var, num_hidden_state) -> zeros``. """ ETP_RULES_INIT_DRTRL[self] = fn
[docs] def register_init_pp(self, fn: Callable): """Install a pp_prop (IO-dim) df trace initialiser. Parameters ---------- fn : Callable Rule with signature ``(x_var, y_var, weight_var, num_hidden_state) -> zeros``. """ ETP_RULES_INIT_PP[self] = fn
[docs] def register_etp_rules( self, *, yw_to_w: Callable = None, xy_to_dw: Callable = None, init_drtrl: Callable = None, init_pp: Callable = None, ): """Install multiple ETP rules in one call. Any argument left as ``None`` is skipped. Parameters ---------- yw_to_w : Callable, optional D-RTRL trace propagation rule. Default ``None``. xy_to_dw : Callable, optional Weight-gradient rule. Default ``None``. init_drtrl : Callable, optional D-RTRL trace initialiser. Default ``None``. init_pp : Callable, optional pp_prop (IO-dim) df trace initialiser. Default ``None``. """ if yw_to_w is not None: ETP_RULES_YW_TO_W[self] = yw_to_w if xy_to_dw is not None: ETP_RULES_XY_TO_DW[self] = xy_to_dw if init_drtrl is not None: ETP_RULES_INIT_DRTRL[self] = init_drtrl if init_pp is not None: ETP_RULES_INIT_PP[self] = init_pp
[docs] def register_primitive( name, impl_fn, *, batched=False, gradient_enabled=False, trainable_invars_fn=None, x_invar_index=0, y_outvar_index=0, ): """Create an :class:`ETPPrimitive` with all JAX rules auto-derived. Only the four ETP-specific rules need hand-writing — call the returned primitive's ``register_*`` methods. Parameters ---------- name : str Primitive name (e.g. ``'etp_mm'``). impl_fn : Callable Implementation function. batched : bool, optional Whether this primitive operates on batched inputs. Default ``False``. gradient_enabled : bool, optional If ``True``, the compiler will *evaluate* this primitive when walking ``y -> h`` (identity-like ops such as ``etp_elemwise_p``). Default ``False``. trainable_invars_fn : Callable or None, optional Function ``eqn.params -> {key: invar_index}`` declaring the primitive's full trainable-input layout. Used by the compiler and executors to support N-trainable-input primitives (e.g. ``{weight, bias}`` for Linear, ``{B, A, bias}`` for LoRA). If ``None``, the compiler falls back to the single-weight ``{'weight': 1}`` layout. Default ``None``. x_invar_index : int or None, optional Position of the input ``x`` in ``eqn.invars``, or ``None`` for primitives with no external input (currently only ``etp_elemwise_p``). Default ``0``. y_outvar_index : int, optional Position of the output ``y`` in ``eqn.outvars``. Default ``0``. Returns ------- ETPPrimitive The registered primitive. Notes ----- The following standard JAX rules are installed automatically: - **impl** — eager execution. - **abstract_eval** — via ``jax.eval_shape(impl)``. - **lowering** — via ``mlir.lower_fun(impl)``. - **JVP** — via ``jax.jvp(impl)``. - **transpose** — derived by JAX from the JVP. - **batching** — via ``jax.vmap(impl)``. """ p = ETPPrimitive(name) ETP_PRIMITIVES.add(p) if batched: BATCHED_PRIMITIVES.add(p) if gradient_enabled: GRADIENT_ENABLED_PRIMITIVES.add(p) if trainable_invars_fn is not None: ETP_TRAINABLE_INVARS_FNS[p] = trainable_invars_fn ETP_X_INVAR_INDICES[p] = x_invar_index ETP_Y_OUTVAR_INDICES[p] = y_outvar_index p.def_impl(impl_fn) @p.def_abstract_eval def _abstract(*args, **params): shapes = tuple(ShapedArray(a.shape, a.dtype) for a in args) out = jax.eval_shape(partial(impl_fn, **params), *shapes) return ShapedArray(out.shape, out.dtype) mlir.register_lowering( p, mlir.lower_fun(impl_fn, multiple_results=False), ) def _jvp(primals, tangents, **params): tans = tuple( jnp.zeros(pr.shape, pr.dtype) if isinstance(t, ad.Zero) else t for pr, t in zip(primals, tangents) ) return jax.jvp(partial(impl_fn, **params), primals, tans) ad.primitive_jvps[p] = _jvp def _batching(args, dims, **params): return jax.vmap(partial(impl_fn, **params), in_axes=dims)(*args), 0 batching.primitive_batchers[p] = _batching return p