MinimalRNNCell#

class braintrace.nn.MinimalRNNCell(in_size, out_size, w_init=Orthogonal(scale=1.0), b_init=ZeroInit(unit=1), state_init=ZeroInit(unit=1), phi=None, name=None)#

Minimal RNN Cell.

Minimal RNN Cell, implemented as in MinimalRNN: Toward More Interpretable and Trainable Recurrent Neural Networks

At each step \(t\), the model first maps its input \(\mathbf{x}_t\) to a latent space through \(\mathbf{z}_t=\Phi(\mathbf{x}_t)\). \(\Phi(\cdot)\) here can be any highly flexible functions such as neural networks. By default, we take \(\Phi(\cdot)\) as a fully connected layer with tanh activation.

Given the latent representation \(\mathbf{z}_t\) of the input, MinimalRNN then updates its states simply as:

\[\mathbf{h}_t=\mathbf{u}_t\odot\mathbf{h}_{t-1}+(\mathbf{1}-\mathbf{u}_t)\odot\mathbf{z}_t\]

where \(\mathbf{u}_t=\sigma(\mathbf{U}_h\mathbf{h}_{t-1}+\mathbf{U}_z\mathbf{z}_t+\mathbf{b}_u)\) is the update gate.

Parameters:

Examples

>>> import braintrace
>>> import brainstate
>>>
>>> # Create a Minimal RNN cell
>>> minrnn_cell = braintrace.nn.MinimalRNNCell(in_size=100, out_size=200)
>>> minrnn_cell.init_state(batch_size=24)
>>>
>>> # Process a sequence of inputs
>>> x = brainstate.random.randn(24, 100)
>>> h = minrnn_cell(x)
>>> print(h.shape)
(24, 200)
init_state(batch_size=None, **kwargs)[source]#

State initialization function.

reset_state(batch_size=None, **kwargs)[source]#

State resetting function.

update(x)[source]#

Advance the cell by one time step.

Parameters:

x (ArrayLike) – Input for the current step, of shape (..., in_size).

Returns:

The updated hidden state, of shape (..., out_size).

Return type:

ArrayLike