MGUCell#

class braintrace.nn.MGUCell(in_size, out_size, w_init=Orthogonal(scale=1.0), b_init=ZeroInit(unit=1), state_init=ZeroInit(unit=1), activation='tanh', name=None)#

Minimal Gated Recurrent Unit (MGU) cell.

Minimal Gated Recurrent Unit (MGU) cell, implemented as in Minimal Gated Unit for Recurrent Neural Networks.

\[\begin{split}\begin{aligned} f_{t}&=\sigma (W_{f}x_{t}+U_{f}h_{t-1}+b_{f})\\ {\hat {h}}_{t}&=\phi (W_{h}x_{t}+U_{h}(f_{t}\odot h_{t-1})+b_{h})\\ h_{t}&=(1-f_{t})\odot h_{t-1}+f_{t}\odot {\hat {h}}_{t} \end{aligned}\end{split}\]

where:

  • \(x_{t}\): input vector

  • \(h_{t}\): output vector

  • \({\hat {h}}_{t}\): candidate activation vector

  • \(f_{t}\): forget vector

  • \(W, U, b\): parameter matrices and vector

Parameters:

Examples

>>> import braintrace
>>> import brainstate
>>>
>>> # Create an MGU cell
>>> mgu_cell = braintrace.nn.MGUCell(in_size=96, out_size=192)
>>> mgu_cell.init_state(batch_size=12)
>>>
>>> # Process a sequence of inputs
>>> x = brainstate.random.randn(12, 96)
>>> h = mgu_cell(x)
>>> print(h.shape)
(12, 192)
init_state(batch_size=None, **kwargs)[source]#

State initialization function.

reset_state(batch_size=None, **kwargs)[source]#

State resetting function.

update(x)[source]#

Advance the cell by one time step.

Parameters:

x (ArrayLike) – Input for the current step, of shape (..., in_size).

Returns:

The updated hidden state, of shape (..., out_size).

Return type:

ArrayLike