ETP Operators & Core Types#
This page documents the user-facing ETP operators — the ops you call inside
a model’s update to make a weight participate in online learning — together
with the small set of core types every algorithm consumes: input wrappers, the
eligibility-trace state, gradient utilities, and the error hierarchy.
To add your own ETP primitive (with custom trace-propagation rules), see Custom ETP Primitives.
ETP Primitive Operators#
These functions mark weight operations for inclusion in online learning. Use
braintrace.matmul(x, w) instead of x @ w to include a weight in
eligibility-trace computation; a parameter used through a regular JAX op is
automatically excluded. There is no special parameter class — every
brainstate.ParamState is eligible, and participation is decided purely by
whether an ETP operator consumed it.
All operators accept physical-unit quantities (mantissa/unit are split, computed, and recombined) and come in batched and unbatched forms selected by input rank.
ETP-aware matrix multiplication. |
|
ETP-aware element-wise operation. |
|
ETP-aware convolution. |
|
ETP-aware sparse matrix multiplication. |
|
ETP-aware LoRA (Low-Rank Adaptation) matrix multiplication. |
Controlling Parameter Participation#
import jax
import braintrace
import brainstate
class MyRNN(brainstate.nn.Module):
def __init__(self):
super().__init__()
self.w_rec = brainstate.ParamState(...) # want ETP
self.w_in = brainstate.ParamState(...) # do NOT want ETP
self.h = brainstate.ShortTermState(...)
def update(self, x):
# regular matmul -> w_in excluded from ETP
inp = x @ self.w_in.value
# ETP matmul -> w_rec included in ETP
self.h.value = jax.nn.tanh(inp + braintrace.matmul(self.h.value, self.w_rec.value))
return self.h.value
Goal |
How |
|---|---|
Include a parameter in online learning |
Use a |
Exclude a parameter from online learning |
Use a regular JAX op (e.g. |
Selection mechanism |
The operation’s primitive type — not the parameter’s class. Every
|
Input Data#
Wrappers that tell an online-learning algorithm whether a step receives a single time step or a whole sequence of time steps.
A container marking input data as belonging to a single time step. |
|
A container marking input data as spanning multiple time steps. |
Eligibility Trace State#
The state object that stores the eligibility trace carried forward across time steps during online learning.
The state for storing the eligibility trace during the computation of online learning algorithms. |
Gradient Utilities#
Helpers used when combining per-step gradients into the running online gradient.
Accumulate gradients with an exponential (leaky) running sum. |
Errors#
Exceptions raised by the compilation and execution machinery.
Exception raised for operations that are not supported. |
|
Exception raised for errors that occur during compilation. |