invarlock 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- invarlock/__init__.py +33 -0
- invarlock/__main__.py +10 -0
- invarlock/_data/runtime/profiles/ci_cpu.yaml +15 -0
- invarlock/_data/runtime/profiles/release.yaml +23 -0
- invarlock/_data/runtime/tiers.yaml +76 -0
- invarlock/adapters/__init__.py +102 -0
- invarlock/adapters/_capabilities.py +45 -0
- invarlock/adapters/auto.py +99 -0
- invarlock/adapters/base.py +530 -0
- invarlock/adapters/base_types.py +85 -0
- invarlock/adapters/hf_bert.py +852 -0
- invarlock/adapters/hf_gpt2.py +403 -0
- invarlock/adapters/hf_llama.py +485 -0
- invarlock/adapters/hf_mixin.py +383 -0
- invarlock/adapters/hf_onnx.py +112 -0
- invarlock/adapters/hf_t5.py +137 -0
- invarlock/adapters/py.typed +1 -0
- invarlock/assurance/__init__.py +43 -0
- invarlock/cli/__init__.py +8 -0
- invarlock/cli/__main__.py +8 -0
- invarlock/cli/_evidence.py +25 -0
- invarlock/cli/_json.py +75 -0
- invarlock/cli/adapter_auto.py +162 -0
- invarlock/cli/app.py +287 -0
- invarlock/cli/commands/__init__.py +26 -0
- invarlock/cli/commands/certify.py +403 -0
- invarlock/cli/commands/doctor.py +1358 -0
- invarlock/cli/commands/explain_gates.py +151 -0
- invarlock/cli/commands/export_html.py +100 -0
- invarlock/cli/commands/plugins.py +1331 -0
- invarlock/cli/commands/report.py +354 -0
- invarlock/cli/commands/run.py +4146 -0
- invarlock/cli/commands/verify.py +1040 -0
- invarlock/cli/config.py +396 -0
- invarlock/cli/constants.py +68 -0
- invarlock/cli/device.py +92 -0
- invarlock/cli/doctor_helpers.py +74 -0
- invarlock/cli/errors.py +6 -0
- invarlock/cli/overhead_utils.py +60 -0
- invarlock/cli/provenance.py +66 -0
- invarlock/cli/utils.py +41 -0
- invarlock/config.py +56 -0
- invarlock/core/__init__.py +62 -0
- invarlock/core/abi.py +15 -0
- invarlock/core/api.py +274 -0
- invarlock/core/auto_tuning.py +317 -0
- invarlock/core/bootstrap.py +226 -0
- invarlock/core/checkpoint.py +221 -0
- invarlock/core/contracts.py +73 -0
- invarlock/core/error_utils.py +64 -0
- invarlock/core/events.py +298 -0
- invarlock/core/exceptions.py +95 -0
- invarlock/core/registry.py +481 -0
- invarlock/core/retry.py +146 -0
- invarlock/core/runner.py +2041 -0
- invarlock/core/types.py +154 -0
- invarlock/edits/__init__.py +12 -0
- invarlock/edits/_edit_utils.py +249 -0
- invarlock/edits/_external_utils.py +268 -0
- invarlock/edits/noop.py +47 -0
- invarlock/edits/py.typed +1 -0
- invarlock/edits/quant_rtn.py +801 -0
- invarlock/edits/registry.py +166 -0
- invarlock/eval/__init__.py +23 -0
- invarlock/eval/bench.py +1207 -0
- invarlock/eval/bootstrap.py +50 -0
- invarlock/eval/data.py +2052 -0
- invarlock/eval/metrics.py +2167 -0
- invarlock/eval/primary_metric.py +767 -0
- invarlock/eval/probes/__init__.py +24 -0
- invarlock/eval/probes/fft.py +139 -0
- invarlock/eval/probes/mi.py +213 -0
- invarlock/eval/probes/post_attention.py +323 -0
- invarlock/eval/providers/base.py +67 -0
- invarlock/eval/providers/seq2seq.py +111 -0
- invarlock/eval/providers/text_lm.py +113 -0
- invarlock/eval/providers/vision_text.py +93 -0
- invarlock/eval/py.typed +1 -0
- invarlock/guards/__init__.py +18 -0
- invarlock/guards/_contracts.py +9 -0
- invarlock/guards/invariants.py +640 -0
- invarlock/guards/policies.py +805 -0
- invarlock/guards/py.typed +1 -0
- invarlock/guards/rmt.py +2097 -0
- invarlock/guards/spectral.py +1419 -0
- invarlock/guards/tier_config.py +354 -0
- invarlock/guards/variance.py +3298 -0
- invarlock/guards_ref/__init__.py +15 -0
- invarlock/guards_ref/rmt_ref.py +40 -0
- invarlock/guards_ref/spectral_ref.py +135 -0
- invarlock/guards_ref/variance_ref.py +60 -0
- invarlock/model_profile.py +353 -0
- invarlock/model_utils.py +221 -0
- invarlock/observability/__init__.py +10 -0
- invarlock/observability/alerting.py +535 -0
- invarlock/observability/core.py +546 -0
- invarlock/observability/exporters.py +565 -0
- invarlock/observability/health.py +588 -0
- invarlock/observability/metrics.py +457 -0
- invarlock/observability/py.typed +1 -0
- invarlock/observability/utils.py +553 -0
- invarlock/plugins/__init__.py +12 -0
- invarlock/plugins/hello_guard.py +33 -0
- invarlock/plugins/hf_awq_adapter.py +82 -0
- invarlock/plugins/hf_bnb_adapter.py +79 -0
- invarlock/plugins/hf_gptq_adapter.py +78 -0
- invarlock/plugins/py.typed +1 -0
- invarlock/py.typed +1 -0
- invarlock/reporting/__init__.py +7 -0
- invarlock/reporting/certificate.py +3221 -0
- invarlock/reporting/certificate_schema.py +244 -0
- invarlock/reporting/dataset_hashing.py +215 -0
- invarlock/reporting/guards_analysis.py +948 -0
- invarlock/reporting/html.py +32 -0
- invarlock/reporting/normalizer.py +235 -0
- invarlock/reporting/policy_utils.py +517 -0
- invarlock/reporting/primary_metric_utils.py +265 -0
- invarlock/reporting/render.py +1442 -0
- invarlock/reporting/report.py +903 -0
- invarlock/reporting/report_types.py +278 -0
- invarlock/reporting/utils.py +175 -0
- invarlock/reporting/validate.py +631 -0
- invarlock/security.py +176 -0
- invarlock/sparsity_utils.py +323 -0
- invarlock/utils/__init__.py +150 -0
- invarlock/utils/digest.py +45 -0
- invarlock-0.2.0.dist-info/METADATA +586 -0
- invarlock-0.2.0.dist-info/RECORD +132 -0
- invarlock-0.2.0.dist-info/WHEEL +5 -0
- invarlock-0.2.0.dist-info/entry_points.txt +20 -0
- invarlock-0.2.0.dist-info/licenses/LICENSE +201 -0
- invarlock-0.2.0.dist-info/top_level.txt +1 -0
invarlock/guards/rmt.py
ADDED
|
@@ -0,0 +1,2097 @@
|
|
|
1
|
+
"""
|
|
2
|
+
InvarLock – Safety: Random Matrix Theory (RMT) Health Check
|
|
3
|
+
=======================================================
|
|
4
|
+
|
|
5
|
+
Detect-only mode for v0: identifies singular value outliers that
|
|
6
|
+
deviate from the Marchenko-Pastur bulk distribution.
|
|
7
|
+
|
|
8
|
+
Based on insights from Słowik et al., 2025 linking MP outliers
|
|
9
|
+
to training instability.
|
|
10
|
+
"""
|
|
11
|
+
|
|
12
|
+
from __future__ import annotations
|
|
13
|
+
|
|
14
|
+
import math
|
|
15
|
+
from dataclasses import dataclass
|
|
16
|
+
from datetime import datetime
|
|
17
|
+
from typing import Any, Literal, TypedDict
|
|
18
|
+
|
|
19
|
+
import numpy as np
|
|
20
|
+
import torch
|
|
21
|
+
import torch.linalg as tla
|
|
22
|
+
import torch.nn as nn
|
|
23
|
+
|
|
24
|
+
from invarlock.cli._evidence import maybe_dump_guard_evidence
|
|
25
|
+
from invarlock.core.api import Guard
|
|
26
|
+
|
|
27
|
+
from ._contracts import guard_assert
|
|
28
|
+
|
|
29
|
+
__all__ = [
|
|
30
|
+
# Utility functions
|
|
31
|
+
"mp_bulk_edges",
|
|
32
|
+
"mp_bulk_edge",
|
|
33
|
+
"layer_svd_stats",
|
|
34
|
+
"rmt_detect",
|
|
35
|
+
"rmt_detect_report",
|
|
36
|
+
"rmt_detect_with_names",
|
|
37
|
+
"clip_full_svd",
|
|
38
|
+
"analyze_weight_distribution",
|
|
39
|
+
"rmt_growth_ratio",
|
|
40
|
+
"within_deadband",
|
|
41
|
+
"capture_baseline_mp_stats",
|
|
42
|
+
# Guard classes and types
|
|
43
|
+
"RMTGuard",
|
|
44
|
+
"RMTPolicy",
|
|
45
|
+
"RMTPolicyDict",
|
|
46
|
+
# Policy utilities
|
|
47
|
+
"get_rmt_policy",
|
|
48
|
+
"create_custom_rmt_policy",
|
|
49
|
+
]
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
def mp_bulk_edges(m: int, n: int, whitened: bool = True) -> tuple[float, float]:
|
|
53
|
+
"""
|
|
54
|
+
Compute Marchenko-Pastur bulk edges for an m×n matrix.
|
|
55
|
+
|
|
56
|
+
For a weight matrix W ∈ ℝ^{m×n}, the MP distribution describes
|
|
57
|
+
the eigenvalues of (W^T W)/m when entries are i.i.d. with variance 1/m.
|
|
58
|
+
|
|
59
|
+
Args:
|
|
60
|
+
m: Number of rows (input features for Conv1D)
|
|
61
|
+
n: Number of columns (output features for Conv1D)
|
|
62
|
+
whitened: If True, assumes W is already whitened by √m
|
|
63
|
+
|
|
64
|
+
Returns:
|
|
65
|
+
(σ_min, σ_max) theoretical bulk edges for singular values
|
|
66
|
+
"""
|
|
67
|
+
if m == 0 or n == 0:
|
|
68
|
+
return 0.0, 0.0
|
|
69
|
+
|
|
70
|
+
# q = n/m (aspect ratio)
|
|
71
|
+
q = n / m
|
|
72
|
+
|
|
73
|
+
if whitened:
|
|
74
|
+
# For whitened matrix W/√m, singular values follow MP with:
|
|
75
|
+
sigma_max = 1.0 + np.sqrt(q)
|
|
76
|
+
sigma_min = abs(1.0 - np.sqrt(q)) if q <= 1 else 0.0
|
|
77
|
+
else:
|
|
78
|
+
# For unwhitened matrix, scale by √m
|
|
79
|
+
sigma_max = np.sqrt(m) * (1.0 + np.sqrt(q))
|
|
80
|
+
sigma_min = np.sqrt(m) * abs(1.0 - np.sqrt(q)) if q <= 1 else 0.0
|
|
81
|
+
|
|
82
|
+
return sigma_min, sigma_max
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
def mp_bulk_edge(m: int, n: int, whitened: bool = False) -> float:
|
|
86
|
+
"""
|
|
87
|
+
Compute Marchenko-Pastur bulk edge for an m×n matrix.
|
|
88
|
+
|
|
89
|
+
This function computes the upper edge (maximum singular value) of the
|
|
90
|
+
Marchenko-Pastur distribution, which represents the theoretical maximum
|
|
91
|
+
singular value for a random matrix with i.i.d. entries.
|
|
92
|
+
|
|
93
|
+
Args:
|
|
94
|
+
m: Number of rows (input features for Conv1D)
|
|
95
|
+
n: Number of columns (output features for Conv1D)
|
|
96
|
+
whitened: If True, assumes W is already whitened by √m
|
|
97
|
+
|
|
98
|
+
Returns:
|
|
99
|
+
σ_max theoretical bulk edge for singular values
|
|
100
|
+
"""
|
|
101
|
+
if m == 0 or n == 0:
|
|
102
|
+
return 0.0
|
|
103
|
+
|
|
104
|
+
# q = n/m (aspect ratio)
|
|
105
|
+
q = n / m
|
|
106
|
+
|
|
107
|
+
if whitened:
|
|
108
|
+
# For whitened matrix W/√m, singular values follow MP with:
|
|
109
|
+
sigma_max = 1.0 + np.sqrt(q)
|
|
110
|
+
else:
|
|
111
|
+
# For unwhitened matrix, scale by √m
|
|
112
|
+
sigma_max = np.sqrt(m) * (1.0 + np.sqrt(q))
|
|
113
|
+
|
|
114
|
+
return float(sigma_max)
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
def _iter_weight_matrices(layer: nn.Module):
|
|
118
|
+
"""Iterate over 2D weight matrices in a layer."""
|
|
119
|
+
for name, param in layer.named_parameters():
|
|
120
|
+
if param.ndim == 2 and "weight" in name:
|
|
121
|
+
yield name, param.detach()
|
|
122
|
+
|
|
123
|
+
|
|
124
|
+
def rmt_growth_ratio(
|
|
125
|
+
sigma_cur: float, mp_cur: float, sigma_base: float, mp_base: float
|
|
126
|
+
) -> float:
|
|
127
|
+
"""
|
|
128
|
+
Compute baseline-aware growth ratio for RMT outlier detection.
|
|
129
|
+
|
|
130
|
+
Compares the growth of σ/mp_edge ratio relative to baseline.
|
|
131
|
+
|
|
132
|
+
Args:
|
|
133
|
+
sigma_cur: Current maximum singular value
|
|
134
|
+
mp_cur: Current MP bulk edge
|
|
135
|
+
sigma_base: Baseline maximum singular value
|
|
136
|
+
mp_base: Baseline MP bulk edge
|
|
137
|
+
|
|
138
|
+
Returns:
|
|
139
|
+
Growth ratio: (σ_cur / mp_cur) / (σ_base / mp_base)
|
|
140
|
+
"""
|
|
141
|
+
r_base = sigma_base / max(mp_base, 1e-12)
|
|
142
|
+
r_cur = sigma_cur / max(mp_cur, 1e-12)
|
|
143
|
+
return r_cur / max(r_base, 1e-12)
|
|
144
|
+
|
|
145
|
+
|
|
146
|
+
def within_deadband(sigma_cur: float, sigma_base: float, deadband: float) -> bool:
|
|
147
|
+
"""
|
|
148
|
+
Check if current sigma is within deadband of baseline.
|
|
149
|
+
|
|
150
|
+
Args:
|
|
151
|
+
sigma_cur: Current spectral norm
|
|
152
|
+
sigma_base: Baseline spectral norm
|
|
153
|
+
deadband: Deadband threshold (e.g., 0.1 for 10%)
|
|
154
|
+
|
|
155
|
+
Returns:
|
|
156
|
+
True if within deadband threshold
|
|
157
|
+
"""
|
|
158
|
+
return sigma_cur <= (1.0 + deadband) * sigma_base
|
|
159
|
+
|
|
160
|
+
|
|
161
|
+
def layer_svd_stats(
|
|
162
|
+
layer: nn.Module,
|
|
163
|
+
baseline_sigmas: dict[str, float] | None = None,
|
|
164
|
+
baseline_mp_stats: dict[str, dict[str, float]] | None = None,
|
|
165
|
+
module_name: str | None = None,
|
|
166
|
+
) -> dict[str, float]:
|
|
167
|
+
"""
|
|
168
|
+
Compute SVD statistics for a single layer with baseline-aware normalization.
|
|
169
|
+
|
|
170
|
+
For HuggingFace Conv1D layers:
|
|
171
|
+
- Weight shape is (in_features, out_features)
|
|
172
|
+
- m = in_features, n = out_features
|
|
173
|
+
|
|
174
|
+
Args:
|
|
175
|
+
layer: Transformer layer to analyze
|
|
176
|
+
baseline_sigmas: Optional baseline singular values for baseline-aware comparison
|
|
177
|
+
baseline_mp_stats: Optional baseline MP statistics (mp_bulk_edge, r_mp_base) for each weight matrix
|
|
178
|
+
module_name: Optional module name for baseline lookups
|
|
179
|
+
|
|
180
|
+
Returns:
|
|
181
|
+
Dict with sigma_min, sigma_max, worst_ratio
|
|
182
|
+
"""
|
|
183
|
+
sigma_min_global = float("inf")
|
|
184
|
+
sigma_max_global = 0.0
|
|
185
|
+
worst_ratio = 0.0
|
|
186
|
+
worst_details = None
|
|
187
|
+
|
|
188
|
+
for name, W in _iter_weight_matrices(layer):
|
|
189
|
+
if W.numel() == 0:
|
|
190
|
+
continue
|
|
191
|
+
if not torch.isfinite(W).all():
|
|
192
|
+
continue
|
|
193
|
+
|
|
194
|
+
# For Conv1D: W.shape = (in_features, out_features)
|
|
195
|
+
m, n = W.shape # m = in_features, n = out_features
|
|
196
|
+
|
|
197
|
+
# Compute singular values of the actual matrix
|
|
198
|
+
try:
|
|
199
|
+
s_actual = tla.svdvals(W.float().cpu())
|
|
200
|
+
s_min = s_actual[-1].item()
|
|
201
|
+
s_max = s_actual[0].item()
|
|
202
|
+
except (RuntimeError, torch.linalg.LinAlgError):
|
|
203
|
+
continue
|
|
204
|
+
|
|
205
|
+
# Track global min/max
|
|
206
|
+
sigma_min_global = min(sigma_min_global, s_min)
|
|
207
|
+
sigma_max_global = max(sigma_max_global, s_max)
|
|
208
|
+
|
|
209
|
+
# Baseline-aware ratio computation for better outlier detection
|
|
210
|
+
if baseline_sigmas and module_name and module_name in baseline_sigmas:
|
|
211
|
+
# Use baseline-aware growth ratio (preferred method)
|
|
212
|
+
baseline_sigma = baseline_sigmas[module_name]
|
|
213
|
+
if baseline_sigma > 0:
|
|
214
|
+
# Compute current MP edge
|
|
215
|
+
mp_edge_current = mp_bulk_edge(m, n, whitened=False)
|
|
216
|
+
|
|
217
|
+
# Get baseline MP edge from stored stats, or fallback to current
|
|
218
|
+
if baseline_mp_stats and module_name in baseline_mp_stats:
|
|
219
|
+
mp_edge_baseline = baseline_mp_stats[module_name].get(
|
|
220
|
+
"mp_bulk_edge_base", mp_edge_current
|
|
221
|
+
)
|
|
222
|
+
else:
|
|
223
|
+
# Fallback: assume same shape so use same MP edge
|
|
224
|
+
mp_edge_baseline = mp_edge_current
|
|
225
|
+
|
|
226
|
+
# Use new helper function for consistent growth ratio calculation
|
|
227
|
+
ratio = rmt_growth_ratio(
|
|
228
|
+
s_max, mp_edge_current, baseline_sigma, mp_edge_baseline
|
|
229
|
+
)
|
|
230
|
+
else:
|
|
231
|
+
ratio = 1.0
|
|
232
|
+
else:
|
|
233
|
+
# Fallback: Use quantile-based normalization when no baseline available
|
|
234
|
+
if len(s_actual) > 1:
|
|
235
|
+
# Use 98th percentile as robust baseline (less sensitive to outliers)
|
|
236
|
+
s_sorted = s_actual.sort()[0]
|
|
237
|
+
idx_98 = int(0.98 * len(s_sorted))
|
|
238
|
+
s_98 = s_sorted[idx_98].item()
|
|
239
|
+
|
|
240
|
+
if s_98 > 0:
|
|
241
|
+
# Ratio relative to 98th percentile
|
|
242
|
+
ratio = s_max / s_98
|
|
243
|
+
else:
|
|
244
|
+
ratio = 1.0
|
|
245
|
+
else:
|
|
246
|
+
# Single singular value
|
|
247
|
+
ratio = 1.0
|
|
248
|
+
|
|
249
|
+
# Track worst deviation
|
|
250
|
+
if ratio > worst_ratio:
|
|
251
|
+
worst_ratio = ratio
|
|
252
|
+
worst_details = {
|
|
253
|
+
"name": name,
|
|
254
|
+
"shape": (m, n),
|
|
255
|
+
"s_max": s_max,
|
|
256
|
+
"s_min": s_min,
|
|
257
|
+
"s_median": s_actual.median().item() if len(s_actual) > 1 else s_max,
|
|
258
|
+
"s_98": s_actual.sort()[0][int(0.98 * len(s_actual))].item()
|
|
259
|
+
if len(s_actual) > 1
|
|
260
|
+
else s_max,
|
|
261
|
+
"ratio": ratio,
|
|
262
|
+
"mp_edge": mp_bulk_edge(m, n, whitened=False),
|
|
263
|
+
"normalization": "baseline_aware"
|
|
264
|
+
if baseline_sigmas and module_name and module_name in baseline_sigmas
|
|
265
|
+
else "98th_percentile",
|
|
266
|
+
}
|
|
267
|
+
|
|
268
|
+
result = {
|
|
269
|
+
"sigma_min": sigma_min_global,
|
|
270
|
+
"sigma_max": sigma_max_global,
|
|
271
|
+
"worst_ratio": worst_ratio,
|
|
272
|
+
}
|
|
273
|
+
|
|
274
|
+
if worst_details:
|
|
275
|
+
result["worst_details"] = worst_details
|
|
276
|
+
|
|
277
|
+
return result
|
|
278
|
+
|
|
279
|
+
|
|
280
|
+
def capture_baseline_mp_stats(model: nn.Module) -> dict[str, dict[str, float]]:
|
|
281
|
+
"""
|
|
282
|
+
Capture baseline MP statistics for linear layers only.
|
|
283
|
+
|
|
284
|
+
CRITICAL: Only includes layers where MP analysis makes sense:
|
|
285
|
+
- attn.c_attn, attn.c_proj, mlp.c_fc, mlp.c_proj
|
|
286
|
+
- EXCLUDES: wte, wpe, lm_head, layer norms, biases
|
|
287
|
+
|
|
288
|
+
Stores mp_bulk_edge and r_mp_base (sigma/mp_edge ratio) for each weight matrix.
|
|
289
|
+
This enables true baseline-aware RMT detection.
|
|
290
|
+
|
|
291
|
+
Args:
|
|
292
|
+
model: Model to analyze
|
|
293
|
+
|
|
294
|
+
Returns:
|
|
295
|
+
Dict mapping module names to their MP statistics:
|
|
296
|
+
{
|
|
297
|
+
'module_name': {
|
|
298
|
+
'mp_bulk_edge_base': float,
|
|
299
|
+
'r_mp_base': float,
|
|
300
|
+
'sigma_base': float
|
|
301
|
+
}
|
|
302
|
+
}
|
|
303
|
+
"""
|
|
304
|
+
mp_stats = {}
|
|
305
|
+
|
|
306
|
+
# Get all modules with 2D weight matrices
|
|
307
|
+
try:
|
|
308
|
+
from transformers.pytorch_utils import Conv1D
|
|
309
|
+
|
|
310
|
+
module_types_with_conv1d: tuple[
|
|
311
|
+
type[nn.Linear], type[nn.Conv1d], type[Conv1D]
|
|
312
|
+
] = (nn.Linear, nn.Conv1d, Conv1D)
|
|
313
|
+
module_types = module_types_with_conv1d
|
|
314
|
+
except ImportError:
|
|
315
|
+
module_types_without_conv1d: tuple[type[nn.Linear], type[nn.Conv1d]] = (
|
|
316
|
+
nn.Linear,
|
|
317
|
+
nn.Conv1d,
|
|
318
|
+
)
|
|
319
|
+
module_types = module_types_without_conv1d
|
|
320
|
+
|
|
321
|
+
# Define allowlist for RMT analysis - only linear layers where MP makes sense
|
|
322
|
+
allowed_suffixes = [".attn.c_attn", ".attn.c_proj", ".mlp.c_fc", ".mlp.c_proj"]
|
|
323
|
+
|
|
324
|
+
for name, module in model.named_modules():
|
|
325
|
+
if isinstance(module, module_types) and hasattr(module, "weight"):
|
|
326
|
+
# CRITICAL: Restrict to only linear layers where MP analysis is meaningful
|
|
327
|
+
# Skip embeddings, LM head, layer norms - MP heuristics don't apply there
|
|
328
|
+
if any(name.endswith(suffix) for suffix in allowed_suffixes):
|
|
329
|
+
# Get 2D weight matrix
|
|
330
|
+
for param_name, param in module.named_parameters(recurse=False):
|
|
331
|
+
if param.ndim == 2 and "weight" in param_name:
|
|
332
|
+
W = param.detach()
|
|
333
|
+
|
|
334
|
+
# Handle Conv1D transposition
|
|
335
|
+
try:
|
|
336
|
+
from transformers.pytorch_utils import Conv1D
|
|
337
|
+
|
|
338
|
+
if isinstance(module, Conv1D):
|
|
339
|
+
W = W.T
|
|
340
|
+
except ImportError:
|
|
341
|
+
pass
|
|
342
|
+
|
|
343
|
+
if W.ndim == 2:
|
|
344
|
+
m, n = W.shape
|
|
345
|
+
|
|
346
|
+
# Compute current sigma and MP edge
|
|
347
|
+
if not torch.isfinite(W).all():
|
|
348
|
+
continue
|
|
349
|
+
try:
|
|
350
|
+
s_actual = torch.linalg.svdvals(W.float().cpu())
|
|
351
|
+
sigma_base = s_actual[0].item()
|
|
352
|
+
mp_edge_base = mp_bulk_edge(m, n, whitened=False)
|
|
353
|
+
|
|
354
|
+
# Compute baseline r_mp ratio
|
|
355
|
+
r_mp_base = sigma_base / max(mp_edge_base, 1e-12)
|
|
356
|
+
|
|
357
|
+
# Store statistics with consistent naming
|
|
358
|
+
mp_stats[name] = {
|
|
359
|
+
"mp_bulk_edge_base": mp_edge_base,
|
|
360
|
+
"r_mp_base": r_mp_base,
|
|
361
|
+
"sigma_base": sigma_base,
|
|
362
|
+
}
|
|
363
|
+
except (RuntimeError, torch.linalg.LinAlgError):
|
|
364
|
+
# Skip if SVD fails
|
|
365
|
+
continue
|
|
366
|
+
break # Only process first weight parameter
|
|
367
|
+
|
|
368
|
+
return mp_stats
|
|
369
|
+
|
|
370
|
+
|
|
371
|
+
def _iter_transformer_layers(model: nn.Module):
|
|
372
|
+
"""Iterate over transformer layers in a model."""
|
|
373
|
+
if hasattr(model, "transformer") and hasattr(model.transformer, "h"):
|
|
374
|
+
# GPT-2 style
|
|
375
|
+
h_layers = model.transformer.h
|
|
376
|
+
if hasattr(h_layers, "__iter__") and hasattr(h_layers, "__len__"):
|
|
377
|
+
try:
|
|
378
|
+
for layer in h_layers:
|
|
379
|
+
yield layer
|
|
380
|
+
except (TypeError, AttributeError):
|
|
381
|
+
pass
|
|
382
|
+
elif hasattr(model, "model") and hasattr(model.model, "layers"):
|
|
383
|
+
# LLaMA style
|
|
384
|
+
layers = model.model.layers
|
|
385
|
+
if hasattr(layers, "__iter__") and hasattr(layers, "__len__"):
|
|
386
|
+
try:
|
|
387
|
+
for layer in layers:
|
|
388
|
+
yield layer
|
|
389
|
+
except (TypeError, AttributeError):
|
|
390
|
+
pass
|
|
391
|
+
elif hasattr(model, "encoder") and hasattr(model.encoder, "layer"):
|
|
392
|
+
# BERT style
|
|
393
|
+
layer_attr = model.encoder.layer
|
|
394
|
+
if hasattr(layer_attr, "__iter__") and hasattr(layer_attr, "__len__"):
|
|
395
|
+
try:
|
|
396
|
+
for layer in layer_attr:
|
|
397
|
+
yield layer
|
|
398
|
+
except (TypeError, AttributeError):
|
|
399
|
+
pass
|
|
400
|
+
else:
|
|
401
|
+
# Fallback
|
|
402
|
+
for module in model.modules():
|
|
403
|
+
if hasattr(module, "attn") and hasattr(module, "mlp"):
|
|
404
|
+
yield module
|
|
405
|
+
|
|
406
|
+
|
|
407
|
+
def rmt_detect(
|
|
408
|
+
model: nn.Module,
|
|
409
|
+
threshold: float = 1.5,
|
|
410
|
+
detect_only: bool = True,
|
|
411
|
+
correction_factor: float | None = None,
|
|
412
|
+
layer_indices: list[int] | None = None,
|
|
413
|
+
target_layers: list[str] | None = None, # Alternative layer specification
|
|
414
|
+
verbose: bool = False,
|
|
415
|
+
max_iterations: int = 2, # Add iteration guard
|
|
416
|
+
baseline_sigmas: dict[str, float]
|
|
417
|
+
| None = None, # Add baseline sigmas for baseline-aware checking
|
|
418
|
+
baseline_mp_stats: dict[str, dict[str, float]]
|
|
419
|
+
| None = None, # Store baseline MP statistics
|
|
420
|
+
deadband: float = 0.0, # Add deadband parameter to align with spectral control
|
|
421
|
+
use_quantile_mp: bool = False, # Use quantile-based MP edge for heavy-tailed spectra
|
|
422
|
+
) -> dict[str, Any]:
|
|
423
|
+
"""
|
|
424
|
+
Detect RMT outliers in model with baseline-aware checking and iteration guard.
|
|
425
|
+
|
|
426
|
+
Args:
|
|
427
|
+
model: Model to analyze
|
|
428
|
+
threshold: Ratio threshold for flagging outliers (default 1.5)
|
|
429
|
+
detect_only: If True, only detect outliers without correction
|
|
430
|
+
correction_factor: Factor to apply for correction (if not detect_only)
|
|
431
|
+
layer_indices: Specific layers to analyze by index (None = all)
|
|
432
|
+
target_layers: Specific layers to analyze by name (None = all)
|
|
433
|
+
verbose: Whether to print warnings and details
|
|
434
|
+
max_iterations: Maximum iterations for correction (default 2)
|
|
435
|
+
baseline_sigmas: Baseline sigmas for baseline-aware checking
|
|
436
|
+
baseline_mp_stats: Baseline MP statistics (mp_bulk_edge, r_mp_base) for each weight matrix
|
|
437
|
+
deadband: Deadband threshold to align with spectral control
|
|
438
|
+
use_quantile_mp: Use quantile-based MP edge for heavy-tailed spectra
|
|
439
|
+
|
|
440
|
+
Returns:
|
|
441
|
+
Dict with detection results including per-layer details
|
|
442
|
+
"""
|
|
443
|
+
per_layer: list[dict[str, Any]] = []
|
|
444
|
+
flagged_layers: list[int] = []
|
|
445
|
+
|
|
446
|
+
# Analyze only linear layers where MP analysis is meaningful
|
|
447
|
+
modules_to_analyze = []
|
|
448
|
+
|
|
449
|
+
# Define allowlist for RMT analysis - same as in capture_baseline_mp_stats
|
|
450
|
+
allowed_suffixes = [".attn.c_attn", ".attn.c_proj", ".mlp.c_fc", ".mlp.c_proj"]
|
|
451
|
+
|
|
452
|
+
if layer_indices is not None or target_layers is not None:
|
|
453
|
+
# If specific layers requested, only analyze transformer layers
|
|
454
|
+
for idx, layer in enumerate(_iter_transformer_layers(model)):
|
|
455
|
+
# Skip if not in specified layers (by index)
|
|
456
|
+
if layer_indices is not None and idx not in layer_indices:
|
|
457
|
+
continue
|
|
458
|
+
|
|
459
|
+
# Skip if not in specified layers (by name)
|
|
460
|
+
if target_layers is not None:
|
|
461
|
+
layer_name = None
|
|
462
|
+
for name, module in model.named_modules():
|
|
463
|
+
if module is layer:
|
|
464
|
+
layer_name = name
|
|
465
|
+
break
|
|
466
|
+
if layer_name is None or not any(
|
|
467
|
+
target in layer_name for target in target_layers
|
|
468
|
+
):
|
|
469
|
+
continue
|
|
470
|
+
|
|
471
|
+
modules_to_analyze.append((f"transformer_layer_{idx}", layer))
|
|
472
|
+
else:
|
|
473
|
+
# CRITICAL: Only analyze modules where MP analysis makes sense
|
|
474
|
+
# Exclude embeddings, LM head, layer norms - they have different spectral properties
|
|
475
|
+
for name, module in model.named_modules():
|
|
476
|
+
# Check if this is an allowed module type with 2D weights
|
|
477
|
+
if any(name.endswith(suffix) for suffix in allowed_suffixes):
|
|
478
|
+
has_2d_weights = any(
|
|
479
|
+
param.ndim == 2 and "weight" in param_name
|
|
480
|
+
for param_name, param in module.named_parameters(recurse=False)
|
|
481
|
+
)
|
|
482
|
+
if has_2d_weights:
|
|
483
|
+
modules_to_analyze.append((name, module))
|
|
484
|
+
|
|
485
|
+
# Iteration guard for correction
|
|
486
|
+
prev_outlier_count = float("inf")
|
|
487
|
+
correction_iterations = 0
|
|
488
|
+
|
|
489
|
+
while correction_iterations < max_iterations:
|
|
490
|
+
current_outliers = 0
|
|
491
|
+
per_layer = [] # Reset per iteration
|
|
492
|
+
flagged_layers = []
|
|
493
|
+
|
|
494
|
+
for idx, (module_name, module) in enumerate(modules_to_analyze):
|
|
495
|
+
# Use baseline-aware stats if available
|
|
496
|
+
stats = layer_svd_stats(
|
|
497
|
+
module, baseline_sigmas, baseline_mp_stats, module_name
|
|
498
|
+
)
|
|
499
|
+
|
|
500
|
+
# Apply baseline-aware RMT detection with deadband support
|
|
501
|
+
has_outlier = False
|
|
502
|
+
skip_reason = None
|
|
503
|
+
|
|
504
|
+
if (
|
|
505
|
+
baseline_sigmas
|
|
506
|
+
and baseline_mp_stats
|
|
507
|
+
and module_name in baseline_sigmas
|
|
508
|
+
and module_name in baseline_mp_stats
|
|
509
|
+
):
|
|
510
|
+
# Step 5 spec: ratio = σ_max_post / bulk_edge_base, flag if ratio > (1+deadband)*margin
|
|
511
|
+
sigma_post = stats["sigma_max"]
|
|
512
|
+
mp_stats = baseline_mp_stats[module_name]
|
|
513
|
+
bulk_edge_base = mp_stats.get("mp_bulk_edge_base", 1.0)
|
|
514
|
+
|
|
515
|
+
# Exact Step 5 detection rule
|
|
516
|
+
ratio = sigma_post / max(bulk_edge_base, 1e-12)
|
|
517
|
+
detection_threshold = (1.0 + deadband) * threshold
|
|
518
|
+
|
|
519
|
+
if ratio > detection_threshold:
|
|
520
|
+
has_outlier = True
|
|
521
|
+
skip_reason = None
|
|
522
|
+
else:
|
|
523
|
+
# Determine skip reason for clear logging
|
|
524
|
+
skip_reason = (
|
|
525
|
+
f"≤ threshold (ratio={ratio:.2f} ≤ {detection_threshold:.2f})"
|
|
526
|
+
)
|
|
527
|
+
elif deadband > 0.0 and baseline_sigmas and module_name in baseline_sigmas:
|
|
528
|
+
# Partial baseline-aware: deadband check only (fallback when no MP stats)
|
|
529
|
+
baseline_sigma = baseline_sigmas[module_name]
|
|
530
|
+
sigma_post = stats["sigma_max"]
|
|
531
|
+
ratio = sigma_post / max(baseline_sigma, 1e-12)
|
|
532
|
+
detection_threshold = (1.0 + deadband) * threshold
|
|
533
|
+
|
|
534
|
+
if ratio > detection_threshold:
|
|
535
|
+
has_outlier = True
|
|
536
|
+
skip_reason = None
|
|
537
|
+
else:
|
|
538
|
+
skip_reason = (
|
|
539
|
+
f"≤ threshold (ratio={ratio:.2f} ≤ {detection_threshold:.2f})"
|
|
540
|
+
)
|
|
541
|
+
else:
|
|
542
|
+
# Standard check without baseline awareness (fallback)
|
|
543
|
+
ratio = stats["worst_ratio"]
|
|
544
|
+
if ratio > threshold:
|
|
545
|
+
has_outlier = True
|
|
546
|
+
skip_reason = None
|
|
547
|
+
else:
|
|
548
|
+
skip_reason = f"≤ threshold (ratio={ratio:.2f} ≤ {threshold:.2f})"
|
|
549
|
+
|
|
550
|
+
layer_info = {
|
|
551
|
+
"layer": idx,
|
|
552
|
+
"module_name": module_name,
|
|
553
|
+
"sigma_min": stats["sigma_min"],
|
|
554
|
+
"sigma_max": stats["sigma_max"],
|
|
555
|
+
"worst_ratio": stats["worst_ratio"],
|
|
556
|
+
"has_outlier": has_outlier,
|
|
557
|
+
}
|
|
558
|
+
|
|
559
|
+
# Add detailed info if available
|
|
560
|
+
if "worst_details" in stats:
|
|
561
|
+
layer_info["details"] = stats["worst_details"]
|
|
562
|
+
|
|
563
|
+
per_layer.append(layer_info)
|
|
564
|
+
|
|
565
|
+
# Store skip reason in layer info for better logging
|
|
566
|
+
layer_info["skip_reason"] = skip_reason
|
|
567
|
+
|
|
568
|
+
if has_outlier:
|
|
569
|
+
flagged_layers.append(idx)
|
|
570
|
+
current_outliers += 1
|
|
571
|
+
if verbose:
|
|
572
|
+
normalization = stats.get("worst_details", {}).get(
|
|
573
|
+
"normalization", "unknown"
|
|
574
|
+
)
|
|
575
|
+
print(
|
|
576
|
+
f" Module {module_name}: ratio={stats['worst_ratio']:.2f} "
|
|
577
|
+
f"(σ_max={stats['sigma_max']:.2f}, norm={normalization})"
|
|
578
|
+
)
|
|
579
|
+
elif verbose and skip_reason:
|
|
580
|
+
print(f" Module {module_name}: SKIP: {skip_reason}")
|
|
581
|
+
|
|
582
|
+
# Apply correction if requested and not detect-only
|
|
583
|
+
if not detect_only and current_outliers > 0 and correction_factor is not None:
|
|
584
|
+
if correction_iterations == 0:
|
|
585
|
+
if verbose:
|
|
586
|
+
print(
|
|
587
|
+
f" Applying RMT correction (iteration {correction_iterations + 1})..."
|
|
588
|
+
)
|
|
589
|
+
# Apply correction to flagged modules
|
|
590
|
+
for idx in flagged_layers:
|
|
591
|
+
module_name, module = modules_to_analyze[idx]
|
|
592
|
+
_apply_rmt_correction(
|
|
593
|
+
module,
|
|
594
|
+
correction_factor,
|
|
595
|
+
baseline_sigmas,
|
|
596
|
+
baseline_mp_stats,
|
|
597
|
+
module_name,
|
|
598
|
+
deadband,
|
|
599
|
+
verbose,
|
|
600
|
+
adapter=None,
|
|
601
|
+
)
|
|
602
|
+
else:
|
|
603
|
+
# Check if improvement occurred
|
|
604
|
+
if current_outliers >= prev_outlier_count:
|
|
605
|
+
if verbose:
|
|
606
|
+
print(
|
|
607
|
+
f" RMT correction stalled ({current_outliers} outliers unchanged), "
|
|
608
|
+
f"downgrading to warning"
|
|
609
|
+
)
|
|
610
|
+
break
|
|
611
|
+
elif verbose:
|
|
612
|
+
print(
|
|
613
|
+
f" RMT correction improving ({prev_outlier_count} → {current_outliers} outliers)"
|
|
614
|
+
)
|
|
615
|
+
else:
|
|
616
|
+
# No correction requested, exit after first iteration
|
|
617
|
+
break
|
|
618
|
+
|
|
619
|
+
prev_outlier_count = current_outliers
|
|
620
|
+
correction_iterations += 1
|
|
621
|
+
|
|
622
|
+
# Exit if no outliers remain
|
|
623
|
+
if current_outliers == 0:
|
|
624
|
+
break
|
|
625
|
+
|
|
626
|
+
# Aggregate results
|
|
627
|
+
n_outliers = len(flagged_layers)
|
|
628
|
+
max_ratio = max((item["worst_ratio"] for item in per_layer), default=0.0)
|
|
629
|
+
has_outliers = n_outliers > 0
|
|
630
|
+
|
|
631
|
+
if verbose and has_outliers:
|
|
632
|
+
baseline_note = (
|
|
633
|
+
" (baseline-aware)"
|
|
634
|
+
if baseline_sigmas and baseline_mp_stats
|
|
635
|
+
else " (absolute)"
|
|
636
|
+
)
|
|
637
|
+
deadband_note = f" with {deadband:.0%} deadband" if deadband > 0.0 else ""
|
|
638
|
+
|
|
639
|
+
# Count detected vs will-be-capped
|
|
640
|
+
n_detected = n_outliers
|
|
641
|
+
n_will_be_capped = n_outliers if not detect_only else 0
|
|
642
|
+
|
|
643
|
+
print(f" ⚠️ RMT outliers detected{baseline_note}{deadband_note}:")
|
|
644
|
+
print(f" Detected: {n_detected}, will correct: {n_will_be_capped}")
|
|
645
|
+
print(f" Max ratio: {max_ratio:.2f}")
|
|
646
|
+
print(" Top offenders (σ_post / σ_ref):")
|
|
647
|
+
|
|
648
|
+
# Show top 3 offenders with detailed information
|
|
649
|
+
top_offenders = sorted(
|
|
650
|
+
[
|
|
651
|
+
(item["worst_ratio"], item["module_name"], item.get("details", {}))
|
|
652
|
+
for item in per_layer
|
|
653
|
+
if item["has_outlier"]
|
|
654
|
+
],
|
|
655
|
+
reverse=True,
|
|
656
|
+
)[:3]
|
|
657
|
+
|
|
658
|
+
for ratio, module_name, details in top_offenders:
|
|
659
|
+
sigma_max = details.get("s_max", 0.0)
|
|
660
|
+
ref_type = "mp_bulk_edge" if not baseline_sigmas else "baseline-aware"
|
|
661
|
+
print(
|
|
662
|
+
f" - {module_name}: {ratio:.2f} (σ_post={sigma_max:.2f}, ref={ref_type})"
|
|
663
|
+
)
|
|
664
|
+
|
|
665
|
+
if len(top_offenders) < n_outliers:
|
|
666
|
+
print(
|
|
667
|
+
f" ... and {n_outliers - len(top_offenders)} more layers flagged"
|
|
668
|
+
)
|
|
669
|
+
|
|
670
|
+
return {
|
|
671
|
+
"has_outliers": has_outliers,
|
|
672
|
+
"n_layers_flagged": n_outliers,
|
|
673
|
+
"outlier_count": n_outliers, # Alias for compatibility
|
|
674
|
+
"max_ratio": max_ratio,
|
|
675
|
+
"threshold": threshold,
|
|
676
|
+
"correction_iterations": correction_iterations,
|
|
677
|
+
"per_layer": per_layer,
|
|
678
|
+
"flagged_layers": flagged_layers,
|
|
679
|
+
"layers": {
|
|
680
|
+
f"layer_{item['layer']}": item for item in per_layer
|
|
681
|
+
}, # Alternative format
|
|
682
|
+
}
|
|
683
|
+
|
|
684
|
+
|
|
685
|
+
def rmt_detect_report(
|
|
686
|
+
model: nn.Module, threshold: float = 1.5
|
|
687
|
+
) -> tuple[dict, list[dict]]:
|
|
688
|
+
"""
|
|
689
|
+
Generate an RMT health report.
|
|
690
|
+
|
|
691
|
+
Args:
|
|
692
|
+
model: Model to analyze
|
|
693
|
+
threshold: Ratio threshold for outliers
|
|
694
|
+
|
|
695
|
+
Returns:
|
|
696
|
+
(summary_dict, per_layer_list)
|
|
697
|
+
"""
|
|
698
|
+
result = rmt_detect(model, threshold, verbose=False)
|
|
699
|
+
|
|
700
|
+
summary = {
|
|
701
|
+
"has_outliers": result["has_outliers"],
|
|
702
|
+
"n_layers_flagged": result["n_layers_flagged"],
|
|
703
|
+
"max_ratio": result["max_ratio"],
|
|
704
|
+
"rmt_max_ratio": result["max_ratio"], # Alias for compatibility
|
|
705
|
+
"rmt_has_outliers": result["has_outliers"], # Alias
|
|
706
|
+
}
|
|
707
|
+
|
|
708
|
+
return summary, result["per_layer"]
|
|
709
|
+
|
|
710
|
+
|
|
711
|
+
def rmt_detect_with_names(
|
|
712
|
+
model: nn.Module, threshold: float = 1.5, verbose: bool = False
|
|
713
|
+
) -> dict[str, Any]:
|
|
714
|
+
"""
|
|
715
|
+
Detect RMT outliers in model and return detailed information with module names.
|
|
716
|
+
|
|
717
|
+
Args:
|
|
718
|
+
model: Model to analyze
|
|
719
|
+
threshold: Ratio threshold for flagging outliers (default 1.5)
|
|
720
|
+
verbose: Whether to print warnings and details
|
|
721
|
+
|
|
722
|
+
Returns:
|
|
723
|
+
Dict with detection results including per-layer details and module names
|
|
724
|
+
"""
|
|
725
|
+
outliers = []
|
|
726
|
+
per_layer = []
|
|
727
|
+
flagged_layers = []
|
|
728
|
+
|
|
729
|
+
# Get all transformer layers with their names
|
|
730
|
+
layer_modules = []
|
|
731
|
+
if hasattr(model, "transformer") and hasattr(model.transformer, "h"):
|
|
732
|
+
# GPT-2 style
|
|
733
|
+
h_layers = model.transformer.h
|
|
734
|
+
if hasattr(h_layers, "__iter__"):
|
|
735
|
+
for idx, layer in enumerate(h_layers):
|
|
736
|
+
layer_modules.append((f"transformer.h.{idx}", layer))
|
|
737
|
+
elif hasattr(model, "model") and hasattr(model.model, "layers"):
|
|
738
|
+
# LLaMA style
|
|
739
|
+
layers = model.model.layers
|
|
740
|
+
if hasattr(layers, "__iter__"):
|
|
741
|
+
for idx, layer in enumerate(layers):
|
|
742
|
+
layer_modules.append((f"model.layers.{idx}", layer))
|
|
743
|
+
elif hasattr(model, "encoder") and hasattr(model.encoder, "layer"):
|
|
744
|
+
# BERT style
|
|
745
|
+
layer_attr = model.encoder.layer
|
|
746
|
+
if hasattr(layer_attr, "__iter__"):
|
|
747
|
+
for idx, layer in enumerate(layer_attr):
|
|
748
|
+
layer_modules.append((f"encoder.layer.{idx}", layer))
|
|
749
|
+
else:
|
|
750
|
+
# Fallback - try to find transformer layers by attributes
|
|
751
|
+
for name, module in model.named_modules():
|
|
752
|
+
if hasattr(module, "attn") and hasattr(module, "mlp"):
|
|
753
|
+
layer_modules.append((name, module))
|
|
754
|
+
|
|
755
|
+
for layer_name, layer in layer_modules:
|
|
756
|
+
stats = layer_svd_stats(layer, module_name=layer_name)
|
|
757
|
+
|
|
758
|
+
# Check if layer has outliers
|
|
759
|
+
has_outlier = stats["worst_ratio"] > threshold
|
|
760
|
+
|
|
761
|
+
# Add detailed info if available
|
|
762
|
+
if "worst_details" in stats:
|
|
763
|
+
layer_info = {
|
|
764
|
+
"layer_name": layer_name,
|
|
765
|
+
"sigma_min": stats["sigma_min"],
|
|
766
|
+
"sigma_max": stats["sigma_max"],
|
|
767
|
+
"worst_ratio": stats["worst_ratio"],
|
|
768
|
+
"has_outlier": has_outlier,
|
|
769
|
+
"details": stats["worst_details"],
|
|
770
|
+
}
|
|
771
|
+
|
|
772
|
+
# Add module name to outlier details
|
|
773
|
+
if has_outlier:
|
|
774
|
+
outlier_info = {
|
|
775
|
+
"layer_name": layer_name,
|
|
776
|
+
"module_name": f"{layer_name}.{stats['worst_details']['name']}",
|
|
777
|
+
"sigma_max": stats["sigma_max"],
|
|
778
|
+
"ratio": stats["worst_ratio"],
|
|
779
|
+
"details": stats["worst_details"],
|
|
780
|
+
}
|
|
781
|
+
outliers.append(outlier_info)
|
|
782
|
+
flagged_layers.append(layer_name)
|
|
783
|
+
else:
|
|
784
|
+
layer_info = {
|
|
785
|
+
"layer_name": layer_name,
|
|
786
|
+
"sigma_min": stats["sigma_min"],
|
|
787
|
+
"sigma_max": stats["sigma_max"],
|
|
788
|
+
"worst_ratio": stats["worst_ratio"],
|
|
789
|
+
"has_outlier": has_outlier,
|
|
790
|
+
}
|
|
791
|
+
|
|
792
|
+
per_layer.append(layer_info)
|
|
793
|
+
|
|
794
|
+
# Aggregate results
|
|
795
|
+
n_outliers = len(flagged_layers)
|
|
796
|
+
max_ratio = 0.0
|
|
797
|
+
if per_layer:
|
|
798
|
+
try:
|
|
799
|
+
max_ratio = max(float(item.get("worst_ratio", 0.0)) for item in per_layer)
|
|
800
|
+
except (TypeError, ValueError):
|
|
801
|
+
max_ratio = 0.0
|
|
802
|
+
has_outliers = n_outliers > 0
|
|
803
|
+
|
|
804
|
+
if verbose and has_outliers:
|
|
805
|
+
print(" ⚠️ RMT outliers detected:")
|
|
806
|
+
print(f" Layers flagged: {n_outliers}")
|
|
807
|
+
print(f" Max ratio: {max_ratio:.2f}")
|
|
808
|
+
print(f" Threshold: {threshold:.2f}")
|
|
809
|
+
print(" Top offenders (σ_post / σ_ref):")
|
|
810
|
+
# Show top offenders with full module names and consistent formatting
|
|
811
|
+
for outlier in outliers[:3]:
|
|
812
|
+
print(
|
|
813
|
+
f" - {outlier['module_name']}: {outlier['ratio']:.2f} (σ_post={outlier['sigma_max']:.2f}, ref=mp_bulk_edge)"
|
|
814
|
+
)
|
|
815
|
+
if len(outliers) > 3:
|
|
816
|
+
print(f" ... and {len(outliers) - 3} more layers flagged")
|
|
817
|
+
|
|
818
|
+
return {
|
|
819
|
+
"has_outliers": has_outliers,
|
|
820
|
+
"n_layers_flagged": n_outliers,
|
|
821
|
+
"outlier_count": n_outliers,
|
|
822
|
+
"max_ratio": max_ratio,
|
|
823
|
+
"threshold": threshold,
|
|
824
|
+
"per_layer": per_layer,
|
|
825
|
+
"flagged_layers": flagged_layers,
|
|
826
|
+
"outliers": outliers, # Add the outliers list with full module names
|
|
827
|
+
"layers": {item["layer_name"]: item for item in per_layer},
|
|
828
|
+
}
|
|
829
|
+
|
|
830
|
+
|
|
831
|
+
def _apply_rmt_correction(
|
|
832
|
+
layer: nn.Module,
|
|
833
|
+
factor: float,
|
|
834
|
+
baseline_sigmas: dict[str, float] | None = None,
|
|
835
|
+
baseline_mp_stats: dict[str, dict[str, float]] | None = None,
|
|
836
|
+
layer_name: str = "",
|
|
837
|
+
deadband: float = 0.0,
|
|
838
|
+
verbose: bool = False,
|
|
839
|
+
adapter=None,
|
|
840
|
+
):
|
|
841
|
+
"""
|
|
842
|
+
Apply RMT-based correction to layer weights with proper cap application.
|
|
843
|
+
|
|
844
|
+
Enhanced for Step 5 with:
|
|
845
|
+
- Step 5 detection rule: target = bulk_edge_base * margin * (1 - deadband)
|
|
846
|
+
- Adapter tying map support for preserving weight tying relationships
|
|
847
|
+
- IN-PLACE scaling (param.mul_) to preserve weight tying
|
|
848
|
+
- Never rewraps Parameters to avoid breaking lm_head ↔ wte aliasing
|
|
849
|
+
"""
|
|
850
|
+
for name, param in layer.named_parameters():
|
|
851
|
+
if param.ndim == 2 and "weight" in name:
|
|
852
|
+
with torch.no_grad():
|
|
853
|
+
# Get current spectral norm
|
|
854
|
+
try:
|
|
855
|
+
W = param.detach()
|
|
856
|
+
# Handle Conv1D transposition
|
|
857
|
+
Conv1D = None
|
|
858
|
+
try:
|
|
859
|
+
from transformers.pytorch_utils import Conv1D as _Conv1D
|
|
860
|
+
|
|
861
|
+
Conv1D = _Conv1D
|
|
862
|
+
|
|
863
|
+
if isinstance(layer, Conv1D):
|
|
864
|
+
W = W.T
|
|
865
|
+
except ImportError:
|
|
866
|
+
pass
|
|
867
|
+
|
|
868
|
+
if not torch.isfinite(W).all():
|
|
869
|
+
continue
|
|
870
|
+
s_vals = torch.linalg.svdvals(W.float().cpu())
|
|
871
|
+
sigma_pre = s_vals[0].item()
|
|
872
|
+
|
|
873
|
+
# Step 5 correction logic: target based on MP bulk edge
|
|
874
|
+
target_sigma = None
|
|
875
|
+
|
|
876
|
+
if (
|
|
877
|
+
baseline_sigmas
|
|
878
|
+
and baseline_mp_stats
|
|
879
|
+
and layer_name in baseline_mp_stats
|
|
880
|
+
):
|
|
881
|
+
# CORRECTED Step 5: Use baseline sigma for target calculation
|
|
882
|
+
mp_stats = baseline_mp_stats[layer_name]
|
|
883
|
+
sigma_base = mp_stats.get("sigma_base", 1.0)
|
|
884
|
+
|
|
885
|
+
# Step 5 target: baseline * margin * (1 - deadband) for conservative correction
|
|
886
|
+
margin = (
|
|
887
|
+
1.5 # Default from policy, could be passed as parameter
|
|
888
|
+
)
|
|
889
|
+
target_sigma = sigma_base * margin * (1.0 - deadband)
|
|
890
|
+
else:
|
|
891
|
+
# Fallback: Use current MP edge
|
|
892
|
+
m, n = W.shape
|
|
893
|
+
mp_edge = mp_bulk_edge(m, n, whitened=False)
|
|
894
|
+
target_sigma = mp_edge * 1.0 # Conservative cap at edge
|
|
895
|
+
|
|
896
|
+
# Apply correction only if needed
|
|
897
|
+
if sigma_pre > target_sigma:
|
|
898
|
+
# Compute proper scale: target/σ_pre
|
|
899
|
+
scale = target_sigma / sigma_pre
|
|
900
|
+
scale = max(
|
|
901
|
+
scale, 0.1
|
|
902
|
+
) # Floor at 10% to avoid extreme shrinkage
|
|
903
|
+
|
|
904
|
+
# Check for tied parameters using adapter's tying map
|
|
905
|
+
tied_params = []
|
|
906
|
+
if adapter and hasattr(adapter, "get_tying_map"):
|
|
907
|
+
try:
|
|
908
|
+
tying_map = adapter.get_tying_map()
|
|
909
|
+
full_param_name = f"{layer_name}.{name}"
|
|
910
|
+
tied_params = tying_map.get(full_param_name, [])
|
|
911
|
+
except Exception:
|
|
912
|
+
# Fallback if adapter doesn't support tying map
|
|
913
|
+
tied_params = []
|
|
914
|
+
|
|
915
|
+
# CRITICAL: Apply IN-PLACE scaling to preserve weight tying
|
|
916
|
+
param.mul_(scale) # PRESERVES TYING - same data pointer
|
|
917
|
+
|
|
918
|
+
# Apply same scaling to tied parameters if any
|
|
919
|
+
if tied_params and adapter:
|
|
920
|
+
for tied_name in tied_params:
|
|
921
|
+
try:
|
|
922
|
+
# Get tied parameter and apply same scale
|
|
923
|
+
tied_param = adapter.get_parameter_by_name(
|
|
924
|
+
tied_name
|
|
925
|
+
)
|
|
926
|
+
if tied_param is not None:
|
|
927
|
+
tied_param.mul_(scale)
|
|
928
|
+
except Exception:
|
|
929
|
+
# Continue if tied parameter access fails
|
|
930
|
+
pass
|
|
931
|
+
|
|
932
|
+
# Recompute sigma after scaling for accurate logging
|
|
933
|
+
W_after = param.detach()
|
|
934
|
+
if Conv1D is not None and isinstance(layer, Conv1D):
|
|
935
|
+
W_after = W_after.T
|
|
936
|
+
s_vals_after = torch.linalg.svdvals(W_after.float().cpu())
|
|
937
|
+
sigma_post = s_vals_after[0].item()
|
|
938
|
+
|
|
939
|
+
# Log the correction with proper values
|
|
940
|
+
if verbose:
|
|
941
|
+
tied_info = (
|
|
942
|
+
f", tied to {len(tied_params)} params"
|
|
943
|
+
if tied_params
|
|
944
|
+
else ""
|
|
945
|
+
)
|
|
946
|
+
print(
|
|
947
|
+
f" {layer_name}.{name}: σ={sigma_pre:.2f}→{sigma_post:.2f} "
|
|
948
|
+
f"(scale={scale:.3f}, target={target_sigma:.2f}{tied_info})"
|
|
949
|
+
)
|
|
950
|
+
else:
|
|
951
|
+
# No correction needed - log skip reason
|
|
952
|
+
if verbose:
|
|
953
|
+
print(
|
|
954
|
+
f" {layer_name}.{name}: SKIP: ≤ target (σ={sigma_pre:.2f} ≤ {target_sigma:.2f})"
|
|
955
|
+
)
|
|
956
|
+
|
|
957
|
+
except (RuntimeError, torch.linalg.LinAlgError):
|
|
958
|
+
# CRITICAL: Even fallback must use in-place scaling
|
|
959
|
+
param.mul_(factor)
|
|
960
|
+
if verbose:
|
|
961
|
+
print(
|
|
962
|
+
f" {layer_name}.{name}: fallback scaling (SVD failed)"
|
|
963
|
+
)
|
|
964
|
+
|
|
965
|
+
|
|
966
|
+
def clip_full_svd(
|
|
967
|
+
W: torch.Tensor, clip_val: float, return_components: bool = False
|
|
968
|
+
) -> torch.Tensor:
|
|
969
|
+
"""
|
|
970
|
+
Clip singular values of a matrix using full SVD.
|
|
971
|
+
|
|
972
|
+
Args:
|
|
973
|
+
W: Weight matrix
|
|
974
|
+
clip_val: Maximum singular value
|
|
975
|
+
return_components: If True, return (U, S_clipped, Vt)
|
|
976
|
+
|
|
977
|
+
Returns:
|
|
978
|
+
Clipped weight matrix or components
|
|
979
|
+
"""
|
|
980
|
+
if not torch.isfinite(W).all():
|
|
981
|
+
if return_components:
|
|
982
|
+
return None, None, None
|
|
983
|
+
return W
|
|
984
|
+
|
|
985
|
+
try:
|
|
986
|
+
U, S, Vt = torch.linalg.svd(W.float(), full_matrices=False)
|
|
987
|
+
S_clipped = torch.clamp(S, max=clip_val)
|
|
988
|
+
|
|
989
|
+
if return_components:
|
|
990
|
+
return U, S_clipped, Vt
|
|
991
|
+
else:
|
|
992
|
+
return (U @ torch.diag(S_clipped) @ Vt).to(W.dtype)
|
|
993
|
+
except (RuntimeError, torch.linalg.LinAlgError):
|
|
994
|
+
# Return original on error
|
|
995
|
+
if return_components:
|
|
996
|
+
return None, None, None
|
|
997
|
+
return W
|
|
998
|
+
|
|
999
|
+
|
|
1000
|
+
def analyze_weight_distribution(model: nn.Module, n_bins: int = 50) -> dict[str, Any]:
|
|
1001
|
+
"""
|
|
1002
|
+
Analyze weight distribution statistics for RMT analysis.
|
|
1003
|
+
|
|
1004
|
+
Args:
|
|
1005
|
+
model: Model to analyze
|
|
1006
|
+
n_bins: Number of histogram bins
|
|
1007
|
+
|
|
1008
|
+
Returns:
|
|
1009
|
+
Dict with distribution statistics
|
|
1010
|
+
"""
|
|
1011
|
+
all_weights = []
|
|
1012
|
+
all_singular_values = []
|
|
1013
|
+
|
|
1014
|
+
for name, param in model.named_parameters():
|
|
1015
|
+
if param.ndim == 2 and "weight" in name:
|
|
1016
|
+
param_cpu = param.detach().cpu()
|
|
1017
|
+
if not torch.isfinite(param_cpu).all():
|
|
1018
|
+
continue
|
|
1019
|
+
|
|
1020
|
+
# Collect weights
|
|
1021
|
+
all_weights.append(param_cpu.flatten())
|
|
1022
|
+
|
|
1023
|
+
# Collect singular values
|
|
1024
|
+
try:
|
|
1025
|
+
s = torch.linalg.svdvals(param_cpu.float())
|
|
1026
|
+
all_singular_values.append(s)
|
|
1027
|
+
except (RuntimeError, torch.linalg.LinAlgError):
|
|
1028
|
+
continue
|
|
1029
|
+
|
|
1030
|
+
if not all_weights:
|
|
1031
|
+
return {}
|
|
1032
|
+
|
|
1033
|
+
# Concatenate all weights
|
|
1034
|
+
weights = torch.cat(all_weights)
|
|
1035
|
+
|
|
1036
|
+
# Compute statistics
|
|
1037
|
+
stats = {
|
|
1038
|
+
"mean": weights.mean().item(),
|
|
1039
|
+
"std": weights.std().item(),
|
|
1040
|
+
"min": weights.min().item(),
|
|
1041
|
+
"max": weights.max().item(),
|
|
1042
|
+
"sparsity": (weights.abs() < 1e-6).float().mean().item(),
|
|
1043
|
+
}
|
|
1044
|
+
|
|
1045
|
+
# Compute histogram
|
|
1046
|
+
hist, edges = torch.histogram(weights, bins=n_bins)
|
|
1047
|
+
stats["histogram"] = hist.tolist()
|
|
1048
|
+
stats["bin_edges"] = edges.tolist()
|
|
1049
|
+
|
|
1050
|
+
# Singular value statistics
|
|
1051
|
+
if all_singular_values:
|
|
1052
|
+
s_all = torch.cat(all_singular_values)
|
|
1053
|
+
singular_values_dict: dict[str, float] = {
|
|
1054
|
+
"mean": s_all.mean().item(),
|
|
1055
|
+
"std": s_all.std().item(),
|
|
1056
|
+
"min": s_all.min().item(),
|
|
1057
|
+
"max": s_all.max().item(),
|
|
1058
|
+
"condition_number": (s_all.max() / (s_all.min() + 1e-8)).item(),
|
|
1059
|
+
}
|
|
1060
|
+
stats["singular_values"] = singular_values_dict
|
|
1061
|
+
|
|
1062
|
+
# Add MP edge information
|
|
1063
|
+
if all_singular_values:
|
|
1064
|
+
# Estimate MP edges from data
|
|
1065
|
+
n_samples: float = sum(s.shape[0] for s in all_singular_values)
|
|
1066
|
+
n_features: float = np.mean([s.shape[0] for s in all_singular_values])
|
|
1067
|
+
mp_min, mp_max = mp_bulk_edges(int(n_samples), int(n_features))
|
|
1068
|
+
mp_edges_dict: dict[str, float] = {"min": mp_min, "max": mp_max}
|
|
1069
|
+
stats["mp_edges"] = mp_edges_dict
|
|
1070
|
+
|
|
1071
|
+
# Add eigenvalue stats (alias for singular values)
|
|
1072
|
+
stats["eigenvalue_stats"] = stats["singular_values"]
|
|
1073
|
+
|
|
1074
|
+
return stats
|
|
1075
|
+
|
|
1076
|
+
|
|
1077
|
+
# === Guard Implementation ===
|
|
1078
|
+
|
|
1079
|
+
# Import GuardOutcome types if available
|
|
1080
|
+
try:
|
|
1081
|
+
from invarlock.core.types import GuardOutcome
|
|
1082
|
+
|
|
1083
|
+
HAS_GUARD_OUTCOME = True
|
|
1084
|
+
except ImportError:
|
|
1085
|
+
# Fallback for standalone usage or when types not available
|
|
1086
|
+
HAS_GUARD_OUTCOME = False
|
|
1087
|
+
GuardOutcome = dict
|
|
1088
|
+
|
|
1089
|
+
|
|
1090
|
+
@dataclass
|
|
1091
|
+
class RMTPolicy:
|
|
1092
|
+
"""
|
|
1093
|
+
RMT Guard Policy Configuration.
|
|
1094
|
+
|
|
1095
|
+
Defines parameters for baseline-aware RMT outlier detection and correction.
|
|
1096
|
+
"""
|
|
1097
|
+
|
|
1098
|
+
q: float | Literal["auto"] = (
|
|
1099
|
+
"auto" # MP aspect ratio m/n (auto-derived from weights)
|
|
1100
|
+
)
|
|
1101
|
+
deadband: float = 0.10 # Tolerance margin (10%)
|
|
1102
|
+
margin: float = 1.5 # RMT threshold ratio
|
|
1103
|
+
correct: bool = True # Enable automatic correction
|
|
1104
|
+
|
|
1105
|
+
|
|
1106
|
+
class RMTPolicyDict(TypedDict):
|
|
1107
|
+
"""TypedDict version of RMTPolicy for compatibility."""
|
|
1108
|
+
|
|
1109
|
+
q: float | Literal["auto"]
|
|
1110
|
+
deadband: float
|
|
1111
|
+
margin: float
|
|
1112
|
+
correct: bool
|
|
1113
|
+
epsilon: float | dict[str, float] | None
|
|
1114
|
+
|
|
1115
|
+
|
|
1116
|
+
class RMTGuard(Guard):
|
|
1117
|
+
"""
|
|
1118
|
+
Standalone RMT Guard for baseline-aware outlier detection and correction.
|
|
1119
|
+
|
|
1120
|
+
Implements Marchenko-Pastur theory-based spectral health checking with:
|
|
1121
|
+
- Baseline capture of MP bulk edges for linear layers
|
|
1122
|
+
- Conservative outlier detection with deadband support
|
|
1123
|
+
- Optional in-place correction preserving weight tying
|
|
1124
|
+
- Comprehensive event logging and metrics
|
|
1125
|
+
|
|
1126
|
+
Policy Structure:
|
|
1127
|
+
- q: MP aspect ratio (auto-derived or manual)
|
|
1128
|
+
- deadband: Tolerance margin before flagging (default 0.10 = 10%)
|
|
1129
|
+
- margin: RMT threshold ratio (default 1.5)
|
|
1130
|
+
- correct: Enable automatic correction (default True)
|
|
1131
|
+
|
|
1132
|
+
Linear Layer Scope (enforced):
|
|
1133
|
+
- attn.c_attn, attn.c_proj, mlp.c_fc, mlp.c_proj
|
|
1134
|
+
- Excludes: embeddings, LM head, layer norms, biases
|
|
1135
|
+
"""
|
|
1136
|
+
|
|
1137
|
+
name = "rmt"
|
|
1138
|
+
|
|
1139
|
+
def __init__(
|
|
1140
|
+
self,
|
|
1141
|
+
q: float | Literal["auto"] = "auto",
|
|
1142
|
+
deadband: float = 0.10,
|
|
1143
|
+
margin: float = 1.5,
|
|
1144
|
+
correct: bool = True,
|
|
1145
|
+
epsilon: float | dict[str, float] | None = None,
|
|
1146
|
+
):
|
|
1147
|
+
"""
|
|
1148
|
+
Initialize RMT Guard.
|
|
1149
|
+
|
|
1150
|
+
Args:
|
|
1151
|
+
q: MP aspect ratio (auto-derived from weight shapes if "auto")
|
|
1152
|
+
deadband: Tolerance margin before flagging outliers (0.10 = 10%)
|
|
1153
|
+
margin: RMT threshold ratio for outlier detection (1.5)
|
|
1154
|
+
correct: Enable automatic correction when outliers detected
|
|
1155
|
+
"""
|
|
1156
|
+
self.q = q
|
|
1157
|
+
self.deadband = deadband
|
|
1158
|
+
self.margin = margin
|
|
1159
|
+
self.correct = correct
|
|
1160
|
+
self.epsilon_default = 0.10
|
|
1161
|
+
self.epsilon_by_family: dict[str, float] = {}
|
|
1162
|
+
self._set_epsilon(epsilon)
|
|
1163
|
+
for family_key in ("attn", "ffn", "embed", "other"):
|
|
1164
|
+
self.epsilon_by_family.setdefault(family_key, self.epsilon_default)
|
|
1165
|
+
|
|
1166
|
+
# Internal state
|
|
1167
|
+
self.baseline_mp_stats: dict[str, dict[str, float]] | None = None
|
|
1168
|
+
self.baseline_sigmas: dict[str, float] | None = None
|
|
1169
|
+
self.prepared = False
|
|
1170
|
+
self.events: list[dict[str, Any]] = []
|
|
1171
|
+
self._last_result: dict[str, Any] | None = None
|
|
1172
|
+
self.adapter = None # Store adapter for tying map access
|
|
1173
|
+
|
|
1174
|
+
# Linear layer scope enforcement - same as existing RMT
|
|
1175
|
+
self.allowed_suffixes = [
|
|
1176
|
+
".attn.c_attn",
|
|
1177
|
+
".attn.c_proj",
|
|
1178
|
+
".mlp.c_fc",
|
|
1179
|
+
".mlp.c_proj",
|
|
1180
|
+
]
|
|
1181
|
+
self.baseline_outliers_per_family: dict[str, int] = {}
|
|
1182
|
+
self.baseline_total_outliers: int = 0
|
|
1183
|
+
self.outliers_per_family: dict[str, int] = {}
|
|
1184
|
+
self.outliers_total: int = 0
|
|
1185
|
+
self.epsilon_violations: list[dict[str, Any]] = []
|
|
1186
|
+
|
|
1187
|
+
def _log_event(
|
|
1188
|
+
self, operation: str, level: str = "INFO", message: str = "", **data
|
|
1189
|
+
):
|
|
1190
|
+
"""Log an event with timestamp."""
|
|
1191
|
+
event = {
|
|
1192
|
+
"timestamp": datetime.utcnow().isoformat(),
|
|
1193
|
+
"component": "rmt_guard",
|
|
1194
|
+
"operation": operation,
|
|
1195
|
+
"level": level,
|
|
1196
|
+
"message": message,
|
|
1197
|
+
"data": data,
|
|
1198
|
+
}
|
|
1199
|
+
self.events.append(event)
|
|
1200
|
+
|
|
1201
|
+
def _set_epsilon(self, epsilon: float | dict[str, float] | None) -> None:
|
|
1202
|
+
"""Configure epsilon defaults and per-family overrides."""
|
|
1203
|
+
if isinstance(epsilon, dict):
|
|
1204
|
+
mapped: dict[str, float] = {}
|
|
1205
|
+
for family, value in epsilon.items():
|
|
1206
|
+
try:
|
|
1207
|
+
mapped[str(family)] = float(value)
|
|
1208
|
+
except (TypeError, ValueError):
|
|
1209
|
+
continue
|
|
1210
|
+
if mapped:
|
|
1211
|
+
self.epsilon_by_family.update(mapped)
|
|
1212
|
+
self.epsilon_default = max(mapped.values())
|
|
1213
|
+
elif isinstance(epsilon, int | float):
|
|
1214
|
+
self.epsilon_default = float(epsilon)
|
|
1215
|
+
if self.epsilon_by_family:
|
|
1216
|
+
for family in list(self.epsilon_by_family):
|
|
1217
|
+
self.epsilon_by_family[family] = self.epsilon_default
|
|
1218
|
+
|
|
1219
|
+
@staticmethod
|
|
1220
|
+
def _classify_family(module_name: str) -> str:
|
|
1221
|
+
"""Classify module name into a guard family."""
|
|
1222
|
+
lower = module_name.lower()
|
|
1223
|
+
# MoE
|
|
1224
|
+
if any(
|
|
1225
|
+
tok in lower
|
|
1226
|
+
for tok in ("router", "routing", "gate", "gating", "dispatch", "switch")
|
|
1227
|
+
):
|
|
1228
|
+
return "router"
|
|
1229
|
+
if any(
|
|
1230
|
+
tok in lower for tok in ("experts", "expert", "moe", "mixture_of_experts")
|
|
1231
|
+
):
|
|
1232
|
+
return "expert_ffn"
|
|
1233
|
+
if ".attn." in lower or "attention" in lower:
|
|
1234
|
+
return "attn"
|
|
1235
|
+
if ".mlp." in lower or "ffn" in lower or ".c_fc" in lower:
|
|
1236
|
+
return "ffn"
|
|
1237
|
+
if "embed" in lower or "wte" in lower or "wpe" in lower:
|
|
1238
|
+
return "embed"
|
|
1239
|
+
return "other"
|
|
1240
|
+
|
|
1241
|
+
def _count_outliers_per_family(
|
|
1242
|
+
self, per_layer: list[dict[str, Any]]
|
|
1243
|
+
) -> dict[str, int]:
|
|
1244
|
+
"""Count outliers grouped by family."""
|
|
1245
|
+
counts: dict[str, int] = {}
|
|
1246
|
+
for layer_info in per_layer:
|
|
1247
|
+
if not layer_info.get("has_outlier"):
|
|
1248
|
+
continue
|
|
1249
|
+
module_name = layer_info.get("module_name", "")
|
|
1250
|
+
family = self._classify_family(module_name)
|
|
1251
|
+
counts[family] = counts.get(family, 0) + 1
|
|
1252
|
+
return counts
|
|
1253
|
+
|
|
1254
|
+
def _compute_epsilon_violations(self) -> list[dict[str, Any]]:
|
|
1255
|
+
"""Compute epsilon-rule violations per family."""
|
|
1256
|
+
violations: list[dict[str, Any]] = []
|
|
1257
|
+
families = set(self.outliers_per_family) | set(
|
|
1258
|
+
self.baseline_outliers_per_family
|
|
1259
|
+
)
|
|
1260
|
+
for family in families:
|
|
1261
|
+
bare = int(self.baseline_outliers_per_family.get(family, 0) or 0)
|
|
1262
|
+
guarded = int(self.outliers_per_family.get(family, 0) or 0)
|
|
1263
|
+
epsilon_val = float(
|
|
1264
|
+
self.epsilon_by_family.get(family, self.epsilon_default)
|
|
1265
|
+
)
|
|
1266
|
+
allowed = math.ceil(bare * (1 + epsilon_val))
|
|
1267
|
+
if guarded > allowed:
|
|
1268
|
+
violations.append(
|
|
1269
|
+
{
|
|
1270
|
+
"family": family,
|
|
1271
|
+
"bare": bare,
|
|
1272
|
+
"guarded": guarded,
|
|
1273
|
+
"allowed": allowed,
|
|
1274
|
+
"epsilon": epsilon_val,
|
|
1275
|
+
}
|
|
1276
|
+
)
|
|
1277
|
+
return violations
|
|
1278
|
+
|
|
1279
|
+
def _get_linear_modules(self, model: nn.Module) -> list[tuple[str, nn.Module]]:
|
|
1280
|
+
"""
|
|
1281
|
+
Get linear modules that are in scope for RMT analysis.
|
|
1282
|
+
|
|
1283
|
+
Args:
|
|
1284
|
+
model: Model to analyze
|
|
1285
|
+
|
|
1286
|
+
Returns:
|
|
1287
|
+
List of (name, module) tuples for linear layers in scope
|
|
1288
|
+
"""
|
|
1289
|
+
modules = []
|
|
1290
|
+
|
|
1291
|
+
# Get module types
|
|
1292
|
+
try:
|
|
1293
|
+
from transformers.pytorch_utils import Conv1D
|
|
1294
|
+
|
|
1295
|
+
module_types_with_conv1d_2: tuple[
|
|
1296
|
+
type[nn.Linear], type[nn.Conv1d], type[Conv1D]
|
|
1297
|
+
] = (nn.Linear, nn.Conv1d, Conv1D)
|
|
1298
|
+
module_types = module_types_with_conv1d_2
|
|
1299
|
+
except ImportError:
|
|
1300
|
+
module_types_without_conv1d_2: tuple[type[nn.Linear], type[nn.Conv1d]] = (
|
|
1301
|
+
nn.Linear,
|
|
1302
|
+
nn.Conv1d,
|
|
1303
|
+
)
|
|
1304
|
+
module_types = module_types_without_conv1d_2
|
|
1305
|
+
|
|
1306
|
+
modules: list[tuple[str, nn.Module]] = []
|
|
1307
|
+
for name, module in model.named_modules():
|
|
1308
|
+
if isinstance(module, module_types) and hasattr(module, "weight"):
|
|
1309
|
+
# Strict scope enforcement - only allowed linear layers
|
|
1310
|
+
if any(name.endswith(suffix) for suffix in self.allowed_suffixes):
|
|
1311
|
+
modules.append((name, module))
|
|
1312
|
+
|
|
1313
|
+
return modules
|
|
1314
|
+
|
|
1315
|
+
def _apply_rmt_detection_and_correction(self, model: nn.Module) -> dict[str, Any]:
|
|
1316
|
+
"""
|
|
1317
|
+
Apply Step 5 RMT detection and correction with adapter support.
|
|
1318
|
+
|
|
1319
|
+
Uses exact Step 5 detection rule: ratio = σ_max_post / bulk_edge_base
|
|
1320
|
+
Flag if ratio > (1+deadband)*margin
|
|
1321
|
+
"""
|
|
1322
|
+
per_layer = []
|
|
1323
|
+
flagged_layers = []
|
|
1324
|
+
corrected_layers = 0
|
|
1325
|
+
|
|
1326
|
+
# Get linear modules in scope
|
|
1327
|
+
modules_to_analyze = self._get_linear_modules(model)
|
|
1328
|
+
|
|
1329
|
+
self._log_event(
|
|
1330
|
+
"rmt_correction",
|
|
1331
|
+
message=f"Applying Step 5 detection and correction to {len(modules_to_analyze)} modules",
|
|
1332
|
+
)
|
|
1333
|
+
|
|
1334
|
+
for idx, (module_name, module) in enumerate(modules_to_analyze):
|
|
1335
|
+
# Get current stats
|
|
1336
|
+
stats = layer_svd_stats(
|
|
1337
|
+
module, self.baseline_sigmas, self.baseline_mp_stats, module_name
|
|
1338
|
+
)
|
|
1339
|
+
|
|
1340
|
+
# Step 5 detection rule
|
|
1341
|
+
has_outlier = False
|
|
1342
|
+
skip_reason = None
|
|
1343
|
+
|
|
1344
|
+
if self.baseline_mp_stats and module_name in self.baseline_mp_stats:
|
|
1345
|
+
sigma_post = stats["sigma_max"]
|
|
1346
|
+
mp_stats = self.baseline_mp_stats[module_name]
|
|
1347
|
+
sigma_base = mp_stats.get("sigma_base", 1.0)
|
|
1348
|
+
|
|
1349
|
+
# CORRECTED Step 5 detection rule: baseline-aware growth ratio
|
|
1350
|
+
# Compare current σ_max to baseline σ_max, normalized for stability
|
|
1351
|
+
ratio = sigma_post / max(sigma_base, 1e-12)
|
|
1352
|
+
detection_threshold = (1.0 + self.deadband) * self.margin
|
|
1353
|
+
|
|
1354
|
+
if ratio > detection_threshold:
|
|
1355
|
+
has_outlier = True
|
|
1356
|
+
|
|
1357
|
+
# Apply correction using enhanced logic with adapter support
|
|
1358
|
+
if self.correct:
|
|
1359
|
+
try:
|
|
1360
|
+
_apply_rmt_correction(
|
|
1361
|
+
module,
|
|
1362
|
+
0.95, # Conservative factor (not used in Step 5 logic)
|
|
1363
|
+
self.baseline_sigmas,
|
|
1364
|
+
self.baseline_mp_stats,
|
|
1365
|
+
module_name,
|
|
1366
|
+
self.deadband,
|
|
1367
|
+
verbose=False,
|
|
1368
|
+
adapter=self.adapter,
|
|
1369
|
+
)
|
|
1370
|
+
corrected_layers += 1
|
|
1371
|
+
|
|
1372
|
+
self._log_event(
|
|
1373
|
+
"rmt_correct",
|
|
1374
|
+
message=f"Applied correction to {module_name}",
|
|
1375
|
+
module_name=module_name,
|
|
1376
|
+
pre_ratio=ratio,
|
|
1377
|
+
threshold=detection_threshold,
|
|
1378
|
+
)
|
|
1379
|
+
|
|
1380
|
+
# Re-compute stats after correction
|
|
1381
|
+
stats_post = layer_svd_stats(
|
|
1382
|
+
module,
|
|
1383
|
+
self.baseline_sigmas,
|
|
1384
|
+
self.baseline_mp_stats,
|
|
1385
|
+
module_name,
|
|
1386
|
+
)
|
|
1387
|
+
mp_stats = self.baseline_mp_stats[module_name]
|
|
1388
|
+
bulk_edge_base = mp_stats.get("mp_bulk_edge_base", 1.0)
|
|
1389
|
+
ratio_post = stats_post["sigma_max"] / max(
|
|
1390
|
+
bulk_edge_base, 1e-12
|
|
1391
|
+
)
|
|
1392
|
+
|
|
1393
|
+
# Update has_outlier based on post-correction ratio
|
|
1394
|
+
has_outlier = ratio_post > detection_threshold
|
|
1395
|
+
|
|
1396
|
+
except Exception as e:
|
|
1397
|
+
self._log_event(
|
|
1398
|
+
"rmt_correct_failed",
|
|
1399
|
+
level="ERROR",
|
|
1400
|
+
message=f"Correction failed for {module_name}: {str(e)}",
|
|
1401
|
+
module_name=module_name,
|
|
1402
|
+
error=str(e),
|
|
1403
|
+
)
|
|
1404
|
+
else:
|
|
1405
|
+
skip_reason = (
|
|
1406
|
+
f"≤ threshold (ratio={ratio:.2f} ≤ {detection_threshold:.2f})"
|
|
1407
|
+
)
|
|
1408
|
+
else:
|
|
1409
|
+
# Fallback when no baseline MP stats
|
|
1410
|
+
ratio = stats["worst_ratio"]
|
|
1411
|
+
if ratio > self.margin:
|
|
1412
|
+
has_outlier = True
|
|
1413
|
+
else:
|
|
1414
|
+
skip_reason = f"≤ margin (ratio={ratio:.2f} ≤ {self.margin:.2f})"
|
|
1415
|
+
|
|
1416
|
+
layer_info = {
|
|
1417
|
+
"layer": idx,
|
|
1418
|
+
"module_name": module_name,
|
|
1419
|
+
"sigma_min": stats["sigma_min"],
|
|
1420
|
+
"sigma_max": stats["sigma_max"],
|
|
1421
|
+
"worst_ratio": stats["worst_ratio"],
|
|
1422
|
+
"has_outlier": has_outlier,
|
|
1423
|
+
"skip_reason": skip_reason,
|
|
1424
|
+
}
|
|
1425
|
+
|
|
1426
|
+
if "worst_details" in stats:
|
|
1427
|
+
layer_info["details"] = stats["worst_details"]
|
|
1428
|
+
|
|
1429
|
+
per_layer.append(layer_info)
|
|
1430
|
+
|
|
1431
|
+
if has_outlier:
|
|
1432
|
+
flagged_layers.append(idx)
|
|
1433
|
+
|
|
1434
|
+
# Aggregate results
|
|
1435
|
+
n_outliers = len(flagged_layers)
|
|
1436
|
+
max_ratio = max((float(item["worst_ratio"]) for item in per_layer), default=0.0)
|
|
1437
|
+
has_outliers = n_outliers > 0
|
|
1438
|
+
|
|
1439
|
+
return {
|
|
1440
|
+
"has_outliers": has_outliers,
|
|
1441
|
+
"n_layers_flagged": n_outliers,
|
|
1442
|
+
"outlier_count": n_outliers,
|
|
1443
|
+
"max_ratio": max_ratio,
|
|
1444
|
+
"threshold": self.margin,
|
|
1445
|
+
"correction_iterations": 1 if corrected_layers > 0 else 0,
|
|
1446
|
+
"corrected_layers": corrected_layers,
|
|
1447
|
+
"per_layer": per_layer,
|
|
1448
|
+
"flagged_layers": flagged_layers,
|
|
1449
|
+
"layers": {f"layer_{item['layer']}": item for item in per_layer},
|
|
1450
|
+
}
|
|
1451
|
+
|
|
1452
|
+
def prepare(
|
|
1453
|
+
self,
|
|
1454
|
+
model: nn.Module,
|
|
1455
|
+
adapter=None,
|
|
1456
|
+
calib=None,
|
|
1457
|
+
policy: dict[str, Any] | None = None,
|
|
1458
|
+
) -> dict[str, Any]:
|
|
1459
|
+
"""
|
|
1460
|
+
Prepare RMT guard by capturing baseline MP statistics.
|
|
1461
|
+
|
|
1462
|
+
Args:
|
|
1463
|
+
model: The model that will be edited
|
|
1464
|
+
adapter: ModelAdapter (optional, for tying map access)
|
|
1465
|
+
calib: Calibration data (unused for RMT)
|
|
1466
|
+
policy: Guard policy parameters (optional)
|
|
1467
|
+
|
|
1468
|
+
Returns:
|
|
1469
|
+
Dictionary with preparation results and baseline metrics
|
|
1470
|
+
"""
|
|
1471
|
+
import time
|
|
1472
|
+
|
|
1473
|
+
start_time = time.time()
|
|
1474
|
+
|
|
1475
|
+
# Store adapter for tying map access during correction
|
|
1476
|
+
self.adapter = adapter
|
|
1477
|
+
|
|
1478
|
+
# Update parameters from policy if provided
|
|
1479
|
+
if policy:
|
|
1480
|
+
self.q = policy.get("q", self.q)
|
|
1481
|
+
self.deadband = policy.get("deadband", self.deadband)
|
|
1482
|
+
self.margin = policy.get("margin", self.margin)
|
|
1483
|
+
self.correct = policy.get("correct", self.correct)
|
|
1484
|
+
if "epsilon" in policy:
|
|
1485
|
+
self._set_epsilon(policy["epsilon"])
|
|
1486
|
+
if "epsilon_by_family" in policy:
|
|
1487
|
+
self._set_epsilon(policy["epsilon_by_family"])
|
|
1488
|
+
|
|
1489
|
+
self._log_event(
|
|
1490
|
+
"prepare",
|
|
1491
|
+
message=f"Preparing RMT guard with q={self.q}, deadband={self.deadband}, margin={self.margin}, correct={self.correct}",
|
|
1492
|
+
)
|
|
1493
|
+
|
|
1494
|
+
try:
|
|
1495
|
+
# Capture baseline MP statistics for linear layers
|
|
1496
|
+
self.baseline_mp_stats = capture_baseline_mp_stats(model)
|
|
1497
|
+
|
|
1498
|
+
# Extract baseline sigmas for compatibility with existing detection
|
|
1499
|
+
self.baseline_sigmas = {}
|
|
1500
|
+
for name, stats in self.baseline_mp_stats.items():
|
|
1501
|
+
self.baseline_sigmas[name] = stats.get("sigma_base", 0.0)
|
|
1502
|
+
|
|
1503
|
+
# Get linear modules in scope
|
|
1504
|
+
linear_modules = self._get_linear_modules(model)
|
|
1505
|
+
|
|
1506
|
+
baseline_detection = rmt_detect(
|
|
1507
|
+
model=model,
|
|
1508
|
+
threshold=self.margin,
|
|
1509
|
+
detect_only=True,
|
|
1510
|
+
baseline_sigmas=self.baseline_sigmas,
|
|
1511
|
+
baseline_mp_stats=self.baseline_mp_stats,
|
|
1512
|
+
deadband=self.deadband,
|
|
1513
|
+
)
|
|
1514
|
+
self.baseline_total_outliers = baseline_detection.get("n_layers_flagged", 0)
|
|
1515
|
+
self.baseline_outliers_per_family = self._count_outliers_per_family(
|
|
1516
|
+
baseline_detection.get("per_layer", [])
|
|
1517
|
+
)
|
|
1518
|
+
for family_key in ("attn", "ffn", "embed", "other"):
|
|
1519
|
+
self.baseline_outliers_per_family.setdefault(family_key, 0)
|
|
1520
|
+
self.outliers_per_family = {}
|
|
1521
|
+
self.outliers_total = 0
|
|
1522
|
+
self.epsilon_violations = []
|
|
1523
|
+
|
|
1524
|
+
self.prepared = True
|
|
1525
|
+
preparation_time = time.time() - start_time
|
|
1526
|
+
|
|
1527
|
+
self._log_event(
|
|
1528
|
+
"prepare_success",
|
|
1529
|
+
message=f"Captured {len(self.baseline_mp_stats)} baseline MP statistics",
|
|
1530
|
+
baseline_count=len(self.baseline_mp_stats),
|
|
1531
|
+
linear_modules_count=len(linear_modules),
|
|
1532
|
+
preparation_time=preparation_time,
|
|
1533
|
+
)
|
|
1534
|
+
|
|
1535
|
+
return {
|
|
1536
|
+
"baseline_metrics": {
|
|
1537
|
+
"mp_stats_sample": dict(list(self.baseline_mp_stats.items())[:3]),
|
|
1538
|
+
"total_layers": len(self.baseline_mp_stats),
|
|
1539
|
+
"linear_modules_in_scope": len(linear_modules),
|
|
1540
|
+
"scope_suffixes": self.allowed_suffixes,
|
|
1541
|
+
"average_baseline_sigma": np.mean(
|
|
1542
|
+
list(self.baseline_sigmas.values())
|
|
1543
|
+
),
|
|
1544
|
+
"max_baseline_sigma": max(self.baseline_sigmas.values())
|
|
1545
|
+
if self.baseline_sigmas
|
|
1546
|
+
else 0.0,
|
|
1547
|
+
"min_baseline_sigma": min(self.baseline_sigmas.values())
|
|
1548
|
+
if self.baseline_sigmas
|
|
1549
|
+
else 0.0,
|
|
1550
|
+
},
|
|
1551
|
+
"policy_applied": {
|
|
1552
|
+
"q": self.q,
|
|
1553
|
+
"deadband": self.deadband,
|
|
1554
|
+
"margin": self.margin,
|
|
1555
|
+
"correct": self.correct,
|
|
1556
|
+
},
|
|
1557
|
+
"preparation_time": preparation_time,
|
|
1558
|
+
"ready": True,
|
|
1559
|
+
}
|
|
1560
|
+
|
|
1561
|
+
except Exception as e:
|
|
1562
|
+
self.prepared = False
|
|
1563
|
+
self._log_event(
|
|
1564
|
+
"prepare_failed",
|
|
1565
|
+
level="ERROR",
|
|
1566
|
+
message=f"Failed to prepare RMT guard: {str(e)}",
|
|
1567
|
+
error=str(e),
|
|
1568
|
+
)
|
|
1569
|
+
|
|
1570
|
+
return {
|
|
1571
|
+
"baseline_metrics": {},
|
|
1572
|
+
"policy_applied": policy or {},
|
|
1573
|
+
"preparation_time": time.time() - start_time,
|
|
1574
|
+
"ready": False,
|
|
1575
|
+
"error": str(e),
|
|
1576
|
+
}
|
|
1577
|
+
|
|
1578
|
+
def before_edit(self, model: nn.Module) -> None:
|
|
1579
|
+
"""
|
|
1580
|
+
Execute before edit (no action needed for RMT).
|
|
1581
|
+
|
|
1582
|
+
Args:
|
|
1583
|
+
model: The model about to be edited
|
|
1584
|
+
"""
|
|
1585
|
+
if self.prepared:
|
|
1586
|
+
self._log_event(
|
|
1587
|
+
"before_edit",
|
|
1588
|
+
message="RMT guard ready for post-edit detection and correction",
|
|
1589
|
+
)
|
|
1590
|
+
|
|
1591
|
+
def after_edit(self, model: nn.Module) -> None:
|
|
1592
|
+
"""
|
|
1593
|
+
Execute after edit - perform RMT detection and optional correction.
|
|
1594
|
+
|
|
1595
|
+
Args:
|
|
1596
|
+
model: The model that was just edited
|
|
1597
|
+
"""
|
|
1598
|
+
if not self.prepared or not self.baseline_mp_stats:
|
|
1599
|
+
self._log_event(
|
|
1600
|
+
"after_edit_skipped",
|
|
1601
|
+
level="WARN",
|
|
1602
|
+
message="RMT guard not prepared, skipping post-edit detection",
|
|
1603
|
+
)
|
|
1604
|
+
return
|
|
1605
|
+
|
|
1606
|
+
self._log_event("after_edit", message="Applying RMT detection and correction")
|
|
1607
|
+
|
|
1608
|
+
try:
|
|
1609
|
+
# Perform RMT detection with baseline awareness
|
|
1610
|
+
# Create custom detection with proper adapter support
|
|
1611
|
+
if self.correct:
|
|
1612
|
+
# Apply correction using enhanced logic with adapter support
|
|
1613
|
+
detection_result = self._apply_rmt_detection_and_correction(model)
|
|
1614
|
+
else:
|
|
1615
|
+
# Detection only
|
|
1616
|
+
detection_result = rmt_detect(
|
|
1617
|
+
model=model,
|
|
1618
|
+
threshold=self.margin, # Use margin as threshold
|
|
1619
|
+
detect_only=True,
|
|
1620
|
+
verbose=False,
|
|
1621
|
+
baseline_sigmas=self.baseline_sigmas,
|
|
1622
|
+
baseline_mp_stats=self.baseline_mp_stats,
|
|
1623
|
+
deadband=self.deadband,
|
|
1624
|
+
)
|
|
1625
|
+
|
|
1626
|
+
# Store results
|
|
1627
|
+
self._last_result = detection_result
|
|
1628
|
+
self.outliers_per_family = self._count_outliers_per_family(
|
|
1629
|
+
detection_result.get("per_layer", [])
|
|
1630
|
+
)
|
|
1631
|
+
for family_key in ("attn", "ffn", "embed", "other"):
|
|
1632
|
+
self.outliers_per_family.setdefault(family_key, 0)
|
|
1633
|
+
self.outliers_total = detection_result.get(
|
|
1634
|
+
"n_layers_flagged", len(self.outliers_per_family)
|
|
1635
|
+
)
|
|
1636
|
+
self.epsilon_violations = self._compute_epsilon_violations()
|
|
1637
|
+
|
|
1638
|
+
flagged_layers = detection_result.get("n_layers_flagged", 0)
|
|
1639
|
+
corrected_layers = detection_result.get("correction_iterations", 0)
|
|
1640
|
+
|
|
1641
|
+
self._log_event(
|
|
1642
|
+
"rmt_detection_complete",
|
|
1643
|
+
message=f"Detected {flagged_layers} outlier layers, correction enabled: {self.correct}",
|
|
1644
|
+
layers_flagged=flagged_layers,
|
|
1645
|
+
correction_iterations=corrected_layers,
|
|
1646
|
+
has_outliers=detection_result.get("has_outliers", False),
|
|
1647
|
+
max_ratio=detection_result.get("max_ratio", 0.0),
|
|
1648
|
+
)
|
|
1649
|
+
|
|
1650
|
+
# Log individual layer results
|
|
1651
|
+
for layer_info in detection_result.get("per_layer", []):
|
|
1652
|
+
if layer_info.get("has_outlier", False):
|
|
1653
|
+
self._log_event(
|
|
1654
|
+
"outlier_detected",
|
|
1655
|
+
message=f"Outlier detected in {layer_info.get('module_name', 'unknown')}",
|
|
1656
|
+
layer_name=layer_info.get("module_name"),
|
|
1657
|
+
ratio=layer_info.get("worst_ratio", 0.0),
|
|
1658
|
+
sigma_max=layer_info.get("sigma_max", 0.0),
|
|
1659
|
+
corrected=self.correct,
|
|
1660
|
+
)
|
|
1661
|
+
elif layer_info.get("skip_reason"):
|
|
1662
|
+
self._log_event(
|
|
1663
|
+
"layer_skipped",
|
|
1664
|
+
message=f"Layer {layer_info.get('module_name', 'unknown')} skipped: {layer_info.get('skip_reason')}",
|
|
1665
|
+
layer_name=layer_info.get("module_name"),
|
|
1666
|
+
skip_reason=layer_info.get("skip_reason"),
|
|
1667
|
+
)
|
|
1668
|
+
|
|
1669
|
+
except Exception as e:
|
|
1670
|
+
self._log_event(
|
|
1671
|
+
"after_edit_failed",
|
|
1672
|
+
level="ERROR",
|
|
1673
|
+
message=f"RMT detection failed: {str(e)}",
|
|
1674
|
+
error=str(e),
|
|
1675
|
+
)
|
|
1676
|
+
# Store empty result for finalize
|
|
1677
|
+
self._last_result = {
|
|
1678
|
+
"has_outliers": False,
|
|
1679
|
+
"n_layers_flagged": 0,
|
|
1680
|
+
"per_layer": [],
|
|
1681
|
+
"max_ratio": 0.0,
|
|
1682
|
+
}
|
|
1683
|
+
self.outliers_per_family = {}
|
|
1684
|
+
self.outliers_total = 0
|
|
1685
|
+
self.epsilon_violations = []
|
|
1686
|
+
|
|
1687
|
+
def validate(
|
|
1688
|
+
self, model: Any, adapter: Any, context: dict[str, Any]
|
|
1689
|
+
) -> dict[str, Any]:
|
|
1690
|
+
"""
|
|
1691
|
+
Validate model state (Guard ABC interface).
|
|
1692
|
+
|
|
1693
|
+
Args:
|
|
1694
|
+
model: Model to validate
|
|
1695
|
+
adapter: ModelAdapter instance
|
|
1696
|
+
context: Validation context
|
|
1697
|
+
|
|
1698
|
+
Returns:
|
|
1699
|
+
Dictionary with validation results
|
|
1700
|
+
"""
|
|
1701
|
+
# Use finalize to get comprehensive results
|
|
1702
|
+
result = self.finalize(model, adapter)
|
|
1703
|
+
|
|
1704
|
+
# Convert to simple dict format if GuardOutcome
|
|
1705
|
+
if (
|
|
1706
|
+
hasattr(result, "passed")
|
|
1707
|
+
and hasattr(result, "action")
|
|
1708
|
+
and hasattr(result, "metrics")
|
|
1709
|
+
):
|
|
1710
|
+
violations_list: list[str] = []
|
|
1711
|
+
if hasattr(result, "violations") and result.violations:
|
|
1712
|
+
violations_list = [str(v) for v in result.violations]
|
|
1713
|
+
return {
|
|
1714
|
+
"passed": bool(result.passed),
|
|
1715
|
+
"action": str(result.action),
|
|
1716
|
+
"metrics": dict(result.metrics),
|
|
1717
|
+
"violations": violations_list,
|
|
1718
|
+
"message": "RMT guard validation completed",
|
|
1719
|
+
}
|
|
1720
|
+
else:
|
|
1721
|
+
return {
|
|
1722
|
+
"passed": result.get("passed", False),
|
|
1723
|
+
"action": "continue" if result.get("passed", False) else "warn",
|
|
1724
|
+
"metrics": result.get("metrics", {}),
|
|
1725
|
+
"violations": result.get("errors", []),
|
|
1726
|
+
"message": "RMT guard validation completed",
|
|
1727
|
+
}
|
|
1728
|
+
|
|
1729
|
+
def finalize(self, model: nn.Module, adapter=None) -> GuardOutcome | dict[str, Any]:
|
|
1730
|
+
"""
|
|
1731
|
+
Finalize RMT guard and return comprehensive results.
|
|
1732
|
+
|
|
1733
|
+
Args:
|
|
1734
|
+
model: The final edited model
|
|
1735
|
+
adapter: Optional adapter for tying map access
|
|
1736
|
+
|
|
1737
|
+
Returns:
|
|
1738
|
+
GuardOutcome or dict with RMT detection and correction results
|
|
1739
|
+
"""
|
|
1740
|
+
import time
|
|
1741
|
+
|
|
1742
|
+
start_time = time.time()
|
|
1743
|
+
|
|
1744
|
+
if not self.prepared:
|
|
1745
|
+
self._log_event(
|
|
1746
|
+
"finalize_failed",
|
|
1747
|
+
level="ERROR",
|
|
1748
|
+
message="RMT guard not properly prepared",
|
|
1749
|
+
)
|
|
1750
|
+
|
|
1751
|
+
if HAS_GUARD_OUTCOME:
|
|
1752
|
+
return GuardOutcome(
|
|
1753
|
+
name=self.name,
|
|
1754
|
+
passed=False,
|
|
1755
|
+
action="abort",
|
|
1756
|
+
violations=[
|
|
1757
|
+
{
|
|
1758
|
+
"type": "preparation",
|
|
1759
|
+
"severity": "error",
|
|
1760
|
+
"message": "RMT guard not properly prepared",
|
|
1761
|
+
"module_name": None,
|
|
1762
|
+
}
|
|
1763
|
+
],
|
|
1764
|
+
metrics={
|
|
1765
|
+
"prepared": False,
|
|
1766
|
+
"finalize_time": time.time() - start_time,
|
|
1767
|
+
},
|
|
1768
|
+
)
|
|
1769
|
+
else:
|
|
1770
|
+
return {
|
|
1771
|
+
"passed": False,
|
|
1772
|
+
"metrics": {
|
|
1773
|
+
"prepared": False,
|
|
1774
|
+
"finalize_time": time.time() - start_time,
|
|
1775
|
+
},
|
|
1776
|
+
"warnings": ["RMT guard not properly prepared"],
|
|
1777
|
+
"errors": ["Preparation failed or baseline MP stats not captured"],
|
|
1778
|
+
"events": self.events,
|
|
1779
|
+
}
|
|
1780
|
+
|
|
1781
|
+
# Get results from after_edit
|
|
1782
|
+
result = self._last_result or {
|
|
1783
|
+
"has_outliers": False,
|
|
1784
|
+
"n_layers_flagged": 0,
|
|
1785
|
+
"per_layer": [],
|
|
1786
|
+
"max_ratio": 0.0,
|
|
1787
|
+
}
|
|
1788
|
+
|
|
1789
|
+
if result and not self.outliers_per_family:
|
|
1790
|
+
self.outliers_per_family = self._count_outliers_per_family(
|
|
1791
|
+
result.get("per_layer", [])
|
|
1792
|
+
)
|
|
1793
|
+
for family_key in ("attn", "ffn", "embed", "other"):
|
|
1794
|
+
self.outliers_per_family.setdefault(family_key, 0)
|
|
1795
|
+
self.baseline_outliers_per_family.setdefault(family_key, 0)
|
|
1796
|
+
self.outliers_total = result.get("n_layers_flagged", self.outliers_total or 0)
|
|
1797
|
+
self.epsilon_violations = self._compute_epsilon_violations()
|
|
1798
|
+
# Contracts: epsilon non-negative, counts non-negative
|
|
1799
|
+
for fam, eps in self.epsilon_by_family.items():
|
|
1800
|
+
guard_assert(eps >= 0.0, f"rmt.epsilon[{fam}] must be >= 0")
|
|
1801
|
+
for fam in set(self.outliers_per_family) | set(
|
|
1802
|
+
self.baseline_outliers_per_family
|
|
1803
|
+
):
|
|
1804
|
+
guard_assert(
|
|
1805
|
+
self.outliers_per_family.get(fam, 0) >= 0,
|
|
1806
|
+
"rmt.outliers_per_family negative",
|
|
1807
|
+
)
|
|
1808
|
+
guard_assert(
|
|
1809
|
+
self.baseline_outliers_per_family.get(fam, 0) >= 0,
|
|
1810
|
+
"rmt.baseline_outliers negative",
|
|
1811
|
+
)
|
|
1812
|
+
|
|
1813
|
+
# Calculate metrics
|
|
1814
|
+
flagged_layers = result.get("n_layers_flagged", 0)
|
|
1815
|
+
total_layers = len(self.baseline_mp_stats) if self.baseline_mp_stats else 0
|
|
1816
|
+
flagged_rate = flagged_layers / total_layers if total_layers > 0 else 0.0
|
|
1817
|
+
|
|
1818
|
+
# Step 5 validation gate: no increase in outliers vs bare edit, ≤1% primary-metric cost
|
|
1819
|
+
# For now, use flagged rate as proxy (will be enhanced with PM checking)
|
|
1820
|
+
passed = flagged_rate <= 0.5 # Allow up to 50% flagged for conservative gate
|
|
1821
|
+
|
|
1822
|
+
# Generate violations for GuardOutcome
|
|
1823
|
+
violations = []
|
|
1824
|
+
warnings = []
|
|
1825
|
+
errors = []
|
|
1826
|
+
|
|
1827
|
+
# Create violations for each flagged layer
|
|
1828
|
+
for layer_info in result.get("per_layer", []):
|
|
1829
|
+
if layer_info.get("has_outlier", False):
|
|
1830
|
+
violations.append(
|
|
1831
|
+
{
|
|
1832
|
+
"type": "rmt_outlier",
|
|
1833
|
+
"severity": "warning" if self.correct else "error",
|
|
1834
|
+
"message": f"RMT outlier detected: ratio={layer_info.get('worst_ratio', 0.0):.2f}",
|
|
1835
|
+
"module_name": layer_info.get("module_name"),
|
|
1836
|
+
"ratio": layer_info.get("worst_ratio", 0.0),
|
|
1837
|
+
"threshold": (1.0 + self.deadband) * self.margin,
|
|
1838
|
+
"corrected": self.correct,
|
|
1839
|
+
}
|
|
1840
|
+
)
|
|
1841
|
+
|
|
1842
|
+
if flagged_rate > 0.3: # Warning threshold at 30%
|
|
1843
|
+
warnings.append(
|
|
1844
|
+
f"High RMT outlier rate: {flagged_layers}/{total_layers} layers flagged ({flagged_rate:.1%})"
|
|
1845
|
+
)
|
|
1846
|
+
|
|
1847
|
+
if flagged_rate > 0.7: # Error threshold at 70%
|
|
1848
|
+
errors.append(
|
|
1849
|
+
f"Excessive RMT outliers: {flagged_layers}/{total_layers} layers flagged"
|
|
1850
|
+
)
|
|
1851
|
+
passed = False
|
|
1852
|
+
|
|
1853
|
+
if self.epsilon_violations:
|
|
1854
|
+
passed = False
|
|
1855
|
+
for failure in self.epsilon_violations:
|
|
1856
|
+
errors.append(
|
|
1857
|
+
"RMT ε-rule violation: "
|
|
1858
|
+
f"{failure['family']} bare={failure['bare']} "
|
|
1859
|
+
f"guarded={failure['guarded']} allowed={failure['allowed']} "
|
|
1860
|
+
f"(ε={failure['epsilon']:.3f})"
|
|
1861
|
+
)
|
|
1862
|
+
|
|
1863
|
+
finalize_time = time.time() - start_time
|
|
1864
|
+
|
|
1865
|
+
# Final metrics
|
|
1866
|
+
final_metrics = {
|
|
1867
|
+
"layers_flagged": flagged_layers,
|
|
1868
|
+
"total_layers": total_layers,
|
|
1869
|
+
"flagged_rate": flagged_rate,
|
|
1870
|
+
"rmt_outliers": flagged_layers,
|
|
1871
|
+
"max_ratio": result.get("max_ratio", 0.0),
|
|
1872
|
+
"correction_enabled": self.correct,
|
|
1873
|
+
"correction_iterations": result.get("correction_iterations", 0),
|
|
1874
|
+
"q_used": self.q,
|
|
1875
|
+
"deadband_used": self.deadband,
|
|
1876
|
+
"margin_used": self.margin,
|
|
1877
|
+
"detection_threshold": (1.0 + self.deadband) * self.margin,
|
|
1878
|
+
"baseline_layers_captured": len(self.baseline_mp_stats)
|
|
1879
|
+
if self.baseline_mp_stats
|
|
1880
|
+
else 0,
|
|
1881
|
+
"finalize_time": finalize_time,
|
|
1882
|
+
"baseline_outliers_per_family": {
|
|
1883
|
+
k: int(v) for k, v in self.baseline_outliers_per_family.items()
|
|
1884
|
+
},
|
|
1885
|
+
"outliers_per_family": {
|
|
1886
|
+
k: int(v) for k, v in self.outliers_per_family.items()
|
|
1887
|
+
},
|
|
1888
|
+
"baseline_outliers_total": int(self.baseline_total_outliers),
|
|
1889
|
+
"outliers_total": int(self.outliers_total),
|
|
1890
|
+
"epsilon_by_family": {
|
|
1891
|
+
k: float(v) for k, v in self.epsilon_by_family.items()
|
|
1892
|
+
},
|
|
1893
|
+
"epsilon_default": float(self.epsilon_default),
|
|
1894
|
+
"epsilon_violations": self.epsilon_violations,
|
|
1895
|
+
}
|
|
1896
|
+
|
|
1897
|
+
self._log_event(
|
|
1898
|
+
"finalize_complete",
|
|
1899
|
+
message=f"RMT guard finalized - {'PASSED' if passed else 'FAILED'}",
|
|
1900
|
+
passed=passed,
|
|
1901
|
+
flagged_rate=flagged_rate,
|
|
1902
|
+
finalize_time=finalize_time,
|
|
1903
|
+
)
|
|
1904
|
+
|
|
1905
|
+
# Return GuardOutcome if available, otherwise legacy dict
|
|
1906
|
+
# Env-gated tiny evidence dump for auditors
|
|
1907
|
+
try:
|
|
1908
|
+
payload = {
|
|
1909
|
+
"rmt": {
|
|
1910
|
+
"epsilon_by_family": {
|
|
1911
|
+
k: float(v) for k, v in self.epsilon_by_family.items()
|
|
1912
|
+
},
|
|
1913
|
+
"deadband": float(self.deadband),
|
|
1914
|
+
"margin": float(self.margin),
|
|
1915
|
+
"evaluated": True,
|
|
1916
|
+
}
|
|
1917
|
+
}
|
|
1918
|
+
maybe_dump_guard_evidence(".", payload)
|
|
1919
|
+
except Exception:
|
|
1920
|
+
pass
|
|
1921
|
+
|
|
1922
|
+
if HAS_GUARD_OUTCOME:
|
|
1923
|
+
# Add details to metrics since GuardOutcome doesn't have a details field
|
|
1924
|
+
final_metrics.update(
|
|
1925
|
+
{
|
|
1926
|
+
"guard_type": "rmt",
|
|
1927
|
+
"baseline_captured": self.baseline_mp_stats is not None,
|
|
1928
|
+
"baseline_count": len(self.baseline_mp_stats)
|
|
1929
|
+
if self.baseline_mp_stats
|
|
1930
|
+
else 0,
|
|
1931
|
+
"flagged_layer_names": [v["module_name"] for v in violations],
|
|
1932
|
+
"per_layer_results": result.get("per_layer", []),
|
|
1933
|
+
"policy": {
|
|
1934
|
+
"q": self.q,
|
|
1935
|
+
"deadband": self.deadband,
|
|
1936
|
+
"margin": self.margin,
|
|
1937
|
+
"correct": self.correct,
|
|
1938
|
+
"epsilon": self.epsilon_by_family.copy(),
|
|
1939
|
+
},
|
|
1940
|
+
"scope_suffixes": self.allowed_suffixes,
|
|
1941
|
+
}
|
|
1942
|
+
)
|
|
1943
|
+
|
|
1944
|
+
return GuardOutcome(
|
|
1945
|
+
name=self.name,
|
|
1946
|
+
passed=passed,
|
|
1947
|
+
action="none" if passed else "rollback",
|
|
1948
|
+
violations=violations,
|
|
1949
|
+
metrics=final_metrics,
|
|
1950
|
+
)
|
|
1951
|
+
else:
|
|
1952
|
+
return {
|
|
1953
|
+
"passed": passed,
|
|
1954
|
+
"metrics": final_metrics,
|
|
1955
|
+
"warnings": warnings,
|
|
1956
|
+
"errors": errors,
|
|
1957
|
+
"violations": violations,
|
|
1958
|
+
"events": self.events,
|
|
1959
|
+
"details": {
|
|
1960
|
+
"guard_type": "rmt",
|
|
1961
|
+
"baseline_captured": self.baseline_mp_stats is not None,
|
|
1962
|
+
"baseline_count": len(self.baseline_mp_stats)
|
|
1963
|
+
if self.baseline_mp_stats
|
|
1964
|
+
else 0,
|
|
1965
|
+
"flagged_layer_names": [v["module_name"] for v in violations],
|
|
1966
|
+
"per_layer_results": result.get("per_layer", []),
|
|
1967
|
+
"policy": {
|
|
1968
|
+
"q": self.q,
|
|
1969
|
+
"deadband": self.deadband,
|
|
1970
|
+
"margin": self.margin,
|
|
1971
|
+
"correct": self.correct,
|
|
1972
|
+
"epsilon": self.epsilon_by_family.copy(),
|
|
1973
|
+
},
|
|
1974
|
+
"scope_suffixes": self.allowed_suffixes,
|
|
1975
|
+
},
|
|
1976
|
+
}
|
|
1977
|
+
|
|
1978
|
+
def policy(self) -> RMTPolicyDict:
|
|
1979
|
+
"""
|
|
1980
|
+
Get default policy for RMT guard.
|
|
1981
|
+
|
|
1982
|
+
Returns:
|
|
1983
|
+
RMTPolicyDict with current configuration
|
|
1984
|
+
"""
|
|
1985
|
+
return RMTPolicyDict(
|
|
1986
|
+
q=self.q,
|
|
1987
|
+
deadband=self.deadband,
|
|
1988
|
+
margin=self.margin,
|
|
1989
|
+
correct=self.correct,
|
|
1990
|
+
epsilon=self.epsilon_by_family.copy(),
|
|
1991
|
+
)
|
|
1992
|
+
|
|
1993
|
+
|
|
1994
|
+
# === Policy Utilities ===
|
|
1995
|
+
|
|
1996
|
+
|
|
1997
|
+
def get_rmt_policy(name: str = "balanced") -> RMTPolicyDict:
|
|
1998
|
+
"""
|
|
1999
|
+
Get a RMT policy by name.
|
|
2000
|
+
|
|
2001
|
+
Args:
|
|
2002
|
+
name: Policy name ("conservative", "balanced", "aggressive")
|
|
2003
|
+
|
|
2004
|
+
Returns:
|
|
2005
|
+
RMTPolicyDict configuration
|
|
2006
|
+
"""
|
|
2007
|
+
# Per-family ε values match tiers.yaml (November 2025 calibration)
|
|
2008
|
+
policies = {
|
|
2009
|
+
"conservative": RMTPolicyDict(
|
|
2010
|
+
q="auto",
|
|
2011
|
+
deadband=0.05,
|
|
2012
|
+
margin=1.3,
|
|
2013
|
+
correct=True,
|
|
2014
|
+
epsilon={"ffn": 0.06, "attn": 0.05, "embed": 0.07, "other": 0.07},
|
|
2015
|
+
),
|
|
2016
|
+
"balanced": RMTPolicyDict(
|
|
2017
|
+
q="auto",
|
|
2018
|
+
deadband=0.10,
|
|
2019
|
+
margin=1.5,
|
|
2020
|
+
correct=True,
|
|
2021
|
+
epsilon={"ffn": 0.10, "attn": 0.08, "embed": 0.12, "other": 0.12},
|
|
2022
|
+
),
|
|
2023
|
+
"aggressive": RMTPolicyDict(
|
|
2024
|
+
q="auto",
|
|
2025
|
+
deadband=0.15,
|
|
2026
|
+
margin=1.8,
|
|
2027
|
+
correct=True,
|
|
2028
|
+
epsilon={"ffn": 0.14, "attn": 0.12, "embed": 0.18, "other": 0.18},
|
|
2029
|
+
),
|
|
2030
|
+
}
|
|
2031
|
+
|
|
2032
|
+
if name not in policies:
|
|
2033
|
+
from invarlock.core.exceptions import GuardError
|
|
2034
|
+
|
|
2035
|
+
available = list(policies.keys())
|
|
2036
|
+
raise GuardError(
|
|
2037
|
+
code="E502",
|
|
2038
|
+
message="POLICY-NOT-FOUND",
|
|
2039
|
+
details={"name": name, "available": available},
|
|
2040
|
+
)
|
|
2041
|
+
|
|
2042
|
+
return policies[name]
|
|
2043
|
+
|
|
2044
|
+
|
|
2045
|
+
def create_custom_rmt_policy(
|
|
2046
|
+
q: float | Literal["auto"] = "auto",
|
|
2047
|
+
deadband: float = 0.10,
|
|
2048
|
+
margin: float = 1.5,
|
|
2049
|
+
correct: bool = True,
|
|
2050
|
+
epsilon: float | dict[str, float] | None = None,
|
|
2051
|
+
) -> RMTPolicyDict:
|
|
2052
|
+
"""
|
|
2053
|
+
Create a custom RMT policy.
|
|
2054
|
+
|
|
2055
|
+
Args:
|
|
2056
|
+
q: MP aspect ratio (auto-derived or manual)
|
|
2057
|
+
deadband: Tolerance margin (0.0-0.5)
|
|
2058
|
+
margin: RMT threshold ratio (> 1.0)
|
|
2059
|
+
correct: Enable automatic correction
|
|
2060
|
+
|
|
2061
|
+
Returns:
|
|
2062
|
+
Custom RMTPolicyDict configuration
|
|
2063
|
+
"""
|
|
2064
|
+
if isinstance(q, float) and not 0.1 <= q <= 10.0:
|
|
2065
|
+
from invarlock.core.exceptions import ValidationError
|
|
2066
|
+
|
|
2067
|
+
raise ValidationError(
|
|
2068
|
+
code="E501",
|
|
2069
|
+
message="POLICY-PARAM-INVALID",
|
|
2070
|
+
details={"param": "q", "value": q},
|
|
2071
|
+
)
|
|
2072
|
+
|
|
2073
|
+
if not 0.0 <= deadband <= 0.5:
|
|
2074
|
+
from invarlock.core.exceptions import ValidationError
|
|
2075
|
+
|
|
2076
|
+
raise ValidationError(
|
|
2077
|
+
code="E501",
|
|
2078
|
+
message="POLICY-PARAM-INVALID",
|
|
2079
|
+
details={"param": "deadband", "value": deadband},
|
|
2080
|
+
)
|
|
2081
|
+
|
|
2082
|
+
if not margin >= 1.0:
|
|
2083
|
+
from invarlock.core.exceptions import ValidationError
|
|
2084
|
+
|
|
2085
|
+
raise ValidationError(
|
|
2086
|
+
code="E501",
|
|
2087
|
+
message="POLICY-PARAM-INVALID",
|
|
2088
|
+
details={"param": "margin", "value": margin},
|
|
2089
|
+
)
|
|
2090
|
+
|
|
2091
|
+
return RMTPolicyDict(
|
|
2092
|
+
q=q,
|
|
2093
|
+
deadband=deadband,
|
|
2094
|
+
margin=margin,
|
|
2095
|
+
correct=correct,
|
|
2096
|
+
epsilon=epsilon,
|
|
2097
|
+
)
|