Source code for braintrace._etrace_algorithms.graph_executor

# Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Author: Chaoming Wang <chao.brain@qq.com>
# Copyright: 2024, Chaoming Wang
# Date: 2024-04-03
#
# ==============================================================================
#
# Refinement History:
#   [2024-04-03] Created
#   [2024-04-06] Added the traceback information for the error messages.
#   [2024-04-16] Changed the "op" in the "HiddenWeightOpTracer" to "JaxprEqn".
#                Added the support for the "pjit" operator.
#   [2024-05] Add the support for vjp_time == 't_minus_1'
#   [2024-06] Conditionally support control flows, including `scan`, `while`, and `cond`
#   [2024-09] version 0.0.2
#   [2024-11-22] compatible with `brainstate>=0.1.0` (#17)
#   [2024-11-23] Add the support for vjp_time_ahead > 1, it can combine the
#                advantage of etrace learning and backpropagation through time.
#   [2024-11] version 0.0.3, a complete new revision for better model debugging.
#
# ==============================================================================

# -*- coding: utf-8 -*-

from typing import Dict, Any, Optional

import brainstate

from braintrace._etrace_compiler import ETraceGraph, compile_etrace_graph
from .._input_data import get_single_step_data
from .._typing import Path

__all__ = [
    'ETraceGraphExecutor',
]


class ETraceGraphExecutor:
    r"""
    The eligibility trace graph executor.

    This class is used for computing the weight spatial gradients and the hidden state residuals.
    It is the most foundational class for the ETrace algorithms.

    It is important to note that the graph is built no matter whether the model is
    batched or not. This means that this graph can be applied to any kind of models.
    However, the compilation is sensitive to the shape of hidden states.

    Parameters
    ----------
    model: brainstate.nn.Module
        The model to build the eligibility trace graph. The models should only define the one-step behavior.
    """
    __module__ = 'braintrace'

    def __init__(
        self,
        model: brainstate.nn.Module,
    ):
        # The original model
        if not isinstance(model, brainstate.nn.Module):
            raise TypeError(
                'The model should be an instance of "brainstate.nn.Module" since '
                'we can extract the program structure from the model for '
                'better debugging.'
            )
        self.model = model

        # the compiled graph
        self._compiled_graph: Optional[ETraceGraph] = None
        self._state_id_to_path: Optional[Dict[int, Path]] = None

    @property
    def graph(self) -> ETraceGraph:
        """
        Retrieve the compiled eligibility trace graph for the model.

        This property provides access to the compiled graph, which is a crucial data structure
        for the eligibility trace algorithm. It contains various attributes that describe the
        relationships between the model's variables, states, and operations.

        Returns
        -------
        ETraceGraph
            The compiled graph for the model. This graph includes detailed information about
            the model's structure, such as output variables, state variables,
            hidden-to-hidden variable relationships, and more.

        Raises
        ------
        ValueError
            If the graph has not been compiled yet. Ensure to call the
            :meth:`compile_graph` method before accessing this property.
        """
        if self._compiled_graph is None:
            raise ValueError('The graph is not compiled yet. Please call ".compile_graph()" first.')
        return self._compiled_graph

    @property
    def states(self) -> brainstate.util.FlattedDict[Path, brainstate.State]:
        """
        The states for the model.

        Returns
        -------
        brainstate.util.FlattedDict[Path, brainstate.State]
            The states for the model.
        """
        return self.graph.module_info.retrieved_model_states

    @property
    def path_to_states(self) -> brainstate.util.FlattedDict[Path, brainstate.State]:
        """
        The path to the states.

        Returns
        -------
        brainstate.util.FlattedDict[Path, brainstate.State]
            The path to the states.
        """
        return self.states

    @property
    def state_id_to_path(self) -> Dict[int, Path]:
        """
        The state id to the path.

        Returns
        -------
        Dict[int, Path]
            The mapping from state id to the path.
        """
        if self._state_id_to_path is None:
            self._state_id_to_path = {id(state): path for path, state in self.states.items()}
        return self._state_id_to_path

