find_hidden_groups_from_module#
- class braintrace.find_hidden_groups_from_module(model, *model_args, **model_kwargs)[source]#
Find hidden groups from a model.
- Parameters:
model (
Module) – The model.*model_args – The positional arguments of the model.
**model_kwargs – The keyword arguments of the model.
- Return type:
Tuple[Sequence[HiddenGroup],PrettyDict]- Returns:
hidden_groups (sequence of HiddenGroup) – The hidden groups.
hid_path_to_group (brainstate.util.PrettyDict) – Mapping from each hidden-state path to its
HiddenGroup.
See also
find_hidden_groups_from_minfoEquivalent helper starting from
ModuleInfo.
Examples
>>> import brainstate >>> import braintrace >>> gru = braintrace.nn.GRUCell(3, 4) >>> _ = brainstate.nn.init_all_states(gru) >>> inputs = brainstate.random.randn(3) >>> hidden_groups, hid_path_to_group = braintrace.find_hidden_groups_from_module(gru, inputs) >>> len(hidden_groups) 1