KappaFilter#
- class braintrace.KappaFilter(init_value, kappa)#
Low-pass output-side filter used by EProp.
The filter smooths the output-side signal following \(x_{\mathrm{filt}} \leftarrow (1-\kappa) \cdot x + \kappa \cdot x_{\mathrm{filt}}\).
- Parameters:
init_value (jax.Array) – Initial value; also dictates the shape and dtype of the filtered state.
kappa (
float) – Decay factor \(\kappa\) in[0, 1). A value of0disables filtering.
- Raises:
ValueError – If
kappais not inside the half-open interval[0, 1).
Examples
>>> import jax.numpy as jnp >>> import braintrace >>> filt = braintrace.KappaFilter(jnp.zeros(3), kappa=0.5) >>> out = filt.update(jnp.ones(3)) >>> print(out) [0.5 0.5 0.5] >>> out = filt.update(jnp.ones(3)) >>> print(out) [0.75 0.75 0.75]