invarlock 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- invarlock/__init__.py +33 -0
- invarlock/__main__.py +10 -0
- invarlock/_data/runtime/profiles/ci_cpu.yaml +15 -0
- invarlock/_data/runtime/profiles/release.yaml +23 -0
- invarlock/_data/runtime/tiers.yaml +76 -0
- invarlock/adapters/__init__.py +102 -0
- invarlock/adapters/_capabilities.py +45 -0
- invarlock/adapters/auto.py +99 -0
- invarlock/adapters/base.py +530 -0
- invarlock/adapters/base_types.py +85 -0
- invarlock/adapters/hf_bert.py +852 -0
- invarlock/adapters/hf_gpt2.py +403 -0
- invarlock/adapters/hf_llama.py +485 -0
- invarlock/adapters/hf_mixin.py +383 -0
- invarlock/adapters/hf_onnx.py +112 -0
- invarlock/adapters/hf_t5.py +137 -0
- invarlock/adapters/py.typed +1 -0
- invarlock/assurance/__init__.py +43 -0
- invarlock/cli/__init__.py +8 -0
- invarlock/cli/__main__.py +8 -0
- invarlock/cli/_evidence.py +25 -0
- invarlock/cli/_json.py +75 -0
- invarlock/cli/adapter_auto.py +162 -0
- invarlock/cli/app.py +287 -0
- invarlock/cli/commands/__init__.py +26 -0
- invarlock/cli/commands/certify.py +403 -0
- invarlock/cli/commands/doctor.py +1358 -0
- invarlock/cli/commands/explain_gates.py +151 -0
- invarlock/cli/commands/export_html.py +100 -0
- invarlock/cli/commands/plugins.py +1331 -0
- invarlock/cli/commands/report.py +354 -0
- invarlock/cli/commands/run.py +4146 -0
- invarlock/cli/commands/verify.py +1040 -0
- invarlock/cli/config.py +396 -0
- invarlock/cli/constants.py +68 -0
- invarlock/cli/device.py +92 -0
- invarlock/cli/doctor_helpers.py +74 -0
- invarlock/cli/errors.py +6 -0
- invarlock/cli/overhead_utils.py +60 -0
- invarlock/cli/provenance.py +66 -0
- invarlock/cli/utils.py +41 -0
- invarlock/config.py +56 -0
- invarlock/core/__init__.py +62 -0
- invarlock/core/abi.py +15 -0
- invarlock/core/api.py +274 -0
- invarlock/core/auto_tuning.py +317 -0
- invarlock/core/bootstrap.py +226 -0
- invarlock/core/checkpoint.py +221 -0
- invarlock/core/contracts.py +73 -0
- invarlock/core/error_utils.py +64 -0
- invarlock/core/events.py +298 -0
- invarlock/core/exceptions.py +95 -0
- invarlock/core/registry.py +481 -0
- invarlock/core/retry.py +146 -0
- invarlock/core/runner.py +2041 -0
- invarlock/core/types.py +154 -0
- invarlock/edits/__init__.py +12 -0
- invarlock/edits/_edit_utils.py +249 -0
- invarlock/edits/_external_utils.py +268 -0
- invarlock/edits/noop.py +47 -0
- invarlock/edits/py.typed +1 -0
- invarlock/edits/quant_rtn.py +801 -0
- invarlock/edits/registry.py +166 -0
- invarlock/eval/__init__.py +23 -0
- invarlock/eval/bench.py +1207 -0
- invarlock/eval/bootstrap.py +50 -0
- invarlock/eval/data.py +2052 -0
- invarlock/eval/metrics.py +2167 -0
- invarlock/eval/primary_metric.py +767 -0
- invarlock/eval/probes/__init__.py +24 -0
- invarlock/eval/probes/fft.py +139 -0
- invarlock/eval/probes/mi.py +213 -0
- invarlock/eval/probes/post_attention.py +323 -0
- invarlock/eval/providers/base.py +67 -0
- invarlock/eval/providers/seq2seq.py +111 -0
- invarlock/eval/providers/text_lm.py +113 -0
- invarlock/eval/providers/vision_text.py +93 -0
- invarlock/eval/py.typed +1 -0
- invarlock/guards/__init__.py +18 -0
- invarlock/guards/_contracts.py +9 -0
- invarlock/guards/invariants.py +640 -0
- invarlock/guards/policies.py +805 -0
- invarlock/guards/py.typed +1 -0
- invarlock/guards/rmt.py +2097 -0
- invarlock/guards/spectral.py +1419 -0
- invarlock/guards/tier_config.py +354 -0
- invarlock/guards/variance.py +3298 -0
- invarlock/guards_ref/__init__.py +15 -0
- invarlock/guards_ref/rmt_ref.py +40 -0
- invarlock/guards_ref/spectral_ref.py +135 -0
- invarlock/guards_ref/variance_ref.py +60 -0
- invarlock/model_profile.py +353 -0
- invarlock/model_utils.py +221 -0
- invarlock/observability/__init__.py +10 -0
- invarlock/observability/alerting.py +535 -0
- invarlock/observability/core.py +546 -0
- invarlock/observability/exporters.py +565 -0
- invarlock/observability/health.py +588 -0
- invarlock/observability/metrics.py +457 -0
- invarlock/observability/py.typed +1 -0
- invarlock/observability/utils.py +553 -0
- invarlock/plugins/__init__.py +12 -0
- invarlock/plugins/hello_guard.py +33 -0
- invarlock/plugins/hf_awq_adapter.py +82 -0
- invarlock/plugins/hf_bnb_adapter.py +79 -0
- invarlock/plugins/hf_gptq_adapter.py +78 -0
- invarlock/plugins/py.typed +1 -0
- invarlock/py.typed +1 -0
- invarlock/reporting/__init__.py +7 -0
- invarlock/reporting/certificate.py +3221 -0
- invarlock/reporting/certificate_schema.py +244 -0
- invarlock/reporting/dataset_hashing.py +215 -0
- invarlock/reporting/guards_analysis.py +948 -0
- invarlock/reporting/html.py +32 -0
- invarlock/reporting/normalizer.py +235 -0
- invarlock/reporting/policy_utils.py +517 -0
- invarlock/reporting/primary_metric_utils.py +265 -0
- invarlock/reporting/render.py +1442 -0
- invarlock/reporting/report.py +903 -0
- invarlock/reporting/report_types.py +278 -0
- invarlock/reporting/utils.py +175 -0
- invarlock/reporting/validate.py +631 -0
- invarlock/security.py +176 -0
- invarlock/sparsity_utils.py +323 -0
- invarlock/utils/__init__.py +150 -0
- invarlock/utils/digest.py +45 -0
- invarlock-0.2.0.dist-info/METADATA +586 -0
- invarlock-0.2.0.dist-info/RECORD +132 -0
- invarlock-0.2.0.dist-info/WHEEL +5 -0
- invarlock-0.2.0.dist-info/entry_points.txt +20 -0
- invarlock-0.2.0.dist-info/licenses/LICENSE +201 -0
- invarlock-0.2.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,3298 @@
|
|
|
1
|
+
"""
|
|
2
|
+
InvarLock – Safety: Data-Driven Variance Equalization (DD-VE)
|
|
3
|
+
=========================================================
|
|
4
|
+
|
|
5
|
+
Branch-level variance equalizer for transformer blocks to maintain
|
|
6
|
+
stable residual stream dynamics after edits.
|
|
7
|
+
|
|
8
|
+
For each transformer block, measures the variance of residual branch
|
|
9
|
+
outputs (attention and MLP) and scales projection weights to maintain
|
|
10
|
+
Var(x_out) ≈ 1 when Var(x_in) ≈ 1.
|
|
11
|
+
"""
|
|
12
|
+
|
|
13
|
+
from __future__ import annotations
|
|
14
|
+
|
|
15
|
+
import copy
|
|
16
|
+
import fnmatch
|
|
17
|
+
import hashlib
|
|
18
|
+
import itertools
|
|
19
|
+
import math
|
|
20
|
+
import time
|
|
21
|
+
from collections import defaultdict
|
|
22
|
+
from collections.abc import Iterable, Sequence
|
|
23
|
+
from datetime import datetime
|
|
24
|
+
from typing import Any
|
|
25
|
+
|
|
26
|
+
import numpy as np
|
|
27
|
+
import torch
|
|
28
|
+
import torch.nn as nn
|
|
29
|
+
|
|
30
|
+
from invarlock.cli._evidence import maybe_dump_guard_evidence
|
|
31
|
+
from invarlock.core.api import Guard
|
|
32
|
+
from invarlock.core.bootstrap import compute_paired_delta_log_ci
|
|
33
|
+
|
|
34
|
+
from ._contracts import guard_assert
|
|
35
|
+
|
|
36
|
+
# Import the policy type and Guard interface
|
|
37
|
+
from .policies import VariancePolicyDict
|
|
38
|
+
|
|
39
|
+
__all__ = ["equalise_residual_variance", "equalise_branch_variance", "VarianceGuard"]
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
try: # Optional dependency: tqdm (progress bars)
|
|
43
|
+
from tqdm.auto import tqdm as _tqdm
|
|
44
|
+
except Exception: # pragma: no cover - exercised only when tqdm is absent
|
|
45
|
+
|
|
46
|
+
class _TqdmShim:
|
|
47
|
+
def __init__(self, iterable=None, total=None, **kwargs):
|
|
48
|
+
self._iterable = iterable
|
|
49
|
+
self.total = total
|
|
50
|
+
|
|
51
|
+
def __iter__(self):
|
|
52
|
+
if self._iterable is None:
|
|
53
|
+
return iter(())
|
|
54
|
+
return iter(self._iterable)
|
|
55
|
+
|
|
56
|
+
def __enter__(self):
|
|
57
|
+
return self
|
|
58
|
+
|
|
59
|
+
def __exit__(self, exc_type, exc, tb):
|
|
60
|
+
return False
|
|
61
|
+
|
|
62
|
+
def update(self, n: int = 1) -> None:
|
|
63
|
+
return None
|
|
64
|
+
|
|
65
|
+
def _tqdm(iterable=None, *args, **kwargs):
|
|
66
|
+
return _TqdmShim(iterable=iterable, **kwargs)
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
tqdm = _tqdm
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
def _unwrap_model(model: nn.Module) -> nn.Module:
|
|
73
|
+
"""Unwrap DataParallel/DDP wrappers to get the underlying model.
|
|
74
|
+
|
|
75
|
+
PyTorch's DataParallel and DistributedDataParallel wrap models with a
|
|
76
|
+
`.module` attribute. This function traverses that chain to get the
|
|
77
|
+
actual model, enabling consistent layer iteration regardless of how
|
|
78
|
+
the model is wrapped for training/inference.
|
|
79
|
+
"""
|
|
80
|
+
unwrapped = model
|
|
81
|
+
while hasattr(unwrapped, "module"):
|
|
82
|
+
unwrapped = unwrapped.module
|
|
83
|
+
return unwrapped
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
def _iter_transformer_layers(model: nn.Module):
|
|
87
|
+
"""Iterate over transformer layers in a model.
|
|
88
|
+
|
|
89
|
+
Handles multiple transformer architectures and automatically unwraps
|
|
90
|
+
DataParallel/DDP wrappers.
|
|
91
|
+
"""
|
|
92
|
+
# Unwrap DataParallel/DDP wrappers first
|
|
93
|
+
model = _unwrap_model(model)
|
|
94
|
+
|
|
95
|
+
# Handle different model architectures
|
|
96
|
+
if hasattr(model, "transformer") and hasattr(model.transformer, "h"):
|
|
97
|
+
# GPT-2 style
|
|
98
|
+
yield from model.transformer.h
|
|
99
|
+
elif hasattr(model, "model") and hasattr(model.model, "layers"):
|
|
100
|
+
# LLaMA style
|
|
101
|
+
yield from model.model.layers
|
|
102
|
+
elif hasattr(model, "encoder") and hasattr(model.encoder, "layer"):
|
|
103
|
+
# BERT style
|
|
104
|
+
yield from model.encoder.layer
|
|
105
|
+
elif hasattr(model, "decoder") and hasattr(model.decoder, "layers"):
|
|
106
|
+
# T5/BART decoder style
|
|
107
|
+
yield from model.decoder.layers
|
|
108
|
+
elif hasattr(model, "layers"):
|
|
109
|
+
# Generic transformer with top-level layers attribute
|
|
110
|
+
yield from model.layers
|
|
111
|
+
else:
|
|
112
|
+
# Fallback: look for modules with attention
|
|
113
|
+
for module in model.modules():
|
|
114
|
+
if hasattr(module, "attn") and hasattr(module, "mlp"):
|
|
115
|
+
yield module
|
|
116
|
+
|
|
117
|
+
|
|
118
|
+
@torch.no_grad()
|
|
119
|
+
def equalise_residual_variance(
|
|
120
|
+
model: nn.Module,
|
|
121
|
+
dataloader,
|
|
122
|
+
*,
|
|
123
|
+
windows: int = 32,
|
|
124
|
+
tol: float = 0.02,
|
|
125
|
+
scale_bias: bool = True,
|
|
126
|
+
seed: int = 42,
|
|
127
|
+
device: str | None = None,
|
|
128
|
+
allow_empty: bool = False,
|
|
129
|
+
clamp_range: tuple | None = (0.9, 1.1),
|
|
130
|
+
) -> dict[str, float]:
|
|
131
|
+
"""
|
|
132
|
+
Apply data-driven variance equalization to transformer branches.
|
|
133
|
+
|
|
134
|
+
This function measures the variance of each residual branch output
|
|
135
|
+
(attention-proj and MLP-proj) and scales projection weights so that
|
|
136
|
+
adding the branch back to the residual stream maintains stable variance.
|
|
137
|
+
|
|
138
|
+
The scaling factor alpha = 1 / sqrt(1 + Var(F)) is used, where F is the
|
|
139
|
+
branch output.
|
|
140
|
+
|
|
141
|
+
Args:
|
|
142
|
+
model: Transformer model to equalize
|
|
143
|
+
dataloader: DataLoader for calibration
|
|
144
|
+
windows: Number of calibration batches
|
|
145
|
+
tol: Tolerance for skipping near-unity scales
|
|
146
|
+
scale_bias: Whether to scale biases along with weights
|
|
147
|
+
seed: Random seed for reproducibility
|
|
148
|
+
device: Device to use (auto-detected if None)
|
|
149
|
+
allow_empty: Whether to allow empty dataloader (returns empty dict)
|
|
150
|
+
clamp_range: Optional (min, max) to clamp scaling factors (e.g., (0.9, 1.1))
|
|
151
|
+
|
|
152
|
+
Returns:
|
|
153
|
+
Dict mapping layer names to applied scaling factors
|
|
154
|
+
"""
|
|
155
|
+
torch.manual_seed(seed)
|
|
156
|
+
|
|
157
|
+
if device is None:
|
|
158
|
+
device = next(model.parameters()).device
|
|
159
|
+
else:
|
|
160
|
+
device = torch.device(device)
|
|
161
|
+
|
|
162
|
+
model.eval()
|
|
163
|
+
|
|
164
|
+
# Storage for variance measurements
|
|
165
|
+
hooks: dict[str, Any] = {}
|
|
166
|
+
sample_values: dict[str, list[float]] = defaultdict(list)
|
|
167
|
+
|
|
168
|
+
def _branch_hook(name):
|
|
169
|
+
def fn(_, __, out):
|
|
170
|
+
y = out[0] if isinstance(out, tuple) else out
|
|
171
|
+
y = y.detach().float()
|
|
172
|
+
# Skip if tensor has zero elements
|
|
173
|
+
if y.numel() == 0:
|
|
174
|
+
return
|
|
175
|
+
mean_square = float(y.pow(2).mean().item())
|
|
176
|
+
sample_values[name].append(mean_square)
|
|
177
|
+
|
|
178
|
+
return fn
|
|
179
|
+
|
|
180
|
+
# Register hooks on projection layers
|
|
181
|
+
for i, blk in enumerate(_iter_transformer_layers(model)):
|
|
182
|
+
# Handle GPT-2 style architecture
|
|
183
|
+
if hasattr(blk, "attn"):
|
|
184
|
+
# Check for c_proj (GPT-2) or out_proj (generic)
|
|
185
|
+
attn_proj = getattr(blk.attn, "c_proj", None) or getattr(
|
|
186
|
+
blk.attn, "out_proj", None
|
|
187
|
+
)
|
|
188
|
+
if attn_proj is not None:
|
|
189
|
+
name = f"block{i}.attn"
|
|
190
|
+
hooks[name] = attn_proj.register_forward_hook(_branch_hook(name))
|
|
191
|
+
|
|
192
|
+
if hasattr(blk, "mlp"):
|
|
193
|
+
# Check for c_proj (GPT-2) or down_proj (LLaMA) or fc2 (generic)
|
|
194
|
+
mlp_proj = (
|
|
195
|
+
getattr(blk.mlp, "c_proj", None)
|
|
196
|
+
or getattr(blk.mlp, "down_proj", None)
|
|
197
|
+
or getattr(blk.mlp, "fc2", None)
|
|
198
|
+
)
|
|
199
|
+
if mlp_proj is not None:
|
|
200
|
+
name = f"block{i}.mlp"
|
|
201
|
+
hooks[name] = mlp_proj.register_forward_hook(_branch_hook(name))
|
|
202
|
+
|
|
203
|
+
# Collect variance statistics
|
|
204
|
+
try:
|
|
205
|
+
it = itertools.islice(iter(dataloader), windows)
|
|
206
|
+
batches = list(it)
|
|
207
|
+
except (StopIteration, TypeError):
|
|
208
|
+
batches = []
|
|
209
|
+
|
|
210
|
+
if not batches and not allow_empty:
|
|
211
|
+
raise ValueError("Empty dataloader provided and allow_empty=False")
|
|
212
|
+
|
|
213
|
+
for batch in tqdm(batches, desc="DD-VE Calibration", leave=False):
|
|
214
|
+
if isinstance(batch, dict):
|
|
215
|
+
input_ids = batch.get("input_ids", batch.get("inputs", None))
|
|
216
|
+
elif isinstance(batch, tuple | list):
|
|
217
|
+
# Handle tuple/list from TensorDataset
|
|
218
|
+
input_ids = batch[0] if len(batch) > 0 else None
|
|
219
|
+
else:
|
|
220
|
+
input_ids = batch
|
|
221
|
+
|
|
222
|
+
if input_ids is not None:
|
|
223
|
+
# Convert to tensor if needed
|
|
224
|
+
if not isinstance(input_ids, torch.Tensor):
|
|
225
|
+
input_ids = torch.as_tensor(input_ids)
|
|
226
|
+
|
|
227
|
+
# Ensure input has batch dimension [batch, seq_len]
|
|
228
|
+
# HF models (GPT-2, etc.) expect 2-D input tensors
|
|
229
|
+
if input_ids.dim() == 1:
|
|
230
|
+
input_ids = input_ids.unsqueeze(0)
|
|
231
|
+
|
|
232
|
+
with torch.no_grad():
|
|
233
|
+
model(input_ids.to(device))
|
|
234
|
+
|
|
235
|
+
# Remove hooks
|
|
236
|
+
for h in hooks.values():
|
|
237
|
+
h.remove()
|
|
238
|
+
|
|
239
|
+
# Apply scaling factors
|
|
240
|
+
applied_scales: dict[str, float] = {}
|
|
241
|
+
|
|
242
|
+
for i, blk in enumerate(_iter_transformer_layers(model)):
|
|
243
|
+
# Handle attention projection
|
|
244
|
+
if hasattr(blk, "attn"):
|
|
245
|
+
attn_proj = getattr(blk.attn, "c_proj", None) or getattr(
|
|
246
|
+
blk.attn, "out_proj", None
|
|
247
|
+
)
|
|
248
|
+
if attn_proj is not None:
|
|
249
|
+
name = f"block{i}.attn"
|
|
250
|
+
values = sample_values.get(name, [])
|
|
251
|
+
if values:
|
|
252
|
+
tensor_vals = torch.tensor(values, dtype=torch.float64)
|
|
253
|
+
|
|
254
|
+
# Winsorize to remove extreme outliers (≈1-2%)
|
|
255
|
+
if tensor_vals.numel() >= 10:
|
|
256
|
+
lower = torch.quantile(tensor_vals, 0.02)
|
|
257
|
+
upper = torch.quantile(tensor_vals, 0.98)
|
|
258
|
+
tensor_vals = torch.clamp(
|
|
259
|
+
tensor_vals, lower.item(), upper.item()
|
|
260
|
+
)
|
|
261
|
+
|
|
262
|
+
group_count = 8 if tensor_vals.numel() >= 8 else tensor_vals.numel()
|
|
263
|
+
if group_count > 1:
|
|
264
|
+
chunks = torch.chunk(tensor_vals, group_count)
|
|
265
|
+
group_means = torch.stack([chunk.mean() for chunk in chunks])
|
|
266
|
+
var_F = torch.median(group_means).item()
|
|
267
|
+
else:
|
|
268
|
+
var_F = tensor_vals.mean().item()
|
|
269
|
+
|
|
270
|
+
alpha = (1.0 / max(var_F, 1e-9)) ** 0.5
|
|
271
|
+
|
|
272
|
+
# Apply clamping if specified
|
|
273
|
+
if clamp_range is not None:
|
|
274
|
+
alpha = max(clamp_range[0], min(alpha, clamp_range[1]))
|
|
275
|
+
|
|
276
|
+
if abs(alpha - 1.0) >= tol:
|
|
277
|
+
with torch.no_grad():
|
|
278
|
+
attn_proj.weight.mul_(alpha)
|
|
279
|
+
if scale_bias and attn_proj.bias is not None:
|
|
280
|
+
attn_proj.bias.mul_(alpha)
|
|
281
|
+
applied_scales[name] = alpha
|
|
282
|
+
|
|
283
|
+
# Handle MLP projection
|
|
284
|
+
if hasattr(blk, "mlp"):
|
|
285
|
+
mlp_proj = (
|
|
286
|
+
getattr(blk.mlp, "c_proj", None)
|
|
287
|
+
or getattr(blk.mlp, "down_proj", None)
|
|
288
|
+
or getattr(blk.mlp, "fc2", None)
|
|
289
|
+
)
|
|
290
|
+
if mlp_proj is not None:
|
|
291
|
+
name = f"block{i}.mlp"
|
|
292
|
+
values = sample_values.get(name, [])
|
|
293
|
+
if values:
|
|
294
|
+
tensor_vals = torch.tensor(values, dtype=torch.float64)
|
|
295
|
+
|
|
296
|
+
if tensor_vals.numel() >= 10:
|
|
297
|
+
lower = torch.quantile(tensor_vals, 0.02)
|
|
298
|
+
upper = torch.quantile(tensor_vals, 0.98)
|
|
299
|
+
tensor_vals = torch.clamp(
|
|
300
|
+
tensor_vals, lower.item(), upper.item()
|
|
301
|
+
)
|
|
302
|
+
|
|
303
|
+
group_count = 8 if tensor_vals.numel() >= 8 else tensor_vals.numel()
|
|
304
|
+
if group_count > 1:
|
|
305
|
+
chunks = torch.chunk(tensor_vals, group_count)
|
|
306
|
+
group_means = torch.stack([chunk.mean() for chunk in chunks])
|
|
307
|
+
var_F = torch.median(group_means).item()
|
|
308
|
+
else:
|
|
309
|
+
var_F = tensor_vals.mean().item()
|
|
310
|
+
|
|
311
|
+
alpha = (1.0 / max(var_F, 1e-9)) ** 0.5
|
|
312
|
+
|
|
313
|
+
# Apply clamping if specified
|
|
314
|
+
if clamp_range is not None:
|
|
315
|
+
alpha = max(clamp_range[0], min(alpha, clamp_range[1]))
|
|
316
|
+
|
|
317
|
+
if abs(alpha - 1.0) >= tol:
|
|
318
|
+
with torch.no_grad():
|
|
319
|
+
mlp_proj.weight.mul_(alpha)
|
|
320
|
+
if scale_bias and mlp_proj.bias is not None:
|
|
321
|
+
mlp_proj.bias.mul_(alpha)
|
|
322
|
+
applied_scales[name] = alpha
|
|
323
|
+
|
|
324
|
+
return applied_scales
|
|
325
|
+
|
|
326
|
+
|
|
327
|
+
def equalise_branch_variance(
|
|
328
|
+
model: nn.Module,
|
|
329
|
+
dataloader,
|
|
330
|
+
windows: int = 32,
|
|
331
|
+
tol: float = 0.02,
|
|
332
|
+
scale_bias: bool = True,
|
|
333
|
+
seed: int = 42,
|
|
334
|
+
device: str | None = None,
|
|
335
|
+
allow_empty: bool = False,
|
|
336
|
+
) -> dict[str, float]:
|
|
337
|
+
"""
|
|
338
|
+
Legacy alias for equalise_residual_variance.
|
|
339
|
+
|
|
340
|
+
Maintained for backward compatibility.
|
|
341
|
+
"""
|
|
342
|
+
return equalise_residual_variance(
|
|
343
|
+
model=model,
|
|
344
|
+
dataloader=dataloader,
|
|
345
|
+
windows=windows,
|
|
346
|
+
tol=tol,
|
|
347
|
+
scale_bias=scale_bias,
|
|
348
|
+
seed=seed,
|
|
349
|
+
device=device,
|
|
350
|
+
allow_empty=allow_empty,
|
|
351
|
+
)
|
|
352
|
+
|
|
353
|
+
|
|
354
|
+
def _predictive_gate_outcome(
|
|
355
|
+
mean_delta: float,
|
|
356
|
+
delta_ci: tuple[float, float] | None,
|
|
357
|
+
min_effect: float,
|
|
358
|
+
one_sided: bool,
|
|
359
|
+
) -> tuple[bool, str]:
|
|
360
|
+
"""
|
|
361
|
+
Decide whether the predictive gate passes given the CI and tier semantics.
|
|
362
|
+
|
|
363
|
+
Args:
|
|
364
|
+
mean_delta: Mean ΔlogNLL (virtual VE − no VE) from paired calibration.
|
|
365
|
+
delta_ci: BCa confidence interval on ΔlogNLL (lower, upper).
|
|
366
|
+
min_effect: Minimum absolute improvement required.
|
|
367
|
+
one_sided: Whether to require a one-sided improvement (balanced tier).
|
|
368
|
+
|
|
369
|
+
Returns:
|
|
370
|
+
Tuple of (passed, reason) where reason is a canonical string used in stats.
|
|
371
|
+
"""
|
|
372
|
+
guard_assert(min_effect >= 0.0, "variance.min_effect must be >= 0")
|
|
373
|
+
if (
|
|
374
|
+
delta_ci is None
|
|
375
|
+
or len(delta_ci) != 2
|
|
376
|
+
or not all(
|
|
377
|
+
isinstance(val, (int | float)) and math.isfinite(val) for val in delta_ci
|
|
378
|
+
)
|
|
379
|
+
):
|
|
380
|
+
return False, "ci_unavailable"
|
|
381
|
+
|
|
382
|
+
lower, upper = float(delta_ci[0]), float(delta_ci[1])
|
|
383
|
+
min_effect = float(min_effect or 0.0)
|
|
384
|
+
|
|
385
|
+
if one_sided:
|
|
386
|
+
if lower >= 0.0:
|
|
387
|
+
return False, "ci_contains_zero"
|
|
388
|
+
if mean_delta >= 0.0:
|
|
389
|
+
return False, "mean_not_negative"
|
|
390
|
+
if min_effect > 0.0 and (-mean_delta) < min_effect:
|
|
391
|
+
return False, "gain_below_threshold"
|
|
392
|
+
return True, "ci_gain_met"
|
|
393
|
+
|
|
394
|
+
# Two-sided improvement: CI must be strictly below zero.
|
|
395
|
+
if upper >= 0.0:
|
|
396
|
+
return False, "ci_contains_zero"
|
|
397
|
+
|
|
398
|
+
gain_lower_bound = -upper # Convert ΔlogNLL CI to gain CI lower bound.
|
|
399
|
+
if gain_lower_bound < min_effect:
|
|
400
|
+
return False, "gain_below_threshold"
|
|
401
|
+
|
|
402
|
+
return True, "ci_gain_met"
|
|
403
|
+
|
|
404
|
+
|
|
405
|
+
# === Standalone Variance Guard Implementation ===
|
|
406
|
+
|
|
407
|
+
|
|
408
|
+
class VarianceGuard(Guard):
|
|
409
|
+
"""
|
|
410
|
+
Standalone Variance Guard with A/B testing for data-driven variance equalization.
|
|
411
|
+
|
|
412
|
+
Implements branch-level variance equalization with reinforced A/B gate functionality:
|
|
413
|
+
- Measures variance of residual branch outputs during calibration
|
|
414
|
+
- Computes scaling factors to maintain stable variance dynamics
|
|
415
|
+
- A/B tests whether VE improves perplexity by at least min_gain
|
|
416
|
+
- Only enables VE if it demonstrably helps (validation gate compliance)
|
|
417
|
+
|
|
418
|
+
Policy Structure:
|
|
419
|
+
- min_gain: Minimum primary-metric improvement required to enable VE
|
|
420
|
+
- max_calib: Maximum calibration samples for A/B testing
|
|
421
|
+
- scope: Which layers to process ("ffn", "attn", "both")
|
|
422
|
+
- clamp: Scaling factor limits (min, max)
|
|
423
|
+
- deadband: Tolerance margin before scaling
|
|
424
|
+
- seed: Random seed for deterministic evaluation
|
|
425
|
+
|
|
426
|
+
Reinforced A/B Testing Flow:
|
|
427
|
+
1. Capture baseline model state with checkpoint discipline
|
|
428
|
+
2. Measure variance and compute proposed scales during prepare
|
|
429
|
+
3. A/B test with identical windows: evaluate the primary metric without VE, then with VE
|
|
430
|
+
4. Apply robust gain math with tie-breaker deadband and absolute floor
|
|
431
|
+
5. Enable VE only if improvement meets all criteria
|
|
432
|
+
6. Idempotent enable/disable with exact state restoration
|
|
433
|
+
"""
|
|
434
|
+
|
|
435
|
+
name = "variance"
|
|
436
|
+
|
|
437
|
+
def __init__(self, policy: VariancePolicyDict | None = None):
|
|
438
|
+
"""
|
|
439
|
+
Initialize Variance Guard with reinforced A/B gate logic.
|
|
440
|
+
|
|
441
|
+
Args:
|
|
442
|
+
policy: Variance policy configuration (uses balanced default if None)
|
|
443
|
+
"""
|
|
444
|
+
from .policies import get_variance_policy
|
|
445
|
+
|
|
446
|
+
self._policy = policy or get_variance_policy("balanced")
|
|
447
|
+
self._policy.setdefault("mode", "ci")
|
|
448
|
+
self._policy.setdefault("min_rel_gain", 0.001)
|
|
449
|
+
self._policy.setdefault("alpha", 0.05)
|
|
450
|
+
self._policy.setdefault("clamp", (0.5, 2.0))
|
|
451
|
+
self._policy.setdefault("seed", 123)
|
|
452
|
+
self._policy.setdefault("tie_breaker_deadband", 0.005)
|
|
453
|
+
self._policy.setdefault("min_abs_adjust", 0.012)
|
|
454
|
+
self._policy.setdefault("max_scale_step", 0.02)
|
|
455
|
+
self._policy.setdefault("topk_backstop", 1)
|
|
456
|
+
self._policy.setdefault("max_adjusted_modules", 0)
|
|
457
|
+
self._policy.setdefault("predictive_gate", True)
|
|
458
|
+
self._policy.setdefault("predictive_one_sided", False)
|
|
459
|
+
self._policy.setdefault("absolute_floor_ppl", 0.05)
|
|
460
|
+
if self._policy.get("min_effect_lognll") is not None:
|
|
461
|
+
self._policy["min_effect_lognll"] = float(self._policy["min_effect_lognll"])
|
|
462
|
+
self._refresh_calibration_defaults()
|
|
463
|
+
self._scales: dict[str, float] = {}
|
|
464
|
+
self._raw_scales: dict[str, float] = {}
|
|
465
|
+
self._enabled = False
|
|
466
|
+
self._stats: dict[str, Any] = {}
|
|
467
|
+
self._prepared = False
|
|
468
|
+
self._baseline_state: dict[str, Any] | None = None
|
|
469
|
+
self.events: list[dict[str, Any]] = []
|
|
470
|
+
self._calibration_stats: dict[str, Any] = {
|
|
471
|
+
"requested": 0,
|
|
472
|
+
"coverage": 0,
|
|
473
|
+
"min_coverage": 0,
|
|
474
|
+
"seed": self._policy["calibration"]["seed"],
|
|
475
|
+
"status": "uninitialized",
|
|
476
|
+
}
|
|
477
|
+
self.ABSOLUTE_FLOOR = float(
|
|
478
|
+
self._policy.get(
|
|
479
|
+
"absolute_floor_pm", self._policy.get("absolute_floor_ppl", 0.05)
|
|
480
|
+
)
|
|
481
|
+
)
|
|
482
|
+
self._monitor_only = bool(self._policy.get("monitor_only", False))
|
|
483
|
+
self._params_changed: int | None = None
|
|
484
|
+
self._run_context: dict[str, Any] | None = None
|
|
485
|
+
self._report_meta: dict[str, Any] | None = None
|
|
486
|
+
self._dataset_meta: dict[str, Any] | None = None
|
|
487
|
+
self._pairing_reference: list[str] = []
|
|
488
|
+
self._pairing_digest: str | None = None
|
|
489
|
+
self._adapter_ref: Any | None = None
|
|
490
|
+
|
|
491
|
+
# A/B testing results with reinforced validation
|
|
492
|
+
self._ppl_no_ve: float | None = None
|
|
493
|
+
self._ppl_with_ve: float | None = None
|
|
494
|
+
self._ab_gain: float | None = None
|
|
495
|
+
self._ab_windows_used: int | None = None
|
|
496
|
+
self._ab_seed_used: int | None = None
|
|
497
|
+
self._ratio_ci: tuple[float, float] | None = None
|
|
498
|
+
self._predictive_gate_state: dict[str, Any] = {
|
|
499
|
+
"evaluated": False,
|
|
500
|
+
"passed": False,
|
|
501
|
+
"reason": "not_evaluated",
|
|
502
|
+
"delta_ci": (None, None),
|
|
503
|
+
"gain_ci": (None, None),
|
|
504
|
+
"mean_delta": None,
|
|
505
|
+
}
|
|
506
|
+
|
|
507
|
+
# Module tracking for safe scaling
|
|
508
|
+
self._target_modules: dict[str, nn.Module] = {}
|
|
509
|
+
self._original_scales: dict[str, float] = {}
|
|
510
|
+
self._focus_modules = {
|
|
511
|
+
self._normalize_module_name(name)
|
|
512
|
+
for name in (self._policy.get("target_modules") or [])
|
|
513
|
+
if isinstance(name, str)
|
|
514
|
+
}
|
|
515
|
+
if self._focus_modules:
|
|
516
|
+
self._policy["target_modules"] = sorted(self._focus_modules)
|
|
517
|
+
|
|
518
|
+
tap_config = self._policy.get("tap")
|
|
519
|
+
if isinstance(tap_config, str):
|
|
520
|
+
tap_patterns = [tap_config]
|
|
521
|
+
elif isinstance(tap_config, Sequence):
|
|
522
|
+
tap_patterns = [
|
|
523
|
+
str(pattern)
|
|
524
|
+
for pattern in tap_config
|
|
525
|
+
if isinstance(pattern, str) and pattern.strip()
|
|
526
|
+
]
|
|
527
|
+
else:
|
|
528
|
+
tap_patterns = []
|
|
529
|
+
if not tap_patterns:
|
|
530
|
+
tap_patterns = ["transformer.h.*.mlp.c_proj"]
|
|
531
|
+
self._tap_patterns = tap_patterns
|
|
532
|
+
|
|
533
|
+
# Checkpoint discipline for robust state management
|
|
534
|
+
self._checkpoint_stack: list[dict[str, torch.Tensor]] = []
|
|
535
|
+
self._enable_attempt_count = 0
|
|
536
|
+
self._disable_attempt_count = 0
|
|
537
|
+
|
|
538
|
+
# Constants for reinforced A/B gate
|
|
539
|
+
self.TIE_BREAKER_DEADBAND = float(
|
|
540
|
+
self._policy.get("tie_breaker_deadband", 0.005)
|
|
541
|
+
) # Extra deadband to avoid flapping on noise
|
|
542
|
+
self.ABSOLUTE_FLOOR = 0.05 # Minimum improvement (ppl-like) to consider
|
|
543
|
+
|
|
544
|
+
# Calibration storage for post-edit evaluation
|
|
545
|
+
self._calibration_batches: list[Any] = []
|
|
546
|
+
self._calibration_window_ids: list[str] = []
|
|
547
|
+
self._calibration_context: dict[str, Any] = {}
|
|
548
|
+
self._calibration_stats_pre_edit: dict[str, Any] | None = None
|
|
549
|
+
self._post_edit_evaluated = False
|
|
550
|
+
self._raw_scales_pre_edit: dict[str, float] = {}
|
|
551
|
+
self._raw_scales_post_edit: dict[str, float] = {}
|
|
552
|
+
self._stats["tap"] = list(self._tap_patterns)
|
|
553
|
+
if self._focus_modules:
|
|
554
|
+
self._stats["focus_modules"] = sorted(self._focus_modules)
|
|
555
|
+
self._stats.setdefault("ab_provenance", {})
|
|
556
|
+
|
|
557
|
+
def _refresh_calibration_defaults(self) -> None:
|
|
558
|
+
"""Ensure calibration config contains required defaults."""
|
|
559
|
+
default_calibration = {
|
|
560
|
+
"windows": 6,
|
|
561
|
+
"min_coverage": 4,
|
|
562
|
+
"seed": self._policy.get("seed", 123),
|
|
563
|
+
}
|
|
564
|
+
calibration_cfg = self._policy.get("calibration", {}) or {}
|
|
565
|
+
if not isinstance(calibration_cfg, dict):
|
|
566
|
+
calibration_cfg = {}
|
|
567
|
+
merged_calibration = {**default_calibration, **calibration_cfg}
|
|
568
|
+
self._policy["calibration"] = merged_calibration
|
|
569
|
+
|
|
570
|
+
def _log_event(
|
|
571
|
+
self, operation: str, level: str = "INFO", message: str = "", **data
|
|
572
|
+
):
|
|
573
|
+
"""Log an event with timestamp."""
|
|
574
|
+
event = {
|
|
575
|
+
"timestamp": datetime.utcnow().isoformat(),
|
|
576
|
+
"component": "variance_guard",
|
|
577
|
+
"operation": operation,
|
|
578
|
+
"level": level,
|
|
579
|
+
"message": message,
|
|
580
|
+
"data": data,
|
|
581
|
+
}
|
|
582
|
+
self.events.append(event)
|
|
583
|
+
|
|
584
|
+
def set_run_context(self, report: Any) -> None:
|
|
585
|
+
"""Capture run-level context (edit metadata, pairing reference, etc.)."""
|
|
586
|
+
self._report_meta = getattr(report, "meta", {}) or {}
|
|
587
|
+
self._run_context = getattr(report, "context", {}) or {}
|
|
588
|
+
if isinstance(self._run_context, dict):
|
|
589
|
+
self._dataset_meta = self._run_context.get("dataset_meta")
|
|
590
|
+
else:
|
|
591
|
+
self._dataset_meta = None
|
|
592
|
+
if isinstance(self._dataset_meta, dict):
|
|
593
|
+
self._stats.setdefault("dataset_meta", self._dataset_meta)
|
|
594
|
+
|
|
595
|
+
pairing_reference: list[str] = []
|
|
596
|
+
pairing_digest: str | None = None
|
|
597
|
+
if isinstance(self._run_context, dict):
|
|
598
|
+
pairing_baseline = self._run_context.get("pairing_baseline")
|
|
599
|
+
else:
|
|
600
|
+
pairing_baseline = None
|
|
601
|
+
if isinstance(pairing_baseline, dict):
|
|
602
|
+
preview_section = pairing_baseline.get("preview") or {}
|
|
603
|
+
final_section = pairing_baseline.get("final") or {}
|
|
604
|
+
pairing_reference.extend(
|
|
605
|
+
self._normalize_pairing_ids(
|
|
606
|
+
"preview", preview_section.get("window_ids") or []
|
|
607
|
+
)
|
|
608
|
+
)
|
|
609
|
+
pairing_reference.extend(
|
|
610
|
+
self._normalize_pairing_ids(
|
|
611
|
+
"final", final_section.get("window_ids") or []
|
|
612
|
+
)
|
|
613
|
+
)
|
|
614
|
+
if pairing_reference:
|
|
615
|
+
joined = "||".join(pairing_reference)
|
|
616
|
+
pairing_digest = hashlib.blake2s(
|
|
617
|
+
joined.encode("utf-8"), digest_size=16
|
|
618
|
+
).hexdigest()
|
|
619
|
+
pairing_stats = self._stats.setdefault("pairing_reference", {})
|
|
620
|
+
pairing_stats.update(
|
|
621
|
+
{
|
|
622
|
+
"count": len(pairing_reference),
|
|
623
|
+
"digest": pairing_digest,
|
|
624
|
+
}
|
|
625
|
+
)
|
|
626
|
+
self._pairing_reference = pairing_reference
|
|
627
|
+
self._pairing_digest = pairing_digest
|
|
628
|
+
if pairing_digest is None:
|
|
629
|
+
self._stats.pop("pairing_reference", None)
|
|
630
|
+
|
|
631
|
+
edit_info = getattr(report, "edit", {}) or {}
|
|
632
|
+
params_changed = None
|
|
633
|
+
if isinstance(edit_info, dict):
|
|
634
|
+
deltas = edit_info.get("deltas") or {}
|
|
635
|
+
if isinstance(deltas, dict):
|
|
636
|
+
params_changed = deltas.get("params_changed")
|
|
637
|
+
if params_changed is None:
|
|
638
|
+
params_changed = (
|
|
639
|
+
0 if edit_info and edit_info.get("name") in {"noop"} else None
|
|
640
|
+
)
|
|
641
|
+
self._params_changed = params_changed
|
|
642
|
+
if params_changed == 0:
|
|
643
|
+
self._monitor_only = True
|
|
644
|
+
self._log_event(
|
|
645
|
+
"monitor_only",
|
|
646
|
+
message="Variance guard forcing monitor-only mode (no parameters changed)",
|
|
647
|
+
)
|
|
648
|
+
# Clear proposed scales in monitor-only mode
|
|
649
|
+
self._scales = {}
|
|
650
|
+
|
|
651
|
+
def _normalize_module_name(self, name: str) -> str:
|
|
652
|
+
"""Normalize module names to transformer.h.<idx>.<branch>.c_proj form."""
|
|
653
|
+
if not isinstance(name, str):
|
|
654
|
+
return ""
|
|
655
|
+
|
|
656
|
+
normalized = name.strip()
|
|
657
|
+
if not normalized:
|
|
658
|
+
return normalized
|
|
659
|
+
|
|
660
|
+
if normalized.startswith("block"):
|
|
661
|
+
parts = normalized.split(".")
|
|
662
|
+
if len(parts) >= 2 and parts[0].startswith("block"):
|
|
663
|
+
layer_idx = parts[0][5:]
|
|
664
|
+
branch = parts[1]
|
|
665
|
+
branch = "attn" if branch.startswith("attn") else "mlp"
|
|
666
|
+
return f"transformer.h.{layer_idx}.{branch}.c_proj"
|
|
667
|
+
|
|
668
|
+
if normalized.startswith("transformer.h."):
|
|
669
|
+
if normalized.endswith(".c_proj"):
|
|
670
|
+
return normalized
|
|
671
|
+
if ".mlp" in normalized and ".c_proj" not in normalized:
|
|
672
|
+
return f"{normalized}.c_proj"
|
|
673
|
+
if ".attn" in normalized and ".c_proj" not in normalized:
|
|
674
|
+
return f"{normalized}.c_proj"
|
|
675
|
+
|
|
676
|
+
return normalized
|
|
677
|
+
|
|
678
|
+
def _matches_tap(self, name: str) -> bool:
|
|
679
|
+
"""Return True if a module name matches configured tap patterns."""
|
|
680
|
+
normalized = self._normalize_module_name(name)
|
|
681
|
+
for pattern in self._tap_patterns:
|
|
682
|
+
if fnmatch.fnmatch(normalized, pattern) or fnmatch.fnmatch(name, pattern):
|
|
683
|
+
return True
|
|
684
|
+
return False
|
|
685
|
+
|
|
686
|
+
def _normalize_pairing_ids(
|
|
687
|
+
self, prefix: str, window_ids: Sequence[Any]
|
|
688
|
+
) -> list[str]:
|
|
689
|
+
normalized: list[str] = []
|
|
690
|
+
for idx in window_ids:
|
|
691
|
+
token = str(idx)
|
|
692
|
+
if "::" in token:
|
|
693
|
+
normalized.append(token)
|
|
694
|
+
else:
|
|
695
|
+
normalized.append(f"{prefix}::{token}")
|
|
696
|
+
return normalized
|
|
697
|
+
|
|
698
|
+
def _expected_window_ids(self) -> list[str]:
|
|
699
|
+
return list(self._pairing_reference)
|
|
700
|
+
|
|
701
|
+
def _normalize_scale_name(self, name: str) -> str:
|
|
702
|
+
"""Normalize a scale name to the canonical module path."""
|
|
703
|
+
return self._normalize_module_name(name)
|
|
704
|
+
|
|
705
|
+
def _scale_matches_target(self, scale_name: str, target_name: str) -> bool:
|
|
706
|
+
"""Check if a scale name from equalise_residual_variance matches a target module name.
|
|
707
|
+
|
|
708
|
+
Handles the format mismatch between:
|
|
709
|
+
- Scale names: block0.mlp, block0.attn
|
|
710
|
+
- Target names: transformer.h.0.mlp.c_proj, transformer.h.0.attn.c_proj
|
|
711
|
+
"""
|
|
712
|
+
# Normalize scale name to target format and check direct match
|
|
713
|
+
normalized_scale = self._normalize_scale_name(scale_name)
|
|
714
|
+
if normalized_scale == target_name:
|
|
715
|
+
return True
|
|
716
|
+
|
|
717
|
+
# Convert block format to layer-component extraction
|
|
718
|
+
if scale_name.startswith("block") and (
|
|
719
|
+
"attn" in scale_name or "mlp" in scale_name
|
|
720
|
+
):
|
|
721
|
+
parts = scale_name.split(".")
|
|
722
|
+
if len(parts) == 2:
|
|
723
|
+
layer_part = parts[0] # e.g., "block0"
|
|
724
|
+
component = parts[1] # e.g., "attn" or "mlp"
|
|
725
|
+
if layer_part.startswith("block"):
|
|
726
|
+
try:
|
|
727
|
+
layer_num = layer_part[5:] # Extract number from "block0"
|
|
728
|
+
# Check if target matches this pattern
|
|
729
|
+
if f"h.{layer_num}.{component}" in target_name:
|
|
730
|
+
return True
|
|
731
|
+
except (ValueError, IndexError):
|
|
732
|
+
pass
|
|
733
|
+
|
|
734
|
+
return False
|
|
735
|
+
|
|
736
|
+
def _is_focus_match(self, name: str) -> bool:
|
|
737
|
+
"""Check whether a module name matches the configured focus list."""
|
|
738
|
+
if not self._focus_modules:
|
|
739
|
+
return True
|
|
740
|
+
normalized = self._normalize_module_name(name)
|
|
741
|
+
return normalized in self._focus_modules
|
|
742
|
+
|
|
743
|
+
def _materialize_batch(self, batch: Any) -> Any:
|
|
744
|
+
"""Detach tensors from device and clone calibration batches for reuse."""
|
|
745
|
+
if isinstance(batch, dict):
|
|
746
|
+
return {key: self._materialize_batch(val) for key, val in batch.items()}
|
|
747
|
+
if isinstance(batch, list | tuple):
|
|
748
|
+
return type(batch)(self._materialize_batch(val) for val in batch)
|
|
749
|
+
if isinstance(batch, torch.Tensor):
|
|
750
|
+
return batch.detach().cpu()
|
|
751
|
+
try:
|
|
752
|
+
return copy.deepcopy(batch)
|
|
753
|
+
except Exception:
|
|
754
|
+
return batch
|
|
755
|
+
|
|
756
|
+
def _ensure_tensor_value(self, value: Any) -> Any:
|
|
757
|
+
"""Convert common calibration value types to torch tensors."""
|
|
758
|
+
if isinstance(value, torch.Tensor):
|
|
759
|
+
return value
|
|
760
|
+
if isinstance(value, np.ndarray):
|
|
761
|
+
return torch.as_tensor(value)
|
|
762
|
+
if isinstance(value, list | tuple):
|
|
763
|
+
try:
|
|
764
|
+
return torch.as_tensor(value)
|
|
765
|
+
except Exception:
|
|
766
|
+
return value
|
|
767
|
+
if isinstance(value, int | float):
|
|
768
|
+
return torch.tensor(value)
|
|
769
|
+
return value
|
|
770
|
+
|
|
771
|
+
def _tensorize_calibration_batches(self, batches: Sequence[Any]) -> list[Any]:
|
|
772
|
+
"""Ensure calibration batches contain tensor payloads for model execution."""
|
|
773
|
+
tensor_batches: list[Any] = []
|
|
774
|
+
for batch in batches:
|
|
775
|
+
if isinstance(batch, dict):
|
|
776
|
+
converted: dict[str, Any] = {}
|
|
777
|
+
for key, value in batch.items():
|
|
778
|
+
if key in {"input_ids", "inputs", "attention_mask", "labels"}:
|
|
779
|
+
converted[key] = self._ensure_tensor_value(value)
|
|
780
|
+
else:
|
|
781
|
+
converted[key] = value
|
|
782
|
+
tensor_batches.append(converted)
|
|
783
|
+
elif isinstance(batch, list | tuple):
|
|
784
|
+
converted_list = [self._ensure_tensor_value(val) for val in batch]
|
|
785
|
+
tensor_batches.append(type(batch)(converted_list))
|
|
786
|
+
else:
|
|
787
|
+
tensor_batches.append(self._ensure_tensor_value(batch))
|
|
788
|
+
return tensor_batches
|
|
789
|
+
|
|
790
|
+
def _extract_window_ids(self, batches: Sequence[Any]) -> list[str]:
|
|
791
|
+
"""Extract window identifiers from calibration batches when present."""
|
|
792
|
+
window_ids: list[str] = []
|
|
793
|
+
for batch in batches:
|
|
794
|
+
candidate: Any | None = None
|
|
795
|
+
if isinstance(batch, dict):
|
|
796
|
+
if "window_id" in batch:
|
|
797
|
+
candidate = batch["window_id"]
|
|
798
|
+
elif "window_ids" in batch:
|
|
799
|
+
candidate = batch["window_ids"]
|
|
800
|
+
elif isinstance(batch.get("metadata"), dict):
|
|
801
|
+
meta = batch["metadata"]
|
|
802
|
+
candidate = meta.get("window_id") or meta.get("window_ids")
|
|
803
|
+
|
|
804
|
+
if candidate is None:
|
|
805
|
+
continue
|
|
806
|
+
|
|
807
|
+
if isinstance(candidate, list | tuple):
|
|
808
|
+
window_ids.extend(str(item) for item in candidate)
|
|
809
|
+
else:
|
|
810
|
+
window_ids.append(str(candidate))
|
|
811
|
+
if not window_ids and batches:
|
|
812
|
+
window_ids = [str(idx) for idx in range(len(batches))]
|
|
813
|
+
return window_ids
|
|
814
|
+
|
|
815
|
+
def _store_calibration_batches(self, batches: list[Any]) -> None:
|
|
816
|
+
"""Persist calibration batches for deterministic post-edit evaluation."""
|
|
817
|
+
materialized = [self._materialize_batch(b) for b in batches]
|
|
818
|
+
self._calibration_batches = self._tensorize_calibration_batches(materialized)
|
|
819
|
+
self._calibration_window_ids = self._extract_window_ids(
|
|
820
|
+
self._calibration_batches
|
|
821
|
+
)
|
|
822
|
+
observed_ids = list(self._calibration_window_ids)
|
|
823
|
+
observed_digest = (
|
|
824
|
+
hashlib.blake2s(
|
|
825
|
+
"||".join(observed_ids).encode("utf-8"), digest_size=16
|
|
826
|
+
).hexdigest()
|
|
827
|
+
if observed_ids
|
|
828
|
+
else None
|
|
829
|
+
)
|
|
830
|
+
self._calibration_context = {
|
|
831
|
+
"window_ids": list(self._calibration_window_ids),
|
|
832
|
+
"count": len(self._calibration_batches),
|
|
833
|
+
"observed_digest": observed_digest,
|
|
834
|
+
}
|
|
835
|
+
expected_ids = self._expected_window_ids()
|
|
836
|
+
if expected_ids:
|
|
837
|
+
self._calibration_context["expected_digest"] = self._pairing_digest
|
|
838
|
+
expected_subset = expected_ids[: len(observed_ids)] if observed_ids else []
|
|
839
|
+
if observed_ids != expected_subset:
|
|
840
|
+
mismatch = {
|
|
841
|
+
"expected_count": len(expected_ids),
|
|
842
|
+
"observed_count": len(observed_ids),
|
|
843
|
+
"expected_sample": expected_subset[:5]
|
|
844
|
+
if expected_subset
|
|
845
|
+
else expected_ids[:5],
|
|
846
|
+
"observed_sample": observed_ids[:5],
|
|
847
|
+
}
|
|
848
|
+
self._log_event(
|
|
849
|
+
"pairing_mismatch",
|
|
850
|
+
level="ERROR",
|
|
851
|
+
message="Variance guard calibration windows do not match baseline pairing",
|
|
852
|
+
**mismatch,
|
|
853
|
+
)
|
|
854
|
+
self._prepared = False
|
|
855
|
+
raise RuntimeError(
|
|
856
|
+
"Variance guard pairing mismatch: calibration windows diverge from baseline schedule"
|
|
857
|
+
)
|
|
858
|
+
self._stats.setdefault("calibration", {})
|
|
859
|
+
self._stats["calibration"].update(self._calibration_context)
|
|
860
|
+
|
|
861
|
+
def _fingerprint_targets(self) -> str | None:
|
|
862
|
+
"""Compute a lightweight fingerprint of targeted module weights."""
|
|
863
|
+
if not self._target_modules:
|
|
864
|
+
return None
|
|
865
|
+
|
|
866
|
+
hasher = hashlib.sha256()
|
|
867
|
+
try:
|
|
868
|
+
for name in sorted(self._target_modules.keys()):
|
|
869
|
+
module = self._target_modules[name]
|
|
870
|
+
state = getattr(module, "state_dict", None)
|
|
871
|
+
if not callable(state):
|
|
872
|
+
continue
|
|
873
|
+
module_state = state()
|
|
874
|
+
for key in sorted(module_state.keys()):
|
|
875
|
+
tensor = module_state[key]
|
|
876
|
+
if hasattr(tensor, "detach"):
|
|
877
|
+
data = tensor.detach().cpu().numpy().tobytes()
|
|
878
|
+
else:
|
|
879
|
+
data = bytes(str(tensor), "utf-8")
|
|
880
|
+
hasher.update(name.encode("utf-8"))
|
|
881
|
+
hasher.update(key.encode("utf-8"))
|
|
882
|
+
hasher.update(data)
|
|
883
|
+
return hasher.hexdigest()[:16]
|
|
884
|
+
except Exception:
|
|
885
|
+
return None
|
|
886
|
+
|
|
887
|
+
def _record_ab_provenance(
|
|
888
|
+
self,
|
|
889
|
+
condition: str,
|
|
890
|
+
*,
|
|
891
|
+
tag: str,
|
|
892
|
+
window_ids: Sequence[str],
|
|
893
|
+
fingerprint: str | None,
|
|
894
|
+
mode: str,
|
|
895
|
+
status: str,
|
|
896
|
+
) -> None:
|
|
897
|
+
"""Record provenance metadata for A/B evaluation conditions."""
|
|
898
|
+
provenance = self._stats.setdefault("ab_provenance", {})
|
|
899
|
+
window_list = list(window_ids)
|
|
900
|
+
provenance[condition] = {
|
|
901
|
+
"tag": tag,
|
|
902
|
+
"mode": mode,
|
|
903
|
+
"window_ids": window_list,
|
|
904
|
+
"window_count": len(window_list),
|
|
905
|
+
"target_fingerprint": fingerprint,
|
|
906
|
+
"status": status,
|
|
907
|
+
"pairing_digest": self._pairing_digest,
|
|
908
|
+
"dataset_hash": (self._dataset_meta or {}).get("dataset_hash"),
|
|
909
|
+
"tokenizer_hash": (self._dataset_meta or {}).get("tokenizer_hash"),
|
|
910
|
+
"model_id": (self._report_meta or {}).get("model_id"),
|
|
911
|
+
"run_seed": (self._report_meta or {}).get("seed"),
|
|
912
|
+
}
|
|
913
|
+
|
|
914
|
+
def _resolve_target_modules(
|
|
915
|
+
self, model: nn.Module, adapter: Any | None = None
|
|
916
|
+
) -> dict[str, nn.Module]:
|
|
917
|
+
"""
|
|
918
|
+
Resolve target modules based on scope policy.
|
|
919
|
+
|
|
920
|
+
Args:
|
|
921
|
+
model: Model to analyze
|
|
922
|
+
adapter: Optional adapter used to query layer modules
|
|
923
|
+
|
|
924
|
+
Returns:
|
|
925
|
+
Dict mapping module names to modules
|
|
926
|
+
"""
|
|
927
|
+
targets = {}
|
|
928
|
+
scope = self._policy["scope"]
|
|
929
|
+
audit_candidates: list[dict[str, Any]] = []
|
|
930
|
+
audit_rejections: list[dict[str, Any]] = []
|
|
931
|
+
|
|
932
|
+
def _record_match(name: str, module: nn.Module) -> None:
|
|
933
|
+
audit_candidates.append(
|
|
934
|
+
{
|
|
935
|
+
"name": name,
|
|
936
|
+
"class": module.__class__.__name__,
|
|
937
|
+
"source": "direct",
|
|
938
|
+
}
|
|
939
|
+
)
|
|
940
|
+
|
|
941
|
+
def _record_rejection(name: str, reason: str, module: Any | None) -> None:
|
|
942
|
+
audit_rejections.append(
|
|
943
|
+
{
|
|
944
|
+
"name": name,
|
|
945
|
+
"reason": reason,
|
|
946
|
+
"class": getattr(module, "__class__", type(None)).__name__
|
|
947
|
+
if module is not None
|
|
948
|
+
else None,
|
|
949
|
+
}
|
|
950
|
+
)
|
|
951
|
+
|
|
952
|
+
# Get module types
|
|
953
|
+
try:
|
|
954
|
+
from transformers.pytorch_utils import Conv1D
|
|
955
|
+
|
|
956
|
+
module_types = (nn.Linear, nn.Conv1d, Conv1D)
|
|
957
|
+
except ImportError:
|
|
958
|
+
module_types = (nn.Linear, nn.Conv1d)
|
|
959
|
+
|
|
960
|
+
def _is_supported_module(module: Any) -> bool:
|
|
961
|
+
"""Heuristic check that a module looks like a projection."""
|
|
962
|
+
if isinstance(module, module_types):
|
|
963
|
+
return True
|
|
964
|
+
class_name = module.__class__.__name__ if module is not None else ""
|
|
965
|
+
if class_name in {"Conv1D", "Linear"}:
|
|
966
|
+
return True
|
|
967
|
+
weight = getattr(module, "weight", None)
|
|
968
|
+
if weight is None:
|
|
969
|
+
return False
|
|
970
|
+
try:
|
|
971
|
+
dim = weight.dim()
|
|
972
|
+
except Exception:
|
|
973
|
+
dim = getattr(weight, "ndim", None)
|
|
974
|
+
return dim == 2
|
|
975
|
+
|
|
976
|
+
for i, blk in enumerate(_iter_transformer_layers(model)):
|
|
977
|
+
# Handle attention projection based on scope
|
|
978
|
+
if scope in ["attn", "both"] and hasattr(blk, "attn"):
|
|
979
|
+
attn_proj = getattr(blk.attn, "c_proj", None) or getattr(
|
|
980
|
+
blk.attn, "out_proj", None
|
|
981
|
+
)
|
|
982
|
+
name = f"transformer.h.{i}.attn.c_proj"
|
|
983
|
+
if attn_proj is None:
|
|
984
|
+
_record_rejection(name, "missing_module", None)
|
|
985
|
+
elif not self._matches_tap(name):
|
|
986
|
+
_record_rejection(name, "tap_mismatch", attn_proj)
|
|
987
|
+
elif not _is_supported_module(attn_proj):
|
|
988
|
+
_record_rejection(name, "unsupported_type", attn_proj)
|
|
989
|
+
else:
|
|
990
|
+
targets[name] = attn_proj
|
|
991
|
+
_record_match(name, attn_proj)
|
|
992
|
+
|
|
993
|
+
# Handle MLP projection based on scope
|
|
994
|
+
if scope in ["ffn", "both"] and hasattr(blk, "mlp"):
|
|
995
|
+
mlp_proj = (
|
|
996
|
+
getattr(blk.mlp, "c_proj", None)
|
|
997
|
+
or getattr(blk.mlp, "down_proj", None)
|
|
998
|
+
or getattr(blk.mlp, "fc2", None)
|
|
999
|
+
)
|
|
1000
|
+
name = f"transformer.h.{i}.mlp.c_proj"
|
|
1001
|
+
if mlp_proj is None:
|
|
1002
|
+
_record_rejection(name, "missing_module", None)
|
|
1003
|
+
elif not self._matches_tap(name):
|
|
1004
|
+
_record_rejection(name, "tap_mismatch", mlp_proj)
|
|
1005
|
+
elif not _is_supported_module(mlp_proj):
|
|
1006
|
+
_record_rejection(name, "unsupported_type", mlp_proj)
|
|
1007
|
+
else:
|
|
1008
|
+
targets[name] = mlp_proj
|
|
1009
|
+
_record_match(name, mlp_proj)
|
|
1010
|
+
|
|
1011
|
+
fallback_used = False
|
|
1012
|
+
|
|
1013
|
+
# Fallback: ask adapter for layer modules if we could not resolve anything
|
|
1014
|
+
# Strategy:
|
|
1015
|
+
# 1. Try adapter.describe() for layer count - works even when model structure is unknown
|
|
1016
|
+
# 2. If that fails, try _iter_transformer_layers() to count layers
|
|
1017
|
+
# 3. If that fails, try model.config for layer count
|
|
1018
|
+
if (
|
|
1019
|
+
not targets
|
|
1020
|
+
and adapter is not None
|
|
1021
|
+
and hasattr(adapter, "get_layer_modules")
|
|
1022
|
+
):
|
|
1023
|
+
try:
|
|
1024
|
+
# Get layer count from adapter.describe() first
|
|
1025
|
+
n_layers = 0
|
|
1026
|
+
if hasattr(adapter, "describe"):
|
|
1027
|
+
try:
|
|
1028
|
+
desc = adapter.describe(model)
|
|
1029
|
+
if isinstance(desc, dict):
|
|
1030
|
+
n_layers = int(desc.get("n_layer", 0) or 0)
|
|
1031
|
+
except Exception as desc_exc:
|
|
1032
|
+
self._log_event(
|
|
1033
|
+
"adapter_describe_error",
|
|
1034
|
+
level="DEBUG",
|
|
1035
|
+
message=f"adapter.describe() failed: {desc_exc}",
|
|
1036
|
+
)
|
|
1037
|
+
|
|
1038
|
+
# Fallback: count layers via _iter_transformer_layers()
|
|
1039
|
+
# This works when model has standard structure but no c_proj
|
|
1040
|
+
if n_layers == 0:
|
|
1041
|
+
try:
|
|
1042
|
+
n_layers = sum(1 for _ in _iter_transformer_layers(model))
|
|
1043
|
+
except Exception:
|
|
1044
|
+
pass
|
|
1045
|
+
|
|
1046
|
+
# Fallback: try model.config for layer count
|
|
1047
|
+
if n_layers == 0:
|
|
1048
|
+
config = getattr(_unwrap_model(model), "config", None)
|
|
1049
|
+
if config is not None:
|
|
1050
|
+
n_layers = (
|
|
1051
|
+
getattr(config, "n_layer", 0)
|
|
1052
|
+
or getattr(config, "num_hidden_layers", 0)
|
|
1053
|
+
or getattr(config, "num_layers", 0)
|
|
1054
|
+
or 0
|
|
1055
|
+
)
|
|
1056
|
+
|
|
1057
|
+
if n_layers == 0:
|
|
1058
|
+
self._log_event(
|
|
1059
|
+
"adapter_fallback_no_layers",
|
|
1060
|
+
level="WARN",
|
|
1061
|
+
message="Adapter fallback: could not determine layer count",
|
|
1062
|
+
)
|
|
1063
|
+
|
|
1064
|
+
for i in range(n_layers):
|
|
1065
|
+
try:
|
|
1066
|
+
modules = adapter.get_layer_modules(model, i) or {}
|
|
1067
|
+
except Exception as exc:
|
|
1068
|
+
_record_rejection(
|
|
1069
|
+
f"transformer.h.{i}",
|
|
1070
|
+
f"adapter_error:{exc}",
|
|
1071
|
+
None,
|
|
1072
|
+
)
|
|
1073
|
+
continue
|
|
1074
|
+
|
|
1075
|
+
for key, module in modules.items():
|
|
1076
|
+
if not isinstance(key, str) or not key.endswith("c_proj"):
|
|
1077
|
+
continue
|
|
1078
|
+
branch = "attn" if "attn" in key else "mlp"
|
|
1079
|
+
name = f"transformer.h.{i}.{branch}.c_proj"
|
|
1080
|
+
if not self._matches_tap(name):
|
|
1081
|
+
_record_rejection(name, "tap_mismatch", module)
|
|
1082
|
+
continue
|
|
1083
|
+
if not _is_supported_module(module):
|
|
1084
|
+
_record_rejection(name, "unsupported_type", module)
|
|
1085
|
+
continue
|
|
1086
|
+
targets[name] = module
|
|
1087
|
+
audit_candidates.append(
|
|
1088
|
+
{
|
|
1089
|
+
"name": name,
|
|
1090
|
+
"class": module.__class__.__name__,
|
|
1091
|
+
"source": "adapter_fallback",
|
|
1092
|
+
}
|
|
1093
|
+
)
|
|
1094
|
+
if targets:
|
|
1095
|
+
fallback_used = True
|
|
1096
|
+
except Exception as exc: # pragma: no cover - defensive logging
|
|
1097
|
+
self._log_event(
|
|
1098
|
+
"target_resolution_fallback_error",
|
|
1099
|
+
level="WARN",
|
|
1100
|
+
message="Adapter fallback failed during VE target resolution",
|
|
1101
|
+
error=str(exc),
|
|
1102
|
+
)
|
|
1103
|
+
|
|
1104
|
+
if self._focus_modules:
|
|
1105
|
+
focused: dict[str, nn.Module] = {}
|
|
1106
|
+
for name, module in targets.items():
|
|
1107
|
+
norm_name = self._normalize_module_name(name)
|
|
1108
|
+
if norm_name in self._focus_modules:
|
|
1109
|
+
focused[name] = module
|
|
1110
|
+
|
|
1111
|
+
if not focused:
|
|
1112
|
+
self._log_event(
|
|
1113
|
+
"focus_miss",
|
|
1114
|
+
level="WARN",
|
|
1115
|
+
message="No target modules matched focus list",
|
|
1116
|
+
focus_modules=sorted(self._focus_modules),
|
|
1117
|
+
available=list(targets.keys()),
|
|
1118
|
+
)
|
|
1119
|
+
else:
|
|
1120
|
+
targets = focused
|
|
1121
|
+
|
|
1122
|
+
# Persist audit statistics for reports
|
|
1123
|
+
rejected_summary: dict[str, Any] = {}
|
|
1124
|
+
for item in audit_rejections:
|
|
1125
|
+
reason = item["reason"]
|
|
1126
|
+
bucket = rejected_summary.setdefault(reason, {"count": 0, "examples": []})
|
|
1127
|
+
bucket["count"] += 1
|
|
1128
|
+
if len(bucket["examples"]) < 5:
|
|
1129
|
+
bucket["examples"].append(
|
|
1130
|
+
{
|
|
1131
|
+
"name": item["name"],
|
|
1132
|
+
"class": item["class"],
|
|
1133
|
+
}
|
|
1134
|
+
)
|
|
1135
|
+
|
|
1136
|
+
self._stats["target_resolution"] = {
|
|
1137
|
+
"scope": scope,
|
|
1138
|
+
"tap": list(self._tap_patterns),
|
|
1139
|
+
"total_matched": len(targets),
|
|
1140
|
+
"matched": sorted(targets.keys()),
|
|
1141
|
+
"fallback_used": fallback_used,
|
|
1142
|
+
"candidates_recorded": len(audit_candidates),
|
|
1143
|
+
"rejected": rejected_summary,
|
|
1144
|
+
}
|
|
1145
|
+
|
|
1146
|
+
self._log_event(
|
|
1147
|
+
"target_resolution",
|
|
1148
|
+
message="Resolved variance guard targets",
|
|
1149
|
+
scope=scope,
|
|
1150
|
+
tap=list(self._tap_patterns),
|
|
1151
|
+
matched=len(targets),
|
|
1152
|
+
rejected=sum(item["count"] for item in rejected_summary.values())
|
|
1153
|
+
if rejected_summary
|
|
1154
|
+
else 0,
|
|
1155
|
+
fallback_used=fallback_used,
|
|
1156
|
+
)
|
|
1157
|
+
|
|
1158
|
+
return targets
|
|
1159
|
+
|
|
1160
|
+
def _compute_variance_scales(
|
|
1161
|
+
self, model: nn.Module, dataloader
|
|
1162
|
+
) -> dict[str, float]:
|
|
1163
|
+
"""
|
|
1164
|
+
Compute variance-based scaling factors using existing implementation.
|
|
1165
|
+
|
|
1166
|
+
Args:
|
|
1167
|
+
model: Model to analyze
|
|
1168
|
+
dataloader: Calibration data
|
|
1169
|
+
|
|
1170
|
+
Returns:
|
|
1171
|
+
Dict mapping module names to proposed scaling factors
|
|
1172
|
+
"""
|
|
1173
|
+
if self._monitor_only:
|
|
1174
|
+
self._log_event(
|
|
1175
|
+
"monitor_only",
|
|
1176
|
+
message="Skipping variance scale computation in monitor-only mode",
|
|
1177
|
+
)
|
|
1178
|
+
self._raw_scales = {}
|
|
1179
|
+
return {}
|
|
1180
|
+
|
|
1181
|
+
# Use existing equalise_residual_variance but don't apply yet
|
|
1182
|
+
# We'll capture the proposed scales and apply them later in enable()
|
|
1183
|
+
|
|
1184
|
+
# Temporarily capture the current model state
|
|
1185
|
+
original_state = copy.deepcopy(model.state_dict())
|
|
1186
|
+
|
|
1187
|
+
try:
|
|
1188
|
+
tensor_ready_batches = self._tensorize_calibration_batches(dataloader)
|
|
1189
|
+
|
|
1190
|
+
# Run variance equalization to get proposed scales
|
|
1191
|
+
proposed_scales = equalise_residual_variance(
|
|
1192
|
+
model=model,
|
|
1193
|
+
dataloader=tensor_ready_batches,
|
|
1194
|
+
windows=min(
|
|
1195
|
+
self._policy["max_calib"] // 10, 50
|
|
1196
|
+
), # Limit calibration windows
|
|
1197
|
+
tol=self._policy["deadband"],
|
|
1198
|
+
scale_bias=False, # Don't scale biases to preserve operating points
|
|
1199
|
+
seed=self._policy["seed"],
|
|
1200
|
+
clamp_range=self._policy["clamp"],
|
|
1201
|
+
allow_empty=True,
|
|
1202
|
+
)
|
|
1203
|
+
|
|
1204
|
+
if not proposed_scales and self._policy.get("deadband", 0.0) > 0.0:
|
|
1205
|
+
relaxed_tol = max(self._policy["deadband"] * 0.5, 1e-4)
|
|
1206
|
+
model.load_state_dict(original_state)
|
|
1207
|
+
tensor_ready_batches = self._tensorize_calibration_batches(dataloader)
|
|
1208
|
+
proposed_scales = equalise_residual_variance(
|
|
1209
|
+
model=model,
|
|
1210
|
+
dataloader=tensor_ready_batches,
|
|
1211
|
+
windows=min(self._policy["max_calib"] // 10, 50),
|
|
1212
|
+
tol=relaxed_tol,
|
|
1213
|
+
scale_bias=False,
|
|
1214
|
+
seed=self._policy["seed"] + 7,
|
|
1215
|
+
clamp_range=self._policy["clamp"],
|
|
1216
|
+
allow_empty=True,
|
|
1217
|
+
)
|
|
1218
|
+
|
|
1219
|
+
raw_scales = dict(proposed_scales)
|
|
1220
|
+
|
|
1221
|
+
# Filter raw_scales to only those that have corresponding target modules
|
|
1222
|
+
# This is critical when scope limits targets (e.g., scope=ffn only has mlp targets)
|
|
1223
|
+
# Only apply this filtering when target modules have been resolved
|
|
1224
|
+
if self._target_modules:
|
|
1225
|
+
filtered_raw_scales: dict[str, float] = {}
|
|
1226
|
+
for scale_name, scale_value in raw_scales.items():
|
|
1227
|
+
# Convert scale name to target module name format
|
|
1228
|
+
target_name = self._normalize_scale_name(scale_name)
|
|
1229
|
+
if target_name in self._target_modules:
|
|
1230
|
+
filtered_raw_scales[scale_name] = scale_value
|
|
1231
|
+
elif self._is_focus_match(scale_name):
|
|
1232
|
+
# Fallback: check if any target module matches via pattern
|
|
1233
|
+
for tm_name in self._target_modules:
|
|
1234
|
+
if self._scale_matches_target(scale_name, tm_name):
|
|
1235
|
+
filtered_raw_scales[scale_name] = scale_value
|
|
1236
|
+
break
|
|
1237
|
+
raw_scales = filtered_raw_scales
|
|
1238
|
+
|
|
1239
|
+
focus_raw_scales = {
|
|
1240
|
+
self._normalize_scale_name(name): scale
|
|
1241
|
+
for name, scale in raw_scales.items()
|
|
1242
|
+
if self._is_focus_match(name)
|
|
1243
|
+
}
|
|
1244
|
+
if focus_raw_scales:
|
|
1245
|
+
self._log_event(
|
|
1246
|
+
"variance_raw_scales",
|
|
1247
|
+
message="Captured raw VE scales",
|
|
1248
|
+
count=len(focus_raw_scales),
|
|
1249
|
+
min_scale=min(focus_raw_scales.values()),
|
|
1250
|
+
max_scale=max(focus_raw_scales.values()),
|
|
1251
|
+
)
|
|
1252
|
+
self._stats.setdefault("raw_scales_observations", []).append(
|
|
1253
|
+
{
|
|
1254
|
+
"timestamp": datetime.utcnow().isoformat(),
|
|
1255
|
+
"count": len(focus_raw_scales),
|
|
1256
|
+
"scales": focus_raw_scales,
|
|
1257
|
+
}
|
|
1258
|
+
)
|
|
1259
|
+
|
|
1260
|
+
# Restore original state since we only wanted the proposed scales
|
|
1261
|
+
model.load_state_dict(original_state)
|
|
1262
|
+
|
|
1263
|
+
filtered_scales: dict[str, float] = {}
|
|
1264
|
+
raw_delta_map: dict[str, float] = {}
|
|
1265
|
+
min_abs = float(max(self._policy.get("min_abs_adjust", 0.0), 0.0))
|
|
1266
|
+
max_step = float(max(self._policy.get("max_scale_step", 0.0), 0.0))
|
|
1267
|
+
topk = int(max(self._policy.get("topk_backstop", 0) or 0, 0))
|
|
1268
|
+
best_candidate: tuple[str, float] | None = None
|
|
1269
|
+
best_delta = 0.0
|
|
1270
|
+
|
|
1271
|
+
for name, scale in raw_scales.items():
|
|
1272
|
+
normalized_name = self._normalize_scale_name(name)
|
|
1273
|
+
if not self._is_focus_match(normalized_name):
|
|
1274
|
+
continue
|
|
1275
|
+
|
|
1276
|
+
raw_delta = abs(scale - 1.0)
|
|
1277
|
+
raw_delta_map[name] = raw_delta
|
|
1278
|
+
|
|
1279
|
+
delta = raw_delta
|
|
1280
|
+
if delta > best_delta:
|
|
1281
|
+
best_candidate = (name, scale)
|
|
1282
|
+
best_delta = delta
|
|
1283
|
+
|
|
1284
|
+
if delta < min_abs:
|
|
1285
|
+
continue
|
|
1286
|
+
|
|
1287
|
+
if max_step > 0.0:
|
|
1288
|
+
limited_delta = min(delta, max_step)
|
|
1289
|
+
scale = 1.0 + math.copysign(limited_delta, scale - 1.0)
|
|
1290
|
+
|
|
1291
|
+
filtered_scales[name] = scale
|
|
1292
|
+
|
|
1293
|
+
backstop_used = False
|
|
1294
|
+
if not filtered_scales and topk > 0 and best_candidate:
|
|
1295
|
+
name, scale = best_candidate
|
|
1296
|
+
deadband = float(self._policy.get("deadband", 0.0) or 0.0)
|
|
1297
|
+
threshold = max(deadband * 0.5, min_abs)
|
|
1298
|
+
if best_delta >= threshold:
|
|
1299
|
+
if max_step > 0.0:
|
|
1300
|
+
limited_delta = min(best_delta, max_step)
|
|
1301
|
+
scale = 1.0 + math.copysign(limited_delta, scale - 1.0)
|
|
1302
|
+
filtered_scales[name] = scale
|
|
1303
|
+
raw_delta_map.setdefault(name, best_delta)
|
|
1304
|
+
backstop_used = True
|
|
1305
|
+
|
|
1306
|
+
trimmed_to_limit = False
|
|
1307
|
+
max_adjusted = int(max(self._policy.get("max_adjusted_modules", 0) or 0, 0))
|
|
1308
|
+
if max_adjusted > 0 and len(filtered_scales) > max_adjusted:
|
|
1309
|
+
sorted_candidates = sorted(
|
|
1310
|
+
filtered_scales.items(),
|
|
1311
|
+
key=lambda item: (
|
|
1312
|
+
raw_delta_map.get(item[0], abs(item[1] - 1.0))
|
|
1313
|
+
+ (2.0 if item[1] >= 1.0 else 0.0),
|
|
1314
|
+
raw_delta_map.get(item[0], abs(item[1] - 1.0)),
|
|
1315
|
+
item[1],
|
|
1316
|
+
),
|
|
1317
|
+
reverse=True,
|
|
1318
|
+
)
|
|
1319
|
+
filtered_scales = dict(sorted_candidates[:max_adjusted])
|
|
1320
|
+
trimmed_to_limit = True
|
|
1321
|
+
|
|
1322
|
+
self._raw_scales = raw_scales
|
|
1323
|
+
if backstop_used:
|
|
1324
|
+
self._log_event(
|
|
1325
|
+
"scale_backstop",
|
|
1326
|
+
message=f"Top-{topk} backstop injected {len(filtered_scales)} scale",
|
|
1327
|
+
count=len(filtered_scales),
|
|
1328
|
+
candidate=best_candidate[0] if best_candidate else None,
|
|
1329
|
+
candidate_normalized=self._normalize_scale_name(best_candidate[0])
|
|
1330
|
+
if best_candidate
|
|
1331
|
+
else None,
|
|
1332
|
+
delta=best_delta,
|
|
1333
|
+
)
|
|
1334
|
+
if trimmed_to_limit:
|
|
1335
|
+
self._log_event(
|
|
1336
|
+
"scale_limit",
|
|
1337
|
+
message="Trimmed VE scales to max_adjusted_modules",
|
|
1338
|
+
limit=max_adjusted,
|
|
1339
|
+
count=len(filtered_scales),
|
|
1340
|
+
)
|
|
1341
|
+
|
|
1342
|
+
filtered_normalized = {
|
|
1343
|
+
self._normalize_scale_name(name): scale
|
|
1344
|
+
for name, scale in filtered_scales.items()
|
|
1345
|
+
}
|
|
1346
|
+
self._stats.setdefault("filtered_scales_observations", []).append(
|
|
1347
|
+
{
|
|
1348
|
+
"timestamp": datetime.utcnow().isoformat(),
|
|
1349
|
+
"count": len(filtered_normalized),
|
|
1350
|
+
"scales": filtered_normalized,
|
|
1351
|
+
"backstop_used": backstop_used,
|
|
1352
|
+
}
|
|
1353
|
+
)
|
|
1354
|
+
|
|
1355
|
+
return filtered_scales
|
|
1356
|
+
|
|
1357
|
+
except Exception as e:
|
|
1358
|
+
# Restore state on any error
|
|
1359
|
+
model.load_state_dict(original_state)
|
|
1360
|
+
raise e
|
|
1361
|
+
|
|
1362
|
+
def _evaluate_calibration_pass(
|
|
1363
|
+
self,
|
|
1364
|
+
model: nn.Module,
|
|
1365
|
+
calibration_batches: list[Any],
|
|
1366
|
+
min_coverage: int,
|
|
1367
|
+
calib_seed: int,
|
|
1368
|
+
tag: str,
|
|
1369
|
+
) -> None:
|
|
1370
|
+
"""Run deterministic calibration for A/B evaluation and predictive gating."""
|
|
1371
|
+
predictive_state: dict[str, Any] = {
|
|
1372
|
+
"evaluated": False,
|
|
1373
|
+
"passed": not bool(self._policy.get("predictive_gate", True)),
|
|
1374
|
+
"reason": "disabled"
|
|
1375
|
+
if not bool(self._policy.get("predictive_gate", True))
|
|
1376
|
+
else "no_calibration",
|
|
1377
|
+
"delta_ci": (None, None),
|
|
1378
|
+
"gain_ci": (None, None),
|
|
1379
|
+
"mean_delta": None,
|
|
1380
|
+
}
|
|
1381
|
+
|
|
1382
|
+
requested = len(calibration_batches)
|
|
1383
|
+
self._calibration_stats.update(
|
|
1384
|
+
{
|
|
1385
|
+
"requested": requested,
|
|
1386
|
+
"coverage": 0,
|
|
1387
|
+
"min_coverage": min_coverage,
|
|
1388
|
+
"seed": calib_seed,
|
|
1389
|
+
"status": "no_calibration"
|
|
1390
|
+
if not calibration_batches
|
|
1391
|
+
else "insufficient",
|
|
1392
|
+
"tag": tag,
|
|
1393
|
+
}
|
|
1394
|
+
)
|
|
1395
|
+
self._stats.setdefault("calibration", {})
|
|
1396
|
+
self._stats["calibration"].update(
|
|
1397
|
+
{
|
|
1398
|
+
"requested": requested,
|
|
1399
|
+
"min_coverage": min_coverage,
|
|
1400
|
+
"seed": calib_seed,
|
|
1401
|
+
"tag": tag,
|
|
1402
|
+
}
|
|
1403
|
+
)
|
|
1404
|
+
|
|
1405
|
+
fingerprint = self._fingerprint_targets()
|
|
1406
|
+
if fingerprint:
|
|
1407
|
+
self._stats["target_fingerprint"] = fingerprint
|
|
1408
|
+
|
|
1409
|
+
if not calibration_batches:
|
|
1410
|
+
self._ratio_ci = None
|
|
1411
|
+
self._predictive_gate_state = predictive_state
|
|
1412
|
+
self._stats["predictive_gate"] = predictive_state.copy()
|
|
1413
|
+
return
|
|
1414
|
+
|
|
1415
|
+
device = next(model.parameters()).device
|
|
1416
|
+
torch.manual_seed(calib_seed)
|
|
1417
|
+
ppl_no_ve_samples, loss_no_ve_samples = self._compute_ppl_for_batches(
|
|
1418
|
+
model, calibration_batches, device
|
|
1419
|
+
)
|
|
1420
|
+
coverage = min(len(calibration_batches), len(ppl_no_ve_samples))
|
|
1421
|
+
ppl_with_ve_samples: list[float] = []
|
|
1422
|
+
loss_with_ve_samples: list[float] = []
|
|
1423
|
+
ratio_ci: tuple[float, float] | None = None
|
|
1424
|
+
|
|
1425
|
+
enable_success = False
|
|
1426
|
+
if coverage >= min_coverage and self._scales:
|
|
1427
|
+
prev_enable_attempts = self._enable_attempt_count
|
|
1428
|
+
prev_disable_attempts = self._disable_attempt_count
|
|
1429
|
+
prev_prepared_flag = self._prepared
|
|
1430
|
+
try:
|
|
1431
|
+
self._prepared = True
|
|
1432
|
+
enable_success = self.enable(model)
|
|
1433
|
+
finally:
|
|
1434
|
+
self._prepared = prev_prepared_flag
|
|
1435
|
+
try:
|
|
1436
|
+
torch.manual_seed(calib_seed)
|
|
1437
|
+
if enable_success:
|
|
1438
|
+
ppl_with_ve_samples, loss_with_ve_samples = (
|
|
1439
|
+
self._compute_ppl_for_batches(
|
|
1440
|
+
model, calibration_batches, device
|
|
1441
|
+
)
|
|
1442
|
+
)
|
|
1443
|
+
finally:
|
|
1444
|
+
if enable_success:
|
|
1445
|
+
self.disable(model)
|
|
1446
|
+
# Restore attempt counters to avoid skewing metrics
|
|
1447
|
+
self._enable_attempt_count = prev_enable_attempts
|
|
1448
|
+
self._disable_attempt_count = prev_disable_attempts
|
|
1449
|
+
|
|
1450
|
+
coverage = min(
|
|
1451
|
+
coverage,
|
|
1452
|
+
len(ppl_with_ve_samples) if ppl_with_ve_samples else coverage,
|
|
1453
|
+
len(loss_with_ve_samples) if loss_with_ve_samples else coverage,
|
|
1454
|
+
)
|
|
1455
|
+
self._calibration_stats.update(
|
|
1456
|
+
{
|
|
1457
|
+
"coverage": coverage,
|
|
1458
|
+
"status": "insufficient" if coverage < min_coverage else "pending",
|
|
1459
|
+
}
|
|
1460
|
+
)
|
|
1461
|
+
|
|
1462
|
+
window_ids = self._calibration_window_ids
|
|
1463
|
+
status_a = "evaluated" if coverage > 0 else "no_data"
|
|
1464
|
+
self._record_ab_provenance(
|
|
1465
|
+
"condition_a",
|
|
1466
|
+
tag=tag,
|
|
1467
|
+
mode="edited_no_ve",
|
|
1468
|
+
window_ids=window_ids,
|
|
1469
|
+
fingerprint=fingerprint,
|
|
1470
|
+
status=status_a,
|
|
1471
|
+
)
|
|
1472
|
+
|
|
1473
|
+
if coverage >= min_coverage and not self._scales:
|
|
1474
|
+
ppl_no_ve_samples = ppl_no_ve_samples[:coverage]
|
|
1475
|
+
ppl_no_ve_mean = float(np.mean(ppl_no_ve_samples))
|
|
1476
|
+
self.set_ab_results(
|
|
1477
|
+
ppl_no_ve=ppl_no_ve_mean,
|
|
1478
|
+
ppl_with_ve=ppl_no_ve_mean,
|
|
1479
|
+
windows_used=coverage,
|
|
1480
|
+
seed_used=calib_seed,
|
|
1481
|
+
ratio_ci=(1.0, 1.0),
|
|
1482
|
+
)
|
|
1483
|
+
self._calibration_stats.update(
|
|
1484
|
+
{
|
|
1485
|
+
"status": "no_scaling_required",
|
|
1486
|
+
"ppl_no_ve": ppl_no_ve_mean,
|
|
1487
|
+
"ratio_ci": (1.0, 1.0),
|
|
1488
|
+
}
|
|
1489
|
+
)
|
|
1490
|
+
self._stats["ab_point_estimates"] = {
|
|
1491
|
+
"tag": tag,
|
|
1492
|
+
"ppl_no_ve": ppl_no_ve_mean,
|
|
1493
|
+
"ppl_with_ve": ppl_no_ve_mean,
|
|
1494
|
+
}
|
|
1495
|
+
self._record_ab_provenance(
|
|
1496
|
+
"condition_b",
|
|
1497
|
+
tag=tag,
|
|
1498
|
+
mode="virtual_ve",
|
|
1499
|
+
window_ids=window_ids,
|
|
1500
|
+
fingerprint=fingerprint,
|
|
1501
|
+
status="no_scales",
|
|
1502
|
+
)
|
|
1503
|
+
predictive_state["evaluated"] = True
|
|
1504
|
+
predictive_state["passed"] = False
|
|
1505
|
+
predictive_state["reason"] = "no_scales"
|
|
1506
|
+
self._predictive_gate_state = predictive_state
|
|
1507
|
+
self._stats["predictive_gate"] = predictive_state.copy()
|
|
1508
|
+
return
|
|
1509
|
+
|
|
1510
|
+
if coverage >= min_coverage and ppl_with_ve_samples and loss_with_ve_samples:
|
|
1511
|
+
ppl_no_ve_samples = ppl_no_ve_samples[:coverage]
|
|
1512
|
+
loss_no_ve_samples = loss_no_ve_samples[:coverage]
|
|
1513
|
+
ppl_with_ve_samples = ppl_with_ve_samples[:coverage]
|
|
1514
|
+
loss_with_ve_samples = loss_with_ve_samples[:coverage]
|
|
1515
|
+
|
|
1516
|
+
ratios = [
|
|
1517
|
+
with_val / no_val
|
|
1518
|
+
for with_val, no_val in zip(
|
|
1519
|
+
ppl_with_ve_samples, ppl_no_ve_samples, strict=False
|
|
1520
|
+
)
|
|
1521
|
+
if no_val > 0
|
|
1522
|
+
]
|
|
1523
|
+
if ratios:
|
|
1524
|
+
ratio_ci = self._bootstrap_mean_ci(
|
|
1525
|
+
ratios,
|
|
1526
|
+
alpha=self._policy.get("alpha", 0.05),
|
|
1527
|
+
n_bootstrap=500,
|
|
1528
|
+
seed=calib_seed,
|
|
1529
|
+
)
|
|
1530
|
+
ppl_no_ve_mean = float(np.mean(ppl_no_ve_samples))
|
|
1531
|
+
ppl_with_ve_mean = float(np.mean(ppl_with_ve_samples))
|
|
1532
|
+
self.set_ab_results(
|
|
1533
|
+
ppl_no_ve=ppl_no_ve_mean,
|
|
1534
|
+
ppl_with_ve=ppl_with_ve_mean,
|
|
1535
|
+
windows_used=coverage,
|
|
1536
|
+
seed_used=calib_seed,
|
|
1537
|
+
ratio_ci=ratio_ci,
|
|
1538
|
+
)
|
|
1539
|
+
self._calibration_stats.update(
|
|
1540
|
+
{
|
|
1541
|
+
"status": "complete",
|
|
1542
|
+
"ppl_no_ve": ppl_no_ve_mean,
|
|
1543
|
+
"ppl_with_ve": ppl_with_ve_mean,
|
|
1544
|
+
"ratio_ci": ratio_ci,
|
|
1545
|
+
}
|
|
1546
|
+
)
|
|
1547
|
+
self._record_ab_provenance(
|
|
1548
|
+
"condition_b",
|
|
1549
|
+
tag=tag,
|
|
1550
|
+
mode="virtual_ve",
|
|
1551
|
+
window_ids=window_ids,
|
|
1552
|
+
fingerprint=fingerprint,
|
|
1553
|
+
status="evaluated",
|
|
1554
|
+
)
|
|
1555
|
+
self._stats["ab_point_estimates"] = {
|
|
1556
|
+
"tag": tag,
|
|
1557
|
+
"ppl_no_ve": ppl_no_ve_mean,
|
|
1558
|
+
"ppl_with_ve": ppl_with_ve_mean,
|
|
1559
|
+
"coverage": coverage,
|
|
1560
|
+
}
|
|
1561
|
+
|
|
1562
|
+
delta_ci: tuple[float, float] | None = None
|
|
1563
|
+
try:
|
|
1564
|
+
delta_ci = compute_paired_delta_log_ci(
|
|
1565
|
+
loss_with_ve_samples,
|
|
1566
|
+
loss_no_ve_samples,
|
|
1567
|
+
method="bca",
|
|
1568
|
+
replicates=500,
|
|
1569
|
+
alpha=self._policy.get("alpha", 0.05),
|
|
1570
|
+
seed=calib_seed + 211,
|
|
1571
|
+
)
|
|
1572
|
+
except Exception as exc:
|
|
1573
|
+
delta_ci = None
|
|
1574
|
+
self._log_event(
|
|
1575
|
+
"predictive_gate_error",
|
|
1576
|
+
level="WARN",
|
|
1577
|
+
message="Failed to compute predictive ΔlogNLL CI",
|
|
1578
|
+
error=str(exc),
|
|
1579
|
+
)
|
|
1580
|
+
|
|
1581
|
+
predictive_state["evaluated"] = True
|
|
1582
|
+
mean_delta = float(
|
|
1583
|
+
np.mean(
|
|
1584
|
+
[
|
|
1585
|
+
with_loss - no_loss
|
|
1586
|
+
for with_loss, no_loss in zip(
|
|
1587
|
+
loss_with_ve_samples,
|
|
1588
|
+
loss_no_ve_samples,
|
|
1589
|
+
strict=False,
|
|
1590
|
+
)
|
|
1591
|
+
]
|
|
1592
|
+
)
|
|
1593
|
+
)
|
|
1594
|
+
predictive_state["mean_delta"] = mean_delta
|
|
1595
|
+
|
|
1596
|
+
if delta_ci is not None and all(
|
|
1597
|
+
isinstance(val, (int | float)) and math.isfinite(val)
|
|
1598
|
+
for val in delta_ci
|
|
1599
|
+
):
|
|
1600
|
+
delta_ci = (float(delta_ci[0]), float(delta_ci[1]))
|
|
1601
|
+
gain_ci = (-delta_ci[1], -delta_ci[0])
|
|
1602
|
+
predictive_state["delta_ci"] = delta_ci
|
|
1603
|
+
predictive_state["gain_ci"] = gain_ci
|
|
1604
|
+
|
|
1605
|
+
if not self._policy.get("predictive_gate", True):
|
|
1606
|
+
predictive_state["passed"] = True
|
|
1607
|
+
predictive_state["reason"] = "disabled"
|
|
1608
|
+
else:
|
|
1609
|
+
one_sided = bool(self._policy.get("predictive_one_sided", False))
|
|
1610
|
+
min_effect = float(
|
|
1611
|
+
self._policy.get("min_effect_lognll", 0.0) or 0.0
|
|
1612
|
+
)
|
|
1613
|
+
passed, reason = _predictive_gate_outcome(
|
|
1614
|
+
mean_delta=mean_delta,
|
|
1615
|
+
delta_ci=delta_ci,
|
|
1616
|
+
min_effect=min_effect,
|
|
1617
|
+
one_sided=one_sided,
|
|
1618
|
+
)
|
|
1619
|
+
predictive_state["passed"] = passed
|
|
1620
|
+
predictive_state["reason"] = reason
|
|
1621
|
+
else:
|
|
1622
|
+
predictive_state["delta_ci"] = (None, None)
|
|
1623
|
+
predictive_state["gain_ci"] = (None, None)
|
|
1624
|
+
predictive_state["reason"] = (
|
|
1625
|
+
predictive_state.get("reason", "ci_unavailable")
|
|
1626
|
+
if predictive_state.get("reason") != "disabled"
|
|
1627
|
+
else "disabled"
|
|
1628
|
+
)
|
|
1629
|
+
else:
|
|
1630
|
+
# Fail-open monitor mode
|
|
1631
|
+
self._ratio_ci = None
|
|
1632
|
+
self._log_event(
|
|
1633
|
+
"prepare_monitor_mode",
|
|
1634
|
+
level="WARN",
|
|
1635
|
+
message="VE calibration coverage insufficient; guard will monitor only",
|
|
1636
|
+
requested=requested,
|
|
1637
|
+
coverage=coverage,
|
|
1638
|
+
min_coverage=min_coverage,
|
|
1639
|
+
tag=tag,
|
|
1640
|
+
)
|
|
1641
|
+
if predictive_state.get("reason") not in {"disabled"}:
|
|
1642
|
+
if coverage < min_coverage:
|
|
1643
|
+
predictive_state["reason"] = "insufficient_coverage"
|
|
1644
|
+
elif not self._scales:
|
|
1645
|
+
predictive_state["reason"] = "no_scales"
|
|
1646
|
+
elif not ppl_with_ve_samples:
|
|
1647
|
+
predictive_state["reason"] = "ve_enable_failed"
|
|
1648
|
+
|
|
1649
|
+
if "condition_b" not in self._stats.get("ab_provenance", {}):
|
|
1650
|
+
self._record_ab_provenance(
|
|
1651
|
+
"condition_b",
|
|
1652
|
+
tag=tag,
|
|
1653
|
+
mode="virtual_ve",
|
|
1654
|
+
window_ids=window_ids,
|
|
1655
|
+
fingerprint=fingerprint,
|
|
1656
|
+
status="not_evaluated",
|
|
1657
|
+
)
|
|
1658
|
+
|
|
1659
|
+
if (
|
|
1660
|
+
"ab_point_estimates" not in self._stats
|
|
1661
|
+
or self._stats["ab_point_estimates"].get("tag") != tag
|
|
1662
|
+
):
|
|
1663
|
+
ppl_no_ve_mean = (
|
|
1664
|
+
float(np.mean(ppl_no_ve_samples[:coverage])) if coverage > 0 else None
|
|
1665
|
+
)
|
|
1666
|
+
ppl_with_ve_mean = (
|
|
1667
|
+
float(np.mean(ppl_with_ve_samples[:coverage]))
|
|
1668
|
+
if ppl_with_ve_samples and coverage > 0
|
|
1669
|
+
else None
|
|
1670
|
+
)
|
|
1671
|
+
self._stats["ab_point_estimates"] = {
|
|
1672
|
+
"tag": tag,
|
|
1673
|
+
"ppl_no_ve": ppl_no_ve_mean,
|
|
1674
|
+
"ppl_with_ve": ppl_with_ve_mean,
|
|
1675
|
+
"coverage": coverage,
|
|
1676
|
+
}
|
|
1677
|
+
|
|
1678
|
+
self._predictive_gate_state = predictive_state
|
|
1679
|
+
self._stats["predictive_gate"] = predictive_state.copy()
|
|
1680
|
+
|
|
1681
|
+
def _refresh_after_edit_metrics(
|
|
1682
|
+
self,
|
|
1683
|
+
model: nn.Module,
|
|
1684
|
+
tag: str = "post_edit",
|
|
1685
|
+
adapter: Any | None = None,
|
|
1686
|
+
) -> None:
|
|
1687
|
+
"""Ensure VE metrics are recomputed on the edited model."""
|
|
1688
|
+
if not self._prepared:
|
|
1689
|
+
return
|
|
1690
|
+
if self._post_edit_evaluated and tag == "post_edit":
|
|
1691
|
+
return
|
|
1692
|
+
if not self._calibration_batches:
|
|
1693
|
+
self._log_event(
|
|
1694
|
+
"post_edit_calibration_skipped",
|
|
1695
|
+
level="WARN",
|
|
1696
|
+
message="Skipping post-edit VE evaluation (no calibration batches)",
|
|
1697
|
+
)
|
|
1698
|
+
self._post_edit_evaluated = True
|
|
1699
|
+
return
|
|
1700
|
+
|
|
1701
|
+
# Refresh target modules in case adapters swapped modules during edit
|
|
1702
|
+
adapter_ref = adapter or self._adapter_ref
|
|
1703
|
+
self._target_modules = self._resolve_target_modules(model, adapter_ref)
|
|
1704
|
+
self._stats["target_module_names"] = sorted(self._target_modules.keys())
|
|
1705
|
+
|
|
1706
|
+
# Recompute scales against the edited model
|
|
1707
|
+
try:
|
|
1708
|
+
self._scales = self._compute_variance_scales(
|
|
1709
|
+
model, self._calibration_batches
|
|
1710
|
+
)
|
|
1711
|
+
except Exception as exc:
|
|
1712
|
+
self._log_event(
|
|
1713
|
+
"post_edit_scale_failure",
|
|
1714
|
+
level="ERROR",
|
|
1715
|
+
message="Failed to recompute VE scales after edit",
|
|
1716
|
+
error=str(exc),
|
|
1717
|
+
)
|
|
1718
|
+
self._scales = {}
|
|
1719
|
+
|
|
1720
|
+
if self._focus_modules:
|
|
1721
|
+
self._scales = {
|
|
1722
|
+
name: scale
|
|
1723
|
+
for name, scale in self._scales.items()
|
|
1724
|
+
if self._is_focus_match(name)
|
|
1725
|
+
}
|
|
1726
|
+
|
|
1727
|
+
self._stats.setdefault(
|
|
1728
|
+
"target_module_names", sorted(self._target_modules.keys())
|
|
1729
|
+
)
|
|
1730
|
+
self._stats["target_modules_post_edit"] = list(self._target_modules.keys())
|
|
1731
|
+
normalized_post_scales = {
|
|
1732
|
+
self._normalize_scale_name(name): scale
|
|
1733
|
+
for name, scale in self._scales.items()
|
|
1734
|
+
}
|
|
1735
|
+
self._stats["proposed_scales_post_edit"] = normalized_post_scales.copy()
|
|
1736
|
+
self._stats["raw_scales_post_edit"] = self._raw_scales.copy()
|
|
1737
|
+
self._stats["raw_scales_post_edit_normalized"] = {
|
|
1738
|
+
self._normalize_scale_name(name): scale
|
|
1739
|
+
for name, scale in self._raw_scales.items()
|
|
1740
|
+
}
|
|
1741
|
+
self._raw_scales_post_edit = {
|
|
1742
|
+
self._normalize_scale_name(name): scale
|
|
1743
|
+
for name, scale in self._raw_scales.items()
|
|
1744
|
+
if self._is_focus_match(name)
|
|
1745
|
+
}
|
|
1746
|
+
if normalized_post_scales:
|
|
1747
|
+
self._log_event(
|
|
1748
|
+
"post_edit_scales",
|
|
1749
|
+
message="Post-edit VE proposed scales",
|
|
1750
|
+
count=len(normalized_post_scales),
|
|
1751
|
+
min_scale=min(normalized_post_scales.values()),
|
|
1752
|
+
max_scale=max(normalized_post_scales.values()),
|
|
1753
|
+
)
|
|
1754
|
+
|
|
1755
|
+
calibration_cfg = self._policy.get("calibration", {})
|
|
1756
|
+
requested_windows = int(calibration_cfg.get("windows", 0) or 0)
|
|
1757
|
+
min_coverage = int(
|
|
1758
|
+
calibration_cfg.get(
|
|
1759
|
+
"min_coverage",
|
|
1760
|
+
max(1, requested_windows // 2 if requested_windows else 1),
|
|
1761
|
+
)
|
|
1762
|
+
)
|
|
1763
|
+
calib_seed = int(calibration_cfg.get("seed", self._policy.get("seed", 123)))
|
|
1764
|
+
|
|
1765
|
+
self._calibration_stats = {
|
|
1766
|
+
"requested": len(self._calibration_batches)
|
|
1767
|
+
if requested_windows == 0
|
|
1768
|
+
else requested_windows,
|
|
1769
|
+
"coverage": 0,
|
|
1770
|
+
"min_coverage": min_coverage,
|
|
1771
|
+
"seed": calib_seed,
|
|
1772
|
+
"status": "pending",
|
|
1773
|
+
"tag": tag,
|
|
1774
|
+
}
|
|
1775
|
+
|
|
1776
|
+
self._evaluate_calibration_pass(
|
|
1777
|
+
model, self._calibration_batches, min_coverage, calib_seed, tag
|
|
1778
|
+
)
|
|
1779
|
+
self._post_edit_evaluated = True
|
|
1780
|
+
|
|
1781
|
+
def _collect_calibration_batches(self, dataloader, windows: int) -> list[Any]:
|
|
1782
|
+
"""Collect a deterministic slice of calibration batches."""
|
|
1783
|
+
batches: list[Any] = []
|
|
1784
|
+
iterator = iter(dataloader)
|
|
1785
|
+
for _ in range(max(windows, 0)):
|
|
1786
|
+
try:
|
|
1787
|
+
batches.append(next(iterator))
|
|
1788
|
+
except StopIteration:
|
|
1789
|
+
break
|
|
1790
|
+
return batches
|
|
1791
|
+
|
|
1792
|
+
def _prepare_batch_tensors(
|
|
1793
|
+
self, batch: Any, device: torch.device
|
|
1794
|
+
) -> tuple[torch.Tensor | None, torch.Tensor | None]:
|
|
1795
|
+
"""Normalize batch inputs to tensors on the target device."""
|
|
1796
|
+
if isinstance(batch, dict):
|
|
1797
|
+
input_ids = batch.get("input_ids", batch.get("inputs"))
|
|
1798
|
+
attention_mask = batch.get("attention_mask")
|
|
1799
|
+
elif isinstance(batch, tuple | list) and batch:
|
|
1800
|
+
input_ids = batch[0]
|
|
1801
|
+
attention_mask = batch[1] if len(batch) > 1 else None
|
|
1802
|
+
else:
|
|
1803
|
+
input_ids = batch
|
|
1804
|
+
attention_mask = None
|
|
1805
|
+
|
|
1806
|
+
if input_ids is None:
|
|
1807
|
+
return None, None
|
|
1808
|
+
|
|
1809
|
+
if not isinstance(input_ids, torch.Tensor):
|
|
1810
|
+
input_ids = torch.as_tensor(input_ids)
|
|
1811
|
+
|
|
1812
|
+
if input_ids.dim() == 1:
|
|
1813
|
+
input_ids = input_ids.unsqueeze(0)
|
|
1814
|
+
|
|
1815
|
+
try:
|
|
1816
|
+
input_ids = input_ids.to(device)
|
|
1817
|
+
except Exception:
|
|
1818
|
+
input_ids = input_ids.clone()
|
|
1819
|
+
|
|
1820
|
+
labels = input_ids.clone()
|
|
1821
|
+
|
|
1822
|
+
if attention_mask is not None:
|
|
1823
|
+
if not isinstance(attention_mask, torch.Tensor):
|
|
1824
|
+
attention_mask = torch.as_tensor(attention_mask)
|
|
1825
|
+
if attention_mask.dim() == 1:
|
|
1826
|
+
attention_mask = attention_mask.unsqueeze(0)
|
|
1827
|
+
try:
|
|
1828
|
+
attention_mask = attention_mask.to(device)
|
|
1829
|
+
except Exception:
|
|
1830
|
+
attention_mask = attention_mask.clone()
|
|
1831
|
+
labels = labels.masked_fill(attention_mask == 0, -100)
|
|
1832
|
+
|
|
1833
|
+
return input_ids, labels
|
|
1834
|
+
|
|
1835
|
+
def _compute_ppl_for_batches(
|
|
1836
|
+
self,
|
|
1837
|
+
model: nn.Module,
|
|
1838
|
+
batches: list[Any],
|
|
1839
|
+
device: torch.device,
|
|
1840
|
+
) -> tuple[list[float], list[float]]:
|
|
1841
|
+
"""Compute per-batch perplexity and log-loss values for deterministic calibration."""
|
|
1842
|
+
ppl_values: list[float] = []
|
|
1843
|
+
loss_values: list[float] = []
|
|
1844
|
+
if not batches:
|
|
1845
|
+
return ppl_values, loss_values
|
|
1846
|
+
|
|
1847
|
+
model_was_training = model.training
|
|
1848
|
+
model.eval()
|
|
1849
|
+
|
|
1850
|
+
with torch.no_grad():
|
|
1851
|
+
for batch in batches:
|
|
1852
|
+
try:
|
|
1853
|
+
inputs, labels = self._prepare_batch_tensors(batch, device)
|
|
1854
|
+
if inputs is None or labels is None:
|
|
1855
|
+
continue
|
|
1856
|
+
|
|
1857
|
+
try:
|
|
1858
|
+
outputs = model(inputs, labels=labels)
|
|
1859
|
+
except TypeError:
|
|
1860
|
+
outputs = model(inputs)
|
|
1861
|
+
loss_val = None
|
|
1862
|
+
if hasattr(outputs, "loss") and hasattr(outputs.loss, "item"):
|
|
1863
|
+
loss_val = outputs.loss.item()
|
|
1864
|
+
|
|
1865
|
+
if loss_val is None and isinstance(outputs, torch.Tensor):
|
|
1866
|
+
try:
|
|
1867
|
+
if labels is not None and outputs.shape == labels.shape:
|
|
1868
|
+
loss_val = torch.nn.functional.mse_loss(
|
|
1869
|
+
outputs.float(), labels.float()
|
|
1870
|
+
).item()
|
|
1871
|
+
else:
|
|
1872
|
+
loss_val = outputs.float().pow(2).mean().item()
|
|
1873
|
+
except Exception:
|
|
1874
|
+
loss_val = None
|
|
1875
|
+
|
|
1876
|
+
if loss_val is None or not math.isfinite(loss_val):
|
|
1877
|
+
continue
|
|
1878
|
+
|
|
1879
|
+
loss = float(loss_val)
|
|
1880
|
+
ppl = math.exp(loss)
|
|
1881
|
+
if math.isfinite(ppl):
|
|
1882
|
+
ppl_values.append(ppl)
|
|
1883
|
+
loss_values.append(loss)
|
|
1884
|
+
except Exception:
|
|
1885
|
+
continue
|
|
1886
|
+
|
|
1887
|
+
if model_was_training:
|
|
1888
|
+
model.train()
|
|
1889
|
+
|
|
1890
|
+
return ppl_values, loss_values
|
|
1891
|
+
|
|
1892
|
+
def _bootstrap_mean_ci(
|
|
1893
|
+
self,
|
|
1894
|
+
samples: list[float],
|
|
1895
|
+
alpha: float,
|
|
1896
|
+
n_bootstrap: int = 500,
|
|
1897
|
+
seed: int | None = None,
|
|
1898
|
+
) -> tuple[float, float]:
|
|
1899
|
+
"""Compute bootstrap confidence interval for the sample mean."""
|
|
1900
|
+
if not samples:
|
|
1901
|
+
raise ValueError("Cannot compute CI on empty samples")
|
|
1902
|
+
data = np.asarray(samples, dtype=float)
|
|
1903
|
+
rng = np.random.default_rng(seed)
|
|
1904
|
+
stats = np.empty(n_bootstrap, dtype=float)
|
|
1905
|
+
for i in range(n_bootstrap):
|
|
1906
|
+
indices = rng.integers(0, data.size, size=data.size)
|
|
1907
|
+
stats[i] = float(np.mean(data[indices]))
|
|
1908
|
+
lower = float(np.percentile(stats, 100 * (alpha / 2)))
|
|
1909
|
+
upper = float(np.percentile(stats, 100 * (1 - alpha / 2)))
|
|
1910
|
+
return lower, upper
|
|
1911
|
+
|
|
1912
|
+
def prepare(
|
|
1913
|
+
self,
|
|
1914
|
+
model: nn.Module,
|
|
1915
|
+
adapter=None,
|
|
1916
|
+
calib=None,
|
|
1917
|
+
policy: dict[str, Any] | None = None,
|
|
1918
|
+
) -> dict[str, Any]:
|
|
1919
|
+
"""
|
|
1920
|
+
Prepare variance guard by computing proposed scaling factors.
|
|
1921
|
+
|
|
1922
|
+
Args:
|
|
1923
|
+
model: The model that will be edited
|
|
1924
|
+
adapter: ModelAdapter (optional, for compatibility)
|
|
1925
|
+
calib: Calibration data for variance measurement
|
|
1926
|
+
policy: Guard policy parameters (optional)
|
|
1927
|
+
|
|
1928
|
+
Returns:
|
|
1929
|
+
Dictionary with preparation results and proposed scales
|
|
1930
|
+
"""
|
|
1931
|
+
start_time = time.time()
|
|
1932
|
+
|
|
1933
|
+
# Update policy if provided
|
|
1934
|
+
if policy:
|
|
1935
|
+
for key in [
|
|
1936
|
+
"min_gain",
|
|
1937
|
+
"max_calib",
|
|
1938
|
+
"scope",
|
|
1939
|
+
"clamp",
|
|
1940
|
+
"deadband",
|
|
1941
|
+
"seed",
|
|
1942
|
+
"mode",
|
|
1943
|
+
"min_rel_gain",
|
|
1944
|
+
"alpha",
|
|
1945
|
+
"tie_breaker_deadband",
|
|
1946
|
+
"min_effect_lognll",
|
|
1947
|
+
"min_abs_adjust",
|
|
1948
|
+
"max_scale_step",
|
|
1949
|
+
"topk_backstop",
|
|
1950
|
+
"max_adjusted_modules",
|
|
1951
|
+
"predictive_gate",
|
|
1952
|
+
"predictive_one_sided",
|
|
1953
|
+
"absolute_floor_ppl",
|
|
1954
|
+
"monitor_only",
|
|
1955
|
+
"calibration",
|
|
1956
|
+
"target_modules",
|
|
1957
|
+
]:
|
|
1958
|
+
if key in policy:
|
|
1959
|
+
self._policy[key] = policy[key]
|
|
1960
|
+
if self._policy.get("min_effect_lognll") is not None:
|
|
1961
|
+
self._policy["min_effect_lognll"] = float(
|
|
1962
|
+
self._policy["min_effect_lognll"]
|
|
1963
|
+
)
|
|
1964
|
+
self.TIE_BREAKER_DEADBAND = float(
|
|
1965
|
+
self._policy.get("tie_breaker_deadband", self.TIE_BREAKER_DEADBAND)
|
|
1966
|
+
)
|
|
1967
|
+
self._refresh_calibration_defaults()
|
|
1968
|
+
if "absolute_floor_ppl" in policy:
|
|
1969
|
+
self.ABSOLUTE_FLOOR = float(
|
|
1970
|
+
self._policy.get(
|
|
1971
|
+
"absolute_floor_pm",
|
|
1972
|
+
self._policy.get("absolute_floor_ppl", self.ABSOLUTE_FLOOR),
|
|
1973
|
+
)
|
|
1974
|
+
)
|
|
1975
|
+
if "target_modules" in policy:
|
|
1976
|
+
focus_list = [
|
|
1977
|
+
self._normalize_module_name(name)
|
|
1978
|
+
for name in (policy.get("target_modules") or [])
|
|
1979
|
+
if isinstance(name, str)
|
|
1980
|
+
]
|
|
1981
|
+
self._focus_modules = set(focus_list)
|
|
1982
|
+
if self._focus_modules:
|
|
1983
|
+
self._policy["target_modules"] = sorted(self._focus_modules)
|
|
1984
|
+
self._stats["focus_modules"] = sorted(self._focus_modules)
|
|
1985
|
+
|
|
1986
|
+
self._log_event(
|
|
1987
|
+
"prepare",
|
|
1988
|
+
message=f"Preparing variance guard with scope={self._policy['scope']}, min_gain={self._policy['min_gain']}",
|
|
1989
|
+
)
|
|
1990
|
+
|
|
1991
|
+
try:
|
|
1992
|
+
# Resolve target modules
|
|
1993
|
+
self._target_modules = self._resolve_target_modules(model, adapter)
|
|
1994
|
+
self._stats["target_module_names"] = sorted(self._target_modules.keys())
|
|
1995
|
+
|
|
1996
|
+
if not self._target_modules:
|
|
1997
|
+
self._prepared = False
|
|
1998
|
+
self._adapter_ref = adapter
|
|
1999
|
+
return {
|
|
2000
|
+
"baseline_metrics": {},
|
|
2001
|
+
"policy_applied": self._policy,
|
|
2002
|
+
"preparation_time": time.time() - start_time,
|
|
2003
|
+
"ready": False,
|
|
2004
|
+
"warning": "No target modules found for variance equalization",
|
|
2005
|
+
}
|
|
2006
|
+
|
|
2007
|
+
self._adapter_ref = adapter
|
|
2008
|
+
|
|
2009
|
+
calibration_cfg = self._policy.get("calibration", {})
|
|
2010
|
+
requested_windows = int(calibration_cfg.get("windows", 0) or 0)
|
|
2011
|
+
min_coverage = int(
|
|
2012
|
+
calibration_cfg.get(
|
|
2013
|
+
"min_coverage",
|
|
2014
|
+
max(1, requested_windows // 2 if requested_windows else 1),
|
|
2015
|
+
)
|
|
2016
|
+
)
|
|
2017
|
+
calib_seed = int(calibration_cfg.get("seed", self._policy.get("seed", 123)))
|
|
2018
|
+
|
|
2019
|
+
scale_windows = min(self._policy["max_calib"] // 10, 50)
|
|
2020
|
+
limit_for_batches = max(scale_windows, requested_windows)
|
|
2021
|
+
|
|
2022
|
+
calib_batches: list[Any] = []
|
|
2023
|
+
dataloader_source = None
|
|
2024
|
+
|
|
2025
|
+
if calib is not None:
|
|
2026
|
+
if hasattr(calib, "dataloader"):
|
|
2027
|
+
dataloader_source = calib.dataloader
|
|
2028
|
+
calib_batches = self._collect_calibration_batches(
|
|
2029
|
+
dataloader_source, limit_for_batches
|
|
2030
|
+
)
|
|
2031
|
+
elif isinstance(calib, Sequence):
|
|
2032
|
+
calib_batches = list(
|
|
2033
|
+
itertools.islice(iter(calib), limit_for_batches)
|
|
2034
|
+
)
|
|
2035
|
+
elif isinstance(calib, Iterable):
|
|
2036
|
+
calib_batches = list(
|
|
2037
|
+
itertools.islice(iter(calib), limit_for_batches)
|
|
2038
|
+
)
|
|
2039
|
+
|
|
2040
|
+
if calib_batches:
|
|
2041
|
+
self._scales = self._compute_variance_scales(model, calib_batches)
|
|
2042
|
+
else:
|
|
2043
|
+
self._scales = {}
|
|
2044
|
+
self._raw_scales = {}
|
|
2045
|
+
self._log_event(
|
|
2046
|
+
"prepare_warning",
|
|
2047
|
+
level="WARN",
|
|
2048
|
+
message="No calibration data provided, VE will be disabled",
|
|
2049
|
+
)
|
|
2050
|
+
|
|
2051
|
+
# Deterministic VE calibration pass for A/B readiness
|
|
2052
|
+
self._calibration_stats = {
|
|
2053
|
+
"requested": requested_windows,
|
|
2054
|
+
"coverage": 0,
|
|
2055
|
+
"min_coverage": min_coverage,
|
|
2056
|
+
"seed": calib_seed,
|
|
2057
|
+
"status": "skipped" if requested_windows == 0 else "insufficient",
|
|
2058
|
+
}
|
|
2059
|
+
|
|
2060
|
+
calibration_batches = calib_batches[:requested_windows]
|
|
2061
|
+
self._store_calibration_batches(calibration_batches)
|
|
2062
|
+
predictive_state: dict[str, Any] = {
|
|
2063
|
+
"evaluated": False,
|
|
2064
|
+
"passed": not bool(self._policy.get("predictive_gate", True)),
|
|
2065
|
+
"reason": "disabled"
|
|
2066
|
+
if not bool(self._policy.get("predictive_gate", True))
|
|
2067
|
+
else "no_calibration",
|
|
2068
|
+
"delta_ci": (None, None),
|
|
2069
|
+
"gain_ci": (None, None),
|
|
2070
|
+
"mean_delta": None,
|
|
2071
|
+
}
|
|
2072
|
+
|
|
2073
|
+
if calibration_batches:
|
|
2074
|
+
device = next(model.parameters()).device
|
|
2075
|
+
torch.manual_seed(calib_seed)
|
|
2076
|
+
ppl_no_ve_samples, loss_no_ve_samples = self._compute_ppl_for_batches(
|
|
2077
|
+
model, calibration_batches, device
|
|
2078
|
+
)
|
|
2079
|
+
coverage = min(len(calibration_batches), len(ppl_no_ve_samples))
|
|
2080
|
+
ppl_with_ve_samples: list[float] = []
|
|
2081
|
+
loss_with_ve_samples: list[float] = []
|
|
2082
|
+
ratio_ci: tuple[float, float] | None = None
|
|
2083
|
+
|
|
2084
|
+
enable_success = False
|
|
2085
|
+
if coverage >= min_coverage and self._scales:
|
|
2086
|
+
prev_enable_attempts = self._enable_attempt_count
|
|
2087
|
+
prev_disable_attempts = self._disable_attempt_count
|
|
2088
|
+
prev_prepared_flag = self._prepared
|
|
2089
|
+
try:
|
|
2090
|
+
self._prepared = True
|
|
2091
|
+
enable_success = self.enable(model)
|
|
2092
|
+
finally:
|
|
2093
|
+
self._prepared = prev_prepared_flag
|
|
2094
|
+
try:
|
|
2095
|
+
torch.manual_seed(calib_seed)
|
|
2096
|
+
if enable_success:
|
|
2097
|
+
(
|
|
2098
|
+
ppl_with_ve_samples,
|
|
2099
|
+
loss_with_ve_samples,
|
|
2100
|
+
) = self._compute_ppl_for_batches(
|
|
2101
|
+
model, calibration_batches, device
|
|
2102
|
+
)
|
|
2103
|
+
finally:
|
|
2104
|
+
if enable_success:
|
|
2105
|
+
self.disable(model)
|
|
2106
|
+
# Restore attempt counters to avoid skewing metrics
|
|
2107
|
+
self._enable_attempt_count = prev_enable_attempts
|
|
2108
|
+
self._disable_attempt_count = prev_disable_attempts
|
|
2109
|
+
|
|
2110
|
+
coverage = min(
|
|
2111
|
+
coverage,
|
|
2112
|
+
len(ppl_with_ve_samples) if ppl_with_ve_samples else coverage,
|
|
2113
|
+
len(loss_with_ve_samples) if loss_with_ve_samples else coverage,
|
|
2114
|
+
)
|
|
2115
|
+
self._calibration_stats.update(
|
|
2116
|
+
{"coverage": coverage, "status": "insufficient"}
|
|
2117
|
+
)
|
|
2118
|
+
|
|
2119
|
+
if coverage >= min_coverage and not self._scales:
|
|
2120
|
+
ppl_no_ve_samples = ppl_no_ve_samples[:coverage]
|
|
2121
|
+
ppl_no_ve_mean = float(np.mean(ppl_no_ve_samples))
|
|
2122
|
+
self.set_ab_results(
|
|
2123
|
+
ppl_no_ve=ppl_no_ve_mean,
|
|
2124
|
+
ppl_with_ve=ppl_no_ve_mean,
|
|
2125
|
+
windows_used=coverage,
|
|
2126
|
+
seed_used=calib_seed,
|
|
2127
|
+
ratio_ci=(1.0, 1.0),
|
|
2128
|
+
)
|
|
2129
|
+
self._calibration_stats.update(
|
|
2130
|
+
{
|
|
2131
|
+
"status": "no_scaling_required",
|
|
2132
|
+
"ppl_no_ve": ppl_no_ve_mean,
|
|
2133
|
+
"ratio_ci": (1.0, 1.0),
|
|
2134
|
+
}
|
|
2135
|
+
)
|
|
2136
|
+
|
|
2137
|
+
if (
|
|
2138
|
+
coverage >= min_coverage
|
|
2139
|
+
and ppl_with_ve_samples
|
|
2140
|
+
and loss_with_ve_samples
|
|
2141
|
+
):
|
|
2142
|
+
ppl_no_ve_samples = ppl_no_ve_samples[:coverage]
|
|
2143
|
+
loss_no_ve_samples = loss_no_ve_samples[:coverage]
|
|
2144
|
+
ppl_with_ve_samples = ppl_with_ve_samples[:coverage]
|
|
2145
|
+
loss_with_ve_samples = loss_with_ve_samples[:coverage]
|
|
2146
|
+
|
|
2147
|
+
ratios = [
|
|
2148
|
+
with_val / no_val
|
|
2149
|
+
for with_val, no_val in zip(
|
|
2150
|
+
ppl_with_ve_samples, ppl_no_ve_samples, strict=False
|
|
2151
|
+
)
|
|
2152
|
+
if no_val > 0
|
|
2153
|
+
]
|
|
2154
|
+
if ratios:
|
|
2155
|
+
ratio_ci = self._bootstrap_mean_ci(
|
|
2156
|
+
ratios,
|
|
2157
|
+
alpha=self._policy.get("alpha", 0.05),
|
|
2158
|
+
n_bootstrap=500,
|
|
2159
|
+
seed=calib_seed,
|
|
2160
|
+
)
|
|
2161
|
+
ppl_no_ve_mean = float(np.mean(ppl_no_ve_samples))
|
|
2162
|
+
ppl_with_ve_mean = float(np.mean(ppl_with_ve_samples))
|
|
2163
|
+
self.set_ab_results(
|
|
2164
|
+
ppl_no_ve=ppl_no_ve_mean,
|
|
2165
|
+
ppl_with_ve=ppl_with_ve_mean,
|
|
2166
|
+
windows_used=coverage,
|
|
2167
|
+
seed_used=calib_seed,
|
|
2168
|
+
ratio_ci=ratio_ci,
|
|
2169
|
+
)
|
|
2170
|
+
self._calibration_stats.update(
|
|
2171
|
+
{
|
|
2172
|
+
"status": "complete",
|
|
2173
|
+
"ppl_no_ve": ppl_no_ve_mean,
|
|
2174
|
+
"ppl_with_ve": ppl_with_ve_mean,
|
|
2175
|
+
"ratio_ci": ratio_ci,
|
|
2176
|
+
}
|
|
2177
|
+
)
|
|
2178
|
+
|
|
2179
|
+
delta_ci: tuple[float, float] | None = None
|
|
2180
|
+
try:
|
|
2181
|
+
delta_ci = compute_paired_delta_log_ci(
|
|
2182
|
+
loss_with_ve_samples,
|
|
2183
|
+
loss_no_ve_samples,
|
|
2184
|
+
method="bca",
|
|
2185
|
+
replicates=500,
|
|
2186
|
+
alpha=self._policy.get("alpha", 0.05),
|
|
2187
|
+
seed=calib_seed + 211,
|
|
2188
|
+
)
|
|
2189
|
+
except Exception as exc:
|
|
2190
|
+
delta_ci = None
|
|
2191
|
+
self._log_event(
|
|
2192
|
+
"predictive_gate_error",
|
|
2193
|
+
level="WARN",
|
|
2194
|
+
message="Failed to compute predictive ΔlogNLL CI",
|
|
2195
|
+
error=str(exc),
|
|
2196
|
+
)
|
|
2197
|
+
|
|
2198
|
+
predictive_state["evaluated"] = True
|
|
2199
|
+
mean_delta = float(
|
|
2200
|
+
np.mean(
|
|
2201
|
+
[
|
|
2202
|
+
with_loss - no_loss
|
|
2203
|
+
for with_loss, no_loss in zip(
|
|
2204
|
+
loss_with_ve_samples,
|
|
2205
|
+
loss_no_ve_samples,
|
|
2206
|
+
strict=False,
|
|
2207
|
+
)
|
|
2208
|
+
]
|
|
2209
|
+
)
|
|
2210
|
+
)
|
|
2211
|
+
predictive_state["mean_delta"] = mean_delta
|
|
2212
|
+
|
|
2213
|
+
if delta_ci is not None and all(
|
|
2214
|
+
isinstance(val, (int | float)) and math.isfinite(val)
|
|
2215
|
+
for val in delta_ci
|
|
2216
|
+
):
|
|
2217
|
+
delta_ci = (float(delta_ci[0]), float(delta_ci[1]))
|
|
2218
|
+
gain_ci = (-delta_ci[1], -delta_ci[0])
|
|
2219
|
+
predictive_state["delta_ci"] = delta_ci
|
|
2220
|
+
predictive_state["gain_ci"] = gain_ci
|
|
2221
|
+
|
|
2222
|
+
if not self._policy.get("predictive_gate", True):
|
|
2223
|
+
predictive_state["passed"] = True
|
|
2224
|
+
predictive_state["reason"] = "disabled"
|
|
2225
|
+
else:
|
|
2226
|
+
one_sided = bool(
|
|
2227
|
+
self._policy.get("predictive_one_sided", False)
|
|
2228
|
+
)
|
|
2229
|
+
min_effect = float(
|
|
2230
|
+
self._policy.get("min_effect_lognll", 0.0) or 0.0
|
|
2231
|
+
)
|
|
2232
|
+
passed, reason = _predictive_gate_outcome(
|
|
2233
|
+
mean_delta=mean_delta,
|
|
2234
|
+
delta_ci=delta_ci,
|
|
2235
|
+
min_effect=min_effect,
|
|
2236
|
+
one_sided=one_sided,
|
|
2237
|
+
)
|
|
2238
|
+
predictive_state["passed"] = passed
|
|
2239
|
+
predictive_state["reason"] = reason
|
|
2240
|
+
else:
|
|
2241
|
+
predictive_state["delta_ci"] = (None, None)
|
|
2242
|
+
predictive_state["gain_ci"] = (None, None)
|
|
2243
|
+
predictive_state["reason"] = (
|
|
2244
|
+
predictive_state.get("reason", "ci_unavailable")
|
|
2245
|
+
if predictive_state.get("reason") != "disabled"
|
|
2246
|
+
else "disabled"
|
|
2247
|
+
)
|
|
2248
|
+
else:
|
|
2249
|
+
# Fail-open monitor mode
|
|
2250
|
+
self._ratio_ci = None
|
|
2251
|
+
self._log_event(
|
|
2252
|
+
"prepare_monitor_mode",
|
|
2253
|
+
level="WARN",
|
|
2254
|
+
message="VE calibration coverage insufficient; guard will monitor only",
|
|
2255
|
+
requested=requested_windows,
|
|
2256
|
+
coverage=coverage,
|
|
2257
|
+
min_coverage=min_coverage,
|
|
2258
|
+
)
|
|
2259
|
+
if predictive_state.get("reason") not in {"disabled"}:
|
|
2260
|
+
if coverage < min_coverage:
|
|
2261
|
+
predictive_state["reason"] = "insufficient_coverage"
|
|
2262
|
+
elif not self._scales:
|
|
2263
|
+
predictive_state["reason"] = "no_scales"
|
|
2264
|
+
elif not ppl_with_ve_samples:
|
|
2265
|
+
predictive_state["reason"] = "ve_enable_failed"
|
|
2266
|
+
else:
|
|
2267
|
+
self._ratio_ci = None
|
|
2268
|
+
if predictive_state.get("reason") != "disabled":
|
|
2269
|
+
predictive_state["reason"] = "no_calibration"
|
|
2270
|
+
|
|
2271
|
+
self._predictive_gate_state = predictive_state
|
|
2272
|
+
|
|
2273
|
+
# Store baseline statistics without overwriting pre-populated instrumentation
|
|
2274
|
+
self._stats.setdefault(
|
|
2275
|
+
"target_module_names", sorted(self._target_modules.keys())
|
|
2276
|
+
)
|
|
2277
|
+
self._stats["target_modules"] = list(self._target_modules.keys())
|
|
2278
|
+
normalized_scales = {
|
|
2279
|
+
self._normalize_scale_name(name): scale
|
|
2280
|
+
for name, scale in self._scales.items()
|
|
2281
|
+
}
|
|
2282
|
+
self._stats["proposed_scales_pre_edit"] = normalized_scales.copy()
|
|
2283
|
+
self._stats["raw_scales_pre_edit"] = self._raw_scales.copy()
|
|
2284
|
+
self._stats["raw_scales_pre_edit_normalized"] = {
|
|
2285
|
+
self._normalize_scale_name(name): scale
|
|
2286
|
+
for name, scale in self._raw_scales.items()
|
|
2287
|
+
}
|
|
2288
|
+
self._stats["total_target_modules"] = len(self._target_modules)
|
|
2289
|
+
self._stats["modules_with_scales_pre_edit"] = len(self._scales)
|
|
2290
|
+
self._stats.setdefault("calibration", {}).update(
|
|
2291
|
+
self._calibration_stats.copy()
|
|
2292
|
+
)
|
|
2293
|
+
self._stats["scale_filtering"] = {
|
|
2294
|
+
"raw_scales": len(self._raw_scales),
|
|
2295
|
+
"filtered_scales": len(self._scales),
|
|
2296
|
+
"min_abs_adjust": float(self._policy.get("min_abs_adjust", 0.0)),
|
|
2297
|
+
"max_scale_step": float(self._policy.get("max_scale_step", 0.0)),
|
|
2298
|
+
"topk_backstop": int(self._policy.get("topk_backstop", 0)),
|
|
2299
|
+
}
|
|
2300
|
+
self._stats["predictive_gate"] = predictive_state.copy()
|
|
2301
|
+
self._calibration_stats_pre_edit = self._calibration_stats.copy()
|
|
2302
|
+
self._post_edit_evaluated = False
|
|
2303
|
+
self._raw_scales_pre_edit = {
|
|
2304
|
+
self._normalize_scale_name(name): scale
|
|
2305
|
+
for name, scale in self._raw_scales.items()
|
|
2306
|
+
}
|
|
2307
|
+
|
|
2308
|
+
self._prepared = True
|
|
2309
|
+
preparation_time = time.time() - start_time
|
|
2310
|
+
|
|
2311
|
+
self._log_event(
|
|
2312
|
+
"prepare_success",
|
|
2313
|
+
message=f"Prepared variance guard with {len(self._target_modules)} target modules",
|
|
2314
|
+
target_modules=len(self._target_modules),
|
|
2315
|
+
proposed_scales=len(self._scales),
|
|
2316
|
+
preparation_time=preparation_time,
|
|
2317
|
+
)
|
|
2318
|
+
|
|
2319
|
+
return {
|
|
2320
|
+
"baseline_metrics": {
|
|
2321
|
+
"target_modules": len(self._target_modules),
|
|
2322
|
+
"proposed_scales": len(self._scales),
|
|
2323
|
+
"scope": self._policy["scope"],
|
|
2324
|
+
"scale_statistics": {
|
|
2325
|
+
"mean_scale": float(
|
|
2326
|
+
sum(self._scales.values()) / len(self._scales)
|
|
2327
|
+
)
|
|
2328
|
+
if self._scales
|
|
2329
|
+
else 1.0,
|
|
2330
|
+
"min_scale": min(self._scales.values())
|
|
2331
|
+
if self._scales
|
|
2332
|
+
else 1.0,
|
|
2333
|
+
"max_scale": max(self._scales.values())
|
|
2334
|
+
if self._scales
|
|
2335
|
+
else 1.0,
|
|
2336
|
+
},
|
|
2337
|
+
"calibration": self._calibration_stats.copy(),
|
|
2338
|
+
},
|
|
2339
|
+
"policy_applied": self._policy.copy(),
|
|
2340
|
+
"preparation_time": preparation_time,
|
|
2341
|
+
"ready": True,
|
|
2342
|
+
}
|
|
2343
|
+
|
|
2344
|
+
except Exception as e:
|
|
2345
|
+
self._prepared = False
|
|
2346
|
+
self._adapter_ref = adapter
|
|
2347
|
+
self._log_event(
|
|
2348
|
+
"prepare_failed",
|
|
2349
|
+
level="ERROR",
|
|
2350
|
+
message=f"Failed to prepare variance guard: {str(e)}",
|
|
2351
|
+
error=str(e),
|
|
2352
|
+
)
|
|
2353
|
+
|
|
2354
|
+
return {
|
|
2355
|
+
"baseline_metrics": {},
|
|
2356
|
+
"policy_applied": self._policy,
|
|
2357
|
+
"preparation_time": time.time() - start_time,
|
|
2358
|
+
"ready": False,
|
|
2359
|
+
"error": str(e),
|
|
2360
|
+
}
|
|
2361
|
+
|
|
2362
|
+
def before_edit(self, model: nn.Module) -> None:
|
|
2363
|
+
"""
|
|
2364
|
+
Execute before edit (no action needed for variance guard).
|
|
2365
|
+
|
|
2366
|
+
Args:
|
|
2367
|
+
model: The model about to be edited
|
|
2368
|
+
"""
|
|
2369
|
+
if self._prepared:
|
|
2370
|
+
self._log_event(
|
|
2371
|
+
"before_edit", message="Variance guard ready for A/B testing"
|
|
2372
|
+
)
|
|
2373
|
+
|
|
2374
|
+
def after_edit(self, model: nn.Module) -> None:
|
|
2375
|
+
"""
|
|
2376
|
+
Execute after edit (A/B testing happens via enable/disable calls).
|
|
2377
|
+
|
|
2378
|
+
Args:
|
|
2379
|
+
model: The model that was just edited
|
|
2380
|
+
"""
|
|
2381
|
+
if not self._prepared:
|
|
2382
|
+
self._log_event(
|
|
2383
|
+
"after_edit_skipped",
|
|
2384
|
+
level="WARN",
|
|
2385
|
+
message="Variance guard not prepared, skipping",
|
|
2386
|
+
)
|
|
2387
|
+
return
|
|
2388
|
+
|
|
2389
|
+
self._refresh_after_edit_metrics(model)
|
|
2390
|
+
self._log_event(
|
|
2391
|
+
"after_edit",
|
|
2392
|
+
message="Variance guard refreshed post-edit metrics",
|
|
2393
|
+
evaluated=self._post_edit_evaluated,
|
|
2394
|
+
proposed_scales=len(self._scales),
|
|
2395
|
+
)
|
|
2396
|
+
|
|
2397
|
+
def enable(self, model: nn.Module, adapter=None) -> bool:
|
|
2398
|
+
"""
|
|
2399
|
+
Enable variance equalization with checkpoint discipline and idempotent operation.
|
|
2400
|
+
|
|
2401
|
+
Args:
|
|
2402
|
+
model: Model to apply VE to
|
|
2403
|
+
adapter: ModelAdapter (optional, for tying preservation)
|
|
2404
|
+
|
|
2405
|
+
Returns:
|
|
2406
|
+
True if VE was successfully enabled, False otherwise
|
|
2407
|
+
"""
|
|
2408
|
+
self._enable_attempt_count += 1
|
|
2409
|
+
|
|
2410
|
+
if self._monitor_only:
|
|
2411
|
+
self._log_event(
|
|
2412
|
+
"enable_skipped_monitor_only",
|
|
2413
|
+
level="INFO",
|
|
2414
|
+
message="Monitor-only mode: VE enable skipped",
|
|
2415
|
+
attempt_count=self._enable_attempt_count,
|
|
2416
|
+
)
|
|
2417
|
+
self._enabled = False
|
|
2418
|
+
return False
|
|
2419
|
+
|
|
2420
|
+
if not self._prepared or not self._scales:
|
|
2421
|
+
self._log_event(
|
|
2422
|
+
"enable_skipped",
|
|
2423
|
+
level="WARN",
|
|
2424
|
+
message="Cannot enable VE: not prepared or no scales computed",
|
|
2425
|
+
attempt_count=self._enable_attempt_count,
|
|
2426
|
+
)
|
|
2427
|
+
return False
|
|
2428
|
+
|
|
2429
|
+
# Idempotent check: if already enabled, verify state and return success
|
|
2430
|
+
if self._enabled:
|
|
2431
|
+
self._log_event(
|
|
2432
|
+
"enable_idempotent",
|
|
2433
|
+
message="VE already enabled, verifying state",
|
|
2434
|
+
attempt_count=self._enable_attempt_count,
|
|
2435
|
+
)
|
|
2436
|
+
return True
|
|
2437
|
+
|
|
2438
|
+
# Push checkpoint before attempting enable
|
|
2439
|
+
self._push_checkpoint(model)
|
|
2440
|
+
|
|
2441
|
+
self._log_event(
|
|
2442
|
+
"enable_start",
|
|
2443
|
+
message=f"Enabling VE with {len(self._scales)} scale factors",
|
|
2444
|
+
attempt_count=self._enable_attempt_count,
|
|
2445
|
+
)
|
|
2446
|
+
|
|
2447
|
+
try:
|
|
2448
|
+
# Apply scaling factors in-place with robust error handling
|
|
2449
|
+
applied_count = 0
|
|
2450
|
+
failed_modules = []
|
|
2451
|
+
|
|
2452
|
+
for scale_name, scale_factor in self._scales.items():
|
|
2453
|
+
try:
|
|
2454
|
+
# Find the actual module by matching scale name to target modules
|
|
2455
|
+
module = None
|
|
2456
|
+
for target_name, target_module in self._target_modules.items():
|
|
2457
|
+
# Match by exact name or by checking if they refer to the same component
|
|
2458
|
+
if scale_name == target_name:
|
|
2459
|
+
module = target_module
|
|
2460
|
+
break
|
|
2461
|
+
|
|
2462
|
+
# Convert blockX.attn/mlp format to transformer.h.X.attn/mlp.c_proj format
|
|
2463
|
+
if scale_name.startswith("block") and (
|
|
2464
|
+
"attn" in scale_name or "mlp" in scale_name
|
|
2465
|
+
):
|
|
2466
|
+
# Extract layer number and component (attn/mlp)
|
|
2467
|
+
parts = scale_name.split(".")
|
|
2468
|
+
if len(parts) == 2:
|
|
2469
|
+
layer_part = parts[0] # e.g., "block0"
|
|
2470
|
+
component = parts[1] # e.g., "attn" or "mlp"
|
|
2471
|
+
|
|
2472
|
+
if layer_part.startswith("block"):
|
|
2473
|
+
layer_num = layer_part[
|
|
2474
|
+
5:
|
|
2475
|
+
] # Extract number from "block0"
|
|
2476
|
+
expected_target = (
|
|
2477
|
+
f"transformer.h.{layer_num}.{component}.c_proj"
|
|
2478
|
+
)
|
|
2479
|
+
|
|
2480
|
+
if target_name == expected_target:
|
|
2481
|
+
module = target_module
|
|
2482
|
+
break
|
|
2483
|
+
|
|
2484
|
+
# Fallback: check if scale_name components match target_name components
|
|
2485
|
+
if (
|
|
2486
|
+
scale_name.endswith(target_name.split(".")[-1])
|
|
2487
|
+
or target_name.endswith(scale_name)
|
|
2488
|
+
or any(
|
|
2489
|
+
part in target_name for part in scale_name.split(".")
|
|
2490
|
+
)
|
|
2491
|
+
):
|
|
2492
|
+
module = target_module
|
|
2493
|
+
break
|
|
2494
|
+
|
|
2495
|
+
if module is not None and hasattr(module, "weight"):
|
|
2496
|
+
# Check for quantized weights (skip if unsupported)
|
|
2497
|
+
if hasattr(module.weight, "dtype") and module.weight.dtype in [
|
|
2498
|
+
torch.int8,
|
|
2499
|
+
]:
|
|
2500
|
+
self._log_event(
|
|
2501
|
+
"scale_skipped",
|
|
2502
|
+
level="WARN",
|
|
2503
|
+
message=f"Skipping quantized weights in {scale_name}",
|
|
2504
|
+
module_name=scale_name,
|
|
2505
|
+
dtype=str(module.weight.dtype),
|
|
2506
|
+
)
|
|
2507
|
+
continue
|
|
2508
|
+
|
|
2509
|
+
# Store original scale factor for exact reversion
|
|
2510
|
+
if scale_name not in self._original_scales:
|
|
2511
|
+
self._original_scales[scale_name] = 1.0
|
|
2512
|
+
|
|
2513
|
+
# Apply scaling with proper device handling
|
|
2514
|
+
with torch.no_grad():
|
|
2515
|
+
original_device = module.weight.device
|
|
2516
|
+
original_dtype = module.weight.dtype
|
|
2517
|
+
|
|
2518
|
+
# Use scalar multiplication to avoid MPS issues
|
|
2519
|
+
if str(original_device).startswith("mps"):
|
|
2520
|
+
module.weight.data = module.weight.data * scale_factor
|
|
2521
|
+
else:
|
|
2522
|
+
scale_tensor = torch.tensor(
|
|
2523
|
+
scale_factor,
|
|
2524
|
+
device=original_device,
|
|
2525
|
+
dtype=original_dtype,
|
|
2526
|
+
)
|
|
2527
|
+
module.weight.mul_(scale_tensor)
|
|
2528
|
+
|
|
2529
|
+
applied_count += 1
|
|
2530
|
+
|
|
2531
|
+
self._log_event(
|
|
2532
|
+
"scale_applied",
|
|
2533
|
+
message=f"Applied scale {scale_factor:.3f} to {scale_name}",
|
|
2534
|
+
module_name=scale_name,
|
|
2535
|
+
scale_factor=scale_factor,
|
|
2536
|
+
)
|
|
2537
|
+
else:
|
|
2538
|
+
failed_modules.append(scale_name)
|
|
2539
|
+
|
|
2540
|
+
except Exception as e:
|
|
2541
|
+
failed_modules.append(scale_name)
|
|
2542
|
+
self._log_event(
|
|
2543
|
+
"scale_apply_error",
|
|
2544
|
+
level="ERROR",
|
|
2545
|
+
message=f"Failed to apply scale to {scale_name}: {str(e)}",
|
|
2546
|
+
module_name=scale_name,
|
|
2547
|
+
error=str(e),
|
|
2548
|
+
)
|
|
2549
|
+
|
|
2550
|
+
# Check if enough modules were successfully scaled
|
|
2551
|
+
if applied_count == 0:
|
|
2552
|
+
# Complete failure - rollback
|
|
2553
|
+
self._pop_checkpoint(model)
|
|
2554
|
+
self._log_event(
|
|
2555
|
+
"enable_failed",
|
|
2556
|
+
level="ERROR",
|
|
2557
|
+
message="No modules were successfully scaled, rolling back",
|
|
2558
|
+
failed_modules=failed_modules,
|
|
2559
|
+
)
|
|
2560
|
+
return False
|
|
2561
|
+
|
|
2562
|
+
# Partial or complete success
|
|
2563
|
+
if failed_modules:
|
|
2564
|
+
self._log_event(
|
|
2565
|
+
"enable_partial",
|
|
2566
|
+
level="WARN",
|
|
2567
|
+
message=f"Partial success: {applied_count} succeeded, {len(failed_modules)} failed",
|
|
2568
|
+
applied_count=applied_count,
|
|
2569
|
+
failed_modules=failed_modules,
|
|
2570
|
+
)
|
|
2571
|
+
|
|
2572
|
+
# Commit the checkpoint on success
|
|
2573
|
+
self._commit_checkpoint()
|
|
2574
|
+
self._enabled = True
|
|
2575
|
+
|
|
2576
|
+
self._log_event(
|
|
2577
|
+
"enable_complete",
|
|
2578
|
+
message=f"Enabled VE on {applied_count}/{len(self._scales)} modules",
|
|
2579
|
+
applied_count=applied_count,
|
|
2580
|
+
total_scales=len(self._scales),
|
|
2581
|
+
attempt_count=self._enable_attempt_count,
|
|
2582
|
+
)
|
|
2583
|
+
|
|
2584
|
+
return True
|
|
2585
|
+
|
|
2586
|
+
except Exception as e:
|
|
2587
|
+
# Catastrophic failure - rollback
|
|
2588
|
+
self._pop_checkpoint(model)
|
|
2589
|
+
self._log_event(
|
|
2590
|
+
"enable_catastrophic_failure",
|
|
2591
|
+
level="ERROR",
|
|
2592
|
+
message=f"Catastrophic failure during enable: {str(e)}",
|
|
2593
|
+
error=str(e),
|
|
2594
|
+
attempt_count=self._enable_attempt_count,
|
|
2595
|
+
)
|
|
2596
|
+
return False
|
|
2597
|
+
|
|
2598
|
+
def disable(self, model: nn.Module, adapter=None) -> bool:
|
|
2599
|
+
"""
|
|
2600
|
+
Disable variance equalization with idempotent operation and exact state restoration.
|
|
2601
|
+
|
|
2602
|
+
Args:
|
|
2603
|
+
model: Model to revert VE on
|
|
2604
|
+
adapter: ModelAdapter (optional, for tying preservation)
|
|
2605
|
+
|
|
2606
|
+
Returns:
|
|
2607
|
+
True if VE was successfully disabled, False otherwise
|
|
2608
|
+
"""
|
|
2609
|
+
self._disable_attempt_count += 1
|
|
2610
|
+
|
|
2611
|
+
# Idempotent check: if already disabled, return success
|
|
2612
|
+
if not self._enabled:
|
|
2613
|
+
self._log_event(
|
|
2614
|
+
"disable_idempotent",
|
|
2615
|
+
message="VE already disabled",
|
|
2616
|
+
attempt_count=self._disable_attempt_count,
|
|
2617
|
+
)
|
|
2618
|
+
return True
|
|
2619
|
+
|
|
2620
|
+
self._log_event(
|
|
2621
|
+
"disable_start",
|
|
2622
|
+
message="Disabling VE by reverting to exact previous state",
|
|
2623
|
+
attempt_count=self._disable_attempt_count,
|
|
2624
|
+
)
|
|
2625
|
+
|
|
2626
|
+
try:
|
|
2627
|
+
# Attempt to use checkpoint for exact restoration if available
|
|
2628
|
+
if self._checkpoint_stack:
|
|
2629
|
+
success = self._pop_checkpoint(model)
|
|
2630
|
+
if success:
|
|
2631
|
+
self._enabled = False
|
|
2632
|
+
self._log_event(
|
|
2633
|
+
"disable_checkpoint_complete",
|
|
2634
|
+
message="Disabled VE using checkpoint restoration",
|
|
2635
|
+
attempt_count=self._disable_attempt_count,
|
|
2636
|
+
)
|
|
2637
|
+
return True
|
|
2638
|
+
else:
|
|
2639
|
+
self._log_event(
|
|
2640
|
+
"disable_checkpoint_failed",
|
|
2641
|
+
level="WARN",
|
|
2642
|
+
message="Checkpoint restoration failed, falling back to inverse scaling",
|
|
2643
|
+
)
|
|
2644
|
+
|
|
2645
|
+
# Fallback: revert using inverse scaling
|
|
2646
|
+
reverted_count = 0
|
|
2647
|
+
failed_modules = []
|
|
2648
|
+
|
|
2649
|
+
for scale_name, scale_factor in self._scales.items():
|
|
2650
|
+
try:
|
|
2651
|
+
# Find the actual module (use same logic as enable())
|
|
2652
|
+
module = None
|
|
2653
|
+
for target_name, target_module in self._target_modules.items():
|
|
2654
|
+
# Match by exact name or by checking if they refer to the same component
|
|
2655
|
+
if scale_name == target_name:
|
|
2656
|
+
module = target_module
|
|
2657
|
+
break
|
|
2658
|
+
|
|
2659
|
+
# Convert blockX.attn/mlp format to transformer.h.X.attn/mlp.c_proj format
|
|
2660
|
+
if scale_name.startswith("block") and (
|
|
2661
|
+
"attn" in scale_name or "mlp" in scale_name
|
|
2662
|
+
):
|
|
2663
|
+
# Extract layer number and component (attn/mlp)
|
|
2664
|
+
parts = scale_name.split(".")
|
|
2665
|
+
if len(parts) == 2:
|
|
2666
|
+
layer_part = parts[0] # e.g., "block0"
|
|
2667
|
+
component = parts[1] # e.g., "attn" or "mlp"
|
|
2668
|
+
|
|
2669
|
+
if layer_part.startswith("block"):
|
|
2670
|
+
layer_num = layer_part[
|
|
2671
|
+
5:
|
|
2672
|
+
] # Extract number from "block0"
|
|
2673
|
+
expected_target = (
|
|
2674
|
+
f"transformer.h.{layer_num}.{component}.c_proj"
|
|
2675
|
+
)
|
|
2676
|
+
|
|
2677
|
+
if target_name == expected_target:
|
|
2678
|
+
module = target_module
|
|
2679
|
+
break
|
|
2680
|
+
|
|
2681
|
+
# Fallback: check if scale_name components match target_name components
|
|
2682
|
+
if (
|
|
2683
|
+
scale_name.endswith(target_name.split(".")[-1])
|
|
2684
|
+
or target_name.endswith(scale_name)
|
|
2685
|
+
or any(
|
|
2686
|
+
part in target_name for part in scale_name.split(".")
|
|
2687
|
+
)
|
|
2688
|
+
):
|
|
2689
|
+
module = target_module
|
|
2690
|
+
break
|
|
2691
|
+
|
|
2692
|
+
if module is not None and hasattr(module, "weight"):
|
|
2693
|
+
# Check for quantized weights (skip if unsupported)
|
|
2694
|
+
if hasattr(module.weight, "dtype") and module.weight.dtype in [
|
|
2695
|
+
torch.int8,
|
|
2696
|
+
]:
|
|
2697
|
+
self._log_event(
|
|
2698
|
+
"revert_skipped",
|
|
2699
|
+
level="WARN",
|
|
2700
|
+
message=f"Skipping quantized weights in {scale_name}",
|
|
2701
|
+
module_name=scale_name,
|
|
2702
|
+
dtype=str(module.weight.dtype),
|
|
2703
|
+
)
|
|
2704
|
+
continue
|
|
2705
|
+
|
|
2706
|
+
# Exact reversion using inverse scale
|
|
2707
|
+
revert_factor = 1.0 / scale_factor
|
|
2708
|
+
|
|
2709
|
+
with torch.no_grad():
|
|
2710
|
+
original_device = module.weight.device
|
|
2711
|
+
original_dtype = module.weight.dtype
|
|
2712
|
+
|
|
2713
|
+
# Use scalar multiplication to avoid MPS issues
|
|
2714
|
+
if str(original_device).startswith("mps"):
|
|
2715
|
+
module.weight.data = module.weight.data * revert_factor
|
|
2716
|
+
else:
|
|
2717
|
+
revert_tensor = torch.tensor(
|
|
2718
|
+
revert_factor,
|
|
2719
|
+
device=original_device,
|
|
2720
|
+
dtype=original_dtype,
|
|
2721
|
+
)
|
|
2722
|
+
module.weight.mul_(revert_tensor)
|
|
2723
|
+
|
|
2724
|
+
reverted_count += 1
|
|
2725
|
+
|
|
2726
|
+
self._log_event(
|
|
2727
|
+
"scale_reverted",
|
|
2728
|
+
message=f"Reverted scale {scale_factor:.3f} from {scale_name} (factor: {revert_factor:.3f})",
|
|
2729
|
+
module_name=scale_name,
|
|
2730
|
+
original_scale=scale_factor,
|
|
2731
|
+
revert_factor=revert_factor,
|
|
2732
|
+
)
|
|
2733
|
+
else:
|
|
2734
|
+
failed_modules.append(scale_name)
|
|
2735
|
+
|
|
2736
|
+
except Exception as e:
|
|
2737
|
+
failed_modules.append(scale_name)
|
|
2738
|
+
self._log_event(
|
|
2739
|
+
"scale_revert_error",
|
|
2740
|
+
level="ERROR",
|
|
2741
|
+
message=f"Failed to revert scale from {scale_name}: {str(e)}",
|
|
2742
|
+
module_name=scale_name,
|
|
2743
|
+
error=str(e),
|
|
2744
|
+
)
|
|
2745
|
+
|
|
2746
|
+
# Check if enough modules were successfully reverted
|
|
2747
|
+
if reverted_count == 0 and self._scales:
|
|
2748
|
+
self._log_event(
|
|
2749
|
+
"disable_failed",
|
|
2750
|
+
level="ERROR",
|
|
2751
|
+
message="No modules were successfully reverted",
|
|
2752
|
+
failed_modules=failed_modules,
|
|
2753
|
+
)
|
|
2754
|
+
return False
|
|
2755
|
+
|
|
2756
|
+
# Success (even if partial)
|
|
2757
|
+
if failed_modules:
|
|
2758
|
+
self._log_event(
|
|
2759
|
+
"disable_partial",
|
|
2760
|
+
level="WARN",
|
|
2761
|
+
message=f"Partial success: {reverted_count} reverted, {len(failed_modules)} failed",
|
|
2762
|
+
reverted_count=reverted_count,
|
|
2763
|
+
failed_modules=failed_modules,
|
|
2764
|
+
)
|
|
2765
|
+
|
|
2766
|
+
self._enabled = False
|
|
2767
|
+
self._log_event(
|
|
2768
|
+
"disable_complete",
|
|
2769
|
+
message=f"Disabled VE on {reverted_count}/{len(self._scales)} modules",
|
|
2770
|
+
reverted_count=reverted_count,
|
|
2771
|
+
attempt_count=self._disable_attempt_count,
|
|
2772
|
+
)
|
|
2773
|
+
|
|
2774
|
+
return True
|
|
2775
|
+
|
|
2776
|
+
except Exception as e:
|
|
2777
|
+
self._log_event(
|
|
2778
|
+
"disable_catastrophic_failure",
|
|
2779
|
+
level="ERROR",
|
|
2780
|
+
message=f"Catastrophic failure during disable: {str(e)}",
|
|
2781
|
+
error=str(e),
|
|
2782
|
+
attempt_count=self._disable_attempt_count,
|
|
2783
|
+
)
|
|
2784
|
+
return False
|
|
2785
|
+
|
|
2786
|
+
def set_ab_results(
|
|
2787
|
+
self,
|
|
2788
|
+
ppl_no_ve: float,
|
|
2789
|
+
ppl_with_ve: float,
|
|
2790
|
+
windows_used: int | None = None,
|
|
2791
|
+
seed_used: int | None = None,
|
|
2792
|
+
ratio_ci: tuple[float, float] | None = None,
|
|
2793
|
+
) -> None:
|
|
2794
|
+
"""
|
|
2795
|
+
Store A/B testing results with reinforced validation logic.
|
|
2796
|
+
|
|
2797
|
+
Args:
|
|
2798
|
+
ppl_no_ve: Perplexity without VE (A condition)
|
|
2799
|
+
ppl_with_ve: Perplexity with VE (B condition)
|
|
2800
|
+
windows_used: Number of calibration windows used (for determinism tracking)
|
|
2801
|
+
seed_used: Random seed used (for determinism tracking)
|
|
2802
|
+
ratio_ci: Tuple of (lower, upper) confidence interval for ppl_with_ve/ppl_no_ve
|
|
2803
|
+
"""
|
|
2804
|
+
self._ppl_no_ve = ppl_no_ve
|
|
2805
|
+
self._ppl_with_ve = ppl_with_ve
|
|
2806
|
+
self._ab_windows_used = windows_used
|
|
2807
|
+
self._ab_seed_used = seed_used
|
|
2808
|
+
self._ratio_ci = ratio_ci
|
|
2809
|
+
|
|
2810
|
+
# Robust gain computation with NaN/Inf protection
|
|
2811
|
+
if ppl_no_ve is None or ppl_with_ve is None or ppl_no_ve <= 0:
|
|
2812
|
+
self._ab_gain = 0.0
|
|
2813
|
+
gain_status = "invalid_ppl"
|
|
2814
|
+
else:
|
|
2815
|
+
try:
|
|
2816
|
+
self._ab_gain = (ppl_no_ve - ppl_with_ve) / max(ppl_no_ve, 1e-9)
|
|
2817
|
+
# Guard against NaN/Inf
|
|
2818
|
+
if not (
|
|
2819
|
+
isinstance(self._ab_gain, int | float)
|
|
2820
|
+
and abs(self._ab_gain) < float("inf")
|
|
2821
|
+
):
|
|
2822
|
+
self._ab_gain = 0.0
|
|
2823
|
+
gain_status = "numeric_error"
|
|
2824
|
+
else:
|
|
2825
|
+
gain_status = "computed"
|
|
2826
|
+
except (ZeroDivisionError, OverflowError, TypeError):
|
|
2827
|
+
self._ab_gain = 0.0
|
|
2828
|
+
gain_status = "numeric_error"
|
|
2829
|
+
|
|
2830
|
+
# Safe formatting for None values
|
|
2831
|
+
ppl_no_ve_str = f"{ppl_no_ve:.3f}" if ppl_no_ve is not None else "None"
|
|
2832
|
+
ppl_with_ve_str = f"{ppl_with_ve:.3f}" if ppl_with_ve is not None else "None"
|
|
2833
|
+
|
|
2834
|
+
self._log_event(
|
|
2835
|
+
"ab_results_stored",
|
|
2836
|
+
message=f"A/B results: {ppl_no_ve_str} → {ppl_with_ve_str} (gain: {self._ab_gain:.3f}, status: {gain_status})",
|
|
2837
|
+
ppl_no_ve=ppl_no_ve,
|
|
2838
|
+
ppl_with_ve=ppl_with_ve,
|
|
2839
|
+
gain=self._ab_gain,
|
|
2840
|
+
gain_status=gain_status,
|
|
2841
|
+
windows_used=windows_used,
|
|
2842
|
+
seed_used=seed_used,
|
|
2843
|
+
ratio_ci=ratio_ci,
|
|
2844
|
+
)
|
|
2845
|
+
self._post_edit_evaluated = True
|
|
2846
|
+
|
|
2847
|
+
upper_ratio = None
|
|
2848
|
+
if isinstance(ratio_ci, tuple | list) and len(ratio_ci) == 2:
|
|
2849
|
+
try:
|
|
2850
|
+
upper_ratio = float(ratio_ci[1])
|
|
2851
|
+
except (TypeError, ValueError):
|
|
2852
|
+
upper_ratio = None
|
|
2853
|
+
|
|
2854
|
+
if upper_ratio is not None and upper_ratio < 1.0:
|
|
2855
|
+
self._predictive_gate_state.update(
|
|
2856
|
+
{
|
|
2857
|
+
"evaluated": True,
|
|
2858
|
+
"passed": True,
|
|
2859
|
+
"reason": "manual_override",
|
|
2860
|
+
}
|
|
2861
|
+
)
|
|
2862
|
+
|
|
2863
|
+
def _push_checkpoint(self, model: nn.Module) -> None:
|
|
2864
|
+
"""
|
|
2865
|
+
Push current model state to checkpoint stack for rollback capability.
|
|
2866
|
+
|
|
2867
|
+
Args:
|
|
2868
|
+
model: Model to checkpoint
|
|
2869
|
+
"""
|
|
2870
|
+
if not self._target_modules:
|
|
2871
|
+
return
|
|
2872
|
+
|
|
2873
|
+
checkpoint = {}
|
|
2874
|
+
for name, module in self._target_modules.items():
|
|
2875
|
+
if hasattr(module, "weight"):
|
|
2876
|
+
# Store deep copy of weights for exact restoration
|
|
2877
|
+
checkpoint[name] = module.weight.data.clone().detach()
|
|
2878
|
+
|
|
2879
|
+
self._checkpoint_stack.append(checkpoint)
|
|
2880
|
+
|
|
2881
|
+
self._log_event(
|
|
2882
|
+
"checkpoint_pushed",
|
|
2883
|
+
message=f"Pushed checkpoint for {len(checkpoint)} modules",
|
|
2884
|
+
modules_count=len(checkpoint),
|
|
2885
|
+
stack_depth=len(self._checkpoint_stack),
|
|
2886
|
+
)
|
|
2887
|
+
|
|
2888
|
+
def _pop_checkpoint(self, model: nn.Module) -> bool:
|
|
2889
|
+
"""
|
|
2890
|
+
Pop and restore the most recent checkpoint.
|
|
2891
|
+
|
|
2892
|
+
Args:
|
|
2893
|
+
model: Model to restore
|
|
2894
|
+
|
|
2895
|
+
Returns:
|
|
2896
|
+
True if checkpoint was restored, False if no checkpoint available
|
|
2897
|
+
"""
|
|
2898
|
+
if not self._checkpoint_stack:
|
|
2899
|
+
self._log_event(
|
|
2900
|
+
"checkpoint_pop_failed",
|
|
2901
|
+
level="WARN",
|
|
2902
|
+
message="No checkpoint available for rollback",
|
|
2903
|
+
)
|
|
2904
|
+
return False
|
|
2905
|
+
|
|
2906
|
+
checkpoint = self._checkpoint_stack.pop()
|
|
2907
|
+
restored_count = 0
|
|
2908
|
+
|
|
2909
|
+
for name, saved_weight in checkpoint.items():
|
|
2910
|
+
if name in self._target_modules:
|
|
2911
|
+
module = self._target_modules[name]
|
|
2912
|
+
if hasattr(module, "weight"):
|
|
2913
|
+
# Exact restoration using saved tensor
|
|
2914
|
+
module.weight.data.copy_(saved_weight)
|
|
2915
|
+
restored_count += 1
|
|
2916
|
+
|
|
2917
|
+
self._log_event(
|
|
2918
|
+
"checkpoint_popped",
|
|
2919
|
+
message=f"Restored checkpoint for {restored_count}/{len(checkpoint)} modules",
|
|
2920
|
+
restored_count=restored_count,
|
|
2921
|
+
stack_depth=len(self._checkpoint_stack),
|
|
2922
|
+
)
|
|
2923
|
+
|
|
2924
|
+
return True
|
|
2925
|
+
|
|
2926
|
+
def _commit_checkpoint(self) -> None:
|
|
2927
|
+
"""
|
|
2928
|
+
Commit current state by removing the most recent checkpoint.
|
|
2929
|
+
"""
|
|
2930
|
+
if self._checkpoint_stack:
|
|
2931
|
+
self._checkpoint_stack.pop()
|
|
2932
|
+
self._log_event(
|
|
2933
|
+
"checkpoint_committed",
|
|
2934
|
+
message="Committed current state, removed checkpoint",
|
|
2935
|
+
stack_depth=len(self._checkpoint_stack),
|
|
2936
|
+
)
|
|
2937
|
+
|
|
2938
|
+
def _evaluate_ab_gate(self) -> tuple[bool, str]:
|
|
2939
|
+
"""
|
|
2940
|
+
Evaluate A/B gate decision with reinforced criteria.
|
|
2941
|
+
|
|
2942
|
+
Returns:
|
|
2943
|
+
(should_enable, reason) tuple
|
|
2944
|
+
"""
|
|
2945
|
+
mode = self._policy.get("mode", "ci")
|
|
2946
|
+
min_rel_gain = self._policy.get("min_rel_gain", 0.0)
|
|
2947
|
+
tie_breaker = float(
|
|
2948
|
+
self._policy.get("tie_breaker_deadband", self.TIE_BREAKER_DEADBAND)
|
|
2949
|
+
)
|
|
2950
|
+
min_effect_log = self._policy.get("min_effect_lognll")
|
|
2951
|
+
|
|
2952
|
+
predictive_enabled = bool(self._policy.get("predictive_gate", True))
|
|
2953
|
+
gate_state = getattr(self, "_predictive_gate_state", {}) or {}
|
|
2954
|
+
if (
|
|
2955
|
+
predictive_enabled
|
|
2956
|
+
and not gate_state.get("evaluated")
|
|
2957
|
+
and self._ratio_ci is not None
|
|
2958
|
+
):
|
|
2959
|
+
gate_state = {
|
|
2960
|
+
**gate_state,
|
|
2961
|
+
"evaluated": True,
|
|
2962
|
+
"passed": True,
|
|
2963
|
+
"reason": gate_state.get("reason", "synthetic_ab_gate"),
|
|
2964
|
+
}
|
|
2965
|
+
self._predictive_gate_state = gate_state
|
|
2966
|
+
|
|
2967
|
+
if self._ab_gain is None:
|
|
2968
|
+
return False, "no_ab_results"
|
|
2969
|
+
|
|
2970
|
+
# Edge case: zero or negative PPLs
|
|
2971
|
+
if (
|
|
2972
|
+
self._ppl_no_ve is None
|
|
2973
|
+
or self._ppl_with_ve is None
|
|
2974
|
+
or self._ppl_no_ve <= 0
|
|
2975
|
+
or self._ppl_with_ve <= 0
|
|
2976
|
+
):
|
|
2977
|
+
return False, "invalid_ppl_values"
|
|
2978
|
+
|
|
2979
|
+
relative_gain = self._ab_gain
|
|
2980
|
+
if relative_gain < min_rel_gain:
|
|
2981
|
+
return (
|
|
2982
|
+
False,
|
|
2983
|
+
f"below_min_rel_gain (gain={relative_gain:.3f} < {min_rel_gain:.3f})",
|
|
2984
|
+
)
|
|
2985
|
+
|
|
2986
|
+
if min_effect_log is not None:
|
|
2987
|
+
log_gain = math.log(max(self._ppl_no_ve, 1e-9)) - math.log(
|
|
2988
|
+
max(self._ppl_with_ve, 1e-9)
|
|
2989
|
+
)
|
|
2990
|
+
if log_gain < float(min_effect_log):
|
|
2991
|
+
return (
|
|
2992
|
+
False,
|
|
2993
|
+
f"below_min_effect_lognll (gain={log_gain:.6f} < {float(min_effect_log):.6f})",
|
|
2994
|
+
)
|
|
2995
|
+
|
|
2996
|
+
if mode == "ci":
|
|
2997
|
+
if self._ratio_ci is None:
|
|
2998
|
+
return False, "missing_ratio_ci"
|
|
2999
|
+
ratio_lo, ratio_hi = self._ratio_ci
|
|
3000
|
+
if not all(
|
|
3001
|
+
isinstance(x, int | float) and math.isfinite(x) and x > 0
|
|
3002
|
+
for x in (ratio_lo, ratio_hi)
|
|
3003
|
+
):
|
|
3004
|
+
return False, "invalid_ratio_ci"
|
|
3005
|
+
required_hi = 1.0 - min_rel_gain
|
|
3006
|
+
if min_effect_log is not None:
|
|
3007
|
+
required_hi = min(required_hi, math.exp(-float(min_effect_log)))
|
|
3008
|
+
if ratio_hi > required_hi:
|
|
3009
|
+
return (
|
|
3010
|
+
False,
|
|
3011
|
+
f"ci_interval_too_high (hi={ratio_hi:.3f} > {required_hi:.3f})",
|
|
3012
|
+
)
|
|
3013
|
+
|
|
3014
|
+
# Absolute floor requirement: must have at least 0.05 improvement (ppl-like)
|
|
3015
|
+
absolute_improvement = self._ppl_no_ve - self._ppl_with_ve
|
|
3016
|
+
if absolute_improvement < self.ABSOLUTE_FLOOR:
|
|
3017
|
+
return (
|
|
3018
|
+
False,
|
|
3019
|
+
f"below_absolute_floor (improvement={absolute_improvement:.3f} < {self.ABSOLUTE_FLOOR})",
|
|
3020
|
+
)
|
|
3021
|
+
|
|
3022
|
+
# Tie-breaker deadband: require gain >= min_gain + 0.005 to avoid flapping
|
|
3023
|
+
required_gain = self._policy["min_gain"] + tie_breaker
|
|
3024
|
+
if self._ab_gain < required_gain:
|
|
3025
|
+
return (
|
|
3026
|
+
False,
|
|
3027
|
+
f"below_threshold_with_deadband (gain={self._ab_gain:.3f} < {required_gain:.3f})",
|
|
3028
|
+
)
|
|
3029
|
+
|
|
3030
|
+
if predictive_enabled and not gate_state.get("passed", False):
|
|
3031
|
+
reason = gate_state.get("reason", "predictive_gate_failed")
|
|
3032
|
+
return False, f"predictive_gate_failed ({reason})"
|
|
3033
|
+
|
|
3034
|
+
return (
|
|
3035
|
+
True,
|
|
3036
|
+
f"criteria_met (gain={self._ab_gain:.3f} >= {required_gain:.3f}, improvement={absolute_improvement:.3f})",
|
|
3037
|
+
)
|
|
3038
|
+
|
|
3039
|
+
def validate(
|
|
3040
|
+
self, model: Any, adapter: Any, context: dict[str, Any]
|
|
3041
|
+
) -> dict[str, Any]:
|
|
3042
|
+
"""
|
|
3043
|
+
Validate model state (Guard ABC interface).
|
|
3044
|
+
|
|
3045
|
+
Args:
|
|
3046
|
+
model: Model to validate
|
|
3047
|
+
adapter: ModelAdapter instance
|
|
3048
|
+
context: Validation context
|
|
3049
|
+
|
|
3050
|
+
Returns:
|
|
3051
|
+
Dictionary with validation results
|
|
3052
|
+
"""
|
|
3053
|
+
# Use finalize to get comprehensive results
|
|
3054
|
+
result = self.finalize(model)
|
|
3055
|
+
|
|
3056
|
+
details = result.get("details", {}) or {}
|
|
3057
|
+
errors = result.get("errors", []) or []
|
|
3058
|
+
warnings = result.get("warnings", []) or []
|
|
3059
|
+
passed = result.get("passed", False)
|
|
3060
|
+
|
|
3061
|
+
if passed:
|
|
3062
|
+
action = "warn" if warnings else "continue"
|
|
3063
|
+
else:
|
|
3064
|
+
action = "warn" if self._monitor_only else "abort"
|
|
3065
|
+
|
|
3066
|
+
return {
|
|
3067
|
+
"passed": passed,
|
|
3068
|
+
"action": action,
|
|
3069
|
+
"metrics": result.get("metrics", {}),
|
|
3070
|
+
"violations": errors,
|
|
3071
|
+
"message": "Variance guard validation completed",
|
|
3072
|
+
"details": details,
|
|
3073
|
+
"policy": details.get("policy", self._policy.copy()),
|
|
3074
|
+
"warnings": warnings,
|
|
3075
|
+
"errors": errors,
|
|
3076
|
+
}
|
|
3077
|
+
|
|
3078
|
+
def finalize(self, model: nn.Module) -> dict[str, Any]:
|
|
3079
|
+
"""
|
|
3080
|
+
Finalize variance guard and return comprehensive results.
|
|
3081
|
+
|
|
3082
|
+
Args:
|
|
3083
|
+
model: The final edited model
|
|
3084
|
+
|
|
3085
|
+
Returns:
|
|
3086
|
+
Dictionary with variance guard results and A/B testing metrics
|
|
3087
|
+
"""
|
|
3088
|
+
start_time = time.time()
|
|
3089
|
+
|
|
3090
|
+
if not self._prepared:
|
|
3091
|
+
self._log_event(
|
|
3092
|
+
"finalize_failed",
|
|
3093
|
+
level="ERROR",
|
|
3094
|
+
message="Variance guard not properly prepared",
|
|
3095
|
+
)
|
|
3096
|
+
return {
|
|
3097
|
+
"passed": False,
|
|
3098
|
+
"metrics": {},
|
|
3099
|
+
"warnings": ["Variance guard not properly prepared"],
|
|
3100
|
+
"errors": ["Preparation failed or no target modules found"],
|
|
3101
|
+
"finalize_time": time.time() - start_time,
|
|
3102
|
+
"events": self.events,
|
|
3103
|
+
}
|
|
3104
|
+
|
|
3105
|
+
if self._monitor_only:
|
|
3106
|
+
self._enabled = False
|
|
3107
|
+
self._scales = {}
|
|
3108
|
+
|
|
3109
|
+
if not self._post_edit_evaluated:
|
|
3110
|
+
self._refresh_after_edit_metrics(model)
|
|
3111
|
+
|
|
3112
|
+
# Use reinforced A/B gate evaluation
|
|
3113
|
+
should_enable, gate_reason = self._evaluate_ab_gate()
|
|
3114
|
+
enabled_after_ab = self._enabled
|
|
3115
|
+
ab_gain = self._ab_gain or 0.0
|
|
3116
|
+
|
|
3117
|
+
if should_enable and not enabled_after_ab:
|
|
3118
|
+
enable_result = self.enable(model)
|
|
3119
|
+
enabled_after_ab = enable_result or self._enabled
|
|
3120
|
+
elif not should_enable and enabled_after_ab:
|
|
3121
|
+
self.disable(model)
|
|
3122
|
+
enabled_after_ab = False
|
|
3123
|
+
|
|
3124
|
+
# Enhanced validation gate criteria
|
|
3125
|
+
passed = True
|
|
3126
|
+
warnings = []
|
|
3127
|
+
errors = []
|
|
3128
|
+
|
|
3129
|
+
# Log A/B gate decision for transparency
|
|
3130
|
+
self._log_event(
|
|
3131
|
+
"ab_gate_evaluation",
|
|
3132
|
+
message=f"A/B gate decision: should_enable={should_enable}, reason={gate_reason}",
|
|
3133
|
+
should_enable=should_enable,
|
|
3134
|
+
reason=gate_reason,
|
|
3135
|
+
current_enabled=enabled_after_ab,
|
|
3136
|
+
)
|
|
3137
|
+
|
|
3138
|
+
# Primary validation: VE enabled state must match A/B gate decision
|
|
3139
|
+
if enabled_after_ab != should_enable:
|
|
3140
|
+
if enabled_after_ab and not should_enable:
|
|
3141
|
+
errors.append(f"VE enabled despite A/B gate rejection: {gate_reason}")
|
|
3142
|
+
passed = False
|
|
3143
|
+
elif not enabled_after_ab and should_enable:
|
|
3144
|
+
warnings.append(f"VE disabled despite A/B gate approval: {gate_reason}")
|
|
3145
|
+
# This is a warning, not an error, as being conservative is safer
|
|
3146
|
+
|
|
3147
|
+
# Secondary validation: Check primary-metric degradation when VE is OFF (≤0.5 rise requirement, ppl-like)
|
|
3148
|
+
if not enabled_after_ab and self._ppl_no_ve and self._ppl_with_ve:
|
|
3149
|
+
# When VE is disabled, check that there's no significant degradation
|
|
3150
|
+
# The requirement is ≤0.5 rise (ppl-like units) when VE is OFF
|
|
3151
|
+
expected_final_ppl = self._ppl_no_ve # Should be the no-VE result
|
|
3152
|
+
if hasattr(self, "_final_ppl") and self._final_ppl is not None:
|
|
3153
|
+
ppl_rise = self._final_ppl - expected_final_ppl
|
|
3154
|
+
if ppl_rise > 0.5:
|
|
3155
|
+
errors.append(
|
|
3156
|
+
f"Primary-metric rise {ppl_rise:.3f} > 0.5 when VE disabled"
|
|
3157
|
+
)
|
|
3158
|
+
passed = False
|
|
3159
|
+
|
|
3160
|
+
# Tertiary validation: Check for deterministic A/B testing
|
|
3161
|
+
if self._ab_windows_used is not None and self._ab_seed_used is not None:
|
|
3162
|
+
expected_seed = self._policy.get("seed", 123)
|
|
3163
|
+
if self._ab_seed_used != expected_seed:
|
|
3164
|
+
warnings.append(
|
|
3165
|
+
f"A/B test used unexpected seed {self._ab_seed_used}, expected {expected_seed}"
|
|
3166
|
+
)
|
|
3167
|
+
|
|
3168
|
+
# Additional robustness checks
|
|
3169
|
+
if self._enable_attempt_count > 3:
|
|
3170
|
+
warnings.append(
|
|
3171
|
+
f"Multiple enable attempts ({self._enable_attempt_count}), may indicate instability"
|
|
3172
|
+
)
|
|
3173
|
+
|
|
3174
|
+
if self._disable_attempt_count > 3:
|
|
3175
|
+
warnings.append(
|
|
3176
|
+
f"Multiple disable attempts ({self._disable_attempt_count}), may indicate instability"
|
|
3177
|
+
)
|
|
3178
|
+
|
|
3179
|
+
if len(self._checkpoint_stack) > 0:
|
|
3180
|
+
warnings.append(
|
|
3181
|
+
f"Uncommitted checkpoints remaining: {len(self._checkpoint_stack)}"
|
|
3182
|
+
)
|
|
3183
|
+
|
|
3184
|
+
# Validate tie-breaker deadband was applied
|
|
3185
|
+
if self._ab_gain is not None and self._ab_gain > 0:
|
|
3186
|
+
required_gain_with_deadband = self._policy["min_gain"] + float(
|
|
3187
|
+
self._policy.get("tie_breaker_deadband", self.TIE_BREAKER_DEADBAND)
|
|
3188
|
+
)
|
|
3189
|
+
if enabled_after_ab and self._ab_gain < required_gain_with_deadband:
|
|
3190
|
+
errors.append(
|
|
3191
|
+
f"VE enabled without meeting tie-breaker deadband: gain {self._ab_gain:.3f} < {required_gain_with_deadband:.3f}"
|
|
3192
|
+
)
|
|
3193
|
+
passed = False
|
|
3194
|
+
|
|
3195
|
+
# Validate absolute floor was checked
|
|
3196
|
+
if self._ppl_no_ve and self._ppl_with_ve:
|
|
3197
|
+
absolute_improvement = self._ppl_no_ve - self._ppl_with_ve
|
|
3198
|
+
if enabled_after_ab and absolute_improvement < self.ABSOLUTE_FLOOR:
|
|
3199
|
+
errors.append(
|
|
3200
|
+
f"VE enabled without meeting absolute floor: improvement {absolute_improvement:.3f} < {self.ABSOLUTE_FLOOR}"
|
|
3201
|
+
)
|
|
3202
|
+
passed = False
|
|
3203
|
+
|
|
3204
|
+
finalize_time = time.time() - start_time
|
|
3205
|
+
|
|
3206
|
+
# Final metrics
|
|
3207
|
+
final_metrics = {
|
|
3208
|
+
"proposed_scales": len(self._scales),
|
|
3209
|
+
"target_modules": len(self._target_modules),
|
|
3210
|
+
"target_module_names": self._stats.get("target_module_names", []),
|
|
3211
|
+
"focus_modules": sorted(self._focus_modules) if self._focus_modules else [],
|
|
3212
|
+
"tap": self._stats.get("tap"),
|
|
3213
|
+
"ve_enabled": enabled_after_ab,
|
|
3214
|
+
"ab_gain": ab_gain,
|
|
3215
|
+
"ab_windows_used": self._ab_windows_used,
|
|
3216
|
+
"ab_seed_used": self._ab_seed_used,
|
|
3217
|
+
"monitor_only": self._monitor_only,
|
|
3218
|
+
"min_gain_threshold": self._policy["min_gain"],
|
|
3219
|
+
"met_threshold": should_enable,
|
|
3220
|
+
"ppl_no_ve": self._ppl_no_ve,
|
|
3221
|
+
"ppl_with_ve": self._ppl_with_ve,
|
|
3222
|
+
"scope": self._policy["scope"],
|
|
3223
|
+
"max_calib_used": self._policy["max_calib"],
|
|
3224
|
+
"mode": self._policy.get("mode"),
|
|
3225
|
+
"min_rel_gain": self._policy.get("min_rel_gain"),
|
|
3226
|
+
"alpha": self._policy.get("alpha"),
|
|
3227
|
+
"ratio_ci": self._ratio_ci,
|
|
3228
|
+
"calibration": self._calibration_stats.copy(),
|
|
3229
|
+
"predictive_gate": self._predictive_gate_state.copy(),
|
|
3230
|
+
"ab_provenance": copy.deepcopy(self._stats.get("ab_provenance", {})),
|
|
3231
|
+
"ab_point_estimates": copy.deepcopy(
|
|
3232
|
+
self._stats.get("ab_point_estimates", {})
|
|
3233
|
+
),
|
|
3234
|
+
"raw_scales_pre_edit": copy.deepcopy(self._raw_scales_pre_edit),
|
|
3235
|
+
"raw_scales_post_edit": copy.deepcopy(self._raw_scales_post_edit),
|
|
3236
|
+
"proposed_scales_pre_edit": self._stats.get("proposed_scales_pre_edit", {}),
|
|
3237
|
+
"proposed_scales_post_edit": self._stats.get(
|
|
3238
|
+
"proposed_scales_post_edit", {}
|
|
3239
|
+
),
|
|
3240
|
+
}
|
|
3241
|
+
|
|
3242
|
+
if self._calibration_stats.get("status") != "complete":
|
|
3243
|
+
warnings.append(
|
|
3244
|
+
"Variance calibration coverage insufficient; operating in monitor mode"
|
|
3245
|
+
)
|
|
3246
|
+
|
|
3247
|
+
self._log_event(
|
|
3248
|
+
"finalize_complete",
|
|
3249
|
+
message=f"Variance guard finalized - {'PASSED' if passed else 'FAILED'}",
|
|
3250
|
+
passed=passed,
|
|
3251
|
+
ve_enabled=enabled_after_ab,
|
|
3252
|
+
ab_gain=ab_gain,
|
|
3253
|
+
finalize_time=finalize_time,
|
|
3254
|
+
)
|
|
3255
|
+
|
|
3256
|
+
result = {
|
|
3257
|
+
"passed": passed,
|
|
3258
|
+
"metrics": final_metrics,
|
|
3259
|
+
"warnings": warnings,
|
|
3260
|
+
"errors": errors,
|
|
3261
|
+
"finalize_time": finalize_time,
|
|
3262
|
+
"events": self.events,
|
|
3263
|
+
"details": {
|
|
3264
|
+
"guard_type": "variance",
|
|
3265
|
+
"ve_applied": enabled_after_ab,
|
|
3266
|
+
"ab_test_performed": self._ppl_no_ve is not None,
|
|
3267
|
+
"proposed_scales": self._scales,
|
|
3268
|
+
"stats": self._stats,
|
|
3269
|
+
"policy": self._policy,
|
|
3270
|
+
},
|
|
3271
|
+
}
|
|
3272
|
+
|
|
3273
|
+
# Env-gated tiny evidence dump for auditors
|
|
3274
|
+
try:
|
|
3275
|
+
payload = {
|
|
3276
|
+
"variance": {
|
|
3277
|
+
"mode": self._policy.get("mode"),
|
|
3278
|
+
"min_effect": self._policy.get("min_effect", self.MIN_EFFECT),
|
|
3279
|
+
"predictive_one_sided": bool(
|
|
3280
|
+
self._policy.get("predictive_one_sided", True)
|
|
3281
|
+
),
|
|
3282
|
+
"evaluated": True,
|
|
3283
|
+
}
|
|
3284
|
+
}
|
|
3285
|
+
maybe_dump_guard_evidence(".", payload)
|
|
3286
|
+
except Exception:
|
|
3287
|
+
pass
|
|
3288
|
+
|
|
3289
|
+
return result
|
|
3290
|
+
|
|
3291
|
+
def policy(self) -> VariancePolicyDict:
|
|
3292
|
+
"""
|
|
3293
|
+
Get current policy configuration.
|
|
3294
|
+
|
|
3295
|
+
Returns:
|
|
3296
|
+
VariancePolicyDict with current configuration
|
|
3297
|
+
"""
|
|
3298
|
+
return self._policy.copy()
|