GRUCell#

class braintrace.nn.GRUCell(in_size, out_size, w_init=Orthogonal(scale=1.0), b_init=ZeroInit(unit=1), state_init=ZeroInit(unit=1), activation='tanh', name=None)#

Gated Recurrent Unit (GRU) cell.

Gated Recurrent Unit (GRU) cell, implemented as in Learning Phrase Representations using RNN Encoder-Decoder for Statistical Machine Translation.

Parameters:

Examples

>>> import braintrace
>>> import brainstate
>>>
>>> # Create a GRU cell
>>> gru_cell = braintrace.nn.GRUCell(in_size=128, out_size=256)
>>> gru_cell.init_state(batch_size=16)
>>>
>>> # Process a sequence of inputs
>>> x = brainstate.random.randn(16, 128)
>>> h = gru_cell(x)
>>> print(h.shape)
(16, 256)
init_state(batch_size=None, **kwargs)[source]#

State initialization function.

reset_state(batch_size=None, **kwargs)[source]#

State resetting function.

update(x)[source]#

Advance the cell by one time step.

Parameters:

x (ArrayLike) – Input for the current step, of shape (..., in_size).

Returns:

The updated hidden state, of shape (..., out_size).

Return type:

ArrayLike