invarlock 0.3.4__py3-none-any.whl → 0.3.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (48) hide show
  1. invarlock/__init__.py +1 -1
  2. invarlock/_data/runtime/tiers.yaml +57 -30
  3. invarlock/adapters/__init__.py +1 -1
  4. invarlock/calibration/spectral_null.py +15 -10
  5. invarlock/calibration/variance_ve.py +0 -2
  6. invarlock/cli/commands/calibrate.py +6 -2
  7. invarlock/cli/commands/certify.py +58 -39
  8. invarlock/cli/commands/doctor.py +3 -1
  9. invarlock/cli/commands/explain_gates.py +57 -8
  10. invarlock/cli/commands/report.py +1 -1
  11. invarlock/cli/commands/run.py +159 -61
  12. invarlock/cli/commands/verify.py +78 -4
  13. invarlock/cli/config.py +21 -5
  14. invarlock/core/api.py +45 -5
  15. invarlock/core/auto_tuning.py +65 -20
  16. invarlock/core/contracts.py +7 -1
  17. invarlock/core/registry.py +2 -2
  18. invarlock/core/runner.py +314 -50
  19. invarlock/eval/bench.py +0 -13
  20. invarlock/eval/data.py +73 -283
  21. invarlock/eval/metrics.py +134 -4
  22. invarlock/eval/primary_metric.py +23 -0
  23. invarlock/eval/tail_stats.py +230 -0
  24. invarlock/guards/_estimators.py +154 -0
  25. invarlock/guards/policies.py +16 -6
  26. invarlock/guards/rmt.py +625 -544
  27. invarlock/guards/spectral.py +348 -110
  28. invarlock/guards/tier_config.py +32 -30
  29. invarlock/guards/variance.py +5 -29
  30. invarlock/guards_ref/rmt_ref.py +23 -23
  31. invarlock/model_profile.py +42 -15
  32. invarlock/reporting/certificate.py +225 -46
  33. invarlock/reporting/certificate_schema.py +2 -1
  34. invarlock/reporting/dataset_hashing.py +15 -2
  35. invarlock/reporting/guards_analysis.py +197 -274
  36. invarlock/reporting/normalizer.py +6 -0
  37. invarlock/reporting/policy_utils.py +38 -36
  38. invarlock/reporting/primary_metric_utils.py +71 -17
  39. invarlock/reporting/render.py +61 -0
  40. invarlock/reporting/report.py +1 -1
  41. invarlock/reporting/report_types.py +5 -2
  42. invarlock/reporting/validate.py +1 -18
  43. {invarlock-0.3.4.dist-info → invarlock-0.3.6.dist-info}/METADATA +6 -6
  44. {invarlock-0.3.4.dist-info → invarlock-0.3.6.dist-info}/RECORD +48 -46
  45. {invarlock-0.3.4.dist-info → invarlock-0.3.6.dist-info}/WHEEL +0 -0
  46. {invarlock-0.3.4.dist-info → invarlock-0.3.6.dist-info}/entry_points.txt +0 -0
  47. {invarlock-0.3.4.dist-info → invarlock-0.3.6.dist-info}/licenses/LICENSE +0 -0
  48. {invarlock-0.3.4.dist-info → invarlock-0.3.6.dist-info}/top_level.txt +0 -0
@@ -623,6 +623,9 @@ def compute_primary_metric_from_report(
623
623
  "preview": float("nan"),
624
624
  "final": float("nan"),
625
625
  "ratio_vs_baseline": float("nan"),
626
+ "invalid": True,
627
+ "degraded": True,
628
+ "degraded_reason": "non_finite_pm",
626
629
  }
627
630
  # For accuracy kinds, derive counts from input_ids if aggregates are missing
628
631
  if kind in {"accuracy", "vqa_accuracy"}:
