KappaFilter

KappaFilter#

class braintrace.KappaFilter(init_value, kappa)#

Low-pass output-side filter used by EProp.

The filter smooths the output-side signal following \(x_{\mathrm{filt}} \leftarrow (1-\kappa) \cdot x + \kappa \cdot x_{\mathrm{filt}}\).

Parameters:
  • init_value (jax.Array) – Initial value; also dictates the shape and dtype of the filtered state.

  • kappa (float) – Decay factor \(\kappa\) in [0, 1). A value of 0 disables filtering.

Raises:

ValueError – If kappa is not inside the half-open interval [0, 1).

Examples

>>> import jax.numpy as jnp
>>> import braintrace
>>> filt = braintrace.KappaFilter(jnp.zeros(3), kappa=0.5)
>>> out = filt.update(jnp.ones(3))
>>> print(out)
[0.5 0.5 0.5]
>>> out = filt.update(jnp.ones(3))
>>> print(out)
[0.75 0.75 0.75]
update(x)[source]#

Apply one low-pass step \(x_{\mathrm{filt}} \leftarrow (1-\kappa) x + \kappa\, x_{\mathrm{filt}}\).

Parameters:

x (jax.Array) – The new input mixed into the filtered state.

Returns:

The updated, filtered value.

Return type:

jax.Array