Linear#

class braintrace.nn.Linear(in_size, out_size, w_init=KaimingNormal(   scale=2.0, mode='fan_in', in_axis=-2, out_axis=-1, distribution='truncated_normal', rng=RandomState([3273 6067]), unit=Unit("1") ), b_init=ZeroInit(   unit=Unit("1") ), w_mask=None, name=None, param_type=<class 'brainstate.ParamState'>)#

Linear transformation layer.

Applies a linear transformation to the incoming data: \(y = xW + b\)

Parameters:
in_size#

Input feature size.

Type:

tuple

out_size#

Output feature size.

Type:

tuple

w_mask#

Weight mask if provided.

Type:

ArrayLike or None

weight#

Parameter state containing ‘weight’ and optionally ‘bias’.

Type:

ParamState

Examples

>>> import braintrace as braintrace
>>> import jax.numpy as jnp
>>>
>>> # Create a linear layer
>>> layer = braintrace.nn.Linear((10,), (5,))
>>> x = jnp.ones((32, 10))
>>> y = layer(x)
>>> y.shape
(32, 5)
>>>
>>> # Linear layer without bias
>>> layer = braintrace.nn.Linear((10,), (5,), b_init=None)
>>> y = layer(x)
>>> y.shape
(32, 5)
update(x)[source]#

Apply the linear transform through the ETP matmul primitive.

Routing the matrix multiplication through braintrace.matmul() (instead of a plain JAX dot) is what makes weight eligible for online-learning trace computation.

Parameters:

x (ArrayLike) – Input array, of shape (..., in_size).

Returns:

The transformed output, of shape (..., out_size).

Return type:

ArrayLike