extract_module_info

extract_module_info#

class braintrace.extract_module_info(model, *model_args, **model_kwargs)[source]#

Extract the model information for the ETrace compiler.

Parameters:
  • model (Module) – The model from which to extract the information.

  • *model_args – The positional arguments of the model.

  • **model_kwargs – The keyword arguments of the model.

Returns:

The model information.

Return type:

ModuleInfo

See also

ModuleInfo

The returned data structure.

Examples

>>> import brainstate
>>> import braintrace
>>> gru = braintrace.nn.GRUCell(3, 4)
>>> _ = brainstate.nn.init_all_states(gru)
>>> inputs = brainstate.random.randn(3)
>>> module_info = braintrace.extract_module_info(gru, inputs)
>>> module_info.num_var_out
1