HiddenParamOpRelation#

class braintrace.HiddenParamOpRelation(primitive: Primitive, x_var: Var | None, y_var: Var, hidden_groups: List[HiddenGroup], y_to_hidden_group_jaxprs: List[Jaxpr], connected_hidden_paths: List[Tuple[str, ...]], eqn_params: dict, path_classification: Dict[Tuple[str, ...], str] = {}, trainable_vars: Dict[str, Var] = {}, trainable_paths: Dict[str, Tuple[str, ...]] = {}, trainable_leaf_indices: Dict[str, int] = {}, trainable_param_states: Dict[str, ParamState] = {}, trainable_processing_chains: Dict[str, Tuple[Primitive, ...]] = {})#

Connection between an ETP primitive, its trainable parameters, and hidden states.

Records the structural relationship

\[h^t = f(y), \quad y = \mathrm{primitive}(x, \theta)\]

discovered by the compiler for a single ETP primitive equation.

primitive#

The JAX primitive (etp_mm_p, etp_mv_p, etc.).

Type:

Primitive

x_var#

Jaxpr Var for the input (None for element-wise ops).

Type:

Var or None

y_var#

Jaxpr Var for the primitive output.

Type:

Var

hidden_groups#

Hidden groups that this op feeds into.

Type:

list of HiddenGroup

y_to_hidden_group_jaxprs#

Transition jaxpr from y to each hidden group.

Type:

list of Jaxpr

connected_hidden_paths#

Hidden-state paths connected to this op.

Type:

list of Path

eqn_params#

Static parameters of the primitive equation.

Type:

dict

path_classification#

Mapping {hidden_path: PathClassification.*} for each connected hidden state. Populated by the path-classification pass.

Type:

dict

trainable_vars#

Per-key dict mapping a primitive-chosen key name (e.g. 'weight', 'bias', 'lora_b', 'lora_a') to its jaxpr Var, with one entry per declared trainable input.

Type:

dict

trainable_paths#

Per-key dict mapping each key to the owning ParamState’s module path. When two keys trace to the same ParamState (e.g. a merged {weight, bias} Linear), the entries share a path.

Type:

dict

trainable_leaf_indices#

Per-key dict mapping each key to the leaf index in jax.tree.leaves of the owning ParamState.

Type:

dict

trainable_param_states#

Per-key dict mapping each key to the actual ParamState object.

Type:

dict

trainable_processing_chains#

Per-key dict mapping each key to the backward-trace processing chain (primitives traversed from the trainable invar back to the originating ParamState invar).

Type:

dict

connected_hidden_paths: List[Tuple[str, ...]]#

Alias for field number 5

dict()[source]#

Return this relation’s named fields as a plain dictionary.

Returns:

An ordered mapping from field name to value, as produced by the underlying typing.NamedTuple.

Return type:

Dict[str, Any]

eqn_params: dict#

Alias for field number 6

hidden_groups: List[HiddenGroup]#

Alias for field number 3

path_classification: Dict[Tuple[str, ...], str]#

Alias for field number 7

primitive: Primitive#

Alias for field number 0

trainable_leaf_indices: Dict[str, int]#

Alias for field number 10

trainable_param_states: Dict[str, ParamState]#

Alias for field number 11

trainable_paths: Dict[str, Tuple[str, ...]]#

Alias for field number 9

trainable_processing_chains: Dict[str, Tuple[Primitive, ...]]#

Alias for field number 12

trainable_vars: Dict[str, Var]#

Alias for field number 8

x_var: Var | None#

Alias for field number 1

y_to_hidden_group_jaxprs: List[Jaxpr]#

Alias for field number 4

y_to_hidden_groups(y_val, const_vals, concat_hidden_vals=True)[source]#

Evaluate the transition jaxprs mapping y to hidden-group values.

Parameters:
  • y_val (jax.Array) – The value of the primitive output y.

  • const_vals (dict) – Mapping from each transition-jaxpr constvar to its value.

  • concat_hidden_vals (bool, optional) – If True, concatenate each group’s hidden values into a single array via HiddenGroup.concat_hidden(). Default True.

Returns:

One entry per hidden group: either a list of per-state arrays (when concat_hidden_vals is False) or a single concatenated array (when True).

Return type:

list

y_var: Var#

Alias for field number 2