PresynapticTrace

PresynapticTrace#

class braintrace.PresynapticTrace(init_value, leak)#

Leaky presynaptic accumulator used by OTTT and OTPE-Approx.

The trace accumulates the presynaptic input with a multiplicative decay, following \(\hat{a} \leftarrow \lambda \cdot \hat{a} + x_t\).

Parameters:
  • init_value (jax.Array) – Initial value; also dictates the shape and dtype of the trace.

  • leak (float) – Decay factor \(\lambda\) in (0, 1). Pulled from the neuron’s membrane leak in SNN usage.

Raises:

ValueError – If leak is not strictly inside the open interval (0, 1).

Examples

>>> import jax.numpy as jnp
>>> import braintrace
>>> trace = braintrace.PresynapticTrace(jnp.zeros(3), leak=0.5)
>>> out = trace.update(jnp.ones(3))
>>> print(out)
[1. 1. 1.]
>>> out = trace.update(jnp.ones(3))
>>> print(out)
[1.5 1.5 1.5]
update(x)[source]#

Apply one accumulation step \(\hat{a} \leftarrow \lambda \cdot \hat{a} + x\).

Parameters:

x (jax.Array) – The new presynaptic input added to the decayed trace.

Returns:

The updated trace value.

Return type:

jax.Array