element_wise

Contents

element_wise#

class braintrace.element_wise(weight, fn=<function <lambda>>)[source]#

ETP-aware element-wise operation.

Applies fn to weight and passes the result through a marker primitive. The operation is treated as diagonal in the hidden-state space, so the weight participates in eligibility-trace computation as a per-element trainable parameter.

Parameters:
  • weight (ArrayLike) – Weight parameter.

  • fn (Callable, optional) – Element-wise function applied to weight before the primitive binds. Default is the identity lambda w: w.

Returns:

fn(weight), with the same shape as weight.

Return type:

ArrayLike

Examples

>>> import brainstate
>>> import braintrace
>>>
>>> brainstate.environ.set(precision=64)
>>> w = brainstate.random.randn(5)
>>> y = braintrace.element_wise(w)
>>> print(y.shape)
(5,)
>>>
>>> # Apply a non-linearity to the weight
>>> import jax.numpy as jnp
>>> y1 = braintrace.element_wise(w, fn=jnp.tanh)
>>> print(y1.shape)
(5,)