Source code for braintrace.nn._rnn

# Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# -*- coding: utf-8 -*-

from typing import Callable, Union

import brainstate
import braintools
import saiunit as u

from braintrace._etrace_op import element_wise
from braintrace._typing import ArrayLike
from ._linear import Linear

__all__ = [
    'ValinaRNNCell',
    'GRUCell',
    'MGUCell',
    'LSTMCell',
    'URLSTMCell',
    'MinimalRNNCell',
    'MiniGRU',
    'MiniLSTM',
    'LRUCell',
]


class ValinaRNNCell(brainstate.nn.RNNCell):
    """Vanilla RNN cell.

    A basic recurrent neural network cell that applies a simple recurrent transformation
    to the input and previous hidden state.

    Parameters
    ----------
    in_size : brainstate.typing.Size
        The number of input units.
    out_size : brainstate.typing.Size
        The number of hidden units.
    state_init : Callable or ArrayLike, optional
        The state initializer. Default is ZeroInit().
    w_init : Callable or ArrayLike, optional
        The input weight initializer. Default is XavierNormal().
    b_init : Callable or ArrayLike, optional
        The bias weight initializer. Default is ZeroInit().
    activation : str or Callable, optional
        The activation function. It can be a string or a callable function. Default is 'relu'.
    name : str or None, optional
        The name of the module. Default is None.

    Examples
    --------
    .. code-block:: python

        >>> import braintrace
        >>> import brainstate
        >>>
        >>> # Create a Vanilla RNN cell
        >>> rnn_cell = braintrace.nn.ValinaRNNCell(in_size=32, out_size=64)
        >>> rnn_cell.init_state(batch_size=8)
        >>>
        >>> # Process a sequence of inputs
        >>> x = brainstate.random.randn(8, 32)
        >>> h = rnn_cell(x)
        >>> print(h.shape)
        (8, 64)
    """
    __module__ = 'braintrace.nn'

    def __init__(
        self,
        in_size: brainstate.typing.Size,
        out_size: brainstate.typing.Size,
        state_init: Union[ArrayLike, Callable] = braintools.init.ZeroInit(),
        w_init: Union[ArrayLike, Callable] = braintools.init.XavierNormal(),
        b_init: Union[ArrayLike, Callable] = braintools.init.ZeroInit(),
        activation: str | Callable = 'relu',
        name: str = None,
    ):
        super().__init__(name=name)

        # parameters
        self._state_initializer = state_init
        self.out_size = out_size
        self.in_size = in_size

        # activation function
        if isinstance(activation, str):
            self.activation = getattr(brainstate.functional, activation)
        else:
            assert callable(activation), "The activation function should be a string or a callable function. "
            self.activation = activation

        # weights
        self.W = Linear(
            self.in_size[-1] + self.out_size[-1], self.out_size[-1],
            w_init=w_init,
            b_init=b_init,
        )

