compile_etrace_graph

compile_etrace_graph#

class braintrace.compile_etrace_graph(model, *model_args, include_hidden_perturb=True)[source]#

Construct the eligibility-trace graph for a given model and inputs.

This is the primary entry point of the ETrace compiler. It builds the graph for the model, tracking the relationship between the eligibility-trace weights (ParamState), the eligibility-trace states (HiddenState), and the eligibility-trace operations (ETP primitives). These relationships are later used to compute the weight spatial gradients, the hidden-state Jacobian, and the hidden-state-to-weight Jacobian.

Parameters:
  • model (Module) – The model for which the eligibility-trace graph is built.

  • *model_args (Tuple) – The positional arguments required by the model.

  • include_hidden_perturb (bool) – Whether to include hidden perturbations in the graph. Default True.

Returns:

The compiled eligibility-trace graph containing module information, hidden groups, hidden parameter-operation relations, and optional hidden perturbations.

Return type:

ETraceGraph

Raises:

NotImplementedError – If a recursive call to the compiler is detected.

See also

ETraceGraph

The returned compiled-graph data structure.

Examples

>>> import brainstate
>>> import braintrace
>>> gru = braintrace.nn.GRUCell(3, 4)
>>> _ = brainstate.nn.init_all_states(gru)
>>> inputs = brainstate.random.randn(3)
>>> graph = braintrace.compile_etrace_graph(gru, inputs)
>>> len(graph.hidden_groups)
1