# Copyright 2026 BrainX Ecosystem Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
r"""LoRA (Low-Rank Adaptation) ETP primitives.
``etp_lora_mm_p`` (batched) and ``etp_lora_mv_p`` (unbatched) compute
:math:`y = \alpha \cdot x \mathbin{@} B \mathbin{@} A` plus an optional
bias. The trace and gradient state are pytrees with ``lora_b``, ``lora_a``
(and optionally ``bias``) leaves; the originating ``ParamState`` holds
all factors as a pytree, e.g. ``{'lora_b': B, 'lora_a': A, 'bias': b}``.
**Forward operation**
.. math::
y = \alpha \, x \, B \, A \;(+ b), \qquad
B \in \mathbb{R}^{I \times r}, \;
A \in \mathbb{R}^{r \times O}, \;
r \ll \min(I, O).
The intermediate :math:`z = x B \in \mathbb{R}^{\dots \times r}` is what
flows through :math:`A` to produce :math:`y`. Both :math:`A` and
:math:`B` are trainable; :math:`\alpha` is a scalar scaling (static).
**Role of each ETP rule**
Let :math:`g = \partial h / \partial y`. The chain rule yields
.. math::
\frac{\partial h}{\partial A_{r,k}}
\;=\; g_k \cdot \alpha \cdot (x B)_{r}, \qquad
\frac{\partial h}{\partial B_{i,r}}
\;=\; \alpha \sum_k g_k\, A_{r,k}\, x_i, \qquad
\frac{\partial h}{\partial b_k}
\;=\; g_k.
* ``xy_to_dw`` — VJP of :math:`y = \alpha\, x B A + b` over the whole
dict ``{'lora_b', 'lora_a', 'bias'}``. JAX's autodiff delivers all
three pullbacks from a single ``jax.vjp`` call, giving the
instantaneous :math:`\operatorname{diag}(\mathbf{D}_f^t)\otimes \mathbf{x}^t`
term of D-RTRL (and the solve-time factor of ES-D-RTRL).
* ``yw_to_w`` — only propagates :math:`g` through the :math:`A` factor
(plus elementwise through the bias). Intuition: :math:`A` is the
"output-facing" factor, so :math:`\partial y / \partial A` attaches a
:math:`g`-shaped scaling to the :math:`A` trace, exactly like dense
matmul's :math:`y \to W` link. :math:`B` is "input-facing" and has
no such :math:`y`-dependent scaling in the linearised view — its
trace is carried unchanged through the :math:`y \to W` step. (The
full :math:`B` gradient *does* depend on :math:`A`, but that
dependence enters via ``xy_to_dw``, not through the trace
propagation.)
* ``init_drtrl`` — allocates separate leaves for :math:`\boldsymbol{\epsilon}_B`,
:math:`\boldsymbol{\epsilon}_A`, and optionally :math:`\boldsymbol{\epsilon}_b`,
each of shape ``(*factor_shape, n_state)`` (plus batch prefix in the
batched primitive).
* ``init_pp`` — output-shaped df trace; same as dense.
**Dict rule API (N-trainable-input refactor)**
Both primitives declare ``trainable_invars_fn``, which returns
``{'lora_b': 1, 'lora_a': 2}`` when ``has_bias=False`` and
``{'lora_b': 1, 'lora_a': 2, 'bias': 3}`` when ``has_bias=True``.
Keys ``'lora_b'`` / ``'lora_a'`` match the pytree leaf names in
``braintrace.nn.LoRALinear``'s merged ``ParamState``.
"""
import jax
import jax.numpy as jnp
import saiunit as u
from ._primitive import register_primitive
__all__ = [
'etp_lora_mm_p',
'etp_lora_mv_p',
'lora_matmul',
]
def _etp_lora_impl(*args, alpha=1.0, has_bias=False):
x, B, A = args[0], args[1], args[2]
y = alpha * (x @ B @ A)
if has_bias:
y = y + args[3]
return y
def _lora_trainable_invars(params):
"""Return ``{key: invar_index}`` for LoRA's trainable inputs."""
base = {'lora_b': 1, 'lora_a': 2}
if params.get('has_bias', False):
base['bias'] = 3
return base
def _lora_mm_yw_to_w(hidden_dim, trace, *, alpha=1.0, has_bias=False):
r"""Batched LoRA ``yw_to_w`` — propagate :math:`\partial h / \partial y`
through the :math:`y \to A` link.
**Role in D-RTRL.** Realises the :math:`y \to (A, B, b)` chain factor
of :math:`\mathbf{D}^t \boldsymbol{\epsilon}^{t-1}` for the LoRA op.
Differentiating :math:`y_k = \alpha \sum_r (xB)_r A_{r,k}` gives
.. math::
\frac{\partial y_k}{\partial A_{r, k'}} = \delta_{k k'}\, \alpha\, (xB)_r,
\qquad
\frac{\partial y_k}{\partial B_{i, r}} =
\alpha\, A_{r, k}\, x_i.
After the executor has already absorbed the :math:`\mathbf{D}^t`
contraction along the hidden axis, only the :math:`y \to` link
remains for ``yw_to_w``. For :math:`A` this link is a simple
broadcast of :math:`g = \partial h / \partial y` across the ``rank``
axis of the trace:
.. math::
\epsilon^t_{A, r, k} = g_k\, \epsilon^{t-1}_{A, r, k}.
For :math:`B`, the :math:`y \to B` link additionally carries an
:math:`A` factor which *does* depend on :math:`y` via the hidden
state. In the D-RTRL diagonal approximation used here, this
cross-coupling is absorbed into the instantaneous contribution
supplied by ``xy_to_dw`` each step rather than carried through the
trace. Consequently the :math:`B`-trace is left unchanged by
:func:`yw_to_w` — propagation only touches :math:`A` (and the
bias, which is diagonal as usual).
**Broadcast rule.** ``jnp.expand_dims(hidden_dim, axis=-2)`` inserts
a singleton at the ``rank`` position in both execution contexts:
(out,) → (1, out) broadcasts with (rank, out) ✓
(batch, out) → (batch, 1, out) broadcasts with (batch, rank, out) ✓
**Shapes.**
trace['lora_b'] : ``(..., in, rank)`` — unchanged
trace['lora_a'] : ``(..., rank, out)`` — scaled by ``g``
trace['bias'] : ``(..., out)`` — elementwise :math:`g`
"""
trace_A = trace['lora_a'] * jnp.expand_dims(hidden_dim, axis=-2)
out = {'lora_b': trace['lora_b'], 'lora_a': trace_A}
if has_bias:
out['bias'] = trace['bias'] * hidden_dim
return out
def _lora_mv_yw_to_w(hidden_dim, trace, *, alpha=1.0, has_bias=False):
r"""Unbatched LoRA ``yw_to_w`` — identical algebra with no batch axis.
Trace shapes:
``trace['lora_b'] : (in, rank, n_state)`` — unchanged
``trace['lora_a'] : (rank, out, n_state)`` — scaled by :math:`g`
``trace['bias'] : (out, n_state)`` — elementwise :math:`g`
``jnp.expand_dims(hidden_dim, axis=0)`` turns ``(out,) → (1, out)``
so it broadcasts against the ``(rank, out)`` leading axes of the
:math:`A` trace. As in the batched case, only :math:`A` (and the
bias) are touched; the :math:`B`-trace propagates unchanged (its
:math:`y \to B` chain factor is deferred to ``xy_to_dw``).
"""
trace_A = trace['lora_a'] * jnp.expand_dims(hidden_dim, axis=0)
out = {'lora_b': trace['lora_b'], 'lora_a': trace_A}
if has_bias:
out['bias'] = trace['bias'] * hidden_dim
return out
def _lora_xy_to_dw(x, hidden_dim, weights, *, alpha=1.0, has_bias=False):
r"""Instantaneous LoRA Jacobian via fused VJP.
**Role in D-RTRL / ES-D-RTRL.** Produces the full instantaneous
:math:`\partial h / \partial \{A, B, b\}` term in one ``jax.vjp``
pass. Using :math:`g = \partial h / \partial y`:
.. math::
\frac{\partial h}{\partial A_{r, k}}
= \alpha\, (xB)_r\, g_k,
.. math::
\frac{\partial h}{\partial B_{i, r}}
= \alpha\, \sum_k A_{r, k}\, g_k\, x_i
= \alpha\, x_i\, (A g)_r,
.. math::
\frac{\partial h}{\partial b_k}
= g_k.
All three are computed simultaneously by differentiating
.. code-block:: python
def _fwd(w): return alpha * (x @ w['lora_b'] @ w['lora_a']) + w['bias']
and pulling back the cotangent ``hidden_dim``. In D-RTRL this is
the :math:`\operatorname{diag}(\mathbf{D}_f^t)\otimes \mathbf{x}^t`
contribution; in ES-D-RTRL it is the pullback applied at solve-time
to combine :math:`\boldsymbol{\epsilon}_f^t` with
:math:`\boldsymbol{\epsilon}_x^t` into the weight gradient.
"""
def _fwd(w):
y = alpha * (x @ w['lora_b'] @ w['lora_a'])
if has_bias:
y = y + w['bias']
return u.get_mantissa(y)
_, vjp_fn = jax.vjp(_fwd, weights)
return jax.tree.map(u.get_mantissa, vjp_fn(hidden_dim)[0])
def _lora_mm_init_drtrl(x_var, y_var, weight_vars, num_hidden_state):
r"""Initialise batched LoRA D-RTRL trace.
Each LoRA factor gets its own trace leaf:
.. math::
\boldsymbol{\epsilon}_B \in \mathbb{R}^{B \times I \times r \times n_{\text{state}}}, \quad
\boldsymbol{\epsilon}_A \in \mathbb{R}^{B \times r \times O \times n_{\text{state}}}, \quad
\boldsymbol{\epsilon}_b \in \mathbb{R}^{B \times O \times n_{\text{state}}}.
Memory cost :math:`\mathcal{O}(B\, r\, (I + O))` versus
:math:`\mathcal{O}(B\, I\, O)` for a dense layer — the whole point
of LoRA. Zero-initialised.
"""
batch = x_var.aval.shape[0]
B_shape = weight_vars['lora_b'].aval.shape
A_shape = weight_vars['lora_a'].aval.shape
out = {
'lora_b': jnp.zeros((batch, *B_shape, num_hidden_state)),
'lora_a': jnp.zeros((batch, *A_shape, num_hidden_state)),
}
if 'bias' in weight_vars:
out['bias'] = jnp.zeros(
(batch, *weight_vars['bias'].aval.shape, num_hidden_state)
)
return out
def _lora_mm_init_pp(x_var, y_var, weight_vars, num_hidden_state):
r"""Initialise batched LoRA pp-prop / ES-D-RTRL df trace.
.. math::
\boldsymbol{\epsilon}_f \in \mathbb{R}^{B \times O \times n_{\text{state}}}.
Same shape as dense — pp-prop factorisation does not care how
:math:`W = \alpha B A` is stored. The :math:`\boldsymbol{\epsilon}_x`
factor is the raw :math:`x`; the :math:`B, A, b` split is handled by
:func:`_lora_xy_to_dw` at solve-time.
"""
return jnp.zeros((*y_var.aval.shape, num_hidden_state), dtype=y_var.aval.dtype)
def _lora_mv_init_drtrl(x_var, y_var, weight_vars, num_hidden_state):
r"""Initialise unbatched LoRA D-RTRL trace.
.. math::
\boldsymbol{\epsilon}_B \in \mathbb{R}^{I \times r \times n_{\text{state}}}, \quad
\boldsymbol{\epsilon}_A \in \mathbb{R}^{r \times O \times n_{\text{state}}}, \quad
\boldsymbol{\epsilon}_b \in \mathbb{R}^{O \times n_{\text{state}}}.
Zero-initialised.
"""
B_shape = weight_vars['lora_b'].aval.shape
A_shape = weight_vars['lora_a'].aval.shape
out = {
'lora_b': jnp.zeros((*B_shape, num_hidden_state)),
'lora_a': jnp.zeros((*A_shape, num_hidden_state)),
}
if 'bias' in weight_vars:
out['bias'] = jnp.zeros(
(*weight_vars['bias'].aval.shape, num_hidden_state)
)
return out
def _lora_mv_init_pp(x_var, y_var, weight_vars, num_hidden_state):
r"""Initialise unbatched LoRA pp-prop / ES-D-RTRL df trace.
.. math::
\boldsymbol{\epsilon}_f \in \mathbb{R}^{O \times n_{\text{state}}}.
"""
return jnp.zeros((*y_var.aval.shape, num_hidden_state), dtype=y_var.aval.dtype)
etp_lora_mm_p = register_primitive(
'etp_lora_mm',
_etp_lora_impl,
batched=True,
trainable_invars_fn=_lora_trainable_invars,
x_invar_index=0,
)
etp_lora_mm_p.register_etp_rules(
yw_to_w=_lora_mm_yw_to_w,
xy_to_dw=_lora_xy_to_dw,
init_drtrl=_lora_mm_init_drtrl,
init_pp=_lora_mm_init_pp,
)
etp_lora_mv_p = register_primitive(
'etp_lora_mv',
_etp_lora_impl,
batched=False,
trainable_invars_fn=_lora_trainable_invars,
x_invar_index=0,
)
etp_lora_mv_p.register_etp_rules(
yw_to_w=_lora_mv_yw_to_w,
xy_to_dw=_lora_xy_to_dw,
init_drtrl=_lora_mv_init_drtrl,
init_pp=_lora_mv_init_pp,
)
[docs]
def lora_matmul(x, B, A, *, alpha=1.0, bias=None):
r"""ETP-aware LoRA (Low-Rank Adaptation) matrix multiplication.
Computes :math:`y = \alpha \cdot x \mathbin{@} B \mathbin{@} A \; (+ b)`,
routing both low-rank factors (and the optional bias) through an ETP
primitive so they participate in eligibility-trace computation.
Auto-dispatches batched/unbatched based on ``x.ndim``.
Parameters
----------
x : ArrayLike
Input array, shape ``(..., in_features)`` or ``(in_features,)``.
B : ArrayLike
Low-rank matrix :math:`B`, shape ``(in_features, rank)``.
A : ArrayLike
Low-rank matrix :math:`A`, shape ``(rank, out_features)``.
alpha : float, optional
Scalar scaling factor :math:`\alpha`. Default ``1.0``.
bias : ArrayLike or None, optional
Bias vector, shape ``(out_features,)``. Default ``None``.
Returns
-------
ArrayLike
Output array, shape ``(..., out_features)`` or ``(out_features,)``.
Examples
--------
.. code-block:: python
>>> import brainstate
>>> import braintrace
>>>
>>> brainstate.environ.set(precision=64)
>>> x = brainstate.random.randn(16, 8)
>>> B = brainstate.random.randn(8, 2)
>>> A = brainstate.random.randn(2, 4)
>>> y = braintrace.lora_matmul(x, B, A, alpha=0.5)
>>> print(y.shape)
(16, 4)
"""
p = etp_lora_mm_p if x.ndim >= 2 else etp_lora_mv_p
x_v, x_u = u.split_mantissa_unit(x)
B_v, B_u = u.split_mantissa_unit(B)
A_v, A_u = u.split_mantissa_unit(A)
unit = x_u * B_u * A_u
if bias is not None:
bias_v = u.Quantity(bias).to_decimal(unit)
r = p.bind(x_v, B_v, A_v, bias_v, alpha=alpha, has_bias=True)
else:
r = p.bind(x_v, B_v, A_v, alpha=alpha, has_bias=False)
return u.maybe_decimal(r * x_u * B_u * A_u)