FixedRandomFeedback#

class braintrace.FixedRandomFeedback(n_target, n_layer, key, init_scale=0.1)#

Frozen random feedback matrix with a stop-gradient guard.

The feedback matrix \(B \in \mathbb{R}^{n_{\mathrm{target}} \times n_{\mathrm{layer}}}\) is sampled once at construction and frozen via jax.lax.stop_gradient(). It is used by OSTTP (per-HiddenGroup target projection) and EProp-random-feedback.

Parameters:
  • n_target (int) – Number of target dimensions (the row count of B).

  • n_layer (int) – Number of layer dimensions (the column count of B).

  • key (jax.Array) – A JAX PRNG key used to sample the feedback matrix.

  • init_scale (float) – Standard-deviation scaling applied to the sampled normal entries. Default is 0.1.

B#

The frozen feedback matrix of shape (n_target, n_layer).

Type:

jax.Array

n_target#

Number of target dimensions.

Type:

int

n_layer#

Number of layer dimensions.

Type:

int

Examples

>>> import jax
>>> import braintrace
>>> fb = braintrace.FixedRandomFeedback(2, 3, jax.random.PRNGKey(0))
>>> print(fb.B.shape)
(2, 3)
>>> y = jax.numpy.ones(2)
>>> print(fb.project(y).shape)
(3,)
project(y_target)[source]#

Project the target onto the frozen feedback matrix.

Parameters:

y_target (jax.Array) – The target tensor to project. Both batched and unbatched layouts are handled.

Returns:

The projection y_target @ B with B frozen.

Return type:

jax.Array