[docs] def init_state(self, batch_size: int = None, **kwargs): self.h = brainstate.HiddenState( braintools.init.param(self._state_initializer, self.out_size, batch_size))
[docs] def reset_state(self, batch_size: int = None, **kwargs): self.h.value = braintools.init.param(self._state_initializer, self.out_size, batch_size)
[docs] def update(self, x): r"""Advance the cell by one time step. Parameters ---------- x : ArrayLike Input for the current step, of shape ``(..., in_size)``. Returns ------- ArrayLike The updated hidden state, of shape ``(..., out_size)``. """ xh = u.math.concatenate([x, self.h.value], axis=-1) self.h.value = self.activation(self.W(xh)) return self.h.value
class GRUCell(brainstate.nn.RNNCell): r"""Gated Recurrent Unit (GRU) cell. Gated Recurrent Unit (GRU) cell, implemented as in `Learning Phrase Representations using RNN Encoder-Decoder for Statistical Machine Translation <https://arxiv.org/abs/1406.1078>`_. Parameters ---------- in_size : brainstate.typing.Size The number of input units. out_size : brainstate.typing.Size The number of hidden units. w_init : Callable or ArrayLike, optional The input weight initializer. Default is Orthogonal(). b_init : Callable or ArrayLike, optional The bias weight initializer. Default is ZeroInit(). state_init : Callable or ArrayLike, optional The state initializer. Default is ZeroInit(). activation : str or Callable, optional The activation function. It can be a string or a callable function. Default is 'tanh'. name : str or None, optional The name of the module. Default is None. Examples -------- .. code-block:: python >>> import braintrace >>> import brainstate >>> >>> # Create a GRU cell >>> gru_cell = braintrace.nn.GRUCell(in_size=128, out_size=256) >>> gru_cell.init_state(batch_size=16) >>> >>> # Process a sequence of inputs >>> x = brainstate.random.randn(16, 128) >>> h = gru_cell(x) >>> print(h.shape) (16, 256) """ __module__ = 'braintrace.nn' def __init__( self, in_size: brainstate.typing.Size, out_size: brainstate.typing.Size, w_init: Union[ArrayLike, Callable] = braintools.init.Orthogonal(), b_init: Union[ArrayLike, Callable] = braintools.init.ZeroInit(), state_init: Union[ArrayLike, Callable] = braintools.init.ZeroInit(), activation: str | Callable = 'tanh', name: str = None, ): super().__init__(name=name) # parameters self._state_initializer = state_init self.out_size = out_size self.in_size = in_size # activation function if isinstance(activation, str): self.activation = getattr(brainstate.functional, activation) else: assert callable(activation), "The activation function should be a string or a callable function. " self.activation = activation # weights params = dict(w_init=w_init, b_init=b_init) self.Wz = Linear(self.in_size[-1] + self.out_size[-1], self.out_size[-1], **params) self.Wr = Linear(self.in_size[-1] + self.out_size[-1], self.out_size[-1], **params) self.Wh = Linear(self.in_size[-1] + self.out_size[-1], self.out_size[-1], **params)
[docs] def init_state(self, batch_size: int = None, **kwargs): self.h = brainstate.HiddenState(braintools.init.param(self._state_initializer, self.out_size, batch_size))
[docs] def reset_state(self, batch_size: int = None, **kwargs): self.h.value = braintools.init.param(self._state_initializer, self.out_size, batch_size)
[docs] def update(self, x): r"""Advance the cell by one time step. Parameters ---------- x : ArrayLike Input for the current step, of shape ``(..., in_size)``. Returns ------- ArrayLike The updated hidden state, of shape ``(..., out_size)``. """ old_h = self.h.value xh = u.math.concatenate([x, old_h], axis=-1) z = brainstate.nn.sigmoid(self.Wz(xh)) r = brainstate.nn.sigmoid(self.Wr(xh)) rh = r * old_h h = self.activation(self.Wh(u.math.concatenate([x, rh], axis=-1))) h = (1 - z) * old_h + z * h self.h.value = h return h
class CFNCell(brainstate.nn.RNNCell): r"""Chaos Free Networks (CFN) cell. Chaos Free Networks (CFN) cell, implemented as in `A recurrent neural network without chaos <https://arxiv.org/abs/1612.06212>`_. Parameters ---------- in_size : brainstate.typing.Size The number of input units. out_size : brainstate.typing.Size The number of hidden units. w_init : Callable or ArrayLike, optional The input weight initializer. Default is Orthogonal(). b_init : Callable or ArrayLike, optional The bias weight initializer. Default is ZeroInit(). state_init : Callable or ArrayLike, optional The state initializer. Default is ZeroInit(). activation : str or Callable, optional The activation function. It can be a string or a callable function. Default is 'tanh'. name : str or None, optional The name of the module. Default is None. Examples -------- .. code-block:: python >>> import braintrace >>> import brainstate >>> >>> # Create a CFN cell >>> cfn_cell = braintrace.nn.CFNCell(in_size=64, out_size=128) >>> cfn_cell.init_state(batch_size=10) >>> >>> # Process a sequence of inputs >>> x = brainstate.random.randn(10, 64) >>> h = cfn_cell(x) >>> print(h.shape) (10, 128) """ __module__ = 'braintrace.nn' def __init__( self, in_size: brainstate.typing.Size, out_size: brainstate.typing.Size, w_init: Union[ArrayLike, Callable] = braintools.init.Orthogonal(), b_init: Union[ArrayLike, Callable] = braintools.init.ZeroInit(), state_init: Union[ArrayLike, Callable] = braintools.init.ZeroInit(), activation: str | Callable = 'tanh', name: str = None, ): super().__init__(name=name) # parameters self._state_initializer = state_init self.out_size = out_size self.in_size = in_size # activation function if isinstance(activation, str): self.activation = getattr(brainstate.functional, activation) else: assert callable(activation), "The activation function should be a string or a callable function. " self.activation = activation # weights params = dict(w_init=w_init, b_init=b_init) self.Wf = Linear(self.in_size[-1] + self.out_size[-1], self.out_size[-1], **params) self.Wi = Linear(self.in_size[-1] + self.out_size[-1], self.out_size[-1], **params) self.Wh = Linear(self.out_size[-1], self.out_size[-1], **params) def init_state(self, batch_size: int = None, **kwargs): self.h = brainstate.HiddenState(braintools.init.param(self._state_initializer, self.out_size, batch_size)) def reset_state(self, batch_size: int = None, **kwargs): self.h.value = braintools.init.param(self._state_initializer, self.out_size, batch_size) def update(self, x): r"""Advance the cell by one time step. Parameters ---------- x : ArrayLike Input for the current step, of shape ``(..., in_size)``. Returns ------- ArrayLike The updated hidden state, of shape ``(..., out_size)``. """ old_h = self.h.value xh = u.math.concatenate([x, old_h], axis=-1) f = brainstate.nn.sigmoid(self.Wf(xh)) i = brainstate.nn.sigmoid(self.Wi(xh)) h = f * self.activation(old_h) + i * self.activation(self.Wh(x)) self.h.value = h return h class MGUCell(brainstate.nn.RNNCell): r"""Minimal Gated Recurrent Unit (MGU) cell. Minimal Gated Recurrent Unit (MGU) cell, implemented as in `Minimal Gated Unit for Recurrent Neural Networks <https://arxiv.org/abs/1603.09420>`_. .. math:: \begin{aligned} f_{t}&=\sigma (W_{f}x_{t}+U_{f}h_{t-1}+b_{f})\\ {\hat {h}}_{t}&=\phi (W_{h}x_{t}+U_{h}(f_{t}\odot h_{t-1})+b_{h})\\ h_{t}&=(1-f_{t})\odot h_{t-1}+f_{t}\odot {\hat {h}}_{t} \end{aligned} where: - :math:`x_{t}`: input vector - :math:`h_{t}`: output vector - :math:`{\hat {h}}_{t}`: candidate activation vector - :math:`f_{t}`: forget vector - :math:`W, U, b`: parameter matrices and vector Parameters ---------- in_size : brainstate.typing.Size The number of input units. out_size : brainstate.typing.Size The number of hidden units. w_init : Callable or ArrayLike, optional The input weight initializer. Default is Orthogonal(). b_init : Callable or ArrayLike, optional The bias weight initializer. Default is ZeroInit(). state_init : Callable or ArrayLike, optional The state initializer. Default is ZeroInit(). activation : str or Callable, optional The activation function. It can be a string or a callable function. Default is 'tanh'. name : str or None, optional The name of the module. Default is None. Examples -------- .. code-block:: python >>> import braintrace >>> import brainstate >>> >>> # Create an MGU cell >>> mgu_cell = braintrace.nn.MGUCell(in_size=96, out_size=192) >>> mgu_cell.init_state(batch_size=12) >>> >>> # Process a sequence of inputs >>> x = brainstate.random.randn(12, 96) >>> h = mgu_cell(x) >>> print(h.shape) (12, 192) """ __module__ = 'braintrace.nn' def __init__( self, in_size: brainstate.typing.Size, out_size: brainstate.typing.Size, w_init: Union[ArrayLike, Callable] = braintools.init.Orthogonal(), b_init: Union[ArrayLike, Callable] = braintools.init.ZeroInit(), state_init: Union[ArrayLike, Callable] = braintools.init.ZeroInit(), activation: str | Callable = 'tanh', name: str = None, ): super().__init__(name=name) # parameters self._state_initializer = state_init self.out_size = out_size self.in_size = in_size # activation function if isinstance(activation, str): self.activation = getattr(brainstate.functional, activation) else: assert callable(activation), "The activation function should be a string or a callable function. " self.activation = activation # weights params = dict(w_init=w_init, b_init=b_init) self.Wf = Linear(self.in_size[-1] + self.out_size[-1], self.out_size[-1], **params) self.Wh = Linear(self.in_size[-1] + self.out_size[-1], self.out_size[-1], **params)
[docs] def init_state(self, batch_size: int = None, **kwargs): self.h = brainstate.HiddenState(braintools.init.param(self._state_initializer, self.out_size, batch_size))
[docs] def reset_state(self, batch_size: int = None, **kwargs): self.h.value = braintools.init.param(self._state_initializer, self.out_size, batch_size)
[docs] def update(self, x): r"""Advance the cell by one time step. Parameters ---------- x : ArrayLike Input for the current step, of shape ``(..., in_size)``. Returns ------- ArrayLike The updated hidden state, of shape ``(..., out_size)``. """ old_h = self.h.value xh = u.math.concatenate([x, old_h], axis=-1) f = brainstate.nn.sigmoid(self.Wf(xh)) fh = f * old_h h = self.activation(self.Wh(u.math.concatenate([x, fh], axis=-1))) self.h.value = (1 - f) * self.h.value + f * h return self.h.value
class LSTMCell(brainstate.nn.RNNCell): r"""Long short-term memory (LSTM) RNN core. The implementation is based on (zaremba, et al., 2014) [1]_. Given :math:`x_t` and the previous state :math:`(h_{t-1}, c_{t-1})` the core computes .. math:: \begin{array}{ll} i_t = \sigma(W_{ii} x_t + W_{hi} h_{t-1} + b_i) \\ f_t = \sigma(W_{if} x_t + W_{hf} h_{t-1} + b_f) \\ g_t = \tanh(W_{ig} x_t + W_{hg} h_{t-1} + b_g) \\ o_t = \sigma(W_{io} x_t + W_{ho} h_{t-1} + b_o) \\ c_t = f_t c_{t-1} + i_t g_t \\ h_t = o_t \tanh(c_t) \end{array} where :math:`i_t`, :math:`f_t`, :math:`o_t` are input, forget and output gate activations, and :math:`g_t` is a vector of cell updates. The output is equal to the new hidden, :math:`h_t`. Parameters ---------- in_size : brainstate.typing.Size The dimension of the input vector. out_size : brainstate.typing.Size The number of hidden unit in the node. w_init : Callable or ArrayLike, optional The input weight initializer. Default is XavierNormal(). b_init : Callable or ArrayLike, optional The bias weight initializer. Default is ZeroInit(). state_init : Callable or ArrayLike, optional The state initializer. Default is ZeroInit(). activation : str or Callable, optional The activation function. It can be a string or a callable function. Default is 'tanh'. name : str or None, optional The name of the module. Default is None. Notes ----- Forget gate initialization: Following (Jozefowicz, et al., 2015) [2]_ we add 1.0 to :math:`b_f` after initialization in order to reduce the scale of forgetting in the beginning of the training. References ---------- .. [1] Zaremba, Wojciech, Ilya Sutskever, and Oriol Vinyals. "Recurrent neural network regularization." arXiv preprint arXiv:1409.2329 (2014). .. [2] Jozefowicz, Rafal, Wojciech Zaremba, and Ilya Sutskever. "An empirical exploration of recurrent network architectures." In International conference on machine learning, pp. 2342-2350. PMLR, 2015. Examples -------- .. code-block:: python >>> import braintrace >>> import brainstate >>> >>> # Create an LSTM cell >>> lstm_cell = braintrace.nn.LSTMCell(in_size=256, out_size=512) >>> lstm_cell.init_state(batch_size=20) >>> >>> # Process a sequence of inputs >>> x = brainstate.random.randn(20, 256) >>> h = lstm_cell(x) >>> print(h.shape) (20, 512) """ __module__ = 'braintrace.nn' def __init__( self, in_size: brainstate.typing.Size, out_size: brainstate.typing.Size, w_init: Union[ArrayLike, Callable] = braintools.init.XavierNormal(), b_init: Union[ArrayLike, Callable] = braintools.init.ZeroInit(), state_init: Union[ArrayLike, Callable] = braintools.init.ZeroInit(), activation: str | Callable = 'tanh', name: str = None, ): super().__init__(name=name) # parameters self.out_size = out_size self.in_size = in_size # initializers self._state_initializer = state_init # activation function if isinstance(activation, str): self.activation = getattr(brainstate.functional, activation) else: assert callable(activation), "The activation function should be a string or a callable function. " self.activation = activation # weights params = dict(w_init=w_init, b_init=b_init) self.Wi = Linear(self.in_size[-1] + self.out_size[-1], self.out_size[-1], **params) self.Wg = Linear(self.in_size[-1] + self.out_size[-1], self.out_size[-1], **params) self.Wf = Linear(self.in_size[-1] + self.out_size[-1], self.out_size[-1], **params) self.Wo = Linear(self.in_size[-1] + self.out_size[-1], self.out_size[-1], **params)
[docs] def init_state(self, batch_size: int = None, **kwargs): self.c = brainstate.HiddenState(braintools.init.param(self._state_initializer, self.out_size, batch_size)) self.h = brainstate.HiddenState(braintools.init.param(self._state_initializer, self.out_size, batch_size))
[docs] def reset_state(self, batch_size: int = None, **kwargs): self.c.value = braintools.init.param(self._state_initializer, self.out_size, batch_size) self.h.value = braintools.init.param(self._state_initializer, self.out_size, batch_size)
[docs] def update(self, x): r"""Advance the cell by one time step. Updates both the cell state ``c`` and the hidden state ``h`` in place. Parameters ---------- x : ArrayLike Input for the current step, of shape ``(..., in_size)``. Returns ------- ArrayLike The updated hidden state ``h``, of shape ``(..., out_size)``. """ h, c = self.h.value, self.c.value xh = u.math.concatenate([x, h], axis=-1) i = self.Wi(xh) g = self.Wg(xh) f = self.Wf(xh) o = self.Wo(xh) c = brainstate.nn.sigmoid(f + 1.) * c + brainstate.nn.sigmoid(i) * self.activation(g) h = brainstate.nn.sigmoid(o) * self.activation(c) self.h.value = h self.c.value = c return h
class URLSTMCell(brainstate.nn.RNNCell): """Update-Reset LSTM (URLSTM) cell. A variant of LSTM that uses update and reset gates for more flexible control over the cell state dynamics. Parameters ---------- in_size : brainstate.typing.Size The dimension of the input vector. out_size : brainstate.typing.Size The number of hidden units in the node. w_init : Callable or ArrayLike, optional The input weight initializer. Default is XavierNormal(). state_init : Callable or ArrayLike, optional The state initializer. Default is ZeroInit(). activation : str or Callable, optional The activation function. It can be a string or a callable function. Default is 'tanh'. name : str or None, optional The name of the module. Default is None. Examples -------- .. code-block:: python >>> import braintrace >>> import brainstate >>> >>> # Create a URLSTM cell >>> urlstm_cell = braintrace.nn.URLSTMCell(in_size=128, out_size=256) >>> urlstm_cell.init_state(batch_size=16) >>> >>> # Process a sequence of inputs >>> x = brainstate.random.randn(16, 128) >>> h = urlstm_cell(x) >>> print(h.shape) (16, 256) """ __module__ = 'braintrace.nn' def __init__( self, in_size: brainstate.typing.Size, out_size: brainstate.typing.Size, w_init: Union[ArrayLike, Callable] = braintools.init.XavierNormal(), state_init: Union[ArrayLike, Callable] = braintools.init.ZeroInit(), activation: str | Callable = 'tanh', name: str = None, ): super().__init__(name=name) # parameters self.out_size = out_size self.in_size = in_size # initializers self._state_initializer = state_init # activation function if isinstance(activation, str): self.activation = getattr(brainstate.functional, activation) else: assert callable(activation), "The activation function should be a string or a callable function. " self.activation = activation # weights params = dict(w_init=w_init, b_init=None) self.Wu = Linear(self.in_size[-1] + self.out_size[-1], self.out_size[-1], **params) self.Wf = Linear(self.in_size[-1] + self.out_size[-1], self.out_size[-1], **params) self.Wr = Linear(self.in_size[-1] + self.out_size[-1], self.out_size[-1], **params) self.Wo = Linear(self.in_size[-1] + self.out_size[-1], self.out_size[-1], **params) self.bias = brainstate.ParamState(self._forget_bias()) def _forget_bias(self): rand_val = brainstate.random.uniform(1 / self.out_size[-1], 1 - 1 / self.out_size[-1], (self.out_size[-1],)) return -u.math.log(1 / rand_val - 1)
[docs] def init_state(self, batch_size: int = None, **kwargs): self.c = brainstate.HiddenState( braintools.init.param(self._state_initializer, self.out_size, batch_size)) self.h = brainstate.HiddenState( braintools.init.param(self._state_initializer, self.out_size, batch_size))
[docs] def reset_state(self, batch_size: int = None, **kwargs): self.c.value = braintools.init.param(self._state_initializer, self.out_size, batch_size) self.h.value = braintools.init.param(self._state_initializer, self.out_size, batch_size)
[docs] def update(self, x: ArrayLike) -> ArrayLike: r"""Advance the cell by one time step. Updates both the cell state ``c`` and the hidden state ``h`` in place. Parameters ---------- x : ArrayLike Input for the current step, of shape ``(..., in_size)``. Returns ------- ArrayLike The updated hidden state ``h``, of shape ``(..., out_size)``. """ h, c = self.h.value, self.c.value xh = u.math.concatenate([x, h], axis=-1) f = self.Wf(xh) r = self.Wr(xh) u_ = self.Wu(xh) o = self.Wo(xh) f_ = brainstate.nn.sigmoid(element_wise(self.bias.value) + f) r_ = brainstate.nn.sigmoid(-(element_wise(self.bias.value) + (-r))) g = 2 * r_ * f_ + (1 - 2 * r_) * f_ ** 2 next_cell = g * c + (1 - g) * self.activation(u_) next_hidden = brainstate.nn.sigmoid(o) * self.activation(next_cell) self.h.value = next_hidden self.c.value = next_cell return next_hidden
class MinimalRNNCell(brainstate.nn.RNNCell): r"""Minimal RNN Cell. Minimal RNN Cell, implemented as in `MinimalRNN: Toward More Interpretable and Trainable Recurrent Neural Networks <https://arxiv.org/abs/1711.06788>`_ At each step :math:`t`, the model first maps its input :math:`\mathbf{x}_t` to a latent space through :math:`\mathbf{z}_t=\Phi(\mathbf{x}_t)`. :math:`\Phi(\cdot)` here can be any highly flexible functions such as neural networks. By default, we take :math:`\Phi(\cdot)` as a fully connected layer with tanh activation. Given the latent representation :math:`\mathbf{z}_t` of the input, MinimalRNN then updates its states simply as: .. math:: \mathbf{h}_t=\mathbf{u}_t\odot\mathbf{h}_{t-1}+(\mathbf{1}-\mathbf{u}_t)\odot\mathbf{z}_t where :math:`\mathbf{u}_t=\sigma(\mathbf{U}_h\mathbf{h}_{t-1}+\mathbf{U}_z\mathbf{z}_t+\mathbf{b}_u)` is the update gate. Parameters ---------- in_size : brainstate.typing.Size The number of input units. out_size : brainstate.typing.Size The number of hidden units. w_init : Callable or ArrayLike, optional The input weight initializer. Default is Orthogonal(). b_init : Callable or ArrayLike, optional The bias weight initializer. Default is ZeroInit(). state_init : Callable or ArrayLike, optional The state initializer. Default is ZeroInit(). phi : Callable or None, optional The input activation function. Default is None. name : str or None, optional The name of the module. Default is None. Examples -------- .. code-block:: python >>> import braintrace >>> import brainstate >>> >>> # Create a Minimal RNN cell >>> minrnn_cell = braintrace.nn.MinimalRNNCell(in_size=100, out_size=200) >>> minrnn_cell.init_state(batch_size=24) >>> >>> # Process a sequence of inputs >>> x = brainstate.random.randn(24, 100) >>> h = minrnn_cell(x) >>> print(h.shape) (24, 200) """ __module__ = 'braintrace.nn' def __init__( self, in_size: brainstate.typing.Size, out_size: brainstate.typing.Size, w_init: Union[ArrayLike, Callable] = braintools.init.Orthogonal(), b_init: Union[ArrayLike, Callable] = braintools.init.ZeroInit(), state_init: Union[ArrayLike, Callable] = braintools.init.ZeroInit(), phi: Callable = None, name: str = None, ): super().__init__(name=name) # parameters self._state_initializer = state_init self.out_size = out_size self.in_size = in_size # functions params = dict(w_init=w_init, b_init=b_init) if phi is None: phi = Linear(self.in_size[-1], self.out_size[-1], **params) assert callable(phi), f"The phi function should be a callable function. But got {phi}" self.phi = phi # weights self.W_u = Linear(self.out_size[-1] * 2, self.out_size[-1], **params)
[docs] def init_state(self, batch_size: int = None, **kwargs): self.h = brainstate.HiddenState(braintools.init.param(self._state_initializer, self.out_size, batch_size))
[docs] def reset_state(self, batch_size: int = None, **kwargs): self.h.value = braintools.init.param(self._state_initializer, self.out_size, batch_size)
[docs] def update(self, x): r"""Advance the cell by one time step. Parameters ---------- x : ArrayLike Input for the current step, of shape ``(..., in_size)``. Returns ------- ArrayLike The updated hidden state, of shape ``(..., out_size)``. """ z = self.phi(x) f = brainstate.nn.sigmoid(self.W_u(u.math.concatenate([z, self.h.value], axis=-1))) self.h.value = f * self.h.value + (1 - f) * z return self.h.value
class MiniGRU(brainstate.nn.RNNCell): r"""Minimal GRU cell. Minimal GRU Cell, a simplified version of GRU implemented as in `MinimalRNN: Toward More Interpretable and Trainable Recurrent Neural Networks <https://arxiv.org/abs/1711.06788>`_ At each step :math:`t`, the model processes the input through a gating mechanism that controls information flow. The hidden state is updated as: .. math:: \mathbf{h}_t = (1 - \mathbf{z}_t) \odot \mathbf{h}_{t-1} + \mathbf{z}_t \odot \mathbf{W}_x \mathbf{x}_t where :math:`\mathbf{z}_t=\sigma(\mathbf{W}_z[\mathbf{x}_t; \mathbf{h}_{t-1}])` is the update gate. Parameters ---------- in_size : brainstate.typing.Size The number of input units. out_size : brainstate.typing.Size The number of hidden units. w_init : Callable or ArrayLike, optional The input weight initializer. Default is Orthogonal(). b_init : Callable or ArrayLike, optional The bias weight initializer. Default is ZeroInit(). state_init : Callable or ArrayLike, optional The state initializer. Default is ZeroInit(). name : str or None, optional The name of the module. Default is None. Examples -------- .. code-block:: python >>> import braintrace >>> import brainstate >>> >>> # Create a Mini GRU cell >>> minigru_cell = braintrace.nn.MiniGRU(in_size=80, out_size=160) >>> minigru_cell.init_state(batch_size=32) >>> >>> # Process a sequence of inputs >>> x = brainstate.random.randn(32, 80) >>> h = minigru_cell(x) >>> print(h.shape) (32, 160) """ __module__ = 'braintrace.nn' def __init__( self, in_size: brainstate.typing.Size, out_size: brainstate.typing.Size, w_init: Union[ArrayLike, Callable] = braintools.init.Orthogonal(), b_init: Union[ArrayLike, Callable] = braintools.init.ZeroInit(), state_init: Union[ArrayLike, Callable] = braintools.init.ZeroInit(), name: str = None, ): super().__init__(name=name) # parameters self._state_initializer = state_init self.out_size = out_size self.in_size = in_size # functions params = dict(w_init=w_init, b_init=b_init) self.W_x = Linear(self.in_size[-1], self.out_size[-1], **params) # weights self.W_z = Linear(self.in_size[-1] + self.out_size[-1], self.out_size[-1], **params)
[docs] def init_state(self, batch_size: int = None, **kwargs): self.h = brainstate.HiddenState(braintools.init.param(self._state_initializer, self.out_size, batch_size))
[docs] def reset_state(self, batch_size: int = None, **kwargs): self.h.value = braintools.init.param(self._state_initializer, self.out_size, batch_size)
[docs] def update(self, x): r"""Advance the cell by one time step. Parameters ---------- x : ArrayLike Input for the current step, of shape ``(..., in_size)``. Returns ------- ArrayLike The updated hidden state, of shape ``(..., out_size)``. """ z = brainstate.nn.sigmoid(self.W_z(u.math.concatenate([x, self.h.value], axis=-1))) self.h.value = (1 - z) * self.h.value + z * self.W_x(x) return self.h.value
class MiniLSTM(brainstate.nn.RNNCell): r"""Minimal LSTM cell. Minimal LSTM Cell, a simplified version of LSTM implemented as in `MinimalRNN: Toward More Interpretable and Trainable Recurrent Neural Networks <https://arxiv.org/abs/1711.06788>`_ This simplified LSTM uses forget and input gates to control the flow of information, updating the hidden state as: .. math:: \mathbf{h}_t = \mathbf{f}_t \odot \mathbf{h}_{t-1} + \mathbf{i}_t \odot \mathbf{W}_x \mathbf{x}_t where :math:`\mathbf{f}_t` and :math:`\mathbf{i}_t` are the forget and input gates, respectively. Parameters ---------- in_size : brainstate.typing.Size The number of input units. out_size : brainstate.typing.Size The number of hidden units. w_init : Callable or ArrayLike, optional The input weight initializer. Default is Orthogonal(). b_init : Callable or ArrayLike, optional The bias weight initializer. Default is ZeroInit(). state_init : Callable or ArrayLike, optional The state initializer. Default is ZeroInit(). name : str or None, optional The name of the module. Default is None. Examples -------- .. code-block:: python >>> import braintrace >>> import brainstate >>> >>> # Create a Mini LSTM cell >>> minilstm_cell = braintrace.nn.MiniLSTM(in_size=150, out_size=300) >>> minilstm_cell.init_state(batch_size=40) >>> >>> # Process a sequence of inputs >>> x = brainstate.random.randn(40, 150) >>> h = minilstm_cell(x) >>> print(h.shape) (40, 300) """ __module__ = 'braintrace.nn' def __init__( self, in_size: brainstate.typing.Size, out_size: brainstate.typing.Size, w_init: Union[ArrayLike, Callable] = braintools.init.Orthogonal(), b_init: Union[ArrayLike, Callable] = braintools.init.ZeroInit(), state_init: Union[ArrayLike, Callable] = braintools.init.ZeroInit(), name: str = None, ): super().__init__(name=name) # parameters self._state_initializer = state_init self.out_size = out_size self.in_size = in_size # functions params = dict(w_init=w_init, b_init=b_init) self.W_x = Linear(self.in_size[-1], self.out_size[-1], **params) # weights self.W_f = Linear(self.in_size[-1] + self.out_size[-1], self.out_size[-1], **params) self.W_i = Linear(self.in_size[-1] + self.out_size[-1], self.out_size[-1], **params)
[docs] def init_state(self, batch_size: int = None, **kwargs): self.h = brainstate.HiddenState(braintools.init.param(self._state_initializer, self.out_size, batch_size))
[docs] def reset_state(self, batch_size: int = None, **kwargs): self.h.value = braintools.init.param(self._state_initializer, self.out_size, batch_size)
[docs] def update(self, x): r"""Advance the cell by one time step. Parameters ---------- x : ArrayLike Input for the current step, of shape ``(..., in_size)``. Returns ------- ArrayLike The updated hidden state, of shape ``(..., out_size)``. """ xh = u.math.concatenate([x, self.h.value], axis=-1) f = brainstate.nn.sigmoid(self.W_f(xh)) i = brainstate.nn.sigmoid(self.W_i(xh)) self.h.value = f * self.h.value + i * self.W_x(x) return self.h.value
def glorot_init(s): return brainstate.random.randn(*s) / u.math.sqrt(s[0])
[docs] class LRUCell(brainstate.nn.Module): r"""Linear Recurrent Unit (LRU) layer. `Linear Recurrent Unit <https://arxiv.org/abs/2303.06349>`_ (LRU) layer, which uses diagonal complex-valued state transitions for efficient sequence modeling. .. math:: h_{t+1} = \lambda * h_t + \exp(\gamma^{\mathrm{log}}) B x_{t+1} \\ \lambda = \text{diag}(\exp(-\exp(\nu^{\mathrm{log}}) + i \exp(\theta^\mathrm{log}))) \\ y_t = Re[C h_t + D x_t] Parameters ---------- d_model : int Input and output dimensions. d_hidden : int Hidden state dimension. r_min : float, optional Smallest lambda norm. Default is 0.0. r_max : float, optional Largest lambda norm. Default is 1.0. max_phase : float, optional Max phase lambda. Default is 6.28. Examples -------- .. code-block:: python >>> import braintrace >>> import brainstate >>> >>> # Create an LRU cell >>> lru_cell = braintrace.nn.LRUCell(d_model=64, d_hidden=128) >>> lru_cell.init_state(batch_size=16) >>> >>> # Process a sequence of inputs >>> x = brainstate.random.randn(16, 64) >>> y = lru_cell(x) >>> print(y.shape) (16, 64) """ def __init__( self, d_model: int, # input and output dimensions d_hidden: int, # hidden state dimension r_min: float = 0.0, # smallest lambda norm r_max: float = 1.0, # largest lambda norm max_phase: float = 6.28, # max phase lambda ): super().__init__() self.in_size = d_model self.out_size = d_hidden self.d_hidden = d_hidden self.d_model = d_model self.r_min = r_min self.r_max = r_max self.max_phase = max_phase # -------- recurrent weight matrix -------- # theta parameter theta_log = u.math.log(max_phase * brainstate.random.uniform(size=d_hidden)) self.theta_log = brainstate.ParamState(theta_log) # nu parameter nu_log = u.math.log( -0.5 * u.math.log( brainstate.random.uniform(size=d_hidden) * (r_max ** 2 - r_min ** 2) + r_min ** 2 ) ) self.nu_log = brainstate.ParamState(nu_log) # -------- input weight matrix -------- # gamma parameter diag_lambda = u.math.exp(-u.math.exp(nu_log) + 1j * u.math.exp(theta_log)) gamma_log = u.math.log(u.math.sqrt(1 - u.math.abs(diag_lambda) ** 2)) self.gamma_log = brainstate.ParamState(gamma_log) # Glorot initialized Input/Output projection matrices self.B_re = Linear(d_model, d_hidden, w_init=glorot_init, b_init=None) self.B_im = Linear(d_model, d_hidden, w_init=glorot_init, b_init=None) # -------- output weight matrix -------- self.C_re = Linear(d_hidden, d_model, w_init=glorot_init, b_init=None) self.C_im = Linear(d_hidden, d_model, w_init=glorot_init, b_init=None) # Parameter for skip connection self.D = brainstate.ParamState(brainstate.random.randn(d_model))
[docs] def init_state(self, batch_size: int = None, **kwargs): self.h_re = brainstate.HiddenState(braintools.init.param(braintools.init.ZeroInit(), self.d_hidden, batch_size)) self.h_im = brainstate.HiddenState(braintools.init.param(braintools.init.ZeroInit(), self.d_hidden, batch_size))
[docs] def reset_state(self, batch_size: int = None, **kwargs): self.h_re.value = braintools.init.param(braintools.init.ZeroInit(), self.d_hidden, batch_size) self.h_im.value = braintools.init.param(braintools.init.ZeroInit(), self.d_hidden, batch_size)
[docs] def update(self, inputs): r"""Advance the unit by one time step. Updates the real and imaginary parts of the complex hidden state (``h_re`` and ``h_im``) in place and returns the projected output. Parameters ---------- inputs : ArrayLike Input for the current step, of shape ``(..., d_model)``. Returns ------- ArrayLike The real-valued output ``y``, of shape ``(..., d_model)``. """ a = u.math.exp(-u.math.exp(element_wise(self.nu_log.value))) b = u.math.exp(element_wise(self.theta_log.value)) c = u.math.exp(element_wise(self.gamma_log.value)) a_cos_b = a * u.math.cos(b) a_sin_b = a * u.math.sin(b) # Compute both new values before any state mutation to avoid stale-read bug # (h_im computation must use old h_re, not the updated one) new_h_re = a_cos_b * self.h_re.value - a_sin_b * self.h_im.value + c * self.B_re(inputs) new_h_im = a_sin_b * self.h_re.value + a_cos_b * self.h_im.value + c * self.B_im(inputs) self.h_re.value = new_h_re self.h_im.value = new_h_im r = self.C_re(self.h_re.value) - self.C_im(self.h_im.value) + inputs * element_wise(self.D.value) return r