invarlock 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (132) hide show
  1. invarlock/__init__.py +33 -0
  2. invarlock/__main__.py +10 -0
  3. invarlock/_data/runtime/profiles/ci_cpu.yaml +15 -0
  4. invarlock/_data/runtime/profiles/release.yaml +23 -0
  5. invarlock/_data/runtime/tiers.yaml +76 -0
  6. invarlock/adapters/__init__.py +102 -0
  7. invarlock/adapters/_capabilities.py +45 -0
  8. invarlock/adapters/auto.py +99 -0
  9. invarlock/adapters/base.py +530 -0
  10. invarlock/adapters/base_types.py +85 -0
  11. invarlock/adapters/hf_bert.py +852 -0
  12. invarlock/adapters/hf_gpt2.py +403 -0
  13. invarlock/adapters/hf_llama.py +485 -0
  14. invarlock/adapters/hf_mixin.py +383 -0
  15. invarlock/adapters/hf_onnx.py +112 -0
  16. invarlock/adapters/hf_t5.py +137 -0
  17. invarlock/adapters/py.typed +1 -0
  18. invarlock/assurance/__init__.py +43 -0
  19. invarlock/cli/__init__.py +8 -0
  20. invarlock/cli/__main__.py +8 -0
  21. invarlock/cli/_evidence.py +25 -0
  22. invarlock/cli/_json.py +75 -0
  23. invarlock/cli/adapter_auto.py +162 -0
  24. invarlock/cli/app.py +287 -0
  25. invarlock/cli/commands/__init__.py +26 -0
  26. invarlock/cli/commands/certify.py +403 -0
  27. invarlock/cli/commands/doctor.py +1358 -0
  28. invarlock/cli/commands/explain_gates.py +151 -0
  29. invarlock/cli/commands/export_html.py +100 -0
  30. invarlock/cli/commands/plugins.py +1331 -0
  31. invarlock/cli/commands/report.py +354 -0
  32. invarlock/cli/commands/run.py +4146 -0
  33. invarlock/cli/commands/verify.py +1040 -0
  34. invarlock/cli/config.py +396 -0
  35. invarlock/cli/constants.py +68 -0
  36. invarlock/cli/device.py +92 -0
  37. invarlock/cli/doctor_helpers.py +74 -0
  38. invarlock/cli/errors.py +6 -0
  39. invarlock/cli/overhead_utils.py +60 -0
  40. invarlock/cli/provenance.py +66 -0
  41. invarlock/cli/utils.py +41 -0
  42. invarlock/config.py +56 -0
  43. invarlock/core/__init__.py +62 -0
  44. invarlock/core/abi.py +15 -0
  45. invarlock/core/api.py +274 -0
  46. invarlock/core/auto_tuning.py +317 -0
  47. invarlock/core/bootstrap.py +226 -0
  48. invarlock/core/checkpoint.py +221 -0
  49. invarlock/core/contracts.py +73 -0
  50. invarlock/core/error_utils.py +64 -0
  51. invarlock/core/events.py +298 -0
  52. invarlock/core/exceptions.py +95 -0
  53. invarlock/core/registry.py +481 -0
  54. invarlock/core/retry.py +146 -0
  55. invarlock/core/runner.py +2041 -0
  56. invarlock/core/types.py +154 -0
  57. invarlock/edits/__init__.py +12 -0
  58. invarlock/edits/_edit_utils.py +249 -0
  59. invarlock/edits/_external_utils.py +268 -0
  60. invarlock/edits/noop.py +47 -0
  61. invarlock/edits/py.typed +1 -0
  62. invarlock/edits/quant_rtn.py +801 -0
  63. invarlock/edits/registry.py +166 -0
  64. invarlock/eval/__init__.py +23 -0
  65. invarlock/eval/bench.py +1207 -0
  66. invarlock/eval/bootstrap.py +50 -0
  67. invarlock/eval/data.py +2052 -0
  68. invarlock/eval/metrics.py +2167 -0
  69. invarlock/eval/primary_metric.py +767 -0
  70. invarlock/eval/probes/__init__.py +24 -0
  71. invarlock/eval/probes/fft.py +139 -0
  72. invarlock/eval/probes/mi.py +213 -0
  73. invarlock/eval/probes/post_attention.py +323 -0
  74. invarlock/eval/providers/base.py +67 -0
  75. invarlock/eval/providers/seq2seq.py +111 -0
  76. invarlock/eval/providers/text_lm.py +113 -0
  77. invarlock/eval/providers/vision_text.py +93 -0
  78. invarlock/eval/py.typed +1 -0
  79. invarlock/guards/__init__.py +18 -0
  80. invarlock/guards/_contracts.py +9 -0
  81. invarlock/guards/invariants.py +640 -0
  82. invarlock/guards/policies.py +805 -0
  83. invarlock/guards/py.typed +1 -0
  84. invarlock/guards/rmt.py +2097 -0
  85. invarlock/guards/spectral.py +1419 -0
  86. invarlock/guards/tier_config.py +354 -0
  87. invarlock/guards/variance.py +3298 -0
  88. invarlock/guards_ref/__init__.py +15 -0
  89. invarlock/guards_ref/rmt_ref.py +40 -0
  90. invarlock/guards_ref/spectral_ref.py +135 -0
  91. invarlock/guards_ref/variance_ref.py +60 -0
  92. invarlock/model_profile.py +353 -0
  93. invarlock/model_utils.py +221 -0
  94. invarlock/observability/__init__.py +10 -0
  95. invarlock/observability/alerting.py +535 -0
  96. invarlock/observability/core.py +546 -0
  97. invarlock/observability/exporters.py +565 -0
  98. invarlock/observability/health.py +588 -0
  99. invarlock/observability/metrics.py +457 -0
  100. invarlock/observability/py.typed +1 -0
  101. invarlock/observability/utils.py +553 -0
  102. invarlock/plugins/__init__.py +12 -0
  103. invarlock/plugins/hello_guard.py +33 -0
  104. invarlock/plugins/hf_awq_adapter.py +82 -0
  105. invarlock/plugins/hf_bnb_adapter.py +79 -0
  106. invarlock/plugins/hf_gptq_adapter.py +78 -0
  107. invarlock/plugins/py.typed +1 -0
  108. invarlock/py.typed +1 -0
  109. invarlock/reporting/__init__.py +7 -0
  110. invarlock/reporting/certificate.py +3221 -0
  111. invarlock/reporting/certificate_schema.py +244 -0
  112. invarlock/reporting/dataset_hashing.py +215 -0
  113. invarlock/reporting/guards_analysis.py +948 -0
  114. invarlock/reporting/html.py +32 -0
  115. invarlock/reporting/normalizer.py +235 -0
  116. invarlock/reporting/policy_utils.py +517 -0
  117. invarlock/reporting/primary_metric_utils.py +265 -0
  118. invarlock/reporting/render.py +1442 -0
  119. invarlock/reporting/report.py +903 -0
  120. invarlock/reporting/report_types.py +278 -0
  121. invarlock/reporting/utils.py +175 -0
  122. invarlock/reporting/validate.py +631 -0
  123. invarlock/security.py +176 -0
  124. invarlock/sparsity_utils.py +323 -0
  125. invarlock/utils/__init__.py +150 -0
  126. invarlock/utils/digest.py +45 -0
  127. invarlock-0.2.0.dist-info/METADATA +586 -0
  128. invarlock-0.2.0.dist-info/RECORD +132 -0
  129. invarlock-0.2.0.dist-info/WHEEL +5 -0
  130. invarlock-0.2.0.dist-info/entry_points.txt +20 -0
  131. invarlock-0.2.0.dist-info/licenses/LICENSE +201 -0
  132. invarlock-0.2.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,1419 @@
