add_hidden_perturbation_in_module#
- class braintrace.add_hidden_perturbation_in_module(model, *model_args, **model_kwargs)[source]#
Add hidden-state perturbations from a model.
Adds perturbations to the hidden states of the given module and replaces the hidden states with the perturbed states.
- Parameters:
model (
Module) – The neural-network module to which hidden-state perturbations are added.*model_args – Additional positional arguments passed to the model.
**model_kwargs – Additional keyword arguments passed to the model.
- Returns:
Information about the perturbations added to the hidden states, including the perturbed variables, paths, states, and the revised jaxpr.
- Return type:
See also
add_hidden_perturbation_from_minfoEquivalent helper starting from
ModuleInfo.
Examples
>>> import brainstate >>> import braintrace >>> gru = braintrace.nn.GRUCell(3, 4) >>> _ = brainstate.nn.init_all_states(gru) >>> inputs = brainstate.random.randn(3) >>> hidden_perturb = braintrace.add_hidden_perturbation_in_module(gru, inputs) >>> len(hidden_perturb.perturb_vars) 1