MiniGRU#

class braintrace.nn.MiniGRU(in_size, out_size, w_init=Orthogonal(scale=1.0), b_init=ZeroInit(unit=1), state_init=ZeroInit(unit=1), name=None)#

Minimal GRU cell.

Minimal GRU Cell, a simplified version of GRU implemented as in MinimalRNN: Toward More Interpretable and Trainable Recurrent Neural Networks

At each step \(t\), the model processes the input through a gating mechanism that controls information flow. The hidden state is updated as:

\[\mathbf{h}_t = (1 - \mathbf{z}_t) \odot \mathbf{h}_{t-1} + \mathbf{z}_t \odot \mathbf{W}_x \mathbf{x}_t\]

where \(\mathbf{z}_t=\sigma(\mathbf{W}_z[\mathbf{x}_t; \mathbf{h}_{t-1}])\) is the update gate.

Parameters:

Examples

>>> import braintrace
>>> import brainstate
>>>
>>> # Create a Mini GRU cell
>>> minigru_cell = braintrace.nn.MiniGRU(in_size=80, out_size=160)
>>> minigru_cell.init_state(batch_size=32)
>>>
>>> # Process a sequence of inputs
>>> x = brainstate.random.randn(32, 80)
>>> h = minigru_cell(x)
>>> print(h.shape)
(32, 160)
init_state(batch_size=None, **kwargs)[source]#

State initialization function.

reset_state(batch_size=None, **kwargs)[source]#

State resetting function.

update(x)[source]#

Advance the cell by one time step.

Parameters:

x (ArrayLike) – Input for the current step, of shape (..., in_size).

Returns:

The updated hidden state, of shape (..., out_size).

Return type:

ArrayLike