invarlock 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (132) hide show
  1. invarlock/__init__.py +33 -0
  2. invarlock/__main__.py +10 -0
  3. invarlock/_data/runtime/profiles/ci_cpu.yaml +15 -0
  4. invarlock/_data/runtime/profiles/release.yaml +23 -0
  5. invarlock/_data/runtime/tiers.yaml +76 -0
  6. invarlock/adapters/__init__.py +102 -0
  7. invarlock/adapters/_capabilities.py +45 -0
  8. invarlock/adapters/auto.py +99 -0
  9. invarlock/adapters/base.py +530 -0
  10. invarlock/adapters/base_types.py +85 -0
  11. invarlock/adapters/hf_bert.py +852 -0
  12. invarlock/adapters/hf_gpt2.py +403 -0
  13. invarlock/adapters/hf_llama.py +485 -0
  14. invarlock/adapters/hf_mixin.py +383 -0
  15. invarlock/adapters/hf_onnx.py +112 -0
  16. invarlock/adapters/hf_t5.py +137 -0
  17. invarlock/adapters/py.typed +1 -0
  18. invarlock/assurance/__init__.py +43 -0
  19. invarlock/cli/__init__.py +8 -0
  20. invarlock/cli/__main__.py +8 -0
  21. invarlock/cli/_evidence.py +25 -0
  22. invarlock/cli/_json.py +75 -0
  23. invarlock/cli/adapter_auto.py +162 -0
  24. invarlock/cli/app.py +287 -0
  25. invarlock/cli/commands/__init__.py +26 -0
  26. invarlock/cli/commands/certify.py +403 -0
  27. invarlock/cli/commands/doctor.py +1358 -0
  28. invarlock/cli/commands/explain_gates.py +151 -0
  29. invarlock/cli/commands/export_html.py +100 -0
  30. invarlock/cli/commands/plugins.py +1331 -0
  31. invarlock/cli/commands/report.py +354 -0
  32. invarlock/cli/commands/run.py +4146 -0
  33. invarlock/cli/commands/verify.py +1040 -0
  34. invarlock/cli/config.py +396 -0
  35. invarlock/cli/constants.py +68 -0
  36. invarlock/cli/device.py +92 -0
  37. invarlock/cli/doctor_helpers.py +74 -0
  38. invarlock/cli/errors.py +6 -0
  39. invarlock/cli/overhead_utils.py +60 -0
  40. invarlock/cli/provenance.py +66 -0
  41. invarlock/cli/utils.py +41 -0
  42. invarlock/config.py +56 -0
  43. invarlock/core/__init__.py +62 -0
  44. invarlock/core/abi.py +15 -0
  45. invarlock/core/api.py +274 -0
  46. invarlock/core/auto_tuning.py +317 -0
  47. invarlock/core/bootstrap.py +226 -0
  48. invarlock/core/checkpoint.py +221 -0
  49. invarlock/core/contracts.py +73 -0
  50. invarlock/core/error_utils.py +64 -0
  51. invarlock/core/events.py +298 -0
  52. invarlock/core/exceptions.py +95 -0
  53. invarlock/core/registry.py +481 -0
  54. invarlock/core/retry.py +146 -0
  55. invarlock/core/runner.py +2041 -0
  56. invarlock/core/types.py +154 -0
  57. invarlock/edits/__init__.py +12 -0
  58. invarlock/edits/_edit_utils.py +249 -0
  59. invarlock/edits/_external_utils.py +268 -0
  60. invarlock/edits/noop.py +47 -0
  61. invarlock/edits/py.typed +1 -0
  62. invarlock/edits/quant_rtn.py +801 -0
  63. invarlock/edits/registry.py +166 -0
  64. invarlock/eval/__init__.py +23 -0
  65. invarlock/eval/bench.py +1207 -0
  66. invarlock/eval/bootstrap.py +50 -0
  67. invarlock/eval/data.py +2052 -0
  68. invarlock/eval/metrics.py +2167 -0
  69. invarlock/eval/primary_metric.py +767 -0
  70. invarlock/eval/probes/__init__.py +24 -0
  71. invarlock/eval/probes/fft.py +139 -0
  72. invarlock/eval/probes/mi.py +213 -0
  73. invarlock/eval/probes/post_attention.py +323 -0
  74. invarlock/eval/providers/base.py +67 -0
  75. invarlock/eval/providers/seq2seq.py +111 -0
  76. invarlock/eval/providers/text_lm.py +113 -0
  77. invarlock/eval/providers/vision_text.py +93 -0
  78. invarlock/eval/py.typed +1 -0
  79. invarlock/guards/__init__.py +18 -0
  80. invarlock/guards/_contracts.py +9 -0
  81. invarlock/guards/invariants.py +640 -0
  82. invarlock/guards/policies.py +805 -0
  83. invarlock/guards/py.typed +1 -0
  84. invarlock/guards/rmt.py +2097 -0
  85. invarlock/guards/spectral.py +1419 -0
  86. invarlock/guards/tier_config.py +354 -0
  87. invarlock/guards/variance.py +3298 -0
  88. invarlock/guards_ref/__init__.py +15 -0
  89. invarlock/guards_ref/rmt_ref.py +40 -0
  90. invarlock/guards_ref/spectral_ref.py +135 -0
  91. invarlock/guards_ref/variance_ref.py +60 -0
  92. invarlock/model_profile.py +353 -0
  93. invarlock/model_utils.py +221 -0
  94. invarlock/observability/__init__.py +10 -0
  95. invarlock/observability/alerting.py +535 -0
  96. invarlock/observability/core.py +546 -0
  97. invarlock/observability/exporters.py +565 -0
  98. invarlock/observability/health.py +588 -0
  99. invarlock/observability/metrics.py +457 -0
  100. invarlock/observability/py.typed +1 -0
  101. invarlock/observability/utils.py +553 -0
  102. invarlock/plugins/__init__.py +12 -0
  103. invarlock/plugins/hello_guard.py +33 -0
  104. invarlock/plugins/hf_awq_adapter.py +82 -0
  105. invarlock/plugins/hf_bnb_adapter.py +79 -0
  106. invarlock/plugins/hf_gptq_adapter.py +78 -0
  107. invarlock/plugins/py.typed +1 -0
  108. invarlock/py.typed +1 -0
  109. invarlock/reporting/__init__.py +7 -0
  110. invarlock/reporting/certificate.py +3221 -0
  111. invarlock/reporting/certificate_schema.py +244 -0
  112. invarlock/reporting/dataset_hashing.py +215 -0
  113. invarlock/reporting/guards_analysis.py +948 -0
  114. invarlock/reporting/html.py +32 -0
  115. invarlock/reporting/normalizer.py +235 -0
  116. invarlock/reporting/policy_utils.py +517 -0
  117. invarlock/reporting/primary_metric_utils.py +265 -0
  118. invarlock/reporting/render.py +1442 -0
  119. invarlock/reporting/report.py +903 -0
  120. invarlock/reporting/report_types.py +278 -0
  121. invarlock/reporting/utils.py +175 -0
  122. invarlock/reporting/validate.py +631 -0
  123. invarlock/security.py +176 -0
  124. invarlock/sparsity_utils.py +323 -0
  125. invarlock/utils/__init__.py +150 -0
  126. invarlock/utils/digest.py +45 -0
  127. invarlock-0.2.0.dist-info/METADATA +586 -0
  128. invarlock-0.2.0.dist-info/RECORD +132 -0
  129. invarlock-0.2.0.dist-info/WHEEL +5 -0
  130. invarlock-0.2.0.dist-info/entry_points.txt +20 -0
  131. invarlock-0.2.0.dist-info/licenses/LICENSE +201 -0
  132. invarlock-0.2.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,2097 @@
