invarlock 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (132) hide show
  1. invarlock/__init__.py +33 -0
  2. invarlock/__main__.py +10 -0
  3. invarlock/_data/runtime/profiles/ci_cpu.yaml +15 -0
  4. invarlock/_data/runtime/profiles/release.yaml +23 -0
  5. invarlock/_data/runtime/tiers.yaml +76 -0
  6. invarlock/adapters/__init__.py +102 -0
  7. invarlock/adapters/_capabilities.py +45 -0
  8. invarlock/adapters/auto.py +99 -0
  9. invarlock/adapters/base.py +530 -0
  10. invarlock/adapters/base_types.py +85 -0
  11. invarlock/adapters/hf_bert.py +852 -0
  12. invarlock/adapters/hf_gpt2.py +403 -0
  13. invarlock/adapters/hf_llama.py +485 -0
  14. invarlock/adapters/hf_mixin.py +383 -0
  15. invarlock/adapters/hf_onnx.py +112 -0
  16. invarlock/adapters/hf_t5.py +137 -0
  17. invarlock/adapters/py.typed +1 -0
  18. invarlock/assurance/__init__.py +43 -0
  19. invarlock/cli/__init__.py +8 -0
  20. invarlock/cli/__main__.py +8 -0
  21. invarlock/cli/_evidence.py +25 -0
  22. invarlock/cli/_json.py +75 -0
  23. invarlock/cli/adapter_auto.py +162 -0
  24. invarlock/cli/app.py +287 -0
  25. invarlock/cli/commands/__init__.py +26 -0
  26. invarlock/cli/commands/certify.py +403 -0
  27. invarlock/cli/commands/doctor.py +1358 -0
  28. invarlock/cli/commands/explain_gates.py +151 -0
  29. invarlock/cli/commands/export_html.py +100 -0
  30. invarlock/cli/commands/plugins.py +1331 -0
  31. invarlock/cli/commands/report.py +354 -0
  32. invarlock/cli/commands/run.py +4146 -0
  33. invarlock/cli/commands/verify.py +1040 -0
  34. invarlock/cli/config.py +396 -0
  35. invarlock/cli/constants.py +68 -0
  36. invarlock/cli/device.py +92 -0
  37. invarlock/cli/doctor_helpers.py +74 -0
  38. invarlock/cli/errors.py +6 -0
  39. invarlock/cli/overhead_utils.py +60 -0
  40. invarlock/cli/provenance.py +66 -0
  41. invarlock/cli/utils.py +41 -0
  42. invarlock/config.py +56 -0
  43. invarlock/core/__init__.py +62 -0
  44. invarlock/core/abi.py +15 -0
  45. invarlock/core/api.py +274 -0
  46. invarlock/core/auto_tuning.py +317 -0
  47. invarlock/core/bootstrap.py +226 -0
  48. invarlock/core/checkpoint.py +221 -0
  49. invarlock/core/contracts.py +73 -0
  50. invarlock/core/error_utils.py +64 -0
  51. invarlock/core/events.py +298 -0
  52. invarlock/core/exceptions.py +95 -0
  53. invarlock/core/registry.py +481 -0
  54. invarlock/core/retry.py +146 -0
  55. invarlock/core/runner.py +2041 -0
  56. invarlock/core/types.py +154 -0
  57. invarlock/edits/__init__.py +12 -0
  58. invarlock/edits/_edit_utils.py +249 -0
  59. invarlock/edits/_external_utils.py +268 -0
  60. invarlock/edits/noop.py +47 -0
  61. invarlock/edits/py.typed +1 -0
  62. invarlock/edits/quant_rtn.py +801 -0
  63. invarlock/edits/registry.py +166 -0
  64. invarlock/eval/__init__.py +23 -0
  65. invarlock/eval/bench.py +1207 -0
  66. invarlock/eval/bootstrap.py +50 -0
  67. invarlock/eval/data.py +2052 -0
  68. invarlock/eval/metrics.py +2167 -0
  69. invarlock/eval/primary_metric.py +767 -0
  70. invarlock/eval/probes/__init__.py +24 -0
  71. invarlock/eval/probes/fft.py +139 -0
  72. invarlock/eval/probes/mi.py +213 -0
  73. invarlock/eval/probes/post_attention.py +323 -0
  74. invarlock/eval/providers/base.py +67 -0
  75. invarlock/eval/providers/seq2seq.py +111 -0
  76. invarlock/eval/providers/text_lm.py +113 -0
  77. invarlock/eval/providers/vision_text.py +93 -0
  78. invarlock/eval/py.typed +1 -0
  79. invarlock/guards/__init__.py +18 -0
  80. invarlock/guards/_contracts.py +9 -0
  81. invarlock/guards/invariants.py +640 -0
  82. invarlock/guards/policies.py +805 -0
  83. invarlock/guards/py.typed +1 -0
  84. invarlock/guards/rmt.py +2097 -0
  85. invarlock/guards/spectral.py +1419 -0
  86. invarlock/guards/tier_config.py +354 -0
  87. invarlock/guards/variance.py +3298 -0
  88. invarlock/guards_ref/__init__.py +15 -0
  89. invarlock/guards_ref/rmt_ref.py +40 -0
  90. invarlock/guards_ref/spectral_ref.py +135 -0
  91. invarlock/guards_ref/variance_ref.py +60 -0
  92. invarlock/model_profile.py +353 -0
  93. invarlock/model_utils.py +221 -0
  94. invarlock/observability/__init__.py +10 -0
  95. invarlock/observability/alerting.py +535 -0
  96. invarlock/observability/core.py +546 -0
  97. invarlock/observability/exporters.py +565 -0
  98. invarlock/observability/health.py +588 -0
  99. invarlock/observability/metrics.py +457 -0
  100. invarlock/observability/py.typed +1 -0
  101. invarlock/observability/utils.py +553 -0
  102. invarlock/plugins/__init__.py +12 -0
  103. invarlock/plugins/hello_guard.py +33 -0
  104. invarlock/plugins/hf_awq_adapter.py +82 -0
  105. invarlock/plugins/hf_bnb_adapter.py +79 -0
  106. invarlock/plugins/hf_gptq_adapter.py +78 -0
  107. invarlock/plugins/py.typed +1 -0
  108. invarlock/py.typed +1 -0
  109. invarlock/reporting/__init__.py +7 -0
  110. invarlock/reporting/certificate.py +3221 -0
  111. invarlock/reporting/certificate_schema.py +244 -0
  112. invarlock/reporting/dataset_hashing.py +215 -0
  113. invarlock/reporting/guards_analysis.py +948 -0
  114. invarlock/reporting/html.py +32 -0
  115. invarlock/reporting/normalizer.py +235 -0
  116. invarlock/reporting/policy_utils.py +517 -0
  117. invarlock/reporting/primary_metric_utils.py +265 -0
  118. invarlock/reporting/render.py +1442 -0
  119. invarlock/reporting/report.py +903 -0
  120. invarlock/reporting/report_types.py +278 -0
  121. invarlock/reporting/utils.py +175 -0
  122. invarlock/reporting/validate.py +631 -0
  123. invarlock/security.py +176 -0
  124. invarlock/sparsity_utils.py +323 -0
  125. invarlock/utils/__init__.py +150 -0
  126. invarlock/utils/digest.py +45 -0
  127. invarlock-0.2.0.dist-info/METADATA +586 -0
  128. invarlock-0.2.0.dist-info/RECORD +132 -0
  129. invarlock-0.2.0.dist-info/WHEEL +5 -0
  130. invarlock-0.2.0.dist-info/entry_points.txt +20 -0
  131. invarlock-0.2.0.dist-info/licenses/LICENSE +201 -0
  132. invarlock-0.2.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,3298 @@
