# Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# ==============================================================================
#
# Author: Chaoming Wang <chao.brain@qq.com>
# Copyright: 2024, Chaoming Wang
# Date: 2024-04-03
#
# Refinement History:
# [2024-04-03] Created
# [2024-04-06] Added the traceback information for the error messages.
# [2024-04-16] Changed the "op" in the "HiddenWeightOpTracer" to "JaxprEqn".
# Added the support for the "pjit" operator.
# [2024-05] Add the support for vjp_time == 't_minus_1'
# [2024-06] Conditionally support control flows, including `scan`, `while`, and `cond`
# [2024-09] version 0.0.2
# [2024-11-22] compatible with `brainstate>=0.1.0` (#17)
# [2024-11-23] Add the support for vjp_time_ahead > 1, it can combine the
# advantage of etrace learning and backpropagation through time.
# [2024-11-26] version 0.0.3, a complete new revision for better model debugging.
# [2024-12-05] change the ETraceWeight to NonETraceWeight if the hidden states are not found;
# remove the connected hidden states when y=x@w is not shape broadcastable with the hidden states.
# [2024-12-09] small updates, related to the key items in "CompiledVjpGraph"
# [2025-02-06]
# - [x] unify model retrieved states (brainstate.graph.states)
# and compiled states (brainstate.transform.StatefulFunction)
# - [x] add the support for the "HiddenGroupState" and "ETraceTreeState"
# - [x] add the support for the "ElemWiseParam"
# - [x] split into "_etrace_compiler.py", "_etrace_vjp_compiler_graph.py", and "hidden_group.py",
#
# ==============================================================================
# -*- coding: utf-8 -*-
from itertools import combinations
from typing import List, Dict, Sequence, Tuple, Set, Optional, Callable, NamedTuple, Any
import brainstate
import jax.core
import numpy as np
import saiunit as u
from brainstate import HiddenGroupState
from braintrace._compatible_imports import Var, Literal, JaxprEqn, Jaxpr
from braintrace._misc import NotSupportedError
from braintrace._typing import (
PyTree,
HiddenInVar,
HiddenOutVar,
Path,
)
from .base import JaxprEvaluation, find_matched_vars
from .module_info import extract_module_info, ModuleInfo
__all__ = [
'HiddenGroup',
'find_hidden_groups_from_minfo',
'find_hidden_groups_from_module',
]
class HiddenGroup(NamedTuple):
r"""The data structure recording a hidden-group relation.
A hidden group bundles the hidden states that are mutually connected through
a recurrence transition, together with the jaxpr that computes that
transition
.. math::
h_1^t, h_2^t, \ldots = f(h_1^{t-1}, h_2^{t-1}, \ldots, x^t).
Attributes
----------
index : int
Position of this group in the compiled group sequence.
hidden_paths : list of Path
The module path to each hidden state in the group.
hidden_states : list of brainstate.HiddenState
The hidden states in the group.
hidden_invars : list of HiddenInVar
The input jaxpr ``Var`` of each hidden state (at the previous step).
hidden_outvars : list of HiddenOutVar
The output jaxpr ``Var`` of each hidden state (at the current step).
transition_jaxpr : Jaxpr
The jaxpr computing the hidden-state transition for the group.
transition_jaxpr_constvars : list of Var
The other input variables required to evaluate ``transition_jaxpr``.
See Also
--------
find_hidden_groups_from_module : Build hidden groups directly from a model.
Examples
--------
.. code-block:: python
>>> import brainstate
>>> import braintrace
>>> gru = braintrace.nn.GRUCell(3, 4)
>>> _ = brainstate.nn.init_all_states(gru)
>>> inputs = brainstate.random.randn(3)
>>> hidden_groups, _ = braintrace.find_hidden_groups_from_module(gru, inputs)
>>> len(hidden_groups)
1
"""
index: int # type: ignore[assignment] # intentional NamedTuple field; shadows tuple.index
# hidden states and their paths
hidden_paths: List[Path] # the hidden state paths
hidden_states: List[brainstate.HiddenState] # the hidden states
# the jax Var at the last time step
hidden_invars: List[HiddenInVar] # the input hidden states
# the jax Var at the current time step
hidden_outvars: List[HiddenOutVar] # the output hidden states
# the jaxpr for computing hidden state transitions
#
# h_1^t, h_2^t, ... = f(h_1^{t-1}, h_2^{t-1}, ..., x)
#
transition_jaxpr: Jaxpr
# the other input variables for transition_jaxpr evaluation
transition_jaxpr_constvars: List[Var]
@property
def varshape(self) -> Tuple[int, ...]:
"""The shape of each state variable.
Returns
-------
tuple of int
The variable shape shared by the hidden states in the group.
"""
return self.hidden_states[0].varshape
@property
def num_state(self) -> int:
"""The number of hidden states.
Returns
-------
int
The total number of hidden states across the group.
"""
return sum([st.num_state for st in self.hidden_states])
[docs]
def check_consistent_varshape(self):
"""Check whether the shapes of the hidden states are consistent.
Raises
------
NotSupportedError
If the shapes of the hidden states are not consistent.
"""
varshapes = set([tuple(st.varshape) for st in self.hidden_states])
if len(varshapes) > 1:
raise NotSupportedError(
f'Error: the shapes of the hidden states are not consistent. \n'
f'{varshapes}'
)
[docs]
def transition(
self,
hidden_vals: Sequence[jax.Array],
input_vals: PyTree,
) -> List[jax.Array]:
r"""Compute the hidden-state transitions.
Evaluates the group transition jaxpr
.. math::
h_1^t, h_2^t, \cdots = f(h_1^{t-1}, h_2^{t-1}, \cdots, x^t).
Parameters
----------
hidden_vals : sequence of jax.Array
The old hidden-state values.
input_vals : PyTree
The input values.
Returns
-------
list of jax.Array
The new hidden-state values.
"""
return jax.core.eval_jaxpr(self.transition_jaxpr, input_vals, *hidden_vals)
[docs]
def diagonal_jacobian(
self,
hidden_vals: Sequence[jax.Array],
input_vals: PyTree,
):
"""Compute the diagonal Jacobian matrix along the last dimension.
Parameters
----------
hidden_vals : sequence of jax.Array
The hidden-state values.
input_vals : PyTree
The input values.
Returns
-------
jax.Array
The diagonal Jacobian matrix, with shape
``(*varshape, num_states, num_states)``.
"""
return jacrev_last_dim(
lambda hid: self.concat_hidden(self.transition(self.split_hidden(hid), input_vals)),
self.concat_hidden(hidden_vals)
)
[docs]
def concat_hidden(self, splitted_hid_vals: Sequence[jax.Array]):
"""Concatenate split hidden-state values into a single array.
Concatenates a sequence of split hidden-state values along the last
axis. For non-``HiddenGroupState`` values, an extra trailing dimension
is added before concatenation.
Parameters
----------
splitted_hid_vals : sequence of jax.Array
A sequence of split hidden-state values, each corresponding to a
hidden state in the group.
Returns
-------
jax.Array
A single array containing all hidden-state values concatenated
along the last axis.
"""
splitted_hid_vals = [
val
if isinstance(st, HiddenGroupState) else
u.math.expand_dims(val, axis=-1)
for val, st in zip(splitted_hid_vals, self.hidden_states)
]
return u.math.concatenate(splitted_hid_vals, axis=-1)
[docs]
def split_hidden(self, concat_hid_vals: jax.Array):
"""Split a concatenated hidden-state array into individual arrays.
Splits a concatenated array of hidden-state values into separate arrays,
one per hidden state in the group. ``HiddenGroupState`` and
non-``HiddenGroupState`` values are handled differently.
Parameters
----------
concat_hid_vals : jax.Array
A concatenated array of hidden-state values. The last dimension is
assumed to contain the concatenated states.
Returns
-------
list of jax.Array
A list of split hidden-state arrays. For non-``HiddenGroupState``
values, the last dimension is squeezed.
"""
num_states = [st.num_state for st in self.hidden_states]
indices = np.cumsum(num_states)
splitted_hid_vals = u.math.split(concat_hid_vals, indices, axis=-1)
splitted_hid_vals = [
val
if isinstance(st, HiddenGroupState) else
u.math.squeeze(val, axis=-1)
for val, st in zip(splitted_hid_vals, self.hidden_states)
]
return splitted_hid_vals
[docs]
def dict(self) -> Dict[str, Any]:
"""Return this group's named fields as a plain dictionary.
Returns
-------
dict
An ordered mapping from field name to value, as produced by the
underlying :class:`typing.NamedTuple`.
"""
return self._asdict()
def __repr__(self) -> str:
return repr(brainstate.util.PrettyMapping(self._asdict(), type_name=self.__class__.__name__))
HiddenGroup.__module__ = 'braintrace'
def jacrev_last_dim(
fn: Callable[..., jax.Array],
hid_vals: jax.Array,
) -> jax.Array:
"""
Compute the Jacobian of a function with respect to its last dimension.
This function calculates the Jacobian matrix of the given function 'fn'
with respect to the last dimension of the input 'hid_vals'. It uses
JAX's vector-Jacobian product (vjp) and vmap for efficient computation.
Args:
fn (Callable[[...], jax.Array]): The function for which to compute
the Jacobian. It should take a JAX array as input and return
a JAX array.
hid_vals (jax.Array): The input values for which to compute the
Jacobian. The last dimension is considered as the dimension
of interest.
Returns:
jax.Array: The Jacobian matrix. Its shape is (*varshape, num_state, num_state),
where varshape is the shape of the input excluding the last dimension,
and num_state is the size of the last dimension.
Raises:
AssertionError: If the number of input and output states are not the same.
"""
new_hid_vals, f_vjp = jax.vjp(fn, hid_vals)
num_state = new_hid_vals.shape[-1]
varshape = new_hid_vals.shape[:-1]
assert num_state == hid_vals.shape[-1], 'Error: the number of input/output states should be the same.'
g_primals = u.math.broadcast_to(u.math.eye(num_state), (*varshape, num_state, num_state))
jac = jax.vmap(f_vjp, in_axes=-2, out_axes=-2)(g_primals)
return jac[0]
class HiddenToHiddenGroupTracer(NamedTuple):
"""
The data structure for the tracing of the hidden-to-hidden states.
Attributes:
hidden_invar (Var): The input variable representing the hidden state.
connected_hidden_outvars (set[Var]): A set of output variables representing the connected hidden states.
other_invars (set[Var]): A set of other input variables involved in the tracing.
invar_needed_in_oth_eqns (set[Var]): A set of variables needed in other equations for trace analysis.
trace (List[JaxprEqn]): A list of JAX equations representing the trace of operations.
"""
hidden_invar: Var
connected_hidden_outvars: set[Var]
other_invars: set[Var]
invar_needed_in_oth_eqns: set[Var]
trace: List[JaxprEqn]
def dict(self) -> Dict[str, Any]:
"""Return this tracer's named fields as a plain dictionary.
Returns
-------
dict
An ordered mapping from field name to value, as produced by the
underlying :class:`typing.NamedTuple`.
"""
return self._asdict()
def __repr__(self) -> str:
return repr(brainstate.util.PrettyMapping(self._asdict(), type_name=self.__class__.__name__))
class Hidden2GroupTransition(NamedTuple):
"""
Represents a hidden state transition in a computational graph.
This class captures the transition of hidden states from one time step to the next
within a neural network model. It includes information about the input hidden state,
the connected output hidden states, and the JAX program representation (jaxpr) that
defines the transition.
Attributes:
hidden_invar (Var): The input variable representing the hidden state at the previous time step.
hidden_path (Path): The path to the hidden state in the model hierarchy.
connected_hidden_outvars (List[Var]): A list of output variables representing the connected hidden states at the current time step.
connected_hidden_paths (List[Path]): A list of paths to the connected hidden states in the model hierarchy.
transition_jaxpr (Jaxpr): The JAX program representation for computing the hidden state transitions.
other_invars (List[Var]): A list of other input variables required for evaluating the transition_jaxpr.
"""
# the hidden state h_i^{t-1}
hidden_invar: Var
hidden_path: Path
# the connected hidden states h_1^t, h_2^t, ...
connected_hidden_outvars: List[Var]
connected_hidden_paths: List[Path]
# the jaxpr for computing hidden state transitions
#
# h_1^t, h_2^t, ... = f(h_i^{t-1}, x)
#
transition_jaxpr: Jaxpr
# the other input variables for jaxpr evaluation
other_invars: List[Var]
def state_transition(
self,
old_hidden_val: jax.Array,
other_input_vals: PyTree,
return_index: Optional[int] = None
) -> List[jax.Array] | jax.Array:
"""
Computing the hidden state transitions :math:`h^t = f(h_i^t, x)`.
Args:
old_hidden_val: The old hidden state value.
other_input_vals: The input values.
return_index: index of the hidden state to return.
Returns:
The new hidden state values.
"""
new_hidden_vals = jax.core.eval_jaxpr(self.transition_jaxpr, other_input_vals, old_hidden_val)
if return_index is not None:
return new_hidden_vals[return_index]
return new_hidden_vals
def dict(self) -> Dict[str, Any]:
"""Return this transition's named fields as a plain dictionary.
Returns
-------
dict
An ordered mapping from field name to value, as produced by the
underlying :class:`typing.NamedTuple`.
"""
return self._asdict()
def __repr__(self) -> str:
return repr(brainstate.util.PrettyMapping(self._asdict(), type_name=self.__class__.__name__))
def _same_recurrence_layer(path1: Path, path2: Path) -> bool:
"""
Check if two hidden state paths belong to the same recurrence layer.
Paths that diverge at a numeric index (e.g., ('layers', 0, ...) vs
('layers', 1, ...)) indicate different sequential layers and should
be in separate groups. Paths that diverge at string keys (e.g.,
('neu', 'V') vs ('neu', 'a')) are within the same layer.
"""
min_len = min(len(path1), len(path2))
for i in range(min_len):
if path1[i] != path2[i]:
return not (isinstance(path1[i], int) or isinstance(path2[i], int))
return True
def _simplify_hid2hid_tracer(
tracer: HiddenToHiddenGroupTracer,
hidden_invar_to_path: Dict[HiddenInVar, Path],
hidden_outvar_to_path: Dict[HiddenOutVar, Path],
path_to_state: Dict[Path, brainstate.HiddenState],
) -> Optional[Hidden2GroupTransition]:
"""
Simplifying the hidden-to-hidden state tracer.
Args:
tracer: The hidden-to-hidden state tracer.
hidden_invar_to_path: The mapping from the hidden input variable to the hidden state path.
hidden_outvar_to_path: The mapping from the hidden output variable to the hidden state path.
path_to_state: The mapping from the hidden state path to the state.
Returns:
The hidden-to-hidden state transition.
"""
#
# [pre-step]
#
# Filter out hidden outvars from different recurrence layers.
# In multi-layer networks, a hidden state from one layer may be
# connected to hidden outvars of other layers through the computation
# graph. These cross-layer connections should not be in the same group.
# Two filters are applied:
# 1. Shape compatibility: outvars must match the invar's shape.
# 2. Layer membership: outvars whose paths diverge from the invar's
# path at a numeric index (e.g., layers.0 vs layers.1) are in
# different sequential layers and are excluded.
invar_path = hidden_invar_to_path[tracer.hidden_invar]
invar_state = path_to_state[invar_path]
compatible_outvars = {
hv for hv in tracer.connected_hidden_outvars
if (path_to_state[hidden_outvar_to_path[hv]].varshape == invar_state.varshape
and _same_recurrence_layer(invar_path, hidden_outvar_to_path[hv]))
}
if not compatible_outvars:
return None
#
# [first step]
#
# Remove the unnecessary equations in the trace.
# The unnecessary equations are the equations
# that do not contain the hidden states.
tracer.invar_needed_in_oth_eqns.clear()
new_trace = []
whole_trace_needed_vars = set(compatible_outvars)
visited_needed_vars = set() # needed_vars has been satisfied
for eqn in reversed(tracer.trace):
need_outvars = []
for outvar in eqn.outvars:
if outvar in whole_trace_needed_vars:
need_outvars.append(outvar)
if len(need_outvars):
visited_needed_vars.update(need_outvars)
new_trace.append(eqn)
whole_trace_needed_vars.update([invar for invar in eqn.invars if isinstance(invar, Var)])
# [second step]
#
# Shape filtering was already done in the pre-step.
hidden_outvars = tuple(compatible_outvars)
# [third step]
#
# Simplify the trace
visited_needed_vars.add(tracer.hidden_invar)
constvars = list(whole_trace_needed_vars.difference(visited_needed_vars))
jaxpr_opt = Jaxpr(
# the const vars are not the hidden states, they are
# intermediate data that are not used in the hidden states
constvars=constvars,
# the invars are always the weight output
invars=[tracer.hidden_invar],
# the outvars are always the connected hidden states of this weight
outvars=list(hidden_outvars),
# the new equations which are simplified
eqns=list(reversed(new_trace)),
)
# [final step]
#
# Change the "HiddenWeightOpTracer" to "Hidden2GroupTransition"
return Hidden2GroupTransition(
hidden_invar=tracer.hidden_invar,
hidden_path=hidden_invar_to_path[tracer.hidden_invar],
connected_hidden_outvars=list(hidden_outvars),
connected_hidden_paths=[hidden_outvar_to_path[var] for var in hidden_outvars],
transition_jaxpr=jaxpr_opt,
other_invars=constvars,
)
class JaxprEvalForHiddenGroup(JaxprEvaluation):
"""
Evaluating the jaxpr for extracting the hidden state ``hidden-to-hidden`` relationships.
Args:
jaxpr: The jaxpr for the model.
hidden_outvar_to_invar: The mapping from the hidden output variable to the hidden input variable.
weight_invars: The weight input variables.
invar_to_hidden_path: The mapping from the weight input variable to the hidden state path.
outvar_to_hidden_path: The mapping from the hidden output variable to the hidden state path.
path_to_state: The mapping from the hidden state path to the state.
"""
__module__ = 'braintrace'
def __init__(
self,
jaxpr: Jaxpr,
hidden_outvar_to_invar: Dict[HiddenOutVar, HiddenInVar],
weight_invars: Set[Var],
invar_to_hidden_path: Dict[HiddenInVar, Path],
outvar_to_hidden_path: Dict[HiddenOutVar, Path],
path_to_state: Dict[Path, brainstate.HiddenState],
):
# the jaxpr of the original model, assuming that the model is well-defined,
# see the doc for the model which can be online learning compiled.
self.jaxpr = jaxpr
# the hidden state groups
self.hidden_outvar_to_invar = hidden_outvar_to_invar
self.hidden_invar_to_outvar = {invar: outvar for outvar, invar in hidden_outvar_to_invar.items()}
hidden_invars = set(hidden_outvar_to_invar.values())
hidden_outvars = set(hidden_outvar_to_invar.keys())
self.path_to_state = path_to_state
# the data structures for the tracing hidden-hidden relationships
self.active_tracers: Dict[Var, HiddenToHiddenGroupTracer] = dict()
super().__init__(
weight_invars=weight_invars,
hidden_invars=hidden_invars,
hidden_outvars=hidden_outvars,
invar_to_hidden_path=invar_to_hidden_path,
outvar_to_hidden_path=outvar_to_hidden_path
)
def compile(self) -> Tuple[
Sequence[HiddenGroup],
Dict[Path, HiddenGroup],
]:
"""
Compiling the jaxpr for the etrace relationships.
"""
# the data structures for the tracing hidden-hidden relationships
self.active_tracers = dict()
# evaluating the jaxpr
self._eval_jaxpr(self.jaxpr)
# post checking
hid_groups, hid_path_to_group = self._post_check()
# reset the temporal data structures
self.active_tracers = dict()
return hid_groups, hid_path_to_group
def _eval_eqn(self, eqn: JaxprEqn) -> None:
"""
Evaluating the normal jaxpr equation.
"""
if eqn.primitive.name == 'stop_gradient':
return
# check whether the invars have one of the hidden states.
# If it is true, add a new tracer.
other_invars = []
hidden_invars = []
for invar in eqn.invars:
if isinstance(invar, Literal):
continue
elif invar in self.hidden_invars:
hidden_invars.append(invar)
else:
other_invars.append(invar)
if len(hidden_invars) > 0:
# A hidden invar may be used in multiple places.
# All places share a common tracer.
if len(hidden_invars) != 1:
paths = [str(self.invar_to_hidden_path[var]) for var in hidden_invars]
hidden_paths = "\n".join(paths)
raise ValueError(
f'Currently, we only support one hidden state in a single equation. \n'
f'{eqn}\n'
f'{hidden_paths}'
)
hidden_var = hidden_invars[0]
hidden_outvars = set([outvar for outvar in eqn.outvars if outvar in self.hidden_outvars])
needed_invars = set([outvar for outvar in eqn.outvars if outvar not in self.hidden_outvars])
if hidden_var in self.active_tracers:
self.active_tracers[hidden_var].trace.append(eqn.replace())
self.active_tracers[hidden_var].other_invars.update(other_invars)
self.active_tracers[hidden_var].invar_needed_in_oth_eqns.update(needed_invars)
self.active_tracers[hidden_var].connected_hidden_outvars.update(hidden_outvars)
else:
tracer = HiddenToHiddenGroupTracer(
hidden_invar=hidden_var,
connected_hidden_outvars=hidden_outvars,
other_invars=set(other_invars),
invar_needed_in_oth_eqns=needed_invars,
trace=[eqn.replace()]
)
self.active_tracers[hidden_var] = tracer
# check whether this equation is used in other tracers
for tracer in tuple(self.active_tracers.values()):
matched = find_matched_vars(eqn.invars, tracer.invar_needed_in_oth_eqns)
# if matched, add the eqn to the trace
# if not matched, skip
if len(matched):
self._add_eqn_in_a_tracer(eqn, tracer)
def _add_eqn_in_a_tracer(
self,
eqn: JaxprEqn,
tracer: HiddenToHiddenGroupTracer
) -> None:
tracer.trace.append(eqn.replace())
tracer.invar_needed_in_oth_eqns.update(eqn.outvars)
# check whether the hidden states are needed in the other equations
for outvar in eqn.outvars:
if outvar in self.hidden_outvars:
tracer.connected_hidden_outvars.add(outvar)
def _post_check(self) -> Tuple[
Sequence[HiddenGroup],
Dict[Path, HiddenGroup],
]:
# [ First step ]
#
# check the following items:
#
# 1. the shape of connected hidden states should be the same
# 2. simplify the trace
# 3. remove the unnecessary hidden states
hidden_to_group_transition = [
t for t in (
_simplify_hid2hid_tracer(
tracer,
self.invar_to_hidden_path,
self.outvar_to_hidden_path,
self.path_to_state,
)
for tracer in self.active_tracers.values()
)
if t is not None
]
# [ second step ]
#
# Find out the hidden group,
# i.e., the hidden states that are connected to each other, the union of all hidden-to-group
outvar_groups: list = [
set(
[self.hidden_invar_to_outvar[transition.hidden_invar]] +
list(transition.connected_hidden_outvars)
)
for transition in hidden_to_group_transition
]
outvar_groups = group_merging(outvar_groups, version=0)
outvar_groups = [list(group) for group in outvar_groups]
invar_groups = [
[self.hidden_outvar_to_invar[outvar] for outvar in group]
for group in outvar_groups
]
# [ third step ]
#
# compile the state transitions in a hidden group
#
# h_1^t, h_2^t, ... h_n^t = f(h_1^t-1, h_2^t-1, ...., h_n^t-1)
#
hidden_invar_to_transition = {
transition.hidden_invar: transition
for transition in hidden_to_group_transition
}
jaxpr_groups = []
for hidden_invars, hidden_outvars in zip(invar_groups, outvar_groups):
jaxpr_groups.append(
write_jaxpr_of_hidden_group_transition(
hidden_invar_to_transition,
hidden_invars,
hidden_outvars,
)
)
# [ fourth step ]
#
# compile HiddenGroup
#
hidden_groups: list = []
for hidden_invars, hidden_outvars, jaxpr in zip(invar_groups, outvar_groups, jaxpr_groups):
group = HiddenGroup(
index=len(hidden_groups),
hidden_invars=list(hidden_invars),
hidden_outvars=list(hidden_outvars),
hidden_paths=[
self.outvar_to_hidden_path[outvar]
for outvar in hidden_outvars
],
hidden_states=[
self.path_to_state[self.outvar_to_hidden_path[outvar]]
for outvar in hidden_outvars
],
transition_jaxpr=jaxpr,
transition_jaxpr_constvars=list(jaxpr.constvars),
)
hidden_groups.append(group)
# [ fifth step ]
#
# transform the hidden group set to the HiddenGroup
#
# hidden outvar to group
#
hidden_path_to_group: Dict[Path, HiddenGroup] = dict()
for group in hidden_groups:
for path in group.hidden_paths:
if path in hidden_path_to_group:
raise ValueError(
f'Error: the hidden state {path} '
f'is found in multiple groups. \n'
f'{hidden_path_to_group[path].hidden_paths} '
f'\n\n'
f'{group.hidden_paths}'
)
hidden_path_to_group[path] = group
return hidden_groups, hidden_path_to_group
def write_jaxpr_of_hidden_group_transition(
hidden_invar_to_transition: Dict[HiddenInVar, Hidden2GroupTransition],
hidden_invars: List[HiddenInVar],
hidden_outvars: List[HiddenOutVar],
) -> Jaxpr:
assert len(hidden_invars) >= 1
#
# step 1:
#
# filter out
#
# 1. all invars + constvars
# 2. equations
# 3. all outvars
#
eqns = []
all_invars = set()
all_outvars = set()
for invar in hidden_invars:
if invar in hidden_invar_to_transition:
transition = hidden_invar_to_transition[invar]
for eq in transition.transition_jaxpr.eqns:
this_eq_exist = [outvar in all_outvars for outvar in eq.outvars]
# this_eq_exist = False
# for outvar in eq.outvars:
# if outvar in all_outvars:
# this_eq_exist = True
# break
if not all(this_eq_exist):
eqns.append(eq.replace())
all_invars.update([invar for invar in eq.invars if not isinstance(invar, Literal)])
all_outvars.update(eq.outvars)
other_invars = list(all_invars.difference(all_outvars).difference(hidden_invars))
#
# step 2:
#
# order the equations so that data dependencies are satisfied
#
new_eqns = []
env = set(list(hidden_invars) + other_invars)
max_iterations = len(eqns) * len(eqns) + 1 # upper bound for topological sort passes
iteration_count = 0
while len(eqns) > 0:
iteration_count += 1
if iteration_count > max_iterations:
unresolved_invars = []
for eqn in eqns:
missing = [v for v in eqn.invars if not isinstance(v, Literal) and v not in env]
unresolved_invars.append((eqn, missing))
raise RuntimeError(
f'Topological sort failed: could not resolve all equation dependencies. '
f'{len(eqns)} equations remain unresolved. '
f'This may indicate a cyclic dependency or missing input variables. '
f'Unresolved equations: {unresolved_invars}'
)
eqn = eqns.pop(0)
if all((invar in env) for invar in eqn.invars if not isinstance(invar, Literal)):
# Execute the equation
new_eqns.append(eqn)
# Add outvars to env
env.update(eqn.outvars)
else:
# If invars are not in env, put the equation back to the queue
eqns.append(eqn)
#
# step 3:
#
# produce the new jaxpr
#
return Jaxpr(
constvars=list(other_invars),
invars=hidden_invars,
outvars=hidden_outvars,
eqns=new_eqns
)
def group_merging(groups, version: int = 1) -> List[frozenset[HiddenOutVar]]:
"""
Merging the hidden groups using the intersection of the hidden states.
For example, if we have the following hidden states:
[(h_1, h_2),
(h_2, h_3),
(h_4, h_5)]
The merged hidden states are:
[(h_1, h_2, h_3),
(h_4, h_5)]
This function takes a list of hidden groups and merges them if they share
any common hidden states. The merging process is controlled by the specified
version of the algorithm.
Args:
groups: A list of hidden groups, where each group is a collection of
hidden states represented as frozensets.
version: An integer specifying the version of the merging algorithm to use.
Default is 1. Version 0 and 1 are supported, with version 1 being
more efficient and readable.
Returns:
A list of merged hidden groups, where each group is a frozenset of
HiddenOutVar objects. The groups are merged based on shared hidden states.
"""
if version == 0:
previous = frozenset([frozenset(g) for g in groups])
while True:
new_groups = []
old_groups = list(previous)
not_merged = list(range(len(old_groups)))
while len(not_merged) > 0:
i = not_merged.pop()
merged = False
for j in tuple(not_merged):
if len(old_groups[i].intersection(old_groups[j])) > 0:
new_groups.append(old_groups[i].union(old_groups[j]))
not_merged.remove(j)
merged = True
if not merged:
new_groups.append(old_groups[i])
new = frozenset([frozenset(g) for g in new_groups])
if new == previous:
break
previous = new
return list(new)
elif version == 1:
# This code has been upgraded for better readability and efficiency.
prev = [frozenset(g) for g in set(map(frozenset, groups))]
while True:
new_groups = []
merged_indices = set()
for i, j in combinations(range(len(prev)), 2):
if i in merged_indices or j in merged_indices:
continue
if prev[i].intersection(prev[j]):
new_groups.append(prev[i].union(prev[j]))
merged_indices.update([i, j])
new_groups.extend(
prev[k]
for k in range(len(prev))
if k not in merged_indices
)
cur = frozenset(new_groups)
if cur == frozenset(prev):
break
prev = list(cur)
return list(cur)
else:
raise ValueError(f'Error: the version {version} is not supported.')
def find_hidden_groups_from_jaxpr(
jaxpr: Jaxpr,
hidden_outvar_to_invar: Dict[HiddenOutVar, HiddenInVar],
weight_invars: Set[Var],
invar_to_hidden_path: Dict[HiddenInVar, Path],
outvar_to_hidden_path: Dict[HiddenOutVar, Path],
path_to_state: Dict[Path, brainstate.State],
) -> Tuple[Sequence[HiddenGroup], brainstate.util.PrettyDict]:
"""
Find hidden groups from the jaxpr.
Args:
jaxpr: The jaxpr for the model.
hidden_outvar_to_invar: Mapping from hidden output variable to hidden input variable.
weight_invars: Set of weight input variables.
invar_to_hidden_path: Mapping from weight input variable to hidden state path.
outvar_to_hidden_path: Mapping from hidden output variable to hidden state path.
path_to_state: Mapping from hidden state path to state.
Returns:
A tuple containing:
- Sequence of HiddenGroup objects
- PrettyDict mapping hidden state paths to hidden groups
"""
evaluator = JaxprEvalForHiddenGroup(
jaxpr=jaxpr,
hidden_outvar_to_invar=hidden_outvar_to_invar,
weight_invars=weight_invars,
invar_to_hidden_path=invar_to_hidden_path,
outvar_to_hidden_path=outvar_to_hidden_path,
path_to_state=path_to_state,
)
hidden_groups, hid_path_to_group = evaluator.compile()
return hidden_groups, brainstate.util.PrettyDict(hid_path_to_group)
[docs]
def find_hidden_groups_from_minfo(
minfo: ModuleInfo
):
"""Find the hidden groups from the model information.
Parameters
----------
minfo : ModuleInfo
The model information.
Returns
-------
hidden_groups : sequence of HiddenGroup
The hidden groups.
hid_path_to_group : dict
Mapping from each hidden-state path to its :class:`HiddenGroup`.
See Also
--------
find_hidden_groups_from_module : Equivalent helper starting from a model.
"""
(
hidden_groups,
hid_path_to_group,
) = find_hidden_groups_from_jaxpr(
jaxpr=minfo.jaxpr,
hidden_outvar_to_invar=minfo.hidden_outvar_to_invar,
weight_invars=set(minfo.weight_invars),
invar_to_hidden_path=minfo.invar_to_hidden_path,
outvar_to_hidden_path=minfo.outvar_to_hidden_path,
path_to_state=minfo.retrieved_model_states,
)
return hidden_groups, hid_path_to_group
[docs]
def find_hidden_groups_from_module(
model: brainstate.nn.Module,
*model_args,
**model_kwargs,
) -> Tuple[Sequence[HiddenGroup], brainstate.util.PrettyDict]:
"""Find hidden groups from a model.
Parameters
----------
model : brainstate.nn.Module
The model.
*model_args
The positional arguments of the model.
**model_kwargs
The keyword arguments of the model.
Returns
-------
hidden_groups : sequence of HiddenGroup
The hidden groups.
hid_path_to_group : brainstate.util.PrettyDict
Mapping from each hidden-state path to its :class:`HiddenGroup`.
See Also
--------
find_hidden_groups_from_minfo : Equivalent helper starting from ``ModuleInfo``.
Examples
--------
.. code-block:: python
>>> import brainstate
>>> import braintrace
>>> gru = braintrace.nn.GRUCell(3, 4)
>>> _ = brainstate.nn.init_all_states(gru)
>>> inputs = brainstate.random.randn(3)
>>> hidden_groups, hid_path_to_group = braintrace.find_hidden_groups_from_module(gru, inputs)
>>> len(hidden_groups)
1
"""
minfo = extract_module_info(model, *model_args, **model_kwargs)
return find_hidden_groups_from_minfo(minfo)