ETPPrimitive#

class braintrace.ETPPrimitive(name)[source]#

A JAX Primitive with ETP rule registration helpers.

Returned by register_primitive(). Supports every standard JAX primitive operation (bind, def_impl, …) and adds five convenience methods for installing ETP-specific rules into the global registries.

See also

register_primitive

Factory that creates and returns an instance.

Examples

>>> import jax.numpy as jnp
>>> import braintrace
>>>
>>> # Register a primitive whose forward delegates to a standard op.
>>> def my_impl(x, w):
...     return x @ w
>>> my_p = braintrace.register_primitive('etp_demo_mm', my_impl, batched=True)
>>> y = my_p.bind(jnp.ones((2, 3)), jnp.ones((3, 4)))
>>> print(y.shape)
(2, 4)
register_etp_rules(*, yw_to_w=None, xy_to_dw=None, init_drtrl=None, init_pp=None)[source]#

Install multiple ETP rules in one call.

Any argument left as None is skipped.

Parameters:
  • yw_to_w (Callable) – D-RTRL trace propagation rule. Default None.

  • xy_to_dw (Callable) – Weight-gradient rule. Default None.

  • init_drtrl (Callable) – D-RTRL trace initialiser. Default None.

  • init_pp (Callable) – pp_prop (IO-dim) df trace initialiser. Default None.

register_init_drtrl(fn)[source]#

Install a D-RTRL trace initialiser.

Parameters:

fn (Callable) – Rule with signature (x_var, y_var, weight_var, num_hidden_state) -> zeros.

register_init_pp(fn)[source]#

Install a pp_prop (IO-dim) df trace initialiser.

Parameters:

fn (Callable) – Rule with signature (x_var, y_var, weight_var, num_hidden_state) -> zeros.

register_xy_to_dw(fn)[source]#

Install a weight-gradient rule.

Parameters:

fn (Callable) – Rule with signature (x, hidden_dim, w, **params) -> dw.

register_yw_to_w(fn)[source]#

Install a D-RTRL trace propagation rule.

Parameters:

fn (Callable) – Rule with signature (hidden_dim, trace, **params) -> trace.