invarlock 0.3.4__py3-none-any.whl → 0.3.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (48) hide show
  1. invarlock/__init__.py +1 -1
  2. invarlock/_data/runtime/tiers.yaml +57 -30
  3. invarlock/adapters/__init__.py +1 -1
  4. invarlock/calibration/spectral_null.py +15 -10
  5. invarlock/calibration/variance_ve.py +0 -2
  6. invarlock/cli/commands/calibrate.py +6 -2
  7. invarlock/cli/commands/certify.py +58 -39
  8. invarlock/cli/commands/doctor.py +3 -1
  9. invarlock/cli/commands/explain_gates.py +57 -8
  10. invarlock/cli/commands/report.py +1 -1
  11. invarlock/cli/commands/run.py +159 -61
  12. invarlock/cli/commands/verify.py +78 -4
  13. invarlock/cli/config.py +21 -5
  14. invarlock/core/api.py +45 -5
  15. invarlock/core/auto_tuning.py +65 -20
  16. invarlock/core/contracts.py +7 -1
  17. invarlock/core/registry.py +2 -2
  18. invarlock/core/runner.py +314 -50
  19. invarlock/eval/bench.py +0 -13
  20. invarlock/eval/data.py +73 -283
  21. invarlock/eval/metrics.py +134 -4
  22. invarlock/eval/primary_metric.py +23 -0
  23. invarlock/eval/tail_stats.py +230 -0
  24. invarlock/guards/_estimators.py +154 -0
  25. invarlock/guards/policies.py +16 -6
  26. invarlock/guards/rmt.py +625 -544
  27. invarlock/guards/spectral.py +348 -110
  28. invarlock/guards/tier_config.py +32 -30
  29. invarlock/guards/variance.py +5 -29
  30. invarlock/guards_ref/rmt_ref.py +23 -23
  31. invarlock/model_profile.py +42 -15
  32. invarlock/reporting/certificate.py +225 -46
  33. invarlock/reporting/certificate_schema.py +2 -1
  34. invarlock/reporting/dataset_hashing.py +15 -2
  35. invarlock/reporting/guards_analysis.py +197 -274
  36. invarlock/reporting/normalizer.py +6 -0
  37. invarlock/reporting/policy_utils.py +38 -36
  38. invarlock/reporting/primary_metric_utils.py +71 -17
  39. invarlock/reporting/render.py +61 -0
  40. invarlock/reporting/report.py +1 -1
  41. invarlock/reporting/report_types.py +5 -2
  42. invarlock/reporting/validate.py +1 -18
  43. {invarlock-0.3.4.dist-info → invarlock-0.3.6.dist-info}/METADATA +6 -6
  44. {invarlock-0.3.4.dist-info → invarlock-0.3.6.dist-info}/RECORD +48 -46
  45. {invarlock-0.3.4.dist-info → invarlock-0.3.6.dist-info}/WHEEL +0 -0
  46. {invarlock-0.3.4.dist-info → invarlock-0.3.6.dist-info}/entry_points.txt +0 -0
  47. {invarlock-0.3.4.dist-info → invarlock-0.3.6.dist-info}/licenses/LICENSE +0 -0
  48. {invarlock-0.3.4.dist-info → invarlock-0.3.6.dist-info}/top_level.txt +0 -0
@@ -12,18 +12,15 @@ from collections import defaultdict
12
12
  from datetime import datetime
13
13
  from typing import Any, TypedDict
14
14
 
15
- try:
16
- from typing import NotRequired
17
- except ImportError: # Python 3.10 fallback
18
- from typing import NotRequired
19
-
20
15
  import numpy as np
21
16
  import torch
22
17
 
23
18
  from invarlock.cli._evidence import maybe_dump_guard_evidence
24
19
  from invarlock.core.api import Guard
20
+ from invarlock.core.exceptions import ValidationError
25
21
 
26
22
  from ._contracts import guard_assert
23
+ from ._estimators import frobenius_norm_sq, power_iter_sigma_max, row_col_norm_extrema
27
24
 
28
25
 
29
26
  def _z_to_two_sided_pvalue(z: Any) -> float:
@@ -104,10 +101,10 @@ class SpectralPolicy(TypedDict, total=False):
104
101
  """Type definition for spectral guard policy configuration."""
105
102
 
106
103
  sigma_quantile: float
107
- contraction: NotRequired[float] # Backward compatibility alias
108
- kappa: NotRequired[float] # Legacy alias
109
104
  deadband: float
110
105
  scope: str
