Source code for braintrace.nn._linear

# Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

# -*- coding: utf-8 -*-

import brainstate
import saiunit as u

from braintrace._etrace_op import matmul, sparse_matmul, lora_matmul
from braintrace._typing import ArrayLike

__all__ = [
    'Linear',
    'SignedWLinear',
    'SparseLinear',
    'LoRA',
]


class Linear(brainstate.nn.Linear):
    __module__ = 'braintrace.nn'
    __doc__ = brainstate.nn.Linear.__doc__.replace('brainstate', 'braintrace')

[docs] def update(self, x): """Apply the linear transform through the ETP ``matmul`` primitive. Routing the matrix multiplication through :func:`braintrace.matmul` (instead of a plain JAX dot) is what makes ``weight`` eligible for online-learning trace computation. Parameters ---------- x : ArrayLike Input array, of shape ``(..., in_size)``. Returns ------- ArrayLike The transformed output, of shape ``(..., out_size)``. """ w = self.weight.value['weight'] if self.w_mask is not None: w = w * self.w_mask b = self.weight.value.get('bias') return matmul(x, w, b)
class SignedWLinear(brainstate.nn.SignedWLinear): __module__ = 'braintrace.nn' __doc__ = brainstate.nn.SignedWLinear.__doc__.replace('brainstate', 'braintrace')
[docs] def update(self, x): """Apply the sign-constrained linear transform through ETP ``matmul``. The stored weight magnitudes are made non-negative and then given a fixed sign before being routed through :func:`braintrace.matmul`, so the weight participates in online-learning trace computation. Parameters ---------- x : ArrayLike Input array, of shape ``(..., in_size)``. Returns ------- ArrayLike The transformed output, of shape ``(..., out_size)``. """ w = u.math.abs(self.weight.value) if self.w_sign is not None: w = w * self.w_sign return matmul(x, w)
class ScaledWSLinear(brainstate.nn.ScaledWSLinear): __module__ = 'braintrace.nn' __doc__ = brainstate.nn.ScaledWSLinear.__doc__.replace('brainstate', 'braintrace') def update(self, x): """Apply the weight-standardized linear transform through ETP ``matmul``. The weight is standardized (zero-mean, unit-variance per output unit) before being routed through :func:`braintrace.matmul`, so it participates in online-learning trace computation. Parameters ---------- x : ArrayLike Input array, of shape ``(..., in_size)``. Returns ------- ArrayLike The transformed output, of shape ``(..., out_size)``. """ params = self.weight.value w = brainstate.nn.weight_standardization(params['weight'], self.eps, params.get('gain', None)) if self.w_mask is not None: w = w * self.w_mask b = params.get('bias', None) return matmul(x, w, b) class SparseLinear(brainstate.nn.SparseLinear): __module__ = 'braintrace.nn' __doc__ = brainstate.nn.SparseLinear.__doc__.replace('brainstate', 'braintrace')
[docs] def update(self, x): """Apply the sparse linear transform through the ETP ``sparse_matmul``. The dense data of the sparse weight is routed through :func:`braintrace.sparse_matmul`, so it participates in online-learning trace computation. Parameters ---------- x : ArrayLike Input array, of shape ``(..., in_size)``. Returns ------- ArrayLike The transformed output, of shape ``(..., out_size)``. """ weight = self.weight.value['weight'] bias = self.weight.value.get('bias', None) return sparse_matmul(x, weight, sparse_mat=self.spar_mat, bias=bias)
class LoRA(brainstate.nn.LoRA): r"""A standalone LoRA layer. LoRA (Low-Rank Adaptation) injects two low-rank factors into a layer so a large pre-trained model can be fine-tuned with far fewer parameters. This subclass preserves the upstream :class:`brainstate.nn.LoRA` constructor and replaces only the forward pass so that the multiplication is routed through :func:`braintrace.lora_matmul` and therefore participates in eligibility-trace computation. The layer adds a low-rank component :math:`\frac{1}{r} B A` to the base weight, where :math:`B` and :math:`A` are learnable factors of rank :math:`r`: .. math:: W_{\mathrm{LoRA}} = W_{\text{orig}} + \frac{1}{r} B A The scaling factor is fixed to ``1 / lora_rank``. Parameters ---------- in_features : int Number of input features. lora_rank : int Rank of the LoRA decomposition. out_features : int Number of output features. base_module : brainstate.nn.Module or None, optional Optional base layer that is called on ``x`` and added to the LoRA branch. Default ``None``. kernel_init : Callable or ArrayLike, optional Initializer used for **both** ``lora_a`` (in×rank) and ``lora_b`` (rank×out). Default is ``LecunNormal()``. To get the classic "LoRA-zero" initialisation use ``init.ZeroInit()``. param_type : type, optional ``ParamState`` subclass used to wrap the weights. Default is ``brainstate.ParamState``. in_size : int or Sequence[int], optional Optional explicit input size override. Default ``None``. Attributes ---------- in_features : int Number of input features. out_features : int Number of output features. base_module : brainstate.nn.Module or None The optional base layer added to the LoRA branch. weight : ParamState ``ParamState`` whose value is a dict with two keys: ``'lora_a'`` of shape ``(in_features, lora_rank)`` and ``'lora_b'`` of shape ``(lora_rank, out_features)``. Examples -------- .. code-block:: python >>> import brainstate >>> import braintrace >>> >>> # Create a standalone LoRA layer >>> brainstate.environ.set(precision=64) >>> layer = braintrace.nn.LoRA(in_features=3, lora_rank=2, out_features=4) >>> x = brainstate.random.randn(16, 3) >>> y = layer(x) >>> print(y.shape) (16, 4) >>> >>> # Wrap around an existing linear layer >>> linear = brainstate.nn.Linear(3, 4) >>> wrapper = braintrace.nn.LoRA(3, 2, 4, base_module=linear) >>> assert wrapper.base_module is linear >>> y = wrapper(x) >>> print(y.shape) (16, 4) """ __module__ = 'braintrace.nn'
[docs] def update(self, x: ArrayLike): r"""Apply the low-rank adaptation through the ETP ``lora_matmul``. Computes :math:`\frac{1}{r} B A\, x` via :func:`braintrace.lora_matmul` (so the LoRA factors participate in online-learning trace computation) and adds the optional base-module output. Parameters ---------- x : ArrayLike Input array, of shape ``(..., in_features)``. Returns ------- ArrayLike The adapted output, of shape ``(..., out_features)``. """ param = self.weight.value lora_rank = param['lora_b'].shape[0] out = lora_matmul(x, param['lora_b'], param['lora_a'], alpha=1.0 / lora_rank) if self.base_module is not None: out += self.base_module(x) return out