add_hidden_perturbation_in_module

add_hidden_perturbation_in_module#

class braintrace.add_hidden_perturbation_in_module(model, *model_args, **model_kwargs)[source]#

Add hidden-state perturbations from a model.

Adds perturbations to the hidden states of the given module and replaces the hidden states with the perturbed states.

Parameters:
  • model (Module) – The neural-network module to which hidden-state perturbations are added.

  • *model_args – Additional positional arguments passed to the model.

  • **model_kwargs – Additional keyword arguments passed to the model.

Returns:

Information about the perturbations added to the hidden states, including the perturbed variables, paths, states, and the revised jaxpr.

Return type:

HiddenPerturbation

See also

add_hidden_perturbation_from_minfo

Equivalent helper starting from ModuleInfo.

Examples

>>> import brainstate
>>> import braintrace
>>> gru = braintrace.nn.GRUCell(3, 4)
>>> _ = brainstate.nn.init_all_states(gru)
>>> inputs = brainstate.random.randn(3)
>>> hidden_perturb = braintrace.add_hidden_perturbation_in_module(gru, inputs)
>>> len(hidden_perturb.perturb_vars)
1