106
+ estimator: dict[str, Any]
107
+ degeneracy: dict[str, Any]
111
108
  correction_enabled: bool
112
109
  family_caps: dict[str, dict[str, float]]
113
110
  ignore_preview_inflation: bool
@@ -172,24 +169,16 @@ class SpectralGuard(Guard):
172
169
 
173
170
  # Default configuration
174
171
  sigma_quantile = self.config.get("sigma_quantile")
175
- if sigma_quantile is None:
176
- for alias in ("contraction", "kappa"):
177
- if self.config.get(alias) is not None:
178
- sigma_quantile = self.config[alias]
179
- break
180
172
  if sigma_quantile is None:
181
173
  sigma_quantile = 0.95
182
174
  self.sigma_quantile = float(sigma_quantile)
183
175
  self.config["sigma_quantile"] = self.sigma_quantile
184
- self.config.pop("contraction", None)
185
- self.config.pop("kappa", None)
186
176
  self.deadband = kwargs.get("deadband", 0.10)
187
177
  self.scope = kwargs.get("scope", "all") # 'all', 'ffn', 'attn'
188
- self.max_spectral_norm = kwargs.get("max_spectral_norm", 10.0)
178
+ self.max_spectral_norm = kwargs.get("max_spectral_norm", None)
189
179
  if self.max_spectral_norm is not None:
190
180
  self.max_spectral_norm = float(self.max_spectral_norm)
191
181
  self.config["max_spectral_norm"] = self.max_spectral_norm
192
- self.min_condition_number = kwargs.get("min_condition_number", 1e-12)
193
182
  self.correction_enabled = kwargs.get("correction_enabled", True)
194
183
  self.family_caps = _normalize_family_caps(
195
184
  kwargs.get("family_caps"), default=True
@@ -200,12 +189,61 @@ class SpectralGuard(Guard):
200
189
  "multiple_testing", {"method": "bh", "alpha": 0.05, "m": 4}
201
190
  )
202
191
 
192
+ estimator_cfg = kwargs.get("estimator")
193
+ if not isinstance(estimator_cfg, dict):
194
+ estimator_cfg = {}
195
+ try:
196
+ est_iters = int(estimator_cfg.get("iters", 4) or 4)
197
+ except Exception:
198
+ est_iters = 4
199
+ if est_iters < 1:
200
+ est_iters = 1
201
+ est_init = str(estimator_cfg.get("init", "ones") or "ones").strip().lower()
202
+ if est_init not in {"ones", "e0"}:
203
+ est_init = "ones"
204
+ self.estimator: dict[str, Any] = {
205
+ "type": "power_iter",
206
+ "iters": est_iters,
207
+ "init": est_init,
208
+ }
209
+
210
+ degeneracy_cfg = kwargs.get("degeneracy")
211
+ if not isinstance(degeneracy_cfg, dict):
212
+ degeneracy_cfg = {}
213
+ stable_rank_cfg = (
214
+ degeneracy_cfg.get("stable_rank")
215
+ if isinstance(degeneracy_cfg, dict)
216
+ else {}
217
+ )
218
+ norm_collapse_cfg = (
219
+ degeneracy_cfg.get("norm_collapse")
220
+ if isinstance(degeneracy_cfg, dict)
221
+ else {}
222
+ )
223
+ self.degeneracy: dict[str, Any] = {
224
+ "enabled": bool(degeneracy_cfg.get("enabled", True)),
225
+ "stable_rank": {
226
+ "warn_ratio": float((stable_rank_cfg or {}).get("warn_ratio", 0.5)),
227
+ "fatal_ratio": float((stable_rank_cfg or {}).get("fatal_ratio", 0.25)),
228
+ },
229
+ "norm_collapse": {
230
+ "warn_ratio": float((norm_collapse_cfg or {}).get("warn_ratio", 0.25)),
231
+ "fatal_ratio": float(
232
+ (norm_collapse_cfg or {}).get("fatal_ratio", 0.10)
233
+ ),
234
+ },
235
+ }
236
+
203
237
  # Baseline and tracking structures
204
238
  self.baseline_sigmas: dict[str, float] = {}
205
239
  self.baseline_family_stats: dict[str, dict[str, float]] = {}
206
240
  self.module_family_map: dict[str, str] = {}
207
241
  self.latest_z_scores: dict[str, float] = {}
208
242
  self.pre_edit_z_scores: dict[str, float] = {}
243
+ self.baseline_degeneracy: dict[str, dict[str, float]] = {}
244
+
245
+ # Run context (informational only; contract is policy-bound)
246
+ self._run_profile: str | None = None
209
247
 
