invarlock 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- invarlock/__init__.py +33 -0
- invarlock/__main__.py +10 -0
- invarlock/_data/runtime/profiles/ci_cpu.yaml +15 -0
- invarlock/_data/runtime/profiles/release.yaml +23 -0
- invarlock/_data/runtime/tiers.yaml +76 -0
- invarlock/adapters/__init__.py +102 -0
- invarlock/adapters/_capabilities.py +45 -0
- invarlock/adapters/auto.py +99 -0
- invarlock/adapters/base.py +530 -0
- invarlock/adapters/base_types.py +85 -0
- invarlock/adapters/hf_bert.py +852 -0
- invarlock/adapters/hf_gpt2.py +403 -0
- invarlock/adapters/hf_llama.py +485 -0
- invarlock/adapters/hf_mixin.py +383 -0
- invarlock/adapters/hf_onnx.py +112 -0
- invarlock/adapters/hf_t5.py +137 -0
- invarlock/adapters/py.typed +1 -0
- invarlock/assurance/__init__.py +43 -0
- invarlock/cli/__init__.py +8 -0
- invarlock/cli/__main__.py +8 -0
- invarlock/cli/_evidence.py +25 -0
- invarlock/cli/_json.py +75 -0
- invarlock/cli/adapter_auto.py +162 -0
- invarlock/cli/app.py +287 -0
- invarlock/cli/commands/__init__.py +26 -0
- invarlock/cli/commands/certify.py +403 -0
- invarlock/cli/commands/doctor.py +1358 -0
- invarlock/cli/commands/explain_gates.py +151 -0
- invarlock/cli/commands/export_html.py +100 -0
- invarlock/cli/commands/plugins.py +1331 -0
- invarlock/cli/commands/report.py +354 -0
- invarlock/cli/commands/run.py +4146 -0
- invarlock/cli/commands/verify.py +1040 -0
- invarlock/cli/config.py +396 -0
- invarlock/cli/constants.py +68 -0
- invarlock/cli/device.py +92 -0
- invarlock/cli/doctor_helpers.py +74 -0
- invarlock/cli/errors.py +6 -0
- invarlock/cli/overhead_utils.py +60 -0
- invarlock/cli/provenance.py +66 -0
- invarlock/cli/utils.py +41 -0
- invarlock/config.py +56 -0
- invarlock/core/__init__.py +62 -0
- invarlock/core/abi.py +15 -0
- invarlock/core/api.py +274 -0
- invarlock/core/auto_tuning.py +317 -0
- invarlock/core/bootstrap.py +226 -0
- invarlock/core/checkpoint.py +221 -0
- invarlock/core/contracts.py +73 -0
- invarlock/core/error_utils.py +64 -0
- invarlock/core/events.py +298 -0
- invarlock/core/exceptions.py +95 -0
- invarlock/core/registry.py +481 -0
- invarlock/core/retry.py +146 -0
- invarlock/core/runner.py +2041 -0
- invarlock/core/types.py +154 -0
- invarlock/edits/__init__.py +12 -0
- invarlock/edits/_edit_utils.py +249 -0
- invarlock/edits/_external_utils.py +268 -0
- invarlock/edits/noop.py +47 -0
- invarlock/edits/py.typed +1 -0
- invarlock/edits/quant_rtn.py +801 -0
- invarlock/edits/registry.py +166 -0
- invarlock/eval/__init__.py +23 -0
- invarlock/eval/bench.py +1207 -0
- invarlock/eval/bootstrap.py +50 -0
- invarlock/eval/data.py +2052 -0
- invarlock/eval/metrics.py +2167 -0
- invarlock/eval/primary_metric.py +767 -0
- invarlock/eval/probes/__init__.py +24 -0
- invarlock/eval/probes/fft.py +139 -0
- invarlock/eval/probes/mi.py +213 -0
- invarlock/eval/probes/post_attention.py +323 -0
- invarlock/eval/providers/base.py +67 -0
- invarlock/eval/providers/seq2seq.py +111 -0
- invarlock/eval/providers/text_lm.py +113 -0
- invarlock/eval/providers/vision_text.py +93 -0
- invarlock/eval/py.typed +1 -0
- invarlock/guards/__init__.py +18 -0
- invarlock/guards/_contracts.py +9 -0
- invarlock/guards/invariants.py +640 -0
- invarlock/guards/policies.py +805 -0
- invarlock/guards/py.typed +1 -0
- invarlock/guards/rmt.py +2097 -0
- invarlock/guards/spectral.py +1419 -0
- invarlock/guards/tier_config.py +354 -0
- invarlock/guards/variance.py +3298 -0
- invarlock/guards_ref/__init__.py +15 -0
- invarlock/guards_ref/rmt_ref.py +40 -0
- invarlock/guards_ref/spectral_ref.py +135 -0
- invarlock/guards_ref/variance_ref.py +60 -0
- invarlock/model_profile.py +353 -0
- invarlock/model_utils.py +221 -0
- invarlock/observability/__init__.py +10 -0
- invarlock/observability/alerting.py +535 -0
- invarlock/observability/core.py +546 -0
- invarlock/observability/exporters.py +565 -0
- invarlock/observability/health.py +588 -0
- invarlock/observability/metrics.py +457 -0
- invarlock/observability/py.typed +1 -0
- invarlock/observability/utils.py +553 -0
- invarlock/plugins/__init__.py +12 -0
- invarlock/plugins/hello_guard.py +33 -0
- invarlock/plugins/hf_awq_adapter.py +82 -0
- invarlock/plugins/hf_bnb_adapter.py +79 -0
- invarlock/plugins/hf_gptq_adapter.py +78 -0
- invarlock/plugins/py.typed +1 -0
- invarlock/py.typed +1 -0
- invarlock/reporting/__init__.py +7 -0
- invarlock/reporting/certificate.py +3221 -0
- invarlock/reporting/certificate_schema.py +244 -0
- invarlock/reporting/dataset_hashing.py +215 -0
- invarlock/reporting/guards_analysis.py +948 -0
- invarlock/reporting/html.py +32 -0
- invarlock/reporting/normalizer.py +235 -0
- invarlock/reporting/policy_utils.py +517 -0
- invarlock/reporting/primary_metric_utils.py +265 -0
- invarlock/reporting/render.py +1442 -0
- invarlock/reporting/report.py +903 -0
- invarlock/reporting/report_types.py +278 -0
- invarlock/reporting/utils.py +175 -0
- invarlock/reporting/validate.py +631 -0
- invarlock/security.py +176 -0
- invarlock/sparsity_utils.py +323 -0
- invarlock/utils/__init__.py +150 -0
- invarlock/utils/digest.py +45 -0
- invarlock-0.2.0.dist-info/METADATA +586 -0
- invarlock-0.2.0.dist-info/RECORD +132 -0
- invarlock-0.2.0.dist-info/WHEEL +5 -0
- invarlock-0.2.0.dist-info/entry_points.txt +20 -0
- invarlock-0.2.0.dist-info/licenses/LICENSE +201 -0
- invarlock-0.2.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,1419 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Spectral Guard Implementation
|
|
3
|
+
============================
|
|
4
|
+
|
|
5
|
+
Monitors spectral properties of model weights to detect instabilities.
|
|
6
|
+
Provides spectral control mechanisms for maintaining numerical stability.
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
import math
|
|
10
|
+
import time
|
|
11
|
+
from collections import defaultdict
|
|
12
|
+
from datetime import datetime
|
|
13
|
+
from typing import Any, TypedDict
|
|
14
|
+
|
|
15
|
+
try:
|
|
16
|
+
from typing import NotRequired
|
|
17
|
+
except ImportError: # Python 3.10 fallback
|
|
18
|
+
from typing import NotRequired
|
|
19
|
+
|
|
20
|
+
import numpy as np
|
|
21
|
+
import torch
|
|
22
|
+
|
|
23
|
+
from invarlock.cli._evidence import maybe_dump_guard_evidence
|
|
24
|
+
from invarlock.core.api import Guard
|
|
25
|
+
|
|
26
|
+
from ._contracts import guard_assert
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class SpectralPolicy(TypedDict, total=False):
|
|
30
|
+
"""Type definition for spectral guard policy configuration."""
|
|
31
|
+
|
|
32
|
+
sigma_quantile: float
|
|
33
|
+
contraction: NotRequired[float] # Backward compatibility alias
|
|
34
|
+
kappa: NotRequired[float] # Legacy alias
|
|
35
|
+
deadband: float
|
|
36
|
+
scope: str
|
|
37
|
+
correction_enabled: bool
|
|
38
|
+
family_caps: dict[str, dict[str, float]]
|
|
39
|
+
ignore_preview_inflation: bool
|
|
40
|
+
max_caps: int
|
|
41
|
+
multiple_testing: dict[str, Any]
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
def _default_family_caps() -> dict[str, dict[str, float]]:
|
|
45
|
+
"""Default per-family spectral z-score caps."""
|
|
46
|
+
return {
|
|
47
|
+
"ffn": {"kappa": 2.5},
|
|
48
|
+
"attn": {"kappa": 2.8},
|
|
49
|
+
"embed": {"kappa": 3.0},
|
|
50
|
+
"other": {"kappa": 3.0},
|
|
51
|
+
}
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
def _normalize_family_caps(
|
|
55
|
+
caps: Any, *, default: bool = True
|
|
56
|
+
) -> dict[str, dict[str, float]]:
|
|
57
|
+
"""Normalize family cap configuration into canonical mapping."""
|
|
58
|
+
|
|
59
|
+
if not isinstance(caps, dict) or not caps:
|
|
60
|
+
return _default_family_caps() if default else {}
|
|
61
|
+
|
|
62
|
+
normalized: dict[str, dict[str, float]] = {}
|
|
63
|
+
for family, values in caps.items():
|
|
64
|
+
entry: dict[str, float] = {}
|
|
65
|
+
if isinstance(values, dict):
|
|
66
|
+
for key, val in values.items():
|
|
67
|
+
if isinstance(val, int | float) and math.isfinite(float(val)):
|
|
68
|
+
entry[str(key)] = float(val)
|
|
69
|
+
elif isinstance(values, int | float) and math.isfinite(float(values)):
|
|
70
|
+
entry["kappa"] = float(values)
|
|
71
|
+
if entry:
|
|
72
|
+
normalized[str(family)] = entry
|
|
73
|
+
|
|
74
|
+
if normalized:
|
|
75
|
+
return normalized
|
|
76
|
+
|
|
77
|
+
return _default_family_caps() if default else {}
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
class SpectralGuard(Guard):
|
|
81
|
+
"""
|
|
82
|
+
Spectral guard for monitoring weight matrix spectral properties.
|
|
83
|
+
|
|
84
|
+
Tracks singular values and spectral norms to detect numerical instabilities.
|
|
85
|
+
Provides automatic spectral control when violations are detected.
|
|
86
|
+
"""
|
|
87
|
+
|
|
88
|
+
name = "spectral"
|
|
89
|
+
|
|
90
|
+
def __init__(self, **kwargs):
|
|
91
|
+
"""Initialize spectral guard."""
|
|
92
|
+
self.config = dict(kwargs)
|
|
93
|
+
self.prepared = False
|
|
94
|
+
self.baseline_metrics = {}
|
|
95
|
+
self.events = []
|
|
96
|
+
self.current_metrics = {}
|
|
97
|
+
self.violations = []
|
|
98
|
+
|
|
99
|
+
# Default configuration
|
|
100
|
+
sigma_quantile = self.config.get("sigma_quantile")
|
|
101
|
+
if sigma_quantile is None:
|
|
102
|
+
for alias in ("contraction", "kappa"):
|
|
103
|
+
if self.config.get(alias) is not None:
|
|
104
|
+
sigma_quantile = self.config[alias]
|
|
105
|
+
break
|
|
106
|
+
if sigma_quantile is None:
|
|
107
|
+
sigma_quantile = 0.95
|
|
108
|
+
self.sigma_quantile = float(sigma_quantile)
|
|
109
|
+
self.config["sigma_quantile"] = self.sigma_quantile
|
|
110
|
+
self.config.pop("contraction", None)
|
|
111
|
+
self.config.pop("kappa", None)
|
|
112
|
+
self.deadband = kwargs.get("deadband", 0.10)
|
|
113
|
+
self.scope = kwargs.get("scope", "all") # 'all', 'ffn', 'attn'
|
|
114
|
+
self.max_spectral_norm = kwargs.get("max_spectral_norm", 10.0)
|
|
115
|
+
if self.max_spectral_norm is not None:
|
|
116
|
+
self.max_spectral_norm = float(self.max_spectral_norm)
|
|
117
|
+
self.config["max_spectral_norm"] = self.max_spectral_norm
|
|
118
|
+
self.min_condition_number = kwargs.get("min_condition_number", 1e-12)
|
|
119
|
+
self.correction_enabled = kwargs.get("correction_enabled", True)
|
|
120
|
+
self.family_caps = _normalize_family_caps(
|
|
121
|
+
kwargs.get("family_caps"), default=True
|
|
122
|
+
)
|
|
123
|
+
self.ignore_preview_inflation = kwargs.get("ignore_preview_inflation", True)
|
|
124
|
+
self.max_caps = kwargs.get("max_caps", 5)
|
|
125
|
+
self.multiple_testing = kwargs.get(
|
|
126
|
+
"multiple_testing", {"method": "bh", "alpha": 0.05, "m": 4}
|
|
127
|
+
)
|
|
128
|
+
|
|
129
|
+
# Baseline and tracking structures
|
|
130
|
+
self.baseline_sigmas: dict[str, float] = {}
|
|
131
|
+
self.baseline_family_stats: dict[str, dict[str, float]] = {}
|
|
132
|
+
self.module_family_map: dict[str, str] = {}
|
|
133
|
+
self.latest_z_scores: dict[str, float] = {}
|
|
134
|
+
self.pre_edit_z_scores: dict[str, float] = {}
|
|
135
|
+
|
|
136
|
+
def _log_event(
|
|
137
|
+
self, operation: str, level: str = "INFO", message: str = "", **data
|
|
138
|
+
):
|
|
139
|
+
"""Log an event with timestamp."""
|
|
140
|
+
event = {
|
|
141
|
+
"timestamp": datetime.utcnow().isoformat(),
|
|
142
|
+
"component": "spectral_guard",
|
|
143
|
+
"operation": operation,
|
|
144
|
+
"level": level,
|
|
145
|
+
"message": message,
|
|
146
|
+
"data": data,
|
|
147
|
+
}
|
|
148
|
+
self.events.append(event)
|
|
149
|
+
|
|
150
|
+
def _serialize_policy(self) -> dict[str, Any]:
|
|
151
|
+
"""Snapshot current guard policy for report serialization."""
|
|
152
|
+
return {
|
|
153
|
+
"scope": self.scope,
|
|
154
|
+
"sigma_quantile": float(self.sigma_quantile),
|
|
155
|
+
"deadband": float(self.deadband),
|
|
156
|
+
"max_caps": int(self.max_caps),
|
|
157
|
+
"max_spectral_norm": (
|
|
158
|
+
float(self.max_spectral_norm)
|
|
159
|
+
if self.max_spectral_norm is not None
|
|
160
|
+
else None
|
|
161
|
+
),
|
|
162
|
+
"family_caps": self.family_caps,
|
|
163
|
+
"multiple_testing": self.multiple_testing,
|
|
164
|
+
"correction_enabled": bool(self.correction_enabled),
|
|
165
|
+
"ignore_preview_inflation": bool(self.ignore_preview_inflation),
|
|
166
|
+
}
|
|
167
|
+
|
|
168
|
+
def prepare(
|
|
169
|
+
self, model: Any, adapter: Any, calib: Any, policy: dict[str, Any]
|
|
170
|
+
) -> dict[str, Any]:
|
|
171
|
+
"""
|
|
172
|
+
Prepare spectral guard by capturing baseline spectral properties.
|
|
173
|
+
|
|
174
|
+
Args:
|
|
175
|
+
model: Model to prepare for
|
|
176
|
+
adapter: ModelAdapter instance
|
|
177
|
+
calib: Calibration data (unused for spectral analysis)
|
|
178
|
+
policy: Policy configuration
|
|
179
|
+
|
|
180
|
+
Returns:
|
|
181
|
+
Preparation results
|
|
182
|
+
"""
|
|
183
|
+
start_time = time.time()
|
|
184
|
+
|
|
185
|
+
# Update configuration from policy
|
|
186
|
+
if policy:
|
|
187
|
+
sigma_value = policy.get("sigma_quantile")
|
|
188
|
+
if sigma_value is None:
|
|
189
|
+
alias_value = policy.get("contraction", policy.get("kappa"))
|
|
190
|
+
if alias_value is not None:
|
|
191
|
+
sigma_value = alias_value
|
|
192
|
+
if sigma_value is not None:
|
|
193
|
+
self.sigma_quantile = float(sigma_value)
|
|
194
|
+
policy["sigma_quantile"] = self.sigma_quantile
|
|
195
|
+
policy.pop("contraction", None)
|
|
196
|
+
policy.pop("kappa", None)
|
|
197
|
+
self.config["sigma_quantile"] = self.sigma_quantile
|
|
198
|
+
|
|
199
|
+
for key in [
|
|
200
|
+
"sigma_quantile",
|
|
201
|
+
"deadband",
|
|
202
|
+
"scope",
|
|
203
|
+
"max_spectral_norm",
|
|
204
|
+
"correction_enabled",
|
|
205
|
+
"max_caps",
|
|
206
|
+
]:
|
|
207
|
+
if key in policy:
|
|
208
|
+
setattr(self, key, policy[key])
|
|
209
|
+
self.config[key] = policy[key]
|
|
210
|
+
|
|
211
|
+
if self.max_spectral_norm is not None:
|
|
212
|
+
self.max_spectral_norm = float(self.max_spectral_norm)
|
|
213
|
+
self.config["max_spectral_norm"] = self.max_spectral_norm
|
|
214
|
+
|
|
215
|
+
if "family_caps" in policy:
|
|
216
|
+
self.family_caps = _normalize_family_caps(
|
|
217
|
+
policy["family_caps"], default=True
|
|
218
|
+
)
|
|
219
|
+
self.config["family_caps"] = self.family_caps
|
|
220
|
+
|
|
221
|
+
if "ignore_preview_inflation" in policy:
|
|
222
|
+
self.ignore_preview_inflation = bool(policy["ignore_preview_inflation"])
|
|
223
|
+
self.config["ignore_preview_inflation"] = self.ignore_preview_inflation
|
|
224
|
+
|
|
225
|
+
# Optional hydration of baseline stats from policy (e.g., baseline certificate)
|
|
226
|
+
if "baseline_family_stats" in policy and isinstance(
|
|
227
|
+
policy["baseline_family_stats"], dict
|
|
228
|
+
):
|
|
229
|
+
self.baseline_family_stats = {
|
|
230
|
+
family: stats.copy()
|
|
231
|
+
for family, stats in policy["baseline_family_stats"].items()
|
|
232
|
+
if isinstance(stats, dict)
|
|
233
|
+
}
|
|
234
|
+
self.config["baseline_family_stats"] = self.baseline_family_stats
|
|
235
|
+
mt_policy = policy.get("multiple_testing")
|
|
236
|
+
if mt_policy is None:
|
|
237
|
+
mt_policy = policy.get("multipletesting")
|
|
238
|
+
if isinstance(mt_policy, dict):
|
|
239
|
+
self.multiple_testing = mt_policy.copy()
|
|
240
|
+
policy["multiple_testing"] = self.multiple_testing
|
|
241
|
+
self.config["multiple_testing"] = self.multiple_testing
|
|
242
|
+
policy.pop("multipletesting", None)
|
|
243
|
+
|
|
244
|
+
self._log_event(
|
|
245
|
+
"prepare",
|
|
246
|
+
message=(
|
|
247
|
+
f"Preparing spectral guard with scope={self.scope}, "
|
|
248
|
+
f"sigma_quantile={self.sigma_quantile}"
|
|
249
|
+
),
|
|
250
|
+
)
|
|
251
|
+
|
|
252
|
+
try:
|
|
253
|
+
# Capture baseline spectral properties
|
|
254
|
+
self.baseline_sigmas = capture_baseline_sigmas(model, scope=self.scope)
|
|
255
|
+
self.module_family_map = classify_model_families(
|
|
256
|
+
model, scope=self.scope, existing=self.module_family_map
|
|
257
|
+
)
|
|
258
|
+
if not self.baseline_family_stats:
|
|
259
|
+
self.baseline_family_stats = compute_family_stats(
|
|
260
|
+
self.baseline_sigmas, self.module_family_map
|
|
261
|
+
)
|
|
262
|
+
|
|
263
|
+
# Compute additional baseline metrics
|
|
264
|
+
baseline_stats = scan_model_gains(model, scope=self.scope)
|
|
265
|
+
summarized = _summarize_sigmas(self.baseline_sigmas)
|
|
266
|
+
baseline_stats.update(summarized)
|
|
267
|
+
|
|
268
|
+
# Store target sigma value
|
|
269
|
+
self.target_sigma = auto_sigma_target(model, percentile=self.sigma_quantile)
|
|
270
|
+
baseline_stats["target_sigma"] = self.target_sigma
|
|
271
|
+
|
|
272
|
+
baseline_stats["family_stats"] = {
|
|
273
|
+
family: stats.copy()
|
|
274
|
+
for family, stats in self.baseline_family_stats.items()
|
|
275
|
+
}
|
|
276
|
+
baseline_stats["family_caps"] = {
|
|
277
|
+
family: caps.copy() for family, caps in self.family_caps.items()
|
|
278
|
+
}
|
|
279
|
+
baseline_stats["module_sigmas"] = self.baseline_sigmas.copy()
|
|
280
|
+
|
|
281
|
+
self.baseline_metrics = baseline_stats
|
|
282
|
+
|
|
283
|
+
self.prepared = True
|
|
284
|
+
preparation_time = time.time() - start_time
|
|
285
|
+
|
|
286
|
+
self._log_event(
|
|
287
|
+
"prepare_success",
|
|
288
|
+
message=f"Prepared spectral guard with {len(self.baseline_metrics)} baseline metrics",
|
|
289
|
+
baseline_metrics_count=len(self.baseline_metrics),
|
|
290
|
+
target_sigma=self.target_sigma,
|
|
291
|
+
preparation_time=preparation_time,
|
|
292
|
+
)
|
|
293
|
+
|
|
294
|
+
return {
|
|
295
|
+
"ready": True,
|
|
296
|
+
"baseline_metrics": self.baseline_metrics,
|
|
297
|
+
"target_sigma": self.target_sigma,
|
|
298
|
+
"scope": self.scope,
|
|
299
|
+
"preparation_time": preparation_time,
|
|
300
|
+
}
|
|
301
|
+
|
|
302
|
+
except Exception as e:
|
|
303
|
+
self.prepared = False
|
|
304
|
+
self._log_event(
|
|
305
|
+
"prepare_failed",
|
|
306
|
+
level="ERROR",
|
|
307
|
+
message=f"Failed to prepare spectral guard: {str(e)}",
|
|
308
|
+
error=str(e),
|
|
309
|
+
)
|
|
310
|
+
|
|
311
|
+
return {
|
|
312
|
+
"ready": False,
|
|
313
|
+
"error": str(e),
|
|
314
|
+
"preparation_time": time.time() - start_time,
|
|
315
|
+
}
|
|
316
|
+
|
|
317
|
+
def before_edit(self, model: Any) -> None:
|
|
318
|
+
"""Execute before edit (capture pre-edit state)."""
|
|
319
|
+
if not self.prepared:
|
|
320
|
+
self._log_event(
|
|
321
|
+
"before_edit_skipped",
|
|
322
|
+
level="WARN",
|
|
323
|
+
message="Spectral guard not prepared, skipping pre-edit capture",
|
|
324
|
+
)
|
|
325
|
+
return
|
|
326
|
+
|
|
327
|
+
# Capture pre-edit spectral state for comparison
|
|
328
|
+
self.pre_edit_metrics = capture_baseline_sigmas(model, scope=self.scope)
|
|
329
|
+
self.pre_edit_z_scores = compute_z_scores(
|
|
330
|
+
self.pre_edit_metrics,
|
|
331
|
+
self.baseline_family_stats,
|
|
332
|
+
self.module_family_map,
|
|
333
|
+
self.baseline_sigmas,
|
|
334
|
+
deadband=self.deadband,
|
|
335
|
+
)
|
|
336
|
+
self._log_event("before_edit", message="Captured pre-edit spectral state")
|
|
337
|
+
|
|
338
|
+
def after_edit(self, model: Any) -> None:
|
|
339
|
+
"""Execute after edit (detect violations and apply control)."""
|
|
340
|
+
if not self.prepared:
|
|
341
|
+
self._log_event(
|
|
342
|
+
"after_edit_skipped",
|
|
343
|
+
level="WARN",
|
|
344
|
+
message="Spectral guard not prepared, skipping post-edit analysis",
|
|
345
|
+
)
|
|
346
|
+
return
|
|
347
|
+
|
|
348
|
+
try:
|
|
349
|
+
# Capture current spectral state
|
|
350
|
+
self.current_metrics = capture_baseline_sigmas(model, scope=self.scope)
|
|
351
|
+
|
|
352
|
+
# Detect violations
|
|
353
|
+
violations = self._detect_spectral_violations(
|
|
354
|
+
model, self.current_metrics, phase="after_edit"
|
|
355
|
+
)
|
|
356
|
+
self.violations = violations
|
|
357
|
+
|
|
358
|
+
# Apply spectral control if violations detected and correction enabled
|
|
359
|
+
if violations and self.correction_enabled:
|
|
360
|
+
control_result = apply_spectral_control(
|
|
361
|
+
model,
|
|
362
|
+
policy={
|
|
363
|
+
"sigma_quantile": self.sigma_quantile,
|
|
364
|
+
"scope": self.scope,
|
|
365
|
+
"baseline_sigmas": self.baseline_sigmas,
|
|
366
|
+
"target_sigma": self.target_sigma,
|
|
367
|
+
},
|
|
368
|
+
)
|
|
369
|
+
|
|
370
|
+
self._log_event(
|
|
371
|
+
"spectral_control_applied",
|
|
372
|
+
message=f"Applied spectral control, violations: {len(violations)}",
|
|
373
|
+
violations_count=len(violations),
|
|
374
|
+
control_result=control_result,
|
|
375
|
+
)
|
|
376
|
+
|
|
377
|
+
self._log_event(
|
|
378
|
+
"after_edit",
|
|
379
|
+
message=f"Post-edit analysis complete, {len(violations)} violations detected",
|
|
380
|
+
)
|
|
381
|
+
|
|
382
|
+
except Exception as e:
|
|
383
|
+
self._log_event(
|
|
384
|
+
"after_edit_failed",
|
|
385
|
+
level="ERROR",
|
|
386
|
+
message=f"Post-edit spectral analysis failed: {str(e)}",
|
|
387
|
+
error=str(e),
|
|
388
|
+
)
|
|
389
|
+
|
|
390
|
+
def _detect_spectral_violations(
|
|
391
|
+
self, model: Any, metrics: dict[str, float], phase: str = "finalize"
|
|
392
|
+
) -> list[dict[str, Any]]:
|
|
393
|
+
"""Detect spectral property violations using per-family z-score caps."""
|
|
394
|
+
violations: list[dict[str, Any]] = []
|
|
395
|
+
latest_z: dict[str, float] = {}
|
|
396
|
+
|
|
397
|
+
for name, module in model.named_modules():
|
|
398
|
+
if not self._should_check_module(name, module):
|
|
399
|
+
continue
|
|
400
|
+
|
|
401
|
+
try:
|
|
402
|
+
if hasattr(module, "weight") and module.weight.ndim == 2:
|
|
403
|
+
sigma_max = metrics.get(name)
|
|
404
|
+
if sigma_max is None:
|
|
405
|
+
sigma_max = compute_sigma_max(module.weight)
|
|
406
|
+
|
|
407
|
+
baseline_sigma = self.baseline_sigmas.get(name, self.target_sigma)
|
|
408
|
+
family = self.module_family_map.get(name)
|
|
409
|
+
if family is None:
|
|
410
|
+
family = classify_module_family(name, module)
|
|
411
|
+
self.module_family_map[name] = family
|
|
412
|
+
|
|
413
|
+
family_stats = self.baseline_family_stats.get(family, {})
|
|
414
|
+
cap_config = self.family_caps.get(family, {})
|
|
415
|
+
kappa_cap = float(cap_config.get("kappa", self.sigma_quantile))
|
|
416
|
+
|
|
417
|
+
z_score = compute_z_score_for_value(
|
|
418
|
+
sigma_max,
|
|
419
|
+
family_stats,
|
|
420
|
+
fallback_value=baseline_sigma,
|
|
421
|
+
deadband=self.deadband,
|
|
422
|
+
)
|
|
423
|
+
latest_z[name] = z_score
|
|
424
|
+
|
|
425
|
+
# Skip preview inflation if configured and not in final phase
|
|
426
|
+
if self.ignore_preview_inflation and phase == "after_edit":
|
|
427
|
+
continue
|
|
428
|
+
|
|
429
|
+
if z_score > kappa_cap:
|
|
430
|
+
violations.append(
|
|
431
|
+
{
|
|
432
|
+
"type": "family_z_cap",
|
|
433
|
+
"module": name,
|
|
434
|
+
"family": family,
|
|
435
|
+
"z_score": float(z_score),
|
|
436
|
+
"kappa": kappa_cap,
|
|
437
|
+
"sigma": float(sigma_max),
|
|
438
|
+
"baseline_sigma": float(baseline_sigma),
|
|
439
|
+
"message": (
|
|
440
|
+
f"Family '{family}' z-score {z_score:.2f}"
|
|
441
|
+
f" exceeds cap {kappa_cap:.2f}"
|
|
442
|
+
),
|
|
443
|
+
}
|
|
444
|
+
)
|
|
445
|
+
|
|
446
|
+
if (
|
|
447
|
+
self.max_spectral_norm is not None
|
|
448
|
+
and sigma_max > self.max_spectral_norm
|
|
449
|
+
):
|
|
450
|
+
threshold = float(self.max_spectral_norm)
|
|
451
|
+
violations.append(
|
|
452
|
+
{
|
|
453
|
+
"type": "max_spectral_norm",
|
|
454
|
+
"module": name,
|
|
455
|
+
"family": family,
|
|
456
|
+
"current_sigma": float(sigma_max),
|
|
457
|
+
"threshold": threshold,
|
|
458
|
+
"message": f"Spectral norm {sigma_max:.3f} exceeds maximum {threshold}",
|
|
459
|
+
}
|
|
460
|
+
)
|
|
461
|
+
|
|
462
|
+
# Condition number monitoring (warn only)
|
|
463
|
+
try:
|
|
464
|
+
U, S, V = torch.svd(module.weight.float())
|
|
465
|
+
if len(S) > 0:
|
|
466
|
+
condition_number = S[0].item() / max(S[-1].item(), 1e-12)
|
|
467
|
+
if S[-1].item() < self.min_condition_number:
|
|
468
|
+
violations.append(
|
|
469
|
+
{
|
|
470
|
+
"type": "ill_conditioned",
|
|
471
|
+
"module": name,
|
|
472
|
+
"family": family,
|
|
473
|
+
"condition_number": float(condition_number),
|
|
474
|
+
"min_singular_value": float(S[-1].item()),
|
|
475
|
+
"threshold": float(self.min_condition_number),
|
|
476
|
+
"message": f"Matrix is ill-conditioned, min singular value: {S[-1].item():.2e}",
|
|
477
|
+
}
|
|
478
|
+
)
|
|
479
|
+
except Exception:
|
|
480
|
+
pass # SVD failure is not a violation
|
|
481
|
+
|
|
482
|
+
except Exception as e:
|
|
483
|
+
self._log_event(
|
|
484
|
+
"violation_check_error",
|
|
485
|
+
level="WARN",
|
|
486
|
+
message=f"Failed to check module {name}: {str(e)}",
|
|
487
|
+
module=name,
|
|
488
|
+
error=str(e),
|
|
489
|
+
)
|
|
490
|
+
|
|
491
|
+
self.latest_z_scores = latest_z
|
|
492
|
+
return violations
|
|
493
|
+
|
|
494
|
+
def _should_check_module(self, name: str, module: Any) -> bool:
|
|
495
|
+
"""Determine if a module should be checked based on scope."""
|
|
496
|
+
if not hasattr(module, "weight") or module.weight.ndim != 2:
|
|
497
|
+
return False
|
|
498
|
+
|
|
499
|
+
if self.scope == "all":
|
|
500
|
+
return True
|
|
501
|
+
elif self.scope == "attn":
|
|
502
|
+
return any(
|
|
503
|
+
keyword in name.lower()
|
|
504
|
+
for keyword in ["attn", "attention", "self_attn"]
|
|
505
|
+
)
|
|
506
|
+
elif self.scope == "ffn":
|
|
507
|
+
return any(
|
|
508
|
+
keyword in name.lower()
|
|
509
|
+
for keyword in ["mlp", "ffn", "feed_forward", "fc"]
|
|
510
|
+
)
|
|
511
|
+
|
|
512
|
+
return True
|
|
513
|
+
|
|
514
|
+
def _compute_family_observability(
|
|
515
|
+
self,
|
|
516
|
+
) -> tuple[dict[str, dict[str, float]], dict[str, list[dict[str, Any]]]]:
|
|
517
|
+
"""Generate per-family quantiles and top-|z| listings from latest z-scores."""
|
|
518
|
+
family_scores: dict[str, list[float]] = defaultdict(list)
|
|
519
|
+
family_modules: dict[str, list[tuple[float, str]]] = defaultdict(list)
|
|
520
|
+
|
|
521
|
+
for module_name, z_value in (self.latest_z_scores or {}).items():
|
|
522
|
+
family = self.module_family_map.get(module_name)
|
|
523
|
+
if family is None:
|
|
524
|
+
continue
|
|
525
|
+
try:
|
|
526
|
+
z_abs = abs(float(z_value))
|
|
527
|
+
except (TypeError, ValueError):
|
|
528
|
+
continue
|
|
529
|
+
family_scores.setdefault(family, []).append(z_abs)
|
|
530
|
+
family_modules.setdefault(family, []).append((z_abs, module_name))
|
|
531
|
+
|
|
532
|
+
def _quantile(sorted_values: list[float], quantile: float) -> float:
|
|
533
|
+
if not sorted_values:
|
|
534
|
+
return 0.0
|
|
535
|
+
if len(sorted_values) == 1:
|
|
536
|
+
return sorted_values[0]
|
|
537
|
+
position = (len(sorted_values) - 1) * quantile
|
|
538
|
+
lower = math.floor(position)
|
|
539
|
+
upper = math.ceil(position)
|
|
540
|
+
if lower == upper:
|
|
541
|
+
return sorted_values[int(position)]
|
|
542
|
+
fraction = position - lower
|
|
543
|
+
return (
|
|
544
|
+
sorted_values[lower]
|
|
545
|
+
+ (sorted_values[upper] - sorted_values[lower]) * fraction
|
|
546
|
+
)
|
|
547
|
+
|
|
548
|
+
family_quantiles: dict[str, dict[str, float]] = {}
|
|
549
|
+
for family, scores in family_scores.items():
|
|
550
|
+
sorted_scores = sorted(scores)
|
|
551
|
+
family_quantiles[family] = {
|
|
552
|
+
"q95": _quantile(sorted_scores, 0.95),
|
|
553
|
+
"q99": _quantile(sorted_scores, 0.99),
|
|
554
|
+
"max": sorted_scores[-1] if sorted_scores else 0.0,
|
|
555
|
+
"count": len(sorted_scores),
|
|
556
|
+
}
|
|
557
|
+
|
|
558
|
+
top_z_scores: dict[str, list[dict[str, Any]]] = {}
|
|
559
|
+
for family, module_entries in family_modules.items():
|
|
560
|
+
module_entries.sort(key=lambda item: item[0], reverse=True)
|
|
561
|
+
top_entries: list[dict[str, Any]] = []
|
|
562
|
+
for z_abs, module_name in module_entries[:3]:
|
|
563
|
+
top_entries.append(
|
|
564
|
+
{"module": module_name, "z": float(z_abs), "family": family}
|
|
565
|
+
)
|
|
566
|
+
top_z_scores[family] = top_entries
|
|
567
|
+
|
|
568
|
+
return family_quantiles, top_z_scores
|
|
569
|
+
|
|
570
|
+
def validate(
|
|
571
|
+
self, model: Any, adapter: Any, context: dict[str, Any]
|
|
572
|
+
) -> dict[str, Any]:
|
|
573
|
+
"""
|
|
574
|
+
Validate model spectral properties.
|
|
575
|
+
|
|
576
|
+
Args:
|
|
577
|
+
model: Model to validate
|
|
578
|
+
adapter: ModelAdapter instance
|
|
579
|
+
context: Validation context
|
|
580
|
+
|
|
581
|
+
Returns:
|
|
582
|
+
Dictionary with validation results
|
|
583
|
+
"""
|
|
584
|
+
try:
|
|
585
|
+
if not self.prepared:
|
|
586
|
+
# Auto-prepare if needed
|
|
587
|
+
self.prepare(model, adapter, None, {})
|
|
588
|
+
|
|
589
|
+
# Capture current spectral state
|
|
590
|
+
current_metrics = capture_baseline_sigmas(model, scope=self.scope)
|
|
591
|
+
|
|
592
|
+
# Detect violations (final validation phase)
|
|
593
|
+
violations = self._detect_spectral_violations(
|
|
594
|
+
model, current_metrics, phase="validate"
|
|
595
|
+
)
|
|
596
|
+
|
|
597
|
+
# Determine if passed under budget/fatal rules
|
|
598
|
+
fatal_violation_types = {"max_spectral_norm", "ill_conditioned"}
|
|
599
|
+
budgeted_violations = [
|
|
600
|
+
violation
|
|
601
|
+
for violation in violations
|
|
602
|
+
if violation.get("type") not in fatal_violation_types
|
|
603
|
+
]
|
|
604
|
+
fatal_violations = [
|
|
605
|
+
violation
|
|
606
|
+
for violation in violations
|
|
607
|
+
if violation.get("type") in fatal_violation_types
|
|
608
|
+
]
|
|
609
|
+
|
|
610
|
+
caps_applied = len(budgeted_violations)
|
|
611
|
+
caps_exceeded = caps_applied > int(self.max_caps)
|
|
612
|
+
passed = not fatal_violations and not caps_exceeded
|
|
613
|
+
if fatal_violations or caps_exceeded:
|
|
614
|
+
action = "abort"
|
|
615
|
+
elif caps_applied > 0:
|
|
616
|
+
action = "warn"
|
|
617
|
+
else:
|
|
618
|
+
action = "continue"
|
|
619
|
+
|
|
620
|
+
# Compute overall metrics
|
|
621
|
+
family_summary = summarize_family_z_scores(
|
|
622
|
+
self.latest_z_scores, self.module_family_map, self.family_caps
|
|
623
|
+
)
|
|
624
|
+
metrics = {
|
|
625
|
+
"modules_checked": len(current_metrics),
|
|
626
|
+
"violations_found": len(violations),
|
|
627
|
+
"budgeted_violations": caps_applied,
|
|
628
|
+
"fatal_violations": len(fatal_violations),
|
|
629
|
+
"max_spectral_norm": max(current_metrics.values())
|
|
630
|
+
if current_metrics
|
|
631
|
+
else 0.0,
|
|
632
|
+
"mean_spectral_norm": np.mean(list(current_metrics.values()))
|
|
633
|
+
if current_metrics
|
|
634
|
+
else 0.0,
|
|
635
|
+
"stability_score": 1.0
|
|
636
|
+
- min(len(violations) / max(len(current_metrics), 1), 1.0),
|
|
637
|
+
"family_z_summary": family_summary,
|
|
638
|
+
"family_caps": self.family_caps,
|
|
639
|
+
"sigma_quantile": float(self.sigma_quantile),
|
|
640
|
+
"deadband": float(self.deadband),
|
|
641
|
+
"max_caps": int(self.max_caps),
|
|
642
|
+
"caps_applied": caps_applied,
|
|
643
|
+
"caps_exceeded": caps_exceeded,
|
|
644
|
+
"multiple_testing": self.multiple_testing,
|
|
645
|
+
}
|
|
646
|
+
|
|
647
|
+
family_quantiles, top_z_scores = self._compute_family_observability()
|
|
648
|
+
if family_quantiles:
|
|
649
|
+
metrics["family_z_quantiles"] = family_quantiles
|
|
650
|
+
if top_z_scores:
|
|
651
|
+
metrics["top_z_scores"] = top_z_scores
|
|
652
|
+
|
|
653
|
+
if passed:
|
|
654
|
+
message = (
|
|
655
|
+
"Spectral validation passed with "
|
|
656
|
+
f"{len(violations)} violations "
|
|
657
|
+
f"(caps_applied={caps_applied}, max_caps={self.max_caps})"
|
|
658
|
+
)
|
|
659
|
+
else:
|
|
660
|
+
reason = (
|
|
661
|
+
"fatal spectral violation detected"
|
|
662
|
+
if fatal_violations
|
|
663
|
+
else "cap budget exceeded"
|
|
664
|
+
)
|
|
665
|
+
message = (
|
|
666
|
+
f"Spectral validation failed: {reason} "
|
|
667
|
+
f"(caps_applied={caps_applied}, max_caps={self.max_caps})"
|
|
668
|
+
)
|
|
669
|
+
|
|
670
|
+
# Runtime contracts (lightweight)
|
|
671
|
+
mt = self.multiple_testing or {}
|
|
672
|
+
try:
|
|
673
|
+
alpha = float(mt.get("alpha", 0.05)) if isinstance(mt, dict) else 0.05
|
|
674
|
+
except Exception:
|
|
675
|
+
alpha = 0.05
|
|
676
|
+
guard_assert(self.deadband >= 0.0, "spectral.deadband must be >= 0")
|
|
677
|
+
guard_assert(
|
|
678
|
+
0.0 < alpha <= 1.0, "spectral.multiple_testing.alpha out of range"
|
|
679
|
+
)
|
|
680
|
+
guard_assert(self.max_caps >= 0, "spectral.max_caps must be >= 0")
|
|
681
|
+
|
|
682
|
+
return {
|
|
683
|
+
"passed": passed,
|
|
684
|
+
"action": action,
|
|
685
|
+
"metrics": metrics,
|
|
686
|
+
"violations": violations,
|
|
687
|
+
"message": message,
|
|
688
|
+
"policy": self._serialize_policy(),
|
|
689
|
+
"final_z_scores": self.latest_z_scores.copy(),
|
|
690
|
+
"module_family_map": dict(self.module_family_map),
|
|
691
|
+
}
|
|
692
|
+
|
|
693
|
+
except Exception as e:
|
|
694
|
+
return {
|
|
695
|
+
"passed": False,
|
|
696
|
+
"action": "warn",
|
|
697
|
+
"error": str(e),
|
|
698
|
+
"metrics": {},
|
|
699
|
+
"message": f"Spectral validation failed: {e}",
|
|
700
|
+
}
|
|
701
|
+
|
|
702
|
+
def finalize(self, model: Any) -> dict[str, Any]:
|
|
703
|
+
"""
|
|
704
|
+
Finalize spectral guard and return comprehensive results.
|
|
705
|
+
|
|
706
|
+
Args:
|
|
707
|
+
model: The final model state
|
|
708
|
+
|
|
709
|
+
Returns:
|
|
710
|
+
Dictionary with spectral guard results
|
|
711
|
+
"""
|
|
712
|
+
if not self.prepared:
|
|
713
|
+
return {
|
|
714
|
+
"passed": False,
|
|
715
|
+
"metrics": {},
|
|
716
|
+
"warnings": ["Spectral guard not properly prepared"],
|
|
717
|
+
"errors": ["Preparation failed or not called"],
|
|
718
|
+
"events": self.events,
|
|
719
|
+
}
|
|
720
|
+
|
|
721
|
+
# Final spectral analysis
|
|
722
|
+
final_metrics = capture_baseline_sigmas(model, scope=self.scope)
|
|
723
|
+
final_violations = self._detect_spectral_violations(
|
|
724
|
+
model, final_metrics, phase="finalize"
|
|
725
|
+
)
|
|
726
|
+
final_z_summary = summarize_family_z_scores(
|
|
727
|
+
self.latest_z_scores, self.module_family_map, self.family_caps
|
|
728
|
+
)
|
|
729
|
+
final_family_stats = compute_family_stats(final_metrics, self.module_family_map)
|
|
730
|
+
|
|
731
|
+
family_quantiles, top_z_scores = self._compute_family_observability()
|
|
732
|
+
|
|
733
|
+
# Determine overall status based on budgeted vs fatal violations
|
|
734
|
+
fatal_violation_types = {"max_spectral_norm", "ill_conditioned"}
|
|
735
|
+
budgeted_violations = [
|
|
736
|
+
violation
|
|
737
|
+
for violation in final_violations
|
|
738
|
+
if violation.get("type") not in fatal_violation_types
|
|
739
|
+
]
|
|
740
|
+
fatal_violations = [
|
|
741
|
+
violation
|
|
742
|
+
for violation in final_violations
|
|
743
|
+
if violation.get("type") in fatal_violation_types
|
|
744
|
+
]
|
|
745
|
+
|
|
746
|
+
caps_applied = len(budgeted_violations)
|
|
747
|
+
caps_exceeded = caps_applied > int(self.max_caps)
|
|
748
|
+
passed = not fatal_violations and not caps_exceeded
|
|
749
|
+
|
|
750
|
+
# Compute comprehensive metrics
|
|
751
|
+
metrics = {
|
|
752
|
+
"modules_analyzed": len(final_metrics),
|
|
753
|
+
"violations_detected": len(final_violations),
|
|
754
|
+
"budgeted_violations": caps_applied,
|
|
755
|
+
"fatal_violations": len(fatal_violations),
|
|
756
|
+
"baseline_modules": len(self.baseline_metrics),
|
|
757
|
+
"scope": self.scope,
|
|
758
|
+
"max_spectral_norm_final": max(final_metrics.values())
|
|
759
|
+
if final_metrics
|
|
760
|
+
else 0.0,
|
|
761
|
+
"mean_spectral_norm_final": np.mean(list(final_metrics.values()))
|
|
762
|
+
if final_metrics
|
|
763
|
+
else 0.0,
|
|
764
|
+
"spectral_stability_score": 1.0
|
|
765
|
+
- min(len(final_violations) / max(len(final_metrics), 1), 1.0),
|
|
766
|
+
"target_sigma": self.target_sigma,
|
|
767
|
+
"correction_applied": len(final_violations) > 0 and self.correction_enabled,
|
|
768
|
+
"family_caps": self.family_caps,
|
|
769
|
+
"family_z_summary": final_z_summary,
|
|
770
|
+
"family_stats": final_family_stats,
|
|
771
|
+
"sigma_quantile": float(self.sigma_quantile),
|
|
772
|
+
"deadband": float(self.deadband),
|
|
773
|
+
"max_caps": int(self.max_caps),
|
|
774
|
+
"caps_applied": caps_applied,
|
|
775
|
+
"caps_exceeded": caps_exceeded,
|
|
776
|
+
"multiple_testing": self.multiple_testing,
|
|
777
|
+
"family_z_quantiles": family_quantiles,
|
|
778
|
+
"top_z_scores": top_z_scores,
|
|
779
|
+
}
|
|
780
|
+
|
|
781
|
+
# Categorize violations
|
|
782
|
+
warnings = []
|
|
783
|
+
errors = []
|
|
784
|
+
|
|
785
|
+
for violation in final_violations:
|
|
786
|
+
if violation["type"] in ["max_spectral_norm", "ill_conditioned"]:
|
|
787
|
+
errors.append(violation["message"])
|
|
788
|
+
else:
|
|
789
|
+
warnings.append(violation["message"])
|
|
790
|
+
|
|
791
|
+
result = {
|
|
792
|
+
"passed": passed,
|
|
793
|
+
"metrics": metrics,
|
|
794
|
+
"warnings": warnings,
|
|
795
|
+
"errors": errors,
|
|
796
|
+
"violations": final_violations,
|
|
797
|
+
"events": self.events,
|
|
798
|
+
"baseline_metrics": self.baseline_metrics,
|
|
799
|
+
"final_metrics": final_metrics,
|
|
800
|
+
"final_z_scores": self.latest_z_scores,
|
|
801
|
+
"module_family_map": dict(self.module_family_map),
|
|
802
|
+
"policy": self._serialize_policy(),
|
|
803
|
+
}
|
|
804
|
+
|
|
805
|
+
# Env-gated tiny evidence dump for auditors
|
|
806
|
+
try:
|
|
807
|
+
payload = {
|
|
808
|
+
"spectral": {
|
|
809
|
+
"sigma_quantile": float(self.sigma_quantile),
|
|
810
|
+
"deadband": float(self.deadband),
|
|
811
|
+
"max_caps": int(self.max_caps),
|
|
812
|
+
"multiple_testing": self.multiple_testing.get("method")
|
|
813
|
+
if isinstance(self.multiple_testing, dict)
|
|
814
|
+
else None,
|
|
815
|
+
"evaluated": True,
|
|
816
|
+
}
|
|
817
|
+
}
|
|
818
|
+
maybe_dump_guard_evidence(".", payload)
|
|
819
|
+
except Exception:
|
|
820
|
+
pass
|
|
821
|
+
|
|
822
|
+
return result
|
|
823
|
+
|
|
824
|
+
|
|
825
|
+
def compute_sigma_max(weight_matrix: Any) -> float:
|
|
826
|
+
"""
|
|
827
|
+
Compute maximum singular value of a weight matrix.
|
|
828
|
+
|
|
829
|
+
Args:
|
|
830
|
+
weight_matrix: Weight matrix to analyze
|
|
831
|
+
|
|
832
|
+
Returns:
|
|
833
|
+
Maximum singular value
|
|
834
|
+
"""
|
|
835
|
+
try:
|
|
836
|
+
if isinstance(weight_matrix, torch.Tensor):
|
|
837
|
+
# Handle different tensor types
|
|
838
|
+
if weight_matrix.dtype in [torch.int8]:
|
|
839
|
+
# Skip quantized weights
|
|
840
|
+
return 1.0
|
|
841
|
+
|
|
842
|
+
# Ensure float type for SVD
|
|
843
|
+
W = weight_matrix.float()
|
|
844
|
+
|
|
845
|
+
# Handle edge cases
|
|
846
|
+
if W.numel() == 0 or W.shape[0] == 0 or W.shape[1] == 0:
|
|
847
|
+
return 0.0
|
|
848
|
+
|
|
849
|
+
# Compute singular values using deterministic backend when available
|
|
850
|
+
try:
|
|
851
|
+
singular_values = torch.linalg.svdvals(W)
|
|
852
|
+
except RuntimeError:
|
|
853
|
+
# Fallback for older backends without svdvals
|
|
854
|
+
singular_values = torch.linalg.svd(W, full_matrices=False).S
|
|
855
|
+
|
|
856
|
+
return singular_values[0].item() if singular_values.numel() > 0 else 0.0
|
|
857
|
+
else:
|
|
858
|
+
return 1.0 # Fallback for non-tensor inputs
|
|
859
|
+
|
|
860
|
+
except Exception:
|
|
861
|
+
return 1.0 # Fallback on any error
|
|
862
|
+
|
|
863
|
+
|
|
864
|
+
def auto_sigma_target(model: Any, percentile: float = 0.95, **kwargs: Any) -> float:
|
|
865
|
+
"""
|
|
866
|
+
Automatically determine sigma target for a model.
|
|
867
|
+
|
|
868
|
+
Args:
|
|
869
|
+
model: Model to analyze
|
|
870
|
+
percentile: Scale factor (target percentile of spectral norms)
|
|
871
|
+
|
|
872
|
+
Returns:
|
|
873
|
+
Target sigma value
|
|
874
|
+
"""
|
|
875
|
+
if "kappa" in kwargs and percentile == 0.95:
|
|
876
|
+
try:
|
|
877
|
+
percentile = float(kwargs["kappa"])
|
|
878
|
+
except (TypeError, ValueError):
|
|
879
|
+
pass
|
|
880
|
+
try:
|
|
881
|
+
# Collect all spectral norms
|
|
882
|
+
spectral_norms = []
|
|
883
|
+
|
|
884
|
+
for _name, module in model.named_modules():
|
|
885
|
+
if hasattr(module, "weight") and module.weight.ndim == 2:
|
|
886
|
+
sigma = compute_sigma_max(module.weight)
|
|
887
|
+
if sigma > 0:
|
|
888
|
+
spectral_norms.append(sigma)
|
|
889
|
+
|
|
890
|
+
if spectral_norms:
|
|
891
|
+
# Use kappa-percentile as target
|
|
892
|
+
target = np.percentile(spectral_norms, percentile * 100)
|
|
893
|
+
return float(target)
|
|
894
|
+
else:
|
|
895
|
+
return percentile # Fallback to requested sigma quantile
|
|
896
|
+
|
|
897
|
+
except Exception:
|
|
898
|
+
return percentile # Default fallback
|
|
899
|
+
|
|
900
|
+
|
|
901
|
+
def apply_weight_rescale(
|
|
902
|
+
model: Any, scale_factor: float = 1.0, scope: str = "all"
|
|
903
|
+
) -> dict[str, Any]:
|
|
904
|
+
"""
|
|
905
|
+
Apply weight rescaling to model parameters.
|
|
906
|
+
|
|
907
|
+
Args:
|
|
908
|
+
model: Model to rescale
|
|
909
|
+
scale_factor: Scaling factor to apply
|
|
910
|
+
scope: Which modules to rescale ('all', 'attn', 'ffn')
|
|
911
|
+
|
|
912
|
+
Returns:
|
|
913
|
+
Rescaling results
|
|
914
|
+
"""
|
|
915
|
+
try:
|
|
916
|
+
rescaled_modules = []
|
|
917
|
+
failed_modules = []
|
|
918
|
+
|
|
919
|
+
for name, module in model.named_modules():
|
|
920
|
+
if not _should_process_module(name, module, scope):
|
|
921
|
+
continue
|
|
922
|
+
|
|
923
|
+
try:
|
|
924
|
+
if hasattr(module, "weight") and module.weight.ndim == 2:
|
|
925
|
+
# Skip quantized weights
|
|
926
|
+
if hasattr(module.weight, "dtype") and module.weight.dtype in [
|
|
927
|
+
torch.int8,
|
|
928
|
+
]:
|
|
929
|
+
continue
|
|
930
|
+
|
|
931
|
+
# Apply rescaling
|
|
932
|
+
with torch.no_grad():
|
|
933
|
+
module.weight.mul_(scale_factor)
|
|
934
|
+
if hasattr(module, "bias") and module.bias is not None:
|
|
935
|
+
module.bias.mul_(scale_factor)
|
|
936
|
+
|
|
937
|
+
rescaled_modules.append(name)
|
|
938
|
+
|
|
939
|
+
except Exception as e:
|
|
940
|
+
failed_modules.append((name, str(e)))
|
|
941
|
+
|
|
942
|
+
return {
|
|
943
|
+
"applied": len(rescaled_modules) > 0,
|
|
944
|
+
"scale_factor": scale_factor,
|
|
945
|
+
"rescaled_modules": rescaled_modules,
|
|
946
|
+
"failed_modules": failed_modules,
|
|
947
|
+
"message": f"Rescaled {len(rescaled_modules)} modules with factor {scale_factor}",
|
|
948
|
+
}
|
|
949
|
+
|
|
950
|
+
except Exception as e:
|
|
951
|
+
return {
|
|
952
|
+
"applied": False,
|
|
953
|
+
"error": str(e),
|
|
954
|
+
"message": f"Weight rescaling failed: {e}",
|
|
955
|
+
}
|
|
956
|
+
|
|
957
|
+
|
|
958
|
+
def apply_relative_spectral_cap(
|
|
959
|
+
model: Any,
|
|
960
|
+
cap_ratio: float = 2.0,
|
|
961
|
+
scope: str = "all",
|
|
962
|
+
baseline_sigmas: dict[str, float] | None = None,
|
|
963
|
+
) -> dict[str, Any]:
|
|
964
|
+
"""
|
|
965
|
+
Apply relative spectral capping to model weights.
|
|
966
|
+
|
|
967
|
+
Args:
|
|
968
|
+
model: Model to cap
|
|
969
|
+
cap_ratio: Maximum allowed ratio relative to baseline
|
|
970
|
+
scope: Which modules to cap ('all', 'attn', 'ffn')
|
|
971
|
+
baseline_sigmas: Mapping of module name to pre-edit sigma values
|
|
972
|
+
|
|
973
|
+
Returns:
|
|
974
|
+
Capping results
|
|
975
|
+
"""
|
|
976
|
+
try:
|
|
977
|
+
if baseline_sigmas is None:
|
|
978
|
+
baseline_sigmas = capture_baseline_sigmas(model, scope=scope)
|
|
979
|
+
|
|
980
|
+
capped_modules = []
|
|
981
|
+
failed_modules = []
|
|
982
|
+
|
|
983
|
+
for name, module in model.named_modules():
|
|
984
|
+
if not _should_process_module(name, module, scope):
|
|
985
|
+
continue
|
|
986
|
+
|
|
987
|
+
try:
|
|
988
|
+
if hasattr(module, "weight") and module.weight.ndim == 2:
|
|
989
|
+
# Skip quantized weights
|
|
990
|
+
if hasattr(module.weight, "dtype") and module.weight.dtype in [
|
|
991
|
+
torch.int8,
|
|
992
|
+
]:
|
|
993
|
+
continue
|
|
994
|
+
|
|
995
|
+
current_sigma = compute_sigma_max(module.weight)
|
|
996
|
+
baseline_sigma = baseline_sigmas.get(name, current_sigma)
|
|
997
|
+
max_allowed = baseline_sigma * cap_ratio
|
|
998
|
+
|
|
999
|
+
if current_sigma > max_allowed:
|
|
1000
|
+
# Apply spectral capping using SVD
|
|
1001
|
+
scale_factor = max_allowed / current_sigma
|
|
1002
|
+
|
|
1003
|
+
with torch.no_grad():
|
|
1004
|
+
module.weight.mul_(scale_factor)
|
|
1005
|
+
|
|
1006
|
+
capped_modules.append(
|
|
1007
|
+
{
|
|
1008
|
+
"module": name,
|
|
1009
|
+
"original_sigma": current_sigma,
|
|
1010
|
+
"capped_sigma": max_allowed,
|
|
1011
|
+
"scale_factor": scale_factor,
|
|
1012
|
+
}
|
|
1013
|
+
)
|
|
1014
|
+
|
|
1015
|
+
except Exception as e:
|
|
1016
|
+
failed_modules.append((name, str(e)))
|
|
1017
|
+
|
|
1018
|
+
return {
|
|
1019
|
+
"applied": len(capped_modules) > 0,
|
|
1020
|
+
"cap_ratio": cap_ratio,
|
|
1021
|
+
"capped_modules": capped_modules,
|
|
1022
|
+
"failed_modules": failed_modules,
|
|
1023
|
+
"message": f"Applied spectral capping to {len(capped_modules)} modules",
|
|
1024
|
+
}
|
|
1025
|
+
|
|
1026
|
+
except Exception as e:
|
|
1027
|
+
return {
|
|
1028
|
+
"applied": False,
|
|
1029
|
+
"error": str(e),
|
|
1030
|
+
"message": f"Spectral capping failed: {e}",
|
|
1031
|
+
}
|
|
1032
|
+
|
|
1033
|
+
|
|
1034
|
+
def apply_spectral_control(model: Any, policy: dict[str, Any]) -> dict[str, Any]:
|
|
1035
|
+
"""
|
|
1036
|
+
Apply spectral control based on policy.
|
|
1037
|
+
|
|
1038
|
+
Args:
|
|
1039
|
+
model: Model to control
|
|
1040
|
+
policy: Spectral control policy
|
|
1041
|
+
|
|
1042
|
+
Returns:
|
|
1043
|
+
Control results
|
|
1044
|
+
"""
|
|
1045
|
+
try:
|
|
1046
|
+
results: dict[str, Any] = {
|
|
1047
|
+
"rescaling_applied": False,
|
|
1048
|
+
"capping_applied": False,
|
|
1049
|
+
"modules_processed": 0,
|
|
1050
|
+
"corrections": [],
|
|
1051
|
+
}
|
|
1052
|
+
|
|
1053
|
+
scope = policy.get("scope", "all")
|
|
1054
|
+
baseline_sigmas = policy.get("baseline_sigmas")
|
|
1055
|
+
|
|
1056
|
+
# Apply relative spectral capping if needed
|
|
1057
|
+
cap_ratio = policy.get("cap_ratio", 2.0)
|
|
1058
|
+
cap_result = apply_relative_spectral_cap(
|
|
1059
|
+
model,
|
|
1060
|
+
cap_ratio=cap_ratio,
|
|
1061
|
+
scope=scope,
|
|
1062
|
+
baseline_sigmas=baseline_sigmas,
|
|
1063
|
+
)
|
|
1064
|
+
|
|
1065
|
+
if cap_result["applied"]:
|
|
1066
|
+
results["capping_applied"] = True
|
|
1067
|
+
results["corrections"].extend(cap_result["capped_modules"])
|
|
1068
|
+
|
|
1069
|
+
# Apply rescaling if target sigma is specified
|
|
1070
|
+
if "rescale_factor" in policy:
|
|
1071
|
+
rescale_result = apply_weight_rescale(
|
|
1072
|
+
model, scale_factor=policy["rescale_factor"], scope=scope
|
|
1073
|
+
)
|
|
1074
|
+
|
|
1075
|
+
if rescale_result["applied"]:
|
|
1076
|
+
results["rescaling_applied"] = True
|
|
1077
|
+
results["modules_processed"] += len(rescale_result["rescaled_modules"])
|
|
1078
|
+
|
|
1079
|
+
results["applied"] = results["rescaling_applied"] or results["capping_applied"]
|
|
1080
|
+
results["policy"] = policy
|
|
1081
|
+
results["message"] = (
|
|
1082
|
+
f"Spectral control applied: capping={results['capping_applied']}, rescaling={results['rescaling_applied']}"
|
|
1083
|
+
)
|
|
1084
|
+
|
|
1085
|
+
return results
|
|
1086
|
+
|
|
1087
|
+
except Exception as e:
|
|
1088
|
+
return {
|
|
1089
|
+
"applied": False,
|
|
1090
|
+
"error": str(e),
|
|
1091
|
+
"policy": policy,
|
|
1092
|
+
"message": f"Spectral control failed: {e}",
|
|
1093
|
+
}
|
|
1094
|
+
|
|
1095
|
+
|
|
1096
|
+
def _summarize_sigmas(sigmas: dict[str, float]) -> dict[str, float]:
|
|
1097
|
+
"""Compute summary statistics for a sigma dictionary."""
|
|
1098
|
+
if not sigmas:
|
|
1099
|
+
return {
|
|
1100
|
+
"max_spectral_norm": 0.0,
|
|
1101
|
+
"mean_spectral_norm": 0.0,
|
|
1102
|
+
"min_spectral_norm": 0.0,
|
|
1103
|
+
}
|
|
1104
|
+
|
|
1105
|
+
values = np.array(list(sigmas.values()), dtype=float)
|
|
1106
|
+
return {
|
|
1107
|
+
"max_spectral_norm": float(values.max()),
|
|
1108
|
+
"mean_spectral_norm": float(values.mean()),
|
|
1109
|
+
"min_spectral_norm": float(values.min()),
|
|
1110
|
+
}
|
|
1111
|
+
|
|
1112
|
+
|
|
1113
|
+
def compute_z_score_for_value(
|
|
1114
|
+
sigma: float,
|
|
1115
|
+
family_stats: dict[str, float],
|
|
1116
|
+
fallback_value: float,
|
|
1117
|
+
deadband: float,
|
|
1118
|
+
) -> float:
|
|
1119
|
+
"""Compute per-family z-score for a spectral norm with sensible fallbacks."""
|
|
1120
|
+
mean = float(family_stats.get("mean", 0.0) or 0.0)
|
|
1121
|
+
std = float(family_stats.get("std", 0.0) or 0.0)
|
|
1122
|
+
|
|
1123
|
+
if std > 0:
|
|
1124
|
+
return float((sigma - mean) / std)
|
|
1125
|
+
|
|
1126
|
+
# Fallback: scale relative change by deadband width
|
|
1127
|
+
denom = fallback_value if fallback_value > 0 else 1.0
|
|
1128
|
+
rel_change = (sigma / denom) - 1.0
|
|
1129
|
+
|
|
1130
|
+
if abs(rel_change) <= deadband:
|
|
1131
|
+
return 0.0
|
|
1132
|
+
|
|
1133
|
+
scale = deadband if deadband > 0 else 1.0
|
|
1134
|
+
return float(rel_change / scale)
|
|
1135
|
+
|
|
1136
|
+
|
|
1137
|
+
def compute_z_scores(
|
|
1138
|
+
metrics: dict[str, float],
|
|
1139
|
+
baseline_family_stats: dict[str, dict[str, float]],
|
|
1140
|
+
module_family_map: dict[str, str],
|
|
1141
|
+
baseline_sigmas: dict[str, float],
|
|
1142
|
+
deadband: float,
|
|
1143
|
+
) -> dict[str, float]:
|
|
1144
|
+
"""Compute z-scores for all modules given baseline family stats."""
|
|
1145
|
+
z_scores: dict[str, float] = {}
|
|
1146
|
+
for name, sigma in metrics.items():
|
|
1147
|
+
family = module_family_map.get(name, "other")
|
|
1148
|
+
family_stats = baseline_family_stats.get(family, {})
|
|
1149
|
+
fallback_value = baseline_sigmas.get(name, family_stats.get("mean", sigma))
|
|
1150
|
+
z_scores[name] = compute_z_score_for_value(
|
|
1151
|
+
float(sigma),
|
|
1152
|
+
family_stats,
|
|
1153
|
+
float(fallback_value),
|
|
1154
|
+
deadband=deadband,
|
|
1155
|
+
)
|
|
1156
|
+
return z_scores
|
|
1157
|
+
|
|
1158
|
+
|
|
1159
|
+
def summarize_family_z_scores(
|
|
1160
|
+
z_scores: dict[str, float],
|
|
1161
|
+
module_family_map: dict[str, str],
|
|
1162
|
+
family_caps: dict[str, dict[str, float]],
|
|
1163
|
+
) -> dict[str, dict[str, float]]:
|
|
1164
|
+
"""Summarize z-scores per family, including violation counts."""
|
|
1165
|
+
family_values: dict[str, list[float]] = defaultdict(list)
|
|
1166
|
+
for name, z in z_scores.items():
|
|
1167
|
+
family = module_family_map.get(name, "other")
|
|
1168
|
+
family_values[family].append(float(z))
|
|
1169
|
+
|
|
1170
|
+
summary: dict[str, dict[str, float]] = {}
|
|
1171
|
+
for family, values in family_values.items():
|
|
1172
|
+
if not values:
|
|
1173
|
+
continue
|
|
1174
|
+
arr = np.array(values, dtype=float)
|
|
1175
|
+
cap = family_caps.get(family, {}).get("kappa")
|
|
1176
|
+
violations = 0
|
|
1177
|
+
if cap is not None:
|
|
1178
|
+
violations = int(np.sum(arr > float(cap)))
|
|
1179
|
+
summary[family] = {
|
|
1180
|
+
"max": float(arr.max()),
|
|
1181
|
+
"mean": float(arr.mean()),
|
|
1182
|
+
"count": len(values),
|
|
1183
|
+
"violations": violations,
|
|
1184
|
+
}
|
|
1185
|
+
if cap is not None:
|
|
1186
|
+
summary[family]["kappa"] = float(cap)
|
|
1187
|
+
return summary
|
|
1188
|
+
|
|
1189
|
+
|
|
1190
|
+
def compute_family_stats(
|
|
1191
|
+
sigmas: dict[str, float], family_map: dict[str, str]
|
|
1192
|
+
) -> dict[str, dict[str, float]]:
|
|
1193
|
+
"""Compute per-family statistics (mean/std/min/max/count)."""
|
|
1194
|
+
buckets: dict[str, list[float]] = defaultdict(list)
|
|
1195
|
+
for name, sigma in sigmas.items():
|
|
1196
|
+
family = family_map.get(name, "other")
|
|
1197
|
+
buckets[family].append(float(sigma))
|
|
1198
|
+
|
|
1199
|
+
stats: dict[str, dict[str, float]] = {}
|
|
1200
|
+
for family, values in buckets.items():
|
|
1201
|
+
if not values:
|
|
1202
|
+
continue
|
|
1203
|
+
arr = np.array(values, dtype=float)
|
|
1204
|
+
stats[family] = {
|
|
1205
|
+
"count": len(values),
|
|
1206
|
+
"mean": float(arr.mean()),
|
|
1207
|
+
"std": float(arr.std(ddof=0)),
|
|
1208
|
+
"min": float(arr.min()),
|
|
1209
|
+
"max": float(arr.max()),
|
|
1210
|
+
}
|
|
1211
|
+
return stats
|
|
1212
|
+
|
|
1213
|
+
|
|
1214
|
+
def classify_model_families(
|
|
1215
|
+
model: Any, scope: str = "all", existing: dict[str, str] | None = None
|
|
1216
|
+
) -> dict[str, str]:
|
|
1217
|
+
"""Build or update a module→family map for the provided model."""
|
|
1218
|
+
family_map = dict(existing) if existing else {}
|
|
1219
|
+
for name, module in model.named_modules():
|
|
1220
|
+
if _should_process_module(name, module, scope):
|
|
1221
|
+
family_map[name] = classify_module_family(name, module)
|
|
1222
|
+
return family_map
|
|
1223
|
+
|
|
1224
|
+
|
|
1225
|
+
def capture_baseline_sigmas(model: Any, scope: str = "all") -> dict[str, float]:
|
|
1226
|
+
"""
|
|
1227
|
+
Capture baseline singular values for model layers.
|
|
1228
|
+
|
|
1229
|
+
Args:
|
|
1230
|
+
model: Model to analyze
|
|
1231
|
+
scope: Which modules to analyze ('all', 'attn', 'ffn')
|
|
1232
|
+
|
|
1233
|
+
Returns:
|
|
1234
|
+
Dictionary of layer name to max singular value
|
|
1235
|
+
"""
|
|
1236
|
+
try:
|
|
1237
|
+
baseline_sigmas = {}
|
|
1238
|
+
|
|
1239
|
+
for name, module in model.named_modules():
|
|
1240
|
+
if _should_process_module(name, module, scope):
|
|
1241
|
+
if hasattr(module, "weight") and module.weight.ndim == 2:
|
|
1242
|
+
sigma = compute_sigma_max(module.weight)
|
|
1243
|
+
baseline_sigmas[name] = sigma
|
|
1244
|
+
|
|
1245
|
+
return baseline_sigmas
|
|
1246
|
+
|
|
1247
|
+
except Exception:
|
|
1248
|
+
return {}
|
|
1249
|
+
|
|
1250
|
+
|
|
1251
|
+
def scan_model_gains(model: Any, scope: str = "all") -> dict[str, Any]:
|
|
1252
|
+
"""
|
|
1253
|
+
Scan model for gain values and spectral statistics.
|
|
1254
|
+
|
|
1255
|
+
Args:
|
|
1256
|
+
model: Model to scan
|
|
1257
|
+
scope: Which modules to scan ('all', 'attn', 'ffn')
|
|
1258
|
+
|
|
1259
|
+
Returns:
|
|
1260
|
+
Gain analysis results
|
|
1261
|
+
"""
|
|
1262
|
+
try:
|
|
1263
|
+
results: dict[str, Any] = {
|
|
1264
|
+
"total_layers": 0,
|
|
1265
|
+
"scanned_modules": 0,
|
|
1266
|
+
"spectral_norms": [],
|
|
1267
|
+
"condition_numbers": [],
|
|
1268
|
+
"weight_statistics": {},
|
|
1269
|
+
}
|
|
1270
|
+
|
|
1271
|
+
for name, module in model.named_modules():
|
|
1272
|
+
results["total_layers"] += 1
|
|
1273
|
+
|
|
1274
|
+
if _should_process_module(name, module, scope):
|
|
1275
|
+
if hasattr(module, "weight") and module.weight.ndim == 2:
|
|
1276
|
+
results["scanned_modules"] += 1
|
|
1277
|
+
|
|
1278
|
+
# Compute spectral norm
|
|
1279
|
+
sigma_max = compute_sigma_max(module.weight)
|
|
1280
|
+
results["spectral_norms"].append(sigma_max)
|
|
1281
|
+
|
|
1282
|
+
# Compute condition number if possible
|
|
1283
|
+
try:
|
|
1284
|
+
U, S, V = torch.svd(module.weight.float())
|
|
1285
|
+
if len(S) > 1:
|
|
1286
|
+
condition_num = (S[0] / S[-1]).item()
|
|
1287
|
+
results["condition_numbers"].append(condition_num)
|
|
1288
|
+
except Exception:
|
|
1289
|
+
pass
|
|
1290
|
+
|
|
1291
|
+
# Basic weight statistics
|
|
1292
|
+
try:
|
|
1293
|
+
weight_stats = {
|
|
1294
|
+
"mean": module.weight.mean().item(),
|
|
1295
|
+
"std": module.weight.std().item(),
|
|
1296
|
+
"min": module.weight.min().item(),
|
|
1297
|
+
"max": module.weight.max().item(),
|
|
1298
|
+
}
|
|
1299
|
+
results["weight_statistics"][name] = weight_stats
|
|
1300
|
+
except Exception:
|
|
1301
|
+
pass
|
|
1302
|
+
|
|
1303
|
+
# Compute summary statistics
|
|
1304
|
+
if results["spectral_norms"]:
|
|
1305
|
+
results["mean_spectral_norm"] = np.mean(results["spectral_norms"])
|
|
1306
|
+
results["max_spectral_norm"] = np.max(results["spectral_norms"])
|
|
1307
|
+
results["min_spectral_norm"] = np.min(results["spectral_norms"])
|
|
1308
|
+
|
|
1309
|
+
if results["condition_numbers"]:
|
|
1310
|
+
results["mean_condition_number"] = np.mean(results["condition_numbers"])
|
|
1311
|
+
results["max_condition_number"] = np.max(results["condition_numbers"])
|
|
1312
|
+
|
|
1313
|
+
results["message"] = (
|
|
1314
|
+
f"Scanned {results['scanned_modules']} modules out of {results['total_layers']} total layers"
|
|
1315
|
+
)
|
|
1316
|
+
|
|
1317
|
+
return results
|
|
1318
|
+
|
|
1319
|
+
except Exception as e:
|
|
1320
|
+
return {
|
|
1321
|
+
"total_layers": sum(1 for _ in model.named_modules()),
|
|
1322
|
+
"scanned_modules": 0,
|
|
1323
|
+
"error": str(e),
|
|
1324
|
+
"message": f"Model scanning failed: {e}",
|
|
1325
|
+
}
|
|
1326
|
+
|
|
1327
|
+
|
|
1328
|
+
def _should_process_module(name: str, module: Any, scope: str) -> bool:
|
|
1329
|
+
"""Helper function to determine if a module should be processed based on scope."""
|
|
1330
|
+
if not hasattr(module, "weight") or module.weight.ndim != 2:
|
|
1331
|
+
return False
|
|
1332
|
+
|
|
1333
|
+
if scope == "all":
|
|
1334
|
+
return True
|
|
1335
|
+
elif scope == "attn":
|
|
1336
|
+
return any(
|
|
1337
|
+
keyword in name.lower()
|
|
1338
|
+
for keyword in ["attn", "attention", "self_attn", "c_attn", "c_proj"]
|
|
1339
|
+
)
|
|
1340
|
+
elif scope == "ffn":
|
|
1341
|
+
return any(
|
|
1342
|
+
keyword in name.lower()
|
|
1343
|
+
for keyword in ["mlp", "ffn", "feed_forward", "fc", "c_fc"]
|
|
1344
|
+
)
|
|
1345
|
+
elif scope == "ffn+proj":
|
|
1346
|
+
lname = name.lower()
|
|
1347
|
+
return any(
|
|
1348
|
+
keyword in lname
|
|
1349
|
+
for keyword in [
|
|
1350
|
+
"mlp",
|
|
1351
|
+
"ffn",
|
|
1352
|
+
"feed_forward",
|
|
1353
|
+
"fc",
|
|
1354
|
+
"c_fc",
|
|
1355
|
+
"c_proj",
|
|
1356
|
+
"projection",
|
|
1357
|
+
]
|
|
1358
|
+
)
|
|
1359
|
+
|
|
1360
|
+
return True
|
|
1361
|
+
|
|
1362
|
+
|
|
1363
|
+
def classify_module_family(name: str, module: Any) -> str:
|
|
1364
|
+
"""Classify module into a spectral family for policy purposes."""
|
|
1365
|
+
lname = name.lower()
|
|
1366
|
+
|
|
1367
|
+
# MoE router/gating
|
|
1368
|
+
if any(
|
|
1369
|
+
tok in lname
|
|
1370
|
+
for tok in ("router", "routing", "gate", "gating", "dispatch", "switch")
|
|
1371
|
+
):
|
|
1372
|
+
return "router"
|
|
1373
|
+
# MoE expert FFN
|
|
1374
|
+
if any(tok in lname for tok in ("experts", "expert", "moe", "mixture_of_experts")):
|
|
1375
|
+
return "expert_ffn"
|
|
1376
|
+
|
|
1377
|
+
if "mlp" in lname or "ffn" in lname or "feed_forward" in lname:
|
|
1378
|
+
return "ffn"
|
|
1379
|
+
|
|
1380
|
+
if (
|
|
1381
|
+
"attn" in lname
|
|
1382
|
+
or "attention" in lname
|
|
1383
|
+
or any(
|
|
1384
|
+
token in lname
|
|
1385
|
+
for token in ["q_proj", "k_proj", "v_proj", "o_proj", "c_attn"]
|
|
1386
|
+
)
|
|
1387
|
+
):
|
|
1388
|
+
return "attn"
|
|
1389
|
+
|
|
1390
|
+
if "embed" in lname or "wte" in lname or "embedding" in lname:
|
|
1391
|
+
return "embed"
|
|
1392
|
+
|
|
1393
|
+
module_type = module.__class__.__name__.lower()
|
|
1394
|
+
if "embedding" in module_type:
|
|
1395
|
+
return "embed"
|
|
1396
|
+
if "conv1d" in module_type or "linear" in module_type:
|
|
1397
|
+
if "attn" in lname:
|
|
1398
|
+
return "attn"
|
|
1399
|
+
if "mlp" in lname or "ffn" in lname:
|
|
1400
|
+
return "ffn"
|
|
1401
|
+
|
|
1402
|
+
return "other"
|
|
1403
|
+
|
|
1404
|
+
|
|
1405
|
+
# Export the main components
|
|
1406
|
+
__all__ = [
|
|
1407
|
+
"SpectralGuard",
|
|
1408
|
+
"SpectralPolicy",
|
|
1409
|
+
"compute_sigma_max",
|
|
1410
|
+
"auto_sigma_target",
|
|
1411
|
+
"apply_weight_rescale",
|
|
1412
|
+
"apply_relative_spectral_cap",
|
|
1413
|
+
"apply_spectral_control",
|
|
1414
|
+
"capture_baseline_sigmas",
|
|
1415
|
+
"scan_model_gains",
|
|
1416
|
+
"compute_family_stats",
|
|
1417
|
+
"summarize_family_z_scores",
|
|
1418
|
+
"classify_module_family",
|
|
1419
|
+
]
|