Compiler, Executor & Diagnostics#

The compiler analyzes a model’s JAX intermediate representation (jaxpr) to discover the relationships between ETP primitives, weight parameters, and hidden states. It recognizes ETP primitives by primitive-type identity (never by string-matching names), and the result is an ETraceGraph that the executor and the online-learning algorithms consume.

Most users never call this layer directly — compile() and the algorithm classes drive it for you. It is documented here for building custom algorithms, inspecting what the compiler discovered, and acting on diagnostics.

Graph Compilation#

The entry point that compiles a model into an eligibility-trace graph, and the graph object it returns.

compile_etrace_graph

Construct the eligibility-trace graph for a given model and inputs.

ETraceGraph

The overall compiled graph for the eligibility trace.

Module Info#

Extracts the jaxpr and state information from a brainstate.nn.Module. ModuleInfo is the compiler’s structured view of a model.

extract_module_info

Extract the model information for the ETrace compiler.

ModuleInfo

The model information for the ETrace compiler.

Hidden Groups#

Sets of hidden states that are updated together in the recurrent computation. The finder functions discover them either from an extracted ModuleInfo or directly from a module.

HiddenGroup

The data structure recording a hidden-group relation.

find_hidden_groups_from_minfo

Find the hidden groups from the model information.

find_hidden_groups_from_module

Find hidden groups from a model.

Hidden–Parameter–Operation Relations#

The core data structure connecting ETP primitives, weight parameters, and hidden states. Each relation encodes “weight W is used through ETP primitive P, and P’s output feeds hidden group H.” Per the non-parametric-tail invariant, a weight that reaches a hidden state only through another trainable ETP primitive is deliberately excluded.

HiddenParamOpRelation

Connection between an ETP primitive, its trainable parameters, and hidden states.

find_hidden_param_op_relations_from_minfo

Find ETP relations from a ModuleInfo.

find_hidden_param_op_relations_from_module

Find ETP relations from a model.

Hidden Perturbation#

Perturbation structures used to compute hidden-to-hidden Jacobians (the diagonal approximation of \(\partial \mathbf{h}^t / \partial \mathbf{h}^{t-1}\)).

HiddenPerturbation

The hidden-perturbation information.

add_hidden_perturbation_from_minfo

Add hidden-state perturbations from a ModuleInfo.

add_hidden_perturbation_in_module

Add hidden-state perturbations from a model.

Graph Executor#

Executes the compiled graph: runs the forward pass and computes the hidden-to-weight and hidden-to-hidden Jacobians the algorithms consume. ETraceVjpGraphExecutor is the VJP-based executor used by the ETraceVjpAlgorithm family.

ETraceGraphExecutor

The eligibility trace graph executor.

ETraceVjpGraphExecutor

The eligibility trace graph executor for the VJP-based online learning algorithms.

Diagnostics#

Structured, leveled records emitted while the compiler analyzes a model. They surface issues that would otherwise be silent — for example a trainable input that does not trace back to a ParamState, or an ETP weight excluded because it only reaches a hidden state through another trainable primitive. DiagnosticLevel orders records by severity (INFO < WARNING < ERROR) and DiagnosticKind names the specific condition.

CompilationRecord

A single compiler decision, captured with structured context.

DiagnosticKind

Machine-readable reason for a CompilationRecord.

DiagnosticLevel

Severity of a CompilationRecord.