210
248
  def _log_event(
211
249
  self, operation: str, level: str = "INFO", message: str = "", **data
@@ -221,6 +259,15 @@ class SpectralGuard(Guard):
221
259
  }
222
260
  self.events.append(event)
223
261
 
262
+ def set_run_context(self, report: Any) -> None:
263
+ """Capture run profile context for reporting (informational only)."""
264
+ ctx = getattr(report, "context", {}) or {}
265
+ profile = ""
266
+ if isinstance(ctx, dict):
267
+ profile = str(ctx.get("profile", "") or "").strip().lower()
268
+
269
+ self._run_profile = profile or None
270
+
224
271
  def _serialize_policy(self) -> dict[str, Any]:
225
272
  """Snapshot current guard policy for report serialization."""
226
273
  return {
@@ -235,6 +282,8 @@ class SpectralGuard(Guard):
235
282
  ),
236
283
  "family_caps": self.family_caps,
237
284
  "multiple_testing": self.multiple_testing,
285
+ "estimator": self.estimator,
286
+ "degeneracy": self.degeneracy,
238
287
  "correction_enabled": bool(self.correction_enabled),
239
288
  "ignore_preview_inflation": bool(self.ignore_preview_inflation),
240
289
  }
@@ -259,15 +308,14 @@ class SpectralGuard(Guard):
259
308
  # Update configuration from policy
260
309
  if policy:
261
310
  sigma_value = policy.get("sigma_quantile")
262
- if sigma_value is None:
263
- alias_value = policy.get("contraction", policy.get("kappa"))
264
- if alias_value is not None:
265
- sigma_value = alias_value
311
+ if "contraction" in policy or "kappa" in policy:
312
+ raise ValueError(
313
+ "Spectral policy keys 'contraction'/'kappa' are not supported; "
314
+ "use 'sigma_quantile'."
315
+ )
266
316
  if sigma_value is not None:
267
317
  self.sigma_quantile = float(sigma_value)
268
318
  policy["sigma_quantile"] = self.sigma_quantile
269
- policy.pop("contraction", None)
270
- policy.pop("kappa", None)
271
319
  self.config["sigma_quantile"] = self.sigma_quantile
272
320
 
273
321
  for key in [
@@ -306,14 +354,65 @@ class SpectralGuard(Guard):
306
354
  if isinstance(stats, dict)
307
355
  }
308
356
  self.config["baseline_family_stats"] = self.baseline_family_stats
357
+ if "multipletesting" in policy:
358
+ raise ValidationError(
359
+ code="E501",
360
+ message="POLICY-PARAM-INVALID",
361
+ details={
362
+ "param": "multipletesting",
363
+ "hint": "Use spectral.multiple_testing instead.",
364
+ },
365
+ )
309
366
  mt_policy = policy.get("multiple_testing")
310
- if mt_policy is None:
311
- mt_policy = policy.get("multipletesting")
312
367
  if isinstance(mt_policy, dict):
313
368
  self.multiple_testing = mt_policy.copy()
314
369
  policy["multiple_testing"] = self.multiple_testing
315
370
  self.config["multiple_testing"] = self.multiple_testing
316
- policy.pop("multipletesting", None)
371
+
372
+ estimator_policy = policy.get("estimator")
373
+ if isinstance(estimator_policy, dict):
374
+ try:
375
+ est_iters = int(estimator_policy.get("iters", 4) or 4)
376
+ except Exception:
377
+ est_iters = 4
378
+ if est_iters < 1:
379
+ est_iters = 1
380
+ est_init = (
381
+ str(estimator_policy.get("init", "ones") or "ones").strip().lower()
382
+ )
383
+ if est_init not in {"ones", "e0"}:
384
+ est_init = "ones"
385
+ self.estimator = {
386
+ "type": "power_iter",
387
+ "iters": est_iters,
388
+ "init": est_init,
389
+ }
390
+ self.config["estimator"] = self.estimator
391
+
392
+ degeneracy_policy = policy.get("degeneracy")
393
+ if isinstance(degeneracy_policy, dict):
394
+ stable_rank_cfg = degeneracy_policy.get("stable_rank")
395
+ norm_collapse_cfg = degeneracy_policy.get("norm_collapse")
396
+ self.degeneracy = {
397
+ "enabled": bool(degeneracy_policy.get("enabled", True)),
398
+ "stable_rank": {
399
+ "warn_ratio": float(
400
+ (stable_rank_cfg or {}).get("warn_ratio", 0.5)
401
+ ),
402
+ "fatal_ratio": float(
403
+ (stable_rank_cfg or {}).get("fatal_ratio", 0.25)
404
+ ),
405
+ },
406
+ "norm_collapse": {
407
+ "warn_ratio": float(
408
+ (norm_collapse_cfg or {}).get("warn_ratio", 0.25)
409
+ ),
410
+ "fatal_ratio": float(
411
+ (norm_collapse_cfg or {}).get("fatal_ratio", 0.10)
412
+ ),
413
+ },
414
+ }
415
+ self.config["degeneracy"] = self.degeneracy
317
416
 
