Source code for braintrace._compile

# Copyright 2026 BrainX Ecosystem Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

from typing import Type, Union

from ._etrace_algorithms import (
    ETraceAlgorithm,
    D_RTRL,
    pp_prop,
    EProp,
    OSTLRecurrent,
    OSTLFeedforward,
    OTPE,
    OTTT,
    OSTTP,
)

__all__ = ['compile']

# Canonical lowercase name (+ aliases) -> algorithm class. No bare ``ostl``
# alias: the ambiguous OSTL factory was removed in 0.2.0, so callers pick
# ``ostl_recurrent`` vs ``ostl_feedforward`` explicitly.
_ALGORITHM_REGISTRY = {
    'd_rtrl': D_RTRL,
    'pp_prop': pp_prop,
    'es_d_rtrl': pp_prop,
    'esd_rtrl': pp_prop,
    'eprop': EProp,
    'e_prop': EProp,
    'ostl_recurrent': OSTLRecurrent,
    'ostl_feedforward': OSTLFeedforward,
    'otpe': OTPE,
    'ottt': OTTT,
    'osttp': OSTTP,
}


def _resolve_algorithm(
    algorithm: Union[str, Type[ETraceAlgorithm]]
) -> Type[ETraceAlgorithm]:
    """Resolve ``algorithm`` to an :class:`ETraceAlgorithm` subclass.

    Parameters
    ----------
    algorithm : type or str
        Either an :class:`ETraceAlgorithm` subclass (returned unchanged) or a
        registered string name (case-insensitive), e.g. ``'D_RTRL'``,
        ``'eprop'``, ``'ottt'``.

    Returns
    -------
    type
        The resolved :class:`ETraceAlgorithm` subclass.

    Raises
    ------
    ValueError
        If ``algorithm`` is a string that is not a registered name.
    TypeError
        If ``algorithm`` is a class that is not an ``ETraceAlgorithm`` subclass,
        or is neither a class nor a string.
    """
    if isinstance(algorithm, type):
        if issubclass(algorithm, ETraceAlgorithm):
            return algorithm
        raise TypeError(
            f'algorithm class must be a subclass of ETraceAlgorithm, got {algorithm!r}.'
        )
    if isinstance(algorithm, str):
        key = algorithm.strip().lower()
        try:
            return _ALGORITHM_REGISTRY[key]
        except KeyError:
            valid = ', '.join(sorted(_ALGORITHM_REGISTRY))
            raise ValueError(
                f'Unknown algorithm name {algorithm!r}. Valid names: {valid}. '
                f'Or pass an ETraceAlgorithm subclass directly.'
            )
    raise TypeError(
        f'algorithm must be an ETraceAlgorithm subclass or a registered string name, '
        f'got {type(algorithm)}.'
    )


[docs] def compile(model, algorithm, *example_inputs, **options): """Construct an online-learning algorithm for ``model`` and eagerly build its eligibility-trace graph, returning a ready-to-``update`` learner. Parameters ---------- model : brainstate.nn.Module The recurrent model. Its states must already be initialized, e.g. via ``brainstate.nn.init_all_states(model)``. algorithm : type or str An :class:`ETraceAlgorithm` subclass, or a registered string name (case-insensitive), e.g. ``'D_RTRL'``, ``'eprop'``, ``'ottt'``. *example_inputs Example call inputs (arrays / :class:`SingleStepData` / :class:`MultiStepData`), matching what ``learner.update(...)`` will later receive. Forwarded to :meth:`ETraceAlgorithm.compile_graph` to trace the jaxpr graph. At least one is required. **options Keyword options forwarded to the algorithm constructor, e.g. ``vjp_method``, ``leak``, ``fast_solve``, ``trace_dtype``, ``feedback``. Returns ------- ETraceAlgorithm The compiled learner; call ``.update(*inputs)`` to train. Raises ------ ValueError If ``algorithm`` is an unknown string name, or no ``example_inputs`` are given. TypeError If ``algorithm`` is neither an ``ETraceAlgorithm`` subclass nor a string. Examples -------- .. code-block:: python >>> import braintrace >>> import brainstate >>> import jax.numpy as jnp >>> model = MyRNN() >>> brainstate.nn.init_all_states(model, batch_size=1) >>> x0 = jnp.ones((3,)) >>> learner = braintrace.compile(model, 'D_RTRL', x0, vjp_method='multi-step') >>> y = learner.update(x0) """ cls = _resolve_algorithm(algorithm) if len(example_inputs) == 0: raise ValueError( 'compile() needs at least one example input to build the graph ' 'eagerly, e.g. compile(model, "D_RTRL", x0). Pass the same inputs ' 'you will give to learner.update(...).' ) learner = cls(model, **options) learner.compile_graph(*example_inputs) return learner