ETraceGraph#

class braintrace.ETraceGraph(module_info: ModuleInfo, hidden_groups: Sequence[HiddenGroup], hid_path_to_group: Dict[Tuple[str, ...], HiddenGroup], hidden_param_op_relations: Sequence[HiddenParamOpRelation], hidden_perturb: HiddenPerturbation | None, diagnostics: Tuple[CompilationRecord, ...] = ())#

The overall compiled graph for the eligibility trace.

Tracks the relationship between the eligibility-trace weights (ParamState), the eligibility-trace variables (HiddenState), and the eligibility-trace operations (ETP primitives). It is the object returned by compile_etrace_graph() and consumed by the online-learning algorithms.

module_info#

The model information.

Type:

ModuleInfo

hidden_groups#

The hidden groups.

Type:

sequence of HiddenGroup

hid_path_to_group#

Mapping from each hidden-state path to its HiddenGroup.

Type:

dict

hidden_param_op_relations#

The hidden parameter-operation relations.

Type:

sequence of HiddenParamOpRelation

hidden_perturb#

The hidden perturbation, or None when perturbations are excluded.

Type:

HiddenPerturbation or None

diagnostics#

The structured compilation records emitted while building the graph.

Type:

tuple of CompilationRecord

See also

compile_etrace_graph

Build an ETraceGraph from a model.

Examples

>>> import brainstate
>>> import braintrace
>>> gru = braintrace.nn.GRUCell(3, 4)
>>> _ = brainstate.nn.init_all_states(gru)
>>> inputs = brainstate.random.randn(3)
>>> graph = braintrace.compile_etrace_graph(gru, inputs)
>>> isinstance(graph, braintrace.ETraceGraph)
True
call_hidden_perturb(args, perturb_data, old_state_vals=None)[source]#

Run the forward pass with additive perturbations injected at the hidden states.

Evaluates the perturbed-forward jaxpr built during compilation, which is the forward computation augmented so that each tracked hidden state has a perturbation term added to it. This is the primitive used to probe hidden->hidden and hidden->output sensitivities.

Parameters:
  • args (Any) – The model inputs for this step, matching the signature captured at compile time.

  • perturb_data (Sequence[Array]) – One perturbation array per tracked hidden state, added at the corresponding perturbation site.

  • old_state_vals (Optional[Sequence[Array]]) – The state values to run from. When None (default) the current values of the compiled model states are used.

Returns:

The processed model outputs, in the same structure produced by a normal forward call.

Return type:

object

diagnostics: Tuple[CompilationRecord, ...]#

Alias for field number 5

dict()[source]#

Return the graph’s fields as a plain dictionary.

Returns:

A mapping from field name to value for every attribute of this ETraceGraph.

Return type:

Dict

explain(*, weight_path=None, hidden_path=None, kind=None)[source]#

Return compilation records filtered by weight path, hidden path, or kind.

weight_path and hidden_path match the record’s weight_path exactly and hidden_paths membership respectively. kind matches CompilationRecord.kind. All filters are optional; with no filters the full diagnostic log is returned.

Parameters:
  • weight_path (Optional[Tuple[str, ...]]) – If given, keep only records whose weight_path equals this value. Default None.

  • hidden_path (Optional[Tuple[str, ...]]) – If given, keep only records whose hidden_paths contain this value. Default None.

  • kind (Optional[DiagnosticKind]) – If given, keep only records whose kind is this value. Default None.

Returns:

The matching records, in emission order.

Return type:

Tuple[CompilationRecord, ...]

hid_path_to_group: Dict[Tuple[str, ...], HiddenGroup]#

Alias for field number 2

hidden_groups: Sequence[HiddenGroup]#

Alias for field number 1

hidden_param_op_relations: Sequence[HiddenParamOpRelation]#

Alias for field number 3

hidden_perturb: HiddenPerturbation | None#

Alias for field number 4

module_info: ModuleInfo#

Alias for field number 0