Linear#
- class braintrace.nn.Linear(in_size, out_size, w_init=KaimingNormal( scale=2.0, mode='fan_in', in_axis=-2, out_axis=-1, distribution='truncated_normal', rng=RandomState([3273 6067]), unit=Unit("1") ), b_init=ZeroInit( unit=Unit("1") ), w_mask=None, name=None, param_type=<class 'brainstate.ParamState'>)#
Linear transformation layer.
Applies a linear transformation to the incoming data: \(y = xW + b\)
- Parameters:
in_size (
Union[int,Sequence[int],integer,Sequence[integer]]) – The input feature size.out_size (
Union[int,Sequence[int],integer,Sequence[integer]]) – The output feature size.w_init (
Union[Callable,Array,ndarray,bool,number,bool,int,float,complex,Quantity]) – Weight initializer. Default isKaimingNormal().b_init (
Union[Array,ndarray,bool,number,bool,int,float,complex,Quantity,Callable,None]) – Bias initializer. IfNone, no bias is added. Default isZeroInit().w_mask (
Union[Array,ndarray,bool,number,bool,int,float,complex,Quantity,Callable,None]) – Optional mask for the weights. If provided, weights will be element-wise multiplied by this mask.param_type (
type) – Type of parameter state. Default isParamState.
- w_mask#
Weight mask if provided.
- Type:
ArrayLike or None
- weight#
Parameter state containing ‘weight’ and optionally ‘bias’.
- Type:
ParamState
Examples
>>> import braintrace as braintrace >>> import jax.numpy as jnp >>> >>> # Create a linear layer >>> layer = braintrace.nn.Linear((10,), (5,)) >>> x = jnp.ones((32, 10)) >>> y = layer(x) >>> y.shape (32, 5) >>> >>> # Linear layer without bias >>> layer = braintrace.nn.Linear((10,), (5,), b_init=None) >>> y = layer(x) >>> y.shape (32, 5)
- update(x)[source]#
Apply the linear transform through the ETP
matmulprimitive.Routing the matrix multiplication through
braintrace.matmul()(instead of a plain JAX dot) is what makesweighteligible for online-learning trace computation.- Parameters:
x (ArrayLike) – Input array, of shape
(..., in_size).- Returns:
The transformed output, of shape
(..., out_size).- Return type:
ArrayLike