1
+ """
2
+ InvarLock – Safety: Data-Driven Variance Equalization (DD-VE)
3
+ =========================================================
4
+
5
+ Branch-level variance equalizer for transformer blocks to maintain
6
+ stable residual stream dynamics after edits.
7
+
8
+ For each transformer block, measures the variance of residual branch
9
+ outputs (attention and MLP) and scales projection weights to maintain
10
+ Var(x_out) ≈ 1 when Var(x_in) ≈ 1.
11
+ """
12
+
13
+ from __future__ import annotations
14
+
15
+ import copy
16
+ import fnmatch
17
+ import hashlib
18
+ import itertools
19
+ import math
20
+ import time
21
+ from collections import defaultdict
22
+ from collections.abc import Iterable, Sequence
23
+ from datetime import datetime
24
+ from typing import Any
25
+
26
+ import numpy as np
27
+ import torch
28
+ import torch.nn as nn
29
+
30
+ from invarlock.cli._evidence import maybe_dump_guard_evidence
31
+ from invarlock.core.api import Guard
32
+ from invarlock.core.bootstrap import compute_paired_delta_log_ci
33
+
34
+ from ._contracts import guard_assert
35
+
36
+ # Import the policy type and Guard interface
37
+ from .policies import VariancePolicyDict
38
+
39
+ __all__ = ["equalise_residual_variance", "equalise_branch_variance", "VarianceGuard"]
40
+
41
+
42
+ try: # Optional dependency: tqdm (progress bars)
43
+ from tqdm.auto import tqdm as _tqdm
44
+ except Exception: # pragma: no cover - exercised only when tqdm is absent
45
+
46
+ class _TqdmShim:
47
+ def __init__(self, iterable=None, total=None, **kwargs):
48
+ self._iterable = iterable
49
+ self.total = total
50
+
51
+ def __iter__(self):
52
+ if self._iterable is None:
53
+ return iter(())
54
+ return iter(self._iterable)
55
+
56
+ def __enter__(self):
57
+ return self
58
+
59
+ def __exit__(self, exc_type, exc, tb):
60
+ return False
61
+
62
+ def update(self, n: int = 1) -> None:
63
+ return None
64
+
65
+ def _tqdm(iterable=None, *args, **kwargs):
66
+ return _TqdmShim(iterable=iterable, **kwargs)
67
+
68
+
69
+ tqdm = _tqdm
70
+
71
+
72
+ def _unwrap_model(model: nn.Module) -> nn.Module:
73
+ """Unwrap DataParallel/DDP wrappers to get the underlying model.
74
+
75
+ PyTorch's DataParallel and DistributedDataParallel wrap models with a
76
+ `.module` attribute. This function traverses that chain to get the
77
+ actual model, enabling consistent layer iteration regardless of how
78
+ the model is wrapped for training/inference.
79
+ """
80
+ unwrapped = model
81
+ while hasattr(unwrapped, "module"):
82
+ unwrapped = unwrapped.module
83
+ return unwrapped
84
+
85
+
86
+ def _iter_transformer_layers(model: nn.Module):
87
+ """Iterate over transformer layers in a model.
88
+
89
+ Handles multiple transformer architectures and automatically unwraps
90
+ DataParallel/DDP wrappers.
91
+ """
92
+ # Unwrap DataParallel/DDP wrappers first
93
+ model = _unwrap_model(model)
94
+
95
+ # Handle different model architectures
96
+ if hasattr(model, "transformer") and hasattr(model.transformer, "h"):
97
+ # GPT-2 style
98
+ yield from model.transformer.h
99
+ elif hasattr(model, "model") and hasattr(model.model, "layers"):
100
+ # LLaMA style
101
+ yield from model.model.layers
102
+ elif hasattr(model, "encoder") and hasattr(model.encoder, "layer"):
103
+ # BERT style
104
+ yield from model.encoder.layer
105
+ elif hasattr(model, "decoder") and hasattr(model.decoder, "layers"):
106
+ # T5/BART decoder style
107
+ yield from model.decoder.layers
108
+ elif hasattr(model, "layers"):
109
+ # Generic transformer with top-level layers attribute
110
+ yield from model.layers
111
+ else:
112
+ # Fallback: look for modules with attention
113
+ for module in model.modules():
114
+ if hasattr(module, "attn") and hasattr(module, "mlp"):
115
+ yield module
116
+
117
+
118
+ @torch.no_grad()
119
+ def equalise_residual_variance(
120
+ model: nn.Module,
121
+ dataloader,
122
+ *,
123
+ windows: int = 32,
124
+ tol: float = 0.02,
125
+ scale_bias: bool = True,
126
+ seed: int = 42,
127
+ device: str | None = None,
128
+ allow_empty: bool = False,
129
+ clamp_range: tuple | None = (0.9, 1.1),
130
+ ) -> dict[str, float]:
131
+ """
132
+ Apply data-driven variance equalization to transformer branches.
133
+
134
+ This function measures the variance of each residual branch output
135
+ (attention-proj and MLP-proj) and scales projection weights so that
136
+ adding the branch back to the residual stream maintains stable variance.
137
+
138
+ The scaling factor alpha = 1 / sqrt(1 + Var(F)) is used, where F is the
139
+ branch output.
140
+
141
+ Args:
142
+ model: Transformer model to equalize
143
+ dataloader: DataLoader for calibration
144
+ windows: Number of calibration batches
145
+ tol: Tolerance for skipping near-unity scales
146
+ scale_bias: Whether to scale biases along with weights
147
+ seed: Random seed for reproducibility
148
+ device: Device to use (auto-detected if None)
149
+ allow_empty: Whether to allow empty dataloader (returns empty dict)
150
+ clamp_range: Optional (min, max) to clamp scaling factors (e.g., (0.9, 1.1))
151
+
152
+ Returns:
153
+ Dict mapping layer names to applied scaling factors
154
+ """
155
+ torch.manual_seed(seed)
156
+
157
+ if device is None:
158
+ device = next(model.parameters()).device
159
+ else:
160
+ device = torch.device(device)
161
+
162
+ model.eval()
163
+
164
+ # Storage for variance measurements
165
+ hooks: dict[str, Any] = {}
166
+ sample_values: dict[str, list[float]] = defaultdict(list)
167
+
168
+ def _branch_hook(name):
169
+ def fn(_, __, out):
170
+ y = out[0] if isinstance(out, tuple) else out
171
+ y = y.detach().float()
172
+ # Skip if tensor has zero elements
173
+ if y.numel() == 0:
174
+ return
175
+ mean_square = float(y.pow(2).mean().item())
176
+ sample_values[name].append(mean_square)
177
+
178
+ return fn
179
+
180
+ # Register hooks on projection layers
181
+ for i, blk in enumerate(_iter_transformer_layers(model)):
182
+ # Handle GPT-2 style architecture
183
+ if hasattr(blk, "attn"):
184
+ # Check for c_proj (GPT-2) or out_proj (generic)
185
+ attn_proj = getattr(blk.attn, "c_proj", None) or getattr(
186
+ blk.attn, "out_proj", None
187
+ )
188
+ if attn_proj is not None:
189
+ name = f"block{i}.attn"
190
+ hooks[name] = attn_proj.register_forward_hook(_branch_hook(name))
191
+
192
+ if hasattr(blk, "mlp"):
193
+ # Check for c_proj (GPT-2) or down_proj (LLaMA) or fc2 (generic)
194
+ mlp_proj = (
195
+ getattr(blk.mlp, "c_proj", None)
196
+ or getattr(blk.mlp, "down_proj", None)
197
+ or getattr(blk.mlp, "fc2", None)
198
+ )
199
+ if mlp_proj is not None:
200
+ name = f"block{i}.mlp"
201
+ hooks[name] = mlp_proj.register_forward_hook(_branch_hook(name))
202
+
203
+ # Collect variance statistics
204
+ try:
205
+ it = itertools.islice(iter(dataloader), windows)
206
+ batches = list(it)
207
+ except (StopIteration, TypeError):
208
+ batches = []
209
+
210
+ if not batches and not allow_empty:
211
+ raise ValueError("Empty dataloader provided and allow_empty=False")
212
+
213
+ for batch in tqdm(batches, desc="DD-VE Calibration", leave=False):
214
+ if isinstance(batch, dict):
215
+ input_ids = batch.get("input_ids", batch.get("inputs", None))
216
+ elif isinstance(batch, tuple | list):
217
+ # Handle tuple/list from TensorDataset
218
+ input_ids = batch[0] if len(batch) > 0 else None
219
+ else:
220
+ input_ids = batch
221
+
222
+ if input_ids is not None:
223
+ # Convert to tensor if needed
224
+ if not isinstance(input_ids, torch.Tensor):
225
+ input_ids = torch.as_tensor(input_ids)
226
+
227
+ # Ensure input has batch dimension [batch, seq_len]
228
+ # HF models (GPT-2, etc.) expect 2-D input tensors
229
+ if input_ids.dim() == 1:
230
+ input_ids = input_ids.unsqueeze(0)
231
+
232
+ with torch.no_grad():
233
+ model(input_ids.to(device))
234
+
235
+ # Remove hooks
236
+ for h in hooks.values():
237
+ h.remove()
238
+
239
+ # Apply scaling factors
240
+ applied_scales: dict[str, float] = {}
241
+
242
+ for i, blk in enumerate(_iter_transformer_layers(model)):
243
+ # Handle attention projection
244
+ if hasattr(blk, "attn"):
245
+ attn_proj = getattr(blk.attn, "c_proj", None) or getattr(
246
+ blk.attn, "out_proj", None
247
+ )
248
+ if attn_proj is not None:
249
+ name = f"block{i}.attn"
250
+ values = sample_values.get(name, [])
251
+ if values:
252
+ tensor_vals = torch.tensor(values, dtype=torch.float64)
253
+
254
+ # Winsorize to remove extreme outliers (≈1-2%)
255
+ if tensor_vals.numel() >= 10:
256
+ lower = torch.quantile(tensor_vals, 0.02)
257
+ upper = torch.quantile(tensor_vals, 0.98)
258
+ tensor_vals = torch.clamp(
259
+ tensor_vals, lower.item(), upper.item()
260
+ )
261
+
262
+ group_count = 8 if tensor_vals.numel() >= 8 else tensor_vals.numel()
263
+ if group_count > 1:
264
+ chunks = torch.chunk(tensor_vals, group_count)
265
+ group_means = torch.stack([chunk.mean() for chunk in chunks])
266
+ var_F = torch.median(group_means).item()
267
+ else:
268
+ var_F = tensor_vals.mean().item()
269
+
270
+ alpha = (1.0 / max(var_F, 1e-9)) ** 0.5
271
+
272
+ # Apply clamping if specified
273
+ if clamp_range is not None:
274
+ alpha = max(clamp_range[0], min(alpha, clamp_range[1]))
275
+
276
+ if abs(alpha - 1.0) >= tol:
277
+ with torch.no_grad():
278
+ attn_proj.weight.mul_(alpha)
279
+ if scale_bias and attn_proj.bias is not None:
280
+ attn_proj.bias.mul_(alpha)
281
+ applied_scales[name] = alpha
282
+
283
+ # Handle MLP projection
284
+ if hasattr(blk, "mlp"):
285
+ mlp_proj = (
286
+ getattr(blk.mlp, "c_proj", None)
287
+ or getattr(blk.mlp, "down_proj", None)
288
+ or getattr(blk.mlp, "fc2", None)
289
+ )
290
+ if mlp_proj is not None:
291
+ name = f"block{i}.mlp"
292
+ values = sample_values.get(name, [])
293
+ if values:
294
+ tensor_vals = torch.tensor(values, dtype=torch.float64)
295
+
296
+ if tensor_vals.numel() >= 10:
297
+ lower = torch.quantile(tensor_vals, 0.02)
298
+ upper = torch.quantile(tensor_vals, 0.98)
299
+ tensor_vals = torch.clamp(
300
+ tensor_vals, lower.item(), upper.item()
301
+ )
302
+
303
+ group_count = 8 if tensor_vals.numel() >= 8 else tensor_vals.numel()
304
+ if group_count > 1:
305
+ chunks = torch.chunk(tensor_vals, group_count)
306
+ group_means = torch.stack([chunk.mean() for chunk in chunks])
307
+ var_F = torch.median(group_means).item()
308
+ else:
309
+ var_F = tensor_vals.mean().item()
310
+
311
+ alpha = (1.0 / max(var_F, 1e-9)) ** 0.5
312
+
313
+ # Apply clamping if specified
314
+ if clamp_range is not None:
315
+ alpha = max(clamp_range[0], min(alpha, clamp_range[1]))
316
+
317
+ if abs(alpha - 1.0) >= tol:
318
+ with torch.no_grad():
319
+ mlp_proj.weight.mul_(alpha)
320
+ if scale_bias and mlp_proj.bias is not None:
321
+ mlp_proj.bias.mul_(alpha)
322
+ applied_scales[name] = alpha
323
+
324
+ return applied_scales
325
+
326
+
327
+ def equalise_branch_variance(
328
+ model: nn.Module,
329
+ dataloader,
330
+ windows: int = 32,
331
+ tol: float = 0.02,
332
+ scale_bias: bool = True,
333
+ seed: int = 42,
334
+ device: str | None = None,
335
+ allow_empty: bool = False,
336
+ ) -> dict[str, float]:
337
+ """
338
+ Legacy alias for equalise_residual_variance.
339
+
340
+ Maintained for backward compatibility.
341
+ """
342
+ return equalise_residual_variance(
343
+ model=model,
344
+ dataloader=dataloader,
345
+ windows=windows,
346
+ tol=tol,
347
+ scale_bias=scale_bias,
348
+ seed=seed,
349
+ device=device,
350
+ allow_empty=allow_empty,
351
+ )
352
+
353
+
354
+ def _predictive_gate_outcome(
355
+ mean_delta: float,
356
+ delta_ci: tuple[float, float] | None,
357
+ min_effect: float,
358
+ one_sided: bool,
359
+ ) -> tuple[bool, str]:
360
+ """
361
+ Decide whether the predictive gate passes given the CI and tier semantics.
362
+
363
+ Args:
364
+ mean_delta: Mean ΔlogNLL (virtual VE − no VE) from paired calibration.
365
+ delta_ci: BCa confidence interval on ΔlogNLL (lower, upper).
366
+ min_effect: Minimum absolute improvement required.
367
+ one_sided: Whether to require a one-sided improvement (balanced tier).
368
+
369
+ Returns:
370
+ Tuple of (passed, reason) where reason is a canonical string used in stats.
371
+ """
372
+ guard_assert(min_effect >= 0.0, "variance.min_effect must be >= 0")
373
+ if (
374
+ delta_ci is None
375
+ or len(delta_ci) != 2
376
+ or not all(
377
+ isinstance(val, (int | float)) and math.isfinite(val) for val in delta_ci
378
+ )
379
+ ):
380
+ return False, "ci_unavailable"
381
+
382
+ lower, upper = float(delta_ci[0]), float(delta_ci[1])
383
+ min_effect = float(min_effect or 0.0)
384
+
385
+ if one_sided:
386
+ if lower >= 0.0:
387
+ return False, "ci_contains_zero"
388
+ if mean_delta >= 0.0:
389
+ return False, "mean_not_negative"
390
+ if min_effect > 0.0 and (-mean_delta) < min_effect:
391
+ return False, "gain_below_threshold"
392
+ return True, "ci_gain_met"
393
+
394
+ # Two-sided improvement: CI must be strictly below zero.
395
+ if upper >= 0.0:
396
+ return False, "ci_contains_zero"
397
+
398
+ gain_lower_bound = -upper # Convert ΔlogNLL CI to gain CI lower bound.
399
+ if gain_lower_bound < min_effect:
400
+ return False, "gain_below_threshold"
401
+
402
+ return True, "ci_gain_met"
403
+
404
+
405
+ # === Standalone Variance Guard Implementation ===
406
+
407
+
408
+ class VarianceGuard(Guard):
409
+ """
410
+ Standalone Variance Guard with A/B testing for data-driven variance equalization.
411
+
412
+ Implements branch-level variance equalization with reinforced A/B gate functionality:
413
+ - Measures variance of residual branch outputs during calibration
414
+ - Computes scaling factors to maintain stable variance dynamics
415
+ - A/B tests whether VE improves perplexity by at least min_gain
416
+ - Only enables VE if it demonstrably helps (validation gate compliance)
417
+
418
+ Policy Structure:
419
+ - min_gain: Minimum primary-metric improvement required to enable VE
420
+ - max_calib: Maximum calibration samples for A/B testing
421
+ - scope: Which layers to process ("ffn", "attn", "both")
422
+ - clamp: Scaling factor limits (min, max)
423
+ - deadband: Tolerance margin before scaling
424
+ - seed: Random seed for deterministic evaluation
425
+
426
+ Reinforced A/B Testing Flow:
427
+ 1. Capture baseline model state with checkpoint discipline
428
+ 2. Measure variance and compute proposed scales during prepare
429
+ 3. A/B test with identical windows: evaluate the primary metric without VE, then with VE
430
+ 4. Apply robust gain math with tie-breaker deadband and absolute floor
431
+ 5. Enable VE only if improvement meets all criteria
432
+ 6. Idempotent enable/disable with exact state restoration
433
+ """
434
+
435
+ name = "variance"
436
+
437
+ def __init__(self, policy: VariancePolicyDict | None = None):
438
+ """
439
+ Initialize Variance Guard with reinforced A/B gate logic.
440
+
441
+ Args:
442
+ policy: Variance policy configuration (uses balanced default if None)
443
+ """
444
+ from .policies import get_variance_policy
445
+
446
+ self._policy = policy or get_variance_policy("balanced")
447
+ self._policy.setdefault("mode", "ci")
448
+ self._policy.setdefault("min_rel_gain", 0.001)
449
+ self._policy.setdefault("alpha", 0.05)
450
+ self._policy.setdefault("clamp", (0.5, 2.0))
451
+ self._policy.setdefault("seed", 123)
452
+ self._policy.setdefault("tie_breaker_deadband", 0.005)
453
+ self._policy.setdefault("min_abs_adjust", 0.012)
454
+ self._policy.setdefault("max_scale_step", 0.02)
455
+ self._policy.setdefault("topk_backstop", 1)
456
+ self._policy.setdefault("max_adjusted_modules", 0)
457
+ self._policy.setdefault("predictive_gate", True)
458
+ self._policy.setdefault("predictive_one_sided", False)
459
+ self._policy.setdefault("absolute_floor_ppl", 0.05)
460
+ if self._policy.get("min_effect_lognll") is not None:
461
+ self._policy["min_effect_lognll"] = float(self._policy["min_effect_lognll"])
462
+ self._refresh_calibration_defaults()
463
+ self._scales: dict[str, float] = {}
464
+ self._raw_scales: dict[str, float] = {}
465
+ self._enabled = False
466
+ self._stats: dict[str, Any] = {}
467
+ self._prepared = False
468
+ self._baseline_state: dict[str, Any] | None = None
469
+ self.events: list[dict[str, Any]] = []
470
+ self._calibration_stats: dict[str, Any] = {
471
+ "requested": 0,
472
+ "coverage": 0,
473
+ "min_coverage": 0,
474
+ "seed": self._policy["calibration"]["seed"],
475
+ "status": "uninitialized",
476
+ }
477
+ self.ABSOLUTE_FLOOR = float(
478
+ self._policy.get(
479
+ "absolute_floor_pm", self._policy.get("absolute_floor_ppl", 0.05)
480
+ )
481
+ )
482
+ self._monitor_only = bool(self._policy.get("monitor_only", False))
483
+ self._params_changed: int | None = None
484
+ self._run_context: dict[str, Any] | None = None
485
+ self._report_meta: dict[str, Any] | None = None
486
+ self._dataset_meta: dict[str, Any] | None = None
487
+ self._pairing_reference: list[str] = []
488
+ self._pairing_digest: str | None = None
489
+ self._adapter_ref: Any | None = None
490
+
491
+ # A/B testing results with reinforced validation
492
+ self._ppl_no_ve: float | None = None
493
+ self._ppl_with_ve: float | None = None
494
+ self._ab_gain: float | None = None
495
+ self._ab_windows_used: int | None = None
496
+ self._ab_seed_used: int | None = None
497
+ self._ratio_ci: tuple[float, float] | None = None
498
+ self._predictive_gate_state: dict[str, Any] = {
499
+ "evaluated": False,
500
+ "passed": False,
501
+ "reason": "not_evaluated",
502
+ "delta_ci": (None, None),
503
+ "gain_ci": (None, None),
504
+ "mean_delta": None,
505
+ }
506
+
507
+ # Module tracking for safe scaling
508
+ self._target_modules: dict[str, nn.Module] = {}
509
+ self._original_scales: dict[str, float] = {}
510
+ self._focus_modules = {
511
+ self._normalize_module_name(name)
512
+ for name in (self._policy.get("target_modules") or [])
513
+ if isinstance(name, str)
514
+ }
515
+ if self._focus_modules:
516
+ self._policy["target_modules"] = sorted(self._focus_modules)
517
+
518
+ tap_config = self._policy.get("tap")
519
+ if isinstance(tap_config, str):
520
+ tap_patterns = [tap_config]
521
+ elif isinstance(tap_config, Sequence):
522
+ tap_patterns = [
523
+ str(pattern)
524
+ for pattern in tap_config
525
+ if isinstance(pattern, str) and pattern.strip()
526
+ ]
527
+ else:
528
+ tap_patterns = []
529
+ if not tap_patterns:
530
+ tap_patterns = ["transformer.h.*.mlp.c_proj"]
531
+ self._tap_patterns = tap_patterns
532
+
533
+ # Checkpoint discipline for robust state management
534
+ self._checkpoint_stack: list[dict[str, torch.Tensor]] = []
535
+ self._enable_attempt_count = 0
536
+ self._disable_attempt_count = 0
537
+
538
+ # Constants for reinforced A/B gate
539
+ self.TIE_BREAKER_DEADBAND = float(
540
+ self._policy.get("tie_breaker_deadband", 0.005)
541
+ ) # Extra deadband to avoid flapping on noise
542
+ self.ABSOLUTE_FLOOR = 0.05 # Minimum improvement (ppl-like) to consider
543
+
544
+ # Calibration storage for post-edit evaluation
545
+ self._calibration_batches: list[Any] = []
546
+ self._calibration_window_ids: list[str] = []
547
+ self._calibration_context: dict[str, Any] = {}
548
+ self._calibration_stats_pre_edit: dict[str, Any] | None = None
549
+ self._post_edit_evaluated = False
550
+ self._raw_scales_pre_edit: dict[str, float] = {}
551
+ self._raw_scales_post_edit: dict[str, float] = {}
552
+ self._stats["tap"] = list(self._tap_patterns)
553
+ if self._focus_modules:
554
+ self._stats["focus_modules"] = sorted(self._focus_modules)
555
+ self._stats.setdefault("ab_provenance", {})
556
+
557
+ def _refresh_calibration_defaults(self) -> None:
558
+ """Ensure calibration config contains required defaults."""
559
+ default_calibration = {
560
+ "windows": 6,
561
+ "min_coverage": 4,
562
+ "seed": self._policy.get("seed", 123),
563
+ }
564
+ calibration_cfg = self._policy.get("calibration", {}) or {}
565
+ if not isinstance(calibration_cfg, dict):
566
+ calibration_cfg = {}
567
+ merged_calibration = {**default_calibration, **calibration_cfg}
568
+ self._policy["calibration"] = merged_calibration
569
+
570
+ def _log_event(
571
+ self, operation: str, level: str = "INFO", message: str = "", **data
572
+ ):
573
+ """Log an event with timestamp."""
574
+ event = {
575
+ "timestamp": datetime.utcnow().isoformat(),
576
+ "component": "variance_guard",
577
+ "operation": operation,
578
+ "level": level,
579
+ "message": message,
580
+ "data": data,
581
+ }
582
+ self.events.append(event)
583
+
584
+ def set_run_context(self, report: Any) -> None:
585
+ """Capture run-level context (edit metadata, pairing reference, etc.)."""
586
+ self._report_meta = getattr(report, "meta", {}) or {}
587
+ self._run_context = getattr(report, "context", {}) or {}
588
+ if isinstance(self._run_context, dict):
589
+ self._dataset_meta = self._run_context.get("dataset_meta")
590
+ else:
591
+ self._dataset_meta = None
592
+ if isinstance(self._dataset_meta, dict):
593
+ self._stats.setdefault("dataset_meta", self._dataset_meta)
594
+
595
+ pairing_reference: list[str] = []
596
+ pairing_digest: str | None = None
597
+ if isinstance(self._run_context, dict):
598
+ pairing_baseline = self._run_context.get("pairing_baseline")
599
+ else:
600
+ pairing_baseline = None
601
+ if isinstance(pairing_baseline, dict):
602
+ preview_section = pairing_baseline.get("preview") or {}
603
+ final_section = pairing_baseline.get("final") or {}
604
+ pairing_reference.extend(
605
+ self._normalize_pairing_ids(
606
+ "preview", preview_section.get("window_ids") or []
607
+ )
608
+ )
609
+ pairing_reference.extend(
610
+ self._normalize_pairing_ids(
611
+ "final", final_section.get("window_ids") or []
612
+ )
613
+ )
614
+ if pairing_reference:
615
+ joined = "||".join(pairing_reference)
616
+ pairing_digest = hashlib.blake2s(
617
+ joined.encode("utf-8"), digest_size=16
618
+ ).hexdigest()
619
+ pairing_stats = self._stats.setdefault("pairing_reference", {})
620
+ pairing_stats.update(
621
+ {
622
+ "count": len(pairing_reference),
623
+ "digest": pairing_digest,
624
+ }
625
+ )
626
+ self._pairing_reference = pairing_reference
627
+ self._pairing_digest = pairing_digest
628
+ if pairing_digest is None:
629
+ self._stats.pop("pairing_reference", None)
630
+
631
+ edit_info = getattr(report, "edit", {}) or {}
632
+ params_changed = None
633
+ if isinstance(edit_info, dict):
634
+ deltas = edit_info.get("deltas") or {}
635
+ if isinstance(deltas, dict):
636
+ params_changed = deltas.get("params_changed")
637
+ if params_changed is None:
638
+ params_changed = (
639
+ 0 if edit_info and edit_info.get("name") in {"noop"} else None
640
+ )
641
+ self._params_changed = params_changed
642
+ if params_changed == 0:
643
+ self._monitor_only = True
644
+ self._log_event(
645
+ "monitor_only",
646
+ message="Variance guard forcing monitor-only mode (no parameters changed)",
647
+ )
648
+ # Clear proposed scales in monitor-only mode
649
+ self._scales = {}
650
+
651
+ def _normalize_module_name(self, name: str) -> str:
652
+ """Normalize module names to transformer.h.<idx>.<branch>.c_proj form."""
653
+ if not isinstance(name, str):
654
+ return ""
655
+
656
+ normalized = name.strip()
657
+ if not normalized:
658
+ return normalized
659
+
660
+ if normalized.startswith("block"):
661
+ parts = normalized.split(".")
662
+ if len(parts) >= 2 and parts[0].startswith("block"):
663
+ layer_idx = parts[0][5:]
664
+ branch = parts[1]
665
+ branch = "attn" if branch.startswith("attn") else "mlp"
666
+ return f"transformer.h.{layer_idx}.{branch}.c_proj"
667
+
668
+ if normalized.startswith("transformer.h."):
669
+ if normalized.endswith(".c_proj"):
670
+ return normalized
671
+ if ".mlp" in normalized and ".c_proj" not in normalized:
672
+ return f"{normalized}.c_proj"
673
+ if ".attn" in normalized and ".c_proj" not in normalized:
674
+ return f"{normalized}.c_proj"
675
+
676
+ return normalized
677
+
678
+ def _matches_tap(self, name: str) -> bool:
679
+ """Return True if a module name matches configured tap patterns."""
680
+ normalized = self._normalize_module_name(name)
681
+ for pattern in self._tap_patterns:
682
+ if fnmatch.fnmatch(normalized, pattern) or fnmatch.fnmatch(name, pattern):
683
+ return True
684
+ return False
685
+
686
+ def _normalize_pairing_ids(
687
+ self, prefix: str, window_ids: Sequence[Any]
688
+ ) -> list[str]:
689
+ normalized: list[str] = []
690
+ for idx in window_ids:
691
+ token = str(idx)
692
+ if "::" in token:
693
+ normalized.append(token)
694
+ else:
695
+ normalized.append(f"{prefix}::{token}")
696
+ return normalized
697
+
698
+ def _expected_window_ids(self) -> list[str]:
699
+ return list(self._pairing_reference)
700
+
701
+ def _normalize_scale_name(self, name: str) -> str:
702
+ """Normalize a scale name to the canonical module path."""
703
+ return self._normalize_module_name(name)
704
+
705
+ def _scale_matches_target(self, scale_name: str, target_name: str) -> bool:
706
+ """Check if a scale name from equalise_residual_variance matches a target module name.
707
+
708
+ Handles the format mismatch between:
709
+ - Scale names: block0.mlp, block0.attn
710
+ - Target names: transformer.h.0.mlp.c_proj, transformer.h.0.attn.c_proj
711
+ """
712
+ # Normalize scale name to target format and check direct match
713
+ normalized_scale = self._normalize_scale_name(scale_name)
714
+ if normalized_scale == target_name:
715
+ return True
716
+
717
+ # Convert block format to layer-component extraction
718
+ if scale_name.startswith("block") and (
719
+ "attn" in scale_name or "mlp" in scale_name
720
+ ):
721
+ parts = scale_name.split(".")
722
+ if len(parts) == 2:
723
+ layer_part = parts[0] # e.g., "block0"
724
+ component = parts[1] # e.g., "attn" or "mlp"
725
+ if layer_part.startswith("block"):
726
+ try:
727
+ layer_num = layer_part[5:] # Extract number from "block0"
728
+ # Check if target matches this pattern
729
+ if f"h.{layer_num}.{component}" in target_name:
730
+ return True
731
+ except (ValueError, IndexError):
732
+ pass
733
+
734
+ return False
735
+
736
+ def _is_focus_match(self, name: str) -> bool:
737
+ """Check whether a module name matches the configured focus list."""
738
+ if not self._focus_modules:
739
+ return True
740
+ normalized = self._normalize_module_name(name)
741
+ return normalized in self._focus_modules
742
+
743
+ def _materialize_batch(self, batch: Any) -> Any:
744
+ """Detach tensors from device and clone calibration batches for reuse."""
745
+ if isinstance(batch, dict):
746
+ return {key: self._materialize_batch(val) for key, val in batch.items()}
747
+ if isinstance(batch, list | tuple):
748
+ return type(batch)(self._materialize_batch(val) for val in batch)
749
+ if isinstance(batch, torch.Tensor):
750
+ return batch.detach().cpu()
751
+ try:
752
+ return copy.deepcopy(batch)
753
+ except Exception:
754
+ return batch
755
+
756
+ def _ensure_tensor_value(self, value: Any) -> Any:
757
+ """Convert common calibration value types to torch tensors."""
758
+ if isinstance(value, torch.Tensor):
759
+ return value
760
+ if isinstance(value, np.ndarray):
761
+ return torch.as_tensor(value)
762
+ if isinstance(value, list | tuple):
763
+ try:
764
+ return torch.as_tensor(value)
765
+ except Exception:
766
+ return value
767
+ if isinstance(value, int | float):
768
+ return torch.tensor(value)
769
+ return value
770
+
771
+ def _tensorize_calibration_batches(self, batches: Sequence[Any]) -> list[Any]:
772
+ """Ensure calibration batches contain tensor payloads for model execution."""
773
+ tensor_batches: list[Any] = []
774
+ for batch in batches:
775
+ if isinstance(batch, dict):
776
+ converted: dict[str, Any] = {}
777
+ for key, value in batch.items():
778
+ if key in {"input_ids", "inputs", "attention_mask", "labels"}:
779
+ converted[key] = self._ensure_tensor_value(value)
780
+ else:
781
+ converted[key] = value
782
+ tensor_batches.append(converted)
783
+ elif isinstance(batch, list | tuple):
784
+ converted_list = [self._ensure_tensor_value(val) for val in batch]
785
+ tensor_batches.append(type(batch)(converted_list))
786
+ else:
787
+ tensor_batches.append(self._ensure_tensor_value(batch))
788
+ return tensor_batches
789
+
790
+ def _extract_window_ids(self, batches: Sequence[Any]) -> list[str]:
791
+ """Extract window identifiers from calibration batches when present."""
792
+ window_ids: list[str] = []
793
+ for batch in batches:
794
+ candidate: Any | None = None
795
+ if isinstance(batch, dict):
796
+ if "window_id" in batch:
797
+ candidate = batch["window_id"]
798
+ elif "window_ids" in batch:
799
+ candidate = batch["window_ids"]
800
+ elif isinstance(batch.get("metadata"), dict):
801
+ meta = batch["metadata"]
802
+ candidate = meta.get("window_id") or meta.get("window_ids")
803
+
804
+ if candidate is None:
805
+ continue
806
+
807
+ if isinstance(candidate, list | tuple):
808
+ window_ids.extend(str(item) for item in candidate)
809
+ else:
810
+ window_ids.append(str(candidate))
811
+ if not window_ids and batches:
812
+ window_ids = [str(idx) for idx in range(len(batches))]
813
+ return window_ids
814
+
815
+ def _store_calibration_batches(self, batches: list[Any]) -> None:
816
+ """Persist calibration batches for deterministic post-edit evaluation."""
817
+ materialized = [self._materialize_batch(b) for b in batches]
818
+ self._calibration_batches = self._tensorize_calibration_batches(materialized)
819
+ self._calibration_window_ids = self._extract_window_ids(
820
+ self._calibration_batches
821
+ )
822
+ observed_ids = list(self._calibration_window_ids)
823
+ observed_digest = (
824
+ hashlib.blake2s(
825
+ "||".join(observed_ids).encode("utf-8"), digest_size=16
826
+ ).hexdigest()
827
+ if observed_ids
828
+ else None
829
+ )
830
+ self._calibration_context = {
831
+ "window_ids": list(self._calibration_window_ids),
832
+ "count": len(self._calibration_batches),
833
+ "observed_digest": observed_digest,
834
+ }
835
+ expected_ids = self._expected_window_ids()
836
+ if expected_ids:
837
+ self._calibration_context["expected_digest"] = self._pairing_digest
838
+ expected_subset = expected_ids[: len(observed_ids)] if observed_ids else []
839
+ if observed_ids != expected_subset:
840
+ mismatch = {
841
+ "expected_count": len(expected_ids),
842
+ "observed_count": len(observed_ids),
843
+ "expected_sample": expected_subset[:5]
844
+ if expected_subset
845
+ else expected_ids[:5],
846
+ "observed_sample": observed_ids[:5],
847
+ }
848
+ self._log_event(
849
+ "pairing_mismatch",
850
+ level="ERROR",
851
+ message="Variance guard calibration windows do not match baseline pairing",
852
+ **mismatch,
853
+ )
854
+ self._prepared = False
855
+ raise RuntimeError(
856
+ "Variance guard pairing mismatch: calibration windows diverge from baseline schedule"
857
+ )
858
+ self._stats.setdefault("calibration", {})
859
+ self._stats["calibration"].update(self._calibration_context)
860
+
861
+ def _fingerprint_targets(self) -> str | None:
862
+ """Compute a lightweight fingerprint of targeted module weights."""
863
+ if not self._target_modules:
864
+ return None
865
+
866
+ hasher = hashlib.sha256()
867
+ try:
868
+ for name in sorted(self._target_modules.keys()):
869
+ module = self._target_modules[name]
870
+ state = getattr(module, "state_dict", None)
871
+ if not callable(state):
872
+ continue
873
+ module_state = state()
874
+ for key in sorted(module_state.keys()):
875
+ tensor = module_state[key]
876
+ if hasattr(tensor, "detach"):
877
+ data = tensor.detach().cpu().numpy().tobytes()
878
+ else:
879
+ data = bytes(str(tensor), "utf-8")
880
+ hasher.update(name.encode("utf-8"))
881
+ hasher.update(key.encode("utf-8"))
882
+ hasher.update(data)
883
+ return hasher.hexdigest()[:16]
884
+ except Exception:
885
+ return None
886
+
887
+ def _record_ab_provenance(
888
+ self,
889
+ condition: str,
890
+ *,
891
+ tag: str,
892
+ window_ids: Sequence[str],
893
+ fingerprint: str | None,
894
+ mode: str,
895
+ status: str,
896
+ ) -> None:
897
+ """Record provenance metadata for A/B evaluation conditions."""
898
+ provenance = self._stats.setdefault("ab_provenance", {})
899
+ window_list = list(window_ids)
900
+ provenance[condition] = {
901
+ "tag": tag,
902
+ "mode": mode,
903
+ "window_ids": window_list,
904
+ "window_count": len(window_list),
905
+ "target_fingerprint": fingerprint,
906
+ "status": status,
907
+ "pairing_digest": self._pairing_digest,
908
+ "dataset_hash": (self._dataset_meta or {}).get("dataset_hash"),
909
+ "tokenizer_hash": (self._dataset_meta or {}).get("tokenizer_hash"),
910
+ "model_id": (self._report_meta or {}).get("model_id"),
911
+ "run_seed": (self._report_meta or {}).get("seed"),
912
+ }
913
+
914
+ def _resolve_target_modules(
915
+ self, model: nn.Module, adapter: Any | None = None
916
+ ) -> dict[str, nn.Module]:
917
+ """
918
+ Resolve target modules based on scope policy.
919
+
920
+ Args:
921
+ model: Model to analyze
922
+ adapter: Optional adapter used to query layer modules
923
+
924
+ Returns:
925
+ Dict mapping module names to modules
926
+ """
927
+ targets = {}
928
+ scope = self._policy["scope"]
929
+ audit_candidates: list[dict[str, Any]] = []
930
+ audit_rejections: list[dict[str, Any]] = []
931
+
932
+ def _record_match(name: str, module: nn.Module) -> None:
933
+ audit_candidates.append(
934
+ {
935
+ "name": name,
936
+ "class": module.__class__.__name__,
937
+ "source": "direct",
938
+ }
939
+ )
940
+
941
+ def _record_rejection(name: str, reason: str, module: Any | None) -> None:
942
+ audit_rejections.append(
943
+ {
944
+ "name": name,
945
+ "reason": reason,
946
+ "class": getattr(module, "__class__", type(None)).__name__
947
+ if module is not None
948
+ else None,
949
+ }
950
+ )
951
+
952
+ # Get module types
953
+ try:
954
+ from transformers.pytorch_utils import Conv1D
955
+
956
+ module_types = (nn.Linear, nn.Conv1d, Conv1D)
957
+ except ImportError:
958
+ module_types = (nn.Linear, nn.Conv1d)
959
+
960
+ def _is_supported_module(module: Any) -> bool:
961
+ """Heuristic check that a module looks like a projection."""
962
+ if isinstance(module, module_types):
963
+ return True
964
+ class_name = module.__class__.__name__ if module is not None else ""
965
+ if class_name in {"Conv1D", "Linear"}:
966
+ return True
967
+ weight = getattr(module, "weight", None)
968
+ if weight is None:
969
+ return False
970
+ try:
971
+ dim = weight.dim()
972
+ except Exception:
973
+ dim = getattr(weight, "ndim", None)
974
+ return dim == 2
975
+
976
+ for i, blk in enumerate(_iter_transformer_layers(model)):
977
+ # Handle attention projection based on scope
978
+ if scope in ["attn", "both"] and hasattr(blk, "attn"):
979
+ attn_proj = getattr(blk.attn, "c_proj", None) or getattr(
980
+ blk.attn, "out_proj", None
981
+ )
982
+ name = f"transformer.h.{i}.attn.c_proj"
983
+ if attn_proj is None:
984
+ _record_rejection(name, "missing_module", None)
985
+ elif not self._matches_tap(name):
986
+ _record_rejection(name, "tap_mismatch", attn_proj)
987
+ elif not _is_supported_module(attn_proj):
988
+ _record_rejection(name, "unsupported_type", attn_proj)
989
+ else:
990
+ targets[name] = attn_proj
991
+ _record_match(name, attn_proj)
992
+
993
+ # Handle MLP projection based on scope
994
+ if scope in ["ffn", "both"] and hasattr(blk, "mlp"):
995
+ mlp_proj = (
996
+ getattr(blk.mlp, "c_proj", None)
997
+ or getattr(blk.mlp, "down_proj", None)
998
+ or getattr(blk.mlp, "fc2", None)
999
+ )
1000
+ name = f"transformer.h.{i}.mlp.c_proj"
1001
+ if mlp_proj is None:
1002
+ _record_rejection(name, "missing_module", None)
1003
+ elif not self._matches_tap(name):
1004
+ _record_rejection(name, "tap_mismatch", mlp_proj)
1005
+ elif not _is_supported_module(mlp_proj):
1006
+ _record_rejection(name, "unsupported_type", mlp_proj)
1007
+ else:
1008
+ targets[name] = mlp_proj
1009
+ _record_match(name, mlp_proj)
1010
+
1011
+ fallback_used = False
1012
+
1013
+ # Fallback: ask adapter for layer modules if we could not resolve anything
1014
+ # Strategy:
1015
+ # 1. Try adapter.describe() for layer count - works even when model structure is unknown
1016
+ # 2. If that fails, try _iter_transformer_layers() to count layers
1017
+ # 3. If that fails, try model.config for layer count
1018
+ if (
1019
+ not targets
1020
+ and adapter is not None
1021
+ and hasattr(adapter, "get_layer_modules")
1022
+ ):
1023
+ try:
1024
+ # Get layer count from adapter.describe() first
1025
+ n_layers = 0
1026
+ if hasattr(adapter, "describe"):
1027
+ try:
1028
+ desc = adapter.describe(model)
1029
+ if isinstance(desc, dict):
1030
+ n_layers = int(desc.get("n_layer", 0) or 0)
1031
+ except Exception as desc_exc:
1032
+ self._log_event(
1033
+ "adapter_describe_error",
1034
+ level="DEBUG",
1035
+ message=f"adapter.describe() failed: {desc_exc}",
1036
+ )
1037
+
1038
+ # Fallback: count layers via _iter_transformer_layers()
1039
+ # This works when model has standard structure but no c_proj
1040
+ if n_layers == 0:
1041
+ try:
1042
+ n_layers = sum(1 for _ in _iter_transformer_layers(model))
1043
+ except Exception:
1044
+ pass
1045
+
1046
+ # Fallback: try model.config for layer count
1047
+ if n_layers == 0:
1048
+ config = getattr(_unwrap_model(model), "config", None)
1049
+ if config is not None:
1050
+ n_layers = (
1051
+ getattr(config, "n_layer", 0)
1052
+ or getattr(config, "num_hidden_layers", 0)
1053
+ or getattr(config, "num_layers", 0)
1054
+ or 0
1055
+ )
1056
+
1057
+ if n_layers == 0:
1058
+ self._log_event(
1059
+ "adapter_fallback_no_layers",
1060
+ level="WARN",
1061
+ message="Adapter fallback: could not determine layer count",
1062
+ )
1063
+
1064
+ for i in range(n_layers):
1065
+ try:
1066
+ modules = adapter.get_layer_modules(model, i) or {}
1067
+ except Exception as exc:
1068
+ _record_rejection(
1069
+ f"transformer.h.{i}",
1070
+ f"adapter_error:{exc}",
1071
+ None,
1072
+ )
1073
+ continue
1074
+
1075
+ for key, module in modules.items():
1076
+ if not isinstance(key, str) or not key.endswith("c_proj"):
1077
+ continue
1078
+ branch = "attn" if "attn" in key else "mlp"
1079
+ name = f"transformer.h.{i}.{branch}.c_proj"
1080
+ if not self._matches_tap(name):
1081
+ _record_rejection(name, "tap_mismatch", module)
1082
+ continue
1083
+ if not _is_supported_module(module):
1084
+ _record_rejection(name, "unsupported_type", module)
1085
+ continue
1086
+ targets[name] = module
1087
+ audit_candidates.append(
1088
+ {
1089
+ "name": name,
1090
+ "class": module.__class__.__name__,
1091
+ "source": "adapter_fallback",
1092
+ }
1093
+ )
1094
+ if targets:
1095
+ fallback_used = True
1096
+ except Exception as exc: # pragma: no cover - defensive logging
1097
+ self._log_event(
1098
+ "target_resolution_fallback_error",
1099
+ level="WARN",
1100
+ message="Adapter fallback failed during VE target resolution",
1101
+ error=str(exc),
1102
+ )
1103
+
1104
+ if self._focus_modules:
1105
+ focused: dict[str, nn.Module] = {}
1106
+ for name, module in targets.items():
1107
+ norm_name = self._normalize_module_name(name)
1108
+ if norm_name in self._focus_modules:
1109
+ focused[name] = module
1110
+
1111
+ if not focused:
1112
+ self._log_event(
1113
+ "focus_miss",
1114
+ level="WARN",
1115
+ message="No target modules matched focus list",
1116
+ focus_modules=sorted(self._focus_modules),
1117
+ available=list(targets.keys()),
1118
+ )
1119
+ else:
1120
+ targets = focused
1121
+
1122
+ # Persist audit statistics for reports
1123
+ rejected_summary: dict[str, Any] = {}
1124
+ for item in audit_rejections:
1125
+ reason = item["reason"]
1126
+ bucket = rejected_summary.setdefault(reason, {"count": 0, "examples": []})
1127
+ bucket["count"] += 1
1128
+ if len(bucket["examples"]) < 5:
1129
+ bucket["examples"].append(
1130
+ {
1131
+ "name": item["name"],
1132
+ "class": item["class"],
1133
+ }
1134
+ )
1135
+
1136
+ self._stats["target_resolution"] = {
1137
+ "scope": scope,
1138
+ "tap": list(self._tap_patterns),
1139
+ "total_matched": len(targets),
1140
+ "matched": sorted(targets.keys()),
1141
+ "fallback_used": fallback_used,
1142
+ "candidates_recorded": len(audit_candidates),
1143
+ "rejected": rejected_summary,
1144
+ }
1145
+
1146
+ self._log_event(
1147
+ "target_resolution",
1148
+ message="Resolved variance guard targets",
1149
+ scope=scope,
1150
+ tap=list(self._tap_patterns),
1151
+ matched=len(targets),
1152
+ rejected=sum(item["count"] for item in rejected_summary.values())
1153
+ if rejected_summary
1154
+ else 0,
1155
+ fallback_used=fallback_used,
1156
+ )
1157
+
1158
+ return targets
1159
+
1160
+ def _compute_variance_scales(
1161
+ self, model: nn.Module, dataloader
1162
+ ) -> dict[str, float]:
1163
+ """
1164
+ Compute variance-based scaling factors using existing implementation.
1165
+
1166
+ Args:
1167
+ model: Model to analyze
1168
+ dataloader: Calibration data
1169
+
1170
+ Returns:
1171
+ Dict mapping module names to proposed scaling factors
1172
+ """
1173
+ if self._monitor_only:
1174
+ self._log_event(
1175
+ "monitor_only",
1176
+ message="Skipping variance scale computation in monitor-only mode",
1177
+ )
1178
+ self._raw_scales = {}
1179
+ return {}
1180
+
1181
+ # Use existing equalise_residual_variance but don't apply yet
1182
+ # We'll capture the proposed scales and apply them later in enable()
1183
+
1184
+ # Temporarily capture the current model state
1185
+ original_state = copy.deepcopy(model.state_dict())
1186
+
1187
+ try:
1188
+ tensor_ready_batches = self._tensorize_calibration_batches(dataloader)
1189
+
1190
+ # Run variance equalization to get proposed scales
1191
+ proposed_scales = equalise_residual_variance(
1192
+ model=model,
1193
+ dataloader=tensor_ready_batches,
1194
+ windows=min(
1195
+ self._policy["max_calib"] // 10, 50
1196
+ ), # Limit calibration windows
1197
+ tol=self._policy["deadband"],
1198
+ scale_bias=False, # Don't scale biases to preserve operating points
1199
+ seed=self._policy["seed"],
1200
+ clamp_range=self._policy["clamp"],
1201
+ allow_empty=True,
1202
+ )
1203
+
1204
+ if not proposed_scales and self._policy.get("deadband", 0.0) > 0.0:
1205
+ relaxed_tol = max(self._policy["deadband"] * 0.5, 1e-4)
1206
+ model.load_state_dict(original_state)
1207
+ tensor_ready_batches = self._tensorize_calibration_batches(dataloader)
1208
+ proposed_scales = equalise_residual_variance(
1209
+ model=model,
1210
+ dataloader=tensor_ready_batches,
1211
+ windows=min(self._policy["max_calib"] // 10, 50),
1212
+ tol=relaxed_tol,
1213
+ scale_bias=False,
1214
+ seed=self._policy["seed"] + 7,
1215
+ clamp_range=self._policy["clamp"],
1216
+ allow_empty=True,
1217
+ )
1218
+
1219
+ raw_scales = dict(proposed_scales)
1220
+
1221
+ # Filter raw_scales to only those that have corresponding target modules
1222
+ # This is critical when scope limits targets (e.g., scope=ffn only has mlp targets)
1223
+ # Only apply this filtering when target modules have been resolved
1224
+ if self._target_modules:
1225
+ filtered_raw_scales: dict[str, float] = {}
1226
+ for scale_name, scale_value in raw_scales.items():
1227
+ # Convert scale name to target module name format
1228
+ target_name = self._normalize_scale_name(scale_name)
1229
+ if target_name in self._target_modules:
1230
+ filtered_raw_scales[scale_name] = scale_value
1231
+ elif self._is_focus_match(scale_name):
1232
+ # Fallback: check if any target module matches via pattern
1233
+ for tm_name in self._target_modules:
1234
+ if self._scale_matches_target(scale_name, tm_name):
1235
+ filtered_raw_scales[scale_name] = scale_value
1236
+ break
1237
+ raw_scales = filtered_raw_scales
1238
+
1239
+ focus_raw_scales = {
1240
+ self._normalize_scale_name(name): scale
1241
+ for name, scale in raw_scales.items()
1242
+ if self._is_focus_match(name)
1243
+ }
1244
+ if focus_raw_scales:
1245
+ self._log_event(
1246
+ "variance_raw_scales",
1247
+ message="Captured raw VE scales",
1248
+ count=len(focus_raw_scales),
1249
+ min_scale=min(focus_raw_scales.values()),
1250
+ max_scale=max(focus_raw_scales.values()),
1251
+ )
1252
+ self._stats.setdefault("raw_scales_observations", []).append(
1253
+ {
1254
+ "timestamp": datetime.utcnow().isoformat(),
1255
+ "count": len(focus_raw_scales),
1256
+ "scales": focus_raw_scales,
1257
+ }
1258
+ )
1259
+
1260
+ # Restore original state since we only wanted the proposed scales
1261
+ model.load_state_dict(original_state)
1262
+
1263
+ filtered_scales: dict[str, float] = {}
1264
+ raw_delta_map: dict[str, float] = {}
1265
+ min_abs = float(max(self._policy.get("min_abs_adjust", 0.0), 0.0))
1266
+ max_step = float(max(self._policy.get("max_scale_step", 0.0), 0.0))
1267
+ topk = int(max(self._policy.get("topk_backstop", 0) or 0, 0))
1268
+ best_candidate: tuple[str, float] | None = None
1269
+ best_delta = 0.0
1270
+
1271
+ for name, scale in raw_scales.items():
1272
+ normalized_name = self._normalize_scale_name(name)
1273
+ if not self._is_focus_match(normalized_name):
1274
+ continue
1275
+
1276
+ raw_delta = abs(scale - 1.0)
1277
+ raw_delta_map[name] = raw_delta
1278
+
1279
+ delta = raw_delta
1280
+ if delta > best_delta:
1281
+ best_candidate = (name, scale)
1282
+ best_delta = delta
1283
+
1284
+ if delta < min_abs:
1285
+ continue
1286
+
1287
+ if max_step > 0.0:
1288
+ limited_delta = min(delta, max_step)
1289
+ scale = 1.0 + math.copysign(limited_delta, scale - 1.0)
1290
+
1291
+ filtered_scales[name] = scale
1292
+
1293
+ backstop_used = False
1294
+ if not filtered_scales and topk > 0 and best_candidate:
1295
+ name, scale = best_candidate
1296
+ deadband = float(self._policy.get("deadband", 0.0) or 0.0)
1297
+ threshold = max(deadband * 0.5, min_abs)
1298
+ if best_delta >= threshold:
1299
+ if max_step > 0.0:
1300
+ limited_delta = min(best_delta, max_step)
1301
+ scale = 1.0 + math.copysign(limited_delta, scale - 1.0)
1302
+ filtered_scales[name] = scale
1303
+ raw_delta_map.setdefault(name, best_delta)
1304
+ backstop_used = True
1305
+
1306
+ trimmed_to_limit = False
1307
+ max_adjusted = int(max(self._policy.get("max_adjusted_modules", 0) or 0, 0))
1308
+ if max_adjusted > 0 and len(filtered_scales) > max_adjusted:
1309
+ sorted_candidates = sorted(
1310
+ filtered_scales.items(),
1311
+ key=lambda item: (
1312
+ raw_delta_map.get(item[0], abs(item[1] - 1.0))
1313
+ + (2.0 if item[1] >= 1.0 else 0.0),
1314
+ raw_delta_map.get(item[0], abs(item[1] - 1.0)),
1315
+ item[1],
1316
+ ),
1317
+ reverse=True,
1318
+ )
1319
+ filtered_scales = dict(sorted_candidates[:max_adjusted])
1320
+ trimmed_to_limit = True
1321
+
1322
+ self._raw_scales = raw_scales
1323
+ if backstop_used:
1324
+ self._log_event(
1325
+ "scale_backstop",
1326
+ message=f"Top-{topk} backstop injected {len(filtered_scales)} scale",
1327
+ count=len(filtered_scales),
1328
+ candidate=best_candidate[0] if best_candidate else None,
1329
+ candidate_normalized=self._normalize_scale_name(best_candidate[0])
1330
+ if best_candidate
1331
+ else None,
1332
+ delta=best_delta,
1333
+ )
1334
+ if trimmed_to_limit:
1335
+ self._log_event(
1336
+ "scale_limit",
1337
+ message="Trimmed VE scales to max_adjusted_modules",
1338
+ limit=max_adjusted,
1339
+ count=len(filtered_scales),
1340
+ )
1341
+
1342
+ filtered_normalized = {
1343
+ self._normalize_scale_name(name): scale
1344
+ for name, scale in filtered_scales.items()
1345
+ }
1346
+ self._stats.setdefault("filtered_scales_observations", []).append(
1347
+ {
1348
+ "timestamp": datetime.utcnow().isoformat(),
1349
+ "count": len(filtered_normalized),
1350
+ "scales": filtered_normalized,
1351
+ "backstop_used": backstop_used,
1352
+ }
1353
+ )
1354
+
1355
+ return filtered_scales
1356
+
1357
+ except Exception as e:
1358
+ # Restore state on any error
1359
+ model.load_state_dict(original_state)
1360
+ raise e
1361
+
1362
+ def _evaluate_calibration_pass(
1363
+ self,
1364
+ model: nn.Module,
1365
+ calibration_batches: list[Any],
1366
+ min_coverage: int,
1367
+ calib_seed: int,
1368
+ tag: str,
1369
+ ) -> None:
1370
+ """Run deterministic calibration for A/B evaluation and predictive gating."""
1371
+ predictive_state: dict[str, Any] = {
1372
+ "evaluated": False,
1373
+ "passed": not bool(self._policy.get("predictive_gate", True)),
1374
+ "reason": "disabled"
1375
+ if not bool(self._policy.get("predictive_gate", True))
1376
+ else "no_calibration",
1377
+ "delta_ci": (None, None),
1378
+ "gain_ci": (None, None),
1379
+ "mean_delta": None,
1380
+ }
1381
+
1382
+ requested = len(calibration_batches)
1383
+ self._calibration_stats.update(
1384
+ {
1385
+ "requested": requested,
1386
+ "coverage": 0,
1387
+ "min_coverage": min_coverage,
1388
+ "seed": calib_seed,
1389
+ "status": "no_calibration"
1390
+ if not calibration_batches
1391
+ else "insufficient",
1392
+ "tag": tag,
1393
+ }
1394
+ )
1395
+ self._stats.setdefault("calibration", {})
1396
+ self._stats["calibration"].update(
1397
+ {
1398
+ "requested": requested,
1399
+ "min_coverage": min_coverage,
1400
+ "seed": calib_seed,
1401
+ "tag": tag,
1402
+ }
1403
+ )
1404
+
1405
+ fingerprint = self._fingerprint_targets()
1406
+ if fingerprint:
1407
+ self._stats["target_fingerprint"] = fingerprint
1408
+
1409
+ if not calibration_batches:
1410
+ self._ratio_ci = None
1411
+ self._predictive_gate_state = predictive_state
1412
+ self._stats["predictive_gate"] = predictive_state.copy()
1413
+ return
1414
+
1415
+ device = next(model.parameters()).device
1416
+ torch.manual_seed(calib_seed)
1417
+ ppl_no_ve_samples, loss_no_ve_samples = self._compute_ppl_for_batches(
1418
+ model, calibration_batches, device
1419
+ )
1420
+ coverage = min(len(calibration_batches), len(ppl_no_ve_samples))
1421
+ ppl_with_ve_samples: list[float] = []
1422
+ loss_with_ve_samples: list[float] = []
1423
+ ratio_ci: tuple[float, float] | None = None
1424
+
1425
+ enable_success = False
1426
+ if coverage >= min_coverage and self._scales:
1427
+ prev_enable_attempts = self._enable_attempt_count
1428
+ prev_disable_attempts = self._disable_attempt_count
1429
+ prev_prepared_flag = self._prepared
1430
+ try:
1431
+ self._prepared = True
1432
+ enable_success = self.enable(model)
1433
+ finally:
1434
+ self._prepared = prev_prepared_flag
1435
+ try:
1436
+ torch.manual_seed(calib_seed)
1437
+ if enable_success:
1438
+ ppl_with_ve_samples, loss_with_ve_samples = (
1439
+ self._compute_ppl_for_batches(
1440
+ model, calibration_batches, device
1441
+ )
1442
+ )
1443
+ finally:
1444
+ if enable_success:
1445
+ self.disable(model)
1446
+ # Restore attempt counters to avoid skewing metrics
1447
+ self._enable_attempt_count = prev_enable_attempts
1448
+ self._disable_attempt_count = prev_disable_attempts
1449
+
1450
+ coverage = min(
1451
+ coverage,
1452
+ len(ppl_with_ve_samples) if ppl_with_ve_samples else coverage,
1453
+ len(loss_with_ve_samples) if loss_with_ve_samples else coverage,
1454
+ )
1455
+ self._calibration_stats.update(
1456
+ {
1457
+ "coverage": coverage,
1458
+ "status": "insufficient" if coverage < min_coverage else "pending",
1459
+ }
1460
+ )
1461
+
1462
+ window_ids = self._calibration_window_ids
1463
+ status_a = "evaluated" if coverage > 0 else "no_data"
1464
+ self._record_ab_provenance(
1465
+ "condition_a",
1466
+ tag=tag,
1467
+ mode="edited_no_ve",
1468
+ window_ids=window_ids,
1469
+ fingerprint=fingerprint,
1470
+ status=status_a,
1471
+ )
1472
+
1473
+ if coverage >= min_coverage and not self._scales:
1474
+ ppl_no_ve_samples = ppl_no_ve_samples[:coverage]
1475
+ ppl_no_ve_mean = float(np.mean(ppl_no_ve_samples))
1476
+ self.set_ab_results(
1477
+ ppl_no_ve=ppl_no_ve_mean,
1478
+ ppl_with_ve=ppl_no_ve_mean,
1479
+ windows_used=coverage,
1480
+ seed_used=calib_seed,
1481
+ ratio_ci=(1.0, 1.0),
1482
+ )
1483
+ self._calibration_stats.update(
1484
+ {
1485
+ "status": "no_scaling_required",
1486
+ "ppl_no_ve": ppl_no_ve_mean,
1487
+ "ratio_ci": (1.0, 1.0),
1488
+ }
1489
+ )
1490
+ self._stats["ab_point_estimates"] = {
1491
+ "tag": tag,
1492
+ "ppl_no_ve": ppl_no_ve_mean,
1493
+ "ppl_with_ve": ppl_no_ve_mean,
1494
+ }
1495
+ self._record_ab_provenance(
1496
+ "condition_b",
1497
+ tag=tag,
1498
+ mode="virtual_ve",
1499
+ window_ids=window_ids,
1500
+ fingerprint=fingerprint,
1501
+ status="no_scales",
1502
+ )
1503
+ predictive_state["evaluated"] = True
1504
+ predictive_state["passed"] = False
1505
+ predictive_state["reason"] = "no_scales"
1506
+ self._predictive_gate_state = predictive_state
1507
+ self._stats["predictive_gate"] = predictive_state.copy()
1508
+ return
1509
+
1510
+ if coverage >= min_coverage and ppl_with_ve_samples and loss_with_ve_samples:
1511
+ ppl_no_ve_samples = ppl_no_ve_samples[:coverage]
1512
+ loss_no_ve_samples = loss_no_ve_samples[:coverage]
1513
+ ppl_with_ve_samples = ppl_with_ve_samples[:coverage]
1514
+ loss_with_ve_samples = loss_with_ve_samples[:coverage]
1515
+
1516
+ ratios = [
1517
+ with_val / no_val
1518
+ for with_val, no_val in zip(
1519
+ ppl_with_ve_samples, ppl_no_ve_samples, strict=False
1520
+ )
1521
+ if no_val > 0
1522
+ ]
1523
+ if ratios:
1524
+ ratio_ci = self._bootstrap_mean_ci(
1525
+ ratios,
1526
+ alpha=self._policy.get("alpha", 0.05),
1527
+ n_bootstrap=500,
1528
+ seed=calib_seed,
1529
+ )
1530
+ ppl_no_ve_mean = float(np.mean(ppl_no_ve_samples))
1531
+ ppl_with_ve_mean = float(np.mean(ppl_with_ve_samples))
1532
+ self.set_ab_results(
1533
+ ppl_no_ve=ppl_no_ve_mean,
1534
+ ppl_with_ve=ppl_with_ve_mean,
1535
+ windows_used=coverage,
1536
+ seed_used=calib_seed,
1537
+ ratio_ci=ratio_ci,
1538
+ )
1539
+ self._calibration_stats.update(
1540
+ {
1541
+ "status": "complete",
1542
+ "ppl_no_ve": ppl_no_ve_mean,
1543
+ "ppl_with_ve": ppl_with_ve_mean,
1544
+ "ratio_ci": ratio_ci,
1545
+ }
1546
+ )
1547
+ self._record_ab_provenance(
1548
+ "condition_b",
1549
+ tag=tag,
1550
+ mode="virtual_ve",
1551
+ window_ids=window_ids,
1552
+ fingerprint=fingerprint,
1553
+ status="evaluated",
1554
+ )
1555
+ self._stats["ab_point_estimates"] = {
1556
+ "tag": tag,
1557
+ "ppl_no_ve": ppl_no_ve_mean,
1558
+ "ppl_with_ve": ppl_with_ve_mean,
1559
+ "coverage": coverage,
1560
+ }
1561
+
1562
+ delta_ci: tuple[float, float] | None = None
1563
+ try:
1564
+ delta_ci = compute_paired_delta_log_ci(
1565
+ loss_with_ve_samples,
1566
+ loss_no_ve_samples,
1567
+ method="bca",
1568
+ replicates=500,
1569
+ alpha=self._policy.get("alpha", 0.05),
1570
+ seed=calib_seed + 211,
1571
+ )
1572
+ except Exception as exc:
1573
+ delta_ci = None
1574
+ self._log_event(
1575
+ "predictive_gate_error",
1576
+ level="WARN",
1577
+ message="Failed to compute predictive ΔlogNLL CI",
1578
+ error=str(exc),
1579
+ )
1580
+
1581
+ predictive_state["evaluated"] = True
1582
+ mean_delta = float(
1583
+ np.mean(
1584
+ [
1585
+ with_loss - no_loss
1586
+ for with_loss, no_loss in zip(
1587
+ loss_with_ve_samples,
1588
+ loss_no_ve_samples,
1589
+ strict=False,
1590
+ )
1591
+ ]
1592
+ )
1593
+ )
1594
+ predictive_state["mean_delta"] = mean_delta
1595
+
1596
+ if delta_ci is not None and all(
1597
+ isinstance(val, (int | float)) and math.isfinite(val)
1598
+ for val in delta_ci
1599
+ ):
1600
+ delta_ci = (float(delta_ci[0]), float(delta_ci[1]))
1601
+ gain_ci = (-delta_ci[1], -delta_ci[0])
1602
+ predictive_state["delta_ci"] = delta_ci
1603
+ predictive_state["gain_ci"] = gain_ci
1604
+
1605
+ if not self._policy.get("predictive_gate", True):
1606
+ predictive_state["passed"] = True
1607
+ predictive_state["reason"] = "disabled"
1608
+ else:
1609
+ one_sided = bool(self._policy.get("predictive_one_sided", False))
1610
+ min_effect = float(
1611
+ self._policy.get("min_effect_lognll", 0.0) or 0.0
1612
+ )
1613
+ passed, reason = _predictive_gate_outcome(
1614
+ mean_delta=mean_delta,
1615
+ delta_ci=delta_ci,
1616
+ min_effect=min_effect,
1617
+ one_sided=one_sided,
1618
+ )
1619
+ predictive_state["passed"] = passed
1620
+ predictive_state["reason"] = reason
1621
+ else:
1622
+ predictive_state["delta_ci"] = (None, None)
1623
+ predictive_state["gain_ci"] = (None, None)
1624
+ predictive_state["reason"] = (
1625
+ predictive_state.get("reason", "ci_unavailable")
1626
+ if predictive_state.get("reason") != "disabled"
1627
+ else "disabled"
1628
+ )
1629
+ else:
1630
+ # Fail-open monitor mode
1631
+ self._ratio_ci = None
1632
+ self._log_event(
1633
+ "prepare_monitor_mode",
1634
+ level="WARN",
1635
+ message="VE calibration coverage insufficient; guard will monitor only",
1636
+ requested=requested,
1637
+ coverage=coverage,
1638
+ min_coverage=min_coverage,
1639
+ tag=tag,
1640
+ )
1641
+ if predictive_state.get("reason") not in {"disabled"}:
1642
+ if coverage < min_coverage:
1643
+ predictive_state["reason"] = "insufficient_coverage"
1644
+ elif not self._scales:
1645
+ predictive_state["reason"] = "no_scales"
1646
+ elif not ppl_with_ve_samples:
1647
+ predictive_state["reason"] = "ve_enable_failed"
1648
+
1649
+ if "condition_b" not in self._stats.get("ab_provenance", {}):
1650
+ self._record_ab_provenance(
1651
+ "condition_b",
1652
+ tag=tag,
1653
+ mode="virtual_ve",
1654
+ window_ids=window_ids,
1655
+ fingerprint=fingerprint,
1656
+ status="not_evaluated",
1657
+ )
1658
+
1659
+ if (
1660
+ "ab_point_estimates" not in self._stats
1661
+ or self._stats["ab_point_estimates"].get("tag") != tag
1662
+ ):
1663
+ ppl_no_ve_mean = (
1664
+ float(np.mean(ppl_no_ve_samples[:coverage])) if coverage > 0 else None
1665
+ )
1666
+ ppl_with_ve_mean = (
1667
+ float(np.mean(ppl_with_ve_samples[:coverage]))
1668
+ if ppl_with_ve_samples and coverage > 0
1669
+ else None
1670
+ )
1671
+ self._stats["ab_point_estimates"] = {
1672
+ "tag": tag,
1673
+ "ppl_no_ve": ppl_no_ve_mean,
1674
+ "ppl_with_ve": ppl_with_ve_mean,
1675
+ "coverage": coverage,
1676
+ }
1677
+
1678
+ self._predictive_gate_state = predictive_state
1679
+ self._stats["predictive_gate"] = predictive_state.copy()
1680
+
1681
+ def _refresh_after_edit_metrics(
1682
+ self,
1683
+ model: nn.Module,
1684
+ tag: str = "post_edit",
1685
+ adapter: Any | None = None,
1686
+ ) -> None:
1687
+ """Ensure VE metrics are recomputed on the edited model."""
1688
+ if not self._prepared:
1689
+ return
1690
+ if self._post_edit_evaluated and tag == "post_edit":
1691
+ return
1692
+ if not self._calibration_batches:
1693
+ self._log_event(
1694
+ "post_edit_calibration_skipped",
1695
+ level="WARN",
1696
+ message="Skipping post-edit VE evaluation (no calibration batches)",
1697
+ )
1698
+ self._post_edit_evaluated = True
1699
+ return
1700
+
1701
+ # Refresh target modules in case adapters swapped modules during edit
1702
+ adapter_ref = adapter or self._adapter_ref
1703
+ self._target_modules = self._resolve_target_modules(model, adapter_ref)
1704
+ self._stats["target_module_names"] = sorted(self._target_modules.keys())
1705
+
1706
+ # Recompute scales against the edited model
1707
+ try:
1708
+ self._scales = self._compute_variance_scales(
1709
+ model, self._calibration_batches
1710
+ )
1711
+ except Exception as exc:
1712
+ self._log_event(
1713
+ "post_edit_scale_failure",
1714
+ level="ERROR",
1715
+ message="Failed to recompute VE scales after edit",
1716
+ error=str(exc),
1717
+ )
1718
+ self._scales = {}
1719
+
1720
+ if self._focus_modules:
1721
+ self._scales = {
1722
+ name: scale
1723
+ for name, scale in self._scales.items()
1724
+ if self._is_focus_match(name)
1725
+ }
1726
+
1727
+ self._stats.setdefault(
1728
+ "target_module_names", sorted(self._target_modules.keys())
1729
+ )
1730
+ self._stats["target_modules_post_edit"] = list(self._target_modules.keys())
1731
+ normalized_post_scales = {
1732
+ self._normalize_scale_name(name): scale
1733
+ for name, scale in self._scales.items()
1734
+ }
1735
+ self._stats["proposed_scales_post_edit"] = normalized_post_scales.copy()
1736
+ self._stats["raw_scales_post_edit"] = self._raw_scales.copy()
1737
+ self._stats["raw_scales_post_edit_normalized"] = {
1738
+ self._normalize_scale_name(name): scale
1739
+ for name, scale in self._raw_scales.items()
1740
+ }
1741
+ self._raw_scales_post_edit = {
1742
+ self._normalize_scale_name(name): scale
1743
+ for name, scale in self._raw_scales.items()
1744
+ if self._is_focus_match(name)
1745
+ }
1746
+ if normalized_post_scales:
1747
+ self._log_event(
1748
+ "post_edit_scales",
1749
+ message="Post-edit VE proposed scales",
1750
+ count=len(normalized_post_scales),
1751
+ min_scale=min(normalized_post_scales.values()),
1752
+ max_scale=max(normalized_post_scales.values()),
1753
+ )
1754
+
1755
+ calibration_cfg = self._policy.get("calibration", {})
1756
+ requested_windows = int(calibration_cfg.get("windows", 0) or 0)
1757
+ min_coverage = int(
1758
+ calibration_cfg.get(
1759
+ "min_coverage",
1760
+ max(1, requested_windows // 2 if requested_windows else 1),
1761
+ )
1762
+ )
1763
+ calib_seed = int(calibration_cfg.get("seed", self._policy.get("seed", 123)))
1764
+
1765
+ self._calibration_stats = {
1766
+ "requested": len(self._calibration_batches)
1767
+ if requested_windows == 0
1768
+ else requested_windows,
1769
+ "coverage": 0,
1770
+ "min_coverage": min_coverage,
1771
+ "seed": calib_seed,
1772
+ "status": "pending",
1773
+ "tag": tag,
1774
+ }
1775
+
1776
+ self._evaluate_calibration_pass(
1777
+ model, self._calibration_batches, min_coverage, calib_seed, tag
1778
+ )
1779
+ self._post_edit_evaluated = True
1780
+
1781
+ def _collect_calibration_batches(self, dataloader, windows: int) -> list[Any]:
1782
+ """Collect a deterministic slice of calibration batches."""
1783
+ batches: list[Any] = []
1784
+ iterator = iter(dataloader)
1785
+ for _ in range(max(windows, 0)):
1786
+ try:
1787
+ batches.append(next(iterator))
1788
+ except StopIteration:
1789
+ break
1790
+ return batches
1791
+
1792
+ def _prepare_batch_tensors(
1793
+ self, batch: Any, device: torch.device
1794
+ ) -> tuple[torch.Tensor | None, torch.Tensor | None]:
1795
+ """Normalize batch inputs to tensors on the target device."""
1796
+ if isinstance(batch, dict):
1797
+ input_ids = batch.get("input_ids", batch.get("inputs"))
1798
+ attention_mask = batch.get("attention_mask")
1799
+ elif isinstance(batch, tuple | list) and batch:
1800
+ input_ids = batch[0]
1801
+ attention_mask = batch[1] if len(batch) > 1 else None
1802
+ else:
1803
+ input_ids = batch
1804
+ attention_mask = None
1805
+
1806
+ if input_ids is None:
1807
+ return None, None
1808
+
1809
+ if not isinstance(input_ids, torch.Tensor):
1810
+ input_ids = torch.as_tensor(input_ids)
1811
+
1812
+ if input_ids.dim() == 1:
1813
+ input_ids = input_ids.unsqueeze(0)
1814
+
1815
+ try:
1816
+ input_ids = input_ids.to(device)
1817
+ except Exception:
1818
+ input_ids = input_ids.clone()
1819
+
1820
+ labels = input_ids.clone()
1821
+
1822
+ if attention_mask is not None:
1823
+ if not isinstance(attention_mask, torch.Tensor):
1824
+ attention_mask = torch.as_tensor(attention_mask)
1825
+ if attention_mask.dim() == 1:
1826
+ attention_mask = attention_mask.unsqueeze(0)
1827
+ try:
1828
+ attention_mask = attention_mask.to(device)
1829
+ except Exception:
1830
+ attention_mask = attention_mask.clone()
1831
+ labels = labels.masked_fill(attention_mask == 0, -100)
1832
+
1833
+ return input_ids, labels
1834
+
1835
+ def _compute_ppl_for_batches(
1836
+ self,
1837
+ model: nn.Module,
1838
+ batches: list[Any],
1839
+ device: torch.device,
1840
+ ) -> tuple[list[float], list[float]]:
1841
+ """Compute per-batch perplexity and log-loss values for deterministic calibration."""
1842
+ ppl_values: list[float] = []
1843
+ loss_values: list[float] = []
1844
+ if not batches:
1845
+ return ppl_values, loss_values
1846
+
1847
+ model_was_training = model.training
1848
+ model.eval()
1849
+
1850
+ with torch.no_grad():
1851
+ for batch in batches:
1852
+ try:
1853
+ inputs, labels = self._prepare_batch_tensors(batch, device)
1854
+ if inputs is None or labels is None:
1855
+ continue
1856
+
1857
+ try:
1858
+ outputs = model(inputs, labels=labels)
1859
+ except TypeError:
1860
+ outputs = model(inputs)
1861
+ loss_val = None
1862
+ if hasattr(outputs, "loss") and hasattr(outputs.loss, "item"):
1863
+ loss_val = outputs.loss.item()
1864
+
1865
+ if loss_val is None and isinstance(outputs, torch.Tensor):
1866
+ try:
1867
+ if labels is not None and outputs.shape == labels.shape:
1868
+ loss_val = torch.nn.functional.mse_loss(
1869
+ outputs.float(), labels.float()
1870
+ ).item()
1871
+ else:
1872
+ loss_val = outputs.float().pow(2).mean().item()
1873
+ except Exception:
1874
+ loss_val = None
1875
+
1876
+ if loss_val is None or not math.isfinite(loss_val):
1877
+ continue
1878
+
1879
+ loss = float(loss_val)
1880
+ ppl = math.exp(loss)
1881
+ if math.isfinite(ppl):
1882
+ ppl_values.append(ppl)
1883
+ loss_values.append(loss)
1884
+ except Exception:
1885
+ continue
1886
+
1887
+ if model_was_training:
1888
+ model.train()
1889
+
1890
+ return ppl_values, loss_values
1891
+
1892
+ def _bootstrap_mean_ci(
1893
+ self,
1894
+ samples: list[float],
1895
+ alpha: float,
1896
+ n_bootstrap: int = 500,
1897
+ seed: int | None = None,
1898
+ ) -> tuple[float, float]:
1899
+ """Compute bootstrap confidence interval for the sample mean."""
1900
+ if not samples:
1901
+ raise ValueError("Cannot compute CI on empty samples")
1902
+ data = np.asarray(samples, dtype=float)
1903
+ rng = np.random.default_rng(seed)
1904
+ stats = np.empty(n_bootstrap, dtype=float)
1905
+ for i in range(n_bootstrap):
1906
+ indices = rng.integers(0, data.size, size=data.size)
1907
+ stats[i] = float(np.mean(data[indices]))
1908
+ lower = float(np.percentile(stats, 100 * (alpha / 2)))
1909
+ upper = float(np.percentile(stats, 100 * (1 - alpha / 2)))
1910
+ return lower, upper
1911
+
1912
+ def prepare(
1913
+ self,
1914
+ model: nn.Module,
1915
+ adapter=None,
1916
+ calib=None,
1917
+ policy: dict[str, Any] | None = None,
1918
+ ) -> dict[str, Any]:
1919
+ """
1920
+ Prepare variance guard by computing proposed scaling factors.
1921
+
1922
+ Args:
1923
+ model: The model that will be edited
1924
+ adapter: ModelAdapter (optional, for compatibility)
1925
+ calib: Calibration data for variance measurement
1926
+ policy: Guard policy parameters (optional)
1927
+
1928
+ Returns:
1929
+ Dictionary with preparation results and proposed scales
1930
+ """
1931
+ start_time = time.time()
1932
+
1933
+ # Update policy if provided
1934
+ if policy:
1935
+ for key in [
1936
+ "min_gain",
1937
+ "max_calib",
1938
+ "scope",
1939
+ "clamp",
1940
+ "deadband",
1941
+ "seed",
1942
+ "mode",
1943
+ "min_rel_gain",
1944
+ "alpha",
1945
+ "tie_breaker_deadband",
1946
+ "min_effect_lognll",
1947
+ "min_abs_adjust",
1948
+ "max_scale_step",
1949
+ "topk_backstop",
1950
+ "max_adjusted_modules",
1951
+ "predictive_gate",
1952
+ "predictive_one_sided",
1953
+ "absolute_floor_ppl",
1954
+ "monitor_only",
1955
+ "calibration",
1956
+ "target_modules",
1957
+ ]:
1958
+ if key in policy:
1959
+ self._policy[key] = policy[key]
1960
+ if self._policy.get("min_effect_lognll") is not None:
1961
+ self._policy["min_effect_lognll"] = float(
1962
+ self._policy["min_effect_lognll"]
1963
+ )
1964
+ self.TIE_BREAKER_DEADBAND = float(
1965
+ self._policy.get("tie_breaker_deadband", self.TIE_BREAKER_DEADBAND)
1966
+ )
1967
+ self._refresh_calibration_defaults()
1968
+ if "absolute_floor_ppl" in policy:
1969
+ self.ABSOLUTE_FLOOR = float(
1970
+ self._policy.get(
1971
+ "absolute_floor_pm",
1972
+ self._policy.get("absolute_floor_ppl", self.ABSOLUTE_FLOOR),
1973
+ )
1974
+ )
1975
+ if "target_modules" in policy:
1976
+ focus_list = [
1977
+ self._normalize_module_name(name)
1978
+ for name in (policy.get("target_modules") or [])
1979
+ if isinstance(name, str)
1980
+ ]
1981
+ self._focus_modules = set(focus_list)
1982
+ if self._focus_modules:
1983
+ self._policy["target_modules"] = sorted(self._focus_modules)
1984
+ self._stats["focus_modules"] = sorted(self._focus_modules)
1985
+
1986
+ self._log_event(
1987
+ "prepare",
1988
+ message=f"Preparing variance guard with scope={self._policy['scope']}, min_gain={self._policy['min_gain']}",
1989
+ )
1990
+
1991
+ try:
1992
+ # Resolve target modules
1993
+ self._target_modules = self._resolve_target_modules(model, adapter)
1994
+ self._stats["target_module_names"] = sorted(self._target_modules.keys())
1995
+
1996
+ if not self._target_modules:
1997
+ self._prepared = False
1998
+ self._adapter_ref = adapter
1999
+ return {
2000
+ "baseline_metrics": {},
2001
+ "policy_applied": self._policy,
2002
+ "preparation_time": time.time() - start_time,
2003
+ "ready": False,
2004
+ "warning": "No target modules found for variance equalization",
2005
+ }
2006
+
2007
+ self._adapter_ref = adapter
2008
+
2009
+ calibration_cfg = self._policy.get("calibration", {})
2010
+ requested_windows = int(calibration_cfg.get("windows", 0) or 0)
2011
+ min_coverage = int(
2012
+ calibration_cfg.get(
2013
+ "min_coverage",
2014
+ max(1, requested_windows // 2 if requested_windows else 1),
2015
+ )
2016
+ )
2017
+ calib_seed = int(calibration_cfg.get("seed", self._policy.get("seed", 123)))
2018
+
2019
+ scale_windows = min(self._policy["max_calib"] // 10, 50)
2020
+ limit_for_batches = max(scale_windows, requested_windows)
2021
+
2022
+ calib_batches: list[Any] = []
2023
+ dataloader_source = None
2024
+
2025
+ if calib is not None:
2026
+ if hasattr(calib, "dataloader"):
2027
+ dataloader_source = calib.dataloader
2028
+ calib_batches = self._collect_calibration_batches(
2029
+ dataloader_source, limit_for_batches
2030
+ )
2031
+ elif isinstance(calib, Sequence):
2032
+ calib_batches = list(
2033
+ itertools.islice(iter(calib), limit_for_batches)
2034
+ )
2035
+ elif isinstance(calib, Iterable):
2036
+ calib_batches = list(
2037
+ itertools.islice(iter(calib), limit_for_batches)
2038
+ )
2039
+
2040
+ if calib_batches:
2041
+ self._scales = self._compute_variance_scales(model, calib_batches)
2042
+ else:
2043
+ self._scales = {}
2044
+ self._raw_scales = {}
2045
+ self._log_event(
2046
+ "prepare_warning",
2047
+ level="WARN",
2048
+ message="No calibration data provided, VE will be disabled",
2049
+ )
2050
+
2051
+ # Deterministic VE calibration pass for A/B readiness
2052
+ self._calibration_stats = {
2053
+ "requested": requested_windows,
2054
+ "coverage": 0,
2055
+ "min_coverage": min_coverage,
2056
+ "seed": calib_seed,
2057
+ "status": "skipped" if requested_windows == 0 else "insufficient",
2058
+ }
2059
+
2060
+ calibration_batches = calib_batches[:requested_windows]
2061
+ self._store_calibration_batches(calibration_batches)
2062
+ predictive_state: dict[str, Any] = {
2063
+ "evaluated": False,
2064
+ "passed": not bool(self._policy.get("predictive_gate", True)),
2065
+ "reason": "disabled"
2066
+ if not bool(self._policy.get("predictive_gate", True))
2067
+ else "no_calibration",
2068
+ "delta_ci": (None, None),
2069
+ "gain_ci": (None, None),
2070
+ "mean_delta": None,
2071
+ }
2072
+
2073
+ if calibration_batches:
2074
+ device = next(model.parameters()).device
2075
+ torch.manual_seed(calib_seed)
2076
+ ppl_no_ve_samples, loss_no_ve_samples = self._compute_ppl_for_batches(
2077
+ model, calibration_batches, device
2078
+ )
2079
+ coverage = min(len(calibration_batches), len(ppl_no_ve_samples))
2080
+ ppl_with_ve_samples: list[float] = []
2081
+ loss_with_ve_samples: list[float] = []
2082
+ ratio_ci: tuple[float, float] | None = None
2083
+
2084
+ enable_success = False
2085
+ if coverage >= min_coverage and self._scales:
2086
+ prev_enable_attempts = self._enable_attempt_count
2087
+ prev_disable_attempts = self._disable_attempt_count
2088
+ prev_prepared_flag = self._prepared
2089
+ try:
2090
+ self._prepared = True
2091
+ enable_success = self.enable(model)
2092
+ finally:
2093
+ self._prepared = prev_prepared_flag
2094
+ try:
2095
+ torch.manual_seed(calib_seed)
2096
+ if enable_success:
2097
+ (
2098
+ ppl_with_ve_samples,
2099
+ loss_with_ve_samples,
2100
+ ) = self._compute_ppl_for_batches(
2101
+ model, calibration_batches, device
2102
+ )
2103
+ finally:
2104
+ if enable_success:
2105
+ self.disable(model)
2106
+ # Restore attempt counters to avoid skewing metrics
2107
+ self._enable_attempt_count = prev_enable_attempts
2108
+ self._disable_attempt_count = prev_disable_attempts
2109
+
2110
+ coverage = min(
2111
+ coverage,
2112
+ len(ppl_with_ve_samples) if ppl_with_ve_samples else coverage,
2113
+ len(loss_with_ve_samples) if loss_with_ve_samples else coverage,
2114
+ )
2115
+ self._calibration_stats.update(
2116
+ {"coverage": coverage, "status": "insufficient"}
2117
+ )
2118
+
2119
+ if coverage >= min_coverage and not self._scales:
2120
+ ppl_no_ve_samples = ppl_no_ve_samples[:coverage]
2121
+ ppl_no_ve_mean = float(np.mean(ppl_no_ve_samples))
2122
+ self.set_ab_results(
2123
+ ppl_no_ve=ppl_no_ve_mean,
2124
+ ppl_with_ve=ppl_no_ve_mean,
2125
+ windows_used=coverage,
2126
+ seed_used=calib_seed,
2127
+ ratio_ci=(1.0, 1.0),
2128
+ )
2129
+ self._calibration_stats.update(
2130
+ {
2131
+ "status": "no_scaling_required",
2132
+ "ppl_no_ve": ppl_no_ve_mean,
2133
+ "ratio_ci": (1.0, 1.0),
2134
+ }
2135
+ )
2136
+
2137
+ if (
2138
+ coverage >= min_coverage
2139
+ and ppl_with_ve_samples
2140
+ and loss_with_ve_samples
2141
+ ):
2142
+ ppl_no_ve_samples = ppl_no_ve_samples[:coverage]
2143
+ loss_no_ve_samples = loss_no_ve_samples[:coverage]
2144
+ ppl_with_ve_samples = ppl_with_ve_samples[:coverage]
2145
+ loss_with_ve_samples = loss_with_ve_samples[:coverage]
2146
+
2147
+ ratios = [
2148
+ with_val / no_val
2149
+ for with_val, no_val in zip(
2150
+ ppl_with_ve_samples, ppl_no_ve_samples, strict=False
2151
+ )
2152
+ if no_val > 0
2153
+ ]
2154
+ if ratios:
2155
+ ratio_ci = self._bootstrap_mean_ci(
2156
+ ratios,
2157
+ alpha=self._policy.get("alpha", 0.05),
2158
+ n_bootstrap=500,
2159
+ seed=calib_seed,
2160
+ )
2161
+ ppl_no_ve_mean = float(np.mean(ppl_no_ve_samples))
2162
+ ppl_with_ve_mean = float(np.mean(ppl_with_ve_samples))
2163
+ self.set_ab_results(
2164
+ ppl_no_ve=ppl_no_ve_mean,
2165
+ ppl_with_ve=ppl_with_ve_mean,
2166
+ windows_used=coverage,
2167
+ seed_used=calib_seed,
2168
+ ratio_ci=ratio_ci,
2169
+ )
2170
+ self._calibration_stats.update(
2171
+ {
2172
+ "status": "complete",
2173
+ "ppl_no_ve": ppl_no_ve_mean,
2174
+ "ppl_with_ve": ppl_with_ve_mean,
2175
+ "ratio_ci": ratio_ci,
2176
+ }
2177
+ )
2178
+
2179
+ delta_ci: tuple[float, float] | None = None
2180
+ try:
2181
+ delta_ci = compute_paired_delta_log_ci(
2182
+ loss_with_ve_samples,
2183
+ loss_no_ve_samples,
2184
+ method="bca",
2185
+ replicates=500,
2186
+ alpha=self._policy.get("alpha", 0.05),
2187
+ seed=calib_seed + 211,
2188
+ )
2189
+ except Exception as exc:
2190
+ delta_ci = None
2191
+ self._log_event(
2192
+ "predictive_gate_error",
2193
+ level="WARN",
2194
+ message="Failed to compute predictive ΔlogNLL CI",
2195
+ error=str(exc),
2196
+ )
2197
+
2198
+ predictive_state["evaluated"] = True
2199
+ mean_delta = float(
2200
+ np.mean(
2201
+ [
2202
+ with_loss - no_loss
2203
+ for with_loss, no_loss in zip(
2204
+ loss_with_ve_samples,
2205
+ loss_no_ve_samples,
2206
+ strict=False,
2207
+ )
2208
+ ]
2209
+ )
2210
+ )
2211
+ predictive_state["mean_delta"] = mean_delta
2212
+
2213
+ if delta_ci is not None and all(
2214
+ isinstance(val, (int | float)) and math.isfinite(val)
2215
+ for val in delta_ci
2216
+ ):
2217
+ delta_ci = (float(delta_ci[0]), float(delta_ci[1]))
2218
+ gain_ci = (-delta_ci[1], -delta_ci[0])
2219
+ predictive_state["delta_ci"] = delta_ci
2220
+ predictive_state["gain_ci"] = gain_ci
2221
+
2222
+ if not self._policy.get("predictive_gate", True):
2223
+ predictive_state["passed"] = True
2224
+ predictive_state["reason"] = "disabled"
2225
+ else:
2226
+ one_sided = bool(
2227
+ self._policy.get("predictive_one_sided", False)
2228
+ )
2229
+ min_effect = float(
2230
+ self._policy.get("min_effect_lognll", 0.0) or 0.0
2231
+ )
2232
+ passed, reason = _predictive_gate_outcome(
2233
+ mean_delta=mean_delta,
2234
+ delta_ci=delta_ci,
2235
+ min_effect=min_effect,
2236
+ one_sided=one_sided,
2237
+ )
2238
+ predictive_state["passed"] = passed
2239
+ predictive_state["reason"] = reason
2240
+ else:
2241
+ predictive_state["delta_ci"] = (None, None)
2242
+ predictive_state["gain_ci"] = (None, None)
2243
+ predictive_state["reason"] = (
2244
+ predictive_state.get("reason", "ci_unavailable")
2245
+ if predictive_state.get("reason") != "disabled"
2246
+ else "disabled"
2247
+ )
2248
+ else:
2249
+ # Fail-open monitor mode
2250
+ self._ratio_ci = None
2251
+ self._log_event(
2252
+ "prepare_monitor_mode",
2253
+ level="WARN",
2254
+ message="VE calibration coverage insufficient; guard will monitor only",
2255
+ requested=requested_windows,
2256
+ coverage=coverage,
2257
+ min_coverage=min_coverage,
2258
+ )
2259
+ if predictive_state.get("reason") not in {"disabled"}:
2260
+ if coverage < min_coverage:
2261
+ predictive_state["reason"] = "insufficient_coverage"
2262
+ elif not self._scales:
2263
+ predictive_state["reason"] = "no_scales"
2264
+ elif not ppl_with_ve_samples:
2265
+ predictive_state["reason"] = "ve_enable_failed"
2266
+ else:
2267
+ self._ratio_ci = None
2268
+ if predictive_state.get("reason") != "disabled":
2269
+ predictive_state["reason"] = "no_calibration"
2270
+
2271
+ self._predictive_gate_state = predictive_state
2272
+
2273
+ # Store baseline statistics without overwriting pre-populated instrumentation
2274
+ self._stats.setdefault(
2275
+ "target_module_names", sorted(self._target_modules.keys())
2276
+ )
2277
+ self._stats["target_modules"] = list(self._target_modules.keys())
2278
+ normalized_scales = {
2279
+ self._normalize_scale_name(name): scale
2280
+ for name, scale in self._scales.items()
2281
+ }
2282
+ self._stats["proposed_scales_pre_edit"] = normalized_scales.copy()
2283
+ self._stats["raw_scales_pre_edit"] = self._raw_scales.copy()
2284
+ self._stats["raw_scales_pre_edit_normalized"] = {
2285
+ self._normalize_scale_name(name): scale
2286
+ for name, scale in self._raw_scales.items()
2287
+ }
2288
+ self._stats["total_target_modules"] = len(self._target_modules)
2289
+ self._stats["modules_with_scales_pre_edit"] = len(self._scales)
2290
+ self._stats.setdefault("calibration", {}).update(
2291
+ self._calibration_stats.copy()
2292
+ )
2293
+ self._stats["scale_filtering"] = {
2294
+ "raw_scales": len(self._raw_scales),
2295
+ "filtered_scales": len(self._scales),
2296
+ "min_abs_adjust": float(self._policy.get("min_abs_adjust", 0.0)),
2297
+ "max_scale_step": float(self._policy.get("max_scale_step", 0.0)),
2298
+ "topk_backstop": int(self._policy.get("topk_backstop", 0)),
2299
+ }
2300
+ self._stats["predictive_gate"] = predictive_state.copy()
2301
+ self._calibration_stats_pre_edit = self._calibration_stats.copy()
2302
+ self._post_edit_evaluated = False
2303
+ self._raw_scales_pre_edit = {
2304
+ self._normalize_scale_name(name): scale
2305
+ for name, scale in self._raw_scales.items()
2306
+ }
2307
+
2308
+ self._prepared = True
2309
+ preparation_time = time.time() - start_time
2310
+
2311
+ self._log_event(
2312
+ "prepare_success",
2313
+ message=f"Prepared variance guard with {len(self._target_modules)} target modules",
2314
+ target_modules=len(self._target_modules),
2315
+ proposed_scales=len(self._scales),
2316
+ preparation_time=preparation_time,
2317
+ )
2318
+
2319
+ return {
2320
+ "baseline_metrics": {
2321
+ "target_modules": len(self._target_modules),
2322
+ "proposed_scales": len(self._scales),
2323
+ "scope": self._policy["scope"],
2324
+ "scale_statistics": {
2325
+ "mean_scale": float(
2326
+ sum(self._scales.values()) / len(self._scales)
2327
+ )
2328
+ if self._scales
2329
+ else 1.0,
2330
+ "min_scale": min(self._scales.values())
2331
+ if self._scales
2332
+ else 1.0,
2333
+ "max_scale": max(self._scales.values())
2334
+ if self._scales
2335
+ else 1.0,
2336
+ },
2337
+ "calibration": self._calibration_stats.copy(),
2338
+ },
2339
+ "policy_applied": self._policy.copy(),
2340
+ "preparation_time": preparation_time,
2341
+ "ready": True,
2342
+ }
2343
+
2344
+ except Exception as e:
2345
+ self._prepared = False
2346
+ self._adapter_ref = adapter
2347
+ self._log_event(
2348
+ "prepare_failed",
2349
+ level="ERROR",
2350
+ message=f"Failed to prepare variance guard: {str(e)}",
2351
+ error=str(e),
2352
+ )
2353
+
2354
+ return {
2355
+ "baseline_metrics": {},
2356
+ "policy_applied": self._policy,
2357
+ "preparation_time": time.time() - start_time,
2358
+ "ready": False,
2359
+ "error": str(e),
2360
+ }
2361
+
2362
+ def before_edit(self, model: nn.Module) -> None:
2363
+ """
2364
+ Execute before edit (no action needed for variance guard).
2365
+
2366
+ Args:
2367
+ model: The model about to be edited
2368
+ """
2369
+ if self._prepared:
2370
+ self._log_event(
2371
+ "before_edit", message="Variance guard ready for A/B testing"
2372
+ )
2373
+
2374
+ def after_edit(self, model: nn.Module) -> None:
2375
+ """
2376
+ Execute after edit (A/B testing happens via enable/disable calls).
2377
+
2378
+ Args:
2379
+ model: The model that was just edited
2380
+ """
2381
+ if not self._prepared:
2382
+ self._log_event(
2383
+ "after_edit_skipped",
2384
+ level="WARN",
2385
+ message="Variance guard not prepared, skipping",
2386
+ )
2387
+ return
2388
+
2389
+ self._refresh_after_edit_metrics(model)
2390
+ self._log_event(
2391
+ "after_edit",
2392
+ message="Variance guard refreshed post-edit metrics",
2393
+ evaluated=self._post_edit_evaluated,
2394
+ proposed_scales=len(self._scales),
2395
+ )
2396
+
2397
+ def enable(self, model: nn.Module, adapter=None) -> bool:
2398
+ """
2399
+ Enable variance equalization with checkpoint discipline and idempotent operation.
2400
+
2401
+ Args:
2402
+ model: Model to apply VE to
2403
+ adapter: ModelAdapter (optional, for tying preservation)
2404
+
2405
+ Returns:
2406
+ True if VE was successfully enabled, False otherwise
2407
+ """
2408
+ self._enable_attempt_count += 1
2409
+
2410
+ if self._monitor_only:
2411
+ self._log_event(
2412
+ "enable_skipped_monitor_only",
2413
+ level="INFO",
2414
+ message="Monitor-only mode: VE enable skipped",
2415
+ attempt_count=self._enable_attempt_count,
2416
+ )
2417
+ self._enabled = False
2418
+ return False
2419
+
2420
+ if not self._prepared or not self._scales:
2421
+ self._log_event(
2422
+ "enable_skipped",
2423
+ level="WARN",
2424
+ message="Cannot enable VE: not prepared or no scales computed",
2425
+ attempt_count=self._enable_attempt_count,
2426
+ )
2427
+ return False
2428
+
2429
+ # Idempotent check: if already enabled, verify state and return success
2430
+ if self._enabled:
2431
+ self._log_event(
2432
+ "enable_idempotent",
2433
+ message="VE already enabled, verifying state",
2434
+ attempt_count=self._enable_attempt_count,
2435
+ )
2436
+ return True
2437
+
2438
+ # Push checkpoint before attempting enable
2439
+ self._push_checkpoint(model)
2440
+
2441
+ self._log_event(
2442
+ "enable_start",
2443
+ message=f"Enabling VE with {len(self._scales)} scale factors",
2444
+ attempt_count=self._enable_attempt_count,
2445
+ )
2446
+
2447
+ try:
2448
+ # Apply scaling factors in-place with robust error handling
2449
+ applied_count = 0
2450
+ failed_modules = []
2451
+
2452
+ for scale_name, scale_factor in self._scales.items():
2453
+ try:
2454
+ # Find the actual module by matching scale name to target modules
2455
+ module = None
2456
+ for target_name, target_module in self._target_modules.items():
2457
+ # Match by exact name or by checking if they refer to the same component
2458
+ if scale_name == target_name:
2459
+ module = target_module
2460
+ break
2461
+
2462
+ # Convert blockX.attn/mlp format to transformer.h.X.attn/mlp.c_proj format
2463
+ if scale_name.startswith("block") and (
2464
+ "attn" in scale_name or "mlp" in scale_name
2465
+ ):
2466
+ # Extract layer number and component (attn/mlp)
2467
+ parts = scale_name.split(".")
2468
+ if len(parts) == 2:
2469
+ layer_part = parts[0] # e.g., "block0"
2470
+ component = parts[1] # e.g., "attn" or "mlp"
2471
+
2472
+ if layer_part.startswith("block"):
2473
+ layer_num = layer_part[
2474
+ 5:
2475
+ ] # Extract number from "block0"
2476
+ expected_target = (
2477
+ f"transformer.h.{layer_num}.{component}.c_proj"
2478
+ )
2479
+
2480
+ if target_name == expected_target:
2481
+ module = target_module
2482
+ break
2483
+
2484
+ # Fallback: check if scale_name components match target_name components
2485
+ if (
2486
+ scale_name.endswith(target_name.split(".")[-1])
2487
+ or target_name.endswith(scale_name)
2488
+ or any(
2489
+ part in target_name for part in scale_name.split(".")
2490
+ )
2491
+ ):
2492
+ module = target_module
2493
+ break
2494
+
2495
+ if module is not None and hasattr(module, "weight"):
2496
+ # Check for quantized weights (skip if unsupported)
2497
+ if hasattr(module.weight, "dtype") and module.weight.dtype in [
2498
+ torch.int8,
2499
+ ]:
2500
+ self._log_event(
2501
+ "scale_skipped",
2502
+ level="WARN",
2503
+ message=f"Skipping quantized weights in {scale_name}",
2504
+ module_name=scale_name,
2505
+ dtype=str(module.weight.dtype),
2506
+ )
2507
+ continue
2508
+
2509
+ # Store original scale factor for exact reversion
2510
+ if scale_name not in self._original_scales:
2511
+ self._original_scales[scale_name] = 1.0
2512
+
2513
+ # Apply scaling with proper device handling
2514
+ with torch.no_grad():
2515
+ original_device = module.weight.device
2516
+ original_dtype = module.weight.dtype
2517
+
2518
+ # Use scalar multiplication to avoid MPS issues
2519
+ if str(original_device).startswith("mps"):
2520
+ module.weight.data = module.weight.data * scale_factor
2521
+ else:
2522
+ scale_tensor = torch.tensor(
2523
+ scale_factor,
2524
+ device=original_device,
2525
+ dtype=original_dtype,
2526
+ )
2527
+ module.weight.mul_(scale_tensor)
2528
+
2529
+ applied_count += 1
2530
+
2531
+ self._log_event(
2532
+ "scale_applied",
2533
+ message=f"Applied scale {scale_factor:.3f} to {scale_name}",
2534
+ module_name=scale_name,
2535
+ scale_factor=scale_factor,
2536
+ )
2537
+ else:
2538
+ failed_modules.append(scale_name)
2539
+
2540
+ except Exception as e:
2541
+ failed_modules.append(scale_name)
2542
+ self._log_event(
2543
+ "scale_apply_error",
2544
+ level="ERROR",
2545
+ message=f"Failed to apply scale to {scale_name}: {str(e)}",
2546
+ module_name=scale_name,
2547
+ error=str(e),
2548
+ )
2549
+
2550
+ # Check if enough modules were successfully scaled
2551
+ if applied_count == 0:
2552
+ # Complete failure - rollback
2553
+ self._pop_checkpoint(model)
2554
+ self._log_event(
2555
+ "enable_failed",
2556
+ level="ERROR",
2557
+ message="No modules were successfully scaled, rolling back",
2558
+ failed_modules=failed_modules,
2559
+ )
2560
+ return False
2561
+
2562
+ # Partial or complete success
2563
+ if failed_modules:
2564
+ self._log_event(
2565
+ "enable_partial",
2566
+ level="WARN",
2567
+ message=f"Partial success: {applied_count} succeeded, {len(failed_modules)} failed",
2568
+ applied_count=applied_count,
2569
+ failed_modules=failed_modules,
2570
+ )
2571
+
2572
+ # Commit the checkpoint on success
2573
+ self._commit_checkpoint()
2574
+ self._enabled = True
2575
+
2576
+ self._log_event(
2577
+ "enable_complete",
2578
+ message=f"Enabled VE on {applied_count}/{len(self._scales)} modules",
2579
+ applied_count=applied_count,
2580
+ total_scales=len(self._scales),
2581
+ attempt_count=self._enable_attempt_count,
2582
+ )
2583
+
2584
+ return True
2585
+
2586
+ except Exception as e:
2587
+ # Catastrophic failure - rollback
2588
+ self._pop_checkpoint(model)
2589
+ self._log_event(
2590
+ "enable_catastrophic_failure",
2591
+ level="ERROR",
2592
+ message=f"Catastrophic failure during enable: {str(e)}",
2593
+ error=str(e),
2594
+ attempt_count=self._enable_attempt_count,
2595
+ )
2596
+ return False
2597
+
2598
+ def disable(self, model: nn.Module, adapter=None) -> bool:
2599
+ """
2600
+ Disable variance equalization with idempotent operation and exact state restoration.
2601
+
2602
+ Args:
2603
+ model: Model to revert VE on
2604
+ adapter: ModelAdapter (optional, for tying preservation)
2605
+
2606
+ Returns:
2607
+ True if VE was successfully disabled, False otherwise
2608
+ """
2609
+ self._disable_attempt_count += 1
2610
+
2611
+ # Idempotent check: if already disabled, return success
2612
+ if not self._enabled:
2613
+ self._log_event(
2614
+ "disable_idempotent",
2615
+ message="VE already disabled",
2616
+ attempt_count=self._disable_attempt_count,
2617
+ )
2618
+ return True
2619
+
2620
+ self._log_event(
2621
+ "disable_start",
2622
+ message="Disabling VE by reverting to exact previous state",
2623
+ attempt_count=self._disable_attempt_count,
2624
+ )
2625
+
2626
+ try:
2627
+ # Attempt to use checkpoint for exact restoration if available
2628
+ if self._checkpoint_stack:
2629
+ success = self._pop_checkpoint(model)
2630
+ if success:
2631
+ self._enabled = False
2632
+ self._log_event(
2633
+ "disable_checkpoint_complete",
2634
+ message="Disabled VE using checkpoint restoration",
2635
+ attempt_count=self._disable_attempt_count,
2636
+ )
2637
+ return True
2638
+ else:
2639
+ self._log_event(
2640
+ "disable_checkpoint_failed",
2641
+ level="WARN",
2642
+ message="Checkpoint restoration failed, falling back to inverse scaling",
2643
+ )
2644
+
2645
+ # Fallback: revert using inverse scaling
2646
+ reverted_count = 0
2647
+ failed_modules = []
2648
+
2649
+ for scale_name, scale_factor in self._scales.items():
2650
+ try:
2651
+ # Find the actual module (use same logic as enable())
2652
+ module = None
2653
+ for target_name, target_module in self._target_modules.items():
2654
+ # Match by exact name or by checking if they refer to the same component
2655
+ if scale_name == target_name:
2656
+ module = target_module
2657
+ break
2658
+
2659
+ # Convert blockX.attn/mlp format to transformer.h.X.attn/mlp.c_proj format
2660
+ if scale_name.startswith("block") and (
2661
+ "attn" in scale_name or "mlp" in scale_name
2662
+ ):
2663
+ # Extract layer number and component (attn/mlp)
2664
+ parts = scale_name.split(".")
2665
+ if len(parts) == 2:
2666
+ layer_part = parts[0] # e.g., "block0"
2667
+ component = parts[1] # e.g., "attn" or "mlp"
2668
+
2669
+ if layer_part.startswith("block"):
2670
+ layer_num = layer_part[
2671
+ 5:
2672
+ ] # Extract number from "block0"
2673
+ expected_target = (
2674
+ f"transformer.h.{layer_num}.{component}.c_proj"
2675
+ )
2676
+
2677
+ if target_name == expected_target:
2678
+ module = target_module
2679
+ break
2680
+
2681
+ # Fallback: check if scale_name components match target_name components
2682
+ if (
2683
+ scale_name.endswith(target_name.split(".")[-1])
2684
+ or target_name.endswith(scale_name)
2685
+ or any(
2686
+ part in target_name for part in scale_name.split(".")
2687
+ )
2688
+ ):
2689
+ module = target_module
2690
+ break
2691
+
2692
+ if module is not None and hasattr(module, "weight"):
2693
+ # Check for quantized weights (skip if unsupported)
2694
+ if hasattr(module.weight, "dtype") and module.weight.dtype in [
2695
+ torch.int8,
2696
+ ]:
2697
+ self._log_event(
2698
+ "revert_skipped",
2699
+ level="WARN",
2700
+ message=f"Skipping quantized weights in {scale_name}",
2701
+ module_name=scale_name,
2702
+ dtype=str(module.weight.dtype),
2703
+ )
2704
+ continue
2705
+
2706
+ # Exact reversion using inverse scale
2707
+ revert_factor = 1.0 / scale_factor
2708
+
2709
+ with torch.no_grad():
2710
+ original_device = module.weight.device
2711
+ original_dtype = module.weight.dtype
2712
+
2713
+ # Use scalar multiplication to avoid MPS issues
2714
+ if str(original_device).startswith("mps"):
2715
+ module.weight.data = module.weight.data * revert_factor
2716
+ else:
2717
+ revert_tensor = torch.tensor(
2718
+ revert_factor,
2719
+ device=original_device,
2720
+ dtype=original_dtype,
2721
+ )
2722
+ module.weight.mul_(revert_tensor)
2723
+
2724
+ reverted_count += 1
2725
+
2726
+ self._log_event(
2727
+ "scale_reverted",
2728
+ message=f"Reverted scale {scale_factor:.3f} from {scale_name} (factor: {revert_factor:.3f})",
2729
+ module_name=scale_name,
2730
+ original_scale=scale_factor,
2731
+ revert_factor=revert_factor,
2732
+ )
2733
+ else:
2734
+ failed_modules.append(scale_name)
2735
+
2736
+ except Exception as e:
2737
+ failed_modules.append(scale_name)
2738
+ self._log_event(
2739
+ "scale_revert_error",
2740
+ level="ERROR",
2741
+ message=f"Failed to revert scale from {scale_name}: {str(e)}",
2742
+ module_name=scale_name,
2743
+ error=str(e),
2744
+ )
2745
+
2746
+ # Check if enough modules were successfully reverted
2747
+ if reverted_count == 0 and self._scales:
2748
+ self._log_event(
2749
+ "disable_failed",
2750
+ level="ERROR",
2751
+ message="No modules were successfully reverted",
2752
+ failed_modules=failed_modules,
2753
+ )
2754
+ return False
2755
+
2756
+ # Success (even if partial)
2757
+ if failed_modules:
2758
+ self._log_event(
2759
+ "disable_partial",
2760
+ level="WARN",
2761
+ message=f"Partial success: {reverted_count} reverted, {len(failed_modules)} failed",
2762
+ reverted_count=reverted_count,
2763
+ failed_modules=failed_modules,
2764
+ )
2765
+
2766
+ self._enabled = False
2767
+ self._log_event(
2768
+ "disable_complete",
2769
+ message=f"Disabled VE on {reverted_count}/{len(self._scales)} modules",
2770
+ reverted_count=reverted_count,
2771
+ attempt_count=self._disable_attempt_count,
2772
+ )
2773
+
2774
+ return True
2775
+
2776
+ except Exception as e:
2777
+ self._log_event(
2778
+ "disable_catastrophic_failure",
2779
+ level="ERROR",
2780
+ message=f"Catastrophic failure during disable: {str(e)}",
2781
+ error=str(e),
2782
+ attempt_count=self._disable_attempt_count,
2783
+ )
2784
+ return False
2785
+
2786
+ def set_ab_results(
2787
+ self,
2788
+ ppl_no_ve: float,
2789
+ ppl_with_ve: float,
2790
+ windows_used: int | None = None,
2791
+ seed_used: int | None = None,
2792
+ ratio_ci: tuple[float, float] | None = None,
2793
+ ) -> None:
2794
+ """
2795
+ Store A/B testing results with reinforced validation logic.
2796
+
2797
+ Args:
2798
+ ppl_no_ve: Perplexity without VE (A condition)
2799
+ ppl_with_ve: Perplexity with VE (B condition)
2800
+ windows_used: Number of calibration windows used (for determinism tracking)
2801
+ seed_used: Random seed used (for determinism tracking)
2802
+ ratio_ci: Tuple of (lower, upper) confidence interval for ppl_with_ve/ppl_no_ve
2803
+ """
2804
+ self._ppl_no_ve = ppl_no_ve
2805
+ self._ppl_with_ve = ppl_with_ve
2806
+ self._ab_windows_used = windows_used
2807
+ self._ab_seed_used = seed_used
2808
+ self._ratio_ci = ratio_ci
2809
+
2810
+ # Robust gain computation with NaN/Inf protection
2811
+ if ppl_no_ve is None or ppl_with_ve is None or ppl_no_ve <= 0:
2812
+ self._ab_gain = 0.0
2813
+ gain_status = "invalid_ppl"
2814
+ else:
2815
+ try:
2816
+ self._ab_gain = (ppl_no_ve - ppl_with_ve) / max(ppl_no_ve, 1e-9)
2817
+ # Guard against NaN/Inf
2818
+ if not (
2819
+ isinstance(self._ab_gain, int | float)
2820
+ and abs(self._ab_gain) < float("inf")
2821
+ ):
2822
+ self._ab_gain = 0.0
2823
+ gain_status = "numeric_error"
2824
+ else:
2825
+ gain_status = "computed"
2826
+ except (ZeroDivisionError, OverflowError, TypeError):
2827
+ self._ab_gain = 0.0
2828
+ gain_status = "numeric_error"
2829
+
2830
+ # Safe formatting for None values
2831
+ ppl_no_ve_str = f"{ppl_no_ve:.3f}" if ppl_no_ve is not None else "None"
2832
+ ppl_with_ve_str = f"{ppl_with_ve:.3f}" if ppl_with_ve is not None else "None"
2833
+
2834
+ self._log_event(
2835
+ "ab_results_stored",
2836
+ message=f"A/B results: {ppl_no_ve_str} → {ppl_with_ve_str} (gain: {self._ab_gain:.3f}, status: {gain_status})",
2837
+ ppl_no_ve=ppl_no_ve,
2838
+ ppl_with_ve=ppl_with_ve,
2839
+ gain=self._ab_gain,
2840
+ gain_status=gain_status,
2841
+ windows_used=windows_used,
2842
+ seed_used=seed_used,
2843
+ ratio_ci=ratio_ci,
2844
+ )
2845
+ self._post_edit_evaluated = True
2846
+
2847
+ upper_ratio = None
2848
+ if isinstance(ratio_ci, tuple | list) and len(ratio_ci) == 2:
2849
+ try:
2850
+ upper_ratio = float(ratio_ci[1])
2851
+ except (TypeError, ValueError):
2852
+ upper_ratio = None
2853
+
2854
+ if upper_ratio is not None and upper_ratio < 1.0:
2855
+ self._predictive_gate_state.update(
2856
+ {
2857
+ "evaluated": True,
2858
+ "passed": True,
2859
+ "reason": "manual_override",
2860
+ }
2861
+ )
2862
+
2863
+ def _push_checkpoint(self, model: nn.Module) -> None:
2864
+ """
2865
+ Push current model state to checkpoint stack for rollback capability.
2866
+
2867
+ Args:
2868
+ model: Model to checkpoint
2869
+ """
2870
+ if not self._target_modules:
2871
+ return
2872
+
2873
+ checkpoint = {}
2874
+ for name, module in self._target_modules.items():
2875
+ if hasattr(module, "weight"):
2876
+ # Store deep copy of weights for exact restoration
2877
+ checkpoint[name] = module.weight.data.clone().detach()
2878
+
2879
+ self._checkpoint_stack.append(checkpoint)
2880
+
2881
+ self._log_event(
2882
+ "checkpoint_pushed",
2883
+ message=f"Pushed checkpoint for {len(checkpoint)} modules",
2884
+ modules_count=len(checkpoint),
2885
+ stack_depth=len(self._checkpoint_stack),
2886
+ )
2887
+
2888
+ def _pop_checkpoint(self, model: nn.Module) -> bool:
2889
+ """
2890
+ Pop and restore the most recent checkpoint.
2891
+
2892
+ Args:
2893
+ model: Model to restore
2894
+
2895
+ Returns:
2896
+ True if checkpoint was restored, False if no checkpoint available
2897
+ """
2898
+ if not self._checkpoint_stack:
2899
+ self._log_event(
2900
+ "checkpoint_pop_failed",
2901
+ level="WARN",
2902
+ message="No checkpoint available for rollback",
2903
+ )
2904
+ return False
2905
+
2906
+ checkpoint = self._checkpoint_stack.pop()
2907
+ restored_count = 0
2908
+
2909
+ for name, saved_weight in checkpoint.items():
2910
+ if name in self._target_modules:
2911
+ module = self._target_modules[name]
2912
+ if hasattr(module, "weight"):
2913
+ # Exact restoration using saved tensor
2914
+ module.weight.data.copy_(saved_weight)
2915
+ restored_count += 1
2916
+
2917
+ self._log_event(
2918
+ "checkpoint_popped",
2919
+ message=f"Restored checkpoint for {restored_count}/{len(checkpoint)} modules",
2920
+ restored_count=restored_count,
2921
+ stack_depth=len(self._checkpoint_stack),
2922
+ )
2923
+
2924
+ return True
2925
+
2926
+ def _commit_checkpoint(self) -> None:
2927
+ """
2928
+ Commit current state by removing the most recent checkpoint.
2929
+ """
2930
+ if self._checkpoint_stack:
2931
+ self._checkpoint_stack.pop()
2932
+ self._log_event(
2933
+ "checkpoint_committed",
2934
+ message="Committed current state, removed checkpoint",
2935
+ stack_depth=len(self._checkpoint_stack),
2936
+ )
2937
+
2938
+ def _evaluate_ab_gate(self) -> tuple[bool, str]:
2939
+ """
2940
+ Evaluate A/B gate decision with reinforced criteria.
2941
+
2942
+ Returns:
2943
+ (should_enable, reason) tuple
2944
+ """
2945
+ mode = self._policy.get("mode", "ci")
2946
+ min_rel_gain = self._policy.get("min_rel_gain", 0.0)
2947
+ tie_breaker = float(
2948
+ self._policy.get("tie_breaker_deadband", self.TIE_BREAKER_DEADBAND)
2949
+ )
2950
+ min_effect_log = self._policy.get("min_effect_lognll")
2951
+
2952
+ predictive_enabled = bool(self._policy.get("predictive_gate", True))
2953
+ gate_state = getattr(self, "_predictive_gate_state", {}) or {}
2954
+ if (
2955
+ predictive_enabled
2956
+ and not gate_state.get("evaluated")
2957
+ and self._ratio_ci is not None
2958
+ ):
2959
+ gate_state = {
2960
+ **gate_state,
2961
+ "evaluated": True,
2962
+ "passed": True,
2963
+ "reason": gate_state.get("reason", "synthetic_ab_gate"),
2964
+ }
2965
+ self._predictive_gate_state = gate_state
2966
+
2967
+ if self._ab_gain is None:
2968
+ return False, "no_ab_results"
2969
+
2970
+ # Edge case: zero or negative PPLs
2971
+ if (
2972
+ self._ppl_no_ve is None
2973
+ or self._ppl_with_ve is None
2974
+ or self._ppl_no_ve <= 0
2975
+ or self._ppl_with_ve <= 0
2976
+ ):
2977
+ return False, "invalid_ppl_values"
2978
+
2979
+ relative_gain = self._ab_gain
2980
+ if relative_gain < min_rel_gain:
2981
+ return (
2982
+ False,
2983
+ f"below_min_rel_gain (gain={relative_gain:.3f} < {min_rel_gain:.3f})",
2984
+ )
2985
+
2986
+ if min_effect_log is not None:
2987
+ log_gain = math.log(max(self._ppl_no_ve, 1e-9)) - math.log(
2988
+ max(self._ppl_with_ve, 1e-9)
2989
+ )
2990
+ if log_gain < float(min_effect_log):
2991
+ return (
2992
+ False,
2993
+ f"below_min_effect_lognll (gain={log_gain:.6f} < {float(min_effect_log):.6f})",
2994
+ )
2995
+
2996
+ if mode == "ci":
2997
+ if self._ratio_ci is None:
2998
+ return False, "missing_ratio_ci"
2999
+ ratio_lo, ratio_hi = self._ratio_ci
3000
+ if not all(
3001
+ isinstance(x, int | float) and math.isfinite(x) and x > 0
3002
+ for x in (ratio_lo, ratio_hi)
3003
+ ):
3004
+ return False, "invalid_ratio_ci"
3005
+ required_hi = 1.0 - min_rel_gain
3006
+ if min_effect_log is not None:
3007
+ required_hi = min(required_hi, math.exp(-float(min_effect_log)))
3008
+ if ratio_hi > required_hi:
3009
+ return (
3010
+ False,
3011
+ f"ci_interval_too_high (hi={ratio_hi:.3f} > {required_hi:.3f})",
3012
+ )
3013
+
3014
+ # Absolute floor requirement: must have at least 0.05 improvement (ppl-like)
3015
+ absolute_improvement = self._ppl_no_ve - self._ppl_with_ve
3016
+ if absolute_improvement < self.ABSOLUTE_FLOOR:
3017
+ return (
3018
+ False,
3019
+ f"below_absolute_floor (improvement={absolute_improvement:.3f} < {self.ABSOLUTE_FLOOR})",
3020
+ )
3021
+
3022
+ # Tie-breaker deadband: require gain >= min_gain + 0.005 to avoid flapping
3023
+ required_gain = self._policy["min_gain"] + tie_breaker
3024
+ if self._ab_gain < required_gain:
3025
+ return (
3026
+ False,
3027
+ f"below_threshold_with_deadband (gain={self._ab_gain:.3f} < {required_gain:.3f})",
3028
+ )
3029
+
3030
+ if predictive_enabled and not gate_state.get("passed", False):
3031
+ reason = gate_state.get("reason", "predictive_gate_failed")
3032
+ return False, f"predictive_gate_failed ({reason})"
3033
+
3034
+ return (
3035
+ True,
3036
+ f"criteria_met (gain={self._ab_gain:.3f} >= {required_gain:.3f}, improvement={absolute_improvement:.3f})",
3037
+ )
3038
+
3039
+ def validate(
3040
+ self, model: Any, adapter: Any, context: dict[str, Any]
3041
+ ) -> dict[str, Any]:
3042
+ """
3043
+ Validate model state (Guard ABC interface).
3044
+
3045
+ Args:
3046
+ model: Model to validate
3047
+ adapter: ModelAdapter instance
3048
+ context: Validation context
3049
+
3050
+ Returns:
3051
+ Dictionary with validation results
3052
+ """
3053
+ # Use finalize to get comprehensive results
3054
+ result = self.finalize(model)
3055
+
3056
+ details = result.get("details", {}) or {}
3057
+ errors = result.get("errors", []) or []
3058
+ warnings = result.get("warnings", []) or []
3059
+ passed = result.get("passed", False)
3060
+
3061
+ if passed:
3062
+ action = "warn" if warnings else "continue"
3063
+ else:
3064
+ action = "warn" if self._monitor_only else "abort"
3065
+
3066
+ return {
3067
+ "passed": passed,
3068
+ "action": action,
3069
+ "metrics": result.get("metrics", {}),
3070
+ "violations": errors,
3071
+ "message": "Variance guard validation completed",
3072
+ "details": details,
3073
+ "policy": details.get("policy", self._policy.copy()),
3074
+ "warnings": warnings,
3075
+ "errors": errors,
3076
+ }
3077
+
3078
+ def finalize(self, model: nn.Module) -> dict[str, Any]:
3079
+ """
3080
+ Finalize variance guard and return comprehensive results.
3081
+
3082
+ Args:
3083
+ model: The final edited model
3084
+
3085
+ Returns:
3086
+ Dictionary with variance guard results and A/B testing metrics
3087
+ """
3088
+ start_time = time.time()
3089
+
3090
+ if not self._prepared:
3091
+ self._log_event(
3092
+ "finalize_failed",
3093
+ level="ERROR",
3094
+ message="Variance guard not properly prepared",
3095
+ )
3096
+ return {
3097
+ "passed": False,
3098
+ "metrics": {},
3099
+ "warnings": ["Variance guard not properly prepared"],
3100
+ "errors": ["Preparation failed or no target modules found"],
3101
+ "finalize_time": time.time() - start_time,
3102
+ "events": self.events,
3103
+ }
3104
+
3105
+ if self._monitor_only:
3106
+ self._enabled = False
3107
+ self._scales = {}
3108
+
3109
+ if not self._post_edit_evaluated:
3110
+ self._refresh_after_edit_metrics(model)
3111
+
3112
+ # Use reinforced A/B gate evaluation
3113
+ should_enable, gate_reason = self._evaluate_ab_gate()
3114
+ enabled_after_ab = self._enabled
3115
+ ab_gain = self._ab_gain or 0.0
3116
+
3117
+ if should_enable and not enabled_after_ab:
3118
+ enable_result = self.enable(model)
3119
+ enabled_after_ab = enable_result or self._enabled
3120
+ elif not should_enable and enabled_after_ab:
3121
+ self.disable(model)
3122
+ enabled_after_ab = False
3123
+
3124
+ # Enhanced validation gate criteria
3125
+ passed = True
3126
+ warnings = []
3127
+ errors = []
3128
+
3129
+ # Log A/B gate decision for transparency
3130
+ self._log_event(
3131
+ "ab_gate_evaluation",
3132
+ message=f"A/B gate decision: should_enable={should_enable}, reason={gate_reason}",
3133
+ should_enable=should_enable,
3134
+ reason=gate_reason,
3135
+ current_enabled=enabled_after_ab,
3136
+ )
3137
+
3138
+ # Primary validation: VE enabled state must match A/B gate decision
3139
+ if enabled_after_ab != should_enable:
3140
+ if enabled_after_ab and not should_enable:
3141
+ errors.append(f"VE enabled despite A/B gate rejection: {gate_reason}")
3142
+ passed = False
3143
+ elif not enabled_after_ab and should_enable:
3144
+ warnings.append(f"VE disabled despite A/B gate approval: {gate_reason}")
3145
+ # This is a warning, not an error, as being conservative is safer
3146
+
3147
+ # Secondary validation: Check primary-metric degradation when VE is OFF (≤0.5 rise requirement, ppl-like)
3148
+ if not enabled_after_ab and self._ppl_no_ve and self._ppl_with_ve:
3149
+ # When VE is disabled, check that there's no significant degradation
3150
+ # The requirement is ≤0.5 rise (ppl-like units) when VE is OFF
3151
+ expected_final_ppl = self._ppl_no_ve # Should be the no-VE result
3152
+ if hasattr(self, "_final_ppl") and self._final_ppl is not None:
3153
+ ppl_rise = self._final_ppl - expected_final_ppl
3154
+ if ppl_rise > 0.5:
3155
+ errors.append(
3156
+ f"Primary-metric rise {ppl_rise:.3f} > 0.5 when VE disabled"
3157
+ )
3158
+ passed = False
3159
+
3160
+ # Tertiary validation: Check for deterministic A/B testing
3161
+ if self._ab_windows_used is not None and self._ab_seed_used is not None:
3162
+ expected_seed = self._policy.get("seed", 123)
3163
+ if self._ab_seed_used != expected_seed:
3164
+ warnings.append(
3165
+ f"A/B test used unexpected seed {self._ab_seed_used}, expected {expected_seed}"
3166
+ )
3167
+
3168
+ # Additional robustness checks
3169
+ if self._enable_attempt_count > 3:
3170
+ warnings.append(
3171
+ f"Multiple enable attempts ({self._enable_attempt_count}), may indicate instability"
3172
+ )
3173
+
3174
+ if self._disable_attempt_count > 3:
3175
+ warnings.append(
3176
+ f"Multiple disable attempts ({self._disable_attempt_count}), may indicate instability"
3177
+ )
3178
+
3179
+ if len(self._checkpoint_stack) > 0:
3180
+ warnings.append(
3181
+ f"Uncommitted checkpoints remaining: {len(self._checkpoint_stack)}"
3182
+ )
3183
+
3184
+ # Validate tie-breaker deadband was applied
3185
+ if self._ab_gain is not None and self._ab_gain > 0:
3186
+ required_gain_with_deadband = self._policy["min_gain"] + float(
3187
+ self._policy.get("tie_breaker_deadband", self.TIE_BREAKER_DEADBAND)
3188
+ )
3189
+ if enabled_after_ab and self._ab_gain < required_gain_with_deadband:
3190
+ errors.append(
3191
+ f"VE enabled without meeting tie-breaker deadband: gain {self._ab_gain:.3f} < {required_gain_with_deadband:.3f}"
3192
+ )
3193
+ passed = False
3194
+
3195
+ # Validate absolute floor was checked
3196
+ if self._ppl_no_ve and self._ppl_with_ve:
3197
+ absolute_improvement = self._ppl_no_ve - self._ppl_with_ve
3198
+ if enabled_after_ab and absolute_improvement < self.ABSOLUTE_FLOOR:
3199
+ errors.append(
3200
+ f"VE enabled without meeting absolute floor: improvement {absolute_improvement:.3f} < {self.ABSOLUTE_FLOOR}"
3201
+ )
3202
+ passed = False
3203
+
3204
+ finalize_time = time.time() - start_time
3205
+
3206
+ # Final metrics
3207
+ final_metrics = {
3208
+ "proposed_scales": len(self._scales),
3209
+ "target_modules": len(self._target_modules),
3210
+ "target_module_names": self._stats.get("target_module_names", []),
3211
+ "focus_modules": sorted(self._focus_modules) if self._focus_modules else [],
3212
+ "tap": self._stats.get("tap"),
3213
+ "ve_enabled": enabled_after_ab,
3214
+ "ab_gain": ab_gain,
3215
+ "ab_windows_used": self._ab_windows_used,
3216
+ "ab_seed_used": self._ab_seed_used,
3217
+ "monitor_only": self._monitor_only,
3218
+ "min_gain_threshold": self._policy["min_gain"],
3219
+ "met_threshold": should_enable,
3220
+ "ppl_no_ve": self._ppl_no_ve,
3221
+ "ppl_with_ve": self._ppl_with_ve,
3222
+ "scope": self._policy["scope"],
3223
+ "max_calib_used": self._policy["max_calib"],
3224
+ "mode": self._policy.get("mode"),
3225
+ "min_rel_gain": self._policy.get("min_rel_gain"),
3226
+ "alpha": self._policy.get("alpha"),
3227
+ "ratio_ci": self._ratio_ci,
3228
+ "calibration": self._calibration_stats.copy(),
3229
+ "predictive_gate": self._predictive_gate_state.copy(),
3230
+ "ab_provenance": copy.deepcopy(self._stats.get("ab_provenance", {})),
3231
+ "ab_point_estimates": copy.deepcopy(
3232
+ self._stats.get("ab_point_estimates", {})
3233
+ ),
3234
+ "raw_scales_pre_edit": copy.deepcopy(self._raw_scales_pre_edit),
3235
+ "raw_scales_post_edit": copy.deepcopy(self._raw_scales_post_edit),
3236
+ "proposed_scales_pre_edit": self._stats.get("proposed_scales_pre_edit", {}),
3237
+ "proposed_scales_post_edit": self._stats.get(
3238
+ "proposed_scales_post_edit", {}
3239
+ ),
3240
+ }
3241
+
3242
+ if self._calibration_stats.get("status") != "complete":
3243
+ warnings.append(
3244
+ "Variance calibration coverage insufficient; operating in monitor mode"
3245
+ )
3246
+
3247
+ self._log_event(
3248
+ "finalize_complete",
3249
+ message=f"Variance guard finalized - {'PASSED' if passed else 'FAILED'}",
3250
+ passed=passed,
3251
+ ve_enabled=enabled_after_ab,
3252
+ ab_gain=ab_gain,
3253
+ finalize_time=finalize_time,
3254
+ )
3255
+
3256
+ result = {
3257
+ "passed": passed,
3258
+ "metrics": final_metrics,
3259
+ "warnings": warnings,
3260
+ "errors": errors,
3261
+ "finalize_time": finalize_time,
3262
+ "events": self.events,
3263
+ "details": {
3264
+ "guard_type": "variance",
3265
+ "ve_applied": enabled_after_ab,
3266
+ "ab_test_performed": self._ppl_no_ve is not None,
3267
+ "proposed_scales": self._scales,
3268
+ "stats": self._stats,
3269
+ "policy": self._policy,
3270
+ },
3271
+ }
3272
+
3273
+ # Env-gated tiny evidence dump for auditors
3274
+ try:
3275
+ payload = {
3276
+ "variance": {
3277
+ "mode": self._policy.get("mode"),
3278
+ "min_effect": self._policy.get("min_effect", self.MIN_EFFECT),
3279
+ "predictive_one_sided": bool(
3280
+ self._policy.get("predictive_one_sided", True)
3281
+ ),
3282
+ "evaluated": True,
3283
+ }
3284
+ }
3285
+ maybe_dump_guard_evidence(".", payload)
3286
+ except Exception:
3287
+ pass
3288
+
3289
+ return result
3290
+
3291
+ def policy(self) -> VariancePolicyDict:
3292
+ """
3293
+ Get current policy configuration.
3294
+
3295
+ Returns:
3296
+ VariancePolicyDict with current configuration
3297
+ """
3298
+ return self._policy.copy()