invarlock 0.3.5__py3-none-any.whl → 0.3.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- invarlock/__init__.py +1 -1
- invarlock/_data/runtime/tiers.yaml +57 -30
- invarlock/adapters/__init__.py +1 -1
- invarlock/calibration/spectral_null.py +15 -10
- invarlock/calibration/variance_ve.py +0 -2
- invarlock/cli/commands/calibrate.py +6 -2
- invarlock/cli/commands/certify.py +58 -39
- invarlock/cli/commands/doctor.py +3 -1
- invarlock/cli/commands/explain_gates.py +57 -8
- invarlock/cli/commands/report.py +1 -1
- invarlock/cli/commands/run.py +159 -61
- invarlock/cli/commands/verify.py +78 -4
- invarlock/cli/config.py +21 -5
- invarlock/core/api.py +45 -5
- invarlock/core/auto_tuning.py +65 -20
- invarlock/core/contracts.py +7 -1
- invarlock/core/registry.py +2 -2
- invarlock/core/runner.py +314 -50
- invarlock/eval/bench.py +0 -13
- invarlock/eval/data.py +14 -28
- invarlock/eval/metrics.py +4 -1
- invarlock/eval/primary_metric.py +23 -0
- invarlock/eval/tail_stats.py +230 -0
- invarlock/guards/_estimators.py +154 -0
- invarlock/guards/policies.py +16 -6
- invarlock/guards/rmt.py +625 -544
- invarlock/guards/spectral.py +348 -110
- invarlock/guards/tier_config.py +32 -30
- invarlock/guards/variance.py +5 -29
- invarlock/guards_ref/rmt_ref.py +23 -23
- invarlock/model_profile.py +42 -15
- invarlock/reporting/certificate.py +225 -46
- invarlock/reporting/certificate_schema.py +2 -1
- invarlock/reporting/dataset_hashing.py +15 -2
- invarlock/reporting/guards_analysis.py +197 -274
- invarlock/reporting/normalizer.py +6 -0
- invarlock/reporting/policy_utils.py +38 -36
- invarlock/reporting/primary_metric_utils.py +71 -17
- invarlock/reporting/render.py +61 -0
- invarlock/reporting/report.py +1 -1
- invarlock/reporting/report_types.py +5 -2
- invarlock/reporting/validate.py +1 -18
- {invarlock-0.3.5.dist-info → invarlock-0.3.6.dist-info}/METADATA +6 -6
- {invarlock-0.3.5.dist-info → invarlock-0.3.6.dist-info}/RECORD +48 -46
- {invarlock-0.3.5.dist-info → invarlock-0.3.6.dist-info}/WHEEL +0 -0
- {invarlock-0.3.5.dist-info → invarlock-0.3.6.dist-info}/entry_points.txt +0 -0
- {invarlock-0.3.5.dist-info → invarlock-0.3.6.dist-info}/licenses/LICENSE +0 -0
- {invarlock-0.3.5.dist-info → invarlock-0.3.6.dist-info}/top_level.txt +0 -0
invarlock/guards/spectral.py
CHANGED
|
@@ -12,18 +12,15 @@ from collections import defaultdict
|
|
|
12
12
|
from datetime import datetime
|
|
13
13
|
from typing import Any, TypedDict
|
|
14
14
|
|
|
15
|
-
try:
|
|
16
|
-
from typing import NotRequired
|
|
17
|
-
except ImportError: # Python 3.10 fallback
|
|
18
|
-
from typing import NotRequired
|
|
19
|
-
|
|
20
15
|
import numpy as np
|
|
21
16
|
import torch
|
|
22
17
|
|
|
23
18
|
from invarlock.cli._evidence import maybe_dump_guard_evidence
|
|
24
19
|
from invarlock.core.api import Guard
|
|
20
|
+
from invarlock.core.exceptions import ValidationError
|
|
25
21
|
|
|
26
22
|
from ._contracts import guard_assert
|
|
23
|
+
from ._estimators import frobenius_norm_sq, power_iter_sigma_max, row_col_norm_extrema
|
|
27
24
|
|
|
28
25
|
|
|
29
26
|
def _z_to_two_sided_pvalue(z: Any) -> float:
|
|
@@ -104,10 +101,10 @@ class SpectralPolicy(TypedDict, total=False):
|
|
|
104
101
|
"""Type definition for spectral guard policy configuration."""
|
|
105
102
|
|
|
106
103
|
sigma_quantile: float
|
|
107
|
-
contraction: NotRequired[float] # Backward compatibility alias
|
|
108
|
-
kappa: NotRequired[float] # Legacy alias
|
|
109
104
|
deadband: float
|
|
110
105
|
scope: str
|
|
106
|
+
estimator: dict[str, Any]
|
|
107
|
+
degeneracy: dict[str, Any]
|
|
111
108
|
correction_enabled: bool
|
|
112
109
|
family_caps: dict[str, dict[str, float]]
|
|
113
110
|
ignore_preview_inflation: bool
|
|
@@ -172,24 +169,16 @@ class SpectralGuard(Guard):
|
|
|
172
169
|
|
|
173
170
|
# Default configuration
|
|
174
171
|
sigma_quantile = self.config.get("sigma_quantile")
|
|
175
|
-
if sigma_quantile is None:
|
|
176
|
-
for alias in ("contraction", "kappa"):
|
|
177
|
-
if self.config.get(alias) is not None:
|
|
178
|
-
sigma_quantile = self.config[alias]
|
|
179
|
-
break
|
|
180
172
|
if sigma_quantile is None:
|
|
181
173
|
sigma_quantile = 0.95
|
|
182
174
|
self.sigma_quantile = float(sigma_quantile)
|
|
183
175
|
self.config["sigma_quantile"] = self.sigma_quantile
|
|
184
|
-
self.config.pop("contraction", None)
|
|
185
|
-
self.config.pop("kappa", None)
|
|
186
176
|
self.deadband = kwargs.get("deadband", 0.10)
|
|
187
177
|
self.scope = kwargs.get("scope", "all") # 'all', 'ffn', 'attn'
|
|
188
|
-
self.max_spectral_norm = kwargs.get("max_spectral_norm",
|
|
178
|
+
self.max_spectral_norm = kwargs.get("max_spectral_norm", None)
|
|
189
179
|
if self.max_spectral_norm is not None:
|
|
190
180
|
self.max_spectral_norm = float(self.max_spectral_norm)
|
|
191
181
|
self.config["max_spectral_norm"] = self.max_spectral_norm
|
|
192
|
-
self.min_condition_number = kwargs.get("min_condition_number", 1e-12)
|
|
193
182
|
self.correction_enabled = kwargs.get("correction_enabled", True)
|
|
194
183
|
self.family_caps = _normalize_family_caps(
|
|
195
184
|
kwargs.get("family_caps"), default=True
|
|
@@ -200,12 +189,61 @@ class SpectralGuard(Guard):
|
|
|
200
189
|
"multiple_testing", {"method": "bh", "alpha": 0.05, "m": 4}
|
|
201
190
|
)
|
|
202
191
|
|
|
192
|
+
estimator_cfg = kwargs.get("estimator")
|
|
193
|
+
if not isinstance(estimator_cfg, dict):
|
|
194
|
+
estimator_cfg = {}
|
|
195
|
+
try:
|
|
196
|
+
est_iters = int(estimator_cfg.get("iters", 4) or 4)
|
|
197
|
+
except Exception:
|
|
198
|
+
est_iters = 4
|
|
199
|
+
if est_iters < 1:
|
|
200
|
+
est_iters = 1
|
|
201
|
+
est_init = str(estimator_cfg.get("init", "ones") or "ones").strip().lower()
|
|
202
|
+
if est_init not in {"ones", "e0"}:
|
|
203
|
+
est_init = "ones"
|
|
204
|
+
self.estimator: dict[str, Any] = {
|
|
205
|
+
"type": "power_iter",
|
|
206
|
+
"iters": est_iters,
|
|
207
|
+
"init": est_init,
|
|
208
|
+
}
|
|
209
|
+
|
|
210
|
+
degeneracy_cfg = kwargs.get("degeneracy")
|
|
211
|
+
if not isinstance(degeneracy_cfg, dict):
|
|
212
|
+
degeneracy_cfg = {}
|
|
213
|
+
stable_rank_cfg = (
|
|
214
|
+
degeneracy_cfg.get("stable_rank")
|
|
215
|
+
if isinstance(degeneracy_cfg, dict)
|
|
216
|
+
else {}
|
|
217
|
+
)
|
|
218
|
+
norm_collapse_cfg = (
|
|
219
|
+
degeneracy_cfg.get("norm_collapse")
|
|
220
|
+
if isinstance(degeneracy_cfg, dict)
|
|
221
|
+
else {}
|
|
222
|
+
)
|
|
223
|
+
self.degeneracy: dict[str, Any] = {
|
|
224
|
+
"enabled": bool(degeneracy_cfg.get("enabled", True)),
|
|
225
|
+
"stable_rank": {
|
|
226
|
+
"warn_ratio": float((stable_rank_cfg or {}).get("warn_ratio", 0.5)),
|
|
227
|
+
"fatal_ratio": float((stable_rank_cfg or {}).get("fatal_ratio", 0.25)),
|
|
228
|
+
},
|
|
229
|
+
"norm_collapse": {
|
|
230
|
+
"warn_ratio": float((norm_collapse_cfg or {}).get("warn_ratio", 0.25)),
|
|
231
|
+
"fatal_ratio": float(
|
|
232
|
+
(norm_collapse_cfg or {}).get("fatal_ratio", 0.10)
|
|
233
|
+
),
|
|
234
|
+
},
|
|
235
|
+
}
|
|
236
|
+
|
|
203
237
|
# Baseline and tracking structures
|
|
204
238
|
self.baseline_sigmas: dict[str, float] = {}
|
|
205
239
|
self.baseline_family_stats: dict[str, dict[str, float]] = {}
|
|
206
240
|
self.module_family_map: dict[str, str] = {}
|
|
207
241
|
self.latest_z_scores: dict[str, float] = {}
|
|
208
242
|
self.pre_edit_z_scores: dict[str, float] = {}
|
|
243
|
+
self.baseline_degeneracy: dict[str, dict[str, float]] = {}
|
|
244
|
+
|
|
245
|
+
# Run context (informational only; contract is policy-bound)
|
|
246
|
+
self._run_profile: str | None = None
|
|
209
247
|
|
|
210
248
|
def _log_event(
|
|
211
249
|
self, operation: str, level: str = "INFO", message: str = "", **data
|
|
@@ -221,6 +259,15 @@ class SpectralGuard(Guard):
|
|
|
221
259
|
}
|
|
222
260
|
self.events.append(event)
|
|
223
261
|
|
|
262
|
+
def set_run_context(self, report: Any) -> None:
|
|
263
|
+
"""Capture run profile context for reporting (informational only)."""
|
|
264
|
+
ctx = getattr(report, "context", {}) or {}
|
|
265
|
+
profile = ""
|
|
266
|
+
if isinstance(ctx, dict):
|
|
267
|
+
profile = str(ctx.get("profile", "") or "").strip().lower()
|
|
268
|
+
|
|
269
|
+
self._run_profile = profile or None
|
|
270
|
+
|
|
224
271
|
def _serialize_policy(self) -> dict[str, Any]:
|
|
225
272
|
"""Snapshot current guard policy for report serialization."""
|
|
226
273
|
return {
|
|
@@ -235,6 +282,8 @@ class SpectralGuard(Guard):
|
|
|
235
282
|
),
|
|
236
283
|
"family_caps": self.family_caps,
|
|
237
284
|
"multiple_testing": self.multiple_testing,
|
|
285
|
+
"estimator": self.estimator,
|
|
286
|
+
"degeneracy": self.degeneracy,
|
|
238
287
|
"correction_enabled": bool(self.correction_enabled),
|
|
239
288
|
"ignore_preview_inflation": bool(self.ignore_preview_inflation),
|
|
240
289
|
}
|
|
@@ -259,15 +308,14 @@ class SpectralGuard(Guard):
|
|
|
259
308
|
# Update configuration from policy
|
|
260
309
|
if policy:
|
|
261
310
|
sigma_value = policy.get("sigma_quantile")
|
|
262
|
-
if
|
|
263
|
-
|
|
264
|
-
|
|
265
|
-
|
|
311
|
+
if "contraction" in policy or "kappa" in policy:
|
|
312
|
+
raise ValueError(
|
|
313
|
+
"Spectral policy keys 'contraction'/'kappa' are not supported; "
|
|
314
|
+
"use 'sigma_quantile'."
|
|
315
|
+
)
|
|
266
316
|
if sigma_value is not None:
|
|
267
317
|
self.sigma_quantile = float(sigma_value)
|
|
268
318
|
policy["sigma_quantile"] = self.sigma_quantile
|
|
269
|
-
policy.pop("contraction", None)
|
|
270
|
-
policy.pop("kappa", None)
|
|
271
319
|
self.config["sigma_quantile"] = self.sigma_quantile
|
|
272
320
|
|
|
273
321
|
for key in [
|
|
@@ -306,14 +354,65 @@ class SpectralGuard(Guard):
|
|
|
306
354
|
if isinstance(stats, dict)
|
|
307
355
|
}
|
|
308
356
|
self.config["baseline_family_stats"] = self.baseline_family_stats
|
|
357
|
+
if "multipletesting" in policy:
|
|
358
|
+
raise ValidationError(
|
|
359
|
+
code="E501",
|
|
360
|
+
message="POLICY-PARAM-INVALID",
|
|
361
|
+
details={
|
|
362
|
+
"param": "multipletesting",
|
|
363
|
+
"hint": "Use spectral.multiple_testing instead.",
|
|
364
|
+
},
|
|
365
|
+
)
|
|
309
366
|
mt_policy = policy.get("multiple_testing")
|
|
310
|
-
if mt_policy is None:
|
|
311
|
-
mt_policy = policy.get("multipletesting")
|
|
312
367
|
if isinstance(mt_policy, dict):
|
|
313
368
|
self.multiple_testing = mt_policy.copy()
|
|
314
369
|
policy["multiple_testing"] = self.multiple_testing
|
|
315
370
|
self.config["multiple_testing"] = self.multiple_testing
|
|
316
|
-
|
|
371
|
+
|
|
372
|
+
estimator_policy = policy.get("estimator")
|
|
373
|
+
if isinstance(estimator_policy, dict):
|
|
374
|
+
try:
|
|
375
|
+
est_iters = int(estimator_policy.get("iters", 4) or 4)
|
|
376
|
+
except Exception:
|
|
377
|
+
est_iters = 4
|
|
378
|
+
if est_iters < 1:
|
|
379
|
+
est_iters = 1
|
|
380
|
+
est_init = (
|
|
381
|
+
str(estimator_policy.get("init", "ones") or "ones").strip().lower()
|
|
382
|
+
)
|
|
383
|
+
if est_init not in {"ones", "e0"}:
|
|
384
|
+
est_init = "ones"
|
|
385
|
+
self.estimator = {
|
|
386
|
+
"type": "power_iter",
|
|
387
|
+
"iters": est_iters,
|
|
388
|
+
"init": est_init,
|
|
389
|
+
}
|
|
390
|
+
self.config["estimator"] = self.estimator
|
|
391
|
+
|
|
392
|
+
degeneracy_policy = policy.get("degeneracy")
|
|
393
|
+
if isinstance(degeneracy_policy, dict):
|
|
394
|
+
stable_rank_cfg = degeneracy_policy.get("stable_rank")
|
|
395
|
+
norm_collapse_cfg = degeneracy_policy.get("norm_collapse")
|
|
396
|
+
self.degeneracy = {
|
|
397
|
+
"enabled": bool(degeneracy_policy.get("enabled", True)),
|
|
398
|
+
"stable_rank": {
|
|
399
|
+
"warn_ratio": float(
|
|
400
|
+
(stable_rank_cfg or {}).get("warn_ratio", 0.5)
|
|
401
|
+
),
|
|
402
|
+
"fatal_ratio": float(
|
|
403
|
+
(stable_rank_cfg or {}).get("fatal_ratio", 0.25)
|
|
404
|
+
),
|
|
405
|
+
},
|
|
406
|
+
"norm_collapse": {
|
|
407
|
+
"warn_ratio": float(
|
|
408
|
+
(norm_collapse_cfg or {}).get("warn_ratio", 0.25)
|
|
409
|
+
),
|
|
410
|
+
"fatal_ratio": float(
|
|
411
|
+
(norm_collapse_cfg or {}).get("fatal_ratio", 0.10)
|
|
412
|
+
),
|
|
413
|
+
},
|
|
414
|
+
}
|
|
415
|
+
self.config["degeneracy"] = self.degeneracy
|
|
317
416
|
|
|
318
417
|
self._log_event(
|
|
319
418
|
"prepare",
|
|
@@ -325,7 +424,7 @@ class SpectralGuard(Guard):
|
|
|
325
424
|
|
|
326
425
|
try:
|
|
327
426
|
# Capture baseline spectral properties
|
|
328
|
-
self.baseline_sigmas =
|
|
427
|
+
self.baseline_sigmas = self._capture_sigmas(model, phase="prepare")
|
|
329
428
|
self.module_family_map = classify_model_families(
|
|
330
429
|
model, scope=self.scope, existing=self.module_family_map
|
|
331
430
|
)
|
|
@@ -334,15 +433,55 @@ class SpectralGuard(Guard):
|
|
|
334
433
|
self.baseline_sigmas, self.module_family_map
|
|
335
434
|
)
|
|
336
435
|
|
|
337
|
-
|
|
338
|
-
baseline_stats = scan_model_gains(model, scope=self.scope)
|
|
339
|
-
summarized = _summarize_sigmas(self.baseline_sigmas)
|
|
340
|
-
baseline_stats.update(summarized)
|
|
436
|
+
baseline_stats: dict[str, Any] = _summarize_sigmas(self.baseline_sigmas)
|
|
341
437
|
|
|
342
|
-
|
|
343
|
-
|
|
438
|
+
try:
|
|
439
|
+
values = np.array(list(self.baseline_sigmas.values()), dtype=float)
|
|
440
|
+
if values.size:
|
|
441
|
+
self.target_sigma = float(
|
|
442
|
+
np.percentile(values, float(self.sigma_quantile) * 100.0)
|
|
443
|
+
)
|
|
444
|
+
else:
|
|
445
|
+
self.target_sigma = float(self.sigma_quantile)
|
|
446
|
+
except Exception:
|
|
447
|
+
self.target_sigma = float(self.sigma_quantile)
|
|
344
448
|
baseline_stats["target_sigma"] = self.target_sigma
|
|
345
449
|
|
|
450
|
+
self.baseline_degeneracy = {}
|
|
451
|
+
if bool((self.degeneracy or {}).get("enabled")):
|
|
452
|
+
eps = 1e-12
|
|
453
|
+
for name, module in model.named_modules():
|
|
454
|
+
if not self._should_check_module(name, module):
|
|
455
|
+
continue
|
|
456
|
+
weight = getattr(module, "weight", None)
|
|
457
|
+
if not isinstance(weight, torch.Tensor) or weight.ndim != 2:
|
|
458
|
+
continue
|
|
459
|
+
sigma = self.baseline_sigmas.get(name)
|
|
460
|
+
if not (
|
|
461
|
+
isinstance(sigma, int | float) and math.isfinite(float(sigma))
|
|
462
|
+
):
|
|
463
|
+
continue
|
|
464
|
+
try:
|
|
465
|
+
stable_rank = frobenius_norm_sq(weight) / max(
|
|
466
|
+
float(sigma) ** 2, eps
|
|
467
|
+
)
|
|
468
|
+
norms = row_col_norm_extrema(weight, eps=eps)
|
|
469
|
+
row_med = max(float(norms.get("row_median", 0.0)), eps)
|
|
470
|
+
col_med = max(float(norms.get("col_median", 0.0)), eps)
|
|
471
|
+
collapse = min(
|
|
472
|
+
float(norms.get("row_min", 0.0)) / row_med,
|
|
473
|
+
float(norms.get("col_min", 0.0)) / col_med,
|
|
474
|
+
)
|
|
475
|
+
self.baseline_degeneracy[name] = {
|
|
476
|
+
"stable_rank": float(stable_rank),
|
|
477
|
+
"norm_collapse": float(collapse),
|
|
478
|
+
}
|
|
479
|
+
except Exception:
|
|
480
|
+
continue
|
|
481
|
+
baseline_stats["baseline_degeneracy"] = {
|
|
482
|
+
name: vals.copy() for name, vals in self.baseline_degeneracy.items()
|
|
483
|
+
}
|
|
484
|
+
|
|
346
485
|
baseline_stats["family_stats"] = {
|
|
347
486
|
family: stats.copy()
|
|
348
487
|
for family, stats in self.baseline_family_stats.items()
|
|
@@ -351,6 +490,10 @@ class SpectralGuard(Guard):
|
|
|
351
490
|
family: caps.copy() for family, caps in self.family_caps.items()
|
|
352
491
|
}
|
|
353
492
|
baseline_stats["module_sigmas"] = self.baseline_sigmas.copy()
|
|
493
|
+
baseline_stats["measurement_contract"] = {
|
|
494
|
+
"estimator": self.estimator,
|
|
495
|
+
"degeneracy": self.degeneracy,
|
|
496
|
+
}
|
|
354
497
|
|
|
355
498
|
self.baseline_metrics = baseline_stats
|
|
356
499
|
|
|
@@ -399,7 +542,7 @@ class SpectralGuard(Guard):
|
|
|
399
542
|
return
|
|
400
543
|
|
|
401
544
|
# Capture pre-edit spectral state for comparison
|
|
402
|
-
self.pre_edit_metrics =
|
|
545
|
+
self.pre_edit_metrics = self._capture_sigmas(model, phase="before_edit")
|
|
403
546
|
self.pre_edit_z_scores = compute_z_scores(
|
|
404
547
|
self.pre_edit_metrics,
|
|
405
548
|
self.baseline_family_stats,
|
|
@@ -421,7 +564,7 @@ class SpectralGuard(Guard):
|
|
|
421
564
|
|
|
422
565
|
try:
|
|
423
566
|
# Capture current spectral state
|
|
424
|
-
self.current_metrics =
|
|
567
|
+
self.current_metrics = self._capture_sigmas(model, phase="after_edit")
|
|
425
568
|
|
|
426
569
|
# Detect violations
|
|
427
570
|
violations = self._detect_spectral_violations(
|
|
@@ -461,6 +604,39 @@ class SpectralGuard(Guard):
|
|
|
461
604
|
error=str(e),
|
|
462
605
|
)
|
|
463
606
|
|
|
607
|
+
def _capture_sigmas(self, model: Any, *, phase: str) -> dict[str, float]:
|
|
608
|
+
"""Capture σ̂max for each in-scope module under the measurement contract."""
|
|
609
|
+
_ = phase # reserved for future observability hooks
|
|
610
|
+
sigmas: dict[str, float] = {}
|
|
611
|
+
try:
|
|
612
|
+
iters = int((self.estimator or {}).get("iters", 4) or 4)
|
|
613
|
+
except Exception:
|
|
614
|
+
iters = 4
|
|
615
|
+
if iters < 1:
|
|
616
|
+
iters = 1
|
|
617
|
+
init = str((self.estimator or {}).get("init", "ones") or "ones").strip().lower()
|
|
618
|
+
if init not in {"ones", "e0"}:
|
|
619
|
+
init = "ones"
|
|
620
|
+
|
|
621
|
+
for name, module in model.named_modules():
|
|
622
|
+
if not _should_process_module(name, module, self.scope):
|
|
623
|
+
continue
|
|
624
|
+
weight = getattr(module, "weight", None)
|
|
625
|
+
if not isinstance(weight, torch.Tensor) or weight.ndim != 2:
|
|
626
|
+
continue
|
|
627
|
+
if weight.dtype in {torch.int8, torch.uint8}:
|
|
628
|
+
# Skip quantized weights; treat as neutral for baseline-relative stats.
|
|
629
|
+
sigmas[name] = 1.0
|
|
630
|
+
continue
|
|
631
|
+
try:
|
|
632
|
+
sigmas[name] = float(
|
|
633
|
+
power_iter_sigma_max(weight, iters=iters, init=init)
|
|
634
|
+
)
|
|
635
|
+
except Exception:
|
|
636
|
+
sigmas[name] = 1.0
|
|
637
|
+
|
|
638
|
+
return sigmas
|
|
639
|
+
|
|
464
640
|
def _detect_spectral_violations(
|
|
465
641
|
self, model: Any, metrics: dict[str, float], phase: str = "finalize"
|
|
466
642
|
) -> list[dict[str, Any]]:
|
|
@@ -504,6 +680,7 @@ class SpectralGuard(Guard):
|
|
|
504
680
|
violations.append(
|
|
505
681
|
{
|
|
506
682
|
"type": "family_z_cap",
|
|
683
|
+
"severity": "budgeted",
|
|
507
684
|
"module": name,
|
|
508
685
|
"family": family,
|
|
509
686
|
"z_score": float(z_score),
|
|
@@ -525,6 +702,7 @@ class SpectralGuard(Guard):
|
|
|
525
702
|
violations.append(
|
|
526
703
|
{
|
|
527
704
|
"type": "max_spectral_norm",
|
|
705
|
+
"severity": "fatal",
|
|
528
706
|
"module": name,
|
|
529
707
|
"family": family,
|
|
530
708
|
"current_sigma": float(sigma_max),
|
|
@@ -533,25 +711,98 @@ class SpectralGuard(Guard):
|
|
|
533
711
|
}
|
|
534
712
|
)
|
|
535
713
|
|
|
536
|
-
|
|
537
|
-
|
|
538
|
-
|
|
539
|
-
|
|
540
|
-
|
|
541
|
-
|
|
542
|
-
|
|
543
|
-
|
|
544
|
-
|
|
545
|
-
|
|
546
|
-
|
|
547
|
-
|
|
548
|
-
|
|
549
|
-
|
|
550
|
-
|
|
551
|
-
|
|
714
|
+
if bool((self.degeneracy or {}).get("enabled")):
|
|
715
|
+
base = self.baseline_degeneracy.get(name) or {}
|
|
716
|
+
base_sr = base.get("stable_rank")
|
|
717
|
+
base_nc = base.get("norm_collapse")
|
|
718
|
+
eps = 1e-12
|
|
719
|
+
try:
|
|
720
|
+
sr_cfg = (
|
|
721
|
+
(self.degeneracy.get("stable_rank") or {})
|
|
722
|
+
if isinstance(self.degeneracy, dict)
|
|
723
|
+
else {}
|
|
724
|
+
)
|
|
725
|
+
nc_cfg = (
|
|
726
|
+
(self.degeneracy.get("norm_collapse") or {})
|
|
727
|
+
if isinstance(self.degeneracy, dict)
|
|
728
|
+
else {}
|
|
729
|
+
)
|
|
730
|
+
sr_warn = float(sr_cfg.get("warn_ratio", 0.5))
|
|
731
|
+
sr_fatal = float(sr_cfg.get("fatal_ratio", 0.25))
|
|
732
|
+
nc_warn = float(nc_cfg.get("warn_ratio", 0.25))
|
|
733
|
+
nc_fatal = float(nc_cfg.get("fatal_ratio", 0.10))
|
|
734
|
+
except Exception:
|
|
735
|
+
sr_warn, sr_fatal, nc_warn, nc_fatal = 0.5, 0.25, 0.25, 0.10
|
|
736
|
+
|
|
737
|
+
if (
|
|
738
|
+
isinstance(base_sr, int | float)
|
|
739
|
+
and math.isfinite(float(base_sr))
|
|
740
|
+
and float(base_sr) > 0
|
|
741
|
+
):
|
|
742
|
+
try:
|
|
743
|
+
sr_cur = frobenius_norm_sq(module.weight) / max(
|
|
744
|
+
float(sigma_max) ** 2, eps
|
|
552
745
|
)
|
|
553
|
-
|
|
554
|
-
|
|
746
|
+
sr_ratio = float(sr_cur) / max(float(base_sr), eps)
|
|
747
|
+
if math.isfinite(sr_ratio) and sr_ratio < sr_warn:
|
|
748
|
+
violations.append(
|
|
749
|
+
{
|
|
750
|
+
"type": "degeneracy_stable_rank_drop",
|
|
751
|
+
"severity": "fatal"
|
|
752
|
+
if sr_ratio < sr_fatal
|
|
753
|
+
else "budgeted",
|
|
754
|
+
"module": name,
|
|
755
|
+
"family": family,
|
|
756
|
+
"stable_rank_base": float(base_sr),
|
|
757
|
+
"stable_rank_cur": float(sr_cur),
|
|
758
|
+
"ratio": float(sr_ratio),
|
|
759
|
+
"warn_ratio": float(sr_warn),
|
|
760
|
+
"fatal_ratio": float(sr_fatal),
|
|
761
|
+
"message": (
|
|
762
|
+
f"Stable-rank ratio {sr_ratio:.3f} below "
|
|
763
|
+
f"{sr_warn:.3f} (base={float(base_sr):.3f}, cur={float(sr_cur):.3f})"
|
|
764
|
+
),
|
|
765
|
+
}
|
|
766
|
+
)
|
|
767
|
+
except Exception:
|
|
768
|
+
pass
|
|
769
|
+
|
|
770
|
+
if (
|
|
771
|
+
isinstance(base_nc, int | float)
|
|
772
|
+
and math.isfinite(float(base_nc))
|
|
773
|
+
and float(base_nc) > 0
|
|
774
|
+
):
|
|
775
|
+
try:
|
|
776
|
+
norms = row_col_norm_extrema(module.weight, eps=eps)
|
|
777
|
+
row_med = max(float(norms.get("row_median", 0.0)), eps)
|
|
778
|
+
col_med = max(float(norms.get("col_median", 0.0)), eps)
|
|
779
|
+
nc_cur = min(
|
|
780
|
+
float(norms.get("row_min", 0.0)) / row_med,
|
|
781
|
+
float(norms.get("col_min", 0.0)) / col_med,
|
|
782
|
+
)
|
|
783
|
+
nc_ratio = float(nc_cur) / max(float(base_nc), eps)
|
|
784
|
+
if math.isfinite(nc_ratio) and nc_ratio < nc_warn:
|
|
785
|
+
violations.append(
|
|
786
|
+
{
|
|
787
|
+
"type": "degeneracy_norm_collapse",
|
|
788
|
+
"severity": "fatal"
|
|
789
|
+
if nc_ratio < nc_fatal
|
|
790
|
+
else "budgeted",
|
|
791
|
+
"module": name,
|
|
792
|
+
"family": family,
|
|
793
|
+
"norm_collapse_base": float(base_nc),
|
|
794
|
+
"norm_collapse_cur": float(nc_cur),
|
|
795
|
+
"ratio": float(nc_ratio),
|
|
796
|
+
"warn_ratio": float(nc_warn),
|
|
797
|
+
"fatal_ratio": float(nc_fatal),
|
|
798
|
+
"message": (
|
|
799
|
+
f"Norm-collapse ratio {nc_ratio:.3f} below "
|
|
800
|
+
f"{nc_warn:.3f} (base={float(base_nc):.3e}, cur={float(nc_cur):.3e})"
|
|
801
|
+
),
|
|
802
|
+
}
|
|
803
|
+
)
|
|
804
|
+
except Exception:
|
|
805
|
+
pass
|
|
555
806
|
|
|
556
807
|
except Exception as e:
|
|
557
808
|
self._log_event(
|
|
@@ -776,7 +1027,7 @@ class SpectralGuard(Guard):
|
|
|
776
1027
|
self.prepare(model, adapter, None, {})
|
|
777
1028
|
|
|
778
1029
|
# Capture current spectral state
|
|
779
|
-
current_metrics =
|
|
1030
|
+
current_metrics = self._capture_sigmas(model, phase="validate")
|
|
780
1031
|
|
|
781
1032
|
# Detect violations (final validation phase)
|
|
782
1033
|
violations = self._detect_spectral_violations(
|
|
@@ -784,16 +1035,16 @@ class SpectralGuard(Guard):
|
|
|
784
1035
|
)
|
|
785
1036
|
|
|
786
1037
|
# Determine if passed under budget/fatal rules
|
|
787
|
-
|
|
788
|
-
budgeted_violations = [
|
|
1038
|
+
fatal_violations = [
|
|
789
1039
|
violation
|
|
790
1040
|
for violation in violations
|
|
791
|
-
if violation.get("
|
|
1041
|
+
if (violation.get("severity") == "fatal")
|
|
1042
|
+
or (violation.get("type") == "max_spectral_norm")
|
|
792
1043
|
]
|
|
793
|
-
|
|
1044
|
+
budgeted_violations = [
|
|
794
1045
|
violation
|
|
795
1046
|
for violation in violations
|
|
796
|
-
if violation
|
|
1047
|
+
if violation not in fatal_violations
|
|
797
1048
|
]
|
|
798
1049
|
|
|
799
1050
|
selected_budgeted, mt_selection = self._select_budgeted_violations(
|
|
@@ -839,6 +1090,10 @@ class SpectralGuard(Guard):
|
|
|
839
1090
|
"caps_exceeded": caps_exceeded,
|
|
840
1091
|
"multiple_testing": self.multiple_testing,
|
|
841
1092
|
"multiple_testing_selection": mt_selection,
|
|
1093
|
+
"measurement_contract": {
|
|
1094
|
+
"estimator": self.estimator,
|
|
1095
|
+
"degeneracy": self.degeneracy,
|
|
1096
|
+
},
|
|
842
1097
|
}
|
|
843
1098
|
|
|
844
1099
|
family_quantiles, top_z_scores = self._compute_family_observability()
|
|
@@ -916,7 +1171,7 @@ class SpectralGuard(Guard):
|
|
|
916
1171
|
}
|
|
917
1172
|
|
|
918
1173
|
# Final spectral analysis
|
|
919
|
-
final_metrics =
|
|
1174
|
+
final_metrics = self._capture_sigmas(model, phase="finalize")
|
|
920
1175
|
final_violations = self._detect_spectral_violations(
|
|
921
1176
|
model, final_metrics, phase="finalize"
|
|
922
1177
|
)
|
|
@@ -928,16 +1183,16 @@ class SpectralGuard(Guard):
|
|
|
928
1183
|
family_quantiles, top_z_scores = self._compute_family_observability()
|
|
929
1184
|
|
|
930
1185
|
# Determine overall status based on budgeted vs fatal violations
|
|
931
|
-
|
|
932
|
-
budgeted_violations = [
|
|
1186
|
+
fatal_violations = [
|
|
933
1187
|
violation
|
|
934
1188
|
for violation in final_violations
|
|
935
|
-
if violation.get("
|
|
1189
|
+
if (violation.get("severity") == "fatal")
|
|
1190
|
+
or (violation.get("type") == "max_spectral_norm")
|
|
936
1191
|
]
|
|
937
|
-
|
|
1192
|
+
budgeted_violations = [
|
|
938
1193
|
violation
|
|
939
1194
|
for violation in final_violations
|
|
940
|
-
if violation
|
|
1195
|
+
if violation not in fatal_violations
|
|
941
1196
|
]
|
|
942
1197
|
|
|
943
1198
|
selected_budgeted, mt_selection = self._select_budgeted_violations(
|
|
@@ -983,6 +1238,10 @@ class SpectralGuard(Guard):
|
|
|
983
1238
|
"multiple_testing_selection": mt_selection,
|
|
984
1239
|
"family_z_quantiles": family_quantiles,
|
|
985
1240
|
"top_z_scores": top_z_scores,
|
|
1241
|
+
"measurement_contract": {
|
|
1242
|
+
"estimator": self.estimator,
|
|
1243
|
+
"degeneracy": self.degeneracy,
|
|
1244
|
+
},
|
|
986
1245
|
}
|
|
987
1246
|
|
|
988
1247
|
# Categorize violations
|
|
@@ -990,7 +1249,9 @@ class SpectralGuard(Guard):
|
|
|
990
1249
|
errors = []
|
|
991
1250
|
|
|
992
1251
|
for violation in selected_final_violations:
|
|
993
|
-
if violation
|
|
1252
|
+
if (violation.get("severity") == "fatal") or (
|
|
1253
|
+
violation.get("type") == "max_spectral_norm"
|
|
1254
|
+
):
|
|
994
1255
|
errors.append(violation["message"])
|
|
995
1256
|
else:
|
|
996
1257
|
warnings.append(violation["message"])
|
|
@@ -1029,7 +1290,9 @@ class SpectralGuard(Guard):
|
|
|
1029
1290
|
return result
|
|
1030
1291
|
|
|
1031
1292
|
|
|
1032
|
-
def compute_sigma_max(
|
|
1293
|
+
def compute_sigma_max(
|
|
1294
|
+
weight_matrix: Any, *, iters: int = 4, init: str = "ones"
|
|
1295
|
+
) -> float:
|
|
1033
1296
|
"""
|
|
1034
1297
|
Compute maximum singular value of a weight matrix.
|
|
1035
1298
|
|
|
@@ -1040,35 +1303,29 @@ def compute_sigma_max(weight_matrix: Any) -> float:
|
|
|
1040
1303
|
Maximum singular value
|
|
1041
1304
|
"""
|
|
1042
1305
|
try:
|
|
1043
|
-
|
|
1044
|
-
|
|
1045
|
-
|
|
1046
|
-
|
|
1047
|
-
|
|
1048
|
-
|
|
1049
|
-
|
|
1050
|
-
|
|
1051
|
-
|
|
1052
|
-
|
|
1053
|
-
|
|
1054
|
-
|
|
1055
|
-
|
|
1056
|
-
|
|
1057
|
-
|
|
1058
|
-
singular_values = torch.linalg.svdvals(W)
|
|
1059
|
-
except RuntimeError:
|
|
1060
|
-
# Fallback for older backends without svdvals
|
|
1061
|
-
singular_values = torch.linalg.svd(W, full_matrices=False).S
|
|
1062
|
-
|
|
1063
|
-
return singular_values[0].item() if singular_values.numel() > 0 else 0.0
|
|
1064
|
-
else:
|
|
1065
|
-
return 1.0 # Fallback for non-tensor inputs
|
|
1306
|
+
iters_i = int(iters)
|
|
1307
|
+
except Exception:
|
|
1308
|
+
iters_i = 4
|
|
1309
|
+
if iters_i < 1:
|
|
1310
|
+
iters_i = 1
|
|
1311
|
+
init_s = str(init or "ones").strip().lower()
|
|
1312
|
+
if init_s not in {"ones", "e0"}:
|
|
1313
|
+
init_s = "ones"
|
|
1314
|
+
|
|
1315
|
+
if not isinstance(weight_matrix, torch.Tensor):
|
|
1316
|
+
return 1.0
|
|
1317
|
+
if weight_matrix.dtype in {torch.int8, torch.uint8}:
|
|
1318
|
+
return 1.0
|
|
1319
|
+
if weight_matrix.numel() == 0 or weight_matrix.ndim != 2:
|
|
1320
|
+
return 0.0
|
|
1066
1321
|
|
|
1322
|
+
try:
|
|
1323
|
+
return float(power_iter_sigma_max(weight_matrix, iters=iters_i, init=init_s))
|
|
1067
1324
|
except Exception:
|
|
1068
|
-
return 1.0
|
|
1325
|
+
return 1.0
|
|
1069
1326
|
|
|
1070
1327
|
|
|
1071
|
-
def auto_sigma_target(model: Any, percentile: float = 0.95
|
|
1328
|
+
def auto_sigma_target(model: Any, percentile: float = 0.95) -> float:
|
|
1072
1329
|
"""
|
|
1073
1330
|
Automatically determine sigma target for a model.
|
|
1074
1331
|
|
|
@@ -1079,11 +1336,6 @@ def auto_sigma_target(model: Any, percentile: float = 0.95, **kwargs: Any) -> fl
|
|
|
1079
1336
|
Returns:
|
|
1080
1337
|
Target sigma value
|
|
1081
1338
|
"""
|
|
1082
|
-
if "kappa" in kwargs and percentile == 0.95:
|
|
1083
|
-
try:
|
|
1084
|
-
percentile = float(kwargs["kappa"])
|
|
1085
|
-
except (TypeError, ValueError):
|
|
1086
|
-
pass
|
|
1087
1339
|
try:
|
|
1088
1340
|
# Collect all spectral norms
|
|
1089
1341
|
spectral_norms = []
|
|
@@ -1471,7 +1723,6 @@ def scan_model_gains(model: Any, scope: str = "all") -> dict[str, Any]:
|
|
|
1471
1723
|
"total_layers": 0,
|
|
1472
1724
|
"scanned_modules": 0,
|
|
1473
1725
|
"spectral_norms": [],
|
|
1474
|
-
"condition_numbers": [],
|
|
1475
1726
|
"weight_statistics": {},
|
|
1476
1727
|
}
|
|
1477
1728
|
|
|
@@ -1486,15 +1737,6 @@ def scan_model_gains(model: Any, scope: str = "all") -> dict[str, Any]:
|
|
|
1486
1737
|
sigma_max = compute_sigma_max(module.weight)
|
|
1487
1738
|
results["spectral_norms"].append(sigma_max)
|
|
1488
1739
|
|
|
1489
|
-
# Compute condition number if possible
|
|
1490
|
-
try:
|
|
1491
|
-
U, S, V = torch.svd(module.weight.float())
|
|
1492
|
-
if len(S) > 1:
|
|
1493
|
-
condition_num = (S[0] / S[-1]).item()
|
|
1494
|
-
results["condition_numbers"].append(condition_num)
|
|
1495
|
-
except Exception:
|
|
1496
|
-
pass
|
|
1497
|
-
|
|
1498
1740
|
# Basic weight statistics
|
|
1499
1741
|
try:
|
|
1500
1742
|
weight_stats = {
|
|
@@ -1513,10 +1755,6 @@ def scan_model_gains(model: Any, scope: str = "all") -> dict[str, Any]:
|
|
|
1513
1755
|
results["max_spectral_norm"] = np.max(results["spectral_norms"])
|
|
1514
1756
|
results["min_spectral_norm"] = np.min(results["spectral_norms"])
|
|
1515
1757
|
|
|
1516
|
-
if results["condition_numbers"]:
|
|
1517
|
-
results["mean_condition_number"] = np.mean(results["condition_numbers"])
|
|
1518
|
-
results["max_condition_number"] = np.max(results["condition_numbers"])
|
|
1519
|
-
|
|
1520
1758
|
results["message"] = (
|
|
1521
1759
|
f"Scanned {results['scanned_modules']} modules out of {results['total_layers']} total layers"
|
|
1522
1760
|
)
|