FixedRandomFeedback#
- class braintrace.FixedRandomFeedback(n_target, n_layer, key, init_scale=0.1)#
Frozen random feedback matrix with a stop-gradient guard.
The feedback matrix \(B \in \mathbb{R}^{n_{\mathrm{target}} \times n_{\mathrm{layer}}}\) is sampled once at construction and frozen via
jax.lax.stop_gradient(). It is used by OSTTP (per-HiddenGroup target projection) and EProp-random-feedback.- Parameters:
n_target (
int) – Number of target dimensions (the row count ofB).n_layer (
int) – Number of layer dimensions (the column count ofB).key (jax.Array) – A JAX PRNG key used to sample the feedback matrix.
init_scale (
float) – Standard-deviation scaling applied to the sampled normal entries. Default is0.1.
- B#
The frozen feedback matrix of shape
(n_target, n_layer).- Type:
jax.Array
Examples
>>> import jax >>> import braintrace >>> fb = braintrace.FixedRandomFeedback(2, 3, jax.random.PRNGKey(0)) >>> print(fb.B.shape) (2, 3) >>> y = jax.numpy.ones(2) >>> print(fb.project(y).shape) (3,)