318
417
  self._log_event(
319
418
  "prepare",
@@ -325,7 +424,7 @@ class SpectralGuard(Guard):
325
424
 
326
425
  try:
327
426
  # Capture baseline spectral properties
328
- self.baseline_sigmas = capture_baseline_sigmas(model, scope=self.scope)
427
+ self.baseline_sigmas = self._capture_sigmas(model, phase="prepare")
329
428
  self.module_family_map = classify_model_families(
330
429
  model, scope=self.scope, existing=self.module_family_map
331
430
  )
@@ -334,15 +433,55 @@ class SpectralGuard(Guard):
334
433
  self.baseline_sigmas, self.module_family_map
335
434
  )
336
435
 
337
- # Compute additional baseline metrics
338
- baseline_stats = scan_model_gains(model, scope=self.scope)
339
- summarized = _summarize_sigmas(self.baseline_sigmas)
340
- baseline_stats.update(summarized)
436
+ baseline_stats: dict[str, Any] = _summarize_sigmas(self.baseline_sigmas)
341
437
 
342
- # Store target sigma value
343
- self.target_sigma = auto_sigma_target(model, percentile=self.sigma_quantile)
438
+ try:
439
+ values = np.array(list(self.baseline_sigmas.values()), dtype=float)
440
+ if values.size:
441
+ self.target_sigma = float(
442
+ np.percentile(values, float(self.sigma_quantile) * 100.0)
443
+ )
444
+ else:
445
+ self.target_sigma = float(self.sigma_quantile)
446
+ except Exception:
447
+ self.target_sigma = float(self.sigma_quantile)
344
448
  baseline_stats["target_sigma"] = self.target_sigma
345
449
 
450
+ self.baseline_degeneracy = {}
451
+ if bool((self.degeneracy or {}).get("enabled")):
452
+ eps = 1e-12
453
+ for name, module in model.named_modules():
454
+ if not self._should_check_module(name, module):
455
+ continue
456
+ weight = getattr(module, "weight", None)
457
+ if not isinstance(weight, torch.Tensor) or weight.ndim != 2:
458
+ continue
459
+ sigma = self.baseline_sigmas.get(name)
460
+ if not (
461
+ isinstance(sigma, int | float) and math.isfinite(float(sigma))
462
+ ):
463
+ continue
464
+ try:
465
+ stable_rank = frobenius_norm_sq(weight) / max(
466
+ float(sigma) ** 2, eps
467
+ )
468
+ norms = row_col_norm_extrema(weight, eps=eps)
469
+ row_med = max(float(norms.get("row_median", 0.0)), eps)
470
+ col_med = max(float(norms.get("col_median", 0.0)), eps)
471
+ collapse = min(
472
+ float(norms.get("row_min", 0.0)) / row_med,
473
+ float(norms.get("col_min", 0.0)) / col_med,
474
+ )
475
+ self.baseline_degeneracy[name] = {
476
+ "stable_rank": float(stable_rank),
477
+ "norm_collapse": float(collapse),
478
+ }
479
+ except Exception:
480
+ continue
481
+ baseline_stats["baseline_degeneracy"] = {
482
+ name: vals.copy() for name, vals in self.baseline_degeneracy.items()
483
+ }
484
+
346
485
  baseline_stats["family_stats"] = {
347
486
  family: stats.copy()
348
487
  for family, stats in self.baseline_family_stats.items()
@@ -351,6 +490,10 @@ class SpectralGuard(Guard):
351
490
  family: caps.copy() for family, caps in self.family_caps.items()
352
491
  }
353
492
  baseline_stats["module_sigmas"] = self.baseline_sigmas.copy()
493
+ baseline_stats["measurement_contract"] = {
494
+ "estimator": self.estimator,
495
+ "degeneracy": self.degeneracy,
496
+ }
354
497
 
