MultiStepData

Contents

MultiStepData#

class braintrace.MultiStepData(data)#

A container marking input data as spanning multiple time steps.

Wraps an arbitrary pytree of arrays whose leading axis is the time dimension, so the online-learning machinery can iterate over time steps. Registered as a JAX pytree node, so instances can cross jit/grad boundaries transparently.

Parameters:

data (Any) – The pytree of arrays to store. The first dimension of each array represents the time steps.

See also

SingleStepData

Container for an input at a single time step.

Examples

>>> import brainstate
>>> import braintrace
>>> # data at 10 time steps, 2 samples each, 3 features per sample
>>> data = braintrace.MultiStepData(brainstate.random.randn(10, 2, 3))
>>> print(data.data.shape)
(10, 2, 3)