LeakyRateReadout#
- class braintrace.nn.LeakyRateReadout(in_size, out_size, tau=Quantity(5., 'ms'), w_init=KaimingNormal(mode=fan_in, nonlinearity=relu, unit=1), r_init=ZeroInit(unit=1), name=None)#
Leaky dynamics for the read-out module used in Real-Time Recurrent Learning.
The LeakyRateReadout class implements a leaky integration mechanism for processing continuous input signals in neural networks. It is designed to simulate the dynamics of rate-based neurons, applying leaky integration to the input and producing a continuous output signal.
This class is part of the BrainTrace project and integrates with the Brain Dynamics Programming ecosystem, providing a biologically inspired approach to neural computation.
- Parameters:
in_size (
Union[int,Sequence[int],integer,Sequence[integer]]) – The size of the input to the readout module.out_size (
Union[int,Sequence[int],integer,Sequence[integer]]) – The size of the output from the readout module.tau (
Union[Array,ndarray,bool,number,bool,int,float,complex,Quantity]) – The time constant for the leaky integration dynamics. Default is 5 ms.w_init (
Callable) – A callable for initializing the weights of the readout module. Default is KaimingNormal().r_init (
Callable) – A callable for initializing the state of the readout module. Default is ZeroInit().name (
Optional[str]) – An optional name for the module. Default is None.
- tau#
The time constant for leaky integration.
- Type:
ArrayLike
- decay#
The decay factor computed from tau.
- Type:
ArrayLike
- r#
The readout state variable.
- Type:
HiddenState
- weight_op#
The parameter object that holds the weights and operations.
- Type:
ParamState
Examples
>>> import braintrace >>> import brainstate >>> import saiunit as u >>> >>> brainstate.environ.set(dt=0.1 * u.ms) >>> # Create a leaky rate readout layer >>> readout = braintrace.nn.LeakyRateReadout( ... in_size=256, ... out_size=10, ... tau=5.0 * u.ms ... ) >>> readout.init_state(batch_size=32) >>> >>> # Process input through the readout layer >>> x = brainstate.random.randn(32, 256) >>> output = readout(x) >>> print(output.shape) (32, 10)
- update(x)[source]#
Advance the readout by one time step of leaky integration.
Applies \(r_t = \mathrm{decay} \cdot r_{t-1} + W^\top x_t\) and stores the result in the readout state.
- Parameters:
x (ArrayLike) – Input for the current step, of shape
(..., in_size).- Returns:
The updated readout state, of shape
(..., out_size).- Return type:
ArrayLike