ModuleInfo#

class braintrace.ModuleInfo(stateful_model: StatefulFunction, closed_jaxpr: ClosedJaxpr, retrieved_model_states: FlattedDict[Tuple[str, ...], State], compiled_model_states: Sequence[State], state_id_to_path: Dict[int, Tuple[str, ...]], state_tree_invars: PyTree[jax._src.core.Var], state_tree_outvars: PyTree[jax._src.core.Var], hidden_path_to_invar: Dict[Tuple[str, ...], Var], hidden_path_to_outvar: Dict[Tuple[str, ...], Var], invar_to_hidden_path: Dict[Var, Tuple[str, ...]], outvar_to_hidden_path: Dict[Var, Tuple[str, ...]], hidden_outvar_to_invar: Dict[Var, Var], weight_invars: List[Var], weight_path_to_invars: Dict[Tuple[str, ...], List[Var]], invar_to_weight_path: Dict[Var, Tuple[str, ...]], num_var_out: int, num_var_state: int)#

The model information for the ETrace compiler.

Bundles the abstract representation of a model and all the lookup tables the compiler needs. It groups information into five categories: the stateful model, the jaxpr, the states, the hidden states, and the parameter weights.

stateful_model#

The stateful function that compiles the model into an abstract jaxpr representation.

Type:

brainstate.transform.StatefulFunction

closed_jaxpr#

The closed-jaxpr representation of the model.

Type:

ClosedJaxpr

retrieved_model_states#

The model states retrieved from model.states(), with well-defined paths and structures.

Type:

brainstate.util.FlattedDict

compiled_model_states#

The model states compiled from the stateful model; accurate and consistent with the model jaxpr but lacking path information.

Type:

sequence of brainstate.State

state_id_to_path#

Mapping from each state id to its state path.

Type:

dict

state_tree_invars#

The input jaxpr variables of the states, as a pytree.

Type:

PyTree of Var

state_tree_outvars#

The output jaxpr variables of the states, as a pytree.

Type:

PyTree of Var

hidden_path_to_invar#

Mapping from each hidden path to its input variable.

Type:

dict

hidden_path_to_outvar#

Mapping from each hidden path to its output variable.

Type:

dict

invar_to_hidden_path#

Mapping from each input variable to its hidden path.

Type:

dict

outvar_to_hidden_path#

Mapping from each output variable to its hidden path.

Type:

dict

hidden_outvar_to_invar#

Mapping from each output variable to its input variable.

Type:

dict

weight_invars#

The weight input variables.

Type:

list of Var

weight_path_to_invars#

Mapping from each weight path to its input variables.

Type:

dict

invar_to_weight_path#

Mapping from each input variable to its weight path.

Type:

dict

num_var_out#

Number of original output variables.

Type:

int

num_var_state#

Number of state-variable outputs.

Type:

int

See also

extract_module_info

Build a ModuleInfo from a model.

Examples

>>> import brainstate
>>> import braintrace
>>> gru = braintrace.nn.GRUCell(3, 4)
>>> _ = brainstate.nn.init_all_states(gru)
>>> inputs = brainstate.random.randn(3)
>>> module_info = braintrace.extract_module_info(gru, inputs)
>>> isinstance(module_info, braintrace.ModuleInfo)
True
add_jaxpr_outs(jax_vars)[source]#

Add extra jaxpr outputs to the model jaxpr.

Returns a new ModuleInfo whose jaxpr additionally outputs the given variables, so the compiler can recover the intermediate values it needs.

Parameters:

jax_vars (Sequence[Var]) – The extra jaxpr variables to append to the jaxpr outputs.

Returns:

A new ModuleInfo with the extended jaxpr.

Return type:

ModuleInfo

closed_jaxpr: ClosedJaxpr#

Alias for field number 1

compiled_model_states: Sequence[State]#

Alias for field number 3

dict()[source]#

Return this module info’s named fields as a plain dictionary.

Returns:

An ordered mapping from field name to value, as produced by the underlying typing.NamedTuple.

Return type:

Dict[str, Any]

hidden_outvar_to_invar: Dict[Var, Var]#

Alias for field number 11

hidden_path_to_invar: Dict[Tuple[str, ...], Var]#

Alias for field number 7

hidden_path_to_outvar: Dict[Tuple[str, ...], Var]#

Alias for field number 8

invar_to_hidden_path: Dict[Var, Tuple[str, ...]]#

Alias for field number 9

invar_to_weight_path: Dict[Var, Tuple[str, ...]]#

Alias for field number 14

property jaxpr: Jaxpr#

The jaxpr of the model.

Returns:

The jaxpr extracted from closed_jaxpr.

Return type:

Jaxpr

jaxpr_call(*args, old_state_vals=None)[source]#

Evaluate the model on the given inputs using the compiled jaxpr.

Parameters:
  • *args (Any) – The inputs of the model.

  • old_state_vals (Optional[Sequence[Array]]) – The old state values. When None, the current values of the compiled model states are used. Default None.

Return type:

Tuple[Any, Dict[Tuple[str, ...], Any], Dict[Tuple[str, ...], Any], Dict[Var, Array]]

Returns:

  • out (Outputs) – The output of the model.

  • etrace_vals (ETraceVals) – The values for the eligibility-trace (hidden) states.

  • oth_state_vals (StateVals) – The other state values.

  • temps (TempData) – The temporary intermediate values.

num_var_out: int#

Alias for field number 15

num_var_state: int#

Alias for field number 16

outvar_to_hidden_path: Dict[Var, Tuple[str, ...]]#

Alias for field number 10

retrieved_model_states: FlattedDict[Tuple[str, ...], State]#

Alias for field number 2

split_state_outvars()[source]#

Split the state outvars into weight, hidden, and other states.

Returns:

  • weight_jaxvar_tree (PyTree of Var) – The weight tree of jaxpr variables.

  • hidden_jaxvar (PyTree of Var) – The hidden tree of jaxpr variables.

  • other_state_jaxvar_tree (PyTree of Var) – The other-state tree of jaxpr variables.

state_id_to_path: Dict[int, Tuple[str, ...]]#

Alias for field number 4

state_tree_invars: PyTree[jax._src.core.Var]#

Alias for field number 5

state_tree_outvars: PyTree[jax._src.core.Var]#

Alias for field number 6

stateful_model: StatefulFunction#

Alias for field number 0

weight_invars: List[Var]#

Alias for field number 12

weight_path_to_invars: Dict[Tuple[str, ...], List[Var]]#

Alias for field number 13