invarlock 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (132) hide show
  1. invarlock/__init__.py +33 -0
  2. invarlock/__main__.py +10 -0
  3. invarlock/_data/runtime/profiles/ci_cpu.yaml +15 -0
  4. invarlock/_data/runtime/profiles/release.yaml +23 -0
  5. invarlock/_data/runtime/tiers.yaml +76 -0
  6. invarlock/adapters/__init__.py +102 -0
  7. invarlock/adapters/_capabilities.py +45 -0
  8. invarlock/adapters/auto.py +99 -0
  9. invarlock/adapters/base.py +530 -0
  10. invarlock/adapters/base_types.py +85 -0
  11. invarlock/adapters/hf_bert.py +852 -0
  12. invarlock/adapters/hf_gpt2.py +403 -0
  13. invarlock/adapters/hf_llama.py +485 -0
  14. invarlock/adapters/hf_mixin.py +383 -0
  15. invarlock/adapters/hf_onnx.py +112 -0
  16. invarlock/adapters/hf_t5.py +137 -0
  17. invarlock/adapters/py.typed +1 -0
  18. invarlock/assurance/__init__.py +43 -0
  19. invarlock/cli/__init__.py +8 -0
  20. invarlock/cli/__main__.py +8 -0
  21. invarlock/cli/_evidence.py +25 -0
  22. invarlock/cli/_json.py +75 -0
  23. invarlock/cli/adapter_auto.py +162 -0
  24. invarlock/cli/app.py +287 -0
  25. invarlock/cli/commands/__init__.py +26 -0
  26. invarlock/cli/commands/certify.py +403 -0
  27. invarlock/cli/commands/doctor.py +1358 -0
  28. invarlock/cli/commands/explain_gates.py +151 -0
  29. invarlock/cli/commands/export_html.py +100 -0
  30. invarlock/cli/commands/plugins.py +1331 -0
  31. invarlock/cli/commands/report.py +354 -0
  32. invarlock/cli/commands/run.py +4146 -0
  33. invarlock/cli/commands/verify.py +1040 -0
  34. invarlock/cli/config.py +396 -0
  35. invarlock/cli/constants.py +68 -0
  36. invarlock/cli/device.py +92 -0
  37. invarlock/cli/doctor_helpers.py +74 -0
  38. invarlock/cli/errors.py +6 -0
  39. invarlock/cli/overhead_utils.py +60 -0
  40. invarlock/cli/provenance.py +66 -0
  41. invarlock/cli/utils.py +41 -0
  42. invarlock/config.py +56 -0
  43. invarlock/core/__init__.py +62 -0
  44. invarlock/core/abi.py +15 -0
  45. invarlock/core/api.py +274 -0
  46. invarlock/core/auto_tuning.py +317 -0
  47. invarlock/core/bootstrap.py +226 -0
  48. invarlock/core/checkpoint.py +221 -0
  49. invarlock/core/contracts.py +73 -0
  50. invarlock/core/error_utils.py +64 -0
  51. invarlock/core/events.py +298 -0
  52. invarlock/core/exceptions.py +95 -0
  53. invarlock/core/registry.py +481 -0
  54. invarlock/core/retry.py +146 -0
  55. invarlock/core/runner.py +2041 -0
  56. invarlock/core/types.py +154 -0
  57. invarlock/edits/__init__.py +12 -0
  58. invarlock/edits/_edit_utils.py +249 -0
  59. invarlock/edits/_external_utils.py +268 -0
  60. invarlock/edits/noop.py +47 -0
  61. invarlock/edits/py.typed +1 -0
  62. invarlock/edits/quant_rtn.py +801 -0
  63. invarlock/edits/registry.py +166 -0
  64. invarlock/eval/__init__.py +23 -0
  65. invarlock/eval/bench.py +1207 -0
  66. invarlock/eval/bootstrap.py +50 -0
  67. invarlock/eval/data.py +2052 -0
  68. invarlock/eval/metrics.py +2167 -0
  69. invarlock/eval/primary_metric.py +767 -0
  70. invarlock/eval/probes/__init__.py +24 -0
  71. invarlock/eval/probes/fft.py +139 -0
  72. invarlock/eval/probes/mi.py +213 -0
  73. invarlock/eval/probes/post_attention.py +323 -0
  74. invarlock/eval/providers/base.py +67 -0
  75. invarlock/eval/providers/seq2seq.py +111 -0
  76. invarlock/eval/providers/text_lm.py +113 -0
  77. invarlock/eval/providers/vision_text.py +93 -0
  78. invarlock/eval/py.typed +1 -0
  79. invarlock/guards/__init__.py +18 -0
  80. invarlock/guards/_contracts.py +9 -0
  81. invarlock/guards/invariants.py +640 -0
  82. invarlock/guards/policies.py +805 -0
  83. invarlock/guards/py.typed +1 -0
  84. invarlock/guards/rmt.py +2097 -0
  85. invarlock/guards/spectral.py +1419 -0
  86. invarlock/guards/tier_config.py +354 -0
  87. invarlock/guards/variance.py +3298 -0
  88. invarlock/guards_ref/__init__.py +15 -0
  89. invarlock/guards_ref/rmt_ref.py +40 -0
  90. invarlock/guards_ref/spectral_ref.py +135 -0
  91. invarlock/guards_ref/variance_ref.py +60 -0
  92. invarlock/model_profile.py +353 -0
  93. invarlock/model_utils.py +221 -0
  94. invarlock/observability/__init__.py +10 -0
  95. invarlock/observability/alerting.py +535 -0
  96. invarlock/observability/core.py +546 -0
  97. invarlock/observability/exporters.py +565 -0
  98. invarlock/observability/health.py +588 -0
  99. invarlock/observability/metrics.py +457 -0
  100. invarlock/observability/py.typed +1 -0
  101. invarlock/observability/utils.py +553 -0
  102. invarlock/plugins/__init__.py +12 -0
  103. invarlock/plugins/hello_guard.py +33 -0
  104. invarlock/plugins/hf_awq_adapter.py +82 -0
  105. invarlock/plugins/hf_bnb_adapter.py +79 -0
  106. invarlock/plugins/hf_gptq_adapter.py +78 -0
  107. invarlock/plugins/py.typed +1 -0
  108. invarlock/py.typed +1 -0
  109. invarlock/reporting/__init__.py +7 -0
  110. invarlock/reporting/certificate.py +3221 -0
  111. invarlock/reporting/certificate_schema.py +244 -0
  112. invarlock/reporting/dataset_hashing.py +215 -0
  113. invarlock/reporting/guards_analysis.py +948 -0
  114. invarlock/reporting/html.py +32 -0
  115. invarlock/reporting/normalizer.py +235 -0
  116. invarlock/reporting/policy_utils.py +517 -0
  117. invarlock/reporting/primary_metric_utils.py +265 -0
  118. invarlock/reporting/render.py +1442 -0
  119. invarlock/reporting/report.py +903 -0
  120. invarlock/reporting/report_types.py +278 -0
  121. invarlock/reporting/utils.py +175 -0
  122. invarlock/reporting/validate.py +631 -0
  123. invarlock/security.py +176 -0
  124. invarlock/sparsity_utils.py +323 -0
  125. invarlock/utils/__init__.py +150 -0
  126. invarlock/utils/digest.py +45 -0
  127. invarlock-0.2.0.dist-info/METADATA +586 -0
  128. invarlock-0.2.0.dist-info/RECORD +132 -0
  129. invarlock-0.2.0.dist-info/WHEEL +5 -0
  130. invarlock-0.2.0.dist-info/entry_points.txt +20 -0
  131. invarlock-0.2.0.dist-info/licenses/LICENSE +201 -0
  132. invarlock-0.2.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,948 @@
