Source code for braintrace.nn._readout
# Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# -*- coding: utf-8 -*-
import numbers
from typing import Callable, Optional
import brainstate
import braintools
import saiunit as u
from braintrace._etrace_op import matmul
from braintrace._typing import Size, ArrayLike
__all__ = [
'LeakyRateReadout',
]
class LeakyRateReadout(brainstate.nn.Module):
"""Leaky dynamics for the read-out module used in Real-Time Recurrent Learning.
The LeakyRateReadout class implements a leaky integration mechanism
for processing continuous input signals in neural networks. It is
designed to simulate the dynamics of rate-based neurons, applying
leaky integration to the input and producing a continuous output
signal.
This class is part of the BrainTrace project and integrates with
the Brain Dynamics Programming ecosystem, providing a biologically
inspired approach to neural computation.
Parameters
----------
in_size : Size
The size of the input to the readout module.
out_size : Size
The size of the output from the readout module.
tau : ArrayLike, optional
The time constant for the leaky integration dynamics. Default is 5 ms.
w_init : Callable, optional
A callable for initializing the weights of the readout module.
Default is KaimingNormal().
r_init : Callable, optional
A callable for initializing the state of the readout module.
Default is ZeroInit().
name : str or None, optional
An optional name for the module. Default is None.
Attributes
----------
in_size : tuple of int
The size of the input.
out_size : tuple of int
The size of the output.
tau : ArrayLike
The time constant for leaky integration.
decay : ArrayLike
The decay factor computed from tau.
r : HiddenState
The readout state variable.
weight_op : ParamState
The parameter object that holds the weights and operations.
Examples
--------
.. code-block:: python
>>> import braintrace
>>> import brainstate
>>> import saiunit as u
>>>
>>> brainstate.environ.set(dt=0.1 * u.ms)
>>> # Create a leaky rate readout layer
>>> readout = braintrace.nn.LeakyRateReadout(
... in_size=256,
... out_size=10,
... tau=5.0 * u.ms
... )
>>> readout.init_state(batch_size=32)
>>>
>>> # Process input through the readout layer
>>> x = brainstate.random.randn(32, 256)
>>> output = readout(x)
>>> print(output.shape)
(32, 10)
"""
__module__ = 'braintrace.nn'
def __init__(
self,
in_size: Size,
out_size: Size,
tau: ArrayLike = 5. * u.ms,
w_init: Callable = braintools.init.KaimingNormal(),
r_init: Callable = braintools.init.ZeroInit(),
name: Optional[str] = None,
):
super().__init__(name=name)
# parameters
self.in_size = (in_size,) if isinstance(in_size, numbers.Integral) else tuple(in_size)
self.out_size = (out_size,) if isinstance(out_size, numbers.Integral) else tuple(out_size)
self.tau = braintools.init.param(tau, self.out_size)
# Compute decay handling units properly
tau_normalized = u.maybe_decimal(self.tau / brainstate.environ.get_dt())
self.decay = u.math.exp(-1.0 / tau_normalized)
self.r_init = r_init
# weights
self.W = brainstate.ParamState(
braintools.init.param(w_init, (self.in_size[0], self.out_size[0]))
)
[docs]
def init_state(self, batch_size=None, **kwargs):
self.r = brainstate.HiddenState(braintools.init.param(self.r_init, self.out_size, batch_size))
[docs]
def reset_state(self, batch_size=None, **kwargs):
self.r.value = braintools.init.param(self.r_init, self.out_size, batch_size)
[docs]
def update(self, x):
r"""Advance the readout by one time step of leaky integration.
Applies :math:`r_t = \mathrm{decay} \cdot r_{t-1} + W^\top x_t` and
stores the result in the readout state.
Parameters
----------
x : ArrayLike
Input for the current step, of shape ``(..., in_size)``.
Returns
-------
ArrayLike
The updated readout state, of shape ``(..., out_size)``.
"""
r = self.decay * self.r.value + matmul(x, self.W.value)
self.r.value = r
return r