HiddenGroup#

class braintrace.HiddenGroup(index: int, hidden_paths: List[Tuple[str, ...]], hidden_states: List[HiddenState], hidden_invars: List[Var], hidden_outvars: List[Var], transition_jaxpr: Jaxpr, transition_jaxpr_constvars: List[Var])#

The data structure recording a hidden-group relation.

A hidden group bundles the hidden states that are mutually connected through a recurrence transition, together with the jaxpr that computes that transition

\[h_1^t, h_2^t, \ldots = f(h_1^{t-1}, h_2^{t-1}, \ldots, x^t).\]
index#

Position of this group in the compiled group sequence.

Type:

int

hidden_paths#

The module path to each hidden state in the group.

Type:

list of Path

hidden_states#

The hidden states in the group.

Type:

list of brainstate.HiddenState

hidden_invars#

The input jaxpr Var of each hidden state (at the previous step).

Type:

list of HiddenInVar

hidden_outvars#

The output jaxpr Var of each hidden state (at the current step).

Type:

list of HiddenOutVar

transition_jaxpr#

The jaxpr computing the hidden-state transition for the group.

Type:

Jaxpr

transition_jaxpr_constvars#

The other input variables required to evaluate transition_jaxpr.

Type:

list of Var

See also

find_hidden_groups_from_module

Build hidden groups directly from a model.

Examples

>>> import brainstate
>>> import braintrace
>>> gru = braintrace.nn.GRUCell(3, 4)
>>> _ = brainstate.nn.init_all_states(gru)
>>> inputs = brainstate.random.randn(3)
>>> hidden_groups, _ = braintrace.find_hidden_groups_from_module(gru, inputs)
>>> len(hidden_groups)
1
check_consistent_varshape()[source]#

Check whether the shapes of the hidden states are consistent.

Raises:

NotSupportedError – If the shapes of the hidden states are not consistent.

concat_hidden(splitted_hid_vals)[source]#

Concatenate split hidden-state values into a single array.

Concatenates a sequence of split hidden-state values along the last axis. For non-HiddenGroupState values, an extra trailing dimension is added before concatenation.

Parameters:

splitted_hid_vals (Sequence[Array]) – A sequence of split hidden-state values, each corresponding to a hidden state in the group.

Returns:

A single array containing all hidden-state values concatenated along the last axis.

Return type:

jax.Array

diagonal_jacobian(hidden_vals, input_vals)[source]#

Compute the diagonal Jacobian matrix along the last dimension.

Parameters:
  • hidden_vals (Sequence[Array]) – The hidden-state values.

  • input_vals (Any) – The input values.

Returns:

The diagonal Jacobian matrix, with shape (*varshape, num_states, num_states).

Return type:

jax.Array

dict()[source]#

Return this group’s named fields as a plain dictionary.

Returns:

An ordered mapping from field name to value, as produced by the underlying typing.NamedTuple.

Return type:

Dict[str, Any]

hidden_invars: List[Var]#

Alias for field number 3

hidden_outvars: List[Var]#

Alias for field number 4

hidden_paths: List[Tuple[str, ...]]#

Alias for field number 1

hidden_states: List[HiddenState]#

Alias for field number 2

index: int#

Alias for field number 0

property num_state: int#

The number of hidden states.

Returns:

The total number of hidden states across the group.

Return type:

int

split_hidden(concat_hid_vals)[source]#

Split a concatenated hidden-state array into individual arrays.

Splits a concatenated array of hidden-state values into separate arrays, one per hidden state in the group. HiddenGroupState and non-HiddenGroupState values are handled differently.

Parameters:

concat_hid_vals (Array) – A concatenated array of hidden-state values. The last dimension is assumed to contain the concatenated states.

Returns:

A list of split hidden-state arrays. For non-HiddenGroupState values, the last dimension is squeezed.

Return type:

list of jax.Array

transition(hidden_vals, input_vals)[source]#

Compute the hidden-state transitions.

Evaluates the group transition jaxpr

\[h_1^t, h_2^t, \cdots = f(h_1^{t-1}, h_2^{t-1}, \cdots, x^t).\]
Parameters:
  • hidden_vals (Sequence[Array]) – The old hidden-state values.

  • input_vals (Any) – The input values.

Returns:

The new hidden-state values.

Return type:

List[Array]

transition_jaxpr: Jaxpr#

Alias for field number 5

transition_jaxpr_constvars: List[Var]#

Alias for field number 6

property varshape: Tuple[int, ...]#

The shape of each state variable.

Returns:

The variable shape shared by the hidden states in the group.

Return type:

tuple of int