find_hidden_groups_from_module

find_hidden_groups_from_module#

class braintrace.find_hidden_groups_from_module(model, *model_args, **model_kwargs)[source]#

Find hidden groups from a model.

Parameters:
  • model (Module) – The model.

  • *model_args – The positional arguments of the model.

  • **model_kwargs – The keyword arguments of the model.

Return type:

Tuple[Sequence[HiddenGroup], PrettyDict]

Returns:

  • hidden_groups (sequence of HiddenGroup) – The hidden groups.

  • hid_path_to_group (brainstate.util.PrettyDict) – Mapping from each hidden-state path to its HiddenGroup.

See also

find_hidden_groups_from_minfo

Equivalent helper starting from ModuleInfo.

Examples

>>> import brainstate
>>> import braintrace
>>> gru = braintrace.nn.GRUCell(3, 4)
>>> _ = brainstate.nn.init_all_states(gru)
>>> inputs = brainstate.random.randn(3)
>>> hidden_groups, hid_path_to_group = braintrace.find_hidden_groups_from_module(gru, inputs)
>>> len(hidden_groups)
1