element_wise#
- class braintrace.element_wise(weight, fn=<function <lambda>>)[source]#
ETP-aware element-wise operation.
Applies
fntoweightand passes the result through a marker primitive. The operation is treated as diagonal in the hidden-state space, so the weight participates in eligibility-trace computation as a per-element trainable parameter.- Parameters:
weight (ArrayLike) – Weight parameter.
fn (Callable, optional) – Element-wise function applied to
weightbefore the primitive binds. Default is the identitylambda w: w.
- Returns:
fn(weight), with the same shape asweight.- Return type:
ArrayLike
Examples
>>> import brainstate >>> import braintrace >>> >>> brainstate.environ.set(precision=64) >>> w = brainstate.random.randn(5) >>> y = braintrace.element_wise(w) >>> print(y.shape) (5,) >>> >>> # Apply a non-linearity to the weight >>> import jax.numpy as jnp >>> y1 = braintrace.element_wise(w, fn=jnp.tanh) >>> print(y1.shape) (5,)