braintrace.register_primitive#
- braintrace.register_primitive(name, impl_fn, *, batched=False, gradient_enabled=False, trainable_invars_fn=None, x_invar_index=0, y_outvar_index=0)[source]#
Create an
ETPPrimitivewith all JAX rules auto-derived.Only the four ETP-specific rules need hand-writing — call the returned primitive’s
register_*methods.- Parameters:
name (str) – Primitive name (e.g.
'etp_mm').impl_fn (Callable) – Implementation function.
batched (bool, optional) – Whether this primitive operates on batched inputs. Default
False.gradient_enabled (bool, optional) – If
True, the compiler will evaluate this primitive when walkingy -> h(identity-like ops such asetp_elemwise_p). DefaultFalse.trainable_invars_fn (Callable or None, optional) – Function
eqn.params -> {key: invar_index}declaring the primitive’s full trainable-input layout. Used by the compiler and executors to support N-trainable-input primitives (e.g.{weight, bias}for Linear,{B, A, bias}for LoRA). IfNone, the compiler falls back to the single-weight{'weight': 1}layout. DefaultNone.x_invar_index (int or None, optional) – Position of the input
xineqn.invars, orNonefor primitives with no external input (currently onlyetp_elemwise_p). Default0.y_outvar_index (int, optional) – Position of the output
yineqn.outvars. Default0.
- Returns:
The registered primitive.
- Return type:
Notes
The following standard JAX rules are installed automatically:
impl — eager execution.
abstract_eval — via
jax.eval_shape(impl).lowering — via
mlir.lower_fun(impl).JVP — via
jax.jvp(impl).transpose — derived by JAX from the JVP.
batching — via
jax.vmap(impl).