HiddenParamOpRelation#
- class braintrace.HiddenParamOpRelation(primitive: Primitive, x_var: Var | None, y_var: Var, hidden_groups: List[HiddenGroup], y_to_hidden_group_jaxprs: List[Jaxpr], connected_hidden_paths: List[Tuple[str, ...]], eqn_params: dict, path_classification: Dict[Tuple[str, ...], str] = {}, trainable_vars: Dict[str, Var] = {}, trainable_paths: Dict[str, Tuple[str, ...]] = {}, trainable_leaf_indices: Dict[str, int] = {}, trainable_param_states: Dict[str, ParamState] = {}, trainable_processing_chains: Dict[str, Tuple[Primitive, ...]] = {})#
Connection between an ETP primitive, its trainable parameters, and hidden states.
Records the structural relationship
\[h^t = f(y), \quad y = \mathrm{primitive}(x, \theta)\]discovered by the compiler for a single ETP primitive equation.
- primitive#
The JAX primitive (
etp_mm_p,etp_mv_p, etc.).- Type:
Primitive
- x_var#
Jaxpr
Varfor the input (Nonefor element-wise ops).- Type:
Var or None
- y_var#
Jaxpr
Varfor the primitive output.- Type:
Var
- hidden_groups#
Hidden groups that this op feeds into.
- Type:
list of HiddenGroup
- path_classification#
Mapping
{hidden_path: PathClassification.*}for each connected hidden state. Populated by the path-classification pass.- Type:
- trainable_vars#
Per-key dict mapping a primitive-chosen key name (e.g.
'weight','bias','lora_b','lora_a') to its jaxprVar, with one entry per declared trainable input.- Type:
- trainable_paths#
Per-key dict mapping each key to the owning
ParamState’s module path. When two keys trace to the sameParamState(e.g. a merged{weight, bias}Linear), the entries share a path.- Type:
- trainable_leaf_indices#
Per-key dict mapping each key to the leaf index in
jax.tree.leavesof the owningParamState.- Type:
- trainable_processing_chains#
Per-key dict mapping each key to the backward-trace processing chain (primitives traversed from the trainable invar back to the originating
ParamStateinvar).- Type:
- dict()[source]#
Return this relation’s named fields as a plain dictionary.
- Returns:
An ordered mapping from field name to value, as produced by the underlying
typing.NamedTuple.- Return type:
- hidden_groups: List[HiddenGroup]#
Alias for field number 3
- primitive: Primitive#
Alias for field number 0
- y_to_hidden_groups(y_val, const_vals, concat_hidden_vals=True)[source]#
Evaluate the transition jaxprs mapping
yto hidden-group values.- Parameters:
y_val (jax.Array) – The value of the primitive output
y.const_vals (dict) – Mapping from each transition-jaxpr constvar to its value.
concat_hidden_vals (bool, optional) – If
True, concatenate each group’s hidden values into a single array viaHiddenGroup.concat_hidden(). DefaultTrue.
- Returns:
One entry per hidden group: either a list of per-state arrays (when
concat_hidden_valsisFalse) or a single concatenated array (whenTrue).- Return type:
- y_var: Var#
Alias for field number 2