ETraceVjpGraphExecutor#

class braintrace.ETraceVjpGraphExecutor(model, vjp_method='single-step')#

The eligibility trace graph executor for the VJP-based online learning algorithms.

This class is used for executing the eligibility trace graph for the VJP-based online learning algorithms, including:

Parameters:
  • model (Module) – The model to build the eligibility trace graph. The models should only define the one-step behavior.

  • vjp_method (str) –

    The method for computing the VJP. It should be either "single-step" or "multi-step". Default is "single-step".

    • "single-step": The VJP is computed at the current time step, i.e., \(\partial L^t/\partial h^t\).

    • "multi-step": The VJP is computed at multiple time steps, i.e., \(\partial L^t/\partial h^{t-k}\), where \(k\) is determined by the data input.

compile_graph(*args)[source]#

Building the eligibility trace graph for the model according to the given inputs.

This is the most important method for the eligibility trace graph. It builds the graph for the model, which is used for computing the weight spatial gradients and the hidden state Jacobian.

Parameters:

*args – The positional arguments for the model.

Return type:

None

property is_multi_step_vjp: bool#

Whether the VJP method is multi-step.

Returns:

Whether the VJP method is multi-step.

Return type:

bool

property is_single_step_vjp: bool#

Whether the VJP method is single-step.

Returns:

Whether the VJP method is single-step.

Return type:

bool

solve_h2w_h2h_jacobian(*args, etrace_stepper=None, init_etrace=None)[source]#

Solving the hidden-to-weight and hidden-to-hidden Jacobian according to the given inputs and parameters.

This function is typically used for computing the forward propagation of hidden-to-weight Jacobian.

Parameters:
  • *args – The positional arguments for the model.

  • etrace_stepper (Optional[Callable]) – A per-step eligibility-trace update callback with signature (etrace_carry, (x_dict, df_dict, diag_list)) -> (new_carry, None). When provided together with multi-step input, the trace roll is fused into the model-forward scan (single loop, no stacked Jacobians) and the final trace is returned in the last slot. When None (default), the per-step Jacobians are stacked and returned as before.

  • init_etrace (Any) – The initial eligibility-trace carry, threaded through the scan when etrace_stepper is given. Ignored otherwise.

Returns:

The outputs, hidden states, other states, the spatial gradients of the weights, the hidden-to-hidden Jacobian, and the final eligibility trace. When etrace_stepper is given the two Jacobian slots are None and the last slot holds the fused trace; otherwise the last slot is None. Return the single-step results if inputs do not contain multiple-step data, otherwise return the multi-step data.

Return type:

Tuple[Any, Dict[Tuple[str, ...], Any], Dict[Tuple[str, ...], Any], Tuple[Dict[int, Array], Dict[Tuple[int, str], Array]], Sequence[Array], Any]

Notes

For the state transition function \(y, h^t = f(h^{t-1}, \theta, x)\), this function aims to solve:

  1. The function output \(y\).

  2. The updated hidden states \(h^t\).

  3. The Jacobian matrix of hidden-to-weight, i.e., \(\partial h^t / \partial \theta^t\).

  4. The Jacobian matrix of hidden-to-hidden, i.e., \(\partial h^t / \partial h^{t-1}\).

solve_h2w_h2h_l2h_jacobian(*args, etrace_stepper=None, init_etrace=None)[source]#

Solving the hidden-to-weight and hidden-to-hidden Jacobian and the VJP transformed loss-to-hidden gradients according to the given inputs.

This function is typically used for computing both the forward propagation of hidden-to-weight Jacobian and the loss-to-hidden gradients at the current time-step.

Parameters:
  • *args – The positional arguments for the model.

  • etrace_stepper (Optional[Callable]) – A per-step eligibility-trace update callback with signature (etrace_carry, (x_dict, df_dict, diag_list)) -> (new_carry, None). When provided together with multi-step input, the trace roll is fused into the over-time scan (so the per-step Jacobians are never stacked) and the final trace is returned in the last slot instead. The callback and init_etrace are captured by closure, not passed to jax.vjp, so they never participate in reverse-mode differentiation.

  • init_etrace (Any) – The initial eligibility-trace carry, threaded through the scan when etrace_stepper is given. Ignored otherwise.

Returns:

The outputs, hidden states, other states, the spatial gradients of the weights, the hidden-to-hidden Jacobian, the residuals, and the final eligibility trace. When etrace_stepper is given the two Jacobian slots are None and the last slot holds the fused trace; otherwise the last slot is None.

Return type:

Tuple[Any, Dict[Tuple[str, ...], Any], Dict[Tuple[str, ...], Any], Tuple[Dict[int, Array], Dict[Tuple[int, str], Array]], Sequence[Array], VjpResiduals, Any]

Notes

Particularly, this function aims to solve:

  1. The Jacobian matrix of hidden-to-weight. That is, \(\partial h / \partial w\), where \(h\) is the hidden state and \(w\) is the weight.

  2. The Jacobian matrix of hidden-to-hidden. That is, \(\partial h / \partial h\), where \(h\) is the hidden state.

  3. The partial gradients of the loss with respect to the hidden states. That is, \(\partial L / \partial h\), where \(L\) is the loss and \(h\) is the hidden state.