1
+ # mypy: ignore-errors
2
+ from __future__ import annotations
3
+
4
+ import math
5
+ from typing import Any, no_type_check
6
+
7
+ from invarlock.core.auto_tuning import TIER_POLICIES
8
+
9
+ from .policy_utils import _promote_legacy_multiple_testing_key, _resolve_policy_tier
10
+ from .report_types import RunReport
11
+
12
+
13
+ @no_type_check
14
+ def _extract_invariants(report: RunReport) -> dict[str, Any]:
15
+ """Extract invariant check results (matches legacy shape used in tests)."""
16
+ invariants_data = (report.get("metrics", {}) or {}).get("invariants", {})
17
+ failures: list[dict[str, Any]] = []
18
+ summary: dict[str, Any] = {}
19
+
20
+ # Collect failures from metrics.invariants
21
+ if isinstance(invariants_data, dict) and invariants_data:
22
+ for check_name, check_result in invariants_data.items():
23
+ if isinstance(check_result, dict):
24
+ if bool(check_result.get("passed", True)):
25
+ continue
26
+ recorded_violation = False
27
+ violations = check_result.get("violations")
28
+ if isinstance(violations, list) and violations:
29
+ for violation in violations:
30
+ if not isinstance(violation, dict):
31
+ continue
32
+ entry: dict[str, Any] = {
33
+ "check": check_name,
34
+ "type": str(violation.get("type", "violation")),
35
+ "severity": violation.get("severity", "warning"),
36
+ }
37
+ detail = {k: v for k, v in violation.items() if k != "type"}
38
+ if detail:
39
+ entry["detail"] = detail
40
+ failures.append(entry)
41
+ recorded_violation = True
42
+ if recorded_violation:
43
+ continue
44
+ # No explicit violations list – treat as error
45
+ failure_entry = {"check": check_name}
46
+ failure_entry["type"] = str(check_result.get("type") or "failure")
47
+ failure_entry["severity"] = "error"
48
+ detail = {
49
+ k: v
50
+ for k, v in check_result.items()
51
+ if k not in {"passed", "violations", "type"}
52
+ }
53
+ if check_result.get("message"):
54
+ detail.setdefault("message", check_result["message"])
55
+ if detail:
56
+ failure_entry["detail"] = detail
57
+ failures.append(failure_entry)
58
+ else:
59
+ # Non-dict value: treat False as error severity
60
+ if not bool(check_result):
61
+ failures.append(
62
+ {"check": check_name, "type": "failure", "severity": "error"}
63
+ )
64
+
65
+ # Guard-level invariants info (counts + detailed violations)
66
+ guard_entry = None
67
+ for guard in report.get("guards", []) or []:
68
+ if str(guard.get("name", "")).lower() == "invariants":
69
+ guard_entry = guard
70
+ break
71
+
72
+ severity_status = "pass"
73
+ if guard_entry:
74
+ gm = guard_entry.get("metrics", {}) or {}
75
+ summary = {
76
+ "checks_performed": gm.get("checks_performed"),
77
+ "violations_found": gm.get("violations_found"),
78
+ "fatal_violations": gm.get("fatal_violations"),
79
+ "warning_violations": gm.get("warning_violations"),
80
+ }
81
+ violations = guard_entry.get("violations", [])
82
+ fatal_count = int(gm.get("fatal_violations", 0) or 0)
83
+ warning_count = int(gm.get("warning_violations", 0) or 0)
84
+ if violations:
85
+ for violation in violations:
86
+ if not isinstance(violation, dict):
87
+ continue
88
+ row = {
89
+ "check": str(
90
+ violation.get("check") or violation.get("name") or "invariant"
91
+ ),
92
+ "type": str(violation.get("type") or "violation"),
93
+ "severity": str(violation.get("severity") or "warning"),
94
+ }
95
+ detail = {k: v for k, v in violation.items() if k not in row}
96
+ if detail:
97
+ row["detail"] = detail
98
+ failures.append(row)
99
+ if fatal_count > 0:
100
+ severity_status = "fail"
101
+ elif warning_count > 0 or violations:
102
+ severity_status = "warn"
103
+
104
+ # If any error-severity entry exists among failures, escalate to fail
105
+ if failures:
106
+ has_error = any(str(f.get("severity", "warning")) == "error" for f in failures)
107
+ if has_error:
108
+ severity_status = "fail"
109
+ elif severity_status == "pass":
110
+ severity_status = "warn"
111
+
112
+ status = severity_status
113
+ if not summary:
114
+ summary = {
115
+ "checks_performed": 0,
116
+ "violations_found": len(failures),
117
+ "fatal_violations": 0,
118
+ "warning_violations": len(failures),
119
+ }
120
+
121
+ return {
122
+ "pre": "pass",
123
+ "post": status,
124
+ "status": status,
125
+ "summary": summary,
126
+ "details": invariants_data,
127
+ "failures": failures,
128
+ }
129
+
130
+
131
+ @no_type_check
132
+ def _extract_spectral_analysis(
133
+ report: RunReport, baseline: dict[str, Any]
134
+ ) -> dict[str, Any]:
135
+ tier = _resolve_policy_tier(report)
136
+ tier_defaults = TIER_POLICIES.get(tier, TIER_POLICIES.get("balanced", {}))
137
+ spectral_defaults = tier_defaults.get("spectral", {}) if tier_defaults else {}
138
+ default_sigma_quantile = spectral_defaults.get("sigma_quantile", 0.95)
139
+ default_deadband = spectral_defaults.get("deadband", 0.1)
140
+ default_caps = spectral_defaults.get("family_caps", {})
141
+ default_max_caps = spectral_defaults.get("max_caps", 5)
142
+
143
+ spectral_guard = None
144
+ for guard in report.get("guards", []) or []:
145
+ if str(guard.get("name", "")).lower() == "spectral":
146
+ spectral_guard = guard
147
+ break
148
+
149
+ guard_policy = spectral_guard.get("policy", {}) if spectral_guard else {}
150
+ guard_metrics = spectral_guard.get("metrics", {}) if spectral_guard else {}
151
+ if guard_metrics:
152
+ raw = (
153
+ guard_metrics.get("violations_detected")
154
+ or guard_metrics.get("violations_found")
155
+ or guard_metrics.get("caps_applied")
156
+ or (1 if guard_metrics.get("correction_applied") else 0)
157
+ or 0
158
+ )
159
+ try:
160
+ caps_applied = int(raw)
161
+ except Exception:
162
+ caps_applied = 0
163
+ else:
164
+ caps_applied = 0
165
+ modules_checked = guard_metrics.get("modules_checked") if guard_metrics else None
166
+ caps_exceeded = (
167
+ bool(guard_metrics.get("caps_exceeded", False)) if guard_metrics else False
168
+ )
169
+ max_caps = guard_policy.get("max_caps") if guard_policy else None
170
+ if max_caps is None:
171
+ max_caps = default_max_caps
172
+
173
+ try:
174
+ max_spectral_norm = float(
175
+ guard_metrics.get("max_spectral_norm_final")
176
+ or guard_metrics.get("max_spectral_norm")
177
+ or 0.0
178
+ )
179
+ except Exception:
180
+ max_spectral_norm = 0.0
181
+ try:
182
+ mean_spectral_norm = float(
183
+ guard_metrics.get("mean_spectral_norm_final")
184
+ or guard_metrics.get("mean_spectral_norm")
185
+ or 0.0
186
+ )
187
+ except Exception:
188
+ mean_spectral_norm = 0.0
189
+
190
+ baseline_max = None
191
+ baseline_mean = None
192
+ baseline_spectral = (
193
+ baseline.get("spectral", {}) if isinstance(baseline, dict) else {}
194
+ )
195
+ if isinstance(baseline_spectral, dict) and baseline_spectral:
196
+ baseline_max = baseline_spectral.get(
197
+ "max_spectral_norm", baseline_spectral.get("max_spectral_norm_final")
198
+ )
199
+ baseline_mean = baseline_spectral.get(
200
+ "mean_spectral_norm", baseline_spectral.get("mean_spectral_norm_final")
201
+ )
202
+ if baseline_max is None:
203
+ baseline_metrics = (
204
+ baseline.get("metrics", {}) if isinstance(baseline, dict) else {}
205
+ )
206
+ if isinstance(baseline_metrics, dict) and "spectral" in baseline_metrics:
207
+ baseline_spectral_metrics = baseline_metrics["spectral"]
208
+ if isinstance(baseline_spectral_metrics, dict):
209
+ baseline_max = baseline_spectral_metrics.get("max_spectral_norm_final")
210
+ baseline_mean = baseline_spectral_metrics.get(
211
+ "mean_spectral_norm_final"
212
+ )
213
+ guard_baseline_metrics = None
214
+ if spectral_guard and isinstance(spectral_guard.get("baseline_metrics"), dict):
215
+ guard_baseline_metrics = spectral_guard.get("baseline_metrics")
216
+ if baseline_max is None and guard_baseline_metrics:
217
+ baseline_max = guard_baseline_metrics.get("max_spectral_norm")
218
+ baseline_mean = guard_baseline_metrics.get("mean_spectral_norm")
219
+ baseline_max = float(baseline_max) if baseline_max not in (None, 0, 0.0) else None
220
+ baseline_mean = (
221
+ float(baseline_mean) if baseline_mean not in (None, 0, 0.0) else None
222
+ )
223
+
224
+ max_sigma_ratio = (
225
+ max_spectral_norm / baseline_max if baseline_max and baseline_max > 0 else 1.0
226
+ )
227
+ median_sigma_ratio = (
228
+ mean_spectral_norm / baseline_mean
229
+ if baseline_mean and baseline_mean > 0
230
+ else 1.0
231
+ )
232
+
233
+ def _compute_quantile(sorted_values: list[float], quantile: float) -> float:
234
+ if not sorted_values:
235
+ return 0.0
236
+ if len(sorted_values) == 1:
237
+ return sorted_values[0]
238
+ position = (len(sorted_values) - 1) * quantile
239
+ lower = math.floor(position)
240
+ upper = math.ceil(position)
241
+ if lower == upper:
242
+ return sorted_values[int(position)]
243
+ fraction = position - lower
244
+ return (
245
+ sorted_values[lower]
246
+ + (sorted_values[upper] - sorted_values[lower]) * fraction
247
+ )
248
+
249
+ def _summarize_from_z_scores(
250
+ z_scores_map: Any, module_family_map: Any
251
+ ) -> tuple[dict[str, dict[str, float]], dict[str, list[dict[str, Any]]]]:
252
+ from collections import defaultdict
253
+
254
+ if not isinstance(z_scores_map, dict) or not z_scores_map:
255
+ return {}, {}
256
+ if not isinstance(module_family_map, dict) or not module_family_map:
257
+ return {}, {}
258
+
259
+ per_family_values: dict[str, list[tuple[float, str]]] = defaultdict(list)
260
+ for module_name, z_value in z_scores_map.items():
261
+ family = module_family_map.get(module_name)
262
+ if family is None:
263
+ continue
264
+ try:
265
+ z_abs = abs(float(z_value))
266
+ except (TypeError, ValueError):
267
+ continue
268
+ per_family_values[family].append((z_abs, module_name))
269
+
270
+ family_quantiles_local: dict[str, dict[str, float]] = {}
271
+ top_z_scores_local: dict[str, list[dict[str, Any]]] = {}
272
+
273
+ for family, value_list in per_family_values.items():
274
+ if not value_list:
275
+ continue
276
+ sorted_scores = sorted(z for z, _ in value_list)
277
+ family_quantiles_local[family] = {
278
+ "q95": _compute_quantile(sorted_scores, 0.95),
279
+ "q99": _compute_quantile(sorted_scores, 0.99),
280
+ "max": sorted_scores[-1],
281
+ "count": len(sorted_scores),
282
+ }
283
+ top_entries = sorted(value_list, key=lambda t: abs(t[0]), reverse=True)[:3]
284
+ top_z_scores_local[family] = [
285
+ {"module": name, "z": float(z)} for z, name in top_entries
286
+ ]
287
+
288
+ return family_quantiles_local, top_z_scores_local
289
+
290
+ summary: dict[str, Any] = {}
291
+ family_quantiles: dict[str, dict[str, float]] = {}
292
+ families: dict[str, dict[str, Any]] = {}
293
+ family_caps: dict[str, dict[str, float]] = {}
294
+ top_z_scores: dict[str, list[dict[str, Any]]] = {}
295
+
296
+ if isinstance(guard_metrics, dict):
297
+ # Resolve deadband from policy/metrics/defaults
298
+ deadband_used: float | None = None
299
+ try:
300
+ db_raw = guard_policy.get("deadband") if guard_policy else None
301
+ if db_raw is None and isinstance(guard_metrics, dict):
302
+ db_raw = guard_metrics.get("deadband")
303
+ if db_raw is None:
304
+ db_raw = default_deadband
305
+ if db_raw is not None:
306
+ deadband_used = float(db_raw)
307
+ except Exception:
308
+ deadband_used = None
309
+
310
+ # Resolve sigma_quantile for summary (policy aliases supported)
311
+ sigma_q_used: float | None = None
312
+ try:
313
+ pol_sq = None
314
+ if isinstance(guard_policy, dict):
315
+ pol_sq = (
316
+ guard_policy.get("sigma_quantile")
317
+ or guard_policy.get("contraction")
318
+ or guard_policy.get("kappa")
319
+ )
320
+ if pol_sq is None:
321
+ pol_sq = default_sigma_quantile
322
+ if pol_sq is not None:
323
+ sigma_q_used = float(pol_sq)
324
+ except Exception:
325
+ sigma_q_used = None
326
+
327
+ summary = {
328
+ "max_sigma_ratio": max_sigma_ratio,
329
+ "median_sigma_ratio": median_sigma_ratio,
330
+ "max_spectral_norm": max_spectral_norm,
331
+ "mean_spectral_norm": mean_spectral_norm,
332
+ "baseline_max_spectral_norm": baseline_max,
333
+ "baseline_mean_spectral_norm": baseline_mean,
334
+ }
335
+ if sigma_q_used is not None:
336
+ summary["sigma_quantile"] = sigma_q_used
337
+ if deadband_used is not None:
338
+ summary["deadband"] = deadband_used
339
+ try:
340
+ summary["stability_score"] = float(
341
+ guard_metrics.get(
342
+ "spectral_stability_score",
343
+ guard_metrics.get("stability_score", 1.0),
344
+ )
345
+ )
346
+ except Exception:
347
+ pass
348
+ # Prefer explicit family_z_quantiles when present; otherwise accept summary
349
+ family_quantiles = (
350
+ guard_metrics.get("family_z_quantiles")
351
+ if isinstance(guard_metrics.get("family_z_quantiles"), dict)
352
+ else {}
353
+ )
354
+ if not family_quantiles:
355
+ family_quantiles = (
356
+ guard_metrics.get("family_z_summary")
357
+ if isinstance(guard_metrics.get("family_z_summary"), dict)
358
+ else {}
359
+ )
360
+ # Build families table from available sources
361
+ families = (
362
+ guard_metrics.get("families")
363
+ if isinstance(guard_metrics.get("families"), dict)
364
+ else {}
365
+ )
366
+ if not families:
367
+ # Prefer z-summary when available; accept legacy 'family_stats' too
368
+ fzs = guard_metrics.get("family_z_summary")
369
+ if not isinstance(fzs, dict) or not fzs:
370
+ fzs = guard_metrics.get("family_stats")
371
+ if isinstance(fzs, dict):
372
+ for fam, stats in fzs.items():
373
+ if not isinstance(stats, dict):
374
+ continue
375
+ entry: dict[str, Any] = {}
376
+ if "max" in stats:
377
+ try:
378
+ entry["max"] = float(stats["max"])
379
+ except Exception:
380
+ pass
381
+ if "mean" in stats:
382
+ try:
383
+ entry["mean"] = float(stats["mean"])
384
+ except Exception:
385
+ pass
386
+ if "count" in stats:
387
+ try:
388
+ entry["count"] = int(stats["count"])
389
+ except Exception:
390
+ pass
391
+ if "violations" in stats:
392
+ try:
393
+ entry["violations"] = int(stats["violations"])
394
+ except Exception:
395
+ pass
396
+ # Propagate kappa from stats or family_caps
397
+ kappa = stats.get("kappa") if isinstance(stats, dict) else None
398
+ if (
399
+ kappa is None
400
+ and family_caps.get(str(fam), {}).get("kappa") is not None
401
+ ):
402
+ kappa = family_caps[str(fam)]["kappa"]
403
+ try:
404
+ if kappa is not None:
405
+ entry["kappa"] = float(kappa)
406
+ except Exception:
407
+ pass
408
+ if entry:
409
+ families[str(fam)] = entry
410
+ family_caps = (
411
+ guard_metrics.get("family_caps")
412
+ if isinstance(guard_metrics.get("family_caps"), dict)
413
+ else {}
414
+ )
415
+ if not family_caps and isinstance(guard_policy, dict):
416
+ fam_caps_pol = guard_policy.get("family_caps")
417
+ if isinstance(fam_caps_pol, dict):
418
+ family_caps = fam_caps_pol
419
+ if not family_caps and isinstance(default_caps, dict):
420
+ family_caps = default_caps
421
+ raw_top = (
422
+ guard_metrics.get("top_z_scores")
423
+ if isinstance(guard_metrics.get("top_z_scores"), dict)
424
+ else {}
425
+ )
426
+ top_z_scores = {}
427
+ if isinstance(raw_top, dict):
428
+ for fam, entries in raw_top.items():
429
+ if not isinstance(entries, list):
430
+ continue
431
+ cleaned: list[dict[str, Any]] = []
432
+ for e in entries:
433
+ if not isinstance(e, dict):
434
+ continue
435
+ mod = e.get("module")
436
+ z = e.get("z")
437
+ try:
438
+ zf = float(z)
439
+ except Exception:
440
+ continue
441
+ cleaned.append({"module": mod, "z": zf})
442
+ if cleaned:
443
+ cleaned.sort(key=lambda d: abs(d.get("z", 0.0)), reverse=True)
444
+ top_z_scores[str(fam)] = cleaned[:3]
445
+
446
+ # Derive quantiles/top z from z-scores when available, and fill any gaps
447
+ if spectral_guard:
448
+ z_map_candidate = spectral_guard.get("final_z_scores") or guard_metrics.get(
449
+ "final_z_scores"
450
+ )
451
+ family_map_candidate = spectral_guard.get(
452
+ "module_family_map"
453
+ ) or guard_metrics.get("module_family_map")
454
+ derived_quantiles, derived_top = _summarize_from_z_scores(
455
+ z_map_candidate, family_map_candidate
456
+ )
457
+ if derived_quantiles and not family_quantiles:
458
+ family_quantiles = derived_quantiles
459
+ # Always backfill missing families in top_z_scores from derived_top
460
+ if isinstance(derived_top, dict) and derived_top:
461
+ if not isinstance(top_z_scores, dict) or not top_z_scores:
462
+ top_z_scores = dict(derived_top)
463
+ else:
464
+ for fam, entries in derived_top.items():
465
+ cur = top_z_scores.get(fam)
466
+ if not isinstance(cur, list) or not cur:
467
+ top_z_scores[fam] = entries
468
+
469
+ # Fallback: compute sigma ratios from raw ratios array when present
470
+ if not guard_metrics:
471
+ spectral_data = (report.get("metrics", {}) or {}).get("spectral", {})
472
+ if isinstance(spectral_data, dict):
473
+ ratios = spectral_data.get("sigma_ratios")
474
+ if isinstance(ratios, list) and ratios:
475
+ try:
476
+ float_ratios = [float(r) for r in ratios]
477
+ summary["max_sigma_ratio"] = max(float_ratios)
478
+ summary["median_sigma_ratio"] = float(
479
+ sorted(float_ratios)[len(float_ratios) // 2]
480
+ )
481
+ except Exception:
482
+ pass
483
+
484
+ # Multiple testing resolution
485
+ def _resolve_multiple_testing(*sources: Any) -> dict[str, Any] | None:
486
+ for source in sources:
487
+ if not isinstance(source, dict):
488
+ continue
489
+ candidate = source.get("multiple_testing") or source.get("multipletesting")
490
+ if isinstance(candidate, dict) and candidate:
491
+ return candidate
492
+ return None
493
+
494
+ multiple_testing = _resolve_multiple_testing(
495
+ guard_metrics, guard_policy, spectral_defaults
496
+ )
497
+
498
+ policy_out: dict[str, Any] | None = None
499
+ if isinstance(guard_policy, dict) and guard_policy:
500
+ policy_out = dict(guard_policy)
501
+ _promote_legacy_multiple_testing_key(policy_out)
502
+ if default_sigma_quantile is not None:
503
+ sq = (
504
+ policy_out.get("sigma_quantile")
505
+ or policy_out.get("contraction")
506
+ or policy_out.get("kappa")
507
+ )
508
+ if sq is not None:
509
+ try:
510
+ policy_out["sigma_quantile"] = float(sq)
511
+ except Exception:
512
+ pass
513
+ policy_out.pop("contraction", None)
514
+ policy_out.pop("kappa", None)
515
+ if tier == "balanced":
516
+ policy_out["correction_enabled"] = False
517
+ policy_out["max_spectral_norm"] = None
518
+ if multiple_testing and "multiple_testing" not in policy_out:
519
+ policy_out["multiple_testing"] = multiple_testing
520
+
521
+ result: dict[str, Any] = {
522
+ "tier": tier,
523
+ "caps_applied": caps_applied,
524
+ "summary": summary,
525
+ "families": families,
526
+ "family_caps": family_caps,
527
+ }
528
+ # Attach status to summary for backward-compatibility in tests
529
+ try:
530
+ summary["status"] = "stable" if int(caps_applied) == 0 else "capped"
531
+ except Exception:
532
+ summary["status"] = "stable" if not caps_applied else "capped"
533
+ if policy_out:
534
+ result["policy"] = policy_out
535
+ if default_sigma_quantile is not None:
536
+ result["sigma_quantile"] = default_sigma_quantile
537
+ if deadband_used is not None:
538
+ result["deadband"] = deadband_used
539
+ # Always include max_caps key for schema/tests parity
540
+ max_caps_val = int(max_caps) if isinstance(max_caps, int | float) else None
541
+ result["max_caps"] = max_caps_val
542
+ try:
543
+ summary["max_caps"] = max_caps_val
544
+ except Exception:
545
+ pass
546
+ if multiple_testing:
547
+ mt_copy = dict(multiple_testing)
548
+ families_present = set((families or {}).keys()) or set(
549
+ (family_caps or {}).keys()
550
+ )
551
+ try:
552
+ mt_copy["m"] = int(mt_copy.get("m") or len(families_present))
553
+ except Exception:
554
+ mt_copy["m"] = len(families_present)
555
+ result["multiple_testing"] = mt_copy
556
+ result["bh_family_count"] = mt_copy["m"]
557
+
558
+ # Additional derived fields for rendering/tests parity
559
+ if families:
560
+ caps_by_family = {
561
+ fam: int(details.get("violations", 0))
562
+ for fam, details in families.items()
563
+ if isinstance(details, dict)
564
+ }
565
+ result["caps_applied_by_family"] = caps_by_family
566
+ if top_z_scores:
567
+ result["top_z_scores"] = top_z_scores
568
+ # Top violations list from guard payload
569
+ if spectral_guard and isinstance(spectral_guard.get("violations"), list):
570
+ top_violations: list[dict[str, Any]] = []
571
+ for violation in spectral_guard["violations"][:5]:
572
+ if not isinstance(violation, dict):
573
+ continue
574
+ entry = {
575
+ "module": violation.get("module"),
576
+ "family": violation.get("family"),
577
+ "kappa": violation.get("kappa"),
578
+ "severity": violation.get("severity", "warn"),
579
+ }
580
+ z_score = violation.get("z_score")
581
+ try:
582
+ entry["z_score"] = float(z_score)
583
+ except Exception:
584
+ pass
585
+ top_violations.append(entry)
586
+ if top_violations:
587
+ result["top_violations"] = top_violations
588
+ if family_quantiles:
589
+ result["family_z_quantiles"] = family_quantiles
590
+ result["caps_exceeded"] = bool(caps_exceeded)
591
+ try:
592
+ summary["caps_exceeded"] = bool(caps_exceeded)
593
+ except Exception:
594
+ pass
595
+ # Propagate modules_checked when present
596
+ if modules_checked is not None:
597
+ try:
598
+ summary["modules_checked"] = int(modules_checked)
599
+ except Exception:
600
+ pass
601
+
602
+ if families:
603
+ caps_by_family = {
604
+ family: int(details.get("violations", 0))
605
+ for family, details in (families or {}).items()
606
+ if isinstance(details, dict)
607
+ }
608
+ result["caps_applied_by_family"] = caps_by_family
609
+ if top_z_scores:
610
+ result["top_z_scores"] = top_z_scores
611
+ if family_quantiles:
612
+ result["family_z_quantiles"] = family_quantiles
613
+ return result
614
+
615
+
616
+ @no_type_check
617
+ def _extract_rmt_analysis(
618
+ report: RunReport, baseline: dict[str, Any]
619
+ ) -> dict[str, Any]:
620
+ tier = _resolve_policy_tier(report)
621
+ tier_defaults = TIER_POLICIES.get(tier, TIER_POLICIES.get("balanced", {}))
622
+ default_epsilon_map = (
623
+ tier_defaults.get("rmt", {}).get("epsilon", {}) if tier_defaults else {}
624
+ )
625
+ default_epsilon_map = {
626
+ str(family): float(value)
627
+ for family, value in (default_epsilon_map or {}).items()
628
+ if isinstance(value, int | float)
629
+ }
630
+
631
+ outliers_guarded = 0
632
+ outliers_bare = 0
633
+ epsilon_default = 0.1
634
+ stable = True
635
+ explicit_stability = False
636
+ max_ratio = 0.0
637
+ max_deviation_ratio = 1.0
638
+ mean_deviation_ratio = 1.0
639
+ epsilon_map: dict[str, float] = {}
640
+ baseline_outliers_per_family: dict[str, int] = {}
641
+ outliers_per_family: dict[str, int] = {}
642
+ epsilon_violations: list[Any] = []
643
+
644
+ for guard in report.get("guards", []) or []:
645
+ if str(guard.get("name", "")).lower() == "rmt":
646
+ guard_metrics = guard.get("metrics", {}) or {}
647
+ guard_policy = guard.get("policy", {}) or {}
648
+ outliers_guarded = guard_metrics.get(
649
+ "rmt_outliers", guard_metrics.get("layers_flagged", outliers_guarded)
650
+ )
651
+ max_ratio = guard_metrics.get("max_ratio", 0.0)
652
+ epsilon_default = guard_policy.get(
653
+ "deadband", guard_metrics.get("deadband_used", epsilon_default)
654
+ )
655
+ epsilon_map = guard_metrics.get("epsilon_by_family", {}) or epsilon_map
656
+ baseline_outliers_per_family = (
657
+ guard_metrics.get("baseline_outliers_per_family", {})
658
+ or baseline_outliers_per_family
659
+ )
660
+ outliers_per_family = (
661
+ guard_metrics.get("outliers_per_family", {}) or outliers_per_family
662
+ )
663
+ epsilon_violations = guard_metrics.get(
664
+ "epsilon_violations", epsilon_violations
665
+ )
666
+ if outliers_per_family:
667
+ outliers_guarded = sum(
668
+ int(v)
669
+ for v in outliers_per_family.values()
670
+ if isinstance(v, int | float)
671
+ )
672
+ if baseline_outliers_per_family:
673
+ outliers_bare = sum(
674
+ int(v)
675
+ for v in baseline_outliers_per_family.values()
676
+ if isinstance(v, int | float)
677
+ )
678
+ flagged_rate = guard_metrics.get("flagged_rate", 0.0)
679
+ stable = flagged_rate <= 0.5
680
+ max_mp_ratio = guard_metrics.get("max_mp_ratio_final", 0.0)
681
+ mean_mp_ratio = guard_metrics.get("mean_mp_ratio_final", 0.0)
682
+
683
+ baseline_max = None
684
+ baseline_mean = None
685
+ baseline_rmt = baseline.get("rmt", {}) if isinstance(baseline, dict) else {}
686
+ if baseline_rmt:
687
+ baseline_max = baseline_rmt.get(
688
+ "max_mp_ratio", baseline_rmt.get("max_mp_ratio_final")
689
+ )
690
+ baseline_mean = baseline_rmt.get(
691
+ "mean_mp_ratio", baseline_rmt.get("mean_mp_ratio_final")
692
+ )
693
+ outliers_bare = baseline_rmt.get(
694
+ "outliers", baseline_rmt.get("rmt_outliers", 0)
695
+ )
696
+ if baseline_max is None:
697
+ baseline_metrics = (
698
+ baseline.get("metrics", {}) if isinstance(baseline, dict) else {}
699
+ )
700
+ if "rmt" in baseline_metrics:
701
+ baseline_rmt_metrics = baseline_metrics["rmt"]
702
+ baseline_max = baseline_rmt_metrics.get("max_mp_ratio_final")
703
+ baseline_mean = baseline_rmt_metrics.get("mean_mp_ratio_final")
704
+ if baseline_max is None and isinstance(guard.get("baseline_metrics"), dict):
705
+ gb = guard.get("baseline_metrics")
706
+ baseline_max = gb.get("max_mp_ratio")
707
+ baseline_mean = gb.get("mean_mp_ratio")
708
+ if baseline_max is not None and baseline_max > 0:
709
+ max_deviation_ratio = max_mp_ratio / baseline_max
710
+ else:
711
+ max_deviation_ratio = 1.0
712
+ if baseline_mean is not None and baseline_mean > 0:
713
+ mean_deviation_ratio = mean_mp_ratio / baseline_mean
714
+ else:
715
+ mean_deviation_ratio = 1.0
716
+ if isinstance(guard_metrics.get("stable"), bool):
717
+ stable = bool(guard_metrics.get("stable"))
718
+ explicit_stability = True
719
+ break
720
+
721
+ # Fallback: use metrics.rmt and/or top-level rmt section when guard is absent
722
+ if outliers_guarded == 0:
723
+ rmt_metrics = (report.get("metrics", {}) or {}).get("rmt", {})
724
+ if isinstance(rmt_metrics, dict):
725
+ try:
726
+ outliers_guarded = int(rmt_metrics.get("outliers", 0) or 0)
727
+ except Exception:
728
+ outliers_guarded = 0
729
+ if isinstance(rmt_metrics.get("stable"), bool):
730
+ stable = bool(rmt_metrics.get("stable"))
731
+ explicit_stability = True
732
+ rmt_top = report.get("rmt", {}) if isinstance(report.get("rmt"), dict) else {}
733
+ if isinstance(rmt_top, dict):
734
+ fams = rmt_top.get("families", {})
735
+ if isinstance(fams, dict) and fams:
736
+ for fam, rec in fams.items():
737
+ if not isinstance(rec, dict):
738
+ continue
739
+ try:
740
+ outliers_per_family[str(fam)] = int(
741
+ rec.get("outliers_guarded", 0) or 0
742
+ )
743
+ baseline_outliers_per_family[str(fam)] = int(
744
+ rec.get("outliers_bare", 0) or 0
745
+ )
746
+ if rec.get("epsilon") is not None:
747
+ try:
748
+ epsilon_map[str(fam)] = float(rec.get("epsilon"))
749
+ except Exception:
750
+ pass
751
+ except Exception:
752
+ continue
753
+ try:
754
+ if outliers_bare == 0:
755
+ outliers_bare = int(rmt_top.get("outliers", 0) or 0)
756
+ except Exception:
757
+ pass
758
+
759
+ # If stability not explicitly provided, derive from outlier behavior
760
+ if not explicit_stability:
761
+ try:
762
+ if outliers_guarded == 0 and outliers_bare == 0:
763
+ stable = True
764
+ elif outliers_guarded <= outliers_bare:
765
+ stable = True
766
+ else:
767
+ stable = (outliers_guarded - outliers_bare) / max(
768
+ outliers_bare, 1
769
+ ) <= 0.5
770
+ except Exception:
771
+ pass
772
+
773
+ delta_per_family = {
774
+ k: int(outliers_per_family.get(k, 0))
775
+ - int(baseline_outliers_per_family.get(k, 0))
776
+ for k in set(outliers_per_family) | set(baseline_outliers_per_family)
777
+ }
778
+ delta_total = int(outliers_guarded) - int(outliers_bare)
779
+ # Conservative baseline fallback when not available
780
+ if outliers_bare == 0 and outliers_guarded > 0:
781
+ # Assume baseline had fewer outliers to make acceptance harder
782
+ outliers_bare = max(0, outliers_guarded - 1)
783
+
784
+ # Recompute stability from epsilon rule when not explicitly provided
785
+ if not explicit_stability:
786
+ try:
787
+ if outliers_per_family and baseline_outliers_per_family:
788
+ families_union = set(outliers_per_family) | set(
789
+ baseline_outliers_per_family
790
+ )
791
+ checks: list[bool] = []
792
+ for fam in families_union:
793
+ guarded = int(outliers_per_family.get(fam, 0) or 0)
794
+ bare = int(baseline_outliers_per_family.get(fam, 0) or 0)
795
+ eps_val = float(epsilon_map.get(fam, epsilon_default))
796
+ allowed = math.ceil(bare * (1.0 + eps_val))
797
+ checks.append(guarded <= allowed)
798
+ if checks:
799
+ stable = all(checks)
800
+ elif outliers_bare > 0:
801
+ stable = outliers_guarded <= (
802
+ outliers_bare * (1.0 + float(epsilon_default))
803
+ )
804
+ except Exception:
805
+ pass
806
+
807
+ # Compute epsilon scalar (fallback) and detailed family breakdown
808
+ if epsilon_map:
809
+ epsilon_scalar = max(float(v) for v in epsilon_map.values())
810
+ elif default_epsilon_map:
811
+ try:
812
+ epsilon_scalar = max(float(v) for v in default_epsilon_map.values())
813
+ except Exception:
814
+ epsilon_scalar = float(epsilon_default)
815
+ else:
816
+ epsilon_scalar = float(epsilon_default)
817
+ try:
818
+ epsilon_scalar = round(float(epsilon_scalar), 3)
819
+ except Exception:
820
+ epsilon_scalar = float(epsilon_default)
821
+
822
+ def _to_int(v: Any) -> int:
823
+ try:
824
+ return int(v)
825
+ except (TypeError, ValueError):
826
+ return 0
827
+
828
+ families = (
829
+ set(outliers_per_family) | set(baseline_outliers_per_family) | set(epsilon_map)
830
+ )
831
+ family_breakdown = {
832
+ family: {
833
+ "bare": _to_int(baseline_outliers_per_family.get(family, 0)),
834
+ "guarded": _to_int(outliers_per_family.get(family, 0)),
835
+ "epsilon": float(epsilon_map.get(family, epsilon_scalar)),
836
+ }
837
+ for family in sorted(families)
838
+ }
839
+
840
+ # Stringify per-family dict keys for stability
841
+ outliers_per_family = {str(k): _to_int(v) for k, v in outliers_per_family.items()}
842
+ baseline_outliers_per_family = {
843
+ str(k): _to_int(v) for k, v in baseline_outliers_per_family.items()
844
+ }
845
+ delta_per_family = {str(k): _to_int(v) for k, v in delta_per_family.items()}
846
+
847
+ return {
848
+ "outliers_bare": outliers_bare,
849
+ "outliers_guarded": outliers_guarded,
850
+ "epsilon": epsilon_scalar,
851
+ "epsilon_default": float(epsilon_default),
852
+ "epsilon_by_family": epsilon_map,
853
+ "outliers_per_family": outliers_per_family,
854
+ "baseline_outliers_per_family": baseline_outliers_per_family,
855
+ "delta_per_family": delta_per_family,
856
+ "delta_total": delta_total,
857
+ "epsilon_violations": epsilon_violations,
858
+ "stable": stable,
859
+ "status": "stable" if stable else "unstable",
860
+ "max_ratio": max_ratio,
861
+ "max_deviation_ratio": max_deviation_ratio,
862
+ "mean_deviation_ratio": mean_deviation_ratio,
863
+ "families": family_breakdown,
864
+ }
865
+
866
+
867
+ @no_type_check
868
+ def _extract_variance_analysis(report: RunReport) -> dict[str, Any]:
869
+ ve_enabled = False
870
+ gain = None
871
+ ppl_no_ve = None
872
+ ppl_with_ve = None
873
+ ratio_ci = None
874
+ calibration = {}
875
+ guard_metrics: dict[str, Any] = {}
876
+ for guard in report.get("guards", []) or []:
877
+ if "variance" in str(guard.get("name", "")).lower():
878
+ metrics = guard.get("metrics", {}) or {}
879
+ guard_metrics = metrics
880
+ ve_enabled = metrics.get("ve_enabled", bool(metrics))
881
+ gain = metrics.get("ab_gain", metrics.get("gain", None))
882
+ ppl_no_ve = metrics.get("ppl_no_ve", None)
883
+ ppl_with_ve = metrics.get("ppl_with_ve", None)
884
+ ratio_ci = metrics.get("ratio_ci", ratio_ci)
885
+ calibration = metrics.get("calibration", calibration)
886
+ break
887
+ if gain is None:
888
+ metrics_variance = (report.get("metrics", {}) or {}).get("variance", {})
889
+ if isinstance(metrics_variance, dict):
890
+ ve_enabled = metrics_variance.get("ve_enabled", ve_enabled)
891
+ gain = metrics_variance.get("gain", gain)
892
+ ppl_no_ve = metrics_variance.get("ppl_no_ve", ppl_no_ve)
893
+ ppl_with_ve = metrics_variance.get("ppl_with_ve", ppl_with_ve)
894
+ if not guard_metrics:
895
+ guard_metrics = metrics_variance
896
+ result = {"enabled": ve_enabled, "gain": gain}
897
+ if ratio_ci:
898
+ try:
899
+ result["ratio_ci"] = (float(ratio_ci[0]), float(ratio_ci[1]))
900
+ except Exception:
901
+ pass
902
+ if calibration:
903
+ result["calibration"] = calibration
904
+ if not ve_enabled and ppl_no_ve is not None and ppl_with_ve is not None:
905
+ result["ppl_no_ve"] = ppl_no_ve
906
+ result["ppl_with_ve"] = ppl_with_ve
907
+ metadata_fields = [
908
+ "tap",
909
+ "target_modules",
910
+ "target_module_names",
911
+ "focus_modules",
912
+ "scope",
913
+ "proposed_scales",
914
+ "proposed_scales_pre_edit",
915
+ "proposed_scales_post_edit",
916
+ "monitor_only",
917
+ "max_calib_used",
918
+ "mode",
919
+ "min_rel_gain",
920
+ "alpha",
921
+ ]
922
+ for field in metadata_fields:
923
+ value = guard_metrics.get(field)
924
+ if value not in (None, {}, []):
925
+ result[field] = value
926
+ predictive_gate = guard_metrics.get("predictive_gate")
927
+ if predictive_gate:
928
+ result["predictive_gate"] = predictive_gate
929
+ ab_section: dict[str, Any] = {}
930
+ if guard_metrics.get("ab_seed_used") is not None:
931
+ ab_section["seed"] = guard_metrics["ab_seed_used"]
932
+ if guard_metrics.get("ab_windows_used") is not None:
933
+ ab_section["windows_used"] = guard_metrics["ab_windows_used"]
934
+ if guard_metrics.get("ab_provenance"):
935
+ ab_section["provenance"] = guard_metrics["ab_provenance"]
936
+ if guard_metrics.get("ab_point_estimates"):
937
+ ab_section["point_estimates"] = guard_metrics["ab_point_estimates"]
938
+ if ab_section:
939
+ result["ab_test"] = ab_section
940
+ return result
941
+
942
+
943
+ __all__ = [
944
+ "_extract_invariants",
945
+ "_extract_spectral_analysis",
946
+ "_extract_rmt_analysis",
947
+ "_extract_variance_analysis",
948
+ ]