HiddenPerturbation#
- class braintrace.HiddenPerturbation(perturb_vars: Sequence[Var], perturb_hidden_paths: Sequence[Tuple[str, ...]], perturb_hidden_states: Sequence[HiddenState], perturb_jaxpr: ClosedJaxpr)#
The hidden-perturbation information.
Hidden perturbation adds a perturbation variable to each hidden state in the jaxpr and replaces the hidden states with the perturbed states:
\[h^t = f(x) \;\Rightarrow\; h^t = f(x) + \mathrm{perturb\_var},\]where \(h\) is the hidden state, \(f\) is the function, \(x\) is the input, and \(\mathrm{perturb\_var}\) is the perturbation variable.
- perturb_vars#
The perturbation variables.
- Type:
sequence of Var
- perturb_hidden_paths#
The hidden-state paths that are perturbed.
- Type:
sequence of Path
- perturb_hidden_states#
The hidden states that are perturbed.
- Type:
sequence of brainstate.HiddenState
- perturb_jaxpr#
The perturbed jaxpr.
- Type:
ClosedJaxpr
See also
add_hidden_perturbation_in_moduleBuild perturbations directly from a model.
Notes
Internally a new variable \(\hat{h}^t = f(x)\) is defined and an extra equation \(h^t = \hat{h}^t + \mathrm{perturb\_var}\) is added. The perturbation lets the hidden-state gradient be read off the perturbation variable
\[\frac{\partial L^t}{\partial h^t} = \frac{\partial L^t}{\partial \mathrm{perturb\_var}}.\]Examples
>>> import brainstate >>> import braintrace >>> gru = braintrace.nn.GRUCell(3, 4) >>> _ = brainstate.nn.init_all_states(gru) >>> inputs = brainstate.random.randn(3) >>> hidden_perturb = braintrace.add_hidden_perturbation_in_module(gru, inputs) >>> isinstance(hidden_perturb, braintrace.HiddenPerturbation) True
- dict()[source]#
Return this perturbation’s named fields as a plain dictionary.
- Returns:
An ordered mapping from field name to value, as produced by the underlying
typing.NamedTuple.- Return type:
- init_perturb_data()[source]#
Initialize the perturbation data to zeros.
- Returns:
One zero array per perturbation variable, matching its shape and dtype.
- Return type:
Sequence[Array]
- perturb_data_to_hidden_group_data(perturb_data, hidden_groups)[source]#
Convert the perturbation data to per-hidden-group data.
- Parameters:
perturb_data (
Sequence[Array]) – The perturbation values, one per entry ofperturb_vars.hidden_groups (
Sequence[HiddenGroup]) – The hidden groups to map the perturbation data onto.
- Returns:
One concatenated perturbation array per hidden group.
- Return type:
Sequence[Array]- Raises:
AssertionError – If
perturb_datadoes not have the same length asperturb_vars.
- perturb_jaxpr: ClosedJaxpr#
Alias for field number 3