# Copyright 2025 BrainX Ecosystem Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from typing import Callable, Dict, Tuple, Any, List, Optional, Sequence
import brainstate
import jax
import jax.numpy as jnp
import saiunit as u
from braintrace._input_data import has_multistep_data
from braintrace._state_managment import assign_state_values_v2
from braintrace._typing import (
Path,
PyTree,
Outputs,
WeightID,
WeightVals,
HiddenVals,
StateVals,
ETraceVals,
Hid2WeightJacobian,
dG_Inputs,
dG_Weight,
dG_Hidden,
dG_State,
)
from .base import ETraceAlgorithm
from .vjp_graph_executor import ETraceVjpGraphExecutor
__all__ = [
'ETraceVjpAlgorithm', # the base class for the eligibility trace algorithm with the VJP gradient computation
]
class ETraceVjpAlgorithm(ETraceAlgorithm):
r"""
The base class for the eligibility trace algorithm supporting the VJP gradient
computation (reverse-mode differentiation).
The term ``VJP`` comes from two aspects. First, this module is designed to be
compatible with JAX's VJP mechanism, so the gradient is computed according to the
reverse-mode differentiation interface, like :func:`jax.grad`, :func:`jax.vjp`, or
:func:`jax.jacrev`. The true update function is defined as a custom VJP function
``._true_update_fun()``, which receives the inputs, the hidden states, other states,
and etrace variables at the last time step, and returns the outputs, the hidden
states, other states, and etrace variables at the current time step. Second, the
algorithm computes the spatial gradient :math:`\partial L^t / \partial H^t` using the
standard back-propagation algorithm, which enhances the accuracy and the stability of
the gradient computation.
Parameters
----------
model : brainstate.nn.Module
The model function, which receives the input arguments and returns the model output.
name : str, optional
The name of the etrace algorithm.
vjp_method : str, optional
The method for computing the VJP. It should be either ``"single-step"`` or
``"multi-step"``. Default is ``"single-step"``.
- ``"single-step"``: The VJP is computed at the current time step, i.e.,
:math:`\partial L^t/\partial h^t`.
- ``"multi-step"``: The VJP is computed at multiple time steps, i.e.,
:math:`\partial L^t/\partial h^{t-k}`, where :math:`k` is determined by the
data input.
Notes
-----
For each subclass (or the instance of an etrace algorithm), the following methods
define the custom VJP rule:
- ``._update()``: update the eligibility trace states and return the outputs, hidden
states, other states, and etrace data.
- ``._update_fwd()``: the forward pass of the custom VJP rule.
- ``._update_bwd()``: the backward pass of the custom VJP rule.
This class provides a default implementation for the ``._update()``,
``._update_fwd()``, and ``._update_bwd()`` methods. To implement a new etrace
algorithm, users just need to override the following methods:
- ``._solve_weight_gradients()``: solve the gradients of the learnable weights / parameters.
- ``._update_etrace_data()``: update the eligibility trace data.
- ``._assign_etrace_data()``: assign the eligibility trace data to the states.
- ``._get_etrace_data()``: get the eligibility trace data.
"""
__module__ = 'braintrace'
graph_executor: ETraceVjpGraphExecutor
def __init__(
self,
model: brainstate.nn.Module,
name: Optional[str] = None,
vjp_method: str = 'single-step'
):
# the VJP method
assert vjp_method in ('single-step', 'multi-step'), (
'The VJP method should be either "single-step" or "multi-step". '
f'While we got {vjp_method}. '
)
self.vjp_method = vjp_method
# graph
graph_executor = ETraceVjpGraphExecutor(model, vjp_method=vjp_method)
# super initialization
super().__init__(model=model, name=name, graph_executor=graph_executor)
# the update rule
self._true_update_fun = jax.custom_vjp(self._update_fn)
self._true_update_fun.defvjp(
fwd=self._update_fn_fwd,
bwd=self._update_fn_bwd
)
def _assert_compiled(self):
if not self.is_compiled:
raise ValueError('The etrace algorithm has not been compiled. Please call `compile_graph()` first. ')
[docs]
def update(self, *args) -> Any:
r"""
Update the model states and the eligibility trace.
The input arguments ``args`` here support very complex data structures, including
the combination of :py:class:`SingleStepData` and :py:class:`MultiStepData`.
- :py:class:`SingleStepData`: indicating the data at the single time step,
:math:`x_t`.
- :py:class:`MultiStepData`: indicating the data at multiple time steps,
:math:`[x_{t-k}, ..., x_t]`.
Parameters
----------
*args
The input arguments.
Returns
-------
Any
The model output.
Notes
-----
Suppose all inputs have the shape of ``(10,)``.
If the input arguments are given by:
.. code-block:: python
x = [jnp.ones((10,)), jnp.zeros((10,))]
Then, two input arguments are considered as the :py:class:`SingleStepData`.
If the input arguments are given by:
.. code-block:: python
x = [braintrace.SingleStepData(jnp.ones((10,))),
braintrace.SingleStepData(jnp.zeros((10,)))]
This is the same as the previous case, they are all considered as the input at the current time step.
If the input arguments are given by:
.. code-block:: python
x = [braintrace.MultiStepData(jnp.ones((5, 10)),
jnp.zeros((10,)))]
or,
.. code-block:: python
x = [braintrace.MultiStepData(jnp.ones((5, 10)),
braintrace.SingleStepData(jnp.zeros((10,)))]
Then, the first input argument is considered as the :py:class:`MultiStepData`, and its data will
be fed into the model within five consecutive steps, and the second input argument will be fed
into the model at each time of this five consecutive steps.
"""
# ----------------------------------------------------------------------------------------------
#
# This method is the main function to
#
# - update the model
# - update the eligibility trace states
# - compute the weight gradients
#
# The key here is that we change the object-oriented attributes as the function arguments.
# Therefore, the function arguments are the states of the current time step, and the function
# returns the states of the next time step.
#
# Particularly, the model calls the "_true_update_fun()" function to update the states.
#
# ----------------------------------------------------------------------------------------------
#
# This function need to process the following multiple cases:
#
# 1. if vjp_method = 'single-step', input = SingleStepData, then output is single step
#
# 2. if vjp_method = 'single-step', input = MultiStepData, then output is multiple step data
#
# 3. if vjp_method = 'multi-step', input = SingleStepData, then output is single step
#
# 4. if vjp_method = 'multi-step', input = MultiStepData, then output is multiple step data
#
# check the compilation
self._assert_compiled()
# state values
weight_vals = {
key: st.value
for key, st in self.param_states.items()
}
hidden_vals = {
key: st.value
for key, st in self.hidden_states.items()
}
other_vals = {
key: st.value
for key, st in self.other_states.items()
}
# etrace data
last_etrace_vals = self._get_etrace_data()
# update all states
#
# [KEY] The key here is that we change the object-oriented attributes as the function arguments.
# Therefore, the function arguments are the states of the current time step, and the function
# returns the states of the next time step.
#
# out: is always multiple step
(
out,
hidden_vals,
other_vals,
new_etrace_vals
) = self._true_update_fun(
args,
weight_vals,
hidden_vals,
other_vals,
last_etrace_vals,
self.running_index.value
)
# assign/restore the weight values back
#
# [KEY] assuming the weight values are not changed
# This is a key assumption in the RTRL algorithm.
# This is very important for the implementation.
assign_state_values_v2(self.param_states, weight_vals, write=False)
# assign the new hidden and state values
assign_state_values_v2(self.hidden_states, hidden_vals)
assign_state_values_v2(self.other_states, other_vals)
#
# assign the new etrace values
#
# "self._assign_etrace_data()" is a protocol method that should be implemented in the subclass.
# It's logic may be different for different etrace algorithms.
#
self._assign_etrace_data(new_etrace_vals) # call the protocol method
# update the running index
running_index = self.running_index.value + 1
self.running_index.value = jax.lax.stop_gradient(jnp.where(running_index >= 0, running_index, 0))
# return the model output
return out
def _update_fn(
self,
args,
weight_vals: WeightVals,
hidden_vals: HiddenVals,
oth_state_vals: StateVals,
etrace_vals: ETraceVals,
running_index,
) -> Tuple[Outputs, HiddenVals, StateVals, ETraceVals]:
"""
The main function to update the [model] and the [eligibility trace] states.
Particularly, ``self.graph.solve_h2w_h2h_jacobian()`` is called to:
- compute the model output, the hidden states, and the other states
- compute the hidden-to-weight Jacobian and the hidden-to-hidden Jacobian
Then, ``self._update_etrace_data`` is called to:
- update the eligibility trace data
Moreover, this function returns:
- the model output
- the updated hidden states
- the updated other states
- the updated eligibility trace states
Note that the weight values are assumed not changed in this function.
"""
input_is_multi_step = has_multistep_data(*args)
# state value assignment
assign_state_values_v2(self.param_states, weight_vals, write=False)
assign_state_values_v2(self.hidden_states, hidden_vals, write=False)
assign_state_values_v2(self.other_states, oth_state_vals, write=False)
# When the trace roll can be fused into the executor's over-time scan
# (multi-step input + a fusable subclass), hand the per-step stepper down
# so the executor updates the trace in-loop and returns the final trace,
# avoiding a second scan over stacked Jacobians.
etrace_stepper = self._make_etrace_stepper(weight_vals) if input_is_multi_step else None
# necessary jacobian information of the weights
(
out,
hidden_vals,
oth_state_vals,
hid2weight_jac_single_or_multi_steps,
hid2hid_jac_single_or_multi_steps,
final_etrace,
) = self.graph_executor.solve_h2w_h2h_jacobian(
*args,
etrace_stepper=etrace_stepper,
init_etrace=etrace_vals if etrace_stepper is not None else None,
)
if final_etrace is not None:
# fused path: the executor already rolled the eligibility trace in-loop.
etrace_vals = final_etrace
else:
# eligibility trace update
#
# "self._update_etrace_data()" is a protocol method that should be implemented in the subclass.
# It's logic may be different for different etrace algorithms.
#
etrace_vals = self._update_etrace_data(
running_index,
etrace_vals,
hid2weight_jac_single_or_multi_steps,
hid2hid_jac_single_or_multi_steps,
weight_vals,
input_is_multi_step,
)
# returns
return out, hidden_vals, oth_state_vals, etrace_vals
def _update_fn_fwd(
self,
args,
weight_vals: WeightVals,
hidden_vals: HiddenVals,
othstate_vals: StateVals,
etrace_vals: ETraceVals,
running_index: int,
) -> Tuple[Tuple[Outputs, HiddenVals, StateVals, ETraceVals], Any]:
"""
The forward function to update the [model] and the [eligibility trace] states when computing
the VJP gradients.
Particularly, ``self.graph.solve_h2w_h2h_jacobian_and_l2h_vjp()`` is called to:
- compute the model output, the hidden states, and the other states
- compute the hidden-to-weight Jacobian and the hidden-to-hidden Jacobian
- compute the loss-to-hidden or loss-to-weight Jacobian
Then, ``self._update_etrace_data`` is called to:
- update the eligibility trace data
The forward function returns two parts of data:
- The first part is the functional returns (same as "self._update()" function):
* the model output
* the updated hidden states
* the updated other states
* the updated eligibility trace states
- The second part is the data used for backward gradient computation:
* the residuals of the model
* the eligibility trace data at the current/last time step
* the weight id to its value mapping
* the running index
"""
input_is_multi_step = has_multistep_data(*args)
# state value assignment
assign_state_values_v2(self.param_states, weight_vals, write=False)
assign_state_values_v2(self.hidden_states, hidden_vals, write=False)
assign_state_values_v2(self.other_states, othstate_vals, write=False)
# As in ``_update_fn``: when fusable + multi-step, push the stepper down so
# the executor rolls the trace inside the same scan that builds the VJP
# residual. The trace carry is detached (stop_gradient), so it never enters
# the residual jaxpr.
etrace_stepper = self._make_etrace_stepper(weight_vals) if input_is_multi_step else None
# necessary gradients of the weights
(
out,
hiddens,
oth_states,
hid2weight_jac_single_or_multi_steps,
hid2hid_jac_single_or_multi_steps,
residuals,
final_etrace,
) = self.graph_executor.solve_h2w_h2h_l2h_jacobian(
*args,
etrace_stepper=etrace_stepper,
init_etrace=etrace_vals if etrace_stepper is not None else None,
)
if final_etrace is not None:
# fused path: the executor already rolled the eligibility trace in-loop.
new_etrace_vals = final_etrace
else:
# eligibility trace update
#
# "self._update_etrace_data()" is a protocol method that should be implemented in the subclass.
# It's logic may be different for different etrace algorithms.
#
new_etrace_vals = self._update_etrace_data(
running_index,
etrace_vals,
hid2weight_jac_single_or_multi_steps,
hid2hid_jac_single_or_multi_steps,
weight_vals,
input_is_multi_step
)
# returns
old_etrace_vals = etrace_vals
fwd_out = (out, hiddens, oth_states, new_etrace_vals)
fwd_res = (
residuals,
(
old_etrace_vals
if self.graph_executor.is_multi_step_vjp else
new_etrace_vals
),
weight_vals,
running_index,
args, # threaded to _update_fn_bwd for the learning-signal hook
)
return fwd_out, fwd_res
def _update_fn_bwd(
self,
fwd_res,
grads,
) -> Tuple[dG_Inputs, dG_Weight, dG_Hidden, dG_State, None, None]:
"""
The backward function to compute the VJP gradients when the learning signal is arrived at
this time step.
There are three steps:
1. Interpret the forward results (eligibility trace) and top-down gradients (learning signal)
2. Compute the gradients of input arguments
(maybe necessary, but it can be optimized away but the XLA compiler)
3. Compute the gradients of the weights
"""
# [1] Interpret the fwd results
#
(
residuals, # the residuals of the VJP computation, for computing the gradients of input arguments
etrace_vals_at_t_or_t_minus_1, # the eligibility trace data at the current or last time step
weight_vals, # the weight id to its value mapping
running_index, # the running index
args, # original update(*args) tuple, used by _compute_learning_signal
) = fwd_res
(
jaxpr,
in_tree,
out_tree,
consts
) = residuals
# [2] Interpret the top-down gradient signals
#
# Since
#
# dg_out, dg_hiddens, dg_others, dg_etrace = grads
#
# we need to remove the "dg_etrace" iterm from the gradients for matching
# the jaxpr vjp gradients.
#
grad_flat, grad_tree = jax.tree.flatten((grads[:-1],))
# [3] Compute the gradients of the input arguments
# It may be unnecessary, but it can be optimized away by the XLA compiler after it is computed.
#
# The input argument gradients are computed through the normal back-propagation algorithm.
#
if out_tree != grad_tree:
raise TypeError(
f'Gradient tree should be the same as the function output tree. '
f'While we got: \n'
f'out_tree = {out_tree}\n!=\n'
f'grad_tree = {grad_tree}'
)
cts_out = jax.core.eval_jaxpr(jaxpr, consts, *grad_flat)
#
# We compute:
#
# - the gradients of input arguments,
# maybe necessary to propagate the gradients to the last layer
#
# - the gradients of the hidden states at the last time step,
# maybe unnecessary but can be optimized away by the XLA compiler
#
# - the gradients of the non-etrace parameters, defined by "NonTempParam"
#
# - the gradients of the other states
#
# - the gradients of the loss-to-hidden at the current time step
#
# the `_jaxpr_compute_model_with_vjp()` in `ETraceGraphExecutor`
(
dg_args,
dg_last_hiddens,
dg_non_etrace_params,
dg_etrace_params,
dg_oth_states,
dg_hid_perturb_or_dl2h
) = jax.tree.unflatten(in_tree, cts_out)
#
# get the gradients of the hidden states at the last time step
#
if self.graph_executor.is_single_step_vjp:
# TODO: the correspondence between the hidden states and the gradients
# should be checked.
#
assert len(dg_etrace_params) == 0 # gradients all etrace weights are updated by the RTRL algorithm
assert self.graph.hidden_perturb is not None
assert len(self.graph.hidden_perturb.perturb_vars) == len(dg_hid_perturb_or_dl2h)
dl2h_at_t_or_t_minus_1 = self.graph.hidden_perturb.perturb_data_to_hidden_group_data(
dg_hid_perturb_or_dl2h, self.graph.hidden_groups,
)
else:
assert len(dg_last_hiddens) == len(self.hidden_states)
assert set(dg_last_hiddens.keys()) == set(self.hidden_states.keys()), (
f'The hidden states should be the same. Bug got \n'
f'{set(dg_last_hiddens.keys())}\n'
f'!=\n'
f'{set(self.hidden_states.keys())}'
)
dl2h_at_t_or_t_minus_1 = [
group.concat_hidden(
[
# dimensionless processing
u.get_mantissa(dg_last_hiddens[path])
for path in group.hidden_paths
]
)
for group in self.graph.hidden_groups
]
#
# Hook: subclasses may replace the reverse-AD learning signal with an
# alternative (e.g. target projection in OSTTP, κ-filtered signal in EProp).
#
dl2h_at_t_or_t_minus_1 = self._compute_learning_signal(
dl2h_at_t_or_t_minus_1, args
)
#
# [4] Compute the gradients of the weights
#
# the gradients of the weights are computed through the RTRL algorithm.
#
# "self._solve_weight_gradients()" is a protocol method that should be implemented in the subclass.
# It's logic may be different for different etrace algorithms.
#
dg_weights = self._solve_weight_gradients(
running_index,
etrace_vals_at_t_or_t_minus_1,
dl2h_at_t_or_t_minus_1,
weight_vals,
dg_non_etrace_params,
dg_etrace_params,
)
# Note that there are no gradients flowing through the etrace data and the running index.
dg_etrace = None
dg_running_index = None
return (
dg_args,
dg_weights,
dg_last_hiddens,
dg_oth_states,
dg_etrace,
dg_running_index
)
def _compute_learning_signal(
self,
dl_to_hidden_from_autodiff: Sequence[jax.Array],
args: tuple,
) -> Sequence[jax.Array]:
"""Override hook. Return the learning signal used by `_solve_weight_gradients`.
Default returns the reverse-AD gradient unchanged. Subclasses that need
target projection (OSTTP) or any other alternative can override this.
Args:
dl_to_hidden_from_autodiff: Sequence of per-hidden-group gradients produced
by reverse-AD inside `_update_fn_bwd`.
args: The exact `*args` tuple passed to the most recent `update()` call,
made available so subclasses can pull auxiliary tensors (e.g.
``y_target``) that were stashed elsewhere (e.g. on ``self``).
Returns:
Sequence of per-hidden-group gradient arrays, one per HiddenGroup. Must
match the shape and length of ``dl_to_hidden_from_autodiff``.
"""
return dl_to_hidden_from_autodiff
def _solve_weight_gradients(
self,
running_index: int,
# The eligibility-trace container and the weight-value mapping are keyed
# differently per algorithm (e.g. Path- vs WeightID-keyed), so this
# abstract hook leaves their concrete types implementation-defined.
etrace_h2w_at_t: Any,
dl_to_hidden_groups: Sequence[jax.Array],
weight_vals: Any,
dl_to_nonetws_at_t: Dict[Path, PyTree],
dl_to_etws_at_t: Optional[Dict[Path, PyTree]],
):
r"""
The method to solve the weight gradients, i.e., :math:`\partial L / \partial W`.
.. note::
This is the protocol method that should be implemented in the subclass.
Particularly, the weight gradients are computed through::
.. math::
\frac{\partial L^t}{\partial W} = \frac{\partial L^t}{\partial h^t} \frac{\partial h^t}{\partial W}
Or,
.. math::
\frac{\partial L^t}{\partial W} = \frac{\partial L^{t-1}}{\partial h^{t-1}}
\frac{\partial h^{t-1}}{\partial W}
+ \frac{\partial L^t}{\partial W^t}
Args:
running_index: Optional[int], the running index.
etrace_h2w_at_t: Any, the eligibility trace data (which track the hidden-to-weight Jacobian)
that have accumulated util the time ``t``.
dl_to_hidden_groups: Dict[HiddenOutVar, jax.Array], the gradients of the loss-to-hidden
at the time ``t``.
weight_vals: Dict[WeightID, PyTree], the weight values.
dl_to_nonetws_at_t: List[PyTree], the gradients of the loss-to-non-etrace parameters
at the time ``t``, i.e., :math:``\partial L^t / \partial W^t``.
dl_to_etws_at_t: List[PyTree], the gradients of the loss-to-etrace parameters
at the time ``t``, i.e., :math:``\partial L^t / \partial W^t``.
"""
raise NotImplementedError
def _make_etrace_stepper(self, weight_vals: WeightVals) -> Optional[Callable]:
"""Return a per-step eligibility-trace update callback, or ``None``.
When a subclass can express its trace roll as a pure step function with
signature ``(etrace_carry, (x_dict, df_dict, diag_list)) -> (new_carry, None)``,
it should override this to return that callback (typically the same
``partial`` it builds inside :meth:`_update_etrace_data`). For multi-step
input the graph executor then fuses the roll into its over-time scan,
eliminating the separate trace scan and the stacked per-step Jacobians.
Returning ``None`` (the default) keeps the legacy two-pass behavior: the
executor stacks the Jacobians and :meth:`_update_etrace_data` rolls the
trace in a second scan. Subclasses whose update cannot be written as such a
step function (or that do not support multi-step input) leave this ``None``.
Parameters
----------
weight_vals : WeightVals
The current parameter values, captured by the returned callback.
Returns
-------
Callable or None
The per-step stepper, or ``None`` to disable scan fusion.
"""
return None
def _update_etrace_data(
self,
running_index: Optional[int],
# The eligibility-trace container type is implementation-defined.
etrace_vals_util_t_1: Any,
hid2weight_jac_single_or_multi_times: Hid2WeightJacobian,
hid2hid_jac_single_or_multi_times: Sequence[jax.Array],
weight_vals: WeightVals,
input_is_multi_step: bool,
) -> Any:
"""
The method to update the eligibility trace data.
.. note::
This is the protocol method that should be implemented in the subclass.
Args:
running_index: Optional[int], the running index.
etrace_vals_util_t_1: ETraceVals, the history eligibility trace data that have accumulated util :math:`t-1`.
hid2weight_jac_single_or_multi_times: ETraceVals, the current eligibility trace data at the time :math:`t`.
hid2hid_jac_single_or_multi_times: The data for computing the hidden-to-hidden Jacobian at the time :math:`t`.
weight_vals: Dict[WeightID, PyTree], the weight values.
Returns:
ETraceVals, the updated eligibility trace data that have accumulated util :math:`t`.
"""
raise NotImplementedError
def _get_etrace_data(self) -> Any:
"""
Get the eligibility trace data at the last time-step.
.. note::
This is the protocol method that should be implemented in the subclass.
Returns:
ETraceVals, the eligibility trace data.
"""
raise NotImplementedError
def _assign_etrace_data(self, etrace_vals: Any) -> None:
"""
Assign the eligibility trace data to the states at the current time-step.
.. note::
This is the protocol method that should be implemented in the subclass.
Args:
etrace_vals: ETraceVals, the eligibility trace data.
"""
raise NotImplementedError