ETPPrimitive#
- class braintrace.ETPPrimitive(name)[source]#
A JAX
Primitivewith ETP rule registration helpers.Returned by
register_primitive(). Supports every standard JAX primitive operation (bind,def_impl, …) and adds five convenience methods for installing ETP-specific rules into the global registries.See also
register_primitiveFactory that creates and returns an instance.
Examples
>>> import jax.numpy as jnp >>> import braintrace >>> >>> # Register a primitive whose forward delegates to a standard op. >>> def my_impl(x, w): ... return x @ w >>> my_p = braintrace.register_primitive('etp_demo_mm', my_impl, batched=True) >>> y = my_p.bind(jnp.ones((2, 3)), jnp.ones((3, 4))) >>> print(y.shape) (2, 4)
- register_etp_rules(*, yw_to_w=None, xy_to_dw=None, init_drtrl=None, init_pp=None)[source]#
Install multiple ETP rules in one call.
Any argument left as
Noneis skipped.
- register_init_drtrl(fn)[source]#
Install a D-RTRL trace initialiser.
- Parameters:
fn (
Callable) – Rule with signature(x_var, y_var, weight_var, num_hidden_state) -> zeros.
- register_init_pp(fn)[source]#
Install a pp_prop (IO-dim) df trace initialiser.
- Parameters:
fn (
Callable) – Rule with signature(x_var, y_var, weight_var, num_hidden_state) -> zeros.