[docs] def compile_graph(self, *args) -> None: r""" Build the eligibility trace graph for the model based on the provided inputs. This method is crucial for constructing the graph used in the eligibility trace algorithm, which is essential for calculating weight spatial gradients and the hidden state Jacobian. Parameters ---------- *args Positional arguments for the model, which may include inputs, parameters, or other necessary data required for graph compilation. Returns ------- None This method does not return any value. It initializes the compiled graph attribute of the instance. """ # invalidate cached mappings on recompilation self._state_id_to_path = None # process the inputs args = get_single_step_data(*args) # compile the graph self._compiled_graph = compile_etrace_graph(self.model, *args)
[docs] def show_graph( self, verbose: bool = True, return_msg: bool = False, ) -> None | str: """ Display the graph illustrating the relationships between weights, operators, and hidden states. This function generates a detailed message that describes the structure of the graph, including hidden groups, dynamic states, and weight parameters associated with hidden states. It can either print this message to the console or return it as a string. Parameters ---------- verbose : bool, optional If True, the function will print the graph details to the console. Default is True. return_msg : bool, optional If True, the function will return the graph details as a string. Default is False. Returns ------- None or str If `return_msg` is True, returns a string containing the graph details. Otherwise, returns None. """ # hidden group msg = '===' * 40 + '\n' msg += 'The hidden groups are:\n\n' hidden_paths = [] group_mapping = dict() for group in self.graph.hidden_groups: msg += f' Group {group.index}: {group.hidden_paths}\n' group_mapping[id(group)] = group.index hidden_paths.extend(group.hidden_paths) msg += '\n\n' # other hidden states other_states = [] short_states = self.states.filter(brainstate.ShortTermState) for i, path in enumerate(short_states.keys()): if path not in hidden_paths: other_states.append(path) if len(other_states): msg += 'The dynamic (non-hidden) states are:\n\n' for i, path in enumerate(other_states): msg += f' Dynamic state {i}: {path}\n' msg += '\n\n' # etrace weights etratce_weight_paths = set() if len(self.graph.hidden_param_op_relations): msg += 'The weight parameters which are associated with the hidden states are:\n\n' for i, hp_relation in enumerate(self.graph.hidden_param_op_relations): etratce_weight_paths.add(hp_relation.path) group_indices = [group_mapping[id(group)] for group in hp_relation.hidden_groups] if len(group_indices) == 1: msg += f' Weight {i}: {hp_relation.path} is associated with hidden group {group_indices[0]}\n' else: msg += f' Weight {i}: {hp_relation.path} is associated with hidden groups {group_indices}\n' msg += '\n\n' # non etrace weights non_etratce_weight_paths = set(self.states.filter(brainstate.ParamState).keys()) non_etratce_weight_paths = non_etratce_weight_paths.difference(etratce_weight_paths) if len(non_etratce_weight_paths): msg += 'The non-etrace weight parameters are:\n\n' for i, path in enumerate(non_etratce_weight_paths): msg += f' Weight {i}: {path}\n' msg += '\n\n' if verbose: print(msg) if return_msg: return msg return None
[docs] def solve_h2w_h2h_jacobian( self, *args, ) -> Any: r""" Compute the hidden-to-weight and hidden-to-hidden Jacobian matrices. This function is designed to calculate the forward propagation of the hidden-to-weight Jacobian and the hidden-to-hidden Jacobian based on the provided inputs and parameters. It is a crucial part of the eligibility trace algorithm, which helps in understanding the influence of weights and previous hidden states on the current hidden state. Parameters ---------- *args Positional arguments for the model, which may include inputs, parameters, or other necessary data required for the computation of the Jacobians. Returns ------- Any A tuple containing the following elements: - The function output (e.g., model predictions). - The updated hidden states after the current computation step. - Other states that may be relevant to the model's operation. - The spatial gradients of the weights, represented by the hidden-to-weight Jacobian. Raises ------ NotImplementedError This method must be implemented by subclasses. Notes ----- For the state transition function :math:`y, h^t = f(h^{t-1}, \theta, x)`, this function aims to solve: 1. The function output :math:`y`. 2. The updated hidden states :math:`h^t`. 3. The Jacobian matrix of hidden-to-weight, i.e., :math:`\partial h^t / \partial \theta^t`. 4. The Jacobian matrix of hidden-to-hidden, i.e., :math:`\partial h^t / \partial h^{t-1}`. """ raise NotImplementedError('The method "solve_h2w_h2h_jacobian" should be ' 'implemented in the subclass.')
[docs] def solve_h2w_h2h_l2h_jacobian( self, *args, ) -> Any: r""" Compute the hidden-to-weight and hidden-to-hidden Jacobian matrices, along with the VJP transformed loss-to-hidden gradients based on the provided inputs. This function is designed to calculate both the forward propagation of the hidden-to-weight Jacobian and the loss-to-hidden gradients at the current time-step. It is essential for understanding the influence of weights and previous hidden states on the current hidden state, as well as the impact of the loss on the hidden states. Parameters ---------- *args Positional arguments for the model, which may include inputs, parameters, or other necessary data required for the computation of the Jacobians and gradients. Returns ------- Any A tuple containing the following elements: - The function output (e.g., model predictions). - The updated hidden states after the current computation step. - Other states that may be relevant to the model's operation. - The spatial gradients of the weights, represented by the hidden-to-weight Jacobian. - The residuals, which are the partial gradients of the loss with respect to the hidden states. Raises ------ NotImplementedError This method must be implemented by subclasses. Notes ----- Particularly, this function aims to solve: 1. The Jacobian matrix of hidden-to-weight. That is, :math:`\partial h / \partial w`, where :math:`h` is the hidden state and :math:`w` is the weight. 2. The Jacobian matrix of hidden-to-hidden. That is, :math:`\partial h / \partial h`, where :math:`h` is the hidden state. 3. The partial gradients of the loss with respect to the hidden states. That is, :math:`\partial L / \partial h`, where :math:`L` is the loss and :math:`h` is the hidden state. """ raise NotImplementedError('The method "solve_h2w_h2h_jacobian_and_l2h_vjp" ' 'should be implemented in the subclass.')