@@ -661,6 +664,11 @@ def compute_primary_metric_from_report(
661
664
  final_point = metric.point_from_windows(windows=final_win)
662
665
 
663
666
  ratio_vs_baseline = float("nan")
667
+ baseline_has_reference = False
668
+
669
+ def _is_finite(value: Any) -> bool:
670
+ return isinstance(value, (int, float)) and math.isfinite(float(value))
671
+
664
672
  if isinstance(baseline, dict):
665
673
  try:
666
674
  base_metrics = (
@@ -686,14 +694,25 @@ def compute_primary_metric_from_report(
686
694
  is_ppl_like = str(kind).lower().startswith("ppl")
687
695
  if is_ppl_like and base_ref > 0:
688
696
  ratio_vs_baseline = float(final_point) / float(base_ref)
697
+ baseline_has_reference = True
689
698
  elif (
690
699
  str(kind).lower() in {"accuracy", "vqa_accuracy"}
691
700
  and 0 <= base_ref <= 1
692
701
  ):
693
702
  ratio_vs_baseline = float(final_point) - float(base_ref)
703
+ baseline_has_reference = True
694
704
  except Exception:
695
705
  ratio_vs_baseline = float("nan")
696
706
 
707
+ invalid = not (_is_finite(preview_point) and _is_finite(final_point))
708
+ degraded_reason = None
709
+ if invalid:
710
+ degraded_reason = "non_finite_pm"
711
+ elif baseline_has_reference and not _is_finite(ratio_vs_baseline):
712
+ degraded_reason = "non_finite_delta"
713
+
714
+ degraded = bool(degraded_reason or invalid)
715
+
697
716
  payload = {
698
717
  "kind": metric.kind,
699
718
  "unit": metric.unit,
@@ -705,7 +724,11 @@ def compute_primary_metric_from_report(
705
724
  "preview": preview_point,
706
725
  "final": final_point,
707
726
  "ratio_vs_baseline": ratio_vs_baseline,
727
+ "invalid": invalid,
728
+ "degraded": degraded,
708
729
  }
730
+ if degraded and degraded_reason:
731
+ payload["degraded_reason"] = degraded_reason
709
732
  # Carry counts for accuracy to aid gating
710
733
  if kind in {"accuracy", "vqa_accuracy"}:
711
734
  if "n_prev" in locals() and n_prev is not None:
@@ -0,0 +1,230 @@
1
+ from __future__ import annotations
2
+
3
+ import math
4
+ from collections.abc import Mapping, Sequence
5
+ from typing import Any
6
+
7
+ __all__ = [
8
+ "compute_tail_summary",
9
+ "evaluate_metric_tail",
10
+ ]
11
+
12
+
13
+ def _as_finite_float(value: Any) -> float | None:
14
+ try:
15
+ out = float(value)
16
+ except Exception:
17
+ return None
18
+ return out if math.isfinite(out) else None
19
+
20
+
21
+ def _linear_quantile(sorted_values: Sequence[float], q: float) -> float:
22
+ """Deterministic linear-interpolated quantile on sorted values (q in [0, 1])."""
23
+ n = len(sorted_values)
24
+ if n == 0:
25
+ return float("nan")
26
+ if n == 1:
27
+ return float(sorted_values[0])
28
+ if q <= 0.0:
29
+ return float(sorted_values[0])
30
+ if q >= 1.0:
31
+ return float(sorted_values[-1])
32
+ pos = float(q) * float(n - 1)
33
+ lo = int(math.floor(pos))
34
+ hi = int(math.ceil(pos))
35
+ if lo == hi:
36
+ return float(sorted_values[lo])
37
+ frac = pos - float(lo)
38
+ a = float(sorted_values[lo])
39
+ b = float(sorted_values[hi])
40
+ return a + frac * (b - a)
41
+
42
+
43
+ def compute_tail_summary(
44
+ deltas: Sequence[float] | Sequence[Any],
45
+ *,
46
+ quantiles: Sequence[float] = (0.5, 0.9, 0.95, 0.99),
47
+ epsilon: float = 1e-4,
48
+ weights: Sequence[float] | Sequence[Any] | None = None,
49
+ ) -> dict[str, Any]:
50
+ """Compute deterministic tail summaries for Δlog-loss samples.
51
+
52
+ - Quantiles are computed unweighted using linear interpolation on sorted values.
53
+ - tail_mass is Pr[delta > epsilon] (unweighted).
54
+ - tail_mass_weighted is included when weights are provided and finite.
55
+ """
56
+ eps = _as_finite_float(epsilon)
57
+ if eps is None or eps < 0.0:
58
+ eps = 0.0
59
+
60
+ values: list[float] = []
61
+ paired_weights: list[float] | None = [] if weights is not None else None
62
+
63
+ if weights is None:
64
+ for d in deltas:
65
+ dv = _as_finite_float(d)
66
+ if dv is None:
67
+ continue
68
+ values.append(float(dv))
69
+ else:
70
+ for d, w in zip(deltas, weights, strict=False):
71
+ dv = _as_finite_float(d)
72
+ if dv is None:
73
+ continue
74
+ wv = _as_finite_float(w)
75
+ if wv is None or wv < 0.0:
76
+ wv = 0.0
77
+ values.append(float(dv))
78
+ if paired_weights is not None:
79
+ paired_weights.append(float(wv))
80
+
81
+ n = int(len(values))
82
+ values_sorted = sorted(values)
83
+
84
+ summary: dict[str, Any] = {
85
+ "n": n,
86
+ "epsilon": float(eps),
87
+ }
88
+ if n == 0:
89
+ summary.update({"max": float("nan"), "tail_mass": 0.0})
90
+ for q in quantiles:
91
+ try:
92
+ qf = float(q)
93
+ except Exception:
94
+ continue
95
+ label = f"q{int(round(100.0 * max(0.0, min(1.0, qf))))}"
96
+ summary[label] = float("nan")
97
+ return summary
98
+
99
+ summary["max"] = float(values_sorted[-1])
100
+ tail_ct = sum(1 for v in values if v > eps)
101
+ summary["tail_mass"] = float(tail_ct / n)
102
+
103
+ if paired_weights is not None:
104
+ total_w = 0.0
105
+ tail_w = 0.0
106
+ for v, w in zip(values, paired_weights, strict=False):
107
+ total_w += float(w)
108
+ if v > eps:
109
+ tail_w += float(w)
110
+ if total_w > 0.0:
111
+ summary["tail_mass_weighted"] = float(tail_w / total_w)
112
+ summary["tail_mass_weighted_by"] = "weights"
113
+
114
+ for q in quantiles:
115
+ try:
116
+ qf = float(q)
117
+ except Exception:
118
+ continue
119
+ qf = max(0.0, min(1.0, qf))
120
+ label = f"q{int(round(100.0 * qf))}"
121
+ summary[label] = float(_linear_quantile(values_sorted, qf))
122
+
123
+ return summary
124
+
125
+
126
+ def evaluate_metric_tail(
127
+ *,
128
+ deltas: Sequence[float] | Sequence[Any],
129
+ policy: Mapping[str, Any] | None = None,
130
+ weights: Sequence[float] | Sequence[Any] | None = None,
131
+ ) -> dict[str, Any]:
132
+ """Evaluate a tail policy against Δlog-loss samples.
133
+
134
+ Policy keys:
135
+ - mode: "off" | "warn" | "fail" (default: "warn")
136
+ - min_windows: int (default: 1)
137
+ - quantile: float in [0, 1] (default: 0.95)
138
+ - quantile_max: float threshold in Δlog-loss (optional)
139
+ - epsilon: float deadband for tail_mass (default: 1e-4)
140
+ - mass_max: float in [0, 1] (optional)
141
+ """
142
+ pol = dict(policy or {})
143
+ mode = str(pol.get("mode", "warn") or "warn").strip().lower()
144
+ if mode not in {"off", "warn", "fail"}:
145
+ mode = "warn"
146
+
147
+ min_windows = pol.get("min_windows", 1)
148
+ try:
149
+ min_windows_i = int(min_windows)
150
+ except Exception:
151
+ min_windows_i = 1
152
+ min_windows_i = max(1, min_windows_i)
153
+
154
+ q = _as_finite_float(pol.get("quantile", 0.95))
155
+ if q is None:
156
+ q = 0.95
157
+ q = max(0.0, min(1.0, float(q)))
158
+
159
+ eps = _as_finite_float(pol.get("epsilon", 1e-4))
160
+ if eps is None or eps < 0.0:
161
+ eps = 0.0
162
+
163
+ qmax = _as_finite_float(pol.get("quantile_max"))
164
+ mmax = _as_finite_float(pol.get("mass_max"))
165
+ if mmax is not None:
166
+ mmax = max(0.0, min(1.0, float(mmax)))
167
+
168
+ quantiles = sorted({0.5, 0.9, 0.95, 0.99, float(q)})
169
+ stats = compute_tail_summary(
170
+ deltas, quantiles=tuple(quantiles), epsilon=float(eps), weights=weights
171
+ )
172
+ n = int(stats.get("n", 0) or 0)
173
+
174
+ thresholds_present = (qmax is not None) or (mmax is not None)
175
+ evaluated = bool(mode != "off" and thresholds_present and n >= min_windows_i)
176
+
177
+ violations: list[dict[str, Any]] = []
178
+ passed = True
179
+ if evaluated:
180
+ passed = True
181
+ q_label = f"q{int(round(100.0 * q))}"
182
+ q_obs = stats.get(q_label)
183
+ if not (isinstance(q_obs, int | float) and math.isfinite(float(q_obs))):
184
+ q_obs = float("nan")
185
+ if qmax is not None and math.isfinite(q_obs) and q_obs > float(qmax):
186
+ passed = False
187
+ violations.append(
188
+ {
189
+ "type": "quantile_max_exceeded",
190
+ "quantile": float(q),
191
+ "observed": float(q_obs),
192
+ "threshold": float(qmax),
193
+ }
194
+ )
195
+
196
+ tail_mass = stats.get("tail_mass")
197
+ if (
198
+ mmax is not None
199
+ and isinstance(tail_mass, int | float)
200
+ and math.isfinite(float(tail_mass))
201
+ and float(tail_mass) > float(mmax)
202
+ ):
203
+ passed = False
204
+ violations.append(
205
+ {
206
+ "type": "tail_mass_exceeded",
207
+ "epsilon": float(eps),
208
+ "observed": float(tail_mass),
209
+ "threshold": float(mmax),
210
+ }
211
+ )
212
+
213
+ warned = bool(evaluated and (not passed) and mode == "warn")
214
+
215
+ return {
216
+ "mode": mode,
217
+ "evaluated": evaluated,
218
+ "passed": bool(passed),
219
+ "warned": warned,
220
+ "violations": violations,
221
+ "policy": {
222
+ "mode": mode,
223
+ "min_windows": int(min_windows_i),
224
+ "quantile": float(q),
225
+ "quantile_max": float(qmax) if qmax is not None else None,
226
+ "epsilon": float(eps),
227
+ "mass_max": float(mmax) if mmax is not None else None,
228
+ },
229
+ "stats": stats,
230
+ }
@@ -0,0 +1,154 @@
1
+ from __future__ import annotations
2
+
3
+ import math
4
+ from typing import Any
5
+
6
+ import torch
7
+
8
+ __all__ = [
9
+ "power_iter_sigma_max",
10
+ "frobenius_norm_sq",
11
+ "row_col_norm_extrema",
12
+ "stable_rank_estimate",
13
+ ]
14
+
15
+
16
+ def _as_matrix(tensor: torch.Tensor) -> torch.Tensor:
17
+ if tensor.ndim == 2:
18
+ return tensor
19
+ return tensor.view(tensor.shape[0], -1)
20
+
21
+
22
+ def power_iter_sigma_max(
23
+ matrix: Any,
24
+ *,
25
+ iters: int,
26
+ init: str = "ones",
27
+ eps: float = 1e-12,
28
+ ) -> float:
29
+ """Estimate the largest singular value (spectral norm) via fixed-iter power iteration.
30
+
31
+ Contract properties (vNext):
32
+ - fixed iteration budget (no convergence stopping)
33
+ - deterministic initialization (`init`)
34
+ - device-resident matvecs (no `.cpu()` transfers)
35
+ """
36
+ try:
37
+ iters_i = int(iters)
38
+ except Exception:
39
+ iters_i = 4
40
+ if iters_i < 1:
41
+ iters_i = 1
42
+
43
+ if not isinstance(matrix, torch.Tensor):
44
+ return 0.0
45
+ if matrix.numel() == 0:
46
+ return 0.0
47
+ if matrix.dtype in {torch.int8, torch.uint8}:
48
+ return 0.0
49
+
50
+ W = _as_matrix(matrix.detach())
51
+ if W.numel() == 0 or W.shape[0] == 0 or W.shape[1] == 0:
52
+ return 0.0
53
+
54
+ device = W.device
55
+ dtype = W.dtype
56
+ n = int(W.shape[1])
57
+
58
+ with torch.no_grad():
59
+ if init == "ones":
60
+ v = torch.ones((n,), device=device, dtype=dtype)
61
+ else:
62
+ # Deterministic fallback: unit vector e0.
63
+ v = torch.zeros((n,), device=device, dtype=dtype)
64
+ v[0] = 1
65
+
66
+ v_norm = torch.linalg.vector_norm(v.float()).clamp_min(eps)
67
+ v = v / v_norm.to(dtype)
68
+
69
+ sigma = 0.0
70
+ for _ in range(iters_i):
71
+ u = W @ v
72
+ u_norm = torch.linalg.vector_norm(u.float()).clamp_min(eps)
73
+ sigma_val = float(u_norm.item())
74
+ if not math.isfinite(sigma_val):
75
+ return 0.0
76
+ u = u / u_norm.to(dtype)
77
+ v = W.T @ u
78
+ v_norm = torch.linalg.vector_norm(v.float()).clamp_min(eps)
79
+ v = v / v_norm.to(dtype)
80
+ sigma = sigma_val
81
+ return float(sigma)
82
+
83
+
84
+ def frobenius_norm_sq(matrix: torch.Tensor) -> float:
85
+ """Return ||matrix||_F^2 with float32 accumulation (device-resident)."""
86
+ W = _as_matrix(matrix.detach())
87
+ if W.numel() == 0:
88
+ return 0.0
89
+ with torch.no_grad():
90
+ # Use a fused reduction to avoid materializing a W*W intermediate.
91
+ norm = torch.linalg.vector_norm(W.reshape(-1), ord=2, dtype=torch.float32)
92
+ out = float((norm * norm).item())
93
+ return out if math.isfinite(out) else 0.0
94
+
95
+
96
+ def row_col_norm_extrema(
97
+ matrix: torch.Tensor, *, eps: float = 1e-12
98
+ ) -> dict[str, float]:
99
+ """Compute min/median/max of row/col L2 norms with float32 accumulation."""
100
+ W = _as_matrix(matrix.detach())
101
+ if W.numel() == 0 or W.shape[0] == 0 or W.shape[1] == 0:
102
+ return {
103
+ "row_min": 0.0,
104
+ "row_median": 0.0,
105
+ "row_max": 0.0,
106
+ "col_min": 0.0,
107
+ "col_median": 0.0,
108
+ "col_max": 0.0,
109
+ }
110
+ with torch.no_grad():
111
+ # Avoid materializing W*W: use fused reductions.
112
+ row = torch.linalg.vector_norm(W, ord=2, dim=1, dtype=torch.float32).clamp_min(
113
+ eps
114
+ )
115
+ col = torch.linalg.vector_norm(W, ord=2, dim=0, dtype=torch.float32).clamp_min(
116
+ eps
117
+ )
118
+
119
+ row_sorted, _ = torch.sort(row)
120
+ col_sorted, _ = torch.sort(col)
121
+
122
+ def _median(sorted_vec: torch.Tensor) -> float:
123
+ n = int(sorted_vec.numel())
124
+ if n <= 0:
125
+ return 0.0
126
+ mid = n // 2
127
+ if n % 2 == 1:
128
+ return float(sorted_vec[mid].item())
129
+ return float((sorted_vec[mid - 1] + sorted_vec[mid]).mul(0.5).item())
130
+
131
+ return {
132
+ "row_min": float(row_sorted[0].item()),
133
+ "row_median": _median(row_sorted),
134
+ "row_max": float(row_sorted[-1].item()),
135
+ "col_min": float(col_sorted[0].item()),
136
+ "col_median": _median(col_sorted),
137
+ "col_max": float(col_sorted[-1].item()),
138
+ }
139
+
140
+
141
+ def stable_rank_estimate(
142
+ matrix: torch.Tensor, *, sigma_max: float, eps: float = 1e-12
143
+ ) -> float:
144
+ """Estimate stable rank: ||W||_F^2 / ||W||_2^2, using a provided σ̂max."""
145
+ try:
146
+ denom = float(sigma_max) ** 2
147
+ except Exception:
148
+ return 0.0
149
+ if not math.isfinite(denom) or denom <= 0.0:
150
+ return 0.0
151
+ denom = max(denom, eps)
152
+ num = frobenius_norm_sq(matrix)
153
+ out = float(num) / denom if denom > 0 else 0.0
154
+ return out if math.isfinite(out) else 0.0
@@ -15,7 +15,7 @@ from typing import Any, Literal
15
15
 
16
16
  try: # Python 3.12+
17
17
  from typing import NotRequired, TypedDict
18
- except ImportError: # Legacy fallback
18
+ except ImportError: # Python <3.12 fallback
19
19
  from typing import NotRequired
20
20
 
21
21
  from typing_extensions import TypedDict
@@ -40,6 +40,7 @@ SPECTRAL_CONSERVATIVE: SpectralPolicy = {
40
40
  "scope": "ffn", # FFN layers only (safest)
41
41
  "correction_enabled": True,
42
42
  "max_caps": 3,
43
+ "max_spectral_norm": None,
43
44
  "multiple_testing": {"method": "bonferroni", "alpha": 0.02, "m": 4},
44
45
  }
45
46
 
@@ -50,6 +51,7 @@ SPECTRAL_BALANCED: SpectralPolicy = {
50
51
  "scope": "ffn", # FFN layers only
51
52
  "correction_enabled": False,
52
53
  "max_caps": 5,
54
+ "max_spectral_norm": None,
53
55
  "multiple_testing": {"method": "bh", "alpha": 0.05, "m": 4},
54
56
  }
55
57
 
@@ -60,6 +62,7 @@ SPECTRAL_AGGRESSIVE: SpectralPolicy = {
60
62
  "scope": "all", # All layers including attention
61
63
  "correction_enabled": True,
62
64
  "max_caps": 8,
65
+ "max_spectral_norm": None,
63
66
  "multiple_testing": {"method": "bh", "alpha": 0.1, "m": 4},
64
67
  }
65
68
 
@@ -70,6 +73,7 @@ SPECTRAL_ATTN_AWARE: SpectralPolicy = {
70
73
  "scope": "attn", # Attention layers only
71
74
  "correction_enabled": False,
72
75
  "max_caps": 5,
76
+ "max_spectral_norm": None,
73
77
  "multiple_testing": {"method": "bh", "alpha": 0.05, "m": 4},
74
78
  }
75
79
 
@@ -81,7 +85,8 @@ RMT_CONSERVATIVE: RMTPolicyDict = {
81
85
  "deadband": 0.05, # 5% deadband - strict threshold
82
86
  "margin": 1.3, # Lower margin for conservative detection
83
87
  "correct": True, # Enable automatic correction
84
- "epsilon": {"attn": 0.05, "ffn": 0.06, "embed": 0.07, "other": 0.07},
88
+ "epsilon_default": 0.06,
89
+ "epsilon_by_family": {"attn": 0.05, "ffn": 0.06, "embed": 0.07, "other": 0.07},
85
90
  }
86
91
 
87
92
  # Balanced RMT policy - good for most use cases
@@ -90,7 +95,8 @@ RMT_BALANCED: RMTPolicyDict = {
90
95
  "deadband": 0.10, # 10% deadband - reasonable tolerance
91
96
  "margin": 1.5, # Standard margin for outlier detection
92
97
  "correct": False, # Monitor-only by default
93
- "epsilon": {"attn": 0.08, "ffn": 0.10, "embed": 0.12, "other": 0.12},
98
+ "epsilon_default": 0.10,
99
+ "epsilon_by_family": {"attn": 0.08, "ffn": 0.10, "embed": 0.12, "other": 0.12},
94
100
  }
95
101
 
96
102
  # Aggressive RMT policy - for research/experimental use
@@ -99,7 +105,8 @@ RMT_AGGRESSIVE: RMTPolicyDict = {
99
105
  "deadband": 0.15, # 15% deadband - more permissive
100
106
  "margin": 1.8, # Higher margin allows more deviation
101
107
  "correct": True, # Enable automatic correction
102
- "epsilon": {"attn": 0.15, "ffn": 0.15, "embed": 0.15, "other": 0.15},
108
+ "epsilon_default": 0.15,
109
+ "epsilon_by_family": {"attn": 0.15, "ffn": 0.15, "embed": 0.15, "other": 0.15},
103
110
  }
104
111
 
105
112
  # === Variance Guard Policies ===
@@ -276,6 +283,8 @@ def get_spectral_policy(
276
283
  policy["scope"] = tier_config["scope"]
277
284
  if "max_caps" in tier_config:
278
285
  policy["max_caps"] = tier_config["max_caps"]
286
+ if "max_spectral_norm" in tier_config:
287
+ policy["max_spectral_norm"] = tier_config["max_spectral_norm"]
279
288
  if "family_caps" in tier_config:
280
289
  policy["family_caps"] = tier_config["family_caps"]
281
290
  if "multiple_testing" in tier_config:
@@ -390,9 +399,10 @@ def get_rmt_policy(name: str = "balanced", *, use_yaml: bool = True) -> RMTPolic
390
399
  policy["deadband"] = tier_config["deadband"]
391
400
  if "margin" in tier_config:
392
401
  policy["margin"] = tier_config["margin"]
393
- # Use epsilon_by_family as the epsilon dict
402
+ if "epsilon_default" in tier_config:
403
+ policy["epsilon_default"] = tier_config["epsilon_default"]
394
404
  if "epsilon_by_family" in tier_config:
395
- policy["epsilon"] = tier_config["epsilon_by_family"]
405
+ policy["epsilon_by_family"] = tier_config["epsilon_by_family"]
396
406
  except Exception:
397
407
  # Fallback to hardcoded values on any error
398
408
  pass