invarlock 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (132) hide show
  1. invarlock/__init__.py +33 -0
  2. invarlock/__main__.py +10 -0
  3. invarlock/_data/runtime/profiles/ci_cpu.yaml +15 -0
  4. invarlock/_data/runtime/profiles/release.yaml +23 -0
  5. invarlock/_data/runtime/tiers.yaml +76 -0
  6. invarlock/adapters/__init__.py +102 -0
  7. invarlock/adapters/_capabilities.py +45 -0
  8. invarlock/adapters/auto.py +99 -0
  9. invarlock/adapters/base.py +530 -0
  10. invarlock/adapters/base_types.py +85 -0
  11. invarlock/adapters/hf_bert.py +852 -0
  12. invarlock/adapters/hf_gpt2.py +403 -0
  13. invarlock/adapters/hf_llama.py +485 -0
  14. invarlock/adapters/hf_mixin.py +383 -0
  15. invarlock/adapters/hf_onnx.py +112 -0
  16. invarlock/adapters/hf_t5.py +137 -0
  17. invarlock/adapters/py.typed +1 -0
  18. invarlock/assurance/__init__.py +43 -0
  19. invarlock/cli/__init__.py +8 -0
  20. invarlock/cli/__main__.py +8 -0
  21. invarlock/cli/_evidence.py +25 -0
  22. invarlock/cli/_json.py +75 -0
  23. invarlock/cli/adapter_auto.py +162 -0
  24. invarlock/cli/app.py +287 -0
  25. invarlock/cli/commands/__init__.py +26 -0
  26. invarlock/cli/commands/certify.py +403 -0
  27. invarlock/cli/commands/doctor.py +1358 -0
  28. invarlock/cli/commands/explain_gates.py +151 -0
  29. invarlock/cli/commands/export_html.py +100 -0
  30. invarlock/cli/commands/plugins.py +1331 -0
  31. invarlock/cli/commands/report.py +354 -0
  32. invarlock/cli/commands/run.py +4146 -0
  33. invarlock/cli/commands/verify.py +1040 -0
  34. invarlock/cli/config.py +396 -0
  35. invarlock/cli/constants.py +68 -0
  36. invarlock/cli/device.py +92 -0
  37. invarlock/cli/doctor_helpers.py +74 -0
  38. invarlock/cli/errors.py +6 -0
  39. invarlock/cli/overhead_utils.py +60 -0
  40. invarlock/cli/provenance.py +66 -0
  41. invarlock/cli/utils.py +41 -0
  42. invarlock/config.py +56 -0
  43. invarlock/core/__init__.py +62 -0
  44. invarlock/core/abi.py +15 -0
  45. invarlock/core/api.py +274 -0
  46. invarlock/core/auto_tuning.py +317 -0
  47. invarlock/core/bootstrap.py +226 -0
  48. invarlock/core/checkpoint.py +221 -0
  49. invarlock/core/contracts.py +73 -0
  50. invarlock/core/error_utils.py +64 -0
  51. invarlock/core/events.py +298 -0
  52. invarlock/core/exceptions.py +95 -0
  53. invarlock/core/registry.py +481 -0
  54. invarlock/core/retry.py +146 -0
  55. invarlock/core/runner.py +2041 -0
  56. invarlock/core/types.py +154 -0
  57. invarlock/edits/__init__.py +12 -0
  58. invarlock/edits/_edit_utils.py +249 -0
  59. invarlock/edits/_external_utils.py +268 -0
  60. invarlock/edits/noop.py +47 -0
  61. invarlock/edits/py.typed +1 -0
  62. invarlock/edits/quant_rtn.py +801 -0
  63. invarlock/edits/registry.py +166 -0
  64. invarlock/eval/__init__.py +23 -0
  65. invarlock/eval/bench.py +1207 -0
  66. invarlock/eval/bootstrap.py +50 -0
  67. invarlock/eval/data.py +2052 -0
  68. invarlock/eval/metrics.py +2167 -0
  69. invarlock/eval/primary_metric.py +767 -0
  70. invarlock/eval/probes/__init__.py +24 -0
  71. invarlock/eval/probes/fft.py +139 -0
  72. invarlock/eval/probes/mi.py +213 -0
  73. invarlock/eval/probes/post_attention.py +323 -0
  74. invarlock/eval/providers/base.py +67 -0
  75. invarlock/eval/providers/seq2seq.py +111 -0
  76. invarlock/eval/providers/text_lm.py +113 -0
  77. invarlock/eval/providers/vision_text.py +93 -0
  78. invarlock/eval/py.typed +1 -0
  79. invarlock/guards/__init__.py +18 -0
  80. invarlock/guards/_contracts.py +9 -0
  81. invarlock/guards/invariants.py +640 -0
  82. invarlock/guards/policies.py +805 -0
  83. invarlock/guards/py.typed +1 -0
  84. invarlock/guards/rmt.py +2097 -0
  85. invarlock/guards/spectral.py +1419 -0
  86. invarlock/guards/tier_config.py +354 -0
  87. invarlock/guards/variance.py +3298 -0
  88. invarlock/guards_ref/__init__.py +15 -0
  89. invarlock/guards_ref/rmt_ref.py +40 -0
  90. invarlock/guards_ref/spectral_ref.py +135 -0
  91. invarlock/guards_ref/variance_ref.py +60 -0
  92. invarlock/model_profile.py +353 -0
  93. invarlock/model_utils.py +221 -0
  94. invarlock/observability/__init__.py +10 -0
  95. invarlock/observability/alerting.py +535 -0
  96. invarlock/observability/core.py +546 -0
  97. invarlock/observability/exporters.py +565 -0
  98. invarlock/observability/health.py +588 -0
  99. invarlock/observability/metrics.py +457 -0
  100. invarlock/observability/py.typed +1 -0
  101. invarlock/observability/utils.py +553 -0
  102. invarlock/plugins/__init__.py +12 -0
  103. invarlock/plugins/hello_guard.py +33 -0
  104. invarlock/plugins/hf_awq_adapter.py +82 -0
  105. invarlock/plugins/hf_bnb_adapter.py +79 -0
  106. invarlock/plugins/hf_gptq_adapter.py +78 -0
  107. invarlock/plugins/py.typed +1 -0
  108. invarlock/py.typed +1 -0
  109. invarlock/reporting/__init__.py +7 -0
  110. invarlock/reporting/certificate.py +3221 -0
  111. invarlock/reporting/certificate_schema.py +244 -0
  112. invarlock/reporting/dataset_hashing.py +215 -0
  113. invarlock/reporting/guards_analysis.py +948 -0
  114. invarlock/reporting/html.py +32 -0
  115. invarlock/reporting/normalizer.py +235 -0
  116. invarlock/reporting/policy_utils.py +517 -0
  117. invarlock/reporting/primary_metric_utils.py +265 -0
  118. invarlock/reporting/render.py +1442 -0
  119. invarlock/reporting/report.py +903 -0
  120. invarlock/reporting/report_types.py +278 -0
  121. invarlock/reporting/utils.py +175 -0
  122. invarlock/reporting/validate.py +631 -0
  123. invarlock/security.py +176 -0
  124. invarlock/sparsity_utils.py +323 -0
  125. invarlock/utils/__init__.py +150 -0
  126. invarlock/utils/digest.py +45 -0
  127. invarlock-0.2.0.dist-info/METADATA +586 -0
  128. invarlock-0.2.0.dist-info/RECORD +132 -0
  129. invarlock-0.2.0.dist-info/WHEEL +5 -0
  130. invarlock-0.2.0.dist-info/entry_points.txt +20 -0
  131. invarlock-0.2.0.dist-info/licenses/LICENSE +201 -0
  132. invarlock-0.2.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,631 @@