1
+ """
2
+ InvarLock – Safety: Random Matrix Theory (RMT) Health Check
3
+ =======================================================
4
+
5
+ Detect-only mode for v0: identifies singular value outliers that
6
+ deviate from the Marchenko-Pastur bulk distribution.
7
+
8
+ Based on insights from Słowik et al., 2025 linking MP outliers
9
+ to training instability.
10
+ """
11
+
12
+ from __future__ import annotations
13
+
14
+ import math
15
+ from dataclasses import dataclass
16
+ from datetime import datetime
17
+ from typing import Any, Literal, TypedDict
18
+
19
+ import numpy as np
20
+ import torch
21
+ import torch.linalg as tla
22
+ import torch.nn as nn
23
+
24
+ from invarlock.cli._evidence import maybe_dump_guard_evidence
25
+ from invarlock.core.api import Guard
26
+
27
+ from ._contracts import guard_assert
28
+
29
+ __all__ = [
30
+ # Utility functions
31
+ "mp_bulk_edges",
32
+ "mp_bulk_edge",
33
+ "layer_svd_stats",
34
+ "rmt_detect",
35
+ "rmt_detect_report",
36
+ "rmt_detect_with_names",
37
+ "clip_full_svd",
38
+ "analyze_weight_distribution",
39
+ "rmt_growth_ratio",
40
+ "within_deadband",
41
+ "capture_baseline_mp_stats",
42
+ # Guard classes and types
43
+ "RMTGuard",
44
+ "RMTPolicy",
45
+ "RMTPolicyDict",
46
+ # Policy utilities
47
+ "get_rmt_policy",
48
+ "create_custom_rmt_policy",
49
+ ]
50
+
51
+
52
+ def mp_bulk_edges(m: int, n: int, whitened: bool = True) -> tuple[float, float]:
53
+ """
54
+ Compute Marchenko-Pastur bulk edges for an m×n matrix.
55
+
56
+ For a weight matrix W ∈ ℝ^{m×n}, the MP distribution describes
57
+ the eigenvalues of (W^T W)/m when entries are i.i.d. with variance 1/m.
58
+
59
+ Args:
60
+ m: Number of rows (input features for Conv1D)
61
+ n: Number of columns (output features for Conv1D)
62
+ whitened: If True, assumes W is already whitened by √m
63
+
64
+ Returns:
65
+ (σ_min, σ_max) theoretical bulk edges for singular values
66
+ """
67
+ if m == 0 or n == 0:
68
+ return 0.0, 0.0
69
+
70
+ # q = n/m (aspect ratio)
71
+ q = n / m
72
+
73
+ if whitened:
74
+ # For whitened matrix W/√m, singular values follow MP with:
75
+ sigma_max = 1.0 + np.sqrt(q)
76
+ sigma_min = abs(1.0 - np.sqrt(q)) if q <= 1 else 0.0
77
+ else:
78
+ # For unwhitened matrix, scale by √m
79
+ sigma_max = np.sqrt(m) * (1.0 + np.sqrt(q))
80
+ sigma_min = np.sqrt(m) * abs(1.0 - np.sqrt(q)) if q <= 1 else 0.0
81
+
82
+ return sigma_min, sigma_max
83
+
84
+
85
+ def mp_bulk_edge(m: int, n: int, whitened: bool = False) -> float:
86
+ """
87
+ Compute Marchenko-Pastur bulk edge for an m×n matrix.
88
+
89
+ This function computes the upper edge (maximum singular value) of the
90
+ Marchenko-Pastur distribution, which represents the theoretical maximum
91
+ singular value for a random matrix with i.i.d. entries.
92
+
93
+ Args:
94
+ m: Number of rows (input features for Conv1D)
95
+ n: Number of columns (output features for Conv1D)
96
+ whitened: If True, assumes W is already whitened by √m
97
+
98
+ Returns:
99
+ σ_max theoretical bulk edge for singular values
100
+ """
101
+ if m == 0 or n == 0:
102
+ return 0.0
103
+
104
+ # q = n/m (aspect ratio)
105
+ q = n / m
106
+
107
+ if whitened:
108
+ # For whitened matrix W/√m, singular values follow MP with:
109
+ sigma_max = 1.0 + np.sqrt(q)
110
+ else:
111
+ # For unwhitened matrix, scale by √m
112
+ sigma_max = np.sqrt(m) * (1.0 + np.sqrt(q))
113
+
114
+ return float(sigma_max)
115
+
116
+
117
+ def _iter_weight_matrices(layer: nn.Module):
118
+ """Iterate over 2D weight matrices in a layer."""
119
+ for name, param in layer.named_parameters():
120
+ if param.ndim == 2 and "weight" in name:
121
+ yield name, param.detach()
122
+
123
+
124
+ def rmt_growth_ratio(
125
+ sigma_cur: float, mp_cur: float, sigma_base: float, mp_base: float
126
+ ) -> float:
127
+ """
128
+ Compute baseline-aware growth ratio for RMT outlier detection.
129
+
130
+ Compares the growth of σ/mp_edge ratio relative to baseline.
131
+
132
+ Args:
133
+ sigma_cur: Current maximum singular value
134
+ mp_cur: Current MP bulk edge
135
+ sigma_base: Baseline maximum singular value
136
+ mp_base: Baseline MP bulk edge
137
+
138
+ Returns:
139
+ Growth ratio: (σ_cur / mp_cur) / (σ_base / mp_base)
140
+ """
141
+ r_base = sigma_base / max(mp_base, 1e-12)
142
+ r_cur = sigma_cur / max(mp_cur, 1e-12)
143
+ return r_cur / max(r_base, 1e-12)
144
+
145
+
146
+ def within_deadband(sigma_cur: float, sigma_base: float, deadband: float) -> bool:
147
+ """
148
+ Check if current sigma is within deadband of baseline.
149
+
150
+ Args:
151
+ sigma_cur: Current spectral norm
152
+ sigma_base: Baseline spectral norm
153
+ deadband: Deadband threshold (e.g., 0.1 for 10%)
154
+
155
+ Returns:
156
+ True if within deadband threshold
157
+ """
158
+ return sigma_cur <= (1.0 + deadband) * sigma_base
159
+
160
+
161
+ def layer_svd_stats(
162
+ layer: nn.Module,
163
+ baseline_sigmas: dict[str, float] | None = None,
164
+ baseline_mp_stats: dict[str, dict[str, float]] | None = None,
165
+ module_name: str | None = None,
166
+ ) -> dict[str, float]:
167
+ """
168
+ Compute SVD statistics for a single layer with baseline-aware normalization.
169
+
170
+ For HuggingFace Conv1D layers:
171
+ - Weight shape is (in_features, out_features)
172
+ - m = in_features, n = out_features
173
+
174
+ Args:
175
+ layer: Transformer layer to analyze
176
+ baseline_sigmas: Optional baseline singular values for baseline-aware comparison
177
+ baseline_mp_stats: Optional baseline MP statistics (mp_bulk_edge, r_mp_base) for each weight matrix
178
+ module_name: Optional module name for baseline lookups
179
+
180
+ Returns:
181
+ Dict with sigma_min, sigma_max, worst_ratio
182
+ """
183
+ sigma_min_global = float("inf")
184
+ sigma_max_global = 0.0
185
+ worst_ratio = 0.0
186
+ worst_details = None
187
+
188
+ for name, W in _iter_weight_matrices(layer):
189
+ if W.numel() == 0:
190
+ continue
191
+ if not torch.isfinite(W).all():
192
+ continue
193
+
194
+ # For Conv1D: W.shape = (in_features, out_features)
195
+ m, n = W.shape # m = in_features, n = out_features
196
+
197
+ # Compute singular values of the actual matrix
198
+ try:
199
+ s_actual = tla.svdvals(W.float().cpu())
200
+ s_min = s_actual[-1].item()
201
+ s_max = s_actual[0].item()
202
+ except (RuntimeError, torch.linalg.LinAlgError):
203
+ continue
204
+
205
+ # Track global min/max
206
+ sigma_min_global = min(sigma_min_global, s_min)
207
+ sigma_max_global = max(sigma_max_global, s_max)
208
+
209
+ # Baseline-aware ratio computation for better outlier detection
210
+ if baseline_sigmas and module_name and module_name in baseline_sigmas:
211
+ # Use baseline-aware growth ratio (preferred method)
212
+ baseline_sigma = baseline_sigmas[module_name]
213
+ if baseline_sigma > 0:
214
+ # Compute current MP edge
215
+ mp_edge_current = mp_bulk_edge(m, n, whitened=False)
216
+
217
+ # Get baseline MP edge from stored stats, or fallback to current
218
+ if baseline_mp_stats and module_name in baseline_mp_stats:
219
+ mp_edge_baseline = baseline_mp_stats[module_name].get(
220
+ "mp_bulk_edge_base", mp_edge_current
221
+ )
222
+ else:
223
+ # Fallback: assume same shape so use same MP edge
224
+ mp_edge_baseline = mp_edge_current
225
+
226
+ # Use new helper function for consistent growth ratio calculation
227
+ ratio = rmt_growth_ratio(
228
+ s_max, mp_edge_current, baseline_sigma, mp_edge_baseline
229
+ )
230
+ else:
231
+ ratio = 1.0
232
+ else:
233
+ # Fallback: Use quantile-based normalization when no baseline available
234
+ if len(s_actual) > 1:
235
+ # Use 98th percentile as robust baseline (less sensitive to outliers)
236
+ s_sorted = s_actual.sort()[0]
237
+ idx_98 = int(0.98 * len(s_sorted))
238
+ s_98 = s_sorted[idx_98].item()
239
+
240
+ if s_98 > 0:
241
+ # Ratio relative to 98th percentile
242
+ ratio = s_max / s_98
243
+ else:
244
+ ratio = 1.0
245
+ else:
246
+ # Single singular value
247
+ ratio = 1.0
248
+
249
+ # Track worst deviation
250
+ if ratio > worst_ratio:
251
+ worst_ratio = ratio
252
+ worst_details = {
253
+ "name": name,
254
+ "shape": (m, n),
255
+ "s_max": s_max,
256
+ "s_min": s_min,
257
+ "s_median": s_actual.median().item() if len(s_actual) > 1 else s_max,
258
+ "s_98": s_actual.sort()[0][int(0.98 * len(s_actual))].item()
259
+ if len(s_actual) > 1
260
+ else s_max,
261
+ "ratio": ratio,
262
+ "mp_edge": mp_bulk_edge(m, n, whitened=False),
263
+ "normalization": "baseline_aware"
264
+ if baseline_sigmas and module_name and module_name in baseline_sigmas
265
+ else "98th_percentile",
266
+ }
267
+
268
+ result = {
269
+ "sigma_min": sigma_min_global,
270
+ "sigma_max": sigma_max_global,
271
+ "worst_ratio": worst_ratio,
272
+ }
273
+
274
+ if worst_details:
275
+ result["worst_details"] = worst_details
276
+
277
+ return result
278
+
279
+
280
+ def capture_baseline_mp_stats(model: nn.Module) -> dict[str, dict[str, float]]:
281
+ """
282
+ Capture baseline MP statistics for linear layers only.
283
+
284
+ CRITICAL: Only includes layers where MP analysis makes sense:
285
+ - attn.c_attn, attn.c_proj, mlp.c_fc, mlp.c_proj
286
+ - EXCLUDES: wte, wpe, lm_head, layer norms, biases
287
+
288
+ Stores mp_bulk_edge and r_mp_base (sigma/mp_edge ratio) for each weight matrix.
289
+ This enables true baseline-aware RMT detection.
290
+
291
+ Args:
292
+ model: Model to analyze
293
+
294
+ Returns:
295
+ Dict mapping module names to their MP statistics:
296
+ {
297
+ 'module_name': {
298
+ 'mp_bulk_edge_base': float,
299
+ 'r_mp_base': float,
300
+ 'sigma_base': float
301
+ }
302
+ }
303
+ """
304
+ mp_stats = {}
305
+
306
+ # Get all modules with 2D weight matrices
307
+ try:
308
+ from transformers.pytorch_utils import Conv1D
309
+
310
+ module_types_with_conv1d: tuple[
311
+ type[nn.Linear], type[nn.Conv1d], type[Conv1D]
312
+ ] = (nn.Linear, nn.Conv1d, Conv1D)
313
+ module_types = module_types_with_conv1d
314
+ except ImportError:
315
+ module_types_without_conv1d: tuple[type[nn.Linear], type[nn.Conv1d]] = (
316
+ nn.Linear,
317
+ nn.Conv1d,
318
+ )
319
+ module_types = module_types_without_conv1d
320
+
321
+ # Define allowlist for RMT analysis - only linear layers where MP makes sense
322
+ allowed_suffixes = [".attn.c_attn", ".attn.c_proj", ".mlp.c_fc", ".mlp.c_proj"]
323
+
324
+ for name, module in model.named_modules():
325
+ if isinstance(module, module_types) and hasattr(module, "weight"):
326
+ # CRITICAL: Restrict to only linear layers where MP analysis is meaningful
327
+ # Skip embeddings, LM head, layer norms - MP heuristics don't apply there
328
+ if any(name.endswith(suffix) for suffix in allowed_suffixes):
329
+ # Get 2D weight matrix
330
+ for param_name, param in module.named_parameters(recurse=False):
331
+ if param.ndim == 2 and "weight" in param_name:
332
+ W = param.detach()
333
+
334
+ # Handle Conv1D transposition
335
+ try:
336
+ from transformers.pytorch_utils import Conv1D
337
+
338
+ if isinstance(module, Conv1D):
339
+ W = W.T
340
+ except ImportError:
341
+ pass
342
+
343
+ if W.ndim == 2:
344
+ m, n = W.shape
345
+
346
+ # Compute current sigma and MP edge
347
+ if not torch.isfinite(W).all():
348
+ continue
349
+ try:
350
+ s_actual = torch.linalg.svdvals(W.float().cpu())
351
+ sigma_base = s_actual[0].item()
352
+ mp_edge_base = mp_bulk_edge(m, n, whitened=False)
353
+
354
+ # Compute baseline r_mp ratio
355
+ r_mp_base = sigma_base / max(mp_edge_base, 1e-12)
356
+
357
+ # Store statistics with consistent naming
358
+ mp_stats[name] = {
359
+ "mp_bulk_edge_base": mp_edge_base,
360
+ "r_mp_base": r_mp_base,
361
+ "sigma_base": sigma_base,
362
+ }
363
+ except (RuntimeError, torch.linalg.LinAlgError):
364
+ # Skip if SVD fails
365
+ continue
366
+ break # Only process first weight parameter
367
+
368
+ return mp_stats
369
+
370
+
371
+ def _iter_transformer_layers(model: nn.Module):
372
+ """Iterate over transformer layers in a model."""
373
+ if hasattr(model, "transformer") and hasattr(model.transformer, "h"):
374
+ # GPT-2 style
375
+ h_layers = model.transformer.h
376
+ if hasattr(h_layers, "__iter__") and hasattr(h_layers, "__len__"):
377
+ try:
378
+ for layer in h_layers:
379
+ yield layer
380
+ except (TypeError, AttributeError):
381
+ pass
382
+ elif hasattr(model, "model") and hasattr(model.model, "layers"):
383
+ # LLaMA style
384
+ layers = model.model.layers
385
+ if hasattr(layers, "__iter__") and hasattr(layers, "__len__"):
386
+ try:
387
+ for layer in layers:
388
+ yield layer
389
+ except (TypeError, AttributeError):
390
+ pass
391
+ elif hasattr(model, "encoder") and hasattr(model.encoder, "layer"):
392
+ # BERT style
393
+ layer_attr = model.encoder.layer
394
+ if hasattr(layer_attr, "__iter__") and hasattr(layer_attr, "__len__"):
395
+ try:
396
+ for layer in layer_attr:
397
+ yield layer
398
+ except (TypeError, AttributeError):
399
+ pass
400
+ else:
401
+ # Fallback
402
+ for module in model.modules():
403
+ if hasattr(module, "attn") and hasattr(module, "mlp"):
404
+ yield module
405
+
406
+
407
+ def rmt_detect(
408
+ model: nn.Module,
409
+ threshold: float = 1.5,
410
+ detect_only: bool = True,
411
+ correction_factor: float | None = None,
412
+ layer_indices: list[int] | None = None,
413
+ target_layers: list[str] | None = None, # Alternative layer specification
414
+ verbose: bool = False,
415
+ max_iterations: int = 2, # Add iteration guard
416
+ baseline_sigmas: dict[str, float]
417
+ | None = None, # Add baseline sigmas for baseline-aware checking
418
+ baseline_mp_stats: dict[str, dict[str, float]]
419
+ | None = None, # Store baseline MP statistics
420
+ deadband: float = 0.0, # Add deadband parameter to align with spectral control
421
+ use_quantile_mp: bool = False, # Use quantile-based MP edge for heavy-tailed spectra
422
+ ) -> dict[str, Any]:
423
+ """
424
+ Detect RMT outliers in model with baseline-aware checking and iteration guard.
425
+
426
+ Args:
427
+ model: Model to analyze
428
+ threshold: Ratio threshold for flagging outliers (default 1.5)
429
+ detect_only: If True, only detect outliers without correction
430
+ correction_factor: Factor to apply for correction (if not detect_only)
431
+ layer_indices: Specific layers to analyze by index (None = all)
432
+ target_layers: Specific layers to analyze by name (None = all)
433
+ verbose: Whether to print warnings and details
434
+ max_iterations: Maximum iterations for correction (default 2)
435
+ baseline_sigmas: Baseline sigmas for baseline-aware checking
436
+ baseline_mp_stats: Baseline MP statistics (mp_bulk_edge, r_mp_base) for each weight matrix
437
+ deadband: Deadband threshold to align with spectral control
438
+ use_quantile_mp: Use quantile-based MP edge for heavy-tailed spectra
439
+
440
+ Returns:
441
+ Dict with detection results including per-layer details
442
+ """
443
+ per_layer: list[dict[str, Any]] = []
444
+ flagged_layers: list[int] = []
445
+
446
+ # Analyze only linear layers where MP analysis is meaningful
447
+ modules_to_analyze = []
448
+
449
+ # Define allowlist for RMT analysis - same as in capture_baseline_mp_stats
450
+ allowed_suffixes = [".attn.c_attn", ".attn.c_proj", ".mlp.c_fc", ".mlp.c_proj"]
451
+
452
+ if layer_indices is not None or target_layers is not None:
453
+ # If specific layers requested, only analyze transformer layers
454
+ for idx, layer in enumerate(_iter_transformer_layers(model)):
455
+ # Skip if not in specified layers (by index)
456
+ if layer_indices is not None and idx not in layer_indices:
457
+ continue
458
+
459
+ # Skip if not in specified layers (by name)
460
+ if target_layers is not None:
461
+ layer_name = None
462
+ for name, module in model.named_modules():
463
+ if module is layer:
464
+ layer_name = name
465
+ break
466
+ if layer_name is None or not any(
467
+ target in layer_name for target in target_layers
468
+ ):
469
+ continue
470
+
471
+ modules_to_analyze.append((f"transformer_layer_{idx}", layer))
472
+ else:
473
+ # CRITICAL: Only analyze modules where MP analysis makes sense
474
+ # Exclude embeddings, LM head, layer norms - they have different spectral properties
475
+ for name, module in model.named_modules():
476
+ # Check if this is an allowed module type with 2D weights
477
+ if any(name.endswith(suffix) for suffix in allowed_suffixes):
478
+ has_2d_weights = any(
479
+ param.ndim == 2 and "weight" in param_name
480
+ for param_name, param in module.named_parameters(recurse=False)
481
+ )
482
+ if has_2d_weights:
483
+ modules_to_analyze.append((name, module))
484
+
485
+ # Iteration guard for correction
486
+ prev_outlier_count = float("inf")
487
+ correction_iterations = 0
488
+
489
+ while correction_iterations < max_iterations:
490
+ current_outliers = 0
491
+ per_layer = [] # Reset per iteration
492
+ flagged_layers = []
493
+
494
+ for idx, (module_name, module) in enumerate(modules_to_analyze):
495
+ # Use baseline-aware stats if available
496
+ stats = layer_svd_stats(
497
+ module, baseline_sigmas, baseline_mp_stats, module_name
498
+ )
499
+
500
+ # Apply baseline-aware RMT detection with deadband support
501
+ has_outlier = False
502
+ skip_reason = None
503
+
504
+ if (
505
+ baseline_sigmas
506
+ and baseline_mp_stats
507
+ and module_name in baseline_sigmas
508
+ and module_name in baseline_mp_stats
509
+ ):
510
+ # Step 5 spec: ratio = σ_max_post / bulk_edge_base, flag if ratio > (1+deadband)*margin
511
+ sigma_post = stats["sigma_max"]
512
+ mp_stats = baseline_mp_stats[module_name]
513
+ bulk_edge_base = mp_stats.get("mp_bulk_edge_base", 1.0)
514
+
515
+ # Exact Step 5 detection rule
516
+ ratio = sigma_post / max(bulk_edge_base, 1e-12)
517
+ detection_threshold = (1.0 + deadband) * threshold
518
+
519
+ if ratio > detection_threshold:
520
+ has_outlier = True
521
+ skip_reason = None
522
+ else:
523
+ # Determine skip reason for clear logging
524
+ skip_reason = (
525
+ f"≤ threshold (ratio={ratio:.2f} ≤ {detection_threshold:.2f})"
526
+ )
527
+ elif deadband > 0.0 and baseline_sigmas and module_name in baseline_sigmas:
528
+ # Partial baseline-aware: deadband check only (fallback when no MP stats)
529
+ baseline_sigma = baseline_sigmas[module_name]
530
+ sigma_post = stats["sigma_max"]
531
+ ratio = sigma_post / max(baseline_sigma, 1e-12)
532
+ detection_threshold = (1.0 + deadband) * threshold
533
+
534
+ if ratio > detection_threshold:
535
+ has_outlier = True
536
+ skip_reason = None
537
+ else:
538
+ skip_reason = (
539
+ f"≤ threshold (ratio={ratio:.2f} ≤ {detection_threshold:.2f})"
540
+ )
541
+ else:
542
+ # Standard check without baseline awareness (fallback)
543
+ ratio = stats["worst_ratio"]
544
+ if ratio > threshold:
545
+ has_outlier = True
546
+ skip_reason = None
547
+ else:
548
+ skip_reason = f"≤ threshold (ratio={ratio:.2f} ≤ {threshold:.2f})"
549
+
550
+ layer_info = {
551
+ "layer": idx,
552
+ "module_name": module_name,
553
+ "sigma_min": stats["sigma_min"],
554
+ "sigma_max": stats["sigma_max"],
555
+ "worst_ratio": stats["worst_ratio"],
556
+ "has_outlier": has_outlier,
557
+ }
558
+
559
+ # Add detailed info if available
560
+ if "worst_details" in stats:
561
+ layer_info["details"] = stats["worst_details"]
562
+
563
+ per_layer.append(layer_info)
564
+
565
+ # Store skip reason in layer info for better logging
566
+ layer_info["skip_reason"] = skip_reason
567
+
568
+ if has_outlier:
569
+ flagged_layers.append(idx)
570
+ current_outliers += 1
571
+ if verbose:
572
+ normalization = stats.get("worst_details", {}).get(
573
+ "normalization", "unknown"
574
+ )
575
+ print(
576
+ f" Module {module_name}: ratio={stats['worst_ratio']:.2f} "
577
+ f"(σ_max={stats['sigma_max']:.2f}, norm={normalization})"
578
+ )
579
+ elif verbose and skip_reason:
580
+ print(f" Module {module_name}: SKIP: {skip_reason}")
581
+
582
+ # Apply correction if requested and not detect-only
583
+ if not detect_only and current_outliers > 0 and correction_factor is not None:
584
+ if correction_iterations == 0:
585
+ if verbose:
586
+ print(
587
+ f" Applying RMT correction (iteration {correction_iterations + 1})..."
588
+ )
589
+ # Apply correction to flagged modules
590
+ for idx in flagged_layers:
591
+ module_name, module = modules_to_analyze[idx]
592
+ _apply_rmt_correction(
593
+ module,
594
+ correction_factor,
595
+ baseline_sigmas,
596
+ baseline_mp_stats,
597
+ module_name,
598
+ deadband,
599
+ verbose,
600
+ adapter=None,
601
+ )
602
+ else:
603
+ # Check if improvement occurred
604
+ if current_outliers >= prev_outlier_count:
605
+ if verbose:
606
+ print(
607
+ f" RMT correction stalled ({current_outliers} outliers unchanged), "
608
+ f"downgrading to warning"
609
+ )
610
+ break
611
+ elif verbose:
612
+ print(
613
+ f" RMT correction improving ({prev_outlier_count} → {current_outliers} outliers)"
614
+ )
615
+ else:
616
+ # No correction requested, exit after first iteration
617
+ break
618
+
619
+ prev_outlier_count = current_outliers
620
+ correction_iterations += 1
621
+
622
+ # Exit if no outliers remain
623
+ if current_outliers == 0:
624
+ break
625
+
626
+ # Aggregate results
627
+ n_outliers = len(flagged_layers)
628
+ max_ratio = max((item["worst_ratio"] for item in per_layer), default=0.0)
629
+ has_outliers = n_outliers > 0
630
+
631
+ if verbose and has_outliers:
632
+ baseline_note = (
633
+ " (baseline-aware)"
634
+ if baseline_sigmas and baseline_mp_stats
635
+ else " (absolute)"
636
+ )
637
+ deadband_note = f" with {deadband:.0%} deadband" if deadband > 0.0 else ""
638
+
639
+ # Count detected vs will-be-capped
640
+ n_detected = n_outliers
641
+ n_will_be_capped = n_outliers if not detect_only else 0
642
+
643
+ print(f" ⚠️ RMT outliers detected{baseline_note}{deadband_note}:")
644
+ print(f" Detected: {n_detected}, will correct: {n_will_be_capped}")
645
+ print(f" Max ratio: {max_ratio:.2f}")
646
+ print(" Top offenders (σ_post / σ_ref):")
647
+
648
+ # Show top 3 offenders with detailed information
649
+ top_offenders = sorted(
650
+ [
651
+ (item["worst_ratio"], item["module_name"], item.get("details", {}))
652
+ for item in per_layer
653
+ if item["has_outlier"]
654
+ ],
655
+ reverse=True,
656
+ )[:3]
657
+
658
+ for ratio, module_name, details in top_offenders:
659
+ sigma_max = details.get("s_max", 0.0)
660
+ ref_type = "mp_bulk_edge" if not baseline_sigmas else "baseline-aware"
661
+ print(
662
+ f" - {module_name}: {ratio:.2f} (σ_post={sigma_max:.2f}, ref={ref_type})"
663
+ )
664
+
665
+ if len(top_offenders) < n_outliers:
666
+ print(
667
+ f" ... and {n_outliers - len(top_offenders)} more layers flagged"
668
+ )
669
+
670
+ return {
671
+ "has_outliers": has_outliers,
672
+ "n_layers_flagged": n_outliers,
673
+ "outlier_count": n_outliers, # Alias for compatibility
674
+ "max_ratio": max_ratio,
675
+ "threshold": threshold,
676
+ "correction_iterations": correction_iterations,
677
+ "per_layer": per_layer,
678
+ "flagged_layers": flagged_layers,
679
+ "layers": {
680
+ f"layer_{item['layer']}": item for item in per_layer
681
+ }, # Alternative format
682
+ }
683
+
684
+
685
+ def rmt_detect_report(
686
+ model: nn.Module, threshold: float = 1.5
687
+ ) -> tuple[dict, list[dict]]:
688
+ """
689
+ Generate an RMT health report.
690
+
691
+ Args:
692
+ model: Model to analyze
693
+ threshold: Ratio threshold for outliers
694
+
695
+ Returns:
696
+ (summary_dict, per_layer_list)
697
+ """
698
+ result = rmt_detect(model, threshold, verbose=False)
699
+
700
+ summary = {
701
+ "has_outliers": result["has_outliers"],
702
+ "n_layers_flagged": result["n_layers_flagged"],
703
+ "max_ratio": result["max_ratio"],
704
+ "rmt_max_ratio": result["max_ratio"], # Alias for compatibility
705
+ "rmt_has_outliers": result["has_outliers"], # Alias
706
+ }
707
+
708
+ return summary, result["per_layer"]
709
+
710
+
711
+ def rmt_detect_with_names(
712
+ model: nn.Module, threshold: float = 1.5, verbose: bool = False
713
+ ) -> dict[str, Any]:
714
+ """
715
+ Detect RMT outliers in model and return detailed information with module names.
716
+
717
+ Args:
718
+ model: Model to analyze
719
+ threshold: Ratio threshold for flagging outliers (default 1.5)
720
+ verbose: Whether to print warnings and details
721
+
722
+ Returns:
723
+ Dict with detection results including per-layer details and module names
724
+ """
725
+ outliers = []
726
+ per_layer = []
727
+ flagged_layers = []
728
+
729
+ # Get all transformer layers with their names
730
+ layer_modules = []
731
+ if hasattr(model, "transformer") and hasattr(model.transformer, "h"):
732
+ # GPT-2 style
733
+ h_layers = model.transformer.h
734
+ if hasattr(h_layers, "__iter__"):
735
+ for idx, layer in enumerate(h_layers):
736
+ layer_modules.append((f"transformer.h.{idx}", layer))
737
+ elif hasattr(model, "model") and hasattr(model.model, "layers"):
738
+ # LLaMA style
739
+ layers = model.model.layers
740
+ if hasattr(layers, "__iter__"):
741
+ for idx, layer in enumerate(layers):
742
+ layer_modules.append((f"model.layers.{idx}", layer))
743
+ elif hasattr(model, "encoder") and hasattr(model.encoder, "layer"):
744
+ # BERT style
745
+ layer_attr = model.encoder.layer
746
+ if hasattr(layer_attr, "__iter__"):
747
+ for idx, layer in enumerate(layer_attr):
748
+ layer_modules.append((f"encoder.layer.{idx}", layer))
749
+ else:
750
+ # Fallback - try to find transformer layers by attributes
751
+ for name, module in model.named_modules():
752
+ if hasattr(module, "attn") and hasattr(module, "mlp"):
753
+ layer_modules.append((name, module))
754
+
755
+ for layer_name, layer in layer_modules:
756
+ stats = layer_svd_stats(layer, module_name=layer_name)
757
+
758
+ # Check if layer has outliers
759
+ has_outlier = stats["worst_ratio"] > threshold
760
+
761
+ # Add detailed info if available
762
+ if "worst_details" in stats:
763
+ layer_info = {
764
+ "layer_name": layer_name,
765
+ "sigma_min": stats["sigma_min"],
766
+ "sigma_max": stats["sigma_max"],
767
+ "worst_ratio": stats["worst_ratio"],
768
+ "has_outlier": has_outlier,
769
+ "details": stats["worst_details"],
770
+ }
771
+
772
+ # Add module name to outlier details
773
+ if has_outlier:
774
+ outlier_info = {
775
+ "layer_name": layer_name,
776
+ "module_name": f"{layer_name}.{stats['worst_details']['name']}",
777
+ "sigma_max": stats["sigma_max"],
778
+ "ratio": stats["worst_ratio"],
779
+ "details": stats["worst_details"],
780
+ }
781
+ outliers.append(outlier_info)
782
+ flagged_layers.append(layer_name)
783
+ else:
784
+ layer_info = {
785
+ "layer_name": layer_name,
786
+ "sigma_min": stats["sigma_min"],
787
+ "sigma_max": stats["sigma_max"],
788
+ "worst_ratio": stats["worst_ratio"],
789
+ "has_outlier": has_outlier,
790
+ }
791
+
792
+ per_layer.append(layer_info)
793
+
794
+ # Aggregate results
795
+ n_outliers = len(flagged_layers)
796
+ max_ratio = 0.0
797
+ if per_layer:
798
+ try:
799
+ max_ratio = max(float(item.get("worst_ratio", 0.0)) for item in per_layer)
800
+ except (TypeError, ValueError):
801
+ max_ratio = 0.0
802
+ has_outliers = n_outliers > 0
803
+
804
+ if verbose and has_outliers:
805
+ print(" ⚠️ RMT outliers detected:")
806
+ print(f" Layers flagged: {n_outliers}")
807
+ print(f" Max ratio: {max_ratio:.2f}")
808
+ print(f" Threshold: {threshold:.2f}")
809
+ print(" Top offenders (σ_post / σ_ref):")
810
+ # Show top offenders with full module names and consistent formatting
811
+ for outlier in outliers[:3]:
812
+ print(
813
+ f" - {outlier['module_name']}: {outlier['ratio']:.2f} (σ_post={outlier['sigma_max']:.2f}, ref=mp_bulk_edge)"
814
+ )
815
+ if len(outliers) > 3:
816
+ print(f" ... and {len(outliers) - 3} more layers flagged")
817
+
818
+ return {
819
+ "has_outliers": has_outliers,
820
+ "n_layers_flagged": n_outliers,
821
+ "outlier_count": n_outliers,
822
+ "max_ratio": max_ratio,
823
+ "threshold": threshold,
824
+ "per_layer": per_layer,
825
+ "flagged_layers": flagged_layers,
826
+ "outliers": outliers, # Add the outliers list with full module names
827
+ "layers": {item["layer_name"]: item for item in per_layer},
828
+ }
829
+
830
+
831
+ def _apply_rmt_correction(
832
+ layer: nn.Module,
833
+ factor: float,
834
+ baseline_sigmas: dict[str, float] | None = None,
835
+ baseline_mp_stats: dict[str, dict[str, float]] | None = None,
836
+ layer_name: str = "",
837
+ deadband: float = 0.0,
838
+ verbose: bool = False,
839
+ adapter=None,
840
+ ):
841
+ """
842
+ Apply RMT-based correction to layer weights with proper cap application.
843
+
844
+ Enhanced for Step 5 with:
845
+ - Step 5 detection rule: target = bulk_edge_base * margin * (1 - deadband)
846
+ - Adapter tying map support for preserving weight tying relationships
847
+ - IN-PLACE scaling (param.mul_) to preserve weight tying
848
+ - Never rewraps Parameters to avoid breaking lm_head ↔ wte aliasing
849
+ """
850
+ for name, param in layer.named_parameters():
851
+ if param.ndim == 2 and "weight" in name:
852
+ with torch.no_grad():
853
+ # Get current spectral norm
854
+ try:
855
+ W = param.detach()
856
+ # Handle Conv1D transposition
857
+ Conv1D = None
858
+ try:
859
+ from transformers.pytorch_utils import Conv1D as _Conv1D
860
+
861
+ Conv1D = _Conv1D
862
+
863
+ if isinstance(layer, Conv1D):
864
+ W = W.T
865
+ except ImportError:
866
+ pass
867
+
868
+ if not torch.isfinite(W).all():
869
+ continue
870
+ s_vals = torch.linalg.svdvals(W.float().cpu())
871
+ sigma_pre = s_vals[0].item()
872
+
873
+ # Step 5 correction logic: target based on MP bulk edge
874
+ target_sigma = None
875
+
876
+ if (
877
+ baseline_sigmas
878
+ and baseline_mp_stats
879
+ and layer_name in baseline_mp_stats
880
+ ):
881
+ # CORRECTED Step 5: Use baseline sigma for target calculation
882
+ mp_stats = baseline_mp_stats[layer_name]
883
+ sigma_base = mp_stats.get("sigma_base", 1.0)
884
+
885
+ # Step 5 target: baseline * margin * (1 - deadband) for conservative correction
886
+ margin = (
887
+ 1.5 # Default from policy, could be passed as parameter
888
+ )
889
+ target_sigma = sigma_base * margin * (1.0 - deadband)
890
+ else:
891
+ # Fallback: Use current MP edge
892
+ m, n = W.shape
893
+ mp_edge = mp_bulk_edge(m, n, whitened=False)
894
+ target_sigma = mp_edge * 1.0 # Conservative cap at edge
895
+
896
+ # Apply correction only if needed
897
+ if sigma_pre > target_sigma:
898
+ # Compute proper scale: target/σ_pre
899
+ scale = target_sigma / sigma_pre
900
+ scale = max(
901
+ scale, 0.1
902
+ ) # Floor at 10% to avoid extreme shrinkage
903
+
904
+ # Check for tied parameters using adapter's tying map
905
+ tied_params = []
906
+ if adapter and hasattr(adapter, "get_tying_map"):
907
+ try:
908
+ tying_map = adapter.get_tying_map()
909
+ full_param_name = f"{layer_name}.{name}"
910
+ tied_params = tying_map.get(full_param_name, [])
911
+ except Exception:
912
+ # Fallback if adapter doesn't support tying map
913
+ tied_params = []
914
+
915
+ # CRITICAL: Apply IN-PLACE scaling to preserve weight tying
916
+ param.mul_(scale) # PRESERVES TYING - same data pointer
917
+
918
+ # Apply same scaling to tied parameters if any
919
+ if tied_params and adapter:
920
+ for tied_name in tied_params:
921
+ try:
922
+ # Get tied parameter and apply same scale
923
+ tied_param = adapter.get_parameter_by_name(
924
+ tied_name
925
+ )
926
+ if tied_param is not None:
927
+ tied_param.mul_(scale)
928
+ except Exception:
929
+ # Continue if tied parameter access fails
930
+ pass
931
+
932
+ # Recompute sigma after scaling for accurate logging
933
+ W_after = param.detach()
934
+ if Conv1D is not None and isinstance(layer, Conv1D):
935
+ W_after = W_after.T
936
+ s_vals_after = torch.linalg.svdvals(W_after.float().cpu())
937
+ sigma_post = s_vals_after[0].item()
938
+
939
+ # Log the correction with proper values
940
+ if verbose:
941
+ tied_info = (
942
+ f", tied to {len(tied_params)} params"
943
+ if tied_params
944
+ else ""
945
+ )
946
+ print(
947
+ f" {layer_name}.{name}: σ={sigma_pre:.2f}→{sigma_post:.2f} "
948
+ f"(scale={scale:.3f}, target={target_sigma:.2f}{tied_info})"
949
+ )
950
+ else:
951
+ # No correction needed - log skip reason
952
+ if verbose:
953
+ print(
954
+ f" {layer_name}.{name}: SKIP: ≤ target (σ={sigma_pre:.2f} ≤ {target_sigma:.2f})"
955
+ )
956
+
957
+ except (RuntimeError, torch.linalg.LinAlgError):
958
+ # CRITICAL: Even fallback must use in-place scaling
959
+ param.mul_(factor)
960
+ if verbose:
961
+ print(
962
+ f" {layer_name}.{name}: fallback scaling (SVD failed)"
963
+ )
964
+
965
+
966
+ def clip_full_svd(
967
+ W: torch.Tensor, clip_val: float, return_components: bool = False
968
+ ) -> torch.Tensor:
969
+ """
970
+ Clip singular values of a matrix using full SVD.
971
+
972
+ Args:
973
+ W: Weight matrix
974
+ clip_val: Maximum singular value
975
+ return_components: If True, return (U, S_clipped, Vt)
976
+
977
+ Returns:
978
+ Clipped weight matrix or components
979
+ """
980
+ if not torch.isfinite(W).all():
981
+ if return_components:
982
+ return None, None, None
983
+ return W
984
+
985
+ try:
986
+ U, S, Vt = torch.linalg.svd(W.float(), full_matrices=False)
987
+ S_clipped = torch.clamp(S, max=clip_val)
988
+
989
+ if return_components:
990
+ return U, S_clipped, Vt
991
+ else:
992
+ return (U @ torch.diag(S_clipped) @ Vt).to(W.dtype)
993
+ except (RuntimeError, torch.linalg.LinAlgError):
994
+ # Return original on error
995
+ if return_components:
996
+ return None, None, None
997
+ return W
998
+
999
+
1000
+ def analyze_weight_distribution(model: nn.Module, n_bins: int = 50) -> dict[str, Any]:
1001
+ """
1002
+ Analyze weight distribution statistics for RMT analysis.
1003
+
1004
+ Args:
1005
+ model: Model to analyze
1006
+ n_bins: Number of histogram bins
1007
+
1008
+ Returns:
1009
+ Dict with distribution statistics
1010
+ """
1011
+ all_weights = []
1012
+ all_singular_values = []
1013
+
1014
+ for name, param in model.named_parameters():
1015
+ if param.ndim == 2 and "weight" in name:
1016
+ param_cpu = param.detach().cpu()
1017
+ if not torch.isfinite(param_cpu).all():
1018
+ continue
1019
+
1020
+ # Collect weights
1021
+ all_weights.append(param_cpu.flatten())
1022
+
1023
+ # Collect singular values
1024
+ try:
1025
+ s = torch.linalg.svdvals(param_cpu.float())
1026
+ all_singular_values.append(s)
1027
+ except (RuntimeError, torch.linalg.LinAlgError):
1028
+ continue
1029
+
1030
+ if not all_weights:
1031
+ return {}
1032
+
1033
+ # Concatenate all weights
1034
+ weights = torch.cat(all_weights)
1035
+
1036
+ # Compute statistics
1037
+ stats = {
1038
+ "mean": weights.mean().item(),
1039
+ "std": weights.std().item(),
1040
+ "min": weights.min().item(),
1041
+ "max": weights.max().item(),
1042
+ "sparsity": (weights.abs() < 1e-6).float().mean().item(),
1043
+ }
1044
+
1045
+ # Compute histogram
1046
+ hist, edges = torch.histogram(weights, bins=n_bins)
1047
+ stats["histogram"] = hist.tolist()
1048
+ stats["bin_edges"] = edges.tolist()
1049
+
1050
+ # Singular value statistics
1051
+ if all_singular_values:
1052
+ s_all = torch.cat(all_singular_values)
1053
+ singular_values_dict: dict[str, float] = {
1054
+ "mean": s_all.mean().item(),
1055
+ "std": s_all.std().item(),
1056
+ "min": s_all.min().item(),
1057
+ "max": s_all.max().item(),
1058
+ "condition_number": (s_all.max() / (s_all.min() + 1e-8)).item(),
1059
+ }
1060
+ stats["singular_values"] = singular_values_dict
1061
+
1062
+ # Add MP edge information
1063
+ if all_singular_values:
1064
+ # Estimate MP edges from data
1065
+ n_samples: float = sum(s.shape[0] for s in all_singular_values)
1066
+ n_features: float = np.mean([s.shape[0] for s in all_singular_values])
1067
+ mp_min, mp_max = mp_bulk_edges(int(n_samples), int(n_features))
1068
+ mp_edges_dict: dict[str, float] = {"min": mp_min, "max": mp_max}
1069
+ stats["mp_edges"] = mp_edges_dict
1070
+
1071
+ # Add eigenvalue stats (alias for singular values)
1072
+ stats["eigenvalue_stats"] = stats["singular_values"]
1073
+
1074
+ return stats
1075
+
1076
+
1077
+ # === Guard Implementation ===
1078
+
1079
+ # Import GuardOutcome types if available
1080
+ try:
1081
+ from invarlock.core.types import GuardOutcome
1082
+
1083
+ HAS_GUARD_OUTCOME = True
1084
+ except ImportError:
1085
+ # Fallback for standalone usage or when types not available
1086
+ HAS_GUARD_OUTCOME = False
1087
+ GuardOutcome = dict
1088
+
1089
+
1090
+ @dataclass
1091
+ class RMTPolicy:
1092
+ """
1093
+ RMT Guard Policy Configuration.
1094
+
1095
+ Defines parameters for baseline-aware RMT outlier detection and correction.
1096
+ """
1097
+
1098
+ q: float | Literal["auto"] = (
1099
+ "auto" # MP aspect ratio m/n (auto-derived from weights)
1100
+ )
1101
+ deadband: float = 0.10 # Tolerance margin (10%)
1102
+ margin: float = 1.5 # RMT threshold ratio
1103
+ correct: bool = True # Enable automatic correction
1104
+
1105
+
1106
+ class RMTPolicyDict(TypedDict):
1107
+ """TypedDict version of RMTPolicy for compatibility."""
1108
+
1109
+ q: float | Literal["auto"]
1110
+ deadband: float
1111
+ margin: float
1112
+ correct: bool
1113
+ epsilon: float | dict[str, float] | None
1114
+
1115
+
1116
+ class RMTGuard(Guard):
1117
+ """
1118
+ Standalone RMT Guard for baseline-aware outlier detection and correction.
1119
+
1120
+ Implements Marchenko-Pastur theory-based spectral health checking with:
1121
+ - Baseline capture of MP bulk edges for linear layers
1122
+ - Conservative outlier detection with deadband support
1123
+ - Optional in-place correction preserving weight tying
1124
+ - Comprehensive event logging and metrics
1125
+
1126
+ Policy Structure:
1127
+ - q: MP aspect ratio (auto-derived or manual)
1128
+ - deadband: Tolerance margin before flagging (default 0.10 = 10%)
1129
+ - margin: RMT threshold ratio (default 1.5)
1130
+ - correct: Enable automatic correction (default True)
1131
+
1132
+ Linear Layer Scope (enforced):
1133
+ - attn.c_attn, attn.c_proj, mlp.c_fc, mlp.c_proj
1134
+ - Excludes: embeddings, LM head, layer norms, biases
1135
+ """
1136
+
1137
+ name = "rmt"
1138
+
1139
+ def __init__(
1140
+ self,
1141
+ q: float | Literal["auto"] = "auto",
1142
+ deadband: float = 0.10,
1143
+ margin: float = 1.5,
1144
+ correct: bool = True,
1145
+ epsilon: float | dict[str, float] | None = None,
1146
+ ):
1147
+ """
1148
+ Initialize RMT Guard.
1149
+
1150
+ Args:
1151
+ q: MP aspect ratio (auto-derived from weight shapes if "auto")
1152
+ deadband: Tolerance margin before flagging outliers (0.10 = 10%)
1153
+ margin: RMT threshold ratio for outlier detection (1.5)
1154
+ correct: Enable automatic correction when outliers detected
1155
+ """
1156
+ self.q = q
1157
+ self.deadband = deadband
1158
+ self.margin = margin
1159
+ self.correct = correct
1160
+ self.epsilon_default = 0.10
1161
+ self.epsilon_by_family: dict[str, float] = {}
1162
+ self._set_epsilon(epsilon)
1163
+ for family_key in ("attn", "ffn", "embed", "other"):
1164
+ self.epsilon_by_family.setdefault(family_key, self.epsilon_default)
1165
+
1166
+ # Internal state
1167
+ self.baseline_mp_stats: dict[str, dict[str, float]] | None = None
1168
+ self.baseline_sigmas: dict[str, float] | None = None
1169
+ self.prepared = False
1170
+ self.events: list[dict[str, Any]] = []
1171
+ self._last_result: dict[str, Any] | None = None
1172
+ self.adapter = None # Store adapter for tying map access
1173
+
1174
+ # Linear layer scope enforcement - same as existing RMT
1175
+ self.allowed_suffixes = [
1176
+ ".attn.c_attn",
1177
+ ".attn.c_proj",
1178
+ ".mlp.c_fc",
1179
+ ".mlp.c_proj",
1180
+ ]
1181
+ self.baseline_outliers_per_family: dict[str, int] = {}
1182
+ self.baseline_total_outliers: int = 0
1183
+ self.outliers_per_family: dict[str, int] = {}
1184
+ self.outliers_total: int = 0
1185
+ self.epsilon_violations: list[dict[str, Any]] = []
1186
+
1187
+ def _log_event(
1188
+ self, operation: str, level: str = "INFO", message: str = "", **data
1189
+ ):
1190
+ """Log an event with timestamp."""
1191
+ event = {
1192
+ "timestamp": datetime.utcnow().isoformat(),
1193
+ "component": "rmt_guard",
1194
+ "operation": operation,
1195
+ "level": level,
1196
+ "message": message,
1197
+ "data": data,
1198
+ }
1199
+ self.events.append(event)
1200
+
1201
+ def _set_epsilon(self, epsilon: float | dict[str, float] | None) -> None:
1202
+ """Configure epsilon defaults and per-family overrides."""
1203
+ if isinstance(epsilon, dict):
1204
+ mapped: dict[str, float] = {}
1205
+ for family, value in epsilon.items():
1206
+ try:
1207
+ mapped[str(family)] = float(value)
1208
+ except (TypeError, ValueError):
1209
+ continue
1210
+ if mapped:
1211
+ self.epsilon_by_family.update(mapped)
1212
+ self.epsilon_default = max(mapped.values())
1213
+ elif isinstance(epsilon, int | float):
1214
+ self.epsilon_default = float(epsilon)
1215
+ if self.epsilon_by_family:
1216
+ for family in list(self.epsilon_by_family):
1217
+ self.epsilon_by_family[family] = self.epsilon_default
1218
+
1219
+ @staticmethod
1220
+ def _classify_family(module_name: str) -> str:
1221
+ """Classify module name into a guard family."""
1222
+ lower = module_name.lower()
1223
+ # MoE
1224
+ if any(
1225
+ tok in lower
1226
+ for tok in ("router", "routing", "gate", "gating", "dispatch", "switch")
1227
+ ):
1228
+ return "router"
1229
+ if any(
1230
+ tok in lower for tok in ("experts", "expert", "moe", "mixture_of_experts")
1231
+ ):
1232
+ return "expert_ffn"
1233
+ if ".attn." in lower or "attention" in lower:
1234
+ return "attn"
1235
+ if ".mlp." in lower or "ffn" in lower or ".c_fc" in lower:
1236
+ return "ffn"
1237
+ if "embed" in lower or "wte" in lower or "wpe" in lower:
1238
+ return "embed"
1239
+ return "other"
1240
+
1241
+ def _count_outliers_per_family(
1242
+ self, per_layer: list[dict[str, Any]]
1243
+ ) -> dict[str, int]:
1244
+ """Count outliers grouped by family."""
1245
+ counts: dict[str, int] = {}
1246
+ for layer_info in per_layer:
1247
+ if not layer_info.get("has_outlier"):
1248
+ continue
1249
+ module_name = layer_info.get("module_name", "")
1250
+ family = self._classify_family(module_name)
1251
+ counts[family] = counts.get(family, 0) + 1
1252
+ return counts
1253
+
1254
+ def _compute_epsilon_violations(self) -> list[dict[str, Any]]:
1255
+ """Compute epsilon-rule violations per family."""
1256
+ violations: list[dict[str, Any]] = []
1257
+ families = set(self.outliers_per_family) | set(
1258
+ self.baseline_outliers_per_family
1259
+ )
1260
+ for family in families:
1261
+ bare = int(self.baseline_outliers_per_family.get(family, 0) or 0)
1262
+ guarded = int(self.outliers_per_family.get(family, 0) or 0)
1263
+ epsilon_val = float(
1264
+ self.epsilon_by_family.get(family, self.epsilon_default)
1265
+ )
1266
+ allowed = math.ceil(bare * (1 + epsilon_val))
1267
+ if guarded > allowed:
1268
+ violations.append(
1269
+ {
1270
+ "family": family,
1271
+ "bare": bare,
1272
+ "guarded": guarded,
1273
+ "allowed": allowed,
1274
+ "epsilon": epsilon_val,
1275
+ }
1276
+ )
1277
+ return violations
1278
+
1279
+ def _get_linear_modules(self, model: nn.Module) -> list[tuple[str, nn.Module]]:
1280
+ """
1281
+ Get linear modules that are in scope for RMT analysis.
1282
+
1283
+ Args:
1284
+ model: Model to analyze
1285
+
1286
+ Returns:
1287
+ List of (name, module) tuples for linear layers in scope
1288
+ """
1289
+ modules = []
1290
+
1291
+ # Get module types
1292
+ try:
1293
+ from transformers.pytorch_utils import Conv1D
1294
+
1295
+ module_types_with_conv1d_2: tuple[
1296
+ type[nn.Linear], type[nn.Conv1d], type[Conv1D]
1297
+ ] = (nn.Linear, nn.Conv1d, Conv1D)
1298
+ module_types = module_types_with_conv1d_2
1299
+ except ImportError:
1300
+ module_types_without_conv1d_2: tuple[type[nn.Linear], type[nn.Conv1d]] = (
1301
+ nn.Linear,
1302
+ nn.Conv1d,
1303
+ )
1304
+ module_types = module_types_without_conv1d_2
1305
+
1306
+ modules: list[tuple[str, nn.Module]] = []
1307
+ for name, module in model.named_modules():
1308
+ if isinstance(module, module_types) and hasattr(module, "weight"):
1309
+ # Strict scope enforcement - only allowed linear layers
1310
+ if any(name.endswith(suffix) for suffix in self.allowed_suffixes):
1311
+ modules.append((name, module))
1312
+
1313
+ return modules
1314
+
1315
+ def _apply_rmt_detection_and_correction(self, model: nn.Module) -> dict[str, Any]:
1316
+ """
1317
+ Apply Step 5 RMT detection and correction with adapter support.
1318
+
1319
+ Uses exact Step 5 detection rule: ratio = σ_max_post / bulk_edge_base
1320
+ Flag if ratio > (1+deadband)*margin
1321
+ """
1322
+ per_layer = []
1323
+ flagged_layers = []
1324
+ corrected_layers = 0
1325
+
1326
+ # Get linear modules in scope
1327
+ modules_to_analyze = self._get_linear_modules(model)
1328
+
1329
+ self._log_event(
1330
+ "rmt_correction",
1331
+ message=f"Applying Step 5 detection and correction to {len(modules_to_analyze)} modules",
1332
+ )
1333
+
1334
+ for idx, (module_name, module) in enumerate(modules_to_analyze):
1335
+ # Get current stats
1336
+ stats = layer_svd_stats(
1337
+ module, self.baseline_sigmas, self.baseline_mp_stats, module_name
1338
+ )
1339
+
1340
+ # Step 5 detection rule
1341
+ has_outlier = False
1342
+ skip_reason = None
1343
+
1344
+ if self.baseline_mp_stats and module_name in self.baseline_mp_stats:
1345
+ sigma_post = stats["sigma_max"]
1346
+ mp_stats = self.baseline_mp_stats[module_name]
1347
+ sigma_base = mp_stats.get("sigma_base", 1.0)
1348
+
1349
+ # CORRECTED Step 5 detection rule: baseline-aware growth ratio
1350
+ # Compare current σ_max to baseline σ_max, normalized for stability
1351
+ ratio = sigma_post / max(sigma_base, 1e-12)
1352
+ detection_threshold = (1.0 + self.deadband) * self.margin
1353
+
1354
+ if ratio > detection_threshold:
1355
+ has_outlier = True
1356
+
1357
+ # Apply correction using enhanced logic with adapter support
1358
+ if self.correct:
1359
+ try:
1360
+ _apply_rmt_correction(
1361
+ module,
1362
+ 0.95, # Conservative factor (not used in Step 5 logic)
1363
+ self.baseline_sigmas,
1364
+ self.baseline_mp_stats,
1365
+ module_name,
1366
+ self.deadband,
1367
+ verbose=False,
1368
+ adapter=self.adapter,
1369
+ )
1370
+ corrected_layers += 1
1371
+
1372
+ self._log_event(
1373
+ "rmt_correct",
1374
+ message=f"Applied correction to {module_name}",
1375
+ module_name=module_name,
1376
+ pre_ratio=ratio,
1377
+ threshold=detection_threshold,
1378
+ )
1379
+
1380
+ # Re-compute stats after correction
1381
+ stats_post = layer_svd_stats(
1382
+ module,
1383
+ self.baseline_sigmas,
1384
+ self.baseline_mp_stats,
1385
+ module_name,
1386
+ )
1387
+ mp_stats = self.baseline_mp_stats[module_name]
1388
+ bulk_edge_base = mp_stats.get("mp_bulk_edge_base", 1.0)
1389
+ ratio_post = stats_post["sigma_max"] / max(
1390
+ bulk_edge_base, 1e-12
1391
+ )
1392
+
1393
+ # Update has_outlier based on post-correction ratio
1394
+ has_outlier = ratio_post > detection_threshold
1395
+
1396
+ except Exception as e:
1397
+ self._log_event(
1398
+ "rmt_correct_failed",
1399
+ level="ERROR",
1400
+ message=f"Correction failed for {module_name}: {str(e)}",
1401
+ module_name=module_name,
1402
+ error=str(e),
1403
+ )
1404
+ else:
1405
+ skip_reason = (
1406
+ f"≤ threshold (ratio={ratio:.2f} ≤ {detection_threshold:.2f})"
1407
+ )
1408
+ else:
1409
+ # Fallback when no baseline MP stats
1410
+ ratio = stats["worst_ratio"]
1411
+ if ratio > self.margin:
1412
+ has_outlier = True
1413
+ else:
1414
+ skip_reason = f"≤ margin (ratio={ratio:.2f} ≤ {self.margin:.2f})"
1415
+
1416
+ layer_info = {
1417
+ "layer": idx,
1418
+ "module_name": module_name,
1419
+ "sigma_min": stats["sigma_min"],
1420
+ "sigma_max": stats["sigma_max"],
1421
+ "worst_ratio": stats["worst_ratio"],
1422
+ "has_outlier": has_outlier,
1423
+ "skip_reason": skip_reason,
1424
+ }
1425
+
1426
+ if "worst_details" in stats:
1427
+ layer_info["details"] = stats["worst_details"]
1428
+
1429
+ per_layer.append(layer_info)
1430
+
1431
+ if has_outlier:
1432
+ flagged_layers.append(idx)
1433
+
1434
+ # Aggregate results
1435
+ n_outliers = len(flagged_layers)
1436
+ max_ratio = max((float(item["worst_ratio"]) for item in per_layer), default=0.0)
1437
+ has_outliers = n_outliers > 0
1438
+
1439
+ return {
1440
+ "has_outliers": has_outliers,
1441
+ "n_layers_flagged": n_outliers,
1442
+ "outlier_count": n_outliers,
1443
+ "max_ratio": max_ratio,
1444
+ "threshold": self.margin,
1445
+ "correction_iterations": 1 if corrected_layers > 0 else 0,
1446
+ "corrected_layers": corrected_layers,
1447
+ "per_layer": per_layer,
1448
+ "flagged_layers": flagged_layers,
1449
+ "layers": {f"layer_{item['layer']}": item for item in per_layer},
1450
+ }
1451
+
1452
+ def prepare(
1453
+ self,
1454
+ model: nn.Module,
1455
+ adapter=None,
1456
+ calib=None,
1457
+ policy: dict[str, Any] | None = None,
1458
+ ) -> dict[str, Any]:
1459
+ """
1460
+ Prepare RMT guard by capturing baseline MP statistics.
1461
+
1462
+ Args:
1463
+ model: The model that will be edited
1464
+ adapter: ModelAdapter (optional, for tying map access)
1465
+ calib: Calibration data (unused for RMT)
1466
+ policy: Guard policy parameters (optional)
1467
+
1468
+ Returns:
1469
+ Dictionary with preparation results and baseline metrics
1470
+ """
1471
+ import time
1472
+
1473
+ start_time = time.time()
1474
+
1475
+ # Store adapter for tying map access during correction
1476
+ self.adapter = adapter
1477
+
1478
+ # Update parameters from policy if provided
1479
+ if policy:
1480
+ self.q = policy.get("q", self.q)
1481
+ self.deadband = policy.get("deadband", self.deadband)
1482
+ self.margin = policy.get("margin", self.margin)
1483
+ self.correct = policy.get("correct", self.correct)
1484
+ if "epsilon" in policy:
1485
+ self._set_epsilon(policy["epsilon"])
1486
+ if "epsilon_by_family" in policy:
1487
+ self._set_epsilon(policy["epsilon_by_family"])
1488
+
1489
+ self._log_event(
1490
+ "prepare",
1491
+ message=f"Preparing RMT guard with q={self.q}, deadband={self.deadband}, margin={self.margin}, correct={self.correct}",
1492
+ )
1493
+
1494
+ try:
1495
+ # Capture baseline MP statistics for linear layers
1496
+ self.baseline_mp_stats = capture_baseline_mp_stats(model)
1497
+
1498
+ # Extract baseline sigmas for compatibility with existing detection
1499
+ self.baseline_sigmas = {}
1500
+ for name, stats in self.baseline_mp_stats.items():
1501
+ self.baseline_sigmas[name] = stats.get("sigma_base", 0.0)
1502
+
1503
+ # Get linear modules in scope
1504
+ linear_modules = self._get_linear_modules(model)
1505
+
1506
+ baseline_detection = rmt_detect(
1507
+ model=model,
1508
+ threshold=self.margin,
1509
+ detect_only=True,
1510
+ baseline_sigmas=self.baseline_sigmas,
1511
+ baseline_mp_stats=self.baseline_mp_stats,
1512
+ deadband=self.deadband,
1513
+ )
1514
+ self.baseline_total_outliers = baseline_detection.get("n_layers_flagged", 0)
1515
+ self.baseline_outliers_per_family = self._count_outliers_per_family(
1516
+ baseline_detection.get("per_layer", [])
1517
+ )
1518
+ for family_key in ("attn", "ffn", "embed", "other"):
1519
+ self.baseline_outliers_per_family.setdefault(family_key, 0)
1520
+ self.outliers_per_family = {}
1521
+ self.outliers_total = 0
1522
+ self.epsilon_violations = []
1523
+
1524
+ self.prepared = True
1525
+ preparation_time = time.time() - start_time
1526
+
1527
+ self._log_event(
1528
+ "prepare_success",
1529
+ message=f"Captured {len(self.baseline_mp_stats)} baseline MP statistics",
1530
+ baseline_count=len(self.baseline_mp_stats),
1531
+ linear_modules_count=len(linear_modules),
1532
+ preparation_time=preparation_time,
1533
+ )
1534
+
1535
+ return {
1536
+ "baseline_metrics": {
1537
+ "mp_stats_sample": dict(list(self.baseline_mp_stats.items())[:3]),
1538
+ "total_layers": len(self.baseline_mp_stats),
1539
+ "linear_modules_in_scope": len(linear_modules),
1540
+ "scope_suffixes": self.allowed_suffixes,
1541
+ "average_baseline_sigma": np.mean(
1542
+ list(self.baseline_sigmas.values())
1543
+ ),
1544
+ "max_baseline_sigma": max(self.baseline_sigmas.values())
1545
+ if self.baseline_sigmas
1546
+ else 0.0,
1547
+ "min_baseline_sigma": min(self.baseline_sigmas.values())
1548
+ if self.baseline_sigmas
1549
+ else 0.0,
1550
+ },
1551
+ "policy_applied": {
1552
+ "q": self.q,
1553
+ "deadband": self.deadband,
1554
+ "margin": self.margin,
1555
+ "correct": self.correct,
1556
+ },
1557
+ "preparation_time": preparation_time,
1558
+ "ready": True,
1559
+ }
1560
+
1561
+ except Exception as e:
1562
+ self.prepared = False
1563
+ self._log_event(
1564
+ "prepare_failed",
1565
+ level="ERROR",
1566
+ message=f"Failed to prepare RMT guard: {str(e)}",
1567
+ error=str(e),
1568
+ )
1569
+
1570
+ return {
1571
+ "baseline_metrics": {},
1572
+ "policy_applied": policy or {},
1573
+ "preparation_time": time.time() - start_time,
1574
+ "ready": False,
1575
+ "error": str(e),
1576
+ }
1577
+
1578
+ def before_edit(self, model: nn.Module) -> None:
1579
+ """
1580
+ Execute before edit (no action needed for RMT).
1581
+
1582
+ Args:
1583
+ model: The model about to be edited
1584
+ """
1585
+ if self.prepared:
1586
+ self._log_event(
1587
+ "before_edit",
1588
+ message="RMT guard ready for post-edit detection and correction",
1589
+ )
1590
+
1591
+ def after_edit(self, model: nn.Module) -> None:
1592
+ """
1593
+ Execute after edit - perform RMT detection and optional correction.
1594
+
1595
+ Args:
1596
+ model: The model that was just edited
1597
+ """
1598
+ if not self.prepared or not self.baseline_mp_stats:
1599
+ self._log_event(
1600
+ "after_edit_skipped",
1601
+ level="WARN",
1602
+ message="RMT guard not prepared, skipping post-edit detection",
1603
+ )
1604
+ return
1605
+
1606
+ self._log_event("after_edit", message="Applying RMT detection and correction")
1607
+
1608
+ try:
1609
+ # Perform RMT detection with baseline awareness
1610
+ # Create custom detection with proper adapter support
1611
+ if self.correct:
1612
+ # Apply correction using enhanced logic with adapter support
1613
+ detection_result = self._apply_rmt_detection_and_correction(model)
1614
+ else:
1615
+ # Detection only
1616
+ detection_result = rmt_detect(
1617
+ model=model,
1618
+ threshold=self.margin, # Use margin as threshold
1619
+ detect_only=True,
1620
+ verbose=False,
1621
+ baseline_sigmas=self.baseline_sigmas,
1622
+ baseline_mp_stats=self.baseline_mp_stats,
1623
+ deadband=self.deadband,
1624
+ )
1625
+
1626
+ # Store results
1627
+ self._last_result = detection_result
1628
+ self.outliers_per_family = self._count_outliers_per_family(
1629
+ detection_result.get("per_layer", [])
1630
+ )
1631
+ for family_key in ("attn", "ffn", "embed", "other"):
1632
+ self.outliers_per_family.setdefault(family_key, 0)
1633
+ self.outliers_total = detection_result.get(
1634
+ "n_layers_flagged", len(self.outliers_per_family)
1635
+ )
1636
+ self.epsilon_violations = self._compute_epsilon_violations()
1637
+
1638
+ flagged_layers = detection_result.get("n_layers_flagged", 0)
1639
+ corrected_layers = detection_result.get("correction_iterations", 0)
1640
+
1641
+ self._log_event(
1642
+ "rmt_detection_complete",
1643
+ message=f"Detected {flagged_layers} outlier layers, correction enabled: {self.correct}",
1644
+ layers_flagged=flagged_layers,
1645
+ correction_iterations=corrected_layers,
1646
+ has_outliers=detection_result.get("has_outliers", False),
1647
+ max_ratio=detection_result.get("max_ratio", 0.0),
1648
+ )
1649
+
1650
+ # Log individual layer results
1651
+ for layer_info in detection_result.get("per_layer", []):
1652
+ if layer_info.get("has_outlier", False):
1653
+ self._log_event(
1654
+ "outlier_detected",
1655
+ message=f"Outlier detected in {layer_info.get('module_name', 'unknown')}",
1656
+ layer_name=layer_info.get("module_name"),
1657
+ ratio=layer_info.get("worst_ratio", 0.0),
1658
+ sigma_max=layer_info.get("sigma_max", 0.0),
1659
+ corrected=self.correct,
1660
+ )
1661
+ elif layer_info.get("skip_reason"):
1662
+ self._log_event(
1663
+ "layer_skipped",
1664
+ message=f"Layer {layer_info.get('module_name', 'unknown')} skipped: {layer_info.get('skip_reason')}",
1665
+ layer_name=layer_info.get("module_name"),
1666
+ skip_reason=layer_info.get("skip_reason"),
1667
+ )
1668
+
1669
+ except Exception as e:
1670
+ self._log_event(
1671
+ "after_edit_failed",
1672
+ level="ERROR",
1673
+ message=f"RMT detection failed: {str(e)}",
1674
+ error=str(e),
1675
+ )
1676
+ # Store empty result for finalize
1677
+ self._last_result = {
1678
+ "has_outliers": False,
1679
+ "n_layers_flagged": 0,
1680
+ "per_layer": [],
1681
+ "max_ratio": 0.0,
1682
+ }
1683
+ self.outliers_per_family = {}
1684
+ self.outliers_total = 0
1685
+ self.epsilon_violations = []
1686
+
1687
+ def validate(
1688
+ self, model: Any, adapter: Any, context: dict[str, Any]
1689
+ ) -> dict[str, Any]:
1690
+ """
1691
+ Validate model state (Guard ABC interface).
1692
+
1693
+ Args:
1694
+ model: Model to validate
1695
+ adapter: ModelAdapter instance
1696
+ context: Validation context
1697
+
1698
+ Returns:
1699
+ Dictionary with validation results
1700
+ """
1701
+ # Use finalize to get comprehensive results
1702
+ result = self.finalize(model, adapter)
1703
+
1704
+ # Convert to simple dict format if GuardOutcome
1705
+ if (
1706
+ hasattr(result, "passed")
1707
+ and hasattr(result, "action")
1708
+ and hasattr(result, "metrics")
1709
+ ):
1710
+ violations_list: list[str] = []
1711
+ if hasattr(result, "violations") and result.violations:
1712
+ violations_list = [str(v) for v in result.violations]
1713
+ return {
1714
+ "passed": bool(result.passed),
1715
+ "action": str(result.action),
1716
+ "metrics": dict(result.metrics),
1717
+ "violations": violations_list,
1718
+ "message": "RMT guard validation completed",
1719
+ }
1720
+ else:
1721
+ return {
1722
+ "passed": result.get("passed", False),
1723
+ "action": "continue" if result.get("passed", False) else "warn",
1724
+ "metrics": result.get("metrics", {}),
1725
+ "violations": result.get("errors", []),
1726
+ "message": "RMT guard validation completed",
1727
+ }
1728
+
1729
+ def finalize(self, model: nn.Module, adapter=None) -> GuardOutcome | dict[str, Any]:
1730
+ """
1731
+ Finalize RMT guard and return comprehensive results.
1732
+
1733
+ Args:
1734
+ model: The final edited model
1735
+ adapter: Optional adapter for tying map access
1736
+
1737
+ Returns:
1738
+ GuardOutcome or dict with RMT detection and correction results
1739
+ """
1740
+ import time
1741
+
1742
+ start_time = time.time()
1743
+
1744
+ if not self.prepared:
1745
+ self._log_event(
1746
+ "finalize_failed",
1747
+ level="ERROR",
1748
+ message="RMT guard not properly prepared",
1749
+ )
1750
+
1751
+ if HAS_GUARD_OUTCOME:
1752
+ return GuardOutcome(
1753
+ name=self.name,
1754
+ passed=False,
1755
+ action="abort",
1756
+ violations=[
1757
+ {
1758
+ "type": "preparation",
1759
+ "severity": "error",
1760
+ "message": "RMT guard not properly prepared",
1761
+ "module_name": None,
1762
+ }
1763
+ ],
1764
+ metrics={
1765
+ "prepared": False,
1766
+ "finalize_time": time.time() - start_time,
1767
+ },
1768
+ )
1769
+ else:
1770
+ return {
1771
+ "passed": False,
1772
+ "metrics": {
1773
+ "prepared": False,
1774
+ "finalize_time": time.time() - start_time,
1775
+ },
1776
+ "warnings": ["RMT guard not properly prepared"],
1777
+ "errors": ["Preparation failed or baseline MP stats not captured"],
1778
+ "events": self.events,
1779
+ }
1780
+
1781
+ # Get results from after_edit
1782
+ result = self._last_result or {
1783
+ "has_outliers": False,
1784
+ "n_layers_flagged": 0,
1785
+ "per_layer": [],
1786
+ "max_ratio": 0.0,
1787
+ }
1788
+
1789
+ if result and not self.outliers_per_family:
1790
+ self.outliers_per_family = self._count_outliers_per_family(
1791
+ result.get("per_layer", [])
1792
+ )
1793
+ for family_key in ("attn", "ffn", "embed", "other"):
1794
+ self.outliers_per_family.setdefault(family_key, 0)
1795
+ self.baseline_outliers_per_family.setdefault(family_key, 0)
1796
+ self.outliers_total = result.get("n_layers_flagged", self.outliers_total or 0)
1797
+ self.epsilon_violations = self._compute_epsilon_violations()
1798
+ # Contracts: epsilon non-negative, counts non-negative
1799
+ for fam, eps in self.epsilon_by_family.items():
1800
+ guard_assert(eps >= 0.0, f"rmt.epsilon[{fam}] must be >= 0")
1801
+ for fam in set(self.outliers_per_family) | set(
1802
+ self.baseline_outliers_per_family
1803
+ ):
1804
+ guard_assert(
1805
+ self.outliers_per_family.get(fam, 0) >= 0,
1806
+ "rmt.outliers_per_family negative",
1807
+ )
1808
+ guard_assert(
1809
+ self.baseline_outliers_per_family.get(fam, 0) >= 0,
1810
+ "rmt.baseline_outliers negative",
1811
+ )
1812
+
1813
+ # Calculate metrics
1814
+ flagged_layers = result.get("n_layers_flagged", 0)
1815
+ total_layers = len(self.baseline_mp_stats) if self.baseline_mp_stats else 0
1816
+ flagged_rate = flagged_layers / total_layers if total_layers > 0 else 0.0
1817
+
1818
+ # Step 5 validation gate: no increase in outliers vs bare edit, ≤1% primary-metric cost
1819
+ # For now, use flagged rate as proxy (will be enhanced with PM checking)
1820
+ passed = flagged_rate <= 0.5 # Allow up to 50% flagged for conservative gate
1821
+
1822
+ # Generate violations for GuardOutcome
1823
+ violations = []
1824
+ warnings = []
1825
+ errors = []
1826
+
1827
+ # Create violations for each flagged layer
1828
+ for layer_info in result.get("per_layer", []):
1829
+ if layer_info.get("has_outlier", False):
1830
+ violations.append(
1831
+ {
1832
+ "type": "rmt_outlier",
1833
+ "severity": "warning" if self.correct else "error",
1834
+ "message": f"RMT outlier detected: ratio={layer_info.get('worst_ratio', 0.0):.2f}",
1835
+ "module_name": layer_info.get("module_name"),
1836
+ "ratio": layer_info.get("worst_ratio", 0.0),
1837
+ "threshold": (1.0 + self.deadband) * self.margin,
1838
+ "corrected": self.correct,
1839
+ }
1840
+ )
1841
+
1842
+ if flagged_rate > 0.3: # Warning threshold at 30%
1843
+ warnings.append(
1844
+ f"High RMT outlier rate: {flagged_layers}/{total_layers} layers flagged ({flagged_rate:.1%})"
1845
+ )
1846
+
1847
+ if flagged_rate > 0.7: # Error threshold at 70%
1848
+ errors.append(
1849
+ f"Excessive RMT outliers: {flagged_layers}/{total_layers} layers flagged"
1850
+ )
1851
+ passed = False
1852
+
1853
+ if self.epsilon_violations:
1854
+ passed = False
1855
+ for failure in self.epsilon_violations:
1856
+ errors.append(
1857
+ "RMT ε-rule violation: "
1858
+ f"{failure['family']} bare={failure['bare']} "
1859
+ f"guarded={failure['guarded']} allowed={failure['allowed']} "
1860
+ f"(ε={failure['epsilon']:.3f})"
1861
+ )
1862
+
1863
+ finalize_time = time.time() - start_time
1864
+
1865
+ # Final metrics
1866
+ final_metrics = {
1867
+ "layers_flagged": flagged_layers,
1868
+ "total_layers": total_layers,
1869
+ "flagged_rate": flagged_rate,
1870
+ "rmt_outliers": flagged_layers,
1871
+ "max_ratio": result.get("max_ratio", 0.0),
1872
+ "correction_enabled": self.correct,
1873
+ "correction_iterations": result.get("correction_iterations", 0),
1874
+ "q_used": self.q,
1875
+ "deadband_used": self.deadband,
1876
+ "margin_used": self.margin,
1877
+ "detection_threshold": (1.0 + self.deadband) * self.margin,
1878
+ "baseline_layers_captured": len(self.baseline_mp_stats)
1879
+ if self.baseline_mp_stats
1880
+ else 0,
1881
+ "finalize_time": finalize_time,
1882
+ "baseline_outliers_per_family": {
1883
+ k: int(v) for k, v in self.baseline_outliers_per_family.items()
1884
+ },
1885
+ "outliers_per_family": {
1886
+ k: int(v) for k, v in self.outliers_per_family.items()
1887
+ },
1888
+ "baseline_outliers_total": int(self.baseline_total_outliers),
1889
+ "outliers_total": int(self.outliers_total),
1890
+ "epsilon_by_family": {
1891
+ k: float(v) for k, v in self.epsilon_by_family.items()
1892
+ },
1893
+ "epsilon_default": float(self.epsilon_default),
1894
+ "epsilon_violations": self.epsilon_violations,
1895
+ }
1896
+
1897
+ self._log_event(
1898
+ "finalize_complete",
1899
+ message=f"RMT guard finalized - {'PASSED' if passed else 'FAILED'}",
1900
+ passed=passed,
1901
+ flagged_rate=flagged_rate,
1902
+ finalize_time=finalize_time,
1903
+ )
1904
+
1905
+ # Return GuardOutcome if available, otherwise legacy dict
1906
+ # Env-gated tiny evidence dump for auditors
1907
+ try:
1908
+ payload = {
1909
+ "rmt": {
1910
+ "epsilon_by_family": {
1911
+ k: float(v) for k, v in self.epsilon_by_family.items()
1912
+ },
1913
+ "deadband": float(self.deadband),
1914
+ "margin": float(self.margin),
1915
+ "evaluated": True,
1916
+ }
1917
+ }
1918
+ maybe_dump_guard_evidence(".", payload)
1919
+ except Exception:
1920
+ pass
1921
+
1922
+ if HAS_GUARD_OUTCOME:
1923
+ # Add details to metrics since GuardOutcome doesn't have a details field
1924
+ final_metrics.update(
1925
+ {
1926
+ "guard_type": "rmt",
1927
+ "baseline_captured": self.baseline_mp_stats is not None,
1928
+ "baseline_count": len(self.baseline_mp_stats)
1929
+ if self.baseline_mp_stats
1930
+ else 0,
1931
+ "flagged_layer_names": [v["module_name"] for v in violations],
1932
+ "per_layer_results": result.get("per_layer", []),
1933
+ "policy": {
1934
+ "q": self.q,
1935
+ "deadband": self.deadband,
1936
+ "margin": self.margin,
1937
+ "correct": self.correct,
1938
+ "epsilon": self.epsilon_by_family.copy(),
1939
+ },
1940
+ "scope_suffixes": self.allowed_suffixes,
1941
+ }
1942
+ )
1943
+
1944
+ return GuardOutcome(
1945
+ name=self.name,
1946
+ passed=passed,
1947
+ action="none" if passed else "rollback",
1948
+ violations=violations,
1949
+ metrics=final_metrics,
1950
+ )
1951
+ else:
1952
+ return {
1953
+ "passed": passed,
1954
+ "metrics": final_metrics,
1955
+ "warnings": warnings,
1956
+ "errors": errors,
1957
+ "violations": violations,
1958
+ "events": self.events,
1959
+ "details": {
1960
+ "guard_type": "rmt",
1961
+ "baseline_captured": self.baseline_mp_stats is not None,
1962
+ "baseline_count": len(self.baseline_mp_stats)
1963
+ if self.baseline_mp_stats
1964
+ else 0,
1965
+ "flagged_layer_names": [v["module_name"] for v in violations],
1966
+ "per_layer_results": result.get("per_layer", []),
1967
+ "policy": {
1968
+ "q": self.q,
1969
+ "deadband": self.deadband,
1970
+ "margin": self.margin,
1971
+ "correct": self.correct,
1972
+ "epsilon": self.epsilon_by_family.copy(),
1973
+ },
1974
+ "scope_suffixes": self.allowed_suffixes,
1975
+ },
1976
+ }
1977
+
1978
+ def policy(self) -> RMTPolicyDict:
1979
+ """
1980
+ Get default policy for RMT guard.
1981
+
1982
+ Returns:
1983
+ RMTPolicyDict with current configuration
1984
+ """
1985
+ return RMTPolicyDict(
1986
+ q=self.q,
1987
+ deadband=self.deadband,
1988
+ margin=self.margin,
1989
+ correct=self.correct,
1990
+ epsilon=self.epsilon_by_family.copy(),
1991
+ )
1992
+
1993
+
1994
+ # === Policy Utilities ===
1995
+
1996
+
1997
+ def get_rmt_policy(name: str = "balanced") -> RMTPolicyDict:
1998
+ """
1999
+ Get a RMT policy by name.
2000
+
2001
+ Args:
2002
+ name: Policy name ("conservative", "balanced", "aggressive")
2003
+
2004
+ Returns:
2005
+ RMTPolicyDict configuration
2006
+ """
2007
+ # Per-family ε values match tiers.yaml (November 2025 calibration)
2008
+ policies = {
2009
+ "conservative": RMTPolicyDict(
2010
+ q="auto",
2011
+ deadband=0.05,
2012
+ margin=1.3,
2013
+ correct=True,
2014
+ epsilon={"ffn": 0.06, "attn": 0.05, "embed": 0.07, "other": 0.07},
2015
+ ),
2016
+ "balanced": RMTPolicyDict(
2017
+ q="auto",
2018
+ deadband=0.10,
2019
+ margin=1.5,
2020
+ correct=True,
2021
+ epsilon={"ffn": 0.10, "attn": 0.08, "embed": 0.12, "other": 0.12},
2022
+ ),
2023
+ "aggressive": RMTPolicyDict(
2024
+ q="auto",
2025
+ deadband=0.15,
2026
+ margin=1.8,
2027
+ correct=True,
2028
+ epsilon={"ffn": 0.14, "attn": 0.12, "embed": 0.18, "other": 0.18},
2029
+ ),
2030
+ }
2031
+
2032
+ if name not in policies:
2033
+ from invarlock.core.exceptions import GuardError
2034
+
2035
+ available = list(policies.keys())
2036
+ raise GuardError(
2037
+ code="E502",
2038
+ message="POLICY-NOT-FOUND",
2039
+ details={"name": name, "available": available},
2040
+ )
2041
+
2042
+ return policies[name]
2043
+
2044
+
2045
+ def create_custom_rmt_policy(
2046
+ q: float | Literal["auto"] = "auto",
2047
+ deadband: float = 0.10,
2048
+ margin: float = 1.5,
2049
+ correct: bool = True,
2050
+ epsilon: float | dict[str, float] | None = None,
2051
+ ) -> RMTPolicyDict:
2052
+ """
2053
+ Create a custom RMT policy.
2054
+
2055
+ Args:
2056
+ q: MP aspect ratio (auto-derived or manual)
2057
+ deadband: Tolerance margin (0.0-0.5)
2058
+ margin: RMT threshold ratio (> 1.0)
2059
+ correct: Enable automatic correction
2060
+
2061
+ Returns:
2062
+ Custom RMTPolicyDict configuration
2063
+ """
2064
+ if isinstance(q, float) and not 0.1 <= q <= 10.0:
2065
+ from invarlock.core.exceptions import ValidationError
2066
+
2067
+ raise ValidationError(
2068
+ code="E501",
2069
+ message="POLICY-PARAM-INVALID",
2070
+ details={"param": "q", "value": q},
2071
+ )
2072
+
2073
+ if not 0.0 <= deadband <= 0.5:
2074
+ from invarlock.core.exceptions import ValidationError
2075
+
2076
+ raise ValidationError(
2077
+ code="E501",
2078
+ message="POLICY-PARAM-INVALID",
2079
+ details={"param": "deadband", "value": deadband},
2080
+ )
2081
+
2082
+ if not margin >= 1.0:
2083
+ from invarlock.core.exceptions import ValidationError
2084
+
2085
+ raise ValidationError(
2086
+ code="E501",
2087
+ message="POLICY-PARAM-INVALID",
2088
+ details={"param": "margin", "value": margin},
2089
+ )
2090
+
2091
+ return RMTPolicyDict(
2092
+ q=q,
2093
+ deadband=deadband,
2094
+ margin=margin,
2095
+ correct=correct,
2096
+ epsilon=epsilon,
2097
+ )