355
498
  self.baseline_metrics = baseline_stats
356
499
 
@@ -399,7 +542,7 @@ class SpectralGuard(Guard):
399
542
  return
400
543
 
401
544
  # Capture pre-edit spectral state for comparison
402
- self.pre_edit_metrics = capture_baseline_sigmas(model, scope=self.scope)
545
+ self.pre_edit_metrics = self._capture_sigmas(model, phase="before_edit")
403
546
  self.pre_edit_z_scores = compute_z_scores(
404
547
  self.pre_edit_metrics,
405
548
  self.baseline_family_stats,
@@ -421,7 +564,7 @@ class SpectralGuard(Guard):
421
564
 
422
565
  try:
423
566
  # Capture current spectral state
424
- self.current_metrics = capture_baseline_sigmas(model, scope=self.scope)
567
+ self.current_metrics = self._capture_sigmas(model, phase="after_edit")
425
568
 
426
569
  # Detect violations
427
570
  violations = self._detect_spectral_violations(
@@ -461,6 +604,39 @@ class SpectralGuard(Guard):
461
604
  error=str(e),
462
605
  )
463
606
 
607
+ def _capture_sigmas(self, model: Any, *, phase: str) -> dict[str, float]:
608
+ """Capture σ̂max for each in-scope module under the measurement contract."""
609
+ _ = phase # reserved for future observability hooks
610
+ sigmas: dict[str, float] = {}
611
+ try:
612
+ iters = int((self.estimator or {}).get("iters", 4) or 4)
613
+ except Exception:
614
+ iters = 4
615
+ if iters < 1:
616
+ iters = 1
617
+ init = str((self.estimator or {}).get("init", "ones") or "ones").strip().lower()
618
+ if init not in {"ones", "e0"}:
619
+ init = "ones"
620
+
621
+ for name, module in model.named_modules():
622
+ if not _should_process_module(name, module, self.scope):
623
+ continue
624
+ weight = getattr(module, "weight", None)
625
+ if not isinstance(weight, torch.Tensor) or weight.ndim != 2:
626
+ continue
627
+ if weight.dtype in {torch.int8, torch.uint8}:
628
+ # Skip quantized weights; treat as neutral for baseline-relative stats.
629
+ sigmas[name] = 1.0
630
+ continue
631
+ try:
632
+ sigmas[name] = float(
633
+ power_iter_sigma_max(weight, iters=iters, init=init)
634
+ )
635
+ except Exception:
636
+ sigmas[name] = 1.0
637
+
638
+ return sigmas
639
+
464
640
  def _detect_spectral_violations(
465
641
  self, model: Any, metrics: dict[str, float], phase: str = "finalize"
466
642
  ) -> list[dict[str, Any]]:
@@ -504,6 +680,7 @@ class SpectralGuard(Guard):
504
680
  violations.append(
505
681
  {
506
682
  "type": "family_z_cap",
683
+ "severity": "budgeted",
507
684
  "module": name,
508
685
  "family": family,
509
686
  "z_score": float(z_score),
@@ -525,6 +702,7 @@ class SpectralGuard(Guard):
525
702
  violations.append(
526
703
  {
527
704
  "type": "max_spectral_norm",
705
+ "severity": "fatal",
528
706
  "module": name,
529
707
  "family": family,
530
708
  "current_sigma": float(sigma_max),
@@ -533,25 +711,98 @@ class SpectralGuard(Guard):
533
711
  }
534
712
  )
535
713
 
536
- # Condition number monitoring (warn only)
537
- try:
538
- U, S, V = torch.svd(module.weight.float())
539
- if len(S) > 0:
540
- condition_number = S[0].item() / max(S[-1].item(), 1e-12)
541
- if S[-1].item() < self.min_condition_number:
542
- violations.append(
543
- {
544
- "type": "ill_conditioned",
545
- "module": name,
546
- "family": family,
547
- "condition_number": float(condition_number),
548
- "min_singular_value": float(S[-1].item()),
549
- "threshold": float(self.min_condition_number),
550
- "message": f"Matrix is ill-conditioned, min singular value: {S[-1].item():.2e}",
551
- }
714
+ if bool((self.degeneracy or {}).get("enabled")):
715
+ base = self.baseline_degeneracy.get(name) or {}
716
+ base_sr = base.get("stable_rank")
717
+ base_nc = base.get("norm_collapse")
718
+ eps = 1e-12
719
+ try:
720
+ sr_cfg = (
721
+ (self.degeneracy.get("stable_rank") or {})
722
+ if isinstance(self.degeneracy, dict)
723
+ else {}
724
+ )
725
+ nc_cfg = (
726
+ (self.degeneracy.get("norm_collapse") or {})
727
+ if isinstance(self.degeneracy, dict)
728
+ else {}
729
+ )
730
+ sr_warn = float(sr_cfg.get("warn_ratio", 0.5))
731
+ sr_fatal = float(sr_cfg.get("fatal_ratio", 0.25))
732
+ nc_warn = float(nc_cfg.get("warn_ratio", 0.25))
733
+ nc_fatal = float(nc_cfg.get("fatal_ratio", 0.10))
734
+ except Exception:
735
+ sr_warn, sr_fatal, nc_warn, nc_fatal = 0.5, 0.25, 0.25, 0.10
736
+
737
+ if (
738
+ isinstance(base_sr, int | float)
739
+ and math.isfinite(float(base_sr))
740
+ and float(base_sr) > 0
741
+ ):
742
+ try:
743
+ sr_cur = frobenius_norm_sq(module.weight) / max(
744
+ float(sigma_max) ** 2, eps
552
745
  )
553
- except Exception:
554
- pass # SVD failure is not a violation
746
+ sr_ratio = float(sr_cur) / max(float(base_sr), eps)
747
+ if math.isfinite(sr_ratio) and sr_ratio < sr_warn:
748
+ violations.append(
749
+ {
750
+ "type": "degeneracy_stable_rank_drop",
751
+ "severity": "fatal"
752
+ if sr_ratio < sr_fatal
753
+ else "budgeted",
754
+ "module": name,
755
+ "family": family,
756
+ "stable_rank_base": float(base_sr),
757
+ "stable_rank_cur": float(sr_cur),
758
+ "ratio": float(sr_ratio),
759
+ "warn_ratio": float(sr_warn),
760
+ "fatal_ratio": float(sr_fatal),
761
+ "message": (
762
+ f"Stable-rank ratio {sr_ratio:.3f} below "
763
+ f"{sr_warn:.3f} (base={float(base_sr):.3f}, cur={float(sr_cur):.3f})"
764
+ ),
765
+ }
766
+ )
767
+ except Exception:
768
+ pass
769
+
770
+ if (
771
+ isinstance(base_nc, int | float)
772
+ and math.isfinite(float(base_nc))
773
+ and float(base_nc) > 0
774
+ ):
775
+ try:
776
+ norms = row_col_norm_extrema(module.weight, eps=eps)
777
+ row_med = max(float(norms.get("row_median", 0.0)), eps)
778
+ col_med = max(float(norms.get("col_median", 0.0)), eps)
779
+ nc_cur = min(
780
+ float(norms.get("row_min", 0.0)) / row_med,
781
+ float(norms.get("col_min", 0.0)) / col_med,
782
+ )
783
+ nc_ratio = float(nc_cur) / max(float(base_nc), eps)
784
+ if math.isfinite(nc_ratio) and nc_ratio < nc_warn:
785
+ violations.append(
786
+ {
787
+ "type": "degeneracy_norm_collapse",
788
+ "severity": "fatal"
789
+ if nc_ratio < nc_fatal
790
+ else "budgeted",
791
+ "module": name,
792
+ "family": family,
793
+ "norm_collapse_base": float(base_nc),
794
+ "norm_collapse_cur": float(nc_cur),
795
+ "ratio": float(nc_ratio),
796
+ "warn_ratio": float(nc_warn),
797
+ "fatal_ratio": float(nc_fatal),
798
+ "message": (
799
+ f"Norm-collapse ratio {nc_ratio:.3f} below "
800
+ f"{nc_warn:.3f} (base={float(base_nc):.3e}, cur={float(nc_cur):.3e})"
801
+ ),
802
+ }
803
+ )
804
+ except Exception:
805
+ pass
555
806
 
556
807
  except Exception as e:
557
808
  self._log_event(
@@ -776,7 +1027,7 @@ class SpectralGuard(Guard):
776
1027
  self.prepare(model, adapter, None, {})
777
1028
 
778
1029
  # Capture current spectral state
779
- current_metrics = capture_baseline_sigmas(model, scope=self.scope)
1030
+ current_metrics = self._capture_sigmas(model, phase="validate")
780
1031
 
781
1032
  # Detect violations (final validation phase)
782
1033
  violations = self._detect_spectral_violations(
@@ -784,16 +1035,16 @@ class SpectralGuard(Guard):
784
1035
  )
785
1036
 
786
1037
  # Determine if passed under budget/fatal rules
787
- fatal_violation_types = {"max_spectral_norm", "ill_conditioned"}
788
- budgeted_violations = [
1038
+ fatal_violations = [
789
1039
  violation
790
1040
  for violation in violations
791
- if violation.get("type") not in fatal_violation_types
1041
+ if (violation.get("severity") == "fatal")
1042
+ or (violation.get("type") == "max_spectral_norm")
792
1043
  ]
793
- fatal_violations = [
1044
+ budgeted_violations = [
794
1045
  violation
795
1046
  for violation in violations
796
- if violation.get("type") in fatal_violation_types
1047
+ if violation not in fatal_violations
797
1048
  ]
798
1049
 
799
1050
  selected_budgeted, mt_selection = self._select_budgeted_violations(
@@ -839,6 +1090,10 @@ class SpectralGuard(Guard):
839
1090
  "caps_exceeded": caps_exceeded,
840
1091
  "multiple_testing": self.multiple_testing,
841
1092
  "multiple_testing_selection": mt_selection,
1093
+ "measurement_contract": {
1094
+ "estimator": self.estimator,
1095
+ "degeneracy": self.degeneracy,
1096
+ },
842
1097
  }
843
1098
 
844
1099
  family_quantiles, top_z_scores = self._compute_family_observability()
@@ -916,7 +1171,7 @@ class SpectralGuard(Guard):
916
1171
  }
917
1172
 
918
1173
  # Final spectral analysis
919
- final_metrics = capture_baseline_sigmas(model, scope=self.scope)
1174
+ final_metrics = self._capture_sigmas(model, phase="finalize")
920
1175
  final_violations = self._detect_spectral_violations(
921
1176
  model, final_metrics, phase="finalize"
922
1177
  )
@@ -928,16 +1183,16 @@ class SpectralGuard(Guard):
928
1183
  family_quantiles, top_z_scores = self._compute_family_observability()
929
1184
 
930
1185
  # Determine overall status based on budgeted vs fatal violations
931
- fatal_violation_types = {"max_spectral_norm", "ill_conditioned"}
932
- budgeted_violations = [
1186
+ fatal_violations = [
933
1187
  violation
934
1188
  for violation in final_violations
935
- if violation.get("type") not in fatal_violation_types
1189
+ if (violation.get("severity") == "fatal")
1190
+ or (violation.get("type") == "max_spectral_norm")
936
1191
  ]
937
- fatal_violations = [
1192
+ budgeted_violations = [
938
1193
  violation
939
1194
  for violation in final_violations
940
- if violation.get("type") in fatal_violation_types
1195
+ if violation not in fatal_violations
941
1196
  ]
942
1197
 
943
1198
  selected_budgeted, mt_selection = self._select_budgeted_violations(
@@ -983,6 +1238,10 @@ class SpectralGuard(Guard):
983
1238
  "multiple_testing_selection": mt_selection,
984
1239
  "family_z_quantiles": family_quantiles,
985
1240
  "top_z_scores": top_z_scores,
1241
+ "measurement_contract": {
1242
+ "estimator": self.estimator,
1243
+ "degeneracy": self.degeneracy,
1244
+ },
986
1245
  }
987
1246
 
988
1247
  # Categorize violations
@@ -990,7 +1249,9 @@ class SpectralGuard(Guard):
990
1249
  errors = []
991
1250
 
992
1251
  for violation in selected_final_violations:
993
- if violation["type"] in ["max_spectral_norm", "ill_conditioned"]:
1252
+ if (violation.get("severity") == "fatal") or (
1253
+ violation.get("type") == "max_spectral_norm"
1254
+ ):
994
1255
  errors.append(violation["message"])
995
1256
  else:
996
1257
  warnings.append(violation["message"])
@@ -1029,7 +1290,9 @@ class SpectralGuard(Guard):
1029
1290
  return result
1030
1291
 
1031
1292
 
1032
- def compute_sigma_max(weight_matrix: Any) -> float:
1293
+ def compute_sigma_max(
1294
+ weight_matrix: Any, *, iters: int = 4, init: str = "ones"
1295
+ ) -> float:
1033
1296
  """
1034
1297
  Compute maximum singular value of a weight matrix.
1035
1298
 
@@ -1040,35 +1303,29 @@ def compute_sigma_max(weight_matrix: Any) -> float:
1040
1303
  Maximum singular value
1041
1304
  """
1042
1305
  try:
1043
- if isinstance(weight_matrix, torch.Tensor):
1044
- # Handle different tensor types
1045
- if weight_matrix.dtype in [torch.int8]:
1046
- # Skip quantized weights
1047
- return 1.0
1048
-
1049
- # Ensure float type for SVD
1050
- W = weight_matrix.float()
1051
-
1052
- # Handle edge cases
1053
- if W.numel() == 0 or W.shape[0] == 0 or W.shape[1] == 0:
1054
- return 0.0
1055
-
1056
- # Compute singular values using deterministic backend when available
1057
- try:
1058
- singular_values = torch.linalg.svdvals(W)
1059
- except RuntimeError:
1060
- # Fallback for older backends without svdvals
1061
- singular_values = torch.linalg.svd(W, full_matrices=False).S
1062
-
1063
- return singular_values[0].item() if singular_values.numel() > 0 else 0.0
1064
- else:
1065
- return 1.0 # Fallback for non-tensor inputs
1306
+ iters_i = int(iters)
1307
+ except Exception:
1308
+ iters_i = 4
1309
+ if iters_i < 1:
1310
+ iters_i = 1
1311
+ init_s = str(init or "ones").strip().lower()
1312
+ if init_s not in {"ones", "e0"}:
1313
+ init_s = "ones"
1314
+
1315
+ if not isinstance(weight_matrix, torch.Tensor):
1316
+ return 1.0
1317
+ if weight_matrix.dtype in {torch.int8, torch.uint8}:
1318
+ return 1.0
1319
+ if weight_matrix.numel() == 0 or weight_matrix.ndim != 2:
1320
+ return 0.0
1066
1321
 
1322
+ try:
1323
+ return float(power_iter_sigma_max(weight_matrix, iters=iters_i, init=init_s))
1067
1324
  except Exception:
1068
- return 1.0 # Fallback on any error
1325
+ return 1.0
1069
1326
 
1070
1327
 
1071
- def auto_sigma_target(model: Any, percentile: float = 0.95, **kwargs: Any) -> float:
1328
+ def auto_sigma_target(model: Any, percentile: float = 0.95) -> float:
1072
1329
  """
1073
1330
  Automatically determine sigma target for a model.
1074
1331
 
@@ -1079,11 +1336,6 @@ def auto_sigma_target(model: Any, percentile: float = 0.95, **kwargs: Any) -> fl
1079
1336
  Returns:
1080
1337
  Target sigma value
1081
1338
  """
1082
- if "kappa" in kwargs and percentile == 0.95:
1083
- try:
1084
- percentile = float(kwargs["kappa"])
1085
- except (TypeError, ValueError):
1086
- pass
1087
1339
  try:
1088
1340
  # Collect all spectral norms
1089
1341
  spectral_norms = []
@@ -1471,7 +1723,6 @@ def scan_model_gains(model: Any, scope: str = "all") -> dict[str, Any]:
1471
1723
  "total_layers": 0,
1472
1724
  "scanned_modules": 0,
1473
1725
  "spectral_norms": [],
1474
- "condition_numbers": [],
1475
1726
  "weight_statistics": {},
1476
1727
  }
1477
1728
 
@@ -1486,15 +1737,6 @@ def scan_model_gains(model: Any, scope: str = "all") -> dict[str, Any]:
1486
1737
  sigma_max = compute_sigma_max(module.weight)
1487
1738
  results["spectral_norms"].append(sigma_max)
1488
1739
 
1489
- # Compute condition number if possible
1490
- try:
1491
- U, S, V = torch.svd(module.weight.float())
1492
- if len(S) > 1:
1493
- condition_num = (S[0] / S[-1]).item()
1494
- results["condition_numbers"].append(condition_num)
1495
- except Exception:
1496
- pass
1497
-
1498
1740
  # Basic weight statistics
1499
1741
  try:
1500
1742
  weight_stats = {
@@ -1513,10 +1755,6 @@ def scan_model_gains(model: Any, scope: str = "all") -> dict[str, Any]:
1513
1755
  results["max_spectral_norm"] = np.max(results["spectral_norms"])
1514
1756
  results["min_spectral_norm"] = np.min(results["spectral_norms"])
1515
1757
 
1516
- if results["condition_numbers"]:
1517
- results["mean_condition_number"] = np.mean(results["condition_numbers"])
1518
- results["max_condition_number"] = np.max(results["condition_numbers"])
1519
-
1520
1758
  results["message"] = (
1521
1759
  f"Scanned {results['scanned_modules']} modules out of {results['total_layers']} total layers"
1522
1760
  )