SignedWLinear#

class braintrace.nn.SignedWLinear(in_size, out_size, w_init=KaimingNormal(   scale=2.0, mode='fan_in', in_axis=-2, out_axis=-1, distribution='truncated_normal', rng=RandomState([3273 6067]), unit=Unit("1") ), w_sign=None, name=None, param_type=<class 'brainstate.ParamState'>)#

Linear layer with signed absolute weights.

This layer uses absolute values of weights multiplied by a sign matrix, ensuring all effective weights have controlled signs.

Parameters:
  • in_size (Union[int, Sequence[int], integer, Sequence[integer]]) – The input feature size.

  • out_size (Union[int, Sequence[int], integer, Sequence[integer]]) – The output feature size.

  • w_init (Union[Callable, Array, ndarray, bool, number, bool, int, float, complex, Quantity]) – Weight initializer. Default is KaimingNormal().

  • w_sign (Union[Array, ndarray, bool, number, bool, int, float, complex, Quantity, None]) – Sign matrix for the weights. If None, all weights are positive (absolute values used). If provided, should have the same shape as the weight matrix.

  • name (Optional[str]) – Name of the module.

  • param_type (type) – Type of parameter state. Default is ParamState.

in_size#

Input feature size.

Type:

tuple

out_size#

Output feature size.

Type:

tuple

w_sign#

Sign matrix for weights.

Type:

ArrayLike or None

weight#

Parameter state containing the weight values.

Type:

ParamState

Examples

>>> import braintrace as braintrace
>>> import jax.numpy as jnp
>>>
>>> # Create a signed weight linear layer with all positive weights
>>> layer = braintrace.nn.SignedWLinear((10,), (5,))
>>> x = jnp.ones((32, 10))
>>> y = layer(x)
>>> y.shape
(32, 5)
>>>
>>> # With custom sign matrix (e.g., inhibitory connections)
>>> w_sign = jnp.ones((10, 5)) * -1.0  # all negative
>>> layer = braintrace.nn.SignedWLinear((10,), (5,), w_sign=w_sign)
>>> y = layer(x)
>>> y.shape
(32, 5)
update(x)[source]#

Apply the sign-constrained linear transform through ETP matmul.

The stored weight magnitudes are made non-negative and then given a fixed sign before being routed through braintrace.matmul(), so the weight participates in online-learning trace computation.

Parameters:

x (ArrayLike) – Input array, of shape (..., in_size).

Returns:

The transformed output, of shape (..., out_size).

Return type:

ArrayLike