HiddenGroup#
- class braintrace.HiddenGroup(index: int, hidden_paths: List[Tuple[str, ...]], hidden_states: List[HiddenState], hidden_invars: List[Var], hidden_outvars: List[Var], transition_jaxpr: Jaxpr, transition_jaxpr_constvars: List[Var])#
The data structure recording a hidden-group relation.
A hidden group bundles the hidden states that are mutually connected through a recurrence transition, together with the jaxpr that computes that transition
\[h_1^t, h_2^t, \ldots = f(h_1^{t-1}, h_2^{t-1}, \ldots, x^t).\]- hidden_invars#
The input jaxpr
Varof each hidden state (at the previous step).- Type:
list of HiddenInVar
- hidden_outvars#
The output jaxpr
Varof each hidden state (at the current step).- Type:
list of HiddenOutVar
- transition_jaxpr#
The jaxpr computing the hidden-state transition for the group.
- Type:
Jaxpr
- transition_jaxpr_constvars#
The other input variables required to evaluate
transition_jaxpr.- Type:
list of Var
See also
find_hidden_groups_from_moduleBuild hidden groups directly from a model.
Examples
>>> import brainstate >>> import braintrace >>> gru = braintrace.nn.GRUCell(3, 4) >>> _ = brainstate.nn.init_all_states(gru) >>> inputs = brainstate.random.randn(3) >>> hidden_groups, _ = braintrace.find_hidden_groups_from_module(gru, inputs) >>> len(hidden_groups) 1
- check_consistent_varshape()[source]#
Check whether the shapes of the hidden states are consistent.
- Raises:
NotSupportedError – If the shapes of the hidden states are not consistent.
- concat_hidden(splitted_hid_vals)[source]#
Concatenate split hidden-state values into a single array.
Concatenates a sequence of split hidden-state values along the last axis. For non-
HiddenGroupStatevalues, an extra trailing dimension is added before concatenation.- Parameters:
splitted_hid_vals (
Sequence[Array]) – A sequence of split hidden-state values, each corresponding to a hidden state in the group.- Returns:
A single array containing all hidden-state values concatenated along the last axis.
- Return type:
jax.Array
- diagonal_jacobian(hidden_vals, input_vals)[source]#
Compute the diagonal Jacobian matrix along the last dimension.
- dict()[source]#
Return this group’s named fields as a plain dictionary.
- Returns:
An ordered mapping from field name to value, as produced by the underlying
typing.NamedTuple.- Return type:
- property num_state: int#
The number of hidden states.
- Returns:
The total number of hidden states across the group.
- Return type:
- split_hidden(concat_hid_vals)[source]#
Split a concatenated hidden-state array into individual arrays.
Splits a concatenated array of hidden-state values into separate arrays, one per hidden state in the group.
HiddenGroupStateand non-HiddenGroupStatevalues are handled differently.- Parameters:
concat_hid_vals (
Array) – A concatenated array of hidden-state values. The last dimension is assumed to contain the concatenated states.- Returns:
A list of split hidden-state arrays. For non-
HiddenGroupStatevalues, the last dimension is squeezed.- Return type:
list of jax.Array
- transition(hidden_vals, input_vals)[source]#
Compute the hidden-state transitions.
Evaluates the group transition jaxpr
\[h_1^t, h_2^t, \cdots = f(h_1^{t-1}, h_2^{t-1}, \cdots, x^t).\]
- transition_jaxpr: Jaxpr#
Alias for field number 5