SingleStepData

Contents

SingleStepData#

class braintrace.SingleStepData(data)#

A container marking input data as belonging to a single time step.

Wraps an arbitrary pytree of arrays so the online-learning machinery can distinguish per-step inputs (which are reused at every step) from time-major inputs. Registered as a JAX pytree node, so instances can cross jit/grad boundaries transparently.

Parameters:

data (Any) – The pytree of arrays to store for a single time step.

See also

MultiStepData

Container for inputs spanning multiple time steps.

Examples

>>> import brainstate
>>> import braintrace
>>> data = braintrace.SingleStepData(brainstate.random.randn(2, 3))
>>> print(data.data.shape)
(2, 3)