compile

Contents

compile#

class braintrace.compile(model, algorithm, *example_inputs, **options)[source]#

Construct an online-learning algorithm for model and eagerly build its eligibility-trace graph, returning a ready-to-update learner.

Parameters:
  • model (brainstate.nn.Module) – The recurrent model. Its states must already be initialized, e.g. via brainstate.nn.init_all_states(model).

  • algorithm (type or str) – An ETraceAlgorithm subclass, or a registered string name (case-insensitive), e.g. 'D_RTRL', 'eprop', 'ottt'.

  • *example_inputs – Example call inputs (arrays / SingleStepData / MultiStepData), matching what learner.update(...) will later receive. Forwarded to ETraceAlgorithm.compile_graph() to trace the jaxpr graph. At least one is required.

  • **options – Keyword options forwarded to the algorithm constructor, e.g. vjp_method, leak, fast_solve, trace_dtype, feedback.

Returns:

The compiled learner; call .update(*inputs) to train.

Return type:

ETraceAlgorithm

Raises:
  • ValueError – If algorithm is an unknown string name, or no example_inputs are given.

  • TypeError – If algorithm is neither an ETraceAlgorithm subclass nor a string.

Examples

>>> import braintrace
>>> import brainstate
>>> import jax.numpy as jnp
>>> model = MyRNN()
>>> brainstate.nn.init_all_states(model, batch_size=1)
>>> x0 = jnp.ones((3,))
>>> learner = braintrace.compile(model, 'D_RTRL', x0, vjp_method='multi-step')
>>> y = learner.update(x0)