lora_matmul

Contents

lora_matmul#

class braintrace.lora_matmul(x, B, A, *, alpha=1.0, bias=None)[source]#

ETP-aware LoRA (Low-Rank Adaptation) matrix multiplication.

Computes \(y = \alpha \cdot x \mathbin{@} B \mathbin{@} A \; (+ b)\), routing both low-rank factors (and the optional bias) through an ETP primitive so they participate in eligibility-trace computation. Auto-dispatches batched/unbatched based on x.ndim.

Parameters:
  • x (ArrayLike) – Input array, shape (..., in_features) or (in_features,).

  • B (ArrayLike) – Low-rank matrix \(B\), shape (in_features, rank).

  • A (ArrayLike) – Low-rank matrix \(A\), shape (rank, out_features).

  • alpha (float, optional) – Scalar scaling factor \(\alpha\). Default 1.0.

  • bias (ArrayLike or None, optional) – Bias vector, shape (out_features,). Default None.

Returns:

Output array, shape (..., out_features) or (out_features,).

Return type:

ArrayLike

Examples

>>> import brainstate
>>> import braintrace
>>>
>>> brainstate.environ.set(precision=64)
>>> x = brainstate.random.randn(16, 8)
>>> B = brainstate.random.randn(8, 2)
>>> A = brainstate.random.randn(2, 4)
>>> y = braintrace.lora_matmul(x, B, A, alpha=0.5)
>>> print(y.shape)
(16, 4)