ETraceVjpGraphExecutor#
- class braintrace.ETraceVjpGraphExecutor(model, vjp_method='single-step')#
The eligibility trace graph executor for the VJP-based online learning algorithms.
This class is used for executing the eligibility trace graph for the VJP-based online learning algorithms, including:
pp_prop(aliasesES_D_RTRL/IODimVjpAlgorithm) for the algorithm with input-output dimensional complexity.ParamDimVjpAlgorithm(aliasD_RTRL) for the algorithm with parameter dimensional complexity.
- Parameters:
model (
Module) – The model to build the eligibility trace graph. The models should only define the one-step behavior.vjp_method (
str) –The method for computing the VJP. It should be either
"single-step"or"multi-step". Default is"single-step"."single-step": The VJP is computed at the current time step, i.e., \(\partial L^t/\partial h^t\)."multi-step": The VJP is computed at multiple time steps, i.e., \(\partial L^t/\partial h^{t-k}\), where \(k\) is determined by the data input.
- compile_graph(*args)[source]#
Building the eligibility trace graph for the model according to the given inputs.
This is the most important method for the eligibility trace graph. It builds the graph for the model, which is used for computing the weight spatial gradients and the hidden state Jacobian.
- Parameters:
*args – The positional arguments for the model.
- Return type:
- property is_multi_step_vjp: bool#
Whether the VJP method is
multi-step.- Returns:
Whether the VJP method is
multi-step.- Return type:
- property is_single_step_vjp: bool#
Whether the VJP method is
single-step.- Returns:
Whether the VJP method is
single-step.- Return type:
- solve_h2w_h2h_jacobian(*args, etrace_stepper=None, init_etrace=None)[source]#
Solving the hidden-to-weight and hidden-to-hidden Jacobian according to the given inputs and parameters.
This function is typically used for computing the forward propagation of hidden-to-weight Jacobian.
- Parameters:
*args – The positional arguments for the model.
etrace_stepper (
Optional[Callable]) – A per-step eligibility-trace update callback with signature(etrace_carry, (x_dict, df_dict, diag_list)) -> (new_carry, None). When provided together with multi-step input, the trace roll is fused into the model-forward scan (single loop, no stacked Jacobians) and the final trace is returned in the last slot. WhenNone(default), the per-step Jacobians are stacked and returned as before.init_etrace (
Any) – The initial eligibility-trace carry, threaded through the scan whenetrace_stepperis given. Ignored otherwise.
- Returns:
The outputs, hidden states, other states, the spatial gradients of the weights, the hidden-to-hidden Jacobian, and the final eligibility trace. When
etrace_stepperis given the two Jacobian slots areNoneand the last slot holds the fused trace; otherwise the last slot isNone. Return the single-step results if inputs do not contain multiple-step data, otherwise return the multi-step data.- Return type:
Tuple[Any,Dict[Tuple[str,...],Any],Dict[Tuple[str,...],Any],Tuple[Dict[int,Array],Dict[Tuple[int,str],Array]],Sequence[Array],Any]
Notes
For the state transition function \(y, h^t = f(h^{t-1}, \theta, x)\), this function aims to solve:
The function output \(y\).
The updated hidden states \(h^t\).
The Jacobian matrix of hidden-to-weight, i.e., \(\partial h^t / \partial \theta^t\).
The Jacobian matrix of hidden-to-hidden, i.e., \(\partial h^t / \partial h^{t-1}\).
- solve_h2w_h2h_l2h_jacobian(*args, etrace_stepper=None, init_etrace=None)[source]#
Solving the hidden-to-weight and hidden-to-hidden Jacobian and the VJP transformed loss-to-hidden gradients according to the given inputs.
This function is typically used for computing both the forward propagation of hidden-to-weight Jacobian and the loss-to-hidden gradients at the current time-step.
- Parameters:
*args – The positional arguments for the model.
etrace_stepper (
Optional[Callable]) – A per-step eligibility-trace update callback with signature(etrace_carry, (x_dict, df_dict, diag_list)) -> (new_carry, None). When provided together with multi-step input, the trace roll is fused into the over-time scan (so the per-step Jacobians are never stacked) and the final trace is returned in the last slot instead. The callback andinit_etraceare captured by closure, not passed tojax.vjp, so they never participate in reverse-mode differentiation.init_etrace (
Any) – The initial eligibility-trace carry, threaded through the scan whenetrace_stepperis given. Ignored otherwise.
- Returns:
The outputs, hidden states, other states, the spatial gradients of the weights, the hidden-to-hidden Jacobian, the residuals, and the final eligibility trace. When
etrace_stepperis given the two Jacobian slots areNoneand the last slot holds the fused trace; otherwise the last slot isNone.- Return type:
Tuple[Any,Dict[Tuple[str,...],Any],Dict[Tuple[str,...],Any],Tuple[Dict[int,Array],Dict[Tuple[int,str],Array]],Sequence[Array],VjpResiduals,Any]
Notes
Particularly, this function aims to solve:
The Jacobian matrix of hidden-to-weight. That is, \(\partial h / \partial w\), where \(h\) is the hidden state and \(w\) is the weight.
The Jacobian matrix of hidden-to-hidden. That is, \(\partial h / \partial h\), where \(h\) is the hidden state.
The partial gradients of the loss with respect to the hidden states. That is, \(\partial L / \partial h\), where \(L\) is the loss and \(h\) is the hidden state.