ModuleInfo#
- class braintrace.ModuleInfo(stateful_model: StatefulFunction, closed_jaxpr: ClosedJaxpr, retrieved_model_states: FlattedDict[Tuple[str, ...], State], compiled_model_states: Sequence[State], state_id_to_path: Dict[int, Tuple[str, ...]], state_tree_invars: PyTree[jax._src.core.Var], state_tree_outvars: PyTree[jax._src.core.Var], hidden_path_to_invar: Dict[Tuple[str, ...], Var], hidden_path_to_outvar: Dict[Tuple[str, ...], Var], invar_to_hidden_path: Dict[Var, Tuple[str, ...]], outvar_to_hidden_path: Dict[Var, Tuple[str, ...]], hidden_outvar_to_invar: Dict[Var, Var], weight_invars: List[Var], weight_path_to_invars: Dict[Tuple[str, ...], List[Var]], invar_to_weight_path: Dict[Var, Tuple[str, ...]], num_var_out: int, num_var_state: int)#
The model information for the ETrace compiler.
Bundles the abstract representation of a model and all the lookup tables the compiler needs. It groups information into five categories: the stateful model, the jaxpr, the states, the hidden states, and the parameter weights.
- stateful_model#
The stateful function that compiles the model into an abstract jaxpr representation.
- Type:
brainstate.transform.StatefulFunction
- closed_jaxpr#
The closed-jaxpr representation of the model.
- Type:
ClosedJaxpr
- retrieved_model_states#
The model states retrieved from
model.states(), with well-defined paths and structures.- Type:
brainstate.util.FlattedDict
- compiled_model_states#
The model states compiled from the stateful model; accurate and consistent with the model jaxpr but lacking path information.
- Type:
sequence of brainstate.State
- state_tree_invars#
The input jaxpr variables of the states, as a pytree.
- Type:
PyTree of Var
- state_tree_outvars#
The output jaxpr variables of the states, as a pytree.
- Type:
PyTree of Var
Mapping from each hidden path to its input variable.
- Type:
Mapping from each hidden path to its output variable.
- Type:
Mapping from each input variable to its hidden path.
- Type:
Mapping from each output variable to its hidden path.
- Type:
Mapping from each output variable to its input variable.
- Type:
See also
extract_module_infoBuild a
ModuleInfofrom a model.
Examples
>>> import brainstate >>> import braintrace >>> gru = braintrace.nn.GRUCell(3, 4) >>> _ = brainstate.nn.init_all_states(gru) >>> inputs = brainstate.random.randn(3) >>> module_info = braintrace.extract_module_info(gru, inputs) >>> isinstance(module_info, braintrace.ModuleInfo) True
- add_jaxpr_outs(jax_vars)[source]#
Add extra jaxpr outputs to the model jaxpr.
Returns a new
ModuleInfowhose jaxpr additionally outputs the given variables, so the compiler can recover the intermediate values it needs.- Parameters:
jax_vars (
Sequence[Var]) – The extra jaxpr variables to append to the jaxpr outputs.- Returns:
A new
ModuleInfowith the extended jaxpr.- Return type:
- closed_jaxpr: ClosedJaxpr#
Alias for field number 1
- dict()[source]#
Return this module info’s named fields as a plain dictionary.
- Returns:
An ordered mapping from field name to value, as produced by the underlying
typing.NamedTuple.- Return type:
- property jaxpr: Jaxpr#
The jaxpr of the model.
- Returns:
The jaxpr extracted from
closed_jaxpr.- Return type:
Jaxpr
- jaxpr_call(*args, old_state_vals=None)[source]#
Evaluate the model on the given inputs using the compiled jaxpr.
- Parameters:
- Return type:
Tuple[Any,Dict[Tuple[str,...],Any],Dict[Tuple[str,...],Any],Dict[Var,Array]]- Returns:
out (Outputs) – The output of the model.
etrace_vals (ETraceVals) – The values for the eligibility-trace (hidden) states.
oth_state_vals (StateVals) – The other state values.
temps (TempData) – The temporary intermediate values.
- split_state_outvars()[source]#
Split the state outvars into weight, hidden, and other states.
- Returns:
weight_jaxvar_tree (PyTree of Var) – The weight tree of jaxpr variables.
hidden_jaxvar (PyTree of Var) – The hidden tree of jaxpr variables.
other_state_jaxvar_tree (PyTree of Var) – The other-state tree of jaxpr variables.
- state_tree_invars: PyTree[jax._src.core.Var]#
Alias for field number 5
- state_tree_outvars: PyTree[jax._src.core.Var]#
Alias for field number 6
- stateful_model: StatefulFunction#
Alias for field number 0