ETP Operators & Core Types#

This page documents the user-facing ETP operators — the ops you call inside a model’s update to make a weight participate in online learning — together with the small set of core types every algorithm consumes: input wrappers, the eligibility-trace state, gradient utilities, and the error hierarchy.

To add your own ETP primitive (with custom trace-propagation rules), see Custom ETP Primitives.

ETP Primitive Operators#

These functions mark weight operations for inclusion in online learning. Use braintrace.matmul(x, w) instead of x @ w to include a weight in eligibility-trace computation; a parameter used through a regular JAX op is automatically excluded. There is no special parameter class — every brainstate.ParamState is eligible, and participation is decided purely by whether an ETP operator consumed it.

All operators accept physical-unit quantities (mantissa/unit are split, computed, and recombined) and come in batched and unbatched forms selected by input rank.

matmul

ETP-aware matrix multiplication.

element_wise

ETP-aware element-wise operation.

conv

ETP-aware convolution.

sparse_matmul

ETP-aware sparse matrix multiplication.

lora_matmul

ETP-aware LoRA (Low-Rank Adaptation) matrix multiplication.

Controlling Parameter Participation#

import jax
import braintrace
import brainstate

class MyRNN(brainstate.nn.Module):
    def __init__(self):
        super().__init__()
        self.w_rec = brainstate.ParamState(...)   # want ETP
        self.w_in = brainstate.ParamState(...)     # do NOT want ETP
        self.h = brainstate.ShortTermState(...)

    def update(self, x):
        # regular matmul -> w_in excluded from ETP
        inp = x @ self.w_in.value
        # ETP matmul -> w_rec included in ETP
        self.h.value = jax.nn.tanh(inp + braintrace.matmul(self.h.value, self.w_rec.value))
        return self.h.value
Table 1 Parameter Selection Rules#

Goal

How

Include a parameter in online learning

Use a braintrace.* ETP operator (e.g. braintrace.matmul(x, w)).

Exclude a parameter from online learning

Use a regular JAX op (e.g. x @ w).

Selection mechanism

The operation’s primitive type — not the parameter’s class. Every brainstate.ParamState is eligible; participation depends solely on whether an ETP primitive consumed it.

Input Data#

Wrappers that tell an online-learning algorithm whether a step receives a single time step or a whole sequence of time steps.

SingleStepData

A container marking input data as belonging to a single time step.

MultiStepData

A container marking input data as spanning multiple time steps.

Eligibility Trace State#

The state object that stores the eligibility trace carried forward across time steps during online learning.

EligibilityTrace

The state for storing the eligibility trace during the computation of online learning algorithms.

Gradient Utilities#

Helpers used when combining per-step gradients into the running online gradient.

GradExpon

Accumulate gradients with an exponential (leaky) running sum.

Errors#

Exceptions raised by the compilation and execution machinery.

NotSupportedError

Exception raised for operations that are not supported.

CompilationError

Exception raised for errors that occur during compilation.