1
+ """
2
+ Spectral Guard Implementation
3
+ ============================
4
+
5
+ Monitors spectral properties of model weights to detect instabilities.
6
+ Provides spectral control mechanisms for maintaining numerical stability.
7
+ """
8
+
9
+ import math
10
+ import time
11
+ from collections import defaultdict
12
+ from datetime import datetime
13
+ from typing import Any, TypedDict
14
+
15
+ try:
16
+ from typing import NotRequired
17
+ except ImportError: # Python 3.10 fallback
18
+ from typing import NotRequired
19
+
20
+ import numpy as np
21
+ import torch
22
+
23
+ from invarlock.cli._evidence import maybe_dump_guard_evidence
24
+ from invarlock.core.api import Guard
25
+
26
+ from ._contracts import guard_assert
27
+
28
+
29
+ class SpectralPolicy(TypedDict, total=False):
30
+ """Type definition for spectral guard policy configuration."""
31
+
32
+ sigma_quantile: float
33
+ contraction: NotRequired[float] # Backward compatibility alias
34
+ kappa: NotRequired[float] # Legacy alias
35
+ deadband: float
36
+ scope: str
37
+ correction_enabled: bool
38
+ family_caps: dict[str, dict[str, float]]
39
+ ignore_preview_inflation: bool
40
+ max_caps: int
41
+ multiple_testing: dict[str, Any]
42
+
43
+
44
+ def _default_family_caps() -> dict[str, dict[str, float]]:
45
+ """Default per-family spectral z-score caps."""
46
+ return {
47
+ "ffn": {"kappa": 2.5},
48
+ "attn": {"kappa": 2.8},
49
+ "embed": {"kappa": 3.0},
50
+ "other": {"kappa": 3.0},
51
+ }
52
+
53
+
54
+ def _normalize_family_caps(
55
+ caps: Any, *, default: bool = True
56
+ ) -> dict[str, dict[str, float]]:
57
+ """Normalize family cap configuration into canonical mapping."""
58
+
59
+ if not isinstance(caps, dict) or not caps:
60
+ return _default_family_caps() if default else {}
61
+
62
+ normalized: dict[str, dict[str, float]] = {}
63
+ for family, values in caps.items():
64
+ entry: dict[str, float] = {}
65
+ if isinstance(values, dict):
66
+ for key, val in values.items():
67
+ if isinstance(val, int | float) and math.isfinite(float(val)):
68
+ entry[str(key)] = float(val)
69
+ elif isinstance(values, int | float) and math.isfinite(float(values)):
70
+ entry["kappa"] = float(values)
71
+ if entry:
72
+ normalized[str(family)] = entry
73
+
74
+ if normalized:
75
+ return normalized
76
+
77
+ return _default_family_caps() if default else {}
78
+
79
+
80
+ class SpectralGuard(Guard):
81
+ """
82
+ Spectral guard for monitoring weight matrix spectral properties.
83
+
84
+ Tracks singular values and spectral norms to detect numerical instabilities.
85
+ Provides automatic spectral control when violations are detected.
86
+ """
87
+
88
+ name = "spectral"
89
+
90
+ def __init__(self, **kwargs):
91
+ """Initialize spectral guard."""
92
+ self.config = dict(kwargs)
93
+ self.prepared = False
94
+ self.baseline_metrics = {}
95
+ self.events = []
96
+ self.current_metrics = {}
97
+ self.violations = []
98
+
99
+ # Default configuration
100
+ sigma_quantile = self.config.get("sigma_quantile")
101
+ if sigma_quantile is None:
102
+ for alias in ("contraction", "kappa"):
103
+ if self.config.get(alias) is not None:
104
+ sigma_quantile = self.config[alias]
105
+ break
106
+ if sigma_quantile is None:
107
+ sigma_quantile = 0.95
108
+ self.sigma_quantile = float(sigma_quantile)
109
+ self.config["sigma_quantile"] = self.sigma_quantile
110
+ self.config.pop("contraction", None)
111
+ self.config.pop("kappa", None)
112
+ self.deadband = kwargs.get("deadband", 0.10)
113
+ self.scope = kwargs.get("scope", "all") # 'all', 'ffn', 'attn'
114
+ self.max_spectral_norm = kwargs.get("max_spectral_norm", 10.0)
115
+ if self.max_spectral_norm is not None:
116
+ self.max_spectral_norm = float(self.max_spectral_norm)
117
+ self.config["max_spectral_norm"] = self.max_spectral_norm
118
+ self.min_condition_number = kwargs.get("min_condition_number", 1e-12)
119
+ self.correction_enabled = kwargs.get("correction_enabled", True)
120
+ self.family_caps = _normalize_family_caps(
121
+ kwargs.get("family_caps"), default=True
122
+ )
123
+ self.ignore_preview_inflation = kwargs.get("ignore_preview_inflation", True)
124
+ self.max_caps = kwargs.get("max_caps", 5)
125
+ self.multiple_testing = kwargs.get(
126
+ "multiple_testing", {"method": "bh", "alpha": 0.05, "m": 4}
127
+ )
128
+
129
+ # Baseline and tracking structures
130
+ self.baseline_sigmas: dict[str, float] = {}
131
+ self.baseline_family_stats: dict[str, dict[str, float]] = {}
132
+ self.module_family_map: dict[str, str] = {}
133
+ self.latest_z_scores: dict[str, float] = {}
134
+ self.pre_edit_z_scores: dict[str, float] = {}
135
+
136
+ def _log_event(
137
+ self, operation: str, level: str = "INFO", message: str = "", **data
138
+ ):
139
+ """Log an event with timestamp."""
140
+ event = {
141
+ "timestamp": datetime.utcnow().isoformat(),
142
+ "component": "spectral_guard",
143
+ "operation": operation,
144
+ "level": level,
145
+ "message": message,
146
+ "data": data,
147
+ }
148
+ self.events.append(event)
149
+
150
+ def _serialize_policy(self) -> dict[str, Any]:
151
+ """Snapshot current guard policy for report serialization."""
152
+ return {
153
+ "scope": self.scope,
154
+ "sigma_quantile": float(self.sigma_quantile),
155
+ "deadband": float(self.deadband),
156
+ "max_caps": int(self.max_caps),
157
+ "max_spectral_norm": (
158
+ float(self.max_spectral_norm)
159
+ if self.max_spectral_norm is not None
160
+ else None
161
+ ),
162
+ "family_caps": self.family_caps,
163
+ "multiple_testing": self.multiple_testing,
164
+ "correction_enabled": bool(self.correction_enabled),
165
+ "ignore_preview_inflation": bool(self.ignore_preview_inflation),
166
+ }
167
+
168
+ def prepare(
169
+ self, model: Any, adapter: Any, calib: Any, policy: dict[str, Any]
170
+ ) -> dict[str, Any]:
171
+ """
172
+ Prepare spectral guard by capturing baseline spectral properties.
173
+
174
+ Args:
175
+ model: Model to prepare for
176
+ adapter: ModelAdapter instance
177
+ calib: Calibration data (unused for spectral analysis)
178
+ policy: Policy configuration
179
+
180
+ Returns:
181
+ Preparation results
182
+ """
183
+ start_time = time.time()
184
+
185
+ # Update configuration from policy
186
+ if policy:
187
+ sigma_value = policy.get("sigma_quantile")
188
+ if sigma_value is None:
189
+ alias_value = policy.get("contraction", policy.get("kappa"))
190
+ if alias_value is not None:
191
+ sigma_value = alias_value
192
+ if sigma_value is not None:
193
+ self.sigma_quantile = float(sigma_value)
194
+ policy["sigma_quantile"] = self.sigma_quantile
195
+ policy.pop("contraction", None)
196
+ policy.pop("kappa", None)
197
+ self.config["sigma_quantile"] = self.sigma_quantile
198
+
199
+ for key in [
200
+ "sigma_quantile",
201
+ "deadband",
202
+ "scope",
203
+ "max_spectral_norm",
204
+ "correction_enabled",
205
+ "max_caps",
206
+ ]:
207
+ if key in policy:
208
+ setattr(self, key, policy[key])
209
+ self.config[key] = policy[key]
210
+
211
+ if self.max_spectral_norm is not None:
212
+ self.max_spectral_norm = float(self.max_spectral_norm)
213
+ self.config["max_spectral_norm"] = self.max_spectral_norm
214
+
215
+ if "family_caps" in policy:
216
+ self.family_caps = _normalize_family_caps(
217
+ policy["family_caps"], default=True
218
+ )
219
+ self.config["family_caps"] = self.family_caps
220
+
221
+ if "ignore_preview_inflation" in policy:
222
+ self.ignore_preview_inflation = bool(policy["ignore_preview_inflation"])
223
+ self.config["ignore_preview_inflation"] = self.ignore_preview_inflation
224
+
225
+ # Optional hydration of baseline stats from policy (e.g., baseline certificate)
226
+ if "baseline_family_stats" in policy and isinstance(
227
+ policy["baseline_family_stats"], dict
228
+ ):
229
+ self.baseline_family_stats = {
230
+ family: stats.copy()
231
+ for family, stats in policy["baseline_family_stats"].items()
232
+ if isinstance(stats, dict)
233
+ }
234
+ self.config["baseline_family_stats"] = self.baseline_family_stats
235
+ mt_policy = policy.get("multiple_testing")
236
+ if mt_policy is None:
237
+ mt_policy = policy.get("multipletesting")
238
+ if isinstance(mt_policy, dict):
239
+ self.multiple_testing = mt_policy.copy()
240
+ policy["multiple_testing"] = self.multiple_testing
241
+ self.config["multiple_testing"] = self.multiple_testing
242
+ policy.pop("multipletesting", None)
243
+
244
+ self._log_event(
245
+ "prepare",
246
+ message=(
247
+ f"Preparing spectral guard with scope={self.scope}, "
248
+ f"sigma_quantile={self.sigma_quantile}"
249
+ ),
250
+ )
251
+
252
+ try:
253
+ # Capture baseline spectral properties
254
+ self.baseline_sigmas = capture_baseline_sigmas(model, scope=self.scope)
255
+ self.module_family_map = classify_model_families(
256
+ model, scope=self.scope, existing=self.module_family_map
257
+ )
258
+ if not self.baseline_family_stats:
259
+ self.baseline_family_stats = compute_family_stats(
260
+ self.baseline_sigmas, self.module_family_map
261
+ )
262
+
263
+ # Compute additional baseline metrics
264
+ baseline_stats = scan_model_gains(model, scope=self.scope)
265
+ summarized = _summarize_sigmas(self.baseline_sigmas)
266
+ baseline_stats.update(summarized)
267
+
268
+ # Store target sigma value
269
+ self.target_sigma = auto_sigma_target(model, percentile=self.sigma_quantile)
270
+ baseline_stats["target_sigma"] = self.target_sigma
271
+
272
+ baseline_stats["family_stats"] = {
273
+ family: stats.copy()
274
+ for family, stats in self.baseline_family_stats.items()
275
+ }
276
+ baseline_stats["family_caps"] = {
277
+ family: caps.copy() for family, caps in self.family_caps.items()
278
+ }
279
+ baseline_stats["module_sigmas"] = self.baseline_sigmas.copy()
280
+
281
+ self.baseline_metrics = baseline_stats
282
+
283
+ self.prepared = True
284
+ preparation_time = time.time() - start_time
285
+
286
+ self._log_event(
287
+ "prepare_success",
288
+ message=f"Prepared spectral guard with {len(self.baseline_metrics)} baseline metrics",
289
+ baseline_metrics_count=len(self.baseline_metrics),
290
+ target_sigma=self.target_sigma,
291
+ preparation_time=preparation_time,
292
+ )
293
+
294
+ return {
295
+ "ready": True,
296
+ "baseline_metrics": self.baseline_metrics,
297
+ "target_sigma": self.target_sigma,
298
+ "scope": self.scope,
299
+ "preparation_time": preparation_time,
300
+ }
301
+
302
+ except Exception as e:
303
+ self.prepared = False
304
+ self._log_event(
305
+ "prepare_failed",
306
+ level="ERROR",
307
+ message=f"Failed to prepare spectral guard: {str(e)}",
308
+ error=str(e),
309
+ )
310
+
311
+ return {
312
+ "ready": False,
313
+ "error": str(e),
314
+ "preparation_time": time.time() - start_time,
315
+ }
316
+
317
+ def before_edit(self, model: Any) -> None:
318
+ """Execute before edit (capture pre-edit state)."""
319
+ if not self.prepared:
320
+ self._log_event(
321
+ "before_edit_skipped",
322
+ level="WARN",
323
+ message="Spectral guard not prepared, skipping pre-edit capture",
324
+ )
325
+ return
326
+
327
+ # Capture pre-edit spectral state for comparison
328
+ self.pre_edit_metrics = capture_baseline_sigmas(model, scope=self.scope)
329
+ self.pre_edit_z_scores = compute_z_scores(
330
+ self.pre_edit_metrics,
331
+ self.baseline_family_stats,
332
+ self.module_family_map,
333
+ self.baseline_sigmas,
334
+ deadband=self.deadband,
335
+ )
336
+ self._log_event("before_edit", message="Captured pre-edit spectral state")
337
+
338
+ def after_edit(self, model: Any) -> None:
339
+ """Execute after edit (detect violations and apply control)."""
340
+ if not self.prepared:
341
+ self._log_event(
342
+ "after_edit_skipped",
343
+ level="WARN",
344
+ message="Spectral guard not prepared, skipping post-edit analysis",
345
+ )
346
+ return
347
+
348
+ try:
349
+ # Capture current spectral state
350
+ self.current_metrics = capture_baseline_sigmas(model, scope=self.scope)
351
+
352
+ # Detect violations
353
+ violations = self._detect_spectral_violations(
354
+ model, self.current_metrics, phase="after_edit"
355
+ )
356
+ self.violations = violations
357
+
358
+ # Apply spectral control if violations detected and correction enabled
359
+ if violations and self.correction_enabled:
360
+ control_result = apply_spectral_control(
361
+ model,
362
+ policy={
363
+ "sigma_quantile": self.sigma_quantile,
364
+ "scope": self.scope,
365
+ "baseline_sigmas": self.baseline_sigmas,
366
+ "target_sigma": self.target_sigma,
367
+ },
368
+ )
369
+
370
+ self._log_event(
371
+ "spectral_control_applied",
372
+ message=f"Applied spectral control, violations: {len(violations)}",
373
+ violations_count=len(violations),
374
+ control_result=control_result,
375
+ )
376
+
377
+ self._log_event(
378
+ "after_edit",
379
+ message=f"Post-edit analysis complete, {len(violations)} violations detected",
380
+ )
381
+
382
+ except Exception as e:
383
+ self._log_event(
384
+ "after_edit_failed",
385
+ level="ERROR",
386
+ message=f"Post-edit spectral analysis failed: {str(e)}",
387
+ error=str(e),
388
+ )
389
+
390
+ def _detect_spectral_violations(
391
+ self, model: Any, metrics: dict[str, float], phase: str = "finalize"
392
+ ) -> list[dict[str, Any]]:
393
+ """Detect spectral property violations using per-family z-score caps."""
394
+ violations: list[dict[str, Any]] = []
395
+ latest_z: dict[str, float] = {}
396
+
397
+ for name, module in model.named_modules():
398
+ if not self._should_check_module(name, module):
399
+ continue
400
+
401
+ try:
402
+ if hasattr(module, "weight") and module.weight.ndim == 2:
403
+ sigma_max = metrics.get(name)
404
+ if sigma_max is None:
405
+ sigma_max = compute_sigma_max(module.weight)
406
+
407
+ baseline_sigma = self.baseline_sigmas.get(name, self.target_sigma)
408
+ family = self.module_family_map.get(name)
409
+ if family is None:
410
+ family = classify_module_family(name, module)
411
+ self.module_family_map[name] = family
412
+
413
+ family_stats = self.baseline_family_stats.get(family, {})
414
+ cap_config = self.family_caps.get(family, {})
415
+ kappa_cap = float(cap_config.get("kappa", self.sigma_quantile))
416
+
417
+ z_score = compute_z_score_for_value(
418
+ sigma_max,
419
+ family_stats,
420
+ fallback_value=baseline_sigma,
421
+ deadband=self.deadband,
422
+ )
423
+ latest_z[name] = z_score
424
+
425
+ # Skip preview inflation if configured and not in final phase
426
+ if self.ignore_preview_inflation and phase == "after_edit":
427
+ continue
428
+
429
+ if z_score > kappa_cap:
430
+ violations.append(
431
+ {
432
+ "type": "family_z_cap",
433
+ "module": name,
434
+ "family": family,
435
+ "z_score": float(z_score),
436
+ "kappa": kappa_cap,
437
+ "sigma": float(sigma_max),
438
+ "baseline_sigma": float(baseline_sigma),
439
+ "message": (
440
+ f"Family '{family}' z-score {z_score:.2f}"
441
+ f" exceeds cap {kappa_cap:.2f}"
442
+ ),
443
+ }
444
+ )
445
+
446
+ if (
447
+ self.max_spectral_norm is not None
448
+ and sigma_max > self.max_spectral_norm
449
+ ):
450
+ threshold = float(self.max_spectral_norm)
451
+ violations.append(
452
+ {
453
+ "type": "max_spectral_norm",
454
+ "module": name,
455
+ "family": family,
456
+ "current_sigma": float(sigma_max),
457
+ "threshold": threshold,
458
+ "message": f"Spectral norm {sigma_max:.3f} exceeds maximum {threshold}",
459
+ }
460
+ )
461
+
462
+ # Condition number monitoring (warn only)
463
+ try:
464
+ U, S, V = torch.svd(module.weight.float())
465
+ if len(S) > 0:
466
+ condition_number = S[0].item() / max(S[-1].item(), 1e-12)
467
+ if S[-1].item() < self.min_condition_number:
468
+ violations.append(
469
+ {
470
+ "type": "ill_conditioned",
471
+ "module": name,
472
+ "family": family,
473
+ "condition_number": float(condition_number),
474
+ "min_singular_value": float(S[-1].item()),
475
+ "threshold": float(self.min_condition_number),
476
+ "message": f"Matrix is ill-conditioned, min singular value: {S[-1].item():.2e}",
477
+ }
478
+ )
479
+ except Exception:
480
+ pass # SVD failure is not a violation
481
+
482
+ except Exception as e:
483
+ self._log_event(
484
+ "violation_check_error",
485
+ level="WARN",
486
+ message=f"Failed to check module {name}: {str(e)}",
487
+ module=name,
488
+ error=str(e),
489
+ )
490
+
491
+ self.latest_z_scores = latest_z
492
+ return violations
493
+
494
+ def _should_check_module(self, name: str, module: Any) -> bool:
495
+ """Determine if a module should be checked based on scope."""
496
+ if not hasattr(module, "weight") or module.weight.ndim != 2:
497
+ return False
498
+
499
+ if self.scope == "all":
500
+ return True
501
+ elif self.scope == "attn":
502
+ return any(
503
+ keyword in name.lower()
504
+ for keyword in ["attn", "attention", "self_attn"]
505
+ )
506
+ elif self.scope == "ffn":
507
+ return any(
508
+ keyword in name.lower()
509
+ for keyword in ["mlp", "ffn", "feed_forward", "fc"]
510
+ )
511
+
512
+ return True
513
+
514
+ def _compute_family_observability(
515
+ self,
516
+ ) -> tuple[dict[str, dict[str, float]], dict[str, list[dict[str, Any]]]]:
517
+ """Generate per-family quantiles and top-|z| listings from latest z-scores."""
518
+ family_scores: dict[str, list[float]] = defaultdict(list)
519
+ family_modules: dict[str, list[tuple[float, str]]] = defaultdict(list)
520
+
521
+ for module_name, z_value in (self.latest_z_scores or {}).items():
522
+ family = self.module_family_map.get(module_name)
523
+ if family is None:
524
+ continue
525
+ try:
526
+ z_abs = abs(float(z_value))
527
+ except (TypeError, ValueError):
528
+ continue
529
+ family_scores.setdefault(family, []).append(z_abs)
530
+ family_modules.setdefault(family, []).append((z_abs, module_name))
531
+
532
+ def _quantile(sorted_values: list[float], quantile: float) -> float:
533
+ if not sorted_values:
534
+ return 0.0
535
+ if len(sorted_values) == 1:
536
+ return sorted_values[0]
537
+ position = (len(sorted_values) - 1) * quantile
538
+ lower = math.floor(position)
539
+ upper = math.ceil(position)
540
+ if lower == upper:
541
+ return sorted_values[int(position)]
542
+ fraction = position - lower
543
+ return (
544
+ sorted_values[lower]
545
+ + (sorted_values[upper] - sorted_values[lower]) * fraction
546
+ )
547
+
548
+ family_quantiles: dict[str, dict[str, float]] = {}
549
+ for family, scores in family_scores.items():
550
+ sorted_scores = sorted(scores)
551
+ family_quantiles[family] = {
552
+ "q95": _quantile(sorted_scores, 0.95),
553
+ "q99": _quantile(sorted_scores, 0.99),
554
+ "max": sorted_scores[-1] if sorted_scores else 0.0,
555
+ "count": len(sorted_scores),
556
+ }
557
+
558
+ top_z_scores: dict[str, list[dict[str, Any]]] = {}
559
+ for family, module_entries in family_modules.items():
560
+ module_entries.sort(key=lambda item: item[0], reverse=True)
561
+ top_entries: list[dict[str, Any]] = []
562
+ for z_abs, module_name in module_entries[:3]:
563
+ top_entries.append(
564
+ {"module": module_name, "z": float(z_abs), "family": family}
565
+ )
566
+ top_z_scores[family] = top_entries
567
+
568
+ return family_quantiles, top_z_scores
569
+
570
+ def validate(
571
+ self, model: Any, adapter: Any, context: dict[str, Any]
572
+ ) -> dict[str, Any]:
573
+ """
574
+ Validate model spectral properties.
575
+
576
+ Args:
577
+ model: Model to validate
578
+ adapter: ModelAdapter instance
579
+ context: Validation context
580
+
581
+ Returns:
582
+ Dictionary with validation results
583
+ """
584
+ try:
585
+ if not self.prepared:
586
+ # Auto-prepare if needed
587
+ self.prepare(model, adapter, None, {})
588
+
589
+ # Capture current spectral state
590
+ current_metrics = capture_baseline_sigmas(model, scope=self.scope)
591
+
592
+ # Detect violations (final validation phase)
593
+ violations = self._detect_spectral_violations(
594
+ model, current_metrics, phase="validate"
595
+ )
596
+
597
+ # Determine if passed under budget/fatal rules
598
+ fatal_violation_types = {"max_spectral_norm", "ill_conditioned"}
599
+ budgeted_violations = [
600
+ violation
601
+ for violation in violations
602
+ if violation.get("type") not in fatal_violation_types
603
+ ]
604
+ fatal_violations = [
605
+ violation
606
+ for violation in violations
607
+ if violation.get("type") in fatal_violation_types
608
+ ]
609
+
610
+ caps_applied = len(budgeted_violations)
611
+ caps_exceeded = caps_applied > int(self.max_caps)
612
+ passed = not fatal_violations and not caps_exceeded
613
+ if fatal_violations or caps_exceeded:
614
+ action = "abort"
615
+ elif caps_applied > 0:
616
+ action = "warn"
617
+ else:
618
+ action = "continue"
619
+
620
+ # Compute overall metrics
621
+ family_summary = summarize_family_z_scores(
622
+ self.latest_z_scores, self.module_family_map, self.family_caps
623
+ )
624
+ metrics = {
625
+ "modules_checked": len(current_metrics),
626
+ "violations_found": len(violations),
627
+ "budgeted_violations": caps_applied,
628
+ "fatal_violations": len(fatal_violations),
629
+ "max_spectral_norm": max(current_metrics.values())
630
+ if current_metrics
631
+ else 0.0,
632
+ "mean_spectral_norm": np.mean(list(current_metrics.values()))
633
+ if current_metrics
634
+ else 0.0,
635
+ "stability_score": 1.0
636
+ - min(len(violations) / max(len(current_metrics), 1), 1.0),
637
+ "family_z_summary": family_summary,
638
+ "family_caps": self.family_caps,
639
+ "sigma_quantile": float(self.sigma_quantile),
640
+ "deadband": float(self.deadband),
641
+ "max_caps": int(self.max_caps),
642
+ "caps_applied": caps_applied,
643
+ "caps_exceeded": caps_exceeded,
644
+ "multiple_testing": self.multiple_testing,
645
+ }
646
+
647
+ family_quantiles, top_z_scores = self._compute_family_observability()
648
+ if family_quantiles:
649
+ metrics["family_z_quantiles"] = family_quantiles
650
+ if top_z_scores:
651
+ metrics["top_z_scores"] = top_z_scores
652
+
653
+ if passed:
654
+ message = (
655
+ "Spectral validation passed with "
656
+ f"{len(violations)} violations "
657
+ f"(caps_applied={caps_applied}, max_caps={self.max_caps})"
658
+ )
659
+ else:
660
+ reason = (
661
+ "fatal spectral violation detected"
662
+ if fatal_violations
663
+ else "cap budget exceeded"
664
+ )
665
+ message = (
666
+ f"Spectral validation failed: {reason} "
667
+ f"(caps_applied={caps_applied}, max_caps={self.max_caps})"
668
+ )
669
+
670
+ # Runtime contracts (lightweight)
671
+ mt = self.multiple_testing or {}
672
+ try:
673
+ alpha = float(mt.get("alpha", 0.05)) if isinstance(mt, dict) else 0.05
674
+ except Exception:
675
+ alpha = 0.05
676
+ guard_assert(self.deadband >= 0.0, "spectral.deadband must be >= 0")
677
+ guard_assert(
678
+ 0.0 < alpha <= 1.0, "spectral.multiple_testing.alpha out of range"
679
+ )
680
+ guard_assert(self.max_caps >= 0, "spectral.max_caps must be >= 0")
681
+
682
+ return {
683
+ "passed": passed,
684
+ "action": action,
685
+ "metrics": metrics,
686
+ "violations": violations,
687
+ "message": message,
688
+ "policy": self._serialize_policy(),
689
+ "final_z_scores": self.latest_z_scores.copy(),
690
+ "module_family_map": dict(self.module_family_map),
691
+ }
692
+
693
+ except Exception as e:
694
+ return {
695
+ "passed": False,
696
+ "action": "warn",
697
+ "error": str(e),
698
+ "metrics": {},
699
+ "message": f"Spectral validation failed: {e}",
700
+ }
701
+
702
+ def finalize(self, model: Any) -> dict[str, Any]:
703
+ """
704
+ Finalize spectral guard and return comprehensive results.
705
+
706
+ Args:
707
+ model: The final model state
708
+
709
+ Returns:
710
+ Dictionary with spectral guard results
711
+ """
712
+ if not self.prepared:
713
+ return {
714
+ "passed": False,
715
+ "metrics": {},
716
+ "warnings": ["Spectral guard not properly prepared"],
717
+ "errors": ["Preparation failed or not called"],
718
+ "events": self.events,
719
+ }
720
+
721
+ # Final spectral analysis
722
+ final_metrics = capture_baseline_sigmas(model, scope=self.scope)
723
+ final_violations = self._detect_spectral_violations(
724
+ model, final_metrics, phase="finalize"
725
+ )
726
+ final_z_summary = summarize_family_z_scores(
727
+ self.latest_z_scores, self.module_family_map, self.family_caps
728
+ )
729
+ final_family_stats = compute_family_stats(final_metrics, self.module_family_map)
730
+
731
+ family_quantiles, top_z_scores = self._compute_family_observability()
732
+
733
+ # Determine overall status based on budgeted vs fatal violations
734
+ fatal_violation_types = {"max_spectral_norm", "ill_conditioned"}
735
+ budgeted_violations = [
736
+ violation
737
+ for violation in final_violations
738
+ if violation.get("type") not in fatal_violation_types
739
+ ]
740
+ fatal_violations = [
741
+ violation
742
+ for violation in final_violations
743
+ if violation.get("type") in fatal_violation_types
744
+ ]
745
+
746
+ caps_applied = len(budgeted_violations)
747
+ caps_exceeded = caps_applied > int(self.max_caps)
748
+ passed = not fatal_violations and not caps_exceeded
749
+
750
+ # Compute comprehensive metrics
751
+ metrics = {
752
+ "modules_analyzed": len(final_metrics),
753
+ "violations_detected": len(final_violations),
754
+ "budgeted_violations": caps_applied,
755
+ "fatal_violations": len(fatal_violations),
756
+ "baseline_modules": len(self.baseline_metrics),
757
+ "scope": self.scope,
758
+ "max_spectral_norm_final": max(final_metrics.values())
759
+ if final_metrics
760
+ else 0.0,
761
+ "mean_spectral_norm_final": np.mean(list(final_metrics.values()))
762
+ if final_metrics
763
+ else 0.0,
764
+ "spectral_stability_score": 1.0
765
+ - min(len(final_violations) / max(len(final_metrics), 1), 1.0),
766
+ "target_sigma": self.target_sigma,
767
+ "correction_applied": len(final_violations) > 0 and self.correction_enabled,
768
+ "family_caps": self.family_caps,
769
+ "family_z_summary": final_z_summary,
770
+ "family_stats": final_family_stats,
771
+ "sigma_quantile": float(self.sigma_quantile),
772
+ "deadband": float(self.deadband),
773
+ "max_caps": int(self.max_caps),
774
+ "caps_applied": caps_applied,
775
+ "caps_exceeded": caps_exceeded,
776
+ "multiple_testing": self.multiple_testing,
777
+ "family_z_quantiles": family_quantiles,
778
+ "top_z_scores": top_z_scores,
779
+ }
780
+
781
+ # Categorize violations
782
+ warnings = []
783
+ errors = []
784
+
785
+ for violation in final_violations:
786
+ if violation["type"] in ["max_spectral_norm", "ill_conditioned"]:
787
+ errors.append(violation["message"])
788
+ else:
789
+ warnings.append(violation["message"])
790
+
791
+ result = {
792
+ "passed": passed,
793
+ "metrics": metrics,
794
+ "warnings": warnings,
795
+ "errors": errors,
796
+ "violations": final_violations,
797
+ "events": self.events,
798
+ "baseline_metrics": self.baseline_metrics,
799
+ "final_metrics": final_metrics,
800
+ "final_z_scores": self.latest_z_scores,
801
+ "module_family_map": dict(self.module_family_map),
802
+ "policy": self._serialize_policy(),
803
+ }
804
+
805
+ # Env-gated tiny evidence dump for auditors
806
+ try:
807
+ payload = {
808
+ "spectral": {
809
+ "sigma_quantile": float(self.sigma_quantile),
810
+ "deadband": float(self.deadband),
811
+ "max_caps": int(self.max_caps),
812
+ "multiple_testing": self.multiple_testing.get("method")
813
+ if isinstance(self.multiple_testing, dict)
814
+ else None,
815
+ "evaluated": True,
816
+ }
817
+ }
818
+ maybe_dump_guard_evidence(".", payload)
819
+ except Exception:
820
+ pass
821
+
822
+ return result
823
+
824
+
825
+ def compute_sigma_max(weight_matrix: Any) -> float:
826
+ """
827
+ Compute maximum singular value of a weight matrix.
828
+
829
+ Args:
830
+ weight_matrix: Weight matrix to analyze
831
+
832
+ Returns:
833
+ Maximum singular value
834
+ """
835
+ try:
836
+ if isinstance(weight_matrix, torch.Tensor):
837
+ # Handle different tensor types
838
+ if weight_matrix.dtype in [torch.int8]:
839
+ # Skip quantized weights
840
+ return 1.0
841
+
842
+ # Ensure float type for SVD
843
+ W = weight_matrix.float()
844
+
845
+ # Handle edge cases
846
+ if W.numel() == 0 or W.shape[0] == 0 or W.shape[1] == 0:
847
+ return 0.0
848
+
849
+ # Compute singular values using deterministic backend when available
850
+ try:
851
+ singular_values = torch.linalg.svdvals(W)
852
+ except RuntimeError:
853
+ # Fallback for older backends without svdvals
854
+ singular_values = torch.linalg.svd(W, full_matrices=False).S
855
+
856
+ return singular_values[0].item() if singular_values.numel() > 0 else 0.0
857
+ else:
858
+ return 1.0 # Fallback for non-tensor inputs
859
+
860
+ except Exception:
861
+ return 1.0 # Fallback on any error
862
+
863
+
864
+ def auto_sigma_target(model: Any, percentile: float = 0.95, **kwargs: Any) -> float:
865
+ """
866
+ Automatically determine sigma target for a model.
867
+
868
+ Args:
869
+ model: Model to analyze
870
+ percentile: Scale factor (target percentile of spectral norms)
871
+
872
+ Returns:
873
+ Target sigma value
874
+ """
875
+ if "kappa" in kwargs and percentile == 0.95:
876
+ try:
877
+ percentile = float(kwargs["kappa"])
878
+ except (TypeError, ValueError):
879
+ pass
880
+ try:
881
+ # Collect all spectral norms
882
+ spectral_norms = []
883
+
884
+ for _name, module in model.named_modules():
885
+ if hasattr(module, "weight") and module.weight.ndim == 2:
886
+ sigma = compute_sigma_max(module.weight)
887
+ if sigma > 0:
888
+ spectral_norms.append(sigma)
889
+
890
+ if spectral_norms:
891
+ # Use kappa-percentile as target
892
+ target = np.percentile(spectral_norms, percentile * 100)
893
+ return float(target)
894
+ else:
895
+ return percentile # Fallback to requested sigma quantile
896
+
897
+ except Exception:
898
+ return percentile # Default fallback
899
+
900
+
901
+ def apply_weight_rescale(
902
+ model: Any, scale_factor: float = 1.0, scope: str = "all"
903
+ ) -> dict[str, Any]:
904
+ """
905
+ Apply weight rescaling to model parameters.
906
+
907
+ Args:
908
+ model: Model to rescale
909
+ scale_factor: Scaling factor to apply
910
+ scope: Which modules to rescale ('all', 'attn', 'ffn')
911
+
912
+ Returns:
913
+ Rescaling results
914
+ """
915
+ try:
916
+ rescaled_modules = []
917
+ failed_modules = []
918
+
919
+ for name, module in model.named_modules():
920
+ if not _should_process_module(name, module, scope):
921
+ continue
922
+
923
+ try:
924
+ if hasattr(module, "weight") and module.weight.ndim == 2:
925
+ # Skip quantized weights
926
+ if hasattr(module.weight, "dtype") and module.weight.dtype in [
927
+ torch.int8,
928
+ ]:
929
+ continue
930
+
931
+ # Apply rescaling
932
+ with torch.no_grad():
933
+ module.weight.mul_(scale_factor)
934
+ if hasattr(module, "bias") and module.bias is not None:
935
+ module.bias.mul_(scale_factor)
936
+
937
+ rescaled_modules.append(name)
938
+
939
+ except Exception as e:
940
+ failed_modules.append((name, str(e)))
941
+
942
+ return {
943
+ "applied": len(rescaled_modules) > 0,
944
+ "scale_factor": scale_factor,
945
+ "rescaled_modules": rescaled_modules,
946
+ "failed_modules": failed_modules,
947
+ "message": f"Rescaled {len(rescaled_modules)} modules with factor {scale_factor}",
948
+ }
949
+
950
+ except Exception as e:
951
+ return {
952
+ "applied": False,
953
+ "error": str(e),
954
+ "message": f"Weight rescaling failed: {e}",
955
+ }
956
+
957
+
958
+ def apply_relative_spectral_cap(
959
+ model: Any,
960
+ cap_ratio: float = 2.0,
961
+ scope: str = "all",
962
+ baseline_sigmas: dict[str, float] | None = None,
963
+ ) -> dict[str, Any]:
964
+ """
965
+ Apply relative spectral capping to model weights.
966
+
967
+ Args:
968
+ model: Model to cap
969
+ cap_ratio: Maximum allowed ratio relative to baseline
970
+ scope: Which modules to cap ('all', 'attn', 'ffn')
971
+ baseline_sigmas: Mapping of module name to pre-edit sigma values
972
+
973
+ Returns:
974
+ Capping results
975
+ """
976
+ try:
977
+ if baseline_sigmas is None:
978
+ baseline_sigmas = capture_baseline_sigmas(model, scope=scope)
979
+
980
+ capped_modules = []
981
+ failed_modules = []
982
+
983
+ for name, module in model.named_modules():
984
+ if not _should_process_module(name, module, scope):
985
+ continue
986
+
987
+ try:
988
+ if hasattr(module, "weight") and module.weight.ndim == 2:
989
+ # Skip quantized weights
990
+ if hasattr(module.weight, "dtype") and module.weight.dtype in [
991
+ torch.int8,
992
+ ]:
993
+ continue
994
+
995
+ current_sigma = compute_sigma_max(module.weight)
996
+ baseline_sigma = baseline_sigmas.get(name, current_sigma)
997
+ max_allowed = baseline_sigma * cap_ratio
998
+
999
+ if current_sigma > max_allowed:
1000
+ # Apply spectral capping using SVD
1001
+ scale_factor = max_allowed / current_sigma
1002
+
1003
+ with torch.no_grad():
1004
+ module.weight.mul_(scale_factor)
1005
+
1006
+ capped_modules.append(
1007
+ {
1008
+ "module": name,
1009
+ "original_sigma": current_sigma,
1010
+ "capped_sigma": max_allowed,
1011
+ "scale_factor": scale_factor,
1012
+ }
1013
+ )
1014
+
1015
+ except Exception as e:
1016
+ failed_modules.append((name, str(e)))
1017
+
1018
+ return {
1019
+ "applied": len(capped_modules) > 0,
1020
+ "cap_ratio": cap_ratio,
1021
+ "capped_modules": capped_modules,
1022
+ "failed_modules": failed_modules,
1023
+ "message": f"Applied spectral capping to {len(capped_modules)} modules",
1024
+ }
1025
+
1026
+ except Exception as e:
1027
+ return {
1028
+ "applied": False,
1029
+ "error": str(e),
1030
+ "message": f"Spectral capping failed: {e}",
1031
+ }
1032
+
1033
+
1034
+ def apply_spectral_control(model: Any, policy: dict[str, Any]) -> dict[str, Any]:
1035
+ """
1036
+ Apply spectral control based on policy.
1037
+
1038
+ Args:
1039
+ model: Model to control
1040
+ policy: Spectral control policy
1041
+
1042
+ Returns:
1043
+ Control results
1044
+ """
1045
+ try:
1046
+ results: dict[str, Any] = {
1047
+ "rescaling_applied": False,
1048
+ "capping_applied": False,
1049
+ "modules_processed": 0,
1050
+ "corrections": [],
1051
+ }
1052
+
1053
+ scope = policy.get("scope", "all")
1054
+ baseline_sigmas = policy.get("baseline_sigmas")
1055
+
1056
+ # Apply relative spectral capping if needed
1057
+ cap_ratio = policy.get("cap_ratio", 2.0)
1058
+ cap_result = apply_relative_spectral_cap(
1059
+ model,
1060
+ cap_ratio=cap_ratio,
1061
+ scope=scope,
1062
+ baseline_sigmas=baseline_sigmas,
1063
+ )
1064
+
1065
+ if cap_result["applied"]:
1066
+ results["capping_applied"] = True
1067
+ results["corrections"].extend(cap_result["capped_modules"])
1068
+
1069
+ # Apply rescaling if target sigma is specified
1070
+ if "rescale_factor" in policy:
1071
+ rescale_result = apply_weight_rescale(
1072
+ model, scale_factor=policy["rescale_factor"], scope=scope
1073
+ )
1074
+
1075
+ if rescale_result["applied"]:
1076
+ results["rescaling_applied"] = True
1077
+ results["modules_processed"] += len(rescale_result["rescaled_modules"])
1078
+
1079
+ results["applied"] = results["rescaling_applied"] or results["capping_applied"]
1080
+ results["policy"] = policy
1081
+ results["message"] = (
1082
+ f"Spectral control applied: capping={results['capping_applied']}, rescaling={results['rescaling_applied']}"
1083
+ )
1084
+
1085
+ return results
1086
+
1087
+ except Exception as e:
1088
+ return {
1089
+ "applied": False,
1090
+ "error": str(e),
1091
+ "policy": policy,
1092
+ "message": f"Spectral control failed: {e}",
1093
+ }
1094
+
1095
+
1096
+ def _summarize_sigmas(sigmas: dict[str, float]) -> dict[str, float]:
1097
+ """Compute summary statistics for a sigma dictionary."""
1098
+ if not sigmas:
1099
+ return {
1100
+ "max_spectral_norm": 0.0,
1101
+ "mean_spectral_norm": 0.0,
1102
+ "min_spectral_norm": 0.0,
1103
+ }
1104
+
1105
+ values = np.array(list(sigmas.values()), dtype=float)
1106
+ return {
1107
+ "max_spectral_norm": float(values.max()),
1108
+ "mean_spectral_norm": float(values.mean()),
1109
+ "min_spectral_norm": float(values.min()),
1110
+ }
1111
+
1112
+
1113
+ def compute_z_score_for_value(
1114
+ sigma: float,
1115
+ family_stats: dict[str, float],
1116
+ fallback_value: float,
1117
+ deadband: float,
1118
+ ) -> float:
1119
+ """Compute per-family z-score for a spectral norm with sensible fallbacks."""
1120
+ mean = float(family_stats.get("mean", 0.0) or 0.0)
1121
+ std = float(family_stats.get("std", 0.0) or 0.0)
1122
+
1123
+ if std > 0:
1124
+ return float((sigma - mean) / std)
1125
+
1126
+ # Fallback: scale relative change by deadband width
1127
+ denom = fallback_value if fallback_value > 0 else 1.0
1128
+ rel_change = (sigma / denom) - 1.0
1129
+
1130
+ if abs(rel_change) <= deadband:
1131
+ return 0.0
1132
+
1133
+ scale = deadband if deadband > 0 else 1.0
1134
+ return float(rel_change / scale)
1135
+
1136
+
1137
+ def compute_z_scores(
1138
+ metrics: dict[str, float],
1139
+ baseline_family_stats: dict[str, dict[str, float]],
1140
+ module_family_map: dict[str, str],
1141
+ baseline_sigmas: dict[str, float],
1142
+ deadband: float,
1143
+ ) -> dict[str, float]:
1144
+ """Compute z-scores for all modules given baseline family stats."""
1145
+ z_scores: dict[str, float] = {}
1146
+ for name, sigma in metrics.items():
1147
+ family = module_family_map.get(name, "other")
1148
+ family_stats = baseline_family_stats.get(family, {})
1149
+ fallback_value = baseline_sigmas.get(name, family_stats.get("mean", sigma))
1150
+ z_scores[name] = compute_z_score_for_value(
1151
+ float(sigma),
1152
+ family_stats,
1153
+ float(fallback_value),
1154
+ deadband=deadband,
1155
+ )
1156
+ return z_scores
1157
+
1158
+
1159
+ def summarize_family_z_scores(
1160
+ z_scores: dict[str, float],
1161
+ module_family_map: dict[str, str],
1162
+ family_caps: dict[str, dict[str, float]],
1163
+ ) -> dict[str, dict[str, float]]:
1164
+ """Summarize z-scores per family, including violation counts."""
1165
+ family_values: dict[str, list[float]] = defaultdict(list)
1166
+ for name, z in z_scores.items():
1167
+ family = module_family_map.get(name, "other")
1168
+ family_values[family].append(float(z))
1169
+
1170
+ summary: dict[str, dict[str, float]] = {}
1171
+ for family, values in family_values.items():
1172
+ if not values:
1173
+ continue
1174
+ arr = np.array(values, dtype=float)
1175
+ cap = family_caps.get(family, {}).get("kappa")
1176
+ violations = 0
1177
+ if cap is not None:
1178
+ violations = int(np.sum(arr > float(cap)))
1179
+ summary[family] = {
1180
+ "max": float(arr.max()),
1181
+ "mean": float(arr.mean()),
1182
+ "count": len(values),
1183
+ "violations": violations,
1184
+ }
1185
+ if cap is not None:
1186
+ summary[family]["kappa"] = float(cap)
1187
+ return summary
1188
+
1189
+
1190
+ def compute_family_stats(
1191
+ sigmas: dict[str, float], family_map: dict[str, str]
1192
+ ) -> dict[str, dict[str, float]]:
1193
+ """Compute per-family statistics (mean/std/min/max/count)."""
1194
+ buckets: dict[str, list[float]] = defaultdict(list)
1195
+ for name, sigma in sigmas.items():
1196
+ family = family_map.get(name, "other")
1197
+ buckets[family].append(float(sigma))
1198
+
1199
+ stats: dict[str, dict[str, float]] = {}
1200
+ for family, values in buckets.items():
1201
+ if not values:
1202
+ continue
1203
+ arr = np.array(values, dtype=float)
1204
+ stats[family] = {
1205
+ "count": len(values),
1206
+ "mean": float(arr.mean()),
1207
+ "std": float(arr.std(ddof=0)),
1208
+ "min": float(arr.min()),
1209
+ "max": float(arr.max()),
1210
+ }
1211
+ return stats
1212
+
1213
+
1214
+ def classify_model_families(
1215
+ model: Any, scope: str = "all", existing: dict[str, str] | None = None
1216
+ ) -> dict[str, str]:
1217
+ """Build or update a module→family map for the provided model."""
1218
+ family_map = dict(existing) if existing else {}
1219
+ for name, module in model.named_modules():
1220
+ if _should_process_module(name, module, scope):
1221
+ family_map[name] = classify_module_family(name, module)
1222
+ return family_map
1223
+
1224
+
1225
+ def capture_baseline_sigmas(model: Any, scope: str = "all") -> dict[str, float]:
1226
+ """
1227
+ Capture baseline singular values for model layers.
1228
+
1229
+ Args:
1230
+ model: Model to analyze
1231
+ scope: Which modules to analyze ('all', 'attn', 'ffn')
1232
+
1233
+ Returns:
1234
+ Dictionary of layer name to max singular value
1235
+ """
1236
+ try:
1237
+ baseline_sigmas = {}
1238
+
1239
+ for name, module in model.named_modules():
1240
+ if _should_process_module(name, module, scope):
1241
+ if hasattr(module, "weight") and module.weight.ndim == 2:
1242
+ sigma = compute_sigma_max(module.weight)
1243
+ baseline_sigmas[name] = sigma
1244
+
1245
+ return baseline_sigmas
1246
+
1247
+ except Exception:
1248
+ return {}
1249
+
1250
+
1251
+ def scan_model_gains(model: Any, scope: str = "all") -> dict[str, Any]:
1252
+ """
1253
+ Scan model for gain values and spectral statistics.
1254
+
1255
+ Args:
1256
+ model: Model to scan
1257
+ scope: Which modules to scan ('all', 'attn', 'ffn')
1258
+
1259
+ Returns:
1260
+ Gain analysis results
1261
+ """
1262
+ try:
1263
+ results: dict[str, Any] = {
1264
+ "total_layers": 0,
1265
+ "scanned_modules": 0,
1266
+ "spectral_norms": [],
1267
+ "condition_numbers": [],
1268
+ "weight_statistics": {},
1269
+ }
1270
+
1271
+ for name, module in model.named_modules():
1272
+ results["total_layers"] += 1
1273
+
1274
+ if _should_process_module(name, module, scope):
1275
+ if hasattr(module, "weight") and module.weight.ndim == 2:
1276
+ results["scanned_modules"] += 1
1277
+
1278
+ # Compute spectral norm
1279
+ sigma_max = compute_sigma_max(module.weight)
1280
+ results["spectral_norms"].append(sigma_max)
1281
+
1282
+ # Compute condition number if possible
1283
+ try:
1284
+ U, S, V = torch.svd(module.weight.float())
1285
+ if len(S) > 1:
1286
+ condition_num = (S[0] / S[-1]).item()
1287
+ results["condition_numbers"].append(condition_num)
1288
+ except Exception:
1289
+ pass
1290
+
1291
+ # Basic weight statistics
1292
+ try:
1293
+ weight_stats = {
1294
+ "mean": module.weight.mean().item(),
1295
+ "std": module.weight.std().item(),
1296
+ "min": module.weight.min().item(),
1297
+ "max": module.weight.max().item(),
1298
+ }
1299
+ results["weight_statistics"][name] = weight_stats
1300
+ except Exception:
1301
+ pass
1302
+
1303
+ # Compute summary statistics
1304
+ if results["spectral_norms"]:
1305
+ results["mean_spectral_norm"] = np.mean(results["spectral_norms"])
1306
+ results["max_spectral_norm"] = np.max(results["spectral_norms"])
1307
+ results["min_spectral_norm"] = np.min(results["spectral_norms"])
1308
+
1309
+ if results["condition_numbers"]:
1310
+ results["mean_condition_number"] = np.mean(results["condition_numbers"])
1311
+ results["max_condition_number"] = np.max(results["condition_numbers"])
1312
+
1313
+ results["message"] = (
1314
+ f"Scanned {results['scanned_modules']} modules out of {results['total_layers']} total layers"
1315
+ )
1316
+
1317
+ return results
1318
+
1319
+ except Exception as e:
1320
+ return {
1321
+ "total_layers": sum(1 for _ in model.named_modules()),
1322
+ "scanned_modules": 0,
1323
+ "error": str(e),
1324
+ "message": f"Model scanning failed: {e}",
1325
+ }
1326
+
1327
+
1328
+ def _should_process_module(name: str, module: Any, scope: str) -> bool:
1329
+ """Helper function to determine if a module should be processed based on scope."""
1330
+ if not hasattr(module, "weight") or module.weight.ndim != 2:
1331
+ return False
1332
+
1333
+ if scope == "all":
1334
+ return True
1335
+ elif scope == "attn":
1336
+ return any(
1337
+ keyword in name.lower()
1338
+ for keyword in ["attn", "attention", "self_attn", "c_attn", "c_proj"]
1339
+ )
1340
+ elif scope == "ffn":
1341
+ return any(
1342
+ keyword in name.lower()
1343
+ for keyword in ["mlp", "ffn", "feed_forward", "fc", "c_fc"]
1344
+ )
1345
+ elif scope == "ffn+proj":
1346
+ lname = name.lower()
1347
+ return any(
1348
+ keyword in lname
1349
+ for keyword in [
1350
+ "mlp",
1351
+ "ffn",
1352
+ "feed_forward",
1353
+ "fc",
1354
+ "c_fc",
1355
+ "c_proj",
1356
+ "projection",
1357
+ ]
1358
+ )
1359
+
1360
+ return True
1361
+
1362
+
1363
+ def classify_module_family(name: str, module: Any) -> str:
1364
+ """Classify module into a spectral family for policy purposes."""
1365
+ lname = name.lower()
1366
+
1367
+ # MoE router/gating
1368
+ if any(
1369
+ tok in lname
1370
+ for tok in ("router", "routing", "gate", "gating", "dispatch", "switch")
1371
+ ):
1372
+ return "router"
1373
+ # MoE expert FFN
1374
+ if any(tok in lname for tok in ("experts", "expert", "moe", "mixture_of_experts")):
1375
+ return "expert_ffn"
1376
+
1377
+ if "mlp" in lname or "ffn" in lname or "feed_forward" in lname:
1378
+ return "ffn"
1379
+
1380
+ if (
1381
+ "attn" in lname
1382
+ or "attention" in lname
1383
+ or any(
1384
+ token in lname
1385
+ for token in ["q_proj", "k_proj", "v_proj", "o_proj", "c_attn"]
1386
+ )
1387
+ ):
1388
+ return "attn"
1389
+
1390
+ if "embed" in lname or "wte" in lname or "embedding" in lname:
1391
+ return "embed"
1392
+
1393
+ module_type = module.__class__.__name__.lower()
1394
+ if "embedding" in module_type:
1395
+ return "embed"
1396
+ if "conv1d" in module_type or "linear" in module_type:
1397
+ if "attn" in lname:
1398
+ return "attn"
1399
+ if "mlp" in lname or "ffn" in lname:
1400
+ return "ffn"
1401
+
1402
+ return "other"
1403
+
1404
+
1405
+ # Export the main components
1406
+ __all__ = [
1407
+ "SpectralGuard",
1408
+ "SpectralPolicy",
1409
+ "compute_sigma_max",
1410
+ "auto_sigma_target",
1411
+ "apply_weight_rescale",
1412
+ "apply_relative_spectral_cap",
1413
+ "apply_spectral_control",
1414
+ "capture_baseline_sigmas",
1415
+ "scan_model_gains",
1416
+ "compute_family_stats",
1417
+ "summarize_family_z_scores",
1418
+ "classify_module_family",
1419
+ ]