HiddenPerturbation#

class braintrace.HiddenPerturbation(perturb_vars: Sequence[Var], perturb_hidden_paths: Sequence[Tuple[str, ...]], perturb_hidden_states: Sequence[HiddenState], perturb_jaxpr: ClosedJaxpr)#

The hidden-perturbation information.

Hidden perturbation adds a perturbation variable to each hidden state in the jaxpr and replaces the hidden states with the perturbed states:

\[h^t = f(x) \;\Rightarrow\; h^t = f(x) + \mathrm{perturb\_var},\]

where \(h\) is the hidden state, \(f\) is the function, \(x\) is the input, and \(\mathrm{perturb\_var}\) is the perturbation variable.

perturb_vars#

The perturbation variables.

Type:

sequence of Var

perturb_hidden_paths#

The hidden-state paths that are perturbed.

Type:

sequence of Path

perturb_hidden_states#

The hidden states that are perturbed.

Type:

sequence of brainstate.HiddenState

perturb_jaxpr#

The perturbed jaxpr.

Type:

ClosedJaxpr

See also

add_hidden_perturbation_in_module

Build perturbations directly from a model.

Notes

Internally a new variable \(\hat{h}^t = f(x)\) is defined and an extra equation \(h^t = \hat{h}^t + \mathrm{perturb\_var}\) is added. The perturbation lets the hidden-state gradient be read off the perturbation variable

\[\frac{\partial L^t}{\partial h^t} = \frac{\partial L^t}{\partial \mathrm{perturb\_var}}.\]

Examples

>>> import brainstate
>>> import braintrace
>>> gru = braintrace.nn.GRUCell(3, 4)
>>> _ = brainstate.nn.init_all_states(gru)
>>> inputs = brainstate.random.randn(3)
>>> hidden_perturb = braintrace.add_hidden_perturbation_in_module(gru, inputs)
>>> isinstance(hidden_perturb, braintrace.HiddenPerturbation)
True
dict()[source]#

Return this perturbation’s named fields as a plain dictionary.

Returns:

An ordered mapping from field name to value, as produced by the underlying typing.NamedTuple.

Return type:

Dict[str, Any]

eval_jaxpr(inputs, perturb_data)[source]#

Evaluate the perturbed jaxpr.

Parameters:
  • inputs (Sequence[Array]) – The flat input values of the original jaxpr.

  • perturb_data (Sequence[Array]) – The perturbation values, one per entry of perturb_vars.

Returns:

The outputs of the perturbed jaxpr.

Return type:

Sequence[Array]

init_perturb_data()[source]#

Initialize the perturbation data to zeros.

Returns:

One zero array per perturbation variable, matching its shape and dtype.

Return type:

Sequence[Array]

perturb_data_to_hidden_group_data(perturb_data, hidden_groups)[source]#

Convert the perturbation data to per-hidden-group data.

Parameters:
  • perturb_data (Sequence[Array]) – The perturbation values, one per entry of perturb_vars.

  • hidden_groups (Sequence[HiddenGroup]) – The hidden groups to map the perturbation data onto.

Returns:

One concatenated perturbation array per hidden group.

Return type:

Sequence[Array]

Raises:

AssertionError – If perturb_data does not have the same length as perturb_vars.

perturb_hidden_paths: Sequence[Tuple[str, ...]]#

Alias for field number 1

perturb_hidden_states: Sequence[HiddenState]#

Alias for field number 2

perturb_jaxpr: ClosedJaxpr#

Alias for field number 3

perturb_vars: Sequence[Var]#

Alias for field number 0