ETraceGraph#
- class braintrace.ETraceGraph(module_info: ModuleInfo, hidden_groups: Sequence[HiddenGroup], hid_path_to_group: Dict[Tuple[str, ...], HiddenGroup], hidden_param_op_relations: Sequence[HiddenParamOpRelation], hidden_perturb: HiddenPerturbation | None, diagnostics: Tuple[CompilationRecord, ...] = ())#
The overall compiled graph for the eligibility trace.
Tracks the relationship between the eligibility-trace weights (
ParamState), the eligibility-trace variables (HiddenState), and the eligibility-trace operations (ETP primitives). It is the object returned bycompile_etrace_graph()and consumed by the online-learning algorithms.- module_info#
The model information.
- Type:
The hidden groups.
- Type:
sequence of HiddenGroup
- hid_path_to_group#
Mapping from each hidden-state path to its
HiddenGroup.- Type:
The hidden parameter-operation relations.
- Type:
sequence of HiddenParamOpRelation
The hidden perturbation, or
Nonewhen perturbations are excluded.- Type:
HiddenPerturbation or None
- diagnostics#
The structured compilation records emitted while building the graph.
- Type:
See also
compile_etrace_graphBuild an
ETraceGraphfrom a model.
Examples
>>> import brainstate >>> import braintrace >>> gru = braintrace.nn.GRUCell(3, 4) >>> _ = brainstate.nn.init_all_states(gru) >>> inputs = brainstate.random.randn(3) >>> graph = braintrace.compile_etrace_graph(gru, inputs) >>> isinstance(graph, braintrace.ETraceGraph) True
Run the forward pass with additive perturbations injected at the hidden states.
Evaluates the perturbed-forward jaxpr built during compilation, which is the forward computation augmented so that each tracked hidden state has a perturbation term added to it. This is the primitive used to probe hidden->hidden and hidden->output sensitivities.
- Parameters:
args (
Any) – The model inputs for this step, matching the signature captured at compile time.perturb_data (
Sequence[Array]) – One perturbation array per tracked hidden state, added at the corresponding perturbation site.old_state_vals (
Optional[Sequence[Array]]) – The state values to run from. WhenNone(default) the current values of the compiled model states are used.
- Returns:
The processed model outputs, in the same structure produced by a normal forward call.
- Return type:
- diagnostics: Tuple[CompilationRecord, ...]#
Alias for field number 5
- dict()[source]#
Return the graph’s fields as a plain dictionary.
- Returns:
A mapping from field name to value for every attribute of this
ETraceGraph.- Return type:
- explain(*, weight_path=None, hidden_path=None, kind=None)[source]#
Return compilation records filtered by weight path, hidden path, or kind.
weight_pathandhidden_pathmatch the record’sweight_pathexactly andhidden_pathsmembership respectively.kindmatchesCompilationRecord.kind. All filters are optional; with no filters the full diagnostic log is returned.- Parameters:
weight_path (
Optional[Tuple[str,...]]) – If given, keep only records whoseweight_pathequals this value. DefaultNone.hidden_path (
Optional[Tuple[str,...]]) – If given, keep only records whosehidden_pathscontain this value. DefaultNone.kind (
Optional[DiagnosticKind]) – If given, keep only records whosekindis this value. DefaultNone.
- Returns:
The matching records, in emission order.
- Return type:
- hid_path_to_group: Dict[Tuple[str, ...], HiddenGroup]#
Alias for field number 2
- hidden_groups: Sequence[HiddenGroup]#
Alias for field number 1
- hidden_param_op_relations: Sequence[HiddenParamOpRelation]#
Alias for field number 3
- hidden_perturb: HiddenPerturbation | None#
Alias for field number 4
- module_info: ModuleInfo#
Alias for field number 0