PresynapticTrace#
- class braintrace.PresynapticTrace(init_value, leak)#
Leaky presynaptic accumulator used by OTTT and OTPE-Approx.
The trace accumulates the presynaptic input with a multiplicative decay, following \(\hat{a} \leftarrow \lambda \cdot \hat{a} + x_t\).
- Parameters:
init_value (jax.Array) – Initial value; also dictates the shape and dtype of the trace.
leak (
float) – Decay factor \(\lambda\) in(0, 1). Pulled from the neuron’s membrane leak in SNN usage.
- Raises:
ValueError – If
leakis not strictly inside the open interval(0, 1).
Examples
>>> import jax.numpy as jnp >>> import braintrace >>> trace = braintrace.PresynapticTrace(jnp.zeros(3), leak=0.5) >>> out = trace.update(jnp.ones(3)) >>> print(out) [1. 1. 1.] >>> out = trace.update(jnp.ones(3)) >>> print(out) [1.5 1.5 1.5]