Source code for braintrace._etrace_compiler.hidden_pertubation

# Copyright 2025 BrainX Ecosystem Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================


from typing import Dict, Set, Sequence, NamedTuple, Any

import brainstate
import jax.core
import saiunit as u

from braintrace._compatible_imports import (
    Var,
    JaxprEqn,
    Jaxpr,
    ClosedJaxpr,
    new_var
)
from braintrace._misc import (
    git_issue_addr,
)
from braintrace._typing import (
    HiddenInVar,
    HiddenOutVar,
    Path,
)
from .base import (
    JaxprEvaluation,
)
from .hidden_group import (
    HiddenGroup,
)
from .module_info import (
    extract_module_info,
    ModuleInfo,
)

__all__ = [
    'HiddenPerturbation',
    'add_hidden_perturbation_from_minfo',
    'add_hidden_perturbation_in_module',
]


class HiddenPerturbation(NamedTuple):
    r"""The hidden-perturbation information.

    Hidden perturbation adds a perturbation variable to each hidden state in the
    jaxpr and replaces the hidden states with the perturbed states:

    .. math::

        h^t = f(x) \;\Rightarrow\; h^t = f(x) + \mathrm{perturb\_var},

    where :math:`h` is the hidden state, :math:`f` is the function, :math:`x` is
    the input, and :math:`\mathrm{perturb\_var}` is the perturbation variable.

    Attributes
    ----------
    perturb_vars : sequence of Var
        The perturbation variables.
    perturb_hidden_paths : sequence of Path
        The hidden-state paths that are perturbed.
    perturb_hidden_states : sequence of brainstate.HiddenState
        The hidden states that are perturbed.
    perturb_jaxpr : ClosedJaxpr
        The perturbed jaxpr.

    See Also
    --------
    add_hidden_perturbation_in_module : Build perturbations directly from a model.

    Notes
    -----
    Internally a new variable :math:`\hat{h}^t = f(x)` is defined and an extra
    equation :math:`h^t = \hat{h}^t + \mathrm{perturb\_var}` is added. The
    perturbation lets the hidden-state gradient be read off the perturbation
    variable

    .. math::

        \frac{\partial L^t}{\partial h^t}
        = \frac{\partial L^t}{\partial \mathrm{perturb\_var}}.

    Examples
    --------
    .. code-block:: python

        >>> import brainstate
        >>> import braintrace
        >>> gru = braintrace.nn.GRUCell(3, 4)
        >>> _ = brainstate.nn.init_all_states(gru)
        >>> inputs = brainstate.random.randn(3)
        >>> hidden_perturb = braintrace.add_hidden_perturbation_in_module(gru, inputs)
        >>> isinstance(hidden_perturb, braintrace.HiddenPerturbation)
        True
    """
    perturb_vars: Sequence[Var]  # the perturbation variables
    perturb_hidden_paths: Sequence[Path]  # the hidden state paths that are perturbed
    perturb_hidden_states: Sequence[brainstate.HiddenState]  # the hidden states that are perturbed
    perturb_jaxpr: ClosedJaxpr  # the perturbed jaxpr

