invarlock 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- invarlock/__init__.py +33 -0
- invarlock/__main__.py +10 -0
- invarlock/_data/runtime/profiles/ci_cpu.yaml +15 -0
- invarlock/_data/runtime/profiles/release.yaml +23 -0
- invarlock/_data/runtime/tiers.yaml +76 -0
- invarlock/adapters/__init__.py +102 -0
- invarlock/adapters/_capabilities.py +45 -0
- invarlock/adapters/auto.py +99 -0
- invarlock/adapters/base.py +530 -0
- invarlock/adapters/base_types.py +85 -0
- invarlock/adapters/hf_bert.py +852 -0
- invarlock/adapters/hf_gpt2.py +403 -0
- invarlock/adapters/hf_llama.py +485 -0
- invarlock/adapters/hf_mixin.py +383 -0
- invarlock/adapters/hf_onnx.py +112 -0
- invarlock/adapters/hf_t5.py +137 -0
- invarlock/adapters/py.typed +1 -0
- invarlock/assurance/__init__.py +43 -0
- invarlock/cli/__init__.py +8 -0
- invarlock/cli/__main__.py +8 -0
- invarlock/cli/_evidence.py +25 -0
- invarlock/cli/_json.py +75 -0
- invarlock/cli/adapter_auto.py +162 -0
- invarlock/cli/app.py +287 -0
- invarlock/cli/commands/__init__.py +26 -0
- invarlock/cli/commands/certify.py +403 -0
- invarlock/cli/commands/doctor.py +1358 -0
- invarlock/cli/commands/explain_gates.py +151 -0
- invarlock/cli/commands/export_html.py +100 -0
- invarlock/cli/commands/plugins.py +1331 -0
- invarlock/cli/commands/report.py +354 -0
- invarlock/cli/commands/run.py +4146 -0
- invarlock/cli/commands/verify.py +1040 -0
- invarlock/cli/config.py +396 -0
- invarlock/cli/constants.py +68 -0
- invarlock/cli/device.py +92 -0
- invarlock/cli/doctor_helpers.py +74 -0
- invarlock/cli/errors.py +6 -0
- invarlock/cli/overhead_utils.py +60 -0
- invarlock/cli/provenance.py +66 -0
- invarlock/cli/utils.py +41 -0
- invarlock/config.py +56 -0
- invarlock/core/__init__.py +62 -0
- invarlock/core/abi.py +15 -0
- invarlock/core/api.py +274 -0
- invarlock/core/auto_tuning.py +317 -0
- invarlock/core/bootstrap.py +226 -0
- invarlock/core/checkpoint.py +221 -0
- invarlock/core/contracts.py +73 -0
- invarlock/core/error_utils.py +64 -0
- invarlock/core/events.py +298 -0
- invarlock/core/exceptions.py +95 -0
- invarlock/core/registry.py +481 -0
- invarlock/core/retry.py +146 -0
- invarlock/core/runner.py +2041 -0
- invarlock/core/types.py +154 -0
- invarlock/edits/__init__.py +12 -0
- invarlock/edits/_edit_utils.py +249 -0
- invarlock/edits/_external_utils.py +268 -0
- invarlock/edits/noop.py +47 -0
- invarlock/edits/py.typed +1 -0
- invarlock/edits/quant_rtn.py +801 -0
- invarlock/edits/registry.py +166 -0
- invarlock/eval/__init__.py +23 -0
- invarlock/eval/bench.py +1207 -0
- invarlock/eval/bootstrap.py +50 -0
- invarlock/eval/data.py +2052 -0
- invarlock/eval/metrics.py +2167 -0
- invarlock/eval/primary_metric.py +767 -0
- invarlock/eval/probes/__init__.py +24 -0
- invarlock/eval/probes/fft.py +139 -0
- invarlock/eval/probes/mi.py +213 -0
- invarlock/eval/probes/post_attention.py +323 -0
- invarlock/eval/providers/base.py +67 -0
- invarlock/eval/providers/seq2seq.py +111 -0
- invarlock/eval/providers/text_lm.py +113 -0
- invarlock/eval/providers/vision_text.py +93 -0
- invarlock/eval/py.typed +1 -0
- invarlock/guards/__init__.py +18 -0
- invarlock/guards/_contracts.py +9 -0
- invarlock/guards/invariants.py +640 -0
- invarlock/guards/policies.py +805 -0
- invarlock/guards/py.typed +1 -0
- invarlock/guards/rmt.py +2097 -0
- invarlock/guards/spectral.py +1419 -0
- invarlock/guards/tier_config.py +354 -0
- invarlock/guards/variance.py +3298 -0
- invarlock/guards_ref/__init__.py +15 -0
- invarlock/guards_ref/rmt_ref.py +40 -0
- invarlock/guards_ref/spectral_ref.py +135 -0
- invarlock/guards_ref/variance_ref.py +60 -0
- invarlock/model_profile.py +353 -0
- invarlock/model_utils.py +221 -0
- invarlock/observability/__init__.py +10 -0
- invarlock/observability/alerting.py +535 -0
- invarlock/observability/core.py +546 -0
- invarlock/observability/exporters.py +565 -0
- invarlock/observability/health.py +588 -0
- invarlock/observability/metrics.py +457 -0
- invarlock/observability/py.typed +1 -0
- invarlock/observability/utils.py +553 -0
- invarlock/plugins/__init__.py +12 -0
- invarlock/plugins/hello_guard.py +33 -0
- invarlock/plugins/hf_awq_adapter.py +82 -0
- invarlock/plugins/hf_bnb_adapter.py +79 -0
- invarlock/plugins/hf_gptq_adapter.py +78 -0
- invarlock/plugins/py.typed +1 -0
- invarlock/py.typed +1 -0
- invarlock/reporting/__init__.py +7 -0
- invarlock/reporting/certificate.py +3221 -0
- invarlock/reporting/certificate_schema.py +244 -0
- invarlock/reporting/dataset_hashing.py +215 -0
- invarlock/reporting/guards_analysis.py +948 -0
- invarlock/reporting/html.py +32 -0
- invarlock/reporting/normalizer.py +235 -0
- invarlock/reporting/policy_utils.py +517 -0
- invarlock/reporting/primary_metric_utils.py +265 -0
- invarlock/reporting/render.py +1442 -0
- invarlock/reporting/report.py +903 -0
- invarlock/reporting/report_types.py +278 -0
- invarlock/reporting/utils.py +175 -0
- invarlock/reporting/validate.py +631 -0
- invarlock/security.py +176 -0
- invarlock/sparsity_utils.py +323 -0
- invarlock/utils/__init__.py +150 -0
- invarlock/utils/digest.py +45 -0
- invarlock-0.2.0.dist-info/METADATA +586 -0
- invarlock-0.2.0.dist-info/RECORD +132 -0
- invarlock-0.2.0.dist-info/WHEEL +5 -0
- invarlock-0.2.0.dist-info/entry_points.txt +20 -0
- invarlock-0.2.0.dist-info/licenses/LICENSE +201 -0
- invarlock-0.2.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,631 @@
|
|
|
1
|
+
"""
|
|
2
|
+
InvarLock Validation Framework
|
|
3
|
+
=========================
|
|
4
|
+
|
|
5
|
+
Validation utilities for checking pruning results against baseline metrics.
|
|
6
|
+
Supports both automated CI testing and flexible user validation.
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
from __future__ import annotations
|
|
10
|
+
|
|
11
|
+
import json
|
|
12
|
+
import os
|
|
13
|
+
import warnings
|
|
14
|
+
from pathlib import Path
|
|
15
|
+
from typing import Any, cast
|
|
16
|
+
|
|
17
|
+
__all__ = [
|
|
18
|
+
"validate_against_baseline",
|
|
19
|
+
"validate_drift_gate",
|
|
20
|
+
"validate_guard_overhead",
|
|
21
|
+
"ValidationResult",
|
|
22
|
+
"load_baseline",
|
|
23
|
+
"save_baseline",
|
|
24
|
+
"create_baseline_from_report",
|
|
25
|
+
]
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class ValidationResult:
|
|
29
|
+
"""Container for validation results."""
|
|
30
|
+
|
|
31
|
+
def __init__(
|
|
32
|
+
self,
|
|
33
|
+
passed: bool,
|
|
34
|
+
checks: dict[str, bool],
|
|
35
|
+
metrics: dict[str, float],
|
|
36
|
+
messages: list[str],
|
|
37
|
+
warnings: list[str] | None = None,
|
|
38
|
+
errors: list[str] | None = None,
|
|
39
|
+
):
|
|
40
|
+
self.passed = passed
|
|
41
|
+
self.checks = checks
|
|
42
|
+
self.metrics = metrics
|
|
43
|
+
self.messages = messages
|
|
44
|
+
self.warnings = warnings or []
|
|
45
|
+
self.errors = errors or []
|
|
46
|
+
|
|
47
|
+
def to_dict(self) -> dict[str, Any]:
|
|
48
|
+
"""Convert to dictionary for serialization."""
|
|
49
|
+
return {
|
|
50
|
+
"passed": self.passed,
|
|
51
|
+
"checks": self.checks,
|
|
52
|
+
"metrics": self.metrics,
|
|
53
|
+
"messages": self.messages,
|
|
54
|
+
"warnings": self.warnings,
|
|
55
|
+
"errors": self.errors,
|
|
56
|
+
}
|
|
57
|
+
|
|
58
|
+
def summary(self) -> str:
|
|
59
|
+
"""Get human-readable summary."""
|
|
60
|
+
status = "✓ PASSED" if self.passed else "✗ FAILED"
|
|
61
|
+
passed_count = sum(1 for check in self.checks.values() if check)
|
|
62
|
+
total_count = len(self.checks)
|
|
63
|
+
|
|
64
|
+
lines = [
|
|
65
|
+
f"Validation {status} ({passed_count}/{total_count} checks passed)",
|
|
66
|
+
"",
|
|
67
|
+
]
|
|
68
|
+
|
|
69
|
+
# Show individual check results
|
|
70
|
+
for check_name, passed in self.checks.items():
|
|
71
|
+
symbol = "✓" if passed else "✗"
|
|
72
|
+
lines.append(f" {symbol} {check_name}")
|
|
73
|
+
|
|
74
|
+
# Show messages
|
|
75
|
+
if self.messages:
|
|
76
|
+
lines.append("")
|
|
77
|
+
lines.extend(f" {msg}" for msg in self.messages)
|
|
78
|
+
|
|
79
|
+
# Show warnings and errors
|
|
80
|
+
if self.warnings:
|
|
81
|
+
lines.append("")
|
|
82
|
+
lines.append("Warnings:")
|
|
83
|
+
lines.extend(f" ⚠️ {warning}" for warning in self.warnings)
|
|
84
|
+
|
|
85
|
+
if self.errors:
|
|
86
|
+
lines.append("")
|
|
87
|
+
lines.append("Errors:")
|
|
88
|
+
lines.extend(f" ❌ {error}" for error in self.errors)
|
|
89
|
+
|
|
90
|
+
return "\n".join(lines)
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
def validate_against_baseline(
|
|
94
|
+
run_report: dict[str, Any],
|
|
95
|
+
baseline: dict[str, Any],
|
|
96
|
+
*,
|
|
97
|
+
tol_ratio: float = 0.02,
|
|
98
|
+
tol_param_ratio: float = 0.02,
|
|
99
|
+
ratio_bounds: tuple[float, float] = (1.25, 1.32),
|
|
100
|
+
delta_bounds_pp: tuple[float, float] | None = None,
|
|
101
|
+
structural_exact: bool = True,
|
|
102
|
+
**kwargs,
|
|
103
|
+
) -> ValidationResult:
|
|
104
|
+
# Backward-compatible kwargs (deprecated): enable via INVARLOCK_VALIDATE_LEGACY=1
|
|
105
|
+
legacy = str(os.environ.get("INVARLOCK_VALIDATE_LEGACY", "")).strip().lower() in {
|
|
106
|
+
"1",
|
|
107
|
+
"true",
|
|
108
|
+
"yes",
|
|
109
|
+
"on",
|
|
110
|
+
}
|
|
111
|
+
if legacy:
|
|
112
|
+
if "tol_ppl_ratio" in kwargs and isinstance(
|
|
113
|
+
kwargs["tol_ppl_ratio"], int | float
|
|
114
|
+
):
|
|
115
|
+
tol_ratio = float(kwargs["tol_ppl_ratio"])
|
|
116
|
+
if "ppl_bounds" in kwargs and isinstance(kwargs["ppl_bounds"], tuple):
|
|
117
|
+
# Coerce after runtime guard
|
|
118
|
+
ratio_bounds = cast(tuple[float, float], kwargs["ppl_bounds"])
|
|
119
|
+
"""
|
|
120
|
+
Validate pruning results against baseline metrics (PM-only API).
|
|
121
|
+
|
|
122
|
+
Args:
|
|
123
|
+
run_report: Report from pruning run (dict with metrics)
|
|
124
|
+
baseline: Baseline metrics to compare against
|
|
125
|
+
tol_ratio: Tolerance for primary metric ratio deviation (±2% = 0.02) for lower-is-better families
|
|
126
|
+
tol_param_ratio: Tolerance for parameter reduction ratio deviation
|
|
127
|
+
ratio_bounds: Acceptable ratio bounds for lower-is-better families (min, max)
|
|
128
|
+
delta_bounds_pp: Acceptable delta bounds in percentage points for higher-is-better families (min, max)
|
|
129
|
+
structural_exact: Whether structural counts must match exactly
|
|
130
|
+
|
|
131
|
+
Returns:
|
|
132
|
+
ValidationResult with detailed check results
|
|
133
|
+
"""
|
|
134
|
+
checks: dict[str, bool] = {}
|
|
135
|
+
metrics: dict[str, float] = {}
|
|
136
|
+
messages: list[str] = []
|
|
137
|
+
warnings_list: list[str] = []
|
|
138
|
+
errors: list[str] = []
|
|
139
|
+
|
|
140
|
+
try:
|
|
141
|
+
# Extract primary metric ratio (canonical)
|
|
142
|
+
current_ratio = None
|
|
143
|
+
pm_kind = None
|
|
144
|
+
pm = (
|
|
145
|
+
(run_report.get("metrics") or {}).get("primary_metric")
|
|
146
|
+
if isinstance(run_report.get("metrics"), dict)
|
|
147
|
+
else None
|
|
148
|
+
)
|
|
149
|
+
if isinstance(pm, dict) and pm:
|
|
150
|
+
val = pm.get("ratio_vs_baseline")
|
|
151
|
+
if isinstance(val, int | float):
|
|
152
|
+
current_ratio = float(val)
|
|
153
|
+
try:
|
|
154
|
+
pm_kind = str(pm.get("kind") or "").lower()
|
|
155
|
+
except Exception:
|
|
156
|
+
pm_kind = None
|
|
157
|
+
if current_ratio is None:
|
|
158
|
+
errors.append("Cannot extract ratio_vs_baseline from run report")
|
|
159
|
+
|
|
160
|
+
if "param_reduction_ratio" in run_report:
|
|
161
|
+
current_param_ratio = run_report["param_reduction_ratio"]
|
|
162
|
+
elif "parameters_removed" in run_report and "original_params" in run_report:
|
|
163
|
+
current_param_ratio = (
|
|
164
|
+
run_report["parameters_removed"] / run_report["original_params"]
|
|
165
|
+
)
|
|
166
|
+
else:
|
|
167
|
+
current_param_ratio = None
|
|
168
|
+
errors.append("Cannot extract parameter reduction ratio from run report")
|
|
169
|
+
|
|
170
|
+
# Extract baseline metrics
|
|
171
|
+
baseline_ratio = baseline.get("ratio_vs_baseline")
|
|
172
|
+
baseline_param_ratio = baseline.get("param_reduction_ratio")
|
|
173
|
+
|
|
174
|
+
if baseline_ratio is None:
|
|
175
|
+
errors.append("Baseline missing ratio_vs_baseline")
|
|
176
|
+
if baseline_param_ratio is None:
|
|
177
|
+
errors.append("Baseline missing param_reduction_ratio")
|
|
178
|
+
|
|
179
|
+
# Primary metric tolerance (lower-is-better families)
|
|
180
|
+
if pm_kind in {"ppl_causal", "ppl_mlm", "ppl_seq2seq", None}:
|
|
181
|
+
if current_ratio is not None and baseline_ratio is not None:
|
|
182
|
+
rel_diff = abs(current_ratio - float(baseline_ratio)) / float(
|
|
183
|
+
baseline_ratio
|
|
184
|
+
)
|
|
185
|
+
checks["ratio_tolerance"] = rel_diff <= tol_ratio
|
|
186
|
+
metrics["ratio_diff"] = rel_diff
|
|
187
|
+
metrics["current_ratio"] = current_ratio
|
|
188
|
+
metrics["baseline_ratio"] = float(baseline_ratio)
|
|
189
|
+
|
|
190
|
+
if not checks["ratio_tolerance"]:
|
|
191
|
+
msg = f"Primary metric ratio deviation {rel_diff:.3f} exceeds tolerance {tol_ratio:.3f}"
|
|
192
|
+
messages.append(msg)
|
|
193
|
+
else:
|
|
194
|
+
messages.append(
|
|
195
|
+
f"Primary metric ratio within tolerance: {current_ratio:.3f} vs baseline {float(baseline_ratio):.3f}"
|
|
196
|
+
)
|
|
197
|
+
else:
|
|
198
|
+
checks["ratio_tolerance"] = False
|
|
199
|
+
|
|
200
|
+
# Parameter ratio validation
|
|
201
|
+
if current_param_ratio is not None and baseline_param_ratio is not None:
|
|
202
|
+
param_relative_diff = (
|
|
203
|
+
abs(current_param_ratio - baseline_param_ratio) / baseline_param_ratio
|
|
204
|
+
)
|
|
205
|
+
checks["param_ratio_tolerance"] = param_relative_diff <= tol_param_ratio
|
|
206
|
+
metrics["param_ratio_diff"] = param_relative_diff
|
|
207
|
+
metrics["current_param_ratio"] = current_param_ratio
|
|
208
|
+
metrics["baseline_param_ratio"] = baseline_param_ratio
|
|
209
|
+
|
|
210
|
+
if not checks["param_ratio_tolerance"]:
|
|
211
|
+
messages.append(
|
|
212
|
+
f"Parameter ratio deviation {param_relative_diff:.3f} exceeds tolerance {tol_param_ratio:.3f}"
|
|
213
|
+
)
|
|
214
|
+
else:
|
|
215
|
+
messages.append(
|
|
216
|
+
f"Parameter ratio within tolerance: {current_param_ratio:.3f} vs baseline {baseline_param_ratio:.3f}"
|
|
217
|
+
)
|
|
218
|
+
else:
|
|
219
|
+
checks["param_ratio_tolerance"] = False
|
|
220
|
+
|
|
221
|
+
# Bounds check
|
|
222
|
+
if current_ratio is not None:
|
|
223
|
+
if pm_kind in {"accuracy", "vqa_accuracy"}:
|
|
224
|
+
# Interpret current_ratio as delta proportion; compare in pp when bounds provided
|
|
225
|
+
if isinstance(delta_bounds_pp, tuple) and len(delta_bounds_pp) == 2:
|
|
226
|
+
delta_pp = 100.0 * float(current_ratio)
|
|
227
|
+
lo_pp, hi_pp = float(delta_bounds_pp[0]), float(delta_bounds_pp[1])
|
|
228
|
+
checks["delta_bounds_pp"] = lo_pp <= delta_pp <= hi_pp
|
|
229
|
+
if not checks["delta_bounds_pp"]:
|
|
230
|
+
messages.append(
|
|
231
|
+
f"Δpp {delta_pp:+.2f} outside acceptable bounds {delta_bounds_pp}"
|
|
232
|
+
)
|
|
233
|
+
else:
|
|
234
|
+
messages.append(
|
|
235
|
+
f"Δpp {delta_pp:+.2f} within acceptable bounds {delta_bounds_pp}"
|
|
236
|
+
)
|
|
237
|
+
else:
|
|
238
|
+
checks["ratio_bounds"] = (
|
|
239
|
+
ratio_bounds[0] <= current_ratio <= ratio_bounds[1]
|
|
240
|
+
)
|
|
241
|
+
if not checks["ratio_bounds"]:
|
|
242
|
+
messages.append(
|
|
243
|
+
f"Ratio {current_ratio:.3f} outside acceptable bounds {ratio_bounds}"
|
|
244
|
+
)
|
|
245
|
+
else:
|
|
246
|
+
messages.append(
|
|
247
|
+
f"Ratio {current_ratio:.3f} within acceptable bounds {ratio_bounds}"
|
|
248
|
+
)
|
|
249
|
+
else:
|
|
250
|
+
if pm_kind in {"accuracy", "vqa_accuracy"}:
|
|
251
|
+
checks["delta_bounds_pp"] = False
|
|
252
|
+
else:
|
|
253
|
+
checks["ratio_bounds"] = False
|
|
254
|
+
|
|
255
|
+
# Structural count validation
|
|
256
|
+
if structural_exact:
|
|
257
|
+
structural_checks = _validate_structural_counts(run_report, baseline)
|
|
258
|
+
checks.update(structural_checks["checks"])
|
|
259
|
+
messages.extend(structural_checks["messages"])
|
|
260
|
+
warnings_list.extend(structural_checks["warnings"])
|
|
261
|
+
else:
|
|
262
|
+
checks["structural_counts"] = True # Skip structural validation
|
|
263
|
+
|
|
264
|
+
# Invariants validation (if present in report)
|
|
265
|
+
invariants_passed = _validate_invariants(run_report)
|
|
266
|
+
if invariants_passed is not None:
|
|
267
|
+
checks["invariants"] = invariants_passed
|
|
268
|
+
if not invariants_passed:
|
|
269
|
+
errors.append("Model invariants validation failed")
|
|
270
|
+
|
|
271
|
+
# Overall pass/fail
|
|
272
|
+
passed = all(checks.values()) and len(errors) == 0
|
|
273
|
+
|
|
274
|
+
return ValidationResult(
|
|
275
|
+
passed=passed,
|
|
276
|
+
checks=checks,
|
|
277
|
+
metrics=metrics,
|
|
278
|
+
messages=messages,
|
|
279
|
+
warnings=warnings_list,
|
|
280
|
+
errors=errors,
|
|
281
|
+
)
|
|
282
|
+
|
|
283
|
+
except Exception as e:
|
|
284
|
+
return ValidationResult(
|
|
285
|
+
passed=False,
|
|
286
|
+
checks={"validation_error": False},
|
|
287
|
+
metrics={},
|
|
288
|
+
messages=[],
|
|
289
|
+
warnings=[],
|
|
290
|
+
errors=[f"Validation failed with exception: {str(e)}"],
|
|
291
|
+
)
|
|
292
|
+
|
|
293
|
+
|
|
294
|
+
def validate_drift_gate(
|
|
295
|
+
run_report: dict[str, Any], drift_bounds: tuple[float, float] = (0.95, 1.05)
|
|
296
|
+
) -> ValidationResult:
|
|
297
|
+
"""
|
|
298
|
+
Validate hard drift gate: 0.95 ≤ final/preview ≤ 1.05.
|
|
299
|
+
|
|
300
|
+
Args:
|
|
301
|
+
run_report: Report from run with metrics.primary_metric preview/final
|
|
302
|
+
drift_bounds: Acceptable drift bounds (min, max) - default (0.95, 1.05)
|
|
303
|
+
|
|
304
|
+
Returns:
|
|
305
|
+
ValidationResult with drift gate check
|
|
306
|
+
"""
|
|
307
|
+
checks = {}
|
|
308
|
+
metrics = {}
|
|
309
|
+
messages = []
|
|
310
|
+
warnings: list[str] = []
|
|
311
|
+
errors = []
|
|
312
|
+
|
|
313
|
+
try:
|
|
314
|
+
# Extract preview and final from primary_metric
|
|
315
|
+
pm = (
|
|
316
|
+
(run_report.get("metrics") or {}).get("primary_metric")
|
|
317
|
+
if isinstance(run_report.get("metrics"), dict)
|
|
318
|
+
else None
|
|
319
|
+
)
|
|
320
|
+
pm_preview = pm.get("preview") if isinstance(pm, dict) else None
|
|
321
|
+
pm_final = pm.get("final") if isinstance(pm, dict) else None
|
|
322
|
+
|
|
323
|
+
# Calculate drift ratio (final/preview) for lower-is-better families
|
|
324
|
+
if (
|
|
325
|
+
isinstance(pm_preview, (int | float))
|
|
326
|
+
and isinstance(pm_final, (int | float))
|
|
327
|
+
and pm_preview > 0
|
|
328
|
+
):
|
|
329
|
+
drift_ratio = float(pm_final) / float(pm_preview)
|
|
330
|
+
metrics["drift_ratio"] = drift_ratio
|
|
331
|
+
metrics["preview"] = float(pm_preview)
|
|
332
|
+
metrics["final"] = float(pm_final)
|
|
333
|
+
|
|
334
|
+
# Apply hard gate
|
|
335
|
+
checks["drift_gate"] = drift_bounds[0] <= drift_ratio <= drift_bounds[1]
|
|
336
|
+
|
|
337
|
+
if checks["drift_gate"]:
|
|
338
|
+
messages.append(
|
|
339
|
+
f"Drift gate PASSED: {drift_ratio:.3f} within bounds {drift_bounds}"
|
|
340
|
+
)
|
|
341
|
+
else:
|
|
342
|
+
errors.append(
|
|
343
|
+
f"Drift gate FAILED: {drift_ratio:.3f} outside bounds {drift_bounds} "
|
|
344
|
+
f"(±5% drift limit exceeded)"
|
|
345
|
+
)
|
|
346
|
+
else:
|
|
347
|
+
errors.append(
|
|
348
|
+
"Cannot calculate drift: missing primary_metric preview/final"
|
|
349
|
+
)
|
|
350
|
+
checks["drift_gate"] = False
|
|
351
|
+
|
|
352
|
+
# Overall pass/fail
|
|
353
|
+
passed = all(checks.values()) and len(errors) == 0
|
|
354
|
+
|
|
355
|
+
return ValidationResult(
|
|
356
|
+
passed=passed,
|
|
357
|
+
checks=checks,
|
|
358
|
+
metrics=metrics,
|
|
359
|
+
messages=messages,
|
|
360
|
+
warnings=warnings,
|
|
361
|
+
errors=errors,
|
|
362
|
+
)
|
|
363
|
+
|
|
364
|
+
except Exception as e:
|
|
365
|
+
return ValidationResult(
|
|
366
|
+
passed=False,
|
|
367
|
+
checks={"drift_gate_error": False},
|
|
368
|
+
metrics={},
|
|
369
|
+
messages=[],
|
|
370
|
+
warnings=[],
|
|
371
|
+
errors=[f"Drift gate validation failed: {str(e)}"],
|
|
372
|
+
)
|
|
373
|
+
|
|
374
|
+
|
|
375
|
+
def validate_guard_overhead(
|
|
376
|
+
bare_report: dict[str, Any],
|
|
377
|
+
guarded_report: dict[str, Any],
|
|
378
|
+
overhead_threshold: float = 0.01,
|
|
379
|
+
) -> ValidationResult:
|
|
380
|
+
"""
|
|
381
|
+
Validate guard overhead using primary_metric: final(guarded)/final(bare) ≤ 1%.
|
|
382
|
+
|
|
383
|
+
Args:
|
|
384
|
+
bare_report: Report from bare (no guards) run (expects metrics.primary_metric)
|
|
385
|
+
guarded_report: Report from guarded run (expects metrics.primary_metric)
|
|
386
|
+
overhead_threshold: Maximum allowed overhead (default 0.01 = 1%)
|
|
387
|
+
|
|
388
|
+
Returns:
|
|
389
|
+
ValidationResult with guard overhead check
|
|
390
|
+
"""
|
|
391
|
+
checks = {}
|
|
392
|
+
metrics = {}
|
|
393
|
+
messages = []
|
|
394
|
+
warnings: list[str] = []
|
|
395
|
+
errors = []
|
|
396
|
+
|
|
397
|
+
try:
|
|
398
|
+
# Extract primary metric final from both reports
|
|
399
|
+
bare_pm = (
|
|
400
|
+
(bare_report.get("metrics") or {}).get("primary_metric")
|
|
401
|
+
if isinstance(bare_report.get("metrics"), dict)
|
|
402
|
+
else None
|
|
403
|
+
)
|
|
404
|
+
guarded_pm = (
|
|
405
|
+
(guarded_report.get("metrics") or {}).get("primary_metric")
|
|
406
|
+
if isinstance(guarded_report.get("metrics"), dict)
|
|
407
|
+
else None
|
|
408
|
+
)
|
|
409
|
+
|
|
410
|
+
bare_final = None
|
|
411
|
+
guarded_final = None
|
|
412
|
+
if isinstance(bare_pm, dict):
|
|
413
|
+
bare_final = bare_pm.get("final")
|
|
414
|
+
if isinstance(guarded_pm, dict):
|
|
415
|
+
guarded_final = guarded_pm.get("final")
|
|
416
|
+
|
|
417
|
+
if (
|
|
418
|
+
isinstance(bare_final, (int | float))
|
|
419
|
+
and bare_final > 0
|
|
420
|
+
and isinstance(guarded_final, (int | float))
|
|
421
|
+
):
|
|
422
|
+
overhead_ratio = float(guarded_final) / float(bare_final)
|
|
423
|
+
overhead_percent = (overhead_ratio - 1.0) * 100
|
|
424
|
+
|
|
425
|
+
metrics["overhead_ratio"] = overhead_ratio
|
|
426
|
+
metrics["overhead_percent"] = overhead_percent
|
|
427
|
+
metrics["bare_final"] = float(bare_final)
|
|
428
|
+
metrics["guarded_final"] = float(guarded_final)
|
|
429
|
+
|
|
430
|
+
# Apply overhead gate
|
|
431
|
+
checks["guard_overhead"] = overhead_ratio <= (1.0 + overhead_threshold)
|
|
432
|
+
|
|
433
|
+
if checks["guard_overhead"]:
|
|
434
|
+
messages.append(
|
|
435
|
+
f"Guard overhead PASSED: {overhead_percent:+.2f}% ≤ {overhead_threshold * 100:.1f}%"
|
|
436
|
+
)
|
|
437
|
+
else:
|
|
438
|
+
errors.append(
|
|
439
|
+
f"Guard overhead FAILED: {overhead_percent:+.2f}% > {overhead_threshold * 100:.1f}% "
|
|
440
|
+
f"(guards add too much primary-metric overhead)"
|
|
441
|
+
)
|
|
442
|
+
else:
|
|
443
|
+
errors.append(
|
|
444
|
+
"Cannot calculate guard overhead: missing primary_metric data"
|
|
445
|
+
)
|
|
446
|
+
checks["guard_overhead"] = False
|
|
447
|
+
|
|
448
|
+
# Overall pass/fail
|
|
449
|
+
passed = all(checks.values()) and len(errors) == 0
|
|
450
|
+
|
|
451
|
+
return ValidationResult(
|
|
452
|
+
passed=passed,
|
|
453
|
+
checks=checks,
|
|
454
|
+
metrics=metrics,
|
|
455
|
+
messages=messages,
|
|
456
|
+
warnings=warnings,
|
|
457
|
+
errors=errors,
|
|
458
|
+
)
|
|
459
|
+
|
|
460
|
+
except Exception as e:
|
|
461
|
+
return ValidationResult(
|
|
462
|
+
passed=False,
|
|
463
|
+
checks={"guard_overhead_error": False},
|
|
464
|
+
metrics={},
|
|
465
|
+
messages=[],
|
|
466
|
+
warnings=[],
|
|
467
|
+
errors=[f"Guard overhead validation failed: {str(e)}"],
|
|
468
|
+
)
|
|
469
|
+
|
|
470
|
+
|
|
471
|
+
def _validate_structural_counts(
|
|
472
|
+
run_report: dict[str, Any], baseline: dict[str, Any]
|
|
473
|
+
) -> dict[str, Any]:
|
|
474
|
+
"""Validate that structural counts match exactly."""
|
|
475
|
+
checks = {}
|
|
476
|
+
messages = []
|
|
477
|
+
warnings = []
|
|
478
|
+
|
|
479
|
+
# Heads/neurons counts removed from simplified schema; only validate layers
|
|
480
|
+
|
|
481
|
+
# Check layers modified
|
|
482
|
+
current_layers = run_report.get(
|
|
483
|
+
"layers_modified", run_report.get("metrics", {}).get("layers_modified")
|
|
484
|
+
)
|
|
485
|
+
baseline_layers = baseline.get("layers_modified")
|
|
486
|
+
|
|
487
|
+
if current_layers is not None and baseline_layers is not None:
|
|
488
|
+
checks["layers_count_exact"] = current_layers == baseline_layers
|
|
489
|
+
if checks["layers_count_exact"]:
|
|
490
|
+
messages.append(f"Modified layers count matches: {current_layers}")
|
|
491
|
+
else:
|
|
492
|
+
messages.append(
|
|
493
|
+
f"Modified layers mismatch: {current_layers} vs baseline {baseline_layers}"
|
|
494
|
+
)
|
|
495
|
+
else:
|
|
496
|
+
warnings.append("Cannot validate layers count - missing data")
|
|
497
|
+
checks["layers_count_exact"] = True # Don't fail on missing data
|
|
498
|
+
|
|
499
|
+
return {"checks": checks, "messages": messages, "warnings": warnings}
|
|
500
|
+
|
|
501
|
+
|
|
502
|
+
def _validate_invariants(run_report: dict[str, Any]) -> bool | None:
|
|
503
|
+
"""Check if model invariants passed."""
|
|
504
|
+
# Look for invariants check in guard reports
|
|
505
|
+
guard_reports = run_report.get("guard_reports", {})
|
|
506
|
+
|
|
507
|
+
for guard_name, guard_report in guard_reports.items():
|
|
508
|
+
if "invariants" in guard_name.lower():
|
|
509
|
+
passed = guard_report.get("passed", True)
|
|
510
|
+
return bool(passed) if passed is not None else True
|
|
511
|
+
|
|
512
|
+
# Look for validation results in metrics
|
|
513
|
+
metrics = run_report.get("metrics", {})
|
|
514
|
+
if "invariants_passed" in metrics:
|
|
515
|
+
passed = metrics["invariants_passed"]
|
|
516
|
+
return bool(passed) if passed is not None else None
|
|
517
|
+
|
|
518
|
+
# No invariants check found
|
|
519
|
+
return None
|
|
520
|
+
|
|
521
|
+
|
|
522
|
+
def load_baseline(baseline_path: Path) -> dict[str, Any]:
|
|
523
|
+
"""Load baseline metrics from JSON file."""
|
|
524
|
+
try:
|
|
525
|
+
with open(baseline_path) as f:
|
|
526
|
+
data = json.load(f)
|
|
527
|
+
if not isinstance(data, dict):
|
|
528
|
+
raise ValueError(
|
|
529
|
+
f"Baseline file must contain a JSON object, got {type(data)}"
|
|
530
|
+
)
|
|
531
|
+
return data
|
|
532
|
+
except FileNotFoundError as e:
|
|
533
|
+
raise FileNotFoundError(f"Baseline file not found: {baseline_path}") from e
|
|
534
|
+
except json.JSONDecodeError as e:
|
|
535
|
+
raise ValueError(f"Invalid JSON in baseline file: {e}") from e
|
|
536
|
+
|
|
537
|
+
|
|
538
|
+
def save_baseline(baseline: dict[str, Any], baseline_path: Path) -> None:
|
|
539
|
+
"""Save baseline metrics to JSON file."""
|
|
540
|
+
baseline_path.parent.mkdir(parents=True, exist_ok=True)
|
|
541
|
+
with open(baseline_path, "w") as f:
|
|
542
|
+
json.dump(baseline, f, indent=2)
|
|
543
|
+
|
|
544
|
+
|
|
545
|
+
def create_baseline_from_report(run_report: dict[str, Any]) -> dict[str, Any]:
|
|
546
|
+
"""Create a baseline structure from a run report."""
|
|
547
|
+
baseline: dict[str, Any] = {}
|
|
548
|
+
|
|
549
|
+
# Extract core metrics (PM-only)
|
|
550
|
+
try:
|
|
551
|
+
pm = (
|
|
552
|
+
run_report.get("metrics", {}).get("primary_metric")
|
|
553
|
+
if isinstance(run_report.get("metrics"), dict)
|
|
554
|
+
else None
|
|
555
|
+
)
|
|
556
|
+
if isinstance(pm, dict) and pm.get("ratio_vs_baseline") is not None:
|
|
557
|
+
baseline["ratio_vs_baseline"] = float(pm["ratio_vs_baseline"])
|
|
558
|
+
except Exception:
|
|
559
|
+
pass
|
|
560
|
+
|
|
561
|
+
if "param_reduction_ratio" in run_report:
|
|
562
|
+
baseline["param_reduction_ratio"] = run_report["param_reduction_ratio"]
|
|
563
|
+
elif "parameters_removed" in run_report and "original_params" in run_report:
|
|
564
|
+
baseline["param_reduction_ratio"] = (
|
|
565
|
+
run_report["parameters_removed"] / run_report["original_params"]
|
|
566
|
+
)
|
|
567
|
+
|
|
568
|
+
# Extract structural counts
|
|
569
|
+
metrics = run_report.get("metrics", {})
|
|
570
|
+
for key in ["heads_pruned", "neurons_pruned", "layers_modified"]:
|
|
571
|
+
if key in run_report:
|
|
572
|
+
baseline[key] = run_report[key]
|
|
573
|
+
elif key in metrics:
|
|
574
|
+
baseline[key] = metrics[key]
|
|
575
|
+
|
|
576
|
+
# Extract sparsity metrics
|
|
577
|
+
sparsity = run_report.get("actual_sparsity", {})
|
|
578
|
+
for key in ["head_sparsity", "neuron_sparsity", "weight_sparsity"]:
|
|
579
|
+
if key in sparsity:
|
|
580
|
+
baseline[key] = sparsity[key]
|
|
581
|
+
|
|
582
|
+
# Add metadata
|
|
583
|
+
baseline["baseline_created"] = True
|
|
584
|
+
baseline["source"] = "run_report"
|
|
585
|
+
|
|
586
|
+
return baseline
|
|
587
|
+
|
|
588
|
+
|
|
589
|
+
def validate_gpt2_small_wt2_baseline(
|
|
590
|
+
run_report: dict[str, Any], baseline_path: Path | None = None
|
|
591
|
+
) -> ValidationResult:
|
|
592
|
+
"""
|
|
593
|
+
Validate against the canonical GPT-2 small + WikiText-2 baseline.
|
|
594
|
+
|
|
595
|
+
This is the CI validation function that uses the pinned baseline.
|
|
596
|
+
"""
|
|
597
|
+
if baseline_path is None:
|
|
598
|
+
# Use default baseline path
|
|
599
|
+
baseline_path = (
|
|
600
|
+
Path(__file__).parent.parent.parent
|
|
601
|
+
/ "benchmarks"
|
|
602
|
+
/ "baselines"
|
|
603
|
+
/ "gpt2_small_wt2.json"
|
|
604
|
+
)
|
|
605
|
+
|
|
606
|
+
try:
|
|
607
|
+
baseline = load_baseline(baseline_path)
|
|
608
|
+
except FileNotFoundError:
|
|
609
|
+
# Create a default baseline if file doesn't exist
|
|
610
|
+
warnings.warn(
|
|
611
|
+
f"Baseline file not found: {baseline_path}. Using default values.",
|
|
612
|
+
stacklevel=2,
|
|
613
|
+
)
|
|
614
|
+
baseline = {
|
|
615
|
+
"ratio_vs_baseline": 1.285, # Target: ~1.25-1.32
|
|
616
|
+
"param_reduction_ratio": 0.022, # Target: ~2.2%
|
|
617
|
+
"heads_pruned": 16, # Example values
|
|
618
|
+
"neurons_pruned": 1024,
|
|
619
|
+
"layers_modified": 8,
|
|
620
|
+
"head_sparsity": 0.1,
|
|
621
|
+
"neuron_sparsity": 0.1,
|
|
622
|
+
}
|
|
623
|
+
|
|
624
|
+
return validate_against_baseline(
|
|
625
|
+
run_report,
|
|
626
|
+
baseline,
|
|
627
|
+
tol_ratio=0.02,
|
|
628
|
+
tol_param_ratio=0.02,
|
|
629
|
+
ratio_bounds=(1.25, 1.32),
|
|
630
|
+
structural_exact=True,
|
|
631
|
+
)
|