MultiStepData#
- class braintrace.MultiStepData(data)#
A container marking input data as spanning multiple time steps.
Wraps an arbitrary pytree of arrays whose leading axis is the time dimension, so the online-learning machinery can iterate over time steps. Registered as a JAX pytree node, so instances can cross
jit/gradboundaries transparently.- Parameters:
data (
Any) – The pytree of arrays to store. The first dimension of each array represents the time steps.
See also
SingleStepDataContainer for an input at a single time step.
Examples
>>> import brainstate >>> import braintrace >>> # data at 10 time steps, 2 samples each, 3 features per sample >>> data = braintrace.MultiStepData(brainstate.random.randn(10, 2, 3)) >>> print(data.data.shape) (10, 2, 3)