diagnostics
diagnostics
¶
Feature relevance diagnostics for MultiHeadMLPLearner.
Extracts per-feature, per-head relevance metrics from existing learner state without modifying the update/predict hot path. Designed for periodic diagnostic reporting in daemon deployments (e.g. rlsecd).
Tier 1 metrics are zero-cost (state extraction only, no forward pass). Tier 2 metrics (feature sensitivity) require a Jacobian computation.
FeatureRelevance
¶
Per-feature and per-head relevance metrics extracted from learner state.
All fields are derived from existing MultiHeadMLPState arrays.
No forward pass is required.
Attributes:
weight_relevance: Path-norm relevance from input features to each head.
Shape (n_heads, feature_dim).
step_size_activity: Mean absolute step-size on input layer per feature.
Shape (feature_dim,).
trace_activity: Mean absolute trunk trace magnitude on input layer
per feature. Shape (feature_dim,).
normalizer_mean: Per-feature normalizer mean estimate, or None if no
normalizer. Shape (feature_dim,).
normalizer_std: Per-feature normalizer std estimate, or None if no
normalizer. Shape (feature_dim,).
head_reliance: L1 norm of each head's weight vector over the last
hidden layer. Shape (n_heads, hidden_dim_last).
head_mean_step_size: Mean step-size per head, or None if optimizer
has no per-weight step-sizes. Shape (n_heads,).
compute_feature_relevance(state)
¶
Extract per-feature relevance metrics from multi-head learner state.
All metrics are computed from existing state arrays via small matrix multiplies. Typical cost: ~10-50us after JIT for a (64,64) trunk with 5 heads and 12 features.
Args: state: Current multi-head MLP learner state.
Returns:
FeatureRelevance dataclass with all Tier 1 metrics.
Source code in src/alberta_framework/core/diagnostics.py
54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 | |
compute_feature_sensitivity(learner, state, observation)
¶
Compute per-head sensitivity to each input feature via Jacobian.
Uses jax.jacrev to compute d(pred_h)/d(obs_f) for all heads
and features. This is a Tier 2 metric requiring one forward pass
per output (5 for 5 heads). Typical cost: ~100-500us for a (64,64)
trunk.
jacrev is used because output dim (n_heads) < input dim
(feature_dim), making reverse-mode more efficient.
Args:
learner: The multi-head MLP learner instance.
state: Current learner state.
observation: Input feature vector, shape (feature_dim,).
Returns:
Jacobian array of shape (n_heads, feature_dim) where entry
[h, f] is the sensitivity of head h's prediction to
feature f at this observation.
Source code in src/alberta_framework/core/diagnostics.py
relevance_to_dict(relevance, feature_names=None, head_names=None)
¶
Convert FeatureRelevance to a JSON-serializable dict.
Produces a structured dict suitable for logging or inspection.
Includes normalized_weight_relevance when normalizer state is
available, which scales weight relevance by normalizer std to give
relevance in raw input units.
Args:
relevance: FeatureRelevance from compute_feature_relevance.
feature_names: Optional list of feature names. If None, uses
"feature_0", "feature_1", etc.
head_names: Optional list of head names. If None, uses
"head_0", "head_1", etc.
Returns:
Nested dict with "trunk" and "per_head" sections.