URLSTMCell#

class braintrace.nn.URLSTMCell(in_size, out_size, w_init=XavierNormal(scale=1.0, unit=1), state_init=ZeroInit(unit=1), activation='tanh', name=None)#

Update-Reset LSTM (URLSTM) cell.

A variant of LSTM that uses update and reset gates for more flexible control over the cell state dynamics.

Parameters:
  • in_size (Union[int, Sequence[int], integer, Sequence[integer]]) – The dimension of the input vector.

  • out_size (Union[int, Sequence[int], integer, Sequence[integer]]) – The number of hidden units in the node.

  • w_init (Union[Array, ndarray, bool, number, bool, int, float, complex, Quantity, Callable]) – The input weight initializer. Default is XavierNormal().

  • state_init (Union[Array, ndarray, bool, number, bool, int, float, complex, Quantity, Callable]) – The state initializer. Default is ZeroInit().

  • activation (Union[str, Callable]) – The activation function. It can be a string or a callable function. Default is ‘tanh’.

  • name (str) – The name of the module. Default is None.

Examples

>>> import braintrace
>>> import brainstate
>>>
>>> # Create a URLSTM cell
>>> urlstm_cell = braintrace.nn.URLSTMCell(in_size=128, out_size=256)
>>> urlstm_cell.init_state(batch_size=16)
>>>
>>> # Process a sequence of inputs
>>> x = brainstate.random.randn(16, 128)
>>> h = urlstm_cell(x)
>>> print(h.shape)
(16, 256)
init_state(batch_size=None, **kwargs)[source]#

State initialization function.

reset_state(batch_size=None, **kwargs)[source]#

State resetting function.

update(x)[source]#

Advance the cell by one time step.

Updates both the cell state c and the hidden state h in place.

Parameters:

x (Union[Array, ndarray, bool, number, bool, int, float, complex, Quantity]) – Input for the current step, of shape (..., in_size).

Returns:

The updated hidden state h, of shape (..., out_size).

Return type:

Union[Array, ndarray, bool, number, bool, int, float, complex, Quantity]