braintrace.register_primitive

braintrace.register_primitive#

braintrace.register_primitive(name, impl_fn, *, batched=False, gradient_enabled=False, trainable_invars_fn=None, x_invar_index=0, y_outvar_index=0)[source]#

Create an ETPPrimitive with all JAX rules auto-derived.

Only the four ETP-specific rules need hand-writing — call the returned primitive’s register_* methods.

Parameters:
  • name (str) – Primitive name (e.g. 'etp_mm').

  • impl_fn (Callable) – Implementation function.

  • batched (bool, optional) – Whether this primitive operates on batched inputs. Default False.

  • gradient_enabled (bool, optional) – If True, the compiler will evaluate this primitive when walking y -> h (identity-like ops such as etp_elemwise_p). Default False.

  • trainable_invars_fn (Callable or None, optional) – Function eqn.params -> {key: invar_index} declaring the primitive’s full trainable-input layout. Used by the compiler and executors to support N-trainable-input primitives (e.g. {weight, bias} for Linear, {B, A, bias} for LoRA). If None, the compiler falls back to the single-weight {'weight': 1} layout. Default None.

  • x_invar_index (int or None, optional) – Position of the input x in eqn.invars, or None for primitives with no external input (currently only etp_elemwise_p). Default 0.

  • y_outvar_index (int, optional) – Position of the output y in eqn.outvars. Default 0.

Returns:

The registered primitive.

Return type:

ETPPrimitive

Notes

The following standard JAX rules are installed automatically:

  • impl — eager execution.

  • abstract_eval — via jax.eval_shape(impl).

  • lowering — via mlir.lower_fun(impl).

  • JVP — via jax.jvp(impl).

  • transpose — derived by JAX from the JVP.

  • batching — via jax.vmap(impl).