compile#
- class braintrace.compile(model, algorithm, *example_inputs, **options)[source]#
Construct an online-learning algorithm for
modeland eagerly build its eligibility-trace graph, returning a ready-to-updatelearner.- Parameters:
model (brainstate.nn.Module) – The recurrent model. Its states must already be initialized, e.g. via
brainstate.nn.init_all_states(model).algorithm (type or str) – An
ETraceAlgorithmsubclass, or a registered string name (case-insensitive), e.g.'D_RTRL','eprop','ottt'.*example_inputs – Example call inputs (arrays /
SingleStepData/MultiStepData), matching whatlearner.update(...)will later receive. Forwarded toETraceAlgorithm.compile_graph()to trace the jaxpr graph. At least one is required.**options – Keyword options forwarded to the algorithm constructor, e.g.
vjp_method,leak,fast_solve,trace_dtype,feedback.
- Returns:
The compiled learner; call
.update(*inputs)to train.- Return type:
- Raises:
ValueError – If
algorithmis an unknown string name, or noexample_inputsare given.TypeError – If
algorithmis neither anETraceAlgorithmsubclass nor a string.
Examples
>>> import braintrace >>> import brainstate >>> import jax.numpy as jnp >>> model = MyRNN() >>> brainstate.nn.init_all_states(model, batch_size=1) >>> x0 = jnp.ones((3,)) >>> learner = braintrace.compile(model, 'D_RTRL', x0, vjp_method='multi-step') >>> y = learner.update(x0)