1
+ """
2
+ InvarLock Validation Framework
3
+ =========================
4
+
5
+ Validation utilities for checking pruning results against baseline metrics.
6
+ Supports both automated CI testing and flexible user validation.
7
+ """
8
+
9
+ from __future__ import annotations
10
+
11
+ import json
12
+ import os
13
+ import warnings
14
+ from pathlib import Path
15
+ from typing import Any, cast
16
+
17
+ __all__ = [
18
+ "validate_against_baseline",
19
+ "validate_drift_gate",
20
+ "validate_guard_overhead",
21
+ "ValidationResult",
22
+ "load_baseline",
23
+ "save_baseline",
24
+ "create_baseline_from_report",
25
+ ]
26
+
27
+
28
+ class ValidationResult:
29
+ """Container for validation results."""
30
+
31
+ def __init__(
32
+ self,
33
+ passed: bool,
34
+ checks: dict[str, bool],
35
+ metrics: dict[str, float],
36
+ messages: list[str],
37
+ warnings: list[str] | None = None,
38
+ errors: list[str] | None = None,
39
+ ):
40
+ self.passed = passed
41
+ self.checks = checks
42
+ self.metrics = metrics
43
+ self.messages = messages
44
+ self.warnings = warnings or []
45
+ self.errors = errors or []
46
+
47
+ def to_dict(self) -> dict[str, Any]:
48
+ """Convert to dictionary for serialization."""
49
+ return {
50
+ "passed": self.passed,
51
+ "checks": self.checks,
52
+ "metrics": self.metrics,
53
+ "messages": self.messages,
54
+ "warnings": self.warnings,
55
+ "errors": self.errors,
56
+ }
57
+
58
+ def summary(self) -> str:
59
+ """Get human-readable summary."""
60
+ status = "✓ PASSED" if self.passed else "✗ FAILED"
61
+ passed_count = sum(1 for check in self.checks.values() if check)
62
+ total_count = len(self.checks)
63
+
64
+ lines = [
65
+ f"Validation {status} ({passed_count}/{total_count} checks passed)",
66
+ "",
67
+ ]
68
+
69
+ # Show individual check results
70
+ for check_name, passed in self.checks.items():
71
+ symbol = "✓" if passed else "✗"
72
+ lines.append(f" {symbol} {check_name}")
73
+
74
+ # Show messages
75
+ if self.messages:
76
+ lines.append("")
77
+ lines.extend(f" {msg}" for msg in self.messages)
78
+
79
+ # Show warnings and errors
80
+ if self.warnings:
81
+ lines.append("")
82
+ lines.append("Warnings:")
83
+ lines.extend(f" ⚠️ {warning}" for warning in self.warnings)
84
+
85
+ if self.errors:
86
+ lines.append("")
87
+ lines.append("Errors:")
88
+ lines.extend(f" ❌ {error}" for error in self.errors)
89
+
90
+ return "\n".join(lines)
91
+
92
+
93
+ def validate_against_baseline(
94
+ run_report: dict[str, Any],
95
+ baseline: dict[str, Any],
96
+ *,
97
+ tol_ratio: float = 0.02,
98
+ tol_param_ratio: float = 0.02,
99
+ ratio_bounds: tuple[float, float] = (1.25, 1.32),
100
+ delta_bounds_pp: tuple[float, float] | None = None,
101
+ structural_exact: bool = True,
102
+ **kwargs,
103
+ ) -> ValidationResult:
104
+ # Backward-compatible kwargs (deprecated): enable via INVARLOCK_VALIDATE_LEGACY=1
105
+ legacy = str(os.environ.get("INVARLOCK_VALIDATE_LEGACY", "")).strip().lower() in {
106
+ "1",
107
+ "true",
108
+ "yes",
109
+ "on",
110
+ }
111
+ if legacy:
112
+ if "tol_ppl_ratio" in kwargs and isinstance(
113
+ kwargs["tol_ppl_ratio"], int | float
114
+ ):
115
+ tol_ratio = float(kwargs["tol_ppl_ratio"])
116
+ if "ppl_bounds" in kwargs and isinstance(kwargs["ppl_bounds"], tuple):
117
+ # Coerce after runtime guard
118
+ ratio_bounds = cast(tuple[float, float], kwargs["ppl_bounds"])
119
+ """
120
+ Validate pruning results against baseline metrics (PM-only API).
121
+
122
+ Args:
123
+ run_report: Report from pruning run (dict with metrics)
124
+ baseline: Baseline metrics to compare against
125
+ tol_ratio: Tolerance for primary metric ratio deviation (±2% = 0.02) for lower-is-better families
126
+ tol_param_ratio: Tolerance for parameter reduction ratio deviation
127
+ ratio_bounds: Acceptable ratio bounds for lower-is-better families (min, max)
128
+ delta_bounds_pp: Acceptable delta bounds in percentage points for higher-is-better families (min, max)
129
+ structural_exact: Whether structural counts must match exactly
130
+
131
+ Returns:
132
+ ValidationResult with detailed check results
133
+ """
134
+ checks: dict[str, bool] = {}
135
+ metrics: dict[str, float] = {}
136
+ messages: list[str] = []
137
+ warnings_list: list[str] = []
138
+ errors: list[str] = []
139
+
140
+ try:
141
+ # Extract primary metric ratio (canonical)
142
+ current_ratio = None
143
+ pm_kind = None
144
+ pm = (
145
+ (run_report.get("metrics") or {}).get("primary_metric")
146
+ if isinstance(run_report.get("metrics"), dict)
147
+ else None
148
+ )
149
+ if isinstance(pm, dict) and pm:
150
+ val = pm.get("ratio_vs_baseline")
151
+ if isinstance(val, int | float):
152
+ current_ratio = float(val)
153
+ try:
154
+ pm_kind = str(pm.get("kind") or "").lower()
155
+ except Exception:
156
+ pm_kind = None
157
+ if current_ratio is None:
158
+ errors.append("Cannot extract ratio_vs_baseline from run report")
159
+
160
+ if "param_reduction_ratio" in run_report:
161
+ current_param_ratio = run_report["param_reduction_ratio"]
162
+ elif "parameters_removed" in run_report and "original_params" in run_report:
163
+ current_param_ratio = (
164
+ run_report["parameters_removed"] / run_report["original_params"]
165
+ )
166
+ else:
167
+ current_param_ratio = None
168
+ errors.append("Cannot extract parameter reduction ratio from run report")
169
+
170
+ # Extract baseline metrics
171
+ baseline_ratio = baseline.get("ratio_vs_baseline")
172
+ baseline_param_ratio = baseline.get("param_reduction_ratio")
173
+
174
+ if baseline_ratio is None:
175
+ errors.append("Baseline missing ratio_vs_baseline")
176
+ if baseline_param_ratio is None:
177
+ errors.append("Baseline missing param_reduction_ratio")
178
+
179
+ # Primary metric tolerance (lower-is-better families)
180
+ if pm_kind in {"ppl_causal", "ppl_mlm", "ppl_seq2seq", None}:
181
+ if current_ratio is not None and baseline_ratio is not None:
182
+ rel_diff = abs(current_ratio - float(baseline_ratio)) / float(
183
+ baseline_ratio
184
+ )
185
+ checks["ratio_tolerance"] = rel_diff <= tol_ratio
186
+ metrics["ratio_diff"] = rel_diff
187
+ metrics["current_ratio"] = current_ratio
188
+ metrics["baseline_ratio"] = float(baseline_ratio)
189
+
190
+ if not checks["ratio_tolerance"]:
191
+ msg = f"Primary metric ratio deviation {rel_diff:.3f} exceeds tolerance {tol_ratio:.3f}"
192
+ messages.append(msg)
193
+ else:
194
+ messages.append(
195
+ f"Primary metric ratio within tolerance: {current_ratio:.3f} vs baseline {float(baseline_ratio):.3f}"
196
+ )
197
+ else:
198
+ checks["ratio_tolerance"] = False
199
+
200
+ # Parameter ratio validation
201
+ if current_param_ratio is not None and baseline_param_ratio is not None:
202
+ param_relative_diff = (
203
+ abs(current_param_ratio - baseline_param_ratio) / baseline_param_ratio
204
+ )
205
+ checks["param_ratio_tolerance"] = param_relative_diff <= tol_param_ratio
206
+ metrics["param_ratio_diff"] = param_relative_diff
207
+ metrics["current_param_ratio"] = current_param_ratio
208
+ metrics["baseline_param_ratio"] = baseline_param_ratio
209
+
210
+ if not checks["param_ratio_tolerance"]:
211
+ messages.append(
212
+ f"Parameter ratio deviation {param_relative_diff:.3f} exceeds tolerance {tol_param_ratio:.3f}"
213
+ )
214
+ else:
215
+ messages.append(
216
+ f"Parameter ratio within tolerance: {current_param_ratio:.3f} vs baseline {baseline_param_ratio:.3f}"
217
+ )
218
+ else:
219
+ checks["param_ratio_tolerance"] = False
220
+
221
+ # Bounds check
222
+ if current_ratio is not None:
223
+ if pm_kind in {"accuracy", "vqa_accuracy"}:
224
+ # Interpret current_ratio as delta proportion; compare in pp when bounds provided
225
+ if isinstance(delta_bounds_pp, tuple) and len(delta_bounds_pp) == 2:
226
+ delta_pp = 100.0 * float(current_ratio)
227
+ lo_pp, hi_pp = float(delta_bounds_pp[0]), float(delta_bounds_pp[1])
228
+ checks["delta_bounds_pp"] = lo_pp <= delta_pp <= hi_pp
229
+ if not checks["delta_bounds_pp"]:
230
+ messages.append(
231
+ f"Δpp {delta_pp:+.2f} outside acceptable bounds {delta_bounds_pp}"
232
+ )
233
+ else:
234
+ messages.append(
235
+ f"Δpp {delta_pp:+.2f} within acceptable bounds {delta_bounds_pp}"
236
+ )
237
+ else:
238
+ checks["ratio_bounds"] = (
239
+ ratio_bounds[0] <= current_ratio <= ratio_bounds[1]
240
+ )
241
+ if not checks["ratio_bounds"]:
242
+ messages.append(
243
+ f"Ratio {current_ratio:.3f} outside acceptable bounds {ratio_bounds}"
244
+ )
245
+ else:
246
+ messages.append(
247
+ f"Ratio {current_ratio:.3f} within acceptable bounds {ratio_bounds}"
248
+ )
249
+ else:
250
+ if pm_kind in {"accuracy", "vqa_accuracy"}:
251
+ checks["delta_bounds_pp"] = False
252
+ else:
253
+ checks["ratio_bounds"] = False
254
+
255
+ # Structural count validation
256
+ if structural_exact:
257
+ structural_checks = _validate_structural_counts(run_report, baseline)
258
+ checks.update(structural_checks["checks"])
259
+ messages.extend(structural_checks["messages"])
260
+ warnings_list.extend(structural_checks["warnings"])
261
+ else:
262
+ checks["structural_counts"] = True # Skip structural validation
263
+
264
+ # Invariants validation (if present in report)
265
+ invariants_passed = _validate_invariants(run_report)
266
+ if invariants_passed is not None:
267
+ checks["invariants"] = invariants_passed
268
+ if not invariants_passed:
269
+ errors.append("Model invariants validation failed")
270
+
271
+ # Overall pass/fail
272
+ passed = all(checks.values()) and len(errors) == 0
273
+
274
+ return ValidationResult(
275
+ passed=passed,
276
+ checks=checks,
277
+ metrics=metrics,
278
+ messages=messages,
279
+ warnings=warnings_list,
280
+ errors=errors,
281
+ )
282
+
283
+ except Exception as e:
284
+ return ValidationResult(
285
+ passed=False,
286
+ checks={"validation_error": False},
287
+ metrics={},
288
+ messages=[],
289
+ warnings=[],
290
+ errors=[f"Validation failed with exception: {str(e)}"],
291
+ )
292
+
293
+
294
+ def validate_drift_gate(
295
+ run_report: dict[str, Any], drift_bounds: tuple[float, float] = (0.95, 1.05)
296
+ ) -> ValidationResult:
297
+ """
298
+ Validate hard drift gate: 0.95 ≤ final/preview ≤ 1.05.
299
+
300
+ Args:
301
+ run_report: Report from run with metrics.primary_metric preview/final
302
+ drift_bounds: Acceptable drift bounds (min, max) - default (0.95, 1.05)
303
+
304
+ Returns:
305
+ ValidationResult with drift gate check
306
+ """
307
+ checks = {}
308
+ metrics = {}
309
+ messages = []
310
+ warnings: list[str] = []
311
+ errors = []
312
+
313
+ try:
314
+ # Extract preview and final from primary_metric
315
+ pm = (
316
+ (run_report.get("metrics") or {}).get("primary_metric")
317
+ if isinstance(run_report.get("metrics"), dict)
318
+ else None
319
+ )
320
+ pm_preview = pm.get("preview") if isinstance(pm, dict) else None
321
+ pm_final = pm.get("final") if isinstance(pm, dict) else None
322
+
323
+ # Calculate drift ratio (final/preview) for lower-is-better families
324
+ if (
325
+ isinstance(pm_preview, (int | float))
326
+ and isinstance(pm_final, (int | float))
327
+ and pm_preview > 0
328
+ ):
329
+ drift_ratio = float(pm_final) / float(pm_preview)
330
+ metrics["drift_ratio"] = drift_ratio
331
+ metrics["preview"] = float(pm_preview)
332
+ metrics["final"] = float(pm_final)
333
+
334
+ # Apply hard gate
335
+ checks["drift_gate"] = drift_bounds[0] <= drift_ratio <= drift_bounds[1]
336
+
337
+ if checks["drift_gate"]:
338
+ messages.append(
339
+ f"Drift gate PASSED: {drift_ratio:.3f} within bounds {drift_bounds}"
340
+ )
341
+ else:
342
+ errors.append(
343
+ f"Drift gate FAILED: {drift_ratio:.3f} outside bounds {drift_bounds} "
344
+ f"(±5% drift limit exceeded)"
345
+ )
346
+ else:
347
+ errors.append(
348
+ "Cannot calculate drift: missing primary_metric preview/final"
349
+ )
350
+ checks["drift_gate"] = False
351
+
352
+ # Overall pass/fail
353
+ passed = all(checks.values()) and len(errors) == 0
354
+
355
+ return ValidationResult(
356
+ passed=passed,
357
+ checks=checks,
358
+ metrics=metrics,
359
+ messages=messages,
360
+ warnings=warnings,
361
+ errors=errors,
362
+ )
363
+
364
+ except Exception as e:
365
+ return ValidationResult(
366
+ passed=False,
367
+ checks={"drift_gate_error": False},
368
+ metrics={},
369
+ messages=[],
370
+ warnings=[],
371
+ errors=[f"Drift gate validation failed: {str(e)}"],
372
+ )
373
+
374
+
375
+ def validate_guard_overhead(
376
+ bare_report: dict[str, Any],
377
+ guarded_report: dict[str, Any],
378
+ overhead_threshold: float = 0.01,
379
+ ) -> ValidationResult:
380
+ """
381
+ Validate guard overhead using primary_metric: final(guarded)/final(bare) ≤ 1%.
382
+
383
+ Args:
384
+ bare_report: Report from bare (no guards) run (expects metrics.primary_metric)
385
+ guarded_report: Report from guarded run (expects metrics.primary_metric)
386
+ overhead_threshold: Maximum allowed overhead (default 0.01 = 1%)
387
+
388
+ Returns:
389
+ ValidationResult with guard overhead check
390
+ """
391
+ checks = {}
392
+ metrics = {}
393
+ messages = []
394
+ warnings: list[str] = []
395
+ errors = []
396
+
397
+ try:
398
+ # Extract primary metric final from both reports
399
+ bare_pm = (
400
+ (bare_report.get("metrics") or {}).get("primary_metric")
401
+ if isinstance(bare_report.get("metrics"), dict)
402
+ else None
403
+ )
404
+ guarded_pm = (
405
+ (guarded_report.get("metrics") or {}).get("primary_metric")
406
+ if isinstance(guarded_report.get("metrics"), dict)
407
+ else None
408
+ )
409
+
410
+ bare_final = None
411
+ guarded_final = None
412
+ if isinstance(bare_pm, dict):
413
+ bare_final = bare_pm.get("final")
414
+ if isinstance(guarded_pm, dict):
415
+ guarded_final = guarded_pm.get("final")
416
+
417
+ if (
418
+ isinstance(bare_final, (int | float))
419
+ and bare_final > 0
420
+ and isinstance(guarded_final, (int | float))
421
+ ):
422
+ overhead_ratio = float(guarded_final) / float(bare_final)
423
+ overhead_percent = (overhead_ratio - 1.0) * 100
424
+
425
+ metrics["overhead_ratio"] = overhead_ratio
426
+ metrics["overhead_percent"] = overhead_percent
427
+ metrics["bare_final"] = float(bare_final)
428
+ metrics["guarded_final"] = float(guarded_final)
429
+
430
+ # Apply overhead gate
431
+ checks["guard_overhead"] = overhead_ratio <= (1.0 + overhead_threshold)
432
+
433
+ if checks["guard_overhead"]:
434
+ messages.append(
435
+ f"Guard overhead PASSED: {overhead_percent:+.2f}% ≤ {overhead_threshold * 100:.1f}%"
436
+ )
437
+ else:
438
+ errors.append(
439
+ f"Guard overhead FAILED: {overhead_percent:+.2f}% > {overhead_threshold * 100:.1f}% "
440
+ f"(guards add too much primary-metric overhead)"
441
+ )
442
+ else:
443
+ errors.append(
444
+ "Cannot calculate guard overhead: missing primary_metric data"
445
+ )
446
+ checks["guard_overhead"] = False
447
+
448
+ # Overall pass/fail
449
+ passed = all(checks.values()) and len(errors) == 0
450
+
451
+ return ValidationResult(
452
+ passed=passed,
453
+ checks=checks,
454
+ metrics=metrics,
455
+ messages=messages,
456
+ warnings=warnings,
457
+ errors=errors,
458
+ )
459
+
460
+ except Exception as e:
461
+ return ValidationResult(
462
+ passed=False,
463
+ checks={"guard_overhead_error": False},
464
+ metrics={},
465
+ messages=[],
466
+ warnings=[],
467
+ errors=[f"Guard overhead validation failed: {str(e)}"],
468
+ )
469
+
470
+
471
+ def _validate_structural_counts(
472
+ run_report: dict[str, Any], baseline: dict[str, Any]
473
+ ) -> dict[str, Any]:
474
+ """Validate that structural counts match exactly."""
475
+ checks = {}
476
+ messages = []
477
+ warnings = []
478
+
479
+ # Heads/neurons counts removed from simplified schema; only validate layers
480
+
481
+ # Check layers modified
482
+ current_layers = run_report.get(
483
+ "layers_modified", run_report.get("metrics", {}).get("layers_modified")
484
+ )
485
+ baseline_layers = baseline.get("layers_modified")
486
+
487
+ if current_layers is not None and baseline_layers is not None:
488
+ checks["layers_count_exact"] = current_layers == baseline_layers
489
+ if checks["layers_count_exact"]:
490
+ messages.append(f"Modified layers count matches: {current_layers}")
491
+ else:
492
+ messages.append(
493
+ f"Modified layers mismatch: {current_layers} vs baseline {baseline_layers}"
494
+ )
495
+ else:
496
+ warnings.append("Cannot validate layers count - missing data")
497
+ checks["layers_count_exact"] = True # Don't fail on missing data
498
+
499
+ return {"checks": checks, "messages": messages, "warnings": warnings}
500
+
501
+
502
+ def _validate_invariants(run_report: dict[str, Any]) -> bool | None:
503
+ """Check if model invariants passed."""
504
+ # Look for invariants check in guard reports
505
+ guard_reports = run_report.get("guard_reports", {})
506
+
507
+ for guard_name, guard_report in guard_reports.items():
508
+ if "invariants" in guard_name.lower():
509
+ passed = guard_report.get("passed", True)
510
+ return bool(passed) if passed is not None else True
511
+
512
+ # Look for validation results in metrics
513
+ metrics = run_report.get("metrics", {})
514
+ if "invariants_passed" in metrics:
515
+ passed = metrics["invariants_passed"]
516
+ return bool(passed) if passed is not None else None
517
+
518
+ # No invariants check found
519
+ return None
520
+
521
+
522
+ def load_baseline(baseline_path: Path) -> dict[str, Any]:
523
+ """Load baseline metrics from JSON file."""
524
+ try:
525
+ with open(baseline_path) as f:
526
+ data = json.load(f)
527
+ if not isinstance(data, dict):
528
+ raise ValueError(
529
+ f"Baseline file must contain a JSON object, got {type(data)}"
530
+ )
531
+ return data
532
+ except FileNotFoundError as e:
533
+ raise FileNotFoundError(f"Baseline file not found: {baseline_path}") from e
534
+ except json.JSONDecodeError as e:
535
+ raise ValueError(f"Invalid JSON in baseline file: {e}") from e
536
+
537
+
538
+ def save_baseline(baseline: dict[str, Any], baseline_path: Path) -> None:
539
+ """Save baseline metrics to JSON file."""
540
+ baseline_path.parent.mkdir(parents=True, exist_ok=True)
541
+ with open(baseline_path, "w") as f:
542
+ json.dump(baseline, f, indent=2)
543
+
544
+
545
+ def create_baseline_from_report(run_report: dict[str, Any]) -> dict[str, Any]:
546
+ """Create a baseline structure from a run report."""
547
+ baseline: dict[str, Any] = {}
548
+
549
+ # Extract core metrics (PM-only)
550
+ try:
551
+ pm = (
552
+ run_report.get("metrics", {}).get("primary_metric")
553
+ if isinstance(run_report.get("metrics"), dict)
554
+ else None
555
+ )
556
+ if isinstance(pm, dict) and pm.get("ratio_vs_baseline") is not None:
557
+ baseline["ratio_vs_baseline"] = float(pm["ratio_vs_baseline"])
558
+ except Exception:
559
+ pass
560
+
561
+ if "param_reduction_ratio" in run_report:
562
+ baseline["param_reduction_ratio"] = run_report["param_reduction_ratio"]
563
+ elif "parameters_removed" in run_report and "original_params" in run_report:
564
+ baseline["param_reduction_ratio"] = (
565
+ run_report["parameters_removed"] / run_report["original_params"]
566
+ )
567
+
568
+ # Extract structural counts
569
+ metrics = run_report.get("metrics", {})
570
+ for key in ["heads_pruned", "neurons_pruned", "layers_modified"]:
571
+ if key in run_report:
572
+ baseline[key] = run_report[key]
573
+ elif key in metrics:
574
+ baseline[key] = metrics[key]
575
+
576
+ # Extract sparsity metrics
577
+ sparsity = run_report.get("actual_sparsity", {})
578
+ for key in ["head_sparsity", "neuron_sparsity", "weight_sparsity"]:
579
+ if key in sparsity:
580
+ baseline[key] = sparsity[key]
581
+
582
+ # Add metadata
583
+ baseline["baseline_created"] = True
584
+ baseline["source"] = "run_report"
585
+
586
+ return baseline
587
+
588
+
589
+ def validate_gpt2_small_wt2_baseline(
590
+ run_report: dict[str, Any], baseline_path: Path | None = None
591
+ ) -> ValidationResult:
592
+ """
593
+ Validate against the canonical GPT-2 small + WikiText-2 baseline.
594
+
595
+ This is the CI validation function that uses the pinned baseline.
596
+ """
597
+ if baseline_path is None:
598
+ # Use default baseline path
599
+ baseline_path = (
600
+ Path(__file__).parent.parent.parent
601
+ / "benchmarks"
602
+ / "baselines"
603
+ / "gpt2_small_wt2.json"
604
+ )
605
+
606
+ try:
607
+ baseline = load_baseline(baseline_path)
608
+ except FileNotFoundError:
609
+ # Create a default baseline if file doesn't exist
610
+ warnings.warn(
611
+ f"Baseline file not found: {baseline_path}. Using default values.",
612
+ stacklevel=2,
613
+ )
614
+ baseline = {
615
+ "ratio_vs_baseline": 1.285, # Target: ~1.25-1.32
616
+ "param_reduction_ratio": 0.022, # Target: ~2.2%
617
+ "heads_pruned": 16, # Example values
618
+ "neurons_pruned": 1024,
619
+ "layers_modified": 8,
620
+ "head_sparsity": 0.1,
621
+ "neuron_sparsity": 0.1,
622
+ }
623
+
624
+ return validate_against_baseline(
625
+ run_report,
626
+ baseline,
627
+ tol_ratio=0.02,
628
+ tol_param_ratio=0.02,
629
+ ratio_bounds=(1.25, 1.32),
630
+ structural_exact=True,
631
+ )