[docs] def eval_jaxpr( self, inputs: Sequence[jax.Array], perturb_data: Sequence[jax.Array] ) -> Sequence[jax.Array]: """Evaluate the perturbed jaxpr. Parameters ---------- inputs : sequence of jax.Array The flat input values of the original jaxpr. perturb_data : sequence of jax.Array The perturbation values, one per entry of ``perturb_vars``. Returns ------- sequence of jax.Array The outputs of the perturbed jaxpr. """ return jax.core.eval_jaxpr( self.perturb_jaxpr.jaxpr, self.perturb_jaxpr.consts, *(tuple(inputs) + tuple(perturb_data)) )
[docs] def init_perturb_data(self) -> Sequence[jax.Array]: """Initialize the perturbation data to zeros. Returns ------- sequence of jax.Array One zero array per perturbation variable, matching its shape and dtype. """ return [jax.numpy.zeros(v.aval.shape, dtype=v.aval.dtype) for v in self.perturb_vars]
[docs] def perturb_data_to_hidden_group_data( self, perturb_data: Sequence[jax.Array], hidden_groups: Sequence[HiddenGroup], ) -> Sequence[jax.Array]: """Convert the perturbation data to per-hidden-group data. Parameters ---------- perturb_data : sequence of jax.Array The perturbation values, one per entry of ``perturb_vars``. hidden_groups : sequence of HiddenGroup The hidden groups to map the perturbation data onto. Returns ------- sequence of jax.Array One concatenated perturbation array per hidden group. Raises ------ AssertionError If ``perturb_data`` does not have the same length as ``perturb_vars``. """ assert len(perturb_data) == len(self.perturb_vars), ( f'The length of the perturb data is not correct. ' f'Expected: {len(self.perturb_vars)}, ' f'Got: {len(perturb_data)}' ) path_to_perturb_data = { path: data for path, data in zip(self.perturb_hidden_paths, perturb_data) } return [ group.concat_hidden( [ # dimensionless processing u.get_mantissa(path_to_perturb_data[path]) for path in group.hidden_paths ] ) for group in hidden_groups ]
[docs] def dict(self) -> Dict[str, Any]: """Return this perturbation's named fields as a plain dictionary. Returns ------- dict An ordered mapping from field name to value, as produced by the underlying :class:`typing.NamedTuple`. """ return self._asdict()
def __repr__(self) -> str: return repr(brainstate.util.PrettyMapping(self._asdict(), type_name=self.__class__.__name__)) HiddenPerturbation.__module__ = 'braintrace' class JaxprEvalForHiddenPerturbation(JaxprEvaluation): """ Adding perturbations to the hidden states in the jaxpr, and replacing the hidden states with the perturbed states. Args: closed_jaxpr: The closed jaxpr for the model. outvar_to_hidden_path: The mapping from the outvar to the state id. hidden_outvar_to_invar: The mapping from the hidden output variable to the hidden input variable. weight_invars: The weight input variables. invar_to_hidden_path: The mapping from the weight input variable to the hidden state path. Returns: The revised closed jaxpr with the perturbations. """ __module__ = 'braintrace' def __init__( self, closed_jaxpr: ClosedJaxpr, hidden_outvar_to_invar: Dict[HiddenOutVar, HiddenInVar], weight_invars: Set[Var], invar_to_hidden_path: Dict[HiddenInVar, Path], outvar_to_hidden_path: Dict[Var, Path], path_to_state: Dict[Path, brainstate.HiddenState], ): # necessary data structures self.closed_jaxpr = closed_jaxpr # initialize the super class # Use dict.fromkeys to deduplicate while preserving insertion order # (avoids non-deterministic set iteration for jaxpr variable ordering) hidden_outvars_ordered = list(dict.fromkeys(hidden_outvar_to_invar.keys())) hidden_invars_ordered = list(dict.fromkeys(hidden_outvar_to_invar.values())) super().__init__( weight_invars=weight_invars, hidden_invars=set(hidden_invars_ordered), hidden_outvars=set(hidden_outvars_ordered), invar_to_hidden_path=invar_to_hidden_path, outvar_to_hidden_path=outvar_to_hidden_path ) # Keep an ordered version for deterministic iteration in compile() self._hidden_outvars_ordered = hidden_outvars_ordered self.path_to_state = path_to_state def compile(self) -> HiddenPerturbation: # new invars, the var order is the same as the hidden_outvars (use ordered list for determinism) self.perturb_invars = { v: self._new_var_like(v) for v in self._hidden_outvars_ordered } # the hidden states that are not found in the code self.hidden_jaxvars_to_remove = set(self.hidden_outvars) # final revised equations self.revised_eqns: list = [] # revising equations self._eval_jaxpr(self.closed_jaxpr.jaxpr) # [final checking] # If there are hidden states that are not found in the code, we raise an error. if len(self.hidden_jaxvars_to_remove) > 0: hid_paths = [self.outvar_to_hidden_path[v] for v in self.hidden_jaxvars_to_remove] hid_info = '\n'.join([f'{v} -> {path}' for v, path in zip(self.hidden_jaxvars_to_remove, hid_paths)]) raise ValueError( f'Error: we did not found your defined hidden state ' f'(see the following information) in the code. \n' f'Please report an issue to the developers at {git_issue_addr}. \n' f'The missed hidden states are: \n' f'{hid_info}' ) # new jaxpr jaxpr = Jaxpr( constvars=list(self.closed_jaxpr.jaxpr.constvars), invars=list(self.closed_jaxpr.jaxpr.invars) + list(self.perturb_invars.values()), outvars=list(self.closed_jaxpr.jaxpr.outvars), eqns=self.revised_eqns ) revised_closed_jaxpr = ClosedJaxpr(jaxpr, self.closed_jaxpr.literals) # finalizing perturb_hidden_paths = [self.outvar_to_hidden_path[v] for v in self._hidden_outvars_ordered] perturb_hidden_states = [self.path_to_state[self.outvar_to_hidden_path[v]] for v in self._hidden_outvars_ordered] info = HiddenPerturbation( perturb_vars=brainstate.util.PrettyList(self.perturb_invars.values()), perturb_hidden_paths=brainstate.util.PrettyList(perturb_hidden_paths), perturb_hidden_states=brainstate.util.PrettyList(perturb_hidden_states), perturb_jaxpr=revised_closed_jaxpr ) # remove the temporal data self.perturb_invars = dict() self.revised_eqns = [] self.hidden_jaxvars_to_remove = set() return info def _eval_pjit(self, eqn: JaxprEqn) -> None: """ Evaluating the pjit primitive. Note: Unlike the base class, the perturbation compiler must process ALL equations (including etrace ops without gradient) to maintain a valid jaxpr. Skipping equations would leave variable references unresolved. """ self._eval_eqn(eqn) def _add_perturb_eqn( self, eqn: JaxprEqn, perturb_var: Var ): # ------------------------------------------------ # # For the hidden var eqn, we want to add a perturbation: # y = f(x) => y = f(x) + perturb_var # # Particularly, we first define a new variable # new_outvar = f(x) # # Then, we add a new equation for the perturbation # y = new_outvar + perturb_var # # ------------------------------------------------ hidden_var = eqn.outvars[0] # Frist step, define the hidden var as a new variable new_outvar = self._new_var_like(perturb_var) old_eqn = eqn.replace(outvars=[new_outvar]) self.revised_eqns.append(old_eqn) # Second step, add the perturbation equation new_eqn = jax.core.new_jaxpr_eqn( [new_outvar, perturb_var], [hidden_var], jax.lax.add_p, {}, set(), eqn.source_info.replace() ) self.revised_eqns.append(new_eqn) def _eval_eqn(self, eqn: JaxprEqn): if len(eqn.outvars) == 1: if eqn.outvars[0] in self.hidden_jaxvars_to_remove: hidden_var = eqn.outvars[0] self.hidden_jaxvars_to_remove.remove(hidden_var) self._add_perturb_eqn(eqn, self.perturb_invars[hidden_var]) return self.revised_eqns.append(eqn.replace()) def _new_var_like(self, v): return new_var('', jax.core.ShapedArray(v.aval.shape, v.aval.dtype)) def add_hidden_perturbation_in_jaxpr( closed_jaxpr: ClosedJaxpr, hidden_outvar_to_invar: Dict[HiddenOutVar, HiddenInVar], weight_invars: Set[Var], invar_to_hidden_path: Dict[HiddenInVar, Path], outvar_to_hidden_path: Dict[Var, Path], path_to_state: Dict[Path, brainstate.HiddenState], ) -> HiddenPerturbation: """ Adding perturbations to the hidden states in the jaxpr, and replacing the hidden states with the perturbed states. Args: closed_jaxpr: The closed jaxpr for the model. outvar_to_hidden_path: The mapping from the outvar to the state id. hidden_outvar_to_invar: The mapping from the hidden output variable to the hidden input variable. weight_invars: The weight input variables. invar_to_hidden_path: The mapping from the weight input variable to the hidden state path. path_to_state: The mapping from the hidden state path to the state. Returns: The revised closed jaxpr with the perturbations. """ return JaxprEvalForHiddenPerturbation( closed_jaxpr=closed_jaxpr, hidden_outvar_to_invar=hidden_outvar_to_invar, weight_invars=weight_invars, invar_to_hidden_path=invar_to_hidden_path, outvar_to_hidden_path=outvar_to_hidden_path, path_to_state=path_to_state, ).compile()
[docs] def add_hidden_perturbation_from_minfo( minfo: ModuleInfo, ) -> HiddenPerturbation: """Add hidden-state perturbations from a ``ModuleInfo``. Adds perturbations to the hidden states in the module jaxpr and replaces the hidden states with the perturbed states. Parameters ---------- minfo : ModuleInfo The model information. Returns ------- HiddenPerturbation The hidden-perturbation information. See Also -------- add_hidden_perturbation_in_module : Equivalent helper starting from a model. """ return add_hidden_perturbation_in_jaxpr( closed_jaxpr=minfo.closed_jaxpr, hidden_outvar_to_invar=minfo.hidden_outvar_to_invar, weight_invars=set(minfo.weight_invars), invar_to_hidden_path=minfo.invar_to_hidden_path, outvar_to_hidden_path=minfo.outvar_to_hidden_path, path_to_state=minfo.retrieved_model_states, )
[docs] def add_hidden_perturbation_in_module( model: brainstate.nn.Module, *model_args, **model_kwargs, ) -> HiddenPerturbation: """Add hidden-state perturbations from a model. Adds perturbations to the hidden states of the given module and replaces the hidden states with the perturbed states. Parameters ---------- model : brainstate.nn.Module The neural-network module to which hidden-state perturbations are added. *model_args Additional positional arguments passed to the model. **model_kwargs Additional keyword arguments passed to the model. Returns ------- HiddenPerturbation Information about the perturbations added to the hidden states, including the perturbed variables, paths, states, and the revised jaxpr. See Also -------- add_hidden_perturbation_from_minfo : Equivalent helper starting from ``ModuleInfo``. Examples -------- .. code-block:: python >>> import brainstate >>> import braintrace >>> gru = braintrace.nn.GRUCell(3, 4) >>> _ = brainstate.nn.init_all_states(gru) >>> inputs = brainstate.random.randn(3) >>> hidden_perturb = braintrace.add_hidden_perturbation_in_module(gru, inputs) >>> len(hidden_perturb.perturb_vars) 1 """ minfo = extract_module_info(model, *model_args, **model_kwargs) return add_hidden_perturbation_from_minfo(minfo)