invarlock 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- invarlock/__init__.py +33 -0
- invarlock/__main__.py +10 -0
- invarlock/_data/runtime/profiles/ci_cpu.yaml +15 -0
- invarlock/_data/runtime/profiles/release.yaml +23 -0
- invarlock/_data/runtime/tiers.yaml +76 -0
- invarlock/adapters/__init__.py +102 -0
- invarlock/adapters/_capabilities.py +45 -0
- invarlock/adapters/auto.py +99 -0
- invarlock/adapters/base.py +530 -0
- invarlock/adapters/base_types.py +85 -0
- invarlock/adapters/hf_bert.py +852 -0
- invarlock/adapters/hf_gpt2.py +403 -0
- invarlock/adapters/hf_llama.py +485 -0
- invarlock/adapters/hf_mixin.py +383 -0
- invarlock/adapters/hf_onnx.py +112 -0
- invarlock/adapters/hf_t5.py +137 -0
- invarlock/adapters/py.typed +1 -0
- invarlock/assurance/__init__.py +43 -0
- invarlock/cli/__init__.py +8 -0
- invarlock/cli/__main__.py +8 -0
- invarlock/cli/_evidence.py +25 -0
- invarlock/cli/_json.py +75 -0
- invarlock/cli/adapter_auto.py +162 -0
- invarlock/cli/app.py +287 -0
- invarlock/cli/commands/__init__.py +26 -0
- invarlock/cli/commands/certify.py +403 -0
- invarlock/cli/commands/doctor.py +1358 -0
- invarlock/cli/commands/explain_gates.py +151 -0
- invarlock/cli/commands/export_html.py +100 -0
- invarlock/cli/commands/plugins.py +1331 -0
- invarlock/cli/commands/report.py +354 -0
- invarlock/cli/commands/run.py +4146 -0
- invarlock/cli/commands/verify.py +1040 -0
- invarlock/cli/config.py +396 -0
- invarlock/cli/constants.py +68 -0
- invarlock/cli/device.py +92 -0
- invarlock/cli/doctor_helpers.py +74 -0
- invarlock/cli/errors.py +6 -0
- invarlock/cli/overhead_utils.py +60 -0
- invarlock/cli/provenance.py +66 -0
- invarlock/cli/utils.py +41 -0
- invarlock/config.py +56 -0
- invarlock/core/__init__.py +62 -0
- invarlock/core/abi.py +15 -0
- invarlock/core/api.py +274 -0
- invarlock/core/auto_tuning.py +317 -0
- invarlock/core/bootstrap.py +226 -0
- invarlock/core/checkpoint.py +221 -0
- invarlock/core/contracts.py +73 -0
- invarlock/core/error_utils.py +64 -0
- invarlock/core/events.py +298 -0
- invarlock/core/exceptions.py +95 -0
- invarlock/core/registry.py +481 -0
- invarlock/core/retry.py +146 -0
- invarlock/core/runner.py +2041 -0
- invarlock/core/types.py +154 -0
- invarlock/edits/__init__.py +12 -0
- invarlock/edits/_edit_utils.py +249 -0
- invarlock/edits/_external_utils.py +268 -0
- invarlock/edits/noop.py +47 -0
- invarlock/edits/py.typed +1 -0
- invarlock/edits/quant_rtn.py +801 -0
- invarlock/edits/registry.py +166 -0
- invarlock/eval/__init__.py +23 -0
- invarlock/eval/bench.py +1207 -0
- invarlock/eval/bootstrap.py +50 -0
- invarlock/eval/data.py +2052 -0
- invarlock/eval/metrics.py +2167 -0
- invarlock/eval/primary_metric.py +767 -0
- invarlock/eval/probes/__init__.py +24 -0
- invarlock/eval/probes/fft.py +139 -0
- invarlock/eval/probes/mi.py +213 -0
- invarlock/eval/probes/post_attention.py +323 -0
- invarlock/eval/providers/base.py +67 -0
- invarlock/eval/providers/seq2seq.py +111 -0
- invarlock/eval/providers/text_lm.py +113 -0
- invarlock/eval/providers/vision_text.py +93 -0
- invarlock/eval/py.typed +1 -0
- invarlock/guards/__init__.py +18 -0
- invarlock/guards/_contracts.py +9 -0
- invarlock/guards/invariants.py +640 -0
- invarlock/guards/policies.py +805 -0
- invarlock/guards/py.typed +1 -0
- invarlock/guards/rmt.py +2097 -0
- invarlock/guards/spectral.py +1419 -0
- invarlock/guards/tier_config.py +354 -0
- invarlock/guards/variance.py +3298 -0
- invarlock/guards_ref/__init__.py +15 -0
- invarlock/guards_ref/rmt_ref.py +40 -0
- invarlock/guards_ref/spectral_ref.py +135 -0
- invarlock/guards_ref/variance_ref.py +60 -0
- invarlock/model_profile.py +353 -0
- invarlock/model_utils.py +221 -0
- invarlock/observability/__init__.py +10 -0
- invarlock/observability/alerting.py +535 -0
- invarlock/observability/core.py +546 -0
- invarlock/observability/exporters.py +565 -0
- invarlock/observability/health.py +588 -0
- invarlock/observability/metrics.py +457 -0
- invarlock/observability/py.typed +1 -0
- invarlock/observability/utils.py +553 -0
- invarlock/plugins/__init__.py +12 -0
- invarlock/plugins/hello_guard.py +33 -0
- invarlock/plugins/hf_awq_adapter.py +82 -0
- invarlock/plugins/hf_bnb_adapter.py +79 -0
- invarlock/plugins/hf_gptq_adapter.py +78 -0
- invarlock/plugins/py.typed +1 -0
- invarlock/py.typed +1 -0
- invarlock/reporting/__init__.py +7 -0
- invarlock/reporting/certificate.py +3221 -0
- invarlock/reporting/certificate_schema.py +244 -0
- invarlock/reporting/dataset_hashing.py +215 -0
- invarlock/reporting/guards_analysis.py +948 -0
- invarlock/reporting/html.py +32 -0
- invarlock/reporting/normalizer.py +235 -0
- invarlock/reporting/policy_utils.py +517 -0
- invarlock/reporting/primary_metric_utils.py +265 -0
- invarlock/reporting/render.py +1442 -0
- invarlock/reporting/report.py +903 -0
- invarlock/reporting/report_types.py +278 -0
- invarlock/reporting/utils.py +175 -0
- invarlock/reporting/validate.py +631 -0
- invarlock/security.py +176 -0
- invarlock/sparsity_utils.py +323 -0
- invarlock/utils/__init__.py +150 -0
- invarlock/utils/digest.py +45 -0
- invarlock-0.2.0.dist-info/METADATA +586 -0
- invarlock-0.2.0.dist-info/RECORD +132 -0
- invarlock-0.2.0.dist-info/WHEEL +5 -0
- invarlock-0.2.0.dist-info/entry_points.txt +20 -0
- invarlock-0.2.0.dist-info/licenses/LICENSE +201 -0
- invarlock-0.2.0.dist-info/top_level.txt +1 -0
invarlock/core/runner.py
ADDED
|
@@ -0,0 +1,2041 @@
|
|
|
1
|
+
"""
|
|
2
|
+
InvarLock Core Runner
|
|
3
|
+
=================
|
|
4
|
+
|
|
5
|
+
Main pipeline execution orchestrator: prepare → edit → guards → eval → finalize/rollback.
|
|
6
|
+
Torch-independent coordination with proper event logging and checkpoint management.
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
from __future__ import annotations
|
|
10
|
+
|
|
11
|
+
import hashlib
|
|
12
|
+
import math
|
|
13
|
+
import os
|
|
14
|
+
import time
|
|
15
|
+
from array import array
|
|
16
|
+
from collections.abc import Sequence
|
|
17
|
+
from typing import Any
|
|
18
|
+
|
|
19
|
+
import numpy as np
|
|
20
|
+
|
|
21
|
+
from .api import Guard, ModelAdapter, ModelEdit, RunConfig, RunReport
|
|
22
|
+
from .auto_tuning import resolve_tier_policies
|
|
23
|
+
from .bootstrap import (
|
|
24
|
+
compute_logloss_ci,
|
|
25
|
+
compute_paired_delta_log_ci,
|
|
26
|
+
logspace_to_ratio_ci,
|
|
27
|
+
)
|
|
28
|
+
from .checkpoint import CheckpointManager
|
|
29
|
+
from .events import EventLogger
|
|
30
|
+
from .types import LogLevel, RunStatus
|
|
31
|
+
|
|
32
|
+
BOOTSTRAP_COVERAGE_REQUIREMENTS = {
|
|
33
|
+
# Minimum window counts and bootstrap replicates expected per policy tier.
|
|
34
|
+
# Individual configs can request more aggressive settings, but these values
|
|
35
|
+
# represent the guard-rail floor that CI profiles should maintain.
|
|
36
|
+
"conservative": {"preview": 220, "final": 220, "replicates": 1500},
|
|
37
|
+
"balanced": {"preview": 180, "final": 180, "replicates": 1200},
|
|
38
|
+
"aggressive": {"preview": 140, "final": 140, "replicates": 800},
|
|
39
|
+
}
|
|
40
|
+
|
|
41
|
+
__all__ = ["CoreRunner"]
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
def _collect_cuda_flags() -> dict[str, Any]:
|
|
45
|
+
"""Capture deterministic CUDA configuration for provenance."""
|
|
46
|
+
flags: dict[str, Any] = {}
|
|
47
|
+
try:
|
|
48
|
+
import torch
|
|
49
|
+
|
|
50
|
+
flags["deterministic_algorithms"] = bool(
|
|
51
|
+
torch.are_deterministic_algorithms_enabled()
|
|
52
|
+
)
|
|
53
|
+
if hasattr(torch.backends, "cudnn"):
|
|
54
|
+
flags["cudnn_deterministic"] = bool(torch.backends.cudnn.deterministic)
|
|
55
|
+
flags["cudnn_benchmark"] = bool(torch.backends.cudnn.benchmark)
|
|
56
|
+
if hasattr(torch.backends.cudnn, "allow_tf32"):
|
|
57
|
+
flags["cudnn_allow_tf32"] = bool(torch.backends.cudnn.allow_tf32)
|
|
58
|
+
if hasattr(torch.backends, "cuda") and hasattr(torch.backends.cuda, "matmul"):
|
|
59
|
+
matmul = torch.backends.cuda.matmul
|
|
60
|
+
if hasattr(matmul, "allow_tf32"):
|
|
61
|
+
flags["cuda_matmul_allow_tf32"] = bool(matmul.allow_tf32)
|
|
62
|
+
except Exception: # pragma: no cover - fallback when torch missing
|
|
63
|
+
pass
|
|
64
|
+
|
|
65
|
+
workspace = os.environ.get("CUBLAS_WORKSPACE_CONFIG")
|
|
66
|
+
if workspace:
|
|
67
|
+
flags["CUBLAS_WORKSPACE_CONFIG"] = workspace
|
|
68
|
+
return flags
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
class CoreRunner:
|
|
72
|
+
"""
|
|
73
|
+
Core pipeline execution orchestrator.
|
|
74
|
+
|
|
75
|
+
Coordinates the full InvarLock pipeline while maintaining torch-independence
|
|
76
|
+
in the core coordination logic. Provides event logging, checkpointing,
|
|
77
|
+
and rollback capabilities.
|
|
78
|
+
"""
|
|
79
|
+
|
|
80
|
+
def __init__(self):
|
|
81
|
+
self.event_logger: EventLogger | None = None
|
|
82
|
+
self.checkpoint_manager: CheckpointManager | None = None
|
|
83
|
+
|
|
84
|
+
def execute(
|
|
85
|
+
self,
|
|
86
|
+
model: Any,
|
|
87
|
+
adapter: ModelAdapter,
|
|
88
|
+
edit: ModelEdit,
|
|
89
|
+
guards: list[Guard],
|
|
90
|
+
config: RunConfig,
|
|
91
|
+
calibration_data: Any = None,
|
|
92
|
+
auto_config: dict[str, Any] | None = None,
|
|
93
|
+
edit_config: dict[str, Any] | None = None,
|
|
94
|
+
preview_n: int | None = None,
|
|
95
|
+
final_n: int | None = None,
|
|
96
|
+
) -> RunReport:
|
|
97
|
+
"""
|
|
98
|
+
Execute the full InvarLock pipeline.
|
|
99
|
+
|
|
100
|
+
Args:
|
|
101
|
+
model: The model to process
|
|
102
|
+
adapter: Model adapter for model-specific operations
|
|
103
|
+
edit: Edit to apply
|
|
104
|
+
guards: Safety guards to run
|
|
105
|
+
config: Runtime configuration
|
|
106
|
+
calibration_data: Optional calibration/validation data
|
|
107
|
+
auto_config: Optional auto-tuning configuration
|
|
108
|
+
|
|
109
|
+
Returns:
|
|
110
|
+
RunReport with execution results
|
|
111
|
+
"""
|
|
112
|
+
# Initialize services
|
|
113
|
+
self._initialize_services(config)
|
|
114
|
+
|
|
115
|
+
# Create report
|
|
116
|
+
report = RunReport()
|
|
117
|
+
report.meta["cuda_flags"] = _collect_cuda_flags()
|
|
118
|
+
report.meta["start_time"] = time.time()
|
|
119
|
+
report.meta["config"] = self._serialize_config(config)
|
|
120
|
+
if config.context:
|
|
121
|
+
try:
|
|
122
|
+
report.context.update(config.context)
|
|
123
|
+
except Exception:
|
|
124
|
+
# Defensive: ensure context remains a dict even if update fails
|
|
125
|
+
report.context = dict(config.context)
|
|
126
|
+
|
|
127
|
+
if isinstance(config.context, dict):
|
|
128
|
+
run_id = config.context.get("run_id")
|
|
129
|
+
if run_id:
|
|
130
|
+
report.meta["run_id"] = run_id
|
|
131
|
+
plugins_meta = config.context.get("plugins")
|
|
132
|
+
if plugins_meta:
|
|
133
|
+
report.meta["plugins"] = plugins_meta
|
|
134
|
+
|
|
135
|
+
# Store auto configuration for tier resolution
|
|
136
|
+
if auto_config:
|
|
137
|
+
report.meta["auto"] = auto_config
|
|
138
|
+
|
|
139
|
+
report.status = RunStatus.RUNNING.value
|
|
140
|
+
|
|
141
|
+
try:
|
|
142
|
+
# Log start
|
|
143
|
+
self._log_event(
|
|
144
|
+
"runner",
|
|
145
|
+
"start",
|
|
146
|
+
LogLevel.INFO,
|
|
147
|
+
{
|
|
148
|
+
"edit": edit.name,
|
|
149
|
+
"guards": [g.name for g in guards],
|
|
150
|
+
"context": report.context,
|
|
151
|
+
},
|
|
152
|
+
)
|
|
153
|
+
|
|
154
|
+
# Phase 1: Prepare (describe model, create checkpoint)
|
|
155
|
+
model_desc = self._prepare_phase(model, adapter, report)
|
|
156
|
+
|
|
157
|
+
# Phase 2: Prepare guards (must happen before edit)
|
|
158
|
+
self._prepare_guards_phase(
|
|
159
|
+
model, adapter, guards, calibration_data, report, auto_config
|
|
160
|
+
)
|
|
161
|
+
|
|
162
|
+
# Phase 3: Apply edit
|
|
163
|
+
self._edit_phase(model, adapter, edit, model_desc, report, edit_config)
|
|
164
|
+
|
|
165
|
+
# Phase 4: Run guards
|
|
166
|
+
guard_results = self._guard_phase(model, adapter, guards, report)
|
|
167
|
+
|
|
168
|
+
# Phase 5: Evaluate final metrics
|
|
169
|
+
metrics = self._eval_phase(
|
|
170
|
+
model,
|
|
171
|
+
adapter,
|
|
172
|
+
calibration_data,
|
|
173
|
+
report,
|
|
174
|
+
preview_n,
|
|
175
|
+
final_n,
|
|
176
|
+
config,
|
|
177
|
+
)
|
|
178
|
+
|
|
179
|
+
# Phase 6: Finalize or rollback
|
|
180
|
+
final_status = self._finalize_phase(
|
|
181
|
+
model, adapter, guard_results, metrics, config, report
|
|
182
|
+
)
|
|
183
|
+
|
|
184
|
+
report.status = final_status
|
|
185
|
+
report.meta["end_time"] = time.time()
|
|
186
|
+
report.meta["duration"] = (
|
|
187
|
+
report.meta["end_time"] - report.meta["start_time"]
|
|
188
|
+
)
|
|
189
|
+
|
|
190
|
+
self._log_event(
|
|
191
|
+
"runner",
|
|
192
|
+
"complete",
|
|
193
|
+
LogLevel.INFO,
|
|
194
|
+
{"status": final_status, "duration": report.meta["duration"]},
|
|
195
|
+
)
|
|
196
|
+
|
|
197
|
+
return report
|
|
198
|
+
|
|
199
|
+
except Exception as e:
|
|
200
|
+
self._handle_error(e, report)
|
|
201
|
+
return report
|
|
202
|
+
|
|
203
|
+
finally:
|
|
204
|
+
self._cleanup_services()
|
|
205
|
+
|
|
206
|
+
def _initialize_services(self, config: RunConfig) -> None:
|
|
207
|
+
"""Initialize event logging and checkpoint services."""
|
|
208
|
+
if config.event_path:
|
|
209
|
+
run_id = None
|
|
210
|
+
if isinstance(config.context, dict):
|
|
211
|
+
run_id = config.context.get("run_id")
|
|
212
|
+
self.event_logger = EventLogger(config.event_path, run_id=run_id)
|
|
213
|
+
|
|
214
|
+
if config.checkpoint_interval > 0:
|
|
215
|
+
self.checkpoint_manager = CheckpointManager()
|
|
216
|
+
|
|
217
|
+
def _cleanup_services(self) -> None:
|
|
218
|
+
"""Clean up services."""
|
|
219
|
+
if self.event_logger:
|
|
220
|
+
self.event_logger.close()
|
|
221
|
+
|
|
222
|
+
if self.checkpoint_manager:
|
|
223
|
+
self.checkpoint_manager.cleanup()
|
|
224
|
+
|
|
225
|
+
def _prepare_phase(
|
|
226
|
+
self, model: Any, adapter: ModelAdapter, report: RunReport
|
|
227
|
+
) -> dict[str, Any]:
|
|
228
|
+
"""Phase 1: Model preparation and analysis."""
|
|
229
|
+
self._log_event("prepare", "start", LogLevel.INFO)
|
|
230
|
+
|
|
231
|
+
# Describe model structure
|
|
232
|
+
model_desc = adapter.describe(model)
|
|
233
|
+
report.meta["model"] = model_desc
|
|
234
|
+
|
|
235
|
+
# Create initial checkpoint
|
|
236
|
+
if self.checkpoint_manager:
|
|
237
|
+
checkpoint_id = self.checkpoint_manager.create_checkpoint(model, adapter)
|
|
238
|
+
report.meta["initial_checkpoint"] = checkpoint_id
|
|
239
|
+
self._log_event(
|
|
240
|
+
"prepare", "checkpoint_created", LogLevel.INFO, {"id": checkpoint_id}
|
|
241
|
+
)
|
|
242
|
+
|
|
243
|
+
self._log_event(
|
|
244
|
+
"prepare",
|
|
245
|
+
"complete",
|
|
246
|
+
LogLevel.INFO,
|
|
247
|
+
{"layers": model_desc.get("n_layer", 0)},
|
|
248
|
+
)
|
|
249
|
+
|
|
250
|
+
return model_desc
|
|
251
|
+
|
|
252
|
+
def _edit_phase(
|
|
253
|
+
self,
|
|
254
|
+
model: Any,
|
|
255
|
+
adapter: ModelAdapter,
|
|
256
|
+
edit: ModelEdit,
|
|
257
|
+
model_desc: dict[str, Any],
|
|
258
|
+
report: RunReport,
|
|
259
|
+
edit_config: dict[str, Any] | None = None,
|
|
260
|
+
) -> dict[str, Any]:
|
|
261
|
+
"""Phase 2: Apply edit operation."""
|
|
262
|
+
edit_label = "baseline" if edit.name == "baseline" else edit.name
|
|
263
|
+
self._log_event("edit", "start", LogLevel.INFO, {"edit": edit_label})
|
|
264
|
+
|
|
265
|
+
# Store edit name for tier resolution
|
|
266
|
+
report.meta["edit_name"] = edit.name
|
|
267
|
+
|
|
268
|
+
# Check if edit can be applied
|
|
269
|
+
if not edit.can_edit(model_desc):
|
|
270
|
+
raise ValueError(f"Edit '{edit.name}' cannot be applied to this model")
|
|
271
|
+
|
|
272
|
+
# Apply edit with configuration parameters
|
|
273
|
+
if edit_config:
|
|
274
|
+
edit_result = edit.apply(model, adapter, **edit_config)
|
|
275
|
+
else:
|
|
276
|
+
edit_result = edit.apply(model, adapter)
|
|
277
|
+
report.edit = edit_result
|
|
278
|
+
if not isinstance(report.context, dict):
|
|
279
|
+
report.context = {}
|
|
280
|
+
edit_context = report.context.setdefault("edit", {})
|
|
281
|
+
if isinstance(edit_result, dict):
|
|
282
|
+
edit_context.setdefault("name", edit_result.get("name", edit.name))
|
|
283
|
+
deltas = edit_result.get("deltas") or {}
|
|
284
|
+
if isinstance(deltas, dict):
|
|
285
|
+
edit_context["params_changed"] = deltas.get("params_changed", 0)
|
|
286
|
+
edit_context["layers_modified"] = deltas.get("layers_modified", 0)
|
|
287
|
+
else:
|
|
288
|
+
edit_context.setdefault("params_changed", 0)
|
|
289
|
+
else:
|
|
290
|
+
edit_context.setdefault("name", edit.name)
|
|
291
|
+
edit_context.setdefault("params_changed", 0)
|
|
292
|
+
|
|
293
|
+
self._log_event(
|
|
294
|
+
"edit",
|
|
295
|
+
"complete",
|
|
296
|
+
LogLevel.INFO,
|
|
297
|
+
{"edit": edit.name, "result": edit_result},
|
|
298
|
+
)
|
|
299
|
+
|
|
300
|
+
return edit_result
|
|
301
|
+
|
|
302
|
+
def _prepare_guards_phase(
|
|
303
|
+
self,
|
|
304
|
+
model: Any,
|
|
305
|
+
adapter: ModelAdapter,
|
|
306
|
+
guards: list[Guard],
|
|
307
|
+
calibration_data: Any,
|
|
308
|
+
report: RunReport,
|
|
309
|
+
auto_config: dict[str, Any] | None = None,
|
|
310
|
+
) -> None:
|
|
311
|
+
"""Phase 2: Prepare safety guards with tier-resolved policies."""
|
|
312
|
+
self._log_event(
|
|
313
|
+
"guards_prepare", "start", LogLevel.INFO, {"count": len(guards)}
|
|
314
|
+
)
|
|
315
|
+
|
|
316
|
+
# Resolve tier policies before guard preparation
|
|
317
|
+
tier_policies = self._resolve_guard_policies(report, auto_config)
|
|
318
|
+
|
|
319
|
+
for guard in guards:
|
|
320
|
+
self._log_event(
|
|
321
|
+
"guard_prepare", "start", LogLevel.INFO, {"guard": guard.name}
|
|
322
|
+
)
|
|
323
|
+
|
|
324
|
+
try:
|
|
325
|
+
guard_policy: dict[str, Any] = tier_policies.get(guard.name, {})
|
|
326
|
+
|
|
327
|
+
# Apply tier-resolved policy to guard
|
|
328
|
+
if guard_policy:
|
|
329
|
+
self._apply_guard_policy(guard, guard_policy)
|
|
330
|
+
self._log_event(
|
|
331
|
+
"guard_prepare",
|
|
332
|
+
"policy_applied",
|
|
333
|
+
LogLevel.INFO,
|
|
334
|
+
{"guard": guard.name, "policy": guard_policy},
|
|
335
|
+
)
|
|
336
|
+
|
|
337
|
+
if hasattr(guard, "set_run_context"):
|
|
338
|
+
try:
|
|
339
|
+
guard.set_run_context(report)
|
|
340
|
+
except Exception as exc:
|
|
341
|
+
self._log_event(
|
|
342
|
+
"guard_prepare",
|
|
343
|
+
"context_error",
|
|
344
|
+
LogLevel.WARNING,
|
|
345
|
+
{"guard": guard.name, "error": str(exc)},
|
|
346
|
+
)
|
|
347
|
+
|
|
348
|
+
# Call prepare method if it exists (most guards need this)
|
|
349
|
+
if hasattr(guard, "prepare"):
|
|
350
|
+
prepare_result = guard.prepare(
|
|
351
|
+
model, adapter, calibration_data, guard_policy
|
|
352
|
+
)
|
|
353
|
+
self._log_event(
|
|
354
|
+
"guard_prepare",
|
|
355
|
+
"complete",
|
|
356
|
+
LogLevel.INFO,
|
|
357
|
+
{
|
|
358
|
+
"guard": guard.name,
|
|
359
|
+
"ready": prepare_result.get("ready", False),
|
|
360
|
+
},
|
|
361
|
+
)
|
|
362
|
+
else:
|
|
363
|
+
self._log_event(
|
|
364
|
+
"guard_prepare",
|
|
365
|
+
"skipped",
|
|
366
|
+
LogLevel.INFO,
|
|
367
|
+
{"guard": guard.name, "reason": "no_prepare_method"},
|
|
368
|
+
)
|
|
369
|
+
|
|
370
|
+
except Exception as e:
|
|
371
|
+
self._log_event(
|
|
372
|
+
"guard_prepare",
|
|
373
|
+
"error",
|
|
374
|
+
LogLevel.ERROR,
|
|
375
|
+
{"guard": guard.name, "error": str(e)},
|
|
376
|
+
)
|
|
377
|
+
|
|
378
|
+
# Store resolved policies in report for certificate
|
|
379
|
+
report.meta["tier_policies"] = tier_policies
|
|
380
|
+
|
|
381
|
+
self._log_event(
|
|
382
|
+
"guards_prepare", "complete", LogLevel.INFO, {"count": len(guards)}
|
|
383
|
+
)
|
|
384
|
+
|
|
385
|
+
def _guard_phase(
|
|
386
|
+
self, model: Any, adapter: ModelAdapter, guards: list[Guard], report: RunReport
|
|
387
|
+
) -> dict[str, dict[str, Any]]:
|
|
388
|
+
"""Phase 4: Run safety guards."""
|
|
389
|
+
self._log_event("guards", "start", LogLevel.INFO, {"count": len(guards)})
|
|
390
|
+
|
|
391
|
+
guard_results = {}
|
|
392
|
+
|
|
393
|
+
for guard in guards:
|
|
394
|
+
self._log_event("guard", "start", LogLevel.INFO, {"guard": guard.name})
|
|
395
|
+
|
|
396
|
+
if hasattr(guard, "set_run_context"):
|
|
397
|
+
try:
|
|
398
|
+
guard.set_run_context(report)
|
|
399
|
+
except Exception as exc: # pragma: no cover - defensive
|
|
400
|
+
self._log_event(
|
|
401
|
+
"guard",
|
|
402
|
+
"context_error",
|
|
403
|
+
LogLevel.WARNING,
|
|
404
|
+
{"guard": guard.name, "error": str(exc)},
|
|
405
|
+
)
|
|
406
|
+
|
|
407
|
+
try:
|
|
408
|
+
result = guard.validate(model, adapter, report.context)
|
|
409
|
+
guard_results[guard.name] = result
|
|
410
|
+
|
|
411
|
+
# Log guard result
|
|
412
|
+
status = "passed" if result.get("passed", False) else "failed"
|
|
413
|
+
self._log_event(
|
|
414
|
+
"guard",
|
|
415
|
+
"complete",
|
|
416
|
+
LogLevel.INFO,
|
|
417
|
+
{"guard": guard.name, "status": status},
|
|
418
|
+
)
|
|
419
|
+
|
|
420
|
+
except Exception as e:
|
|
421
|
+
guard_results[guard.name] = {"passed": False, "error": str(e)}
|
|
422
|
+
self._log_event(
|
|
423
|
+
"guard",
|
|
424
|
+
"error",
|
|
425
|
+
LogLevel.ERROR,
|
|
426
|
+
{"guard": guard.name, "error": str(e)},
|
|
427
|
+
)
|
|
428
|
+
|
|
429
|
+
report.guards = guard_results
|
|
430
|
+
|
|
431
|
+
# Summary
|
|
432
|
+
passed_guards = sum(1 for r in guard_results.values() if r.get("passed", False))
|
|
433
|
+
self._log_event(
|
|
434
|
+
"guards",
|
|
435
|
+
"complete",
|
|
436
|
+
LogLevel.INFO,
|
|
437
|
+
{"total": len(guards), "passed": passed_guards},
|
|
438
|
+
)
|
|
439
|
+
|
|
440
|
+
return guard_results
|
|
441
|
+
|
|
442
|
+
def _eval_phase(
|
|
443
|
+
self,
|
|
444
|
+
model: Any,
|
|
445
|
+
adapter: ModelAdapter,
|
|
446
|
+
calibration_data: Any,
|
|
447
|
+
report: RunReport,
|
|
448
|
+
preview_n: int | None = None,
|
|
449
|
+
final_n: int | None = None,
|
|
450
|
+
config: RunConfig | None = None,
|
|
451
|
+
) -> dict[str, Any]:
|
|
452
|
+
"""Phase 4: Final evaluation metrics."""
|
|
453
|
+
self._log_event("eval", "start", LogLevel.INFO)
|
|
454
|
+
|
|
455
|
+
if calibration_data is not None:
|
|
456
|
+
if os.environ.get("INVARLOCK_DEBUG_TRACE"):
|
|
457
|
+
length_hint = None
|
|
458
|
+
try:
|
|
459
|
+
length_hint = len(calibration_data) # type: ignore[arg-type]
|
|
460
|
+
except Exception: # pragma: no cover - defensive
|
|
461
|
+
length_hint = None
|
|
462
|
+
first_batch = None
|
|
463
|
+
indexable = hasattr(calibration_data, "__getitem__")
|
|
464
|
+
if isinstance(calibration_data, list | tuple):
|
|
465
|
+
if calibration_data:
|
|
466
|
+
first_batch = calibration_data[0]
|
|
467
|
+
elif indexable:
|
|
468
|
+
try:
|
|
469
|
+
first_batch = calibration_data[0] # type: ignore[index]
|
|
470
|
+
except Exception: # pragma: no cover - defensive
|
|
471
|
+
first_batch = None
|
|
472
|
+
masked_preview = None
|
|
473
|
+
first_keys = None
|
|
474
|
+
if isinstance(first_batch, dict):
|
|
475
|
+
first_keys = list(first_batch.keys())
|
|
476
|
+
labels_preview = first_batch.get("labels")
|
|
477
|
+
if isinstance(labels_preview, list | tuple):
|
|
478
|
+
try:
|
|
479
|
+
masked_preview = sum(
|
|
480
|
+
1 for tok in labels_preview if tok != -100
|
|
481
|
+
)
|
|
482
|
+
except Exception: # pragma: no cover - defensive
|
|
483
|
+
masked_preview = None
|
|
484
|
+
self._log_event(
|
|
485
|
+
"eval",
|
|
486
|
+
"calibration_snapshot",
|
|
487
|
+
LogLevel.DEBUG,
|
|
488
|
+
{
|
|
489
|
+
"calibration_type": type(calibration_data).__name__,
|
|
490
|
+
"length_hint": length_hint,
|
|
491
|
+
"indexable": bool(indexable),
|
|
492
|
+
"first_batch_keys": first_keys,
|
|
493
|
+
"first_batch_masked": masked_preview,
|
|
494
|
+
},
|
|
495
|
+
)
|
|
496
|
+
# Compute real perplexity using calibration data
|
|
497
|
+
metrics, eval_windows = self._compute_real_metrics(
|
|
498
|
+
model,
|
|
499
|
+
calibration_data,
|
|
500
|
+
adapter,
|
|
501
|
+
preview_n,
|
|
502
|
+
final_n,
|
|
503
|
+
config,
|
|
504
|
+
)
|
|
505
|
+
else:
|
|
506
|
+
# Fallback to mock metrics if no calibration data
|
|
507
|
+
self._log_event(
|
|
508
|
+
"eval",
|
|
509
|
+
"warning",
|
|
510
|
+
LogLevel.WARNING,
|
|
511
|
+
{"message": "No calibration data provided, using mock metrics"},
|
|
512
|
+
)
|
|
513
|
+
# Provide a minimal primary_metric snapshot and basic perf counters
|
|
514
|
+
metrics = {
|
|
515
|
+
"primary_metric": {
|
|
516
|
+
"kind": "ppl_causal",
|
|
517
|
+
"preview": 25.0,
|
|
518
|
+
"final": 26.0,
|
|
519
|
+
},
|
|
520
|
+
"latency_ms_per_tok": 15.0,
|
|
521
|
+
"memory_mb_peak": 1024.0,
|
|
522
|
+
}
|
|
523
|
+
eval_windows = {"preview": {}, "final": {}}
|
|
524
|
+
|
|
525
|
+
# Store metrics in report
|
|
526
|
+
if hasattr(report, "metrics"):
|
|
527
|
+
report.metrics.update(metrics)
|
|
528
|
+
else:
|
|
529
|
+
report.metrics = metrics
|
|
530
|
+
|
|
531
|
+
report.evaluation_windows = eval_windows
|
|
532
|
+
|
|
533
|
+
self._log_event("eval", "complete", LogLevel.INFO, {"metrics": metrics})
|
|
534
|
+
|
|
535
|
+
return metrics
|
|
536
|
+
|
|
537
|
+
def _compute_real_metrics(
|
|
538
|
+
self,
|
|
539
|
+
model: Any,
|
|
540
|
+
calibration_data: Any,
|
|
541
|
+
adapter: ModelAdapter,
|
|
542
|
+
preview_n: int | None = None,
|
|
543
|
+
final_n: int | None = None,
|
|
544
|
+
config: RunConfig | None = None,
|
|
545
|
+
) -> tuple[dict[str, Any], dict[str, Any]]:
|
|
546
|
+
"""Compute real evaluation metrics using calibration data."""
|
|
547
|
+
import os
|
|
548
|
+
|
|
549
|
+
import psutil
|
|
550
|
+
import torch
|
|
551
|
+
|
|
552
|
+
_ = adapter # Adapter kept for API parity; direct HF forward used.
|
|
553
|
+
|
|
554
|
+
model.eval()
|
|
555
|
+
|
|
556
|
+
if os.environ.get("INVARLOCK_DEBUG_TRACE"):
|
|
557
|
+
print(
|
|
558
|
+
f"[debug] compute_real_metrics preview_n={preview_n} final_n={final_n} calibration_len={len(calibration_data) if hasattr(calibration_data, '__len__') else 'n/a'}"
|
|
559
|
+
)
|
|
560
|
+
device = next(model.parameters()).device
|
|
561
|
+
|
|
562
|
+
eval_device_override = os.environ.get("INVARLOCK_EVAL_DEVICE")
|
|
563
|
+
if eval_device_override:
|
|
564
|
+
override_device = torch.device(eval_device_override)
|
|
565
|
+
if override_device != device:
|
|
566
|
+
model.to(override_device)
|
|
567
|
+
device = override_device
|
|
568
|
+
|
|
569
|
+
process = psutil.Process(os.getpid())
|
|
570
|
+
initial_memory = process.memory_info().rss / 1024 / 1024 # MB
|
|
571
|
+
|
|
572
|
+
total_available = (
|
|
573
|
+
len(calibration_data) if hasattr(calibration_data, "__len__") else 0
|
|
574
|
+
)
|
|
575
|
+
if total_available == 0:
|
|
576
|
+
raise ValueError("Calibration data is empty; cannot compute metrics.")
|
|
577
|
+
|
|
578
|
+
if preview_n is None:
|
|
579
|
+
preview_n = max(total_available // 2, 1)
|
|
580
|
+
if final_n is None:
|
|
581
|
+
final_n = preview_n
|
|
582
|
+
|
|
583
|
+
preview_n = max(int(preview_n), 0)
|
|
584
|
+
final_n = max(int(final_n), 0)
|
|
585
|
+
max_needed = max(preview_n, final_n)
|
|
586
|
+
if max_needed <= 0:
|
|
587
|
+
raise ValueError("preview_n and final_n cannot both be zero.")
|
|
588
|
+
|
|
589
|
+
if max_needed > total_available:
|
|
590
|
+
self._log_event(
|
|
591
|
+
"eval",
|
|
592
|
+
"data_scaled",
|
|
593
|
+
LogLevel.WARNING,
|
|
594
|
+
{
|
|
595
|
+
"requested_preview": preview_n,
|
|
596
|
+
"requested_final": final_n,
|
|
597
|
+
"available": total_available,
|
|
598
|
+
},
|
|
599
|
+
)
|
|
600
|
+
preview_n = min(preview_n, total_available)
|
|
601
|
+
final_n = min(final_n, total_available)
|
|
602
|
+
max_needed = max(preview_n, final_n)
|
|
603
|
+
|
|
604
|
+
requested_preview = preview_n
|
|
605
|
+
requested_final = final_n
|
|
606
|
+
|
|
607
|
+
max_needed = preview_n + final_n
|
|
608
|
+
if max_needed > total_available:
|
|
609
|
+
self._log_event(
|
|
610
|
+
"eval",
|
|
611
|
+
"window_shortage",
|
|
612
|
+
LogLevel.WARNING,
|
|
613
|
+
{
|
|
614
|
+
"requested_preview": preview_n,
|
|
615
|
+
"requested_final": final_n,
|
|
616
|
+
"available": total_available,
|
|
617
|
+
},
|
|
618
|
+
)
|
|
619
|
+
max_needed = min(total_available, max_needed)
|
|
620
|
+
|
|
621
|
+
preview_n = min(preview_n, total_available)
|
|
622
|
+
final_start = preview_n
|
|
623
|
+
remaining = max(total_available - preview_n, 0)
|
|
624
|
+
if final_n > remaining:
|
|
625
|
+
self._log_event(
|
|
626
|
+
"eval",
|
|
627
|
+
"final_window_shortage",
|
|
628
|
+
LogLevel.WARNING,
|
|
629
|
+
{
|
|
630
|
+
"requested_final": final_n,
|
|
631
|
+
"available_after_preview": remaining,
|
|
632
|
+
"requested_preview": requested_preview,
|
|
633
|
+
"requested_final_original": requested_final,
|
|
634
|
+
},
|
|
635
|
+
)
|
|
636
|
+
final_n = remaining
|
|
637
|
+
|
|
638
|
+
preview_data = calibration_data[:preview_n]
|
|
639
|
+
final_data = calibration_data[final_start : final_start + final_n]
|
|
640
|
+
|
|
641
|
+
eval_context: dict[str, Any] = {}
|
|
642
|
+
if config and isinstance(config.context, dict):
|
|
643
|
+
eval_context = config.context.get("eval", {}) or {}
|
|
644
|
+
|
|
645
|
+
loss_cfg = (
|
|
646
|
+
eval_context.get("loss", {}) if isinstance(eval_context, dict) else {}
|
|
647
|
+
)
|
|
648
|
+
resolved_loss_mode = str(
|
|
649
|
+
loss_cfg.get("resolved_type") or loss_cfg.get("type") or ""
|
|
650
|
+
).lower()
|
|
651
|
+
bootstrap_cfg = eval_context.get("bootstrap", {}) or {}
|
|
652
|
+
bootstrap_enabled = bool(bootstrap_cfg.get("enabled", True))
|
|
653
|
+
bootstrap_method = str(
|
|
654
|
+
bootstrap_cfg.get("method", "bca_paired_delta_log")
|
|
655
|
+
).lower()
|
|
656
|
+
bootstrap_replicates = int(
|
|
657
|
+
bootstrap_cfg.get("replicates", bootstrap_cfg.get("n", 1000) or 1000)
|
|
658
|
+
)
|
|
659
|
+
bootstrap_alpha = float(bootstrap_cfg.get("alpha", 0.05) or 0.05)
|
|
660
|
+
bootstrap_seed_cfg = bootstrap_cfg.get("seed")
|
|
661
|
+
ci_band = float(bootstrap_cfg.get("ci_band", 0.10) or 0.10)
|
|
662
|
+
|
|
663
|
+
single_method = "bca"
|
|
664
|
+
delta_method = "bca"
|
|
665
|
+
if bootstrap_method == "percentile":
|
|
666
|
+
single_method = "percentile"
|
|
667
|
+
delta_method = "percentile"
|
|
668
|
+
elif bootstrap_method == "bca_paired_delta_log":
|
|
669
|
+
single_method = "bca"
|
|
670
|
+
delta_method = "bca"
|
|
671
|
+
else:
|
|
672
|
+
single_method = bootstrap_method
|
|
673
|
+
delta_method = bootstrap_method
|
|
674
|
+
|
|
675
|
+
dataset_seed = None
|
|
676
|
+
profile_label = ""
|
|
677
|
+
pairing_context: dict[str, Any] = {}
|
|
678
|
+
if config and isinstance(config.context, dict):
|
|
679
|
+
dataset_cfg = config.context.get("dataset", {})
|
|
680
|
+
if isinstance(dataset_cfg, dict):
|
|
681
|
+
dataset_seed = dataset_cfg.get("seed")
|
|
682
|
+
profile_label = str(config.context.get("profile", "")).lower()
|
|
683
|
+
pairing_context = config.context.get("pairing_baseline", {}) or {}
|
|
684
|
+
|
|
685
|
+
bootstrap_seed = (
|
|
686
|
+
bootstrap_seed_cfg if bootstrap_seed_cfg is not None else dataset_seed
|
|
687
|
+
)
|
|
688
|
+
try:
|
|
689
|
+
bootstrap_seed = int(bootstrap_seed) if bootstrap_seed is not None else 0
|
|
690
|
+
except (TypeError, ValueError):
|
|
691
|
+
bootstrap_seed = 0
|
|
692
|
+
|
|
693
|
+
if bootstrap_replicates <= 0:
|
|
694
|
+
bootstrap_enabled = False
|
|
695
|
+
if not (0.0 < bootstrap_alpha < 1.0):
|
|
696
|
+
bootstrap_alpha = 0.05
|
|
697
|
+
|
|
698
|
+
pm_preview = 50.0
|
|
699
|
+
pm_final = 50.0
|
|
700
|
+
pm_ratio = 1.0
|
|
701
|
+
ratio_ci: tuple[float, float] = (pm_ratio, pm_ratio)
|
|
702
|
+
preview_log_ci: tuple[float, float] = (
|
|
703
|
+
math.log(pm_preview),
|
|
704
|
+
math.log(pm_preview),
|
|
705
|
+
)
|
|
706
|
+
final_log_ci: tuple[float, float] = (math.log(pm_final), math.log(pm_final))
|
|
707
|
+
delta_log_ci: tuple[float, float] = (0.0, 0.0)
|
|
708
|
+
preview_mean_log = math.log(pm_preview)
|
|
709
|
+
final_mean_log = math.log(pm_final)
|
|
710
|
+
delta_mean_log = 0.0
|
|
711
|
+
preview_log_losses: list[float] = []
|
|
712
|
+
final_log_losses: list[float] = []
|
|
713
|
+
preview_tokens_ct = 0
|
|
714
|
+
final_tokens_ct = 0
|
|
715
|
+
preview_batches_ct = 0
|
|
716
|
+
final_batches_ct = 0
|
|
717
|
+
window_overlap_fraction = 0.0
|
|
718
|
+
# Defaults for pairing metrics to avoid unbound locals on error paths
|
|
719
|
+
window_match_fraction = 1.0
|
|
720
|
+
pairing_reason = None
|
|
721
|
+
preview_pair_stats = {"matched": 0, "expected": 0}
|
|
722
|
+
final_pair_stats = {"matched": 0, "expected": 0}
|
|
723
|
+
preview_window_ids: list[int] = []
|
|
724
|
+
final_window_ids: list[int] = []
|
|
725
|
+
preview_tokens: list[list[int]] = []
|
|
726
|
+
final_tokens: list[list[int]] = []
|
|
727
|
+
preview_limit = min(preview_n, len(preview_data)) if preview_data else 0
|
|
728
|
+
final_limit = min(final_n, len(final_data)) if final_data else 0
|
|
729
|
+
|
|
730
|
+
# Safe defaults in case of early exceptions inside compute block
|
|
731
|
+
preview_actual_tokens_ct = int(preview_tokens_ct)
|
|
732
|
+
final_actual_tokens_ct = int(final_tokens_ct)
|
|
733
|
+
preview_masked_total = int(preview_tokens_ct)
|
|
734
|
+
final_masked_total = int(final_tokens_ct)
|
|
735
|
+
preview_token_counts = []
|
|
736
|
+
final_token_counts = []
|
|
737
|
+
preview_attention_masks: list[list[int]] = []
|
|
738
|
+
final_attention_masks: list[list[int]] = []
|
|
739
|
+
preview_mask_counts: list[int] = []
|
|
740
|
+
final_mask_counts: list[int] = []
|
|
741
|
+
preview_labels: list[list[int]] = []
|
|
742
|
+
final_labels: list[list[int]] = []
|
|
743
|
+
preview_actual_token_counts: list[int] = []
|
|
744
|
+
final_actual_token_counts: list[int] = []
|
|
745
|
+
|
|
746
|
+
# Defaults for degeneracy flags
|
|
747
|
+
degenerate_delta = False
|
|
748
|
+
degenerate_reason: str | None = None
|
|
749
|
+
|
|
750
|
+
bootstrap_info = {
|
|
751
|
+
"enabled": bool(bootstrap_enabled),
|
|
752
|
+
"method": bootstrap_method,
|
|
753
|
+
"alpha": float(bootstrap_alpha),
|
|
754
|
+
"replicates": int(bootstrap_replicates),
|
|
755
|
+
"seed": int(bootstrap_seed),
|
|
756
|
+
"ci_band": float(ci_band),
|
|
757
|
+
}
|
|
758
|
+
|
|
759
|
+
alignment_logged = False
|
|
760
|
+
|
|
761
|
+
# Initialize to safe defaults to ensure later metrics assembly succeeds
|
|
762
|
+
# even if an exception occurs during the main compute block.
|
|
763
|
+
delta_samples: list[float] = []
|
|
764
|
+
delta_weights: list[float] = []
|
|
765
|
+
|
|
766
|
+
try:
|
|
767
|
+
|
|
768
|
+
def _resolve_limit(batches: Sequence[Any], requested: int) -> int:
|
|
769
|
+
if not batches:
|
|
770
|
+
return 0
|
|
771
|
+
if requested <= 0:
|
|
772
|
+
return len(batches)
|
|
773
|
+
return min(len(batches), requested)
|
|
774
|
+
|
|
775
|
+
def _compute_slice_summary(
|
|
776
|
+
batches: Sequence[Any],
|
|
777
|
+
max_batches: int,
|
|
778
|
+
start_idx: int,
|
|
779
|
+
) -> dict[str, Any]:
|
|
780
|
+
nonlocal alignment_logged
|
|
781
|
+
|
|
782
|
+
total_tokens_local = 0
|
|
783
|
+
actual_tokens_local = 0
|
|
784
|
+
weighted_log_loss = 0.0
|
|
785
|
+
log_losses: list[float] = []
|
|
786
|
+
window_ids: list[int] = []
|
|
787
|
+
collected_tokens: list[list[int]] = []
|
|
788
|
+
collected_attn: list[list[int]] = []
|
|
789
|
+
collected_labels: list[list[int]] = []
|
|
790
|
+
token_counts: list[int] = []
|
|
791
|
+
masked_token_counts: list[int] = []
|
|
792
|
+
actual_token_counts: list[int] = []
|
|
793
|
+
count = 0
|
|
794
|
+
zero_mask_batches = 0
|
|
795
|
+
any_labels_seen = False
|
|
796
|
+
store_windows = os.environ.get(
|
|
797
|
+
"INVARLOCK_STORE_EVAL_WINDOWS", "1"
|
|
798
|
+
).lower() not in {"0", "false", "no"}
|
|
799
|
+
|
|
800
|
+
if not batches:
|
|
801
|
+
return {
|
|
802
|
+
"ppl": float("nan"),
|
|
803
|
+
"total_tokens": 0,
|
|
804
|
+
"num_batches": 0,
|
|
805
|
+
"log_losses": [],
|
|
806
|
+
"window_ids": [],
|
|
807
|
+
"tokens": [],
|
|
808
|
+
"attention_masks": [],
|
|
809
|
+
"weighted_log_loss": 0.0,
|
|
810
|
+
"window_token_counts": [],
|
|
811
|
+
}
|
|
812
|
+
|
|
813
|
+
limit = _resolve_limit(batches, max_batches)
|
|
814
|
+
|
|
815
|
+
for batch in batches[:limit]:
|
|
816
|
+
if max_batches > 0 and count >= max_batches:
|
|
817
|
+
break
|
|
818
|
+
|
|
819
|
+
labels = None
|
|
820
|
+
if isinstance(batch, dict):
|
|
821
|
+
input_ids = batch.get("input_ids", batch.get("inputs"))
|
|
822
|
+
attention_mask = batch.get("attention_mask")
|
|
823
|
+
labels = batch.get("labels")
|
|
824
|
+
else:
|
|
825
|
+
input_ids = batch
|
|
826
|
+
attention_mask = None
|
|
827
|
+
|
|
828
|
+
if input_ids is None:
|
|
829
|
+
continue
|
|
830
|
+
|
|
831
|
+
if isinstance(input_ids, torch.Tensor):
|
|
832
|
+
input_ids_t = input_ids.to(device=device, dtype=torch.long)
|
|
833
|
+
else:
|
|
834
|
+
input_ids_t = torch.as_tensor(
|
|
835
|
+
input_ids, device=device, dtype=torch.long
|
|
836
|
+
)
|
|
837
|
+
|
|
838
|
+
if input_ids_t.dim() == 1:
|
|
839
|
+
input_ids_t = input_ids_t.unsqueeze(0)
|
|
840
|
+
|
|
841
|
+
attn_t = None
|
|
842
|
+
if attention_mask is not None:
|
|
843
|
+
if isinstance(attention_mask, torch.Tensor):
|
|
844
|
+
attn_t = attention_mask.to(device=device, dtype=torch.long)
|
|
845
|
+
else:
|
|
846
|
+
attn_t = torch.as_tensor(
|
|
847
|
+
attention_mask, device=device, dtype=torch.long
|
|
848
|
+
)
|
|
849
|
+
if attn_t.dim() == 1:
|
|
850
|
+
attn_t = attn_t.unsqueeze(0)
|
|
851
|
+
|
|
852
|
+
if labels is not None:
|
|
853
|
+
any_labels_seen = True
|
|
854
|
+
if isinstance(labels, torch.Tensor):
|
|
855
|
+
labels_t = labels.to(device=device, dtype=torch.long)
|
|
856
|
+
else:
|
|
857
|
+
labels_t = torch.as_tensor(
|
|
858
|
+
labels, device=device, dtype=torch.long
|
|
859
|
+
)
|
|
860
|
+
if labels_t.dim() == 1:
|
|
861
|
+
labels_t = labels_t.unsqueeze(0)
|
|
862
|
+
else:
|
|
863
|
+
labels_t = input_ids_t.clone()
|
|
864
|
+
if attn_t is not None:
|
|
865
|
+
labels_t = labels_t.masked_fill(attn_t == 0, -100)
|
|
866
|
+
|
|
867
|
+
snapshot = input_ids_t.detach().cpu()
|
|
868
|
+
attn_snapshot = (
|
|
869
|
+
attn_t.detach().cpu() if attn_t is not None else None
|
|
870
|
+
)
|
|
871
|
+
|
|
872
|
+
with torch.no_grad():
|
|
873
|
+
if attn_t is not None:
|
|
874
|
+
outputs = model(
|
|
875
|
+
input_ids_t, attention_mask=attn_t, labels=labels_t
|
|
876
|
+
)
|
|
877
|
+
else:
|
|
878
|
+
outputs = model(input_ids_t, labels=labels_t)
|
|
879
|
+
|
|
880
|
+
loss_val = (
|
|
881
|
+
outputs.loss.item()
|
|
882
|
+
if hasattr(outputs, "loss") and hasattr(outputs.loss, "item")
|
|
883
|
+
else None
|
|
884
|
+
)
|
|
885
|
+
if loss_val is None:
|
|
886
|
+
if os.environ.get("INVARLOCK_DEBUG_TRACE"):
|
|
887
|
+
self._log_event(
|
|
888
|
+
"eval",
|
|
889
|
+
"missing_loss",
|
|
890
|
+
LogLevel.DEBUG,
|
|
891
|
+
{
|
|
892
|
+
"has_loss_attr": bool(hasattr(outputs, "loss")),
|
|
893
|
+
"labels_provided": bool(labels is not None),
|
|
894
|
+
"window_index": start_idx + count,
|
|
895
|
+
},
|
|
896
|
+
)
|
|
897
|
+
continue
|
|
898
|
+
|
|
899
|
+
if attn_snapshot is not None:
|
|
900
|
+
tokens_in_batch = int(attn_snapshot.sum().item())
|
|
901
|
+
else:
|
|
902
|
+
tokens_in_batch = int(input_ids_t.numel())
|
|
903
|
+
|
|
904
|
+
if tokens_in_batch <= 0:
|
|
905
|
+
continue
|
|
906
|
+
|
|
907
|
+
masked_tokens_batch = int((labels_t != -100).sum().item())
|
|
908
|
+
effective_masked = masked_tokens_batch
|
|
909
|
+
if labels is not None and masked_tokens_batch <= 0:
|
|
910
|
+
zero_mask_batches += 1
|
|
911
|
+
effective_masked = tokens_in_batch
|
|
912
|
+
if os.environ.get("INVARLOCK_DEBUG_TRACE"):
|
|
913
|
+
sample_labels = None
|
|
914
|
+
try:
|
|
915
|
+
sample_labels = labels_t[0].detach().cpu().tolist()[:8]
|
|
916
|
+
except Exception: # pragma: no cover - defensive
|
|
917
|
+
sample_labels = None
|
|
918
|
+
self._log_event(
|
|
919
|
+
"eval",
|
|
920
|
+
"zero_mask_batch",
|
|
921
|
+
LogLevel.WARNING,
|
|
922
|
+
{
|
|
923
|
+
"window_index": start_idx + count,
|
|
924
|
+
"tokens_in_batch": tokens_in_batch,
|
|
925
|
+
"masked_tokens": masked_tokens_batch,
|
|
926
|
+
"labels_sample": sample_labels,
|
|
927
|
+
"fallback_weight": effective_masked,
|
|
928
|
+
},
|
|
929
|
+
)
|
|
930
|
+
effective_weight = (
|
|
931
|
+
effective_masked if labels is not None else tokens_in_batch
|
|
932
|
+
)
|
|
933
|
+
if effective_weight <= 0:
|
|
934
|
+
continue
|
|
935
|
+
|
|
936
|
+
if os.environ.get("INVARLOCK_DEBUG_TRACE"):
|
|
937
|
+
print(
|
|
938
|
+
f"[debug] eval batch loss={float(loss_val):.6f} masked_tokens={masked_tokens_batch} tokens_in_batch={tokens_in_batch}"
|
|
939
|
+
)
|
|
940
|
+
|
|
941
|
+
if store_windows:
|
|
942
|
+
for row in snapshot:
|
|
943
|
+
collected_tokens.append(row.tolist())
|
|
944
|
+
|
|
945
|
+
if attn_snapshot is not None:
|
|
946
|
+
for row in attn_snapshot:
|
|
947
|
+
collected_attn.append(row.tolist())
|
|
948
|
+
else:
|
|
949
|
+
for row in snapshot:
|
|
950
|
+
collected_attn.append([1] * len(row))
|
|
951
|
+
collected_labels.extend(labels_t.detach().cpu().tolist())
|
|
952
|
+
|
|
953
|
+
if not alignment_logged:
|
|
954
|
+
self._log_event(
|
|
955
|
+
"eval",
|
|
956
|
+
"label_alignment",
|
|
957
|
+
LogLevel.INFO,
|
|
958
|
+
{
|
|
959
|
+
"ignore_index": -100,
|
|
960
|
+
"used_attention_mask": bool(attn_snapshot is not None),
|
|
961
|
+
"tokens_in_batch": tokens_in_batch,
|
|
962
|
+
"masked_tokens": masked_tokens_batch,
|
|
963
|
+
},
|
|
964
|
+
)
|
|
965
|
+
alignment_logged = True
|
|
966
|
+
|
|
967
|
+
log_losses.append(float(loss_val))
|
|
968
|
+
actual_tokens_local += tokens_in_batch
|
|
969
|
+
total_tokens_local += effective_weight
|
|
970
|
+
weighted_log_loss += float(loss_val) * effective_weight
|
|
971
|
+
token_counts.append(effective_weight)
|
|
972
|
+
masked_token_counts.append(masked_tokens_batch)
|
|
973
|
+
if labels is not None and masked_tokens_batch <= 0:
|
|
974
|
+
masked_token_counts[-1] = effective_masked
|
|
975
|
+
actual_token_counts.append(tokens_in_batch)
|
|
976
|
+
window_ids.append(start_idx + count)
|
|
977
|
+
count += 1
|
|
978
|
+
|
|
979
|
+
if count == 0:
|
|
980
|
+
if zero_mask_batches and os.environ.get("INVARLOCK_DEBUG_TRACE"):
|
|
981
|
+
self._log_event(
|
|
982
|
+
"eval",
|
|
983
|
+
"zero_mask_total",
|
|
984
|
+
LogLevel.ERROR,
|
|
985
|
+
{
|
|
986
|
+
"zero_mask_batches": zero_mask_batches,
|
|
987
|
+
"requested": limit,
|
|
988
|
+
},
|
|
989
|
+
)
|
|
990
|
+
if resolved_loss_mode == "mlm":
|
|
991
|
+
error_msg = (
|
|
992
|
+
"MLM evaluation produced zero usable batches; "
|
|
993
|
+
"ensure baseline pairing includes masked tokens."
|
|
994
|
+
)
|
|
995
|
+
if any_labels_seen:
|
|
996
|
+
error_msg = (
|
|
997
|
+
"MLM evaluation saw labels but zero masked tokens were accumulated; "
|
|
998
|
+
"check calibration data integrity."
|
|
999
|
+
)
|
|
1000
|
+
self._log_event(
|
|
1001
|
+
"eval",
|
|
1002
|
+
"mlm_missing_masks",
|
|
1003
|
+
LogLevel.ERROR,
|
|
1004
|
+
{
|
|
1005
|
+
"any_labels": bool(any_labels_seen),
|
|
1006
|
+
"requested": limit,
|
|
1007
|
+
"zero_mask_batches": zero_mask_batches,
|
|
1008
|
+
},
|
|
1009
|
+
)
|
|
1010
|
+
raise ValueError(error_msg)
|
|
1011
|
+
return {
|
|
1012
|
+
"ppl": float("nan"),
|
|
1013
|
+
"total_tokens": total_tokens_local,
|
|
1014
|
+
"actual_total_tokens": actual_tokens_local,
|
|
1015
|
+
"num_batches": 0,
|
|
1016
|
+
"log_losses": [],
|
|
1017
|
+
"window_ids": [],
|
|
1018
|
+
"tokens": [],
|
|
1019
|
+
"attention_masks": [],
|
|
1020
|
+
"weighted_log_loss": 0.0,
|
|
1021
|
+
"window_token_counts": [],
|
|
1022
|
+
"masked_token_counts": [],
|
|
1023
|
+
"actual_token_counts": [],
|
|
1024
|
+
"labels": [],
|
|
1025
|
+
}
|
|
1026
|
+
|
|
1027
|
+
mean_loss = (
|
|
1028
|
+
weighted_log_loss / total_tokens_local
|
|
1029
|
+
if total_tokens_local > 0
|
|
1030
|
+
else sum(log_losses) / max(count, 1)
|
|
1031
|
+
)
|
|
1032
|
+
return {
|
|
1033
|
+
"ppl": float(math.exp(mean_loss)),
|
|
1034
|
+
"total_tokens": total_tokens_local,
|
|
1035
|
+
"num_batches": count,
|
|
1036
|
+
"log_losses": log_losses,
|
|
1037
|
+
"window_ids": window_ids,
|
|
1038
|
+
"tokens": collected_tokens,
|
|
1039
|
+
"attention_masks": collected_attn,
|
|
1040
|
+
"weighted_log_loss": weighted_log_loss,
|
|
1041
|
+
"window_token_counts": token_counts,
|
|
1042
|
+
"masked_token_counts": masked_token_counts,
|
|
1043
|
+
"actual_token_counts": actual_token_counts,
|
|
1044
|
+
"labels": collected_labels,
|
|
1045
|
+
"actual_total_tokens": actual_tokens_local,
|
|
1046
|
+
}
|
|
1047
|
+
|
|
1048
|
+
preview_limit = _resolve_limit(preview_data, preview_n)
|
|
1049
|
+
final_limit = _resolve_limit(final_data, final_n)
|
|
1050
|
+
|
|
1051
|
+
preview_summary = _compute_slice_summary(preview_data, preview_limit, 0)
|
|
1052
|
+
final_summary = _compute_slice_summary(
|
|
1053
|
+
final_data, final_limit, preview_summary["num_batches"]
|
|
1054
|
+
)
|
|
1055
|
+
|
|
1056
|
+
preview_log_losses = preview_summary["log_losses"]
|
|
1057
|
+
final_log_losses = final_summary["log_losses"]
|
|
1058
|
+
preview_tokens_ct = preview_summary["total_tokens"]
|
|
1059
|
+
final_tokens_ct = final_summary["total_tokens"]
|
|
1060
|
+
preview_batches_ct = preview_summary["num_batches"]
|
|
1061
|
+
final_batches_ct = final_summary["num_batches"]
|
|
1062
|
+
preview_window_ids = list(preview_summary["window_ids"])
|
|
1063
|
+
final_window_ids = list(final_summary["window_ids"])
|
|
1064
|
+
preview_tokens = list(preview_summary["tokens"])
|
|
1065
|
+
final_tokens = list(final_summary["tokens"])
|
|
1066
|
+
preview_token_counts = list(preview_summary.get("window_token_counts", []))
|
|
1067
|
+
final_token_counts = list(final_summary.get("window_token_counts", []))
|
|
1068
|
+
preview_attention_masks = list(preview_summary.get("attention_masks", []))
|
|
1069
|
+
final_attention_masks = list(final_summary.get("attention_masks", []))
|
|
1070
|
+
preview_mask_counts = list(preview_summary.get("masked_token_counts", []))
|
|
1071
|
+
final_mask_counts = list(final_summary.get("masked_token_counts", []))
|
|
1072
|
+
preview_labels = list(preview_summary.get("labels", []))
|
|
1073
|
+
final_labels = list(final_summary.get("labels", []))
|
|
1074
|
+
preview_actual_token_counts = list(
|
|
1075
|
+
preview_summary.get("actual_token_counts", [])
|
|
1076
|
+
)
|
|
1077
|
+
final_actual_token_counts = list(
|
|
1078
|
+
final_summary.get("actual_token_counts", [])
|
|
1079
|
+
)
|
|
1080
|
+
preview_actual_tokens_ct = int(
|
|
1081
|
+
preview_summary.get("actual_total_tokens", preview_tokens_ct)
|
|
1082
|
+
)
|
|
1083
|
+
final_actual_tokens_ct = int(
|
|
1084
|
+
final_summary.get("actual_total_tokens", final_tokens_ct)
|
|
1085
|
+
)
|
|
1086
|
+
preview_masked_total = (
|
|
1087
|
+
sum(preview_mask_counts)
|
|
1088
|
+
if preview_mask_counts
|
|
1089
|
+
else int(preview_tokens_ct)
|
|
1090
|
+
)
|
|
1091
|
+
final_masked_total = (
|
|
1092
|
+
sum(final_mask_counts) if final_mask_counts else int(final_tokens_ct)
|
|
1093
|
+
)
|
|
1094
|
+
preview_weighted_loss = float(preview_summary.get("weighted_log_loss", 0.0))
|
|
1095
|
+
final_weighted_loss = float(final_summary.get("weighted_log_loss", 0.0))
|
|
1096
|
+
|
|
1097
|
+
if preview_tokens_ct > 0:
|
|
1098
|
+
preview_mean_log = float(preview_weighted_loss / preview_tokens_ct)
|
|
1099
|
+
pm_preview = math.exp(preview_mean_log)
|
|
1100
|
+
elif preview_log_losses:
|
|
1101
|
+
preview_mean_log = float(np.mean(preview_log_losses))
|
|
1102
|
+
pm_preview = math.exp(preview_mean_log)
|
|
1103
|
+
else:
|
|
1104
|
+
pm_preview = preview_summary["ppl"]
|
|
1105
|
+
if not math.isfinite(pm_preview) or pm_preview <= 0:
|
|
1106
|
+
pm_preview = 50.0
|
|
1107
|
+
preview_mean_log = math.log(pm_preview)
|
|
1108
|
+
|
|
1109
|
+
if final_tokens_ct > 0:
|
|
1110
|
+
final_mean_log = float(final_weighted_loss / final_tokens_ct)
|
|
1111
|
+
pm_final = math.exp(final_mean_log)
|
|
1112
|
+
elif final_log_losses:
|
|
1113
|
+
final_mean_log = float(np.mean(final_log_losses))
|
|
1114
|
+
pm_final = math.exp(final_mean_log)
|
|
1115
|
+
else:
|
|
1116
|
+
pm_final = final_summary["ppl"]
|
|
1117
|
+
if not math.isfinite(pm_final) or pm_final <= 0:
|
|
1118
|
+
pm_final = 50.0
|
|
1119
|
+
final_mean_log = math.log(pm_final)
|
|
1120
|
+
|
|
1121
|
+
delta_mean_log = final_mean_log - preview_mean_log
|
|
1122
|
+
pm_ratio = math.exp(delta_mean_log)
|
|
1123
|
+
|
|
1124
|
+
if not (math.isfinite(delta_mean_log) and math.isfinite(pm_ratio)):
|
|
1125
|
+
raise RuntimeError("Invalid perplexity ratio or delta")
|
|
1126
|
+
|
|
1127
|
+
expected_ratio = math.exp(delta_mean_log)
|
|
1128
|
+
if abs(pm_ratio - expected_ratio) > 1e-6:
|
|
1129
|
+
raise RuntimeError(
|
|
1130
|
+
"Primary-metric ratio mismatch with exp(mean ΔlogNLL)"
|
|
1131
|
+
)
|
|
1132
|
+
|
|
1133
|
+
if bootstrap_enabled and preview_log_losses:
|
|
1134
|
+
preview_log_ci = compute_logloss_ci(
|
|
1135
|
+
preview_log_losses,
|
|
1136
|
+
method=single_method,
|
|
1137
|
+
replicates=bootstrap_replicates,
|
|
1138
|
+
alpha=bootstrap_alpha,
|
|
1139
|
+
seed=bootstrap_seed + 7,
|
|
1140
|
+
)
|
|
1141
|
+
else:
|
|
1142
|
+
preview_log_ci = (preview_mean_log, preview_mean_log)
|
|
1143
|
+
|
|
1144
|
+
# primary_metric consumers use log-space intervals; skip ppl-space tuple here
|
|
1145
|
+
|
|
1146
|
+
if bootstrap_enabled and final_log_losses:
|
|
1147
|
+
final_log_ci = compute_logloss_ci(
|
|
1148
|
+
final_log_losses,
|
|
1149
|
+
method=single_method,
|
|
1150
|
+
replicates=bootstrap_replicates,
|
|
1151
|
+
alpha=bootstrap_alpha,
|
|
1152
|
+
seed=bootstrap_seed + 13,
|
|
1153
|
+
)
|
|
1154
|
+
else:
|
|
1155
|
+
final_log_ci = (final_mean_log, final_mean_log)
|
|
1156
|
+
|
|
1157
|
+
# primary_metric consumers use log-space intervals; skip ppl-space tuple here
|
|
1158
|
+
|
|
1159
|
+
if (
|
|
1160
|
+
bootstrap_enabled
|
|
1161
|
+
and final_log_losses
|
|
1162
|
+
and preview_log_losses
|
|
1163
|
+
and len(final_log_losses)
|
|
1164
|
+
and len(preview_log_losses)
|
|
1165
|
+
):
|
|
1166
|
+
delta_log_ci = compute_paired_delta_log_ci(
|
|
1167
|
+
final_log_losses,
|
|
1168
|
+
preview_log_losses,
|
|
1169
|
+
method=delta_method,
|
|
1170
|
+
replicates=bootstrap_replicates,
|
|
1171
|
+
alpha=bootstrap_alpha,
|
|
1172
|
+
seed=bootstrap_seed + 97,
|
|
1173
|
+
)
|
|
1174
|
+
ratio_ci = logspace_to_ratio_ci(delta_log_ci)
|
|
1175
|
+
expected_ratio_ci = tuple(math.exp(bound) for bound in delta_log_ci)
|
|
1176
|
+
if any(
|
|
1177
|
+
abs(r - e) > 1e-6
|
|
1178
|
+
for r, e in zip(ratio_ci, expected_ratio_ci, strict=False)
|
|
1179
|
+
):
|
|
1180
|
+
raise RuntimeError("Ratio CI inconsistent with Δlog CI")
|
|
1181
|
+
else:
|
|
1182
|
+
delta_log_ci = (delta_mean_log, delta_mean_log)
|
|
1183
|
+
ratio_ci = (pm_ratio, pm_ratio)
|
|
1184
|
+
|
|
1185
|
+
delta_samples: list[float] = []
|
|
1186
|
+
delta_weights: list[float] = []
|
|
1187
|
+
if final_log_losses and preview_log_losses:
|
|
1188
|
+
limit = min(len(final_log_losses), len(preview_log_losses))
|
|
1189
|
+
if limit:
|
|
1190
|
+
delta_samples = [
|
|
1191
|
+
final_log_losses[i] - preview_log_losses[i]
|
|
1192
|
+
for i in range(limit)
|
|
1193
|
+
]
|
|
1194
|
+
if preview_token_counts and len(preview_token_counts) >= limit:
|
|
1195
|
+
delta_weights = [
|
|
1196
|
+
float(max(preview_token_counts[i], 1)) for i in range(limit)
|
|
1197
|
+
]
|
|
1198
|
+
|
|
1199
|
+
degenerate_delta = False
|
|
1200
|
+
degenerate_reason: str | None = None
|
|
1201
|
+
if len(delta_samples) < 2:
|
|
1202
|
+
if len(delta_samples) == 0:
|
|
1203
|
+
degenerate_delta = True
|
|
1204
|
+
degenerate_reason = "no_pairs"
|
|
1205
|
+
else:
|
|
1206
|
+
degenerate_delta = True
|
|
1207
|
+
degenerate_reason = "single_pair"
|
|
1208
|
+
elif np.allclose(delta_samples, delta_samples[0]):
|
|
1209
|
+
degenerate_delta = True
|
|
1210
|
+
degenerate_reason = "no_variation"
|
|
1211
|
+
|
|
1212
|
+
if degenerate_delta:
|
|
1213
|
+
self._log_event(
|
|
1214
|
+
"eval",
|
|
1215
|
+
"degenerate_delta_samples",
|
|
1216
|
+
LogLevel.ERROR,
|
|
1217
|
+
{
|
|
1218
|
+
"reason": degenerate_reason,
|
|
1219
|
+
"sample_count": len(delta_samples),
|
|
1220
|
+
},
|
|
1221
|
+
)
|
|
1222
|
+
if profile_label in {"ci", "release"}:
|
|
1223
|
+
raise RuntimeError(
|
|
1224
|
+
f"Degenerate paired ΔlogNLL distribution ({degenerate_reason})"
|
|
1225
|
+
)
|
|
1226
|
+
|
|
1227
|
+
def _hash_tokens(tokens: list[int]) -> bytes:
|
|
1228
|
+
if not tokens:
|
|
1229
|
+
return b""
|
|
1230
|
+
token_array = array("I", (int(token) & 0xFFFFFFFF for token in tokens))
|
|
1231
|
+
return hashlib.blake2b(token_array.tobytes(), digest_size=16).digest()
|
|
1232
|
+
|
|
1233
|
+
def _duplicate_fraction(seqs: list[list[int]]) -> float:
|
|
1234
|
+
if not seqs:
|
|
1235
|
+
return 0.0
|
|
1236
|
+
hashes = [_hash_tokens(seq) for seq in seqs]
|
|
1237
|
+
unique = len(set(hashes))
|
|
1238
|
+
if not hashes:
|
|
1239
|
+
return 0.0
|
|
1240
|
+
return max(0.0, (len(hashes) - unique) / len(hashes))
|
|
1241
|
+
|
|
1242
|
+
def _compare_with_baseline(
|
|
1243
|
+
run_ids: list[int],
|
|
1244
|
+
run_tokens: list[list[int]],
|
|
1245
|
+
baseline_section: dict[str, Any] | None,
|
|
1246
|
+
split_label: str,
|
|
1247
|
+
) -> dict[str, Any]:
|
|
1248
|
+
stats = {
|
|
1249
|
+
"matched": 0,
|
|
1250
|
+
"expected": 0,
|
|
1251
|
+
"missing_ids": [],
|
|
1252
|
+
"mismatched_ids": [],
|
|
1253
|
+
"unexpected_ids": [],
|
|
1254
|
+
"reason": None,
|
|
1255
|
+
}
|
|
1256
|
+
|
|
1257
|
+
if not baseline_section:
|
|
1258
|
+
stats["matched"] = len(run_tokens)
|
|
1259
|
+
stats["expected"] = len(run_tokens)
|
|
1260
|
+
stats["reason"] = "no_baseline_reference"
|
|
1261
|
+
return stats
|
|
1262
|
+
|
|
1263
|
+
base_ids = baseline_section.get("window_ids") or []
|
|
1264
|
+
base_tokens = baseline_section.get("input_ids") or []
|
|
1265
|
+
if not isinstance(base_ids, list) or not isinstance(base_tokens, list):
|
|
1266
|
+
stats["matched"] = len(run_tokens)
|
|
1267
|
+
stats["expected"] = len(run_tokens)
|
|
1268
|
+
stats["reason"] = "invalid_baseline_reference"
|
|
1269
|
+
return stats
|
|
1270
|
+
|
|
1271
|
+
base_map: dict[int, bytes] = {}
|
|
1272
|
+
for bid, seq in zip(base_ids, base_tokens, strict=False):
|
|
1273
|
+
try:
|
|
1274
|
+
bid_int = int(bid)
|
|
1275
|
+
except Exception:
|
|
1276
|
+
continue
|
|
1277
|
+
seq_list = list(seq) if not isinstance(seq, list) else seq
|
|
1278
|
+
base_map[bid_int] = _hash_tokens(seq_list)
|
|
1279
|
+
|
|
1280
|
+
stats["expected"] = len(base_map)
|
|
1281
|
+
matched = 0
|
|
1282
|
+
seen_ids: set[int] = set()
|
|
1283
|
+
mismatched: list[int] = []
|
|
1284
|
+
unexpected: list[int] = []
|
|
1285
|
+
|
|
1286
|
+
for rid, seq in zip(run_ids, run_tokens, strict=False):
|
|
1287
|
+
try:
|
|
1288
|
+
rid_int = int(rid)
|
|
1289
|
+
except Exception:
|
|
1290
|
+
unexpected.append(rid)
|
|
1291
|
+
continue
|
|
1292
|
+
|
|
1293
|
+
hashed = _hash_tokens(seq)
|
|
1294
|
+
if rid_int not in base_map:
|
|
1295
|
+
unexpected.append(rid_int)
|
|
1296
|
+
continue
|
|
1297
|
+
|
|
1298
|
+
seen_ids.add(rid_int)
|
|
1299
|
+
if hashed == base_map[rid_int]:
|
|
1300
|
+
matched += 1
|
|
1301
|
+
else:
|
|
1302
|
+
mismatched.append(rid_int)
|
|
1303
|
+
|
|
1304
|
+
missing = [bid for bid in base_map if bid not in seen_ids]
|
|
1305
|
+
stats.update(
|
|
1306
|
+
{
|
|
1307
|
+
"matched": matched,
|
|
1308
|
+
"missing_ids": missing,
|
|
1309
|
+
"mismatched_ids": mismatched,
|
|
1310
|
+
"unexpected_ids": unexpected,
|
|
1311
|
+
}
|
|
1312
|
+
)
|
|
1313
|
+
|
|
1314
|
+
if missing:
|
|
1315
|
+
stats["reason"] = f"{split_label}_missing_ids:{missing[:3]}"
|
|
1316
|
+
elif mismatched:
|
|
1317
|
+
stats["reason"] = f"{split_label}_token_mismatch:{mismatched[:3]}"
|
|
1318
|
+
elif unexpected:
|
|
1319
|
+
stats["reason"] = f"{split_label}_unexpected_ids:{unexpected[:3]}"
|
|
1320
|
+
else:
|
|
1321
|
+
stats["reason"] = None
|
|
1322
|
+
|
|
1323
|
+
return stats
|
|
1324
|
+
|
|
1325
|
+
baseline_preview = (
|
|
1326
|
+
pairing_context.get("preview")
|
|
1327
|
+
if isinstance(pairing_context, dict)
|
|
1328
|
+
else {}
|
|
1329
|
+
)
|
|
1330
|
+
baseline_final = (
|
|
1331
|
+
pairing_context.get("final")
|
|
1332
|
+
if isinstance(pairing_context, dict)
|
|
1333
|
+
else {}
|
|
1334
|
+
)
|
|
1335
|
+
|
|
1336
|
+
preview_pair_stats = _compare_with_baseline(
|
|
1337
|
+
preview_window_ids, preview_tokens, baseline_preview, "preview"
|
|
1338
|
+
)
|
|
1339
|
+
final_pair_stats = _compare_with_baseline(
|
|
1340
|
+
final_window_ids, final_tokens, baseline_final, "final"
|
|
1341
|
+
)
|
|
1342
|
+
|
|
1343
|
+
total_expected = (
|
|
1344
|
+
preview_pair_stats["expected"] + final_pair_stats["expected"]
|
|
1345
|
+
)
|
|
1346
|
+
total_matched = preview_pair_stats["matched"] + final_pair_stats["matched"]
|
|
1347
|
+
window_match_fraction = (
|
|
1348
|
+
float(total_matched / total_expected) if total_expected > 0 else 1.0
|
|
1349
|
+
)
|
|
1350
|
+
window_overlap_fraction = _duplicate_fraction(preview_tokens + final_tokens)
|
|
1351
|
+
|
|
1352
|
+
pairing_reason = None
|
|
1353
|
+
if total_expected > 0:
|
|
1354
|
+
for stats_dict, label in (
|
|
1355
|
+
(preview_pair_stats, "preview"),
|
|
1356
|
+
(final_pair_stats, "final"),
|
|
1357
|
+
):
|
|
1358
|
+
if (
|
|
1359
|
+
stats_dict["expected"]
|
|
1360
|
+
and stats_dict["matched"] < stats_dict["expected"]
|
|
1361
|
+
):
|
|
1362
|
+
pairing_reason = stats_dict.get("reason") or f"{label}_mismatch"
|
|
1363
|
+
break
|
|
1364
|
+
if pairing_reason is None:
|
|
1365
|
+
if window_overlap_fraction > 0.0:
|
|
1366
|
+
pairing_reason = "duplicate_windows"
|
|
1367
|
+
elif not pairing_context:
|
|
1368
|
+
pairing_reason = preview_pair_stats.get(
|
|
1369
|
+
"reason"
|
|
1370
|
+
) or final_pair_stats.get("reason")
|
|
1371
|
+
|
|
1372
|
+
if pairing_context and window_match_fraction < 0.999999:
|
|
1373
|
+
self._log_event(
|
|
1374
|
+
"eval",
|
|
1375
|
+
"window_pairing_mismatch",
|
|
1376
|
+
LogLevel.ERROR,
|
|
1377
|
+
{
|
|
1378
|
+
"match_fraction": window_match_fraction,
|
|
1379
|
+
"overlap_fraction": window_overlap_fraction,
|
|
1380
|
+
"reason": pairing_reason,
|
|
1381
|
+
"preview": preview_pair_stats,
|
|
1382
|
+
"final": final_pair_stats,
|
|
1383
|
+
},
|
|
1384
|
+
)
|
|
1385
|
+
|
|
1386
|
+
if window_overlap_fraction > 0.0 and pairing_context:
|
|
1387
|
+
self._log_event(
|
|
1388
|
+
"eval",
|
|
1389
|
+
"window_overlap_warning",
|
|
1390
|
+
LogLevel.WARNING,
|
|
1391
|
+
{
|
|
1392
|
+
"duplicate_fraction": window_overlap_fraction,
|
|
1393
|
+
"match_fraction": window_match_fraction,
|
|
1394
|
+
"preview": preview_pair_stats,
|
|
1395
|
+
"final": final_pair_stats,
|
|
1396
|
+
},
|
|
1397
|
+
)
|
|
1398
|
+
|
|
1399
|
+
if pairing_context and profile_label in {"ci", "release"}:
|
|
1400
|
+
if window_match_fraction < 0.999999:
|
|
1401
|
+
raise RuntimeError(
|
|
1402
|
+
f"Window pairing mismatch detected (fraction={window_match_fraction:.3f}, reason={pairing_reason})"
|
|
1403
|
+
)
|
|
1404
|
+
if window_overlap_fraction > 0.0:
|
|
1405
|
+
raise RuntimeError(
|
|
1406
|
+
f"Window duplication detected (overlap_fraction={window_overlap_fraction:.3f})"
|
|
1407
|
+
)
|
|
1408
|
+
|
|
1409
|
+
tier = "balanced"
|
|
1410
|
+
if config and isinstance(config.context, dict):
|
|
1411
|
+
auto_section = config.context.get("auto", {})
|
|
1412
|
+
if isinstance(auto_section, dict):
|
|
1413
|
+
tier = str(auto_section.get("tier", tier)).lower()
|
|
1414
|
+
|
|
1415
|
+
coverage_requirements = BOOTSTRAP_COVERAGE_REQUIREMENTS.get(
|
|
1416
|
+
tier, BOOTSTRAP_COVERAGE_REQUIREMENTS["balanced"]
|
|
1417
|
+
)
|
|
1418
|
+
|
|
1419
|
+
def _meets_requirement(actual: int, required: int) -> bool:
|
|
1420
|
+
if required <= 0:
|
|
1421
|
+
return True
|
|
1422
|
+
slack = max(1, int(required * 0.95))
|
|
1423
|
+
return actual >= slack
|
|
1424
|
+
|
|
1425
|
+
preview_required = int(coverage_requirements.get("preview", 0))
|
|
1426
|
+
final_required = int(coverage_requirements.get("final", 0))
|
|
1427
|
+
replicates_required = int(coverage_requirements.get("replicates", 0))
|
|
1428
|
+
|
|
1429
|
+
preview_ok = _meets_requirement(preview_batches_ct, preview_required)
|
|
1430
|
+
final_ok = _meets_requirement(final_batches_ct, final_required)
|
|
1431
|
+
replicates_ok = (
|
|
1432
|
+
_meets_requirement(bootstrap_replicates, replicates_required)
|
|
1433
|
+
if bootstrap_enabled
|
|
1434
|
+
else True
|
|
1435
|
+
)
|
|
1436
|
+
|
|
1437
|
+
if not (preview_ok and final_ok and replicates_ok):
|
|
1438
|
+
self._log_event(
|
|
1439
|
+
"eval",
|
|
1440
|
+
"bootstrap_coverage_warning",
|
|
1441
|
+
LogLevel.WARNING,
|
|
1442
|
+
{
|
|
1443
|
+
"tier": tier,
|
|
1444
|
+
"preview_used": preview_batches_ct,
|
|
1445
|
+
"preview_required": preview_required,
|
|
1446
|
+
"final_used": final_batches_ct,
|
|
1447
|
+
"final_required": final_required,
|
|
1448
|
+
"replicates_used": bootstrap_replicates,
|
|
1449
|
+
"replicates_required": replicates_required,
|
|
1450
|
+
},
|
|
1451
|
+
)
|
|
1452
|
+
# In CI/Release profiles, treat insufficient coverage as a hard error
|
|
1453
|
+
if pairing_context and profile_label in {"ci", "release"}:
|
|
1454
|
+
from invarlock.cli.errors import InvarlockError
|
|
1455
|
+
|
|
1456
|
+
raise InvarlockError(
|
|
1457
|
+
code="E005",
|
|
1458
|
+
message=(
|
|
1459
|
+
"INSUFFICIENT-SAMPLE: bootstrap coverage below policy floors in CI/Release"
|
|
1460
|
+
),
|
|
1461
|
+
)
|
|
1462
|
+
|
|
1463
|
+
bootstrap_info.update(
|
|
1464
|
+
{
|
|
1465
|
+
"enabled": bool(bootstrap_enabled),
|
|
1466
|
+
"method": bootstrap_method,
|
|
1467
|
+
"alpha": float(bootstrap_alpha),
|
|
1468
|
+
"replicates": int(bootstrap_replicates),
|
|
1469
|
+
"seed": int(bootstrap_seed),
|
|
1470
|
+
"ci_band": float(ci_band),
|
|
1471
|
+
"window_duplicate_fraction": float(window_overlap_fraction),
|
|
1472
|
+
"window_match_fraction": float(window_match_fraction),
|
|
1473
|
+
"coverage": {
|
|
1474
|
+
"tier": tier,
|
|
1475
|
+
"preview": {
|
|
1476
|
+
"used": int(preview_batches_ct),
|
|
1477
|
+
"required": preview_required,
|
|
1478
|
+
"ok": bool(preview_ok),
|
|
1479
|
+
},
|
|
1480
|
+
"final": {
|
|
1481
|
+
"used": int(final_batches_ct),
|
|
1482
|
+
"required": final_required,
|
|
1483
|
+
"ok": bool(final_ok),
|
|
1484
|
+
},
|
|
1485
|
+
"replicates": {
|
|
1486
|
+
"used": int(bootstrap_replicates),
|
|
1487
|
+
"required": replicates_required,
|
|
1488
|
+
"ok": bool(replicates_ok),
|
|
1489
|
+
},
|
|
1490
|
+
},
|
|
1491
|
+
}
|
|
1492
|
+
)
|
|
1493
|
+
|
|
1494
|
+
except Exception as exc: # pragma: no cover - defensive fallback
|
|
1495
|
+
self._log_event(
|
|
1496
|
+
"eval",
|
|
1497
|
+
"error",
|
|
1498
|
+
LogLevel.ERROR,
|
|
1499
|
+
{"message": f"Primary-metric computation failed: {exc}"},
|
|
1500
|
+
)
|
|
1501
|
+
|
|
1502
|
+
pm_ratio = pm_final / pm_preview if pm_preview > 0 else 1.0
|
|
1503
|
+
|
|
1504
|
+
latency_ms_per_tok = self._measure_latency(
|
|
1505
|
+
model, preview_data[:1] if preview_data else final_data[:1], device
|
|
1506
|
+
)
|
|
1507
|
+
|
|
1508
|
+
current_memory = process.memory_info().rss / 1024 / 1024
|
|
1509
|
+
peak_memory = max(initial_memory, current_memory)
|
|
1510
|
+
|
|
1511
|
+
eval_samples = 0
|
|
1512
|
+
total_tokens = 0
|
|
1513
|
+
masked_total_tokens = 0
|
|
1514
|
+
try:
|
|
1515
|
+
eval_samples = int(preview_batches_ct) + int(final_batches_ct)
|
|
1516
|
+
total_tokens = int(preview_actual_tokens_ct) + int(final_actual_tokens_ct)
|
|
1517
|
+
masked_total_tokens = int(preview_masked_total) + int(final_masked_total)
|
|
1518
|
+
except Exception:
|
|
1519
|
+
pass
|
|
1520
|
+
|
|
1521
|
+
paired_windows_count = len(delta_samples)
|
|
1522
|
+
unweighted_delta_mean = (
|
|
1523
|
+
float(np.mean(delta_samples)) if delta_samples else float(delta_mean_log)
|
|
1524
|
+
)
|
|
1525
|
+
preview_weighted_delta_mean: float | None = None
|
|
1526
|
+
if delta_weights:
|
|
1527
|
+
total_weight = float(sum(delta_weights))
|
|
1528
|
+
if total_weight > 0.0:
|
|
1529
|
+
preview_weighted_delta_mean = float(
|
|
1530
|
+
np.dot(delta_samples, delta_weights) / total_weight
|
|
1531
|
+
)
|
|
1532
|
+
paired_delta_mean = float(delta_mean_log)
|
|
1533
|
+
paired_delta_std = (
|
|
1534
|
+
float(np.std(delta_samples, ddof=1)) if len(delta_samples) > 1 else 0.0
|
|
1535
|
+
)
|
|
1536
|
+
paired_delta_min = float(min(delta_samples)) if delta_samples else None
|
|
1537
|
+
paired_delta_max = float(max(delta_samples)) if delta_samples else None
|
|
1538
|
+
|
|
1539
|
+
# Resolve primary metric kind from resolved loss mode
|
|
1540
|
+
pm_kind = "ppl_causal"
|
|
1541
|
+
if resolved_loss_mode == "mlm":
|
|
1542
|
+
pm_kind = "ppl_mlm"
|
|
1543
|
+
elif resolved_loss_mode in {"seq2seq", "s2s", "t5"}:
|
|
1544
|
+
pm_kind = "ppl_seq2seq"
|
|
1545
|
+
|
|
1546
|
+
metrics = {
|
|
1547
|
+
"primary_metric": {
|
|
1548
|
+
"kind": pm_kind,
|
|
1549
|
+
"preview": float(pm_preview),
|
|
1550
|
+
"final": float(pm_final),
|
|
1551
|
+
},
|
|
1552
|
+
"logloss_preview": float(preview_mean_log),
|
|
1553
|
+
"logloss_final": float(final_mean_log),
|
|
1554
|
+
"logloss_delta": float(delta_mean_log),
|
|
1555
|
+
"logloss_preview_ci": tuple(map(float, preview_log_ci)),
|
|
1556
|
+
"logloss_final_ci": tuple(map(float, final_log_ci)),
|
|
1557
|
+
"logloss_delta_ci": tuple(map(float, delta_log_ci)),
|
|
1558
|
+
"latency_ms_per_tok": latency_ms_per_tok,
|
|
1559
|
+
"memory_mb_peak": peak_memory,
|
|
1560
|
+
"eval_samples": eval_samples,
|
|
1561
|
+
"total_tokens": total_tokens,
|
|
1562
|
+
"preview_total_tokens": int(preview_actual_tokens_ct),
|
|
1563
|
+
"final_total_tokens": int(final_actual_tokens_ct),
|
|
1564
|
+
"masked_tokens_total": masked_total_tokens,
|
|
1565
|
+
"masked_tokens_preview": int(preview_masked_total),
|
|
1566
|
+
"masked_tokens_final": int(final_masked_total),
|
|
1567
|
+
"reduction": {
|
|
1568
|
+
"mode": "token_mean",
|
|
1569
|
+
"implementation": "huggingface_cross_entropy",
|
|
1570
|
+
},
|
|
1571
|
+
"window_overlap_fraction": float(window_overlap_fraction),
|
|
1572
|
+
"window_match_fraction": float(window_match_fraction),
|
|
1573
|
+
"window_pairing_reason": pairing_reason,
|
|
1574
|
+
"window_pairing_preview": {
|
|
1575
|
+
"matched": preview_pair_stats["matched"],
|
|
1576
|
+
"expected": preview_pair_stats["expected"],
|
|
1577
|
+
"reason": preview_pair_stats.get("reason"),
|
|
1578
|
+
},
|
|
1579
|
+
"window_pairing_final": {
|
|
1580
|
+
"matched": final_pair_stats["matched"],
|
|
1581
|
+
"expected": final_pair_stats["expected"],
|
|
1582
|
+
"reason": final_pair_stats.get("reason"),
|
|
1583
|
+
},
|
|
1584
|
+
"bootstrap": bootstrap_info,
|
|
1585
|
+
"paired_windows": paired_windows_count,
|
|
1586
|
+
"paired_delta_summary": {
|
|
1587
|
+
"mean": paired_delta_mean,
|
|
1588
|
+
"mean_unweighted": unweighted_delta_mean,
|
|
1589
|
+
"mean_preview_weighted": (
|
|
1590
|
+
preview_weighted_delta_mean
|
|
1591
|
+
if preview_weighted_delta_mean is not None
|
|
1592
|
+
else unweighted_delta_mean
|
|
1593
|
+
),
|
|
1594
|
+
"std": paired_delta_std,
|
|
1595
|
+
"min": paired_delta_min,
|
|
1596
|
+
"max": paired_delta_max,
|
|
1597
|
+
"degenerate": degenerate_delta,
|
|
1598
|
+
"degenerate_reason": degenerate_reason,
|
|
1599
|
+
},
|
|
1600
|
+
}
|
|
1601
|
+
|
|
1602
|
+
eval_windows = {
|
|
1603
|
+
"preview": {
|
|
1604
|
+
"window_ids": preview_window_ids[:preview_limit],
|
|
1605
|
+
"logloss": list(preview_log_losses),
|
|
1606
|
+
"input_ids": preview_tokens,
|
|
1607
|
+
"attention_masks": preview_attention_masks,
|
|
1608
|
+
"token_counts": preview_token_counts,
|
|
1609
|
+
"masked_token_counts": preview_mask_counts,
|
|
1610
|
+
"actual_token_counts": preview_actual_token_counts,
|
|
1611
|
+
"labels": preview_labels,
|
|
1612
|
+
},
|
|
1613
|
+
"final": {
|
|
1614
|
+
"window_ids": final_window_ids[:final_limit],
|
|
1615
|
+
"logloss": list(final_log_losses),
|
|
1616
|
+
"input_ids": final_tokens,
|
|
1617
|
+
"attention_masks": final_attention_masks,
|
|
1618
|
+
"token_counts": final_token_counts,
|
|
1619
|
+
"masked_token_counts": final_mask_counts,
|
|
1620
|
+
"actual_token_counts": final_actual_token_counts,
|
|
1621
|
+
"labels": final_labels,
|
|
1622
|
+
},
|
|
1623
|
+
}
|
|
1624
|
+
|
|
1625
|
+
return metrics, eval_windows
|
|
1626
|
+
|
|
1627
|
+
def _measure_latency(self, model: Any, sample_data: Any, device: Any) -> float:
|
|
1628
|
+
"""Simple latency measurement for a sample."""
|
|
1629
|
+
import time
|
|
1630
|
+
|
|
1631
|
+
import torch
|
|
1632
|
+
|
|
1633
|
+
if not sample_data:
|
|
1634
|
+
return 0.0
|
|
1635
|
+
|
|
1636
|
+
# Model eval is managed by caller to avoid duplicate invocations in tests
|
|
1637
|
+
|
|
1638
|
+
# Get a sample for timing
|
|
1639
|
+
sample = sample_data[0] if sample_data else None
|
|
1640
|
+
if sample is None:
|
|
1641
|
+
return 0.0
|
|
1642
|
+
|
|
1643
|
+
if isinstance(sample, dict):
|
|
1644
|
+
input_ids = sample.get("input_ids", sample.get("inputs"))
|
|
1645
|
+
else:
|
|
1646
|
+
input_ids = sample
|
|
1647
|
+
|
|
1648
|
+
if input_ids is None:
|
|
1649
|
+
return 0.0
|
|
1650
|
+
|
|
1651
|
+
# Convert to tensor if needed
|
|
1652
|
+
if not isinstance(input_ids, torch.Tensor):
|
|
1653
|
+
input_ids = torch.tensor(input_ids)
|
|
1654
|
+
|
|
1655
|
+
# Some tests patch torch.tensor with dim side effects; guard against exceptions
|
|
1656
|
+
try:
|
|
1657
|
+
dim_val = input_ids.dim()
|
|
1658
|
+
except Exception:
|
|
1659
|
+
dim_val = 2 # assume already batched
|
|
1660
|
+
if dim_val == 1:
|
|
1661
|
+
try:
|
|
1662
|
+
input_ids = input_ids.unsqueeze(0)
|
|
1663
|
+
except Exception:
|
|
1664
|
+
pass
|
|
1665
|
+
|
|
1666
|
+
# to(device) may be a Mock; guard call
|
|
1667
|
+
try:
|
|
1668
|
+
input_ids = input_ids.to(device)
|
|
1669
|
+
except Exception:
|
|
1670
|
+
pass
|
|
1671
|
+
|
|
1672
|
+
# Simple timing measurement
|
|
1673
|
+
with torch.no_grad():
|
|
1674
|
+
try:
|
|
1675
|
+
# Prepare labels and attention mask if available
|
|
1676
|
+
labels_t = input_ids
|
|
1677
|
+
attn_t = None
|
|
1678
|
+
token_type_t = None
|
|
1679
|
+
if isinstance(sample, dict) and "attention_mask" in sample:
|
|
1680
|
+
try:
|
|
1681
|
+
attn_t = torch.tensor(sample["attention_mask"])
|
|
1682
|
+
try:
|
|
1683
|
+
if attn_t.dim() == 1:
|
|
1684
|
+
attn_t = attn_t.unsqueeze(0)
|
|
1685
|
+
except Exception:
|
|
1686
|
+
pass
|
|
1687
|
+
try:
|
|
1688
|
+
attn_t = attn_t.to(device)
|
|
1689
|
+
except Exception:
|
|
1690
|
+
pass
|
|
1691
|
+
except Exception:
|
|
1692
|
+
attn_t = None
|
|
1693
|
+
if isinstance(sample, dict) and "token_type_ids" in sample:
|
|
1694
|
+
try:
|
|
1695
|
+
token_type_t = torch.tensor(sample["token_type_ids"])
|
|
1696
|
+
if token_type_t.dim() == 1:
|
|
1697
|
+
token_type_t = token_type_t.unsqueeze(0)
|
|
1698
|
+
token_type_t = token_type_t.to(device)
|
|
1699
|
+
except Exception:
|
|
1700
|
+
token_type_t = None
|
|
1701
|
+
|
|
1702
|
+
def _call_model():
|
|
1703
|
+
if attn_t is not None:
|
|
1704
|
+
return model(
|
|
1705
|
+
input_ids,
|
|
1706
|
+
attention_mask=attn_t,
|
|
1707
|
+
labels=labels_t,
|
|
1708
|
+
token_type_ids=token_type_t,
|
|
1709
|
+
)
|
|
1710
|
+
else:
|
|
1711
|
+
return model(
|
|
1712
|
+
input_ids,
|
|
1713
|
+
labels=labels_t,
|
|
1714
|
+
token_type_ids=token_type_t,
|
|
1715
|
+
)
|
|
1716
|
+
|
|
1717
|
+
# Warmup
|
|
1718
|
+
for _ in range(3):
|
|
1719
|
+
_ = _call_model()
|
|
1720
|
+
|
|
1721
|
+
# Measure
|
|
1722
|
+
start_time = time.time()
|
|
1723
|
+
for _ in range(10):
|
|
1724
|
+
_ = _call_model()
|
|
1725
|
+
end_time = time.time()
|
|
1726
|
+
|
|
1727
|
+
total_time = (end_time - start_time) * 1000 # Convert to ms
|
|
1728
|
+
try:
|
|
1729
|
+
total_tokens = input_ids.numel() * 10 # 10 iterations
|
|
1730
|
+
except Exception:
|
|
1731
|
+
total_tokens = 0
|
|
1732
|
+
|
|
1733
|
+
return total_time / total_tokens if total_tokens > 0 else 0.0
|
|
1734
|
+
|
|
1735
|
+
except Exception:
|
|
1736
|
+
return 0.0
|
|
1737
|
+
|
|
1738
|
+
def _samples_to_dataloader(self, samples: list) -> Any:
|
|
1739
|
+
"""
|
|
1740
|
+
Convert list of samples to DataLoader-compatible format.
|
|
1741
|
+
|
|
1742
|
+
Args:
|
|
1743
|
+
samples: List of sample dictionaries with 'input_ids' and 'attention_mask'
|
|
1744
|
+
|
|
1745
|
+
Returns:
|
|
1746
|
+
Simple iterable that yields batches compatible with compute_perplexity()
|
|
1747
|
+
"""
|
|
1748
|
+
|
|
1749
|
+
class SampleDataLoader:
|
|
1750
|
+
def __init__(self, samples):
|
|
1751
|
+
self.samples = samples
|
|
1752
|
+
|
|
1753
|
+
def __iter__(self):
|
|
1754
|
+
for sample in self.samples:
|
|
1755
|
+
# Each sample is already a dict with 'input_ids' and 'attention_mask'
|
|
1756
|
+
# Convert to tensor format that compute_perplexity expects
|
|
1757
|
+
import torch
|
|
1758
|
+
|
|
1759
|
+
input_ids = sample.get("input_ids", sample.get("inputs"))
|
|
1760
|
+
attention_mask = sample.get("attention_mask")
|
|
1761
|
+
|
|
1762
|
+
if input_ids is None:
|
|
1763
|
+
continue
|
|
1764
|
+
|
|
1765
|
+
# Convert to tensors if needed and add batch dimension
|
|
1766
|
+
if not isinstance(input_ids, torch.Tensor):
|
|
1767
|
+
input_ids = torch.tensor(input_ids, dtype=torch.long)
|
|
1768
|
+
if input_ids.dim() == 1:
|
|
1769
|
+
input_ids = input_ids.unsqueeze(0)
|
|
1770
|
+
|
|
1771
|
+
if attention_mask is not None:
|
|
1772
|
+
if not isinstance(attention_mask, torch.Tensor):
|
|
1773
|
+
attention_mask = torch.tensor(
|
|
1774
|
+
attention_mask, dtype=torch.long
|
|
1775
|
+
)
|
|
1776
|
+
if attention_mask.dim() == 1:
|
|
1777
|
+
attention_mask = attention_mask.unsqueeze(0)
|
|
1778
|
+
|
|
1779
|
+
batch = {"input_ids": input_ids}
|
|
1780
|
+
if attention_mask is not None:
|
|
1781
|
+
batch["attention_mask"] = attention_mask
|
|
1782
|
+
|
|
1783
|
+
token_type = sample.get("token_type_ids")
|
|
1784
|
+
if token_type is not None:
|
|
1785
|
+
if not isinstance(token_type, torch.Tensor):
|
|
1786
|
+
token_type = torch.tensor(token_type, dtype=torch.long)
|
|
1787
|
+
if token_type.dim() == 1:
|
|
1788
|
+
token_type = token_type.unsqueeze(0)
|
|
1789
|
+
batch["token_type_ids"] = token_type
|
|
1790
|
+
|
|
1791
|
+
labels = sample.get("labels")
|
|
1792
|
+
if labels is None:
|
|
1793
|
+
labels = input_ids.clone()
|
|
1794
|
+
if attention_mask is not None:
|
|
1795
|
+
labels = labels.masked_fill(attention_mask == 0, -100)
|
|
1796
|
+
else:
|
|
1797
|
+
if not isinstance(labels, torch.Tensor):
|
|
1798
|
+
labels = torch.tensor(labels, dtype=torch.long)
|
|
1799
|
+
if labels.dim() == 1:
|
|
1800
|
+
labels = labels.unsqueeze(0)
|
|
1801
|
+
batch["labels"] = labels
|
|
1802
|
+
|
|
1803
|
+
yield batch
|
|
1804
|
+
|
|
1805
|
+
def __len__(self):
|
|
1806
|
+
return len(self.samples)
|
|
1807
|
+
|
|
1808
|
+
return SampleDataLoader(samples)
|
|
1809
|
+
|
|
1810
|
+
def _finalize_phase(
|
|
1811
|
+
self,
|
|
1812
|
+
model: Any,
|
|
1813
|
+
adapter: ModelAdapter,
|
|
1814
|
+
guard_results: dict[str, dict[str, Any]],
|
|
1815
|
+
metrics: dict[str, Any],
|
|
1816
|
+
config: RunConfig,
|
|
1817
|
+
report: RunReport,
|
|
1818
|
+
) -> str:
|
|
1819
|
+
"""Phase 5: Finalize or rollback based on results."""
|
|
1820
|
+
self._log_event("finalize", "start", LogLevel.INFO)
|
|
1821
|
+
|
|
1822
|
+
# Check if guards passed
|
|
1823
|
+
all_guards_passed = all(r.get("passed", False) for r in guard_results.values())
|
|
1824
|
+
|
|
1825
|
+
# Check for catastrophic drift spike using primary metric preview/final
|
|
1826
|
+
pm = metrics.get("primary_metric", {}) if isinstance(metrics, dict) else {}
|
|
1827
|
+
pm_prev = pm.get("preview") if isinstance(pm, dict) else None
|
|
1828
|
+
pm_fin = pm.get("final") if isinstance(pm, dict) else None
|
|
1829
|
+
try:
|
|
1830
|
+
drift_ratio = (
|
|
1831
|
+
float(pm_fin) / float(pm_prev)
|
|
1832
|
+
if isinstance(pm_fin, (int | float))
|
|
1833
|
+
and isinstance(pm_prev, (int | float))
|
|
1834
|
+
and float(pm_prev) > 0.0
|
|
1835
|
+
else float("inf")
|
|
1836
|
+
)
|
|
1837
|
+
except Exception:
|
|
1838
|
+
drift_ratio = float("inf")
|
|
1839
|
+
spike_threshold = getattr(config, "spike_threshold", 2.0)
|
|
1840
|
+
is_catastrophic_spike = drift_ratio > spike_threshold
|
|
1841
|
+
|
|
1842
|
+
# Check if standard metrics are acceptable against configured max ratio
|
|
1843
|
+
metrics_acceptable = drift_ratio <= getattr(config, "max_pm_ratio", 2.0)
|
|
1844
|
+
|
|
1845
|
+
# Determine rollback reason and status
|
|
1846
|
+
rollback_reason = None
|
|
1847
|
+
if is_catastrophic_spike:
|
|
1848
|
+
rollback_reason = (
|
|
1849
|
+
f"catastrophic_ppl_spike (ratio: {drift_ratio:.3f} > {spike_threshold})"
|
|
1850
|
+
)
|
|
1851
|
+
status = RunStatus.ROLLBACK.value
|
|
1852
|
+
|
|
1853
|
+
self._log_event(
|
|
1854
|
+
"finalize",
|
|
1855
|
+
"catastrophic_spike_detected",
|
|
1856
|
+
LogLevel.ERROR,
|
|
1857
|
+
{
|
|
1858
|
+
"primary_metric_drift_ratio": drift_ratio,
|
|
1859
|
+
"spike_threshold": spike_threshold,
|
|
1860
|
+
"immediate_rollback": True,
|
|
1861
|
+
},
|
|
1862
|
+
)
|
|
1863
|
+
elif (not all_guards_passed) or (not metrics_acceptable):
|
|
1864
|
+
# Match historical/test expectation string exactly
|
|
1865
|
+
rollback_reason = "guards_failed or metrics_unacceptable"
|
|
1866
|
+
status = RunStatus.ROLLBACK.value
|
|
1867
|
+
else:
|
|
1868
|
+
status = RunStatus.SUCCESS.value
|
|
1869
|
+
|
|
1870
|
+
# Execute the determined action
|
|
1871
|
+
if status == RunStatus.SUCCESS.value:
|
|
1872
|
+
self._log_event(
|
|
1873
|
+
"finalize",
|
|
1874
|
+
"success",
|
|
1875
|
+
LogLevel.INFO,
|
|
1876
|
+
{
|
|
1877
|
+
"guards_passed": all_guards_passed,
|
|
1878
|
+
"metrics_ok": metrics_acceptable,
|
|
1879
|
+
},
|
|
1880
|
+
)
|
|
1881
|
+
else:
|
|
1882
|
+
# Perform rollback if checkpoint available
|
|
1883
|
+
if self.checkpoint_manager and "initial_checkpoint" in report.meta:
|
|
1884
|
+
checkpoint_id = report.meta["initial_checkpoint"]
|
|
1885
|
+
self.checkpoint_manager.restore_checkpoint(
|
|
1886
|
+
model, adapter, checkpoint_id
|
|
1887
|
+
)
|
|
1888
|
+
# Match test expectation: only include checkpoint and reason
|
|
1889
|
+
self._log_event(
|
|
1890
|
+
"finalize",
|
|
1891
|
+
"rollback",
|
|
1892
|
+
LogLevel.WARNING,
|
|
1893
|
+
{
|
|
1894
|
+
"checkpoint": checkpoint_id,
|
|
1895
|
+
"reason": rollback_reason,
|
|
1896
|
+
},
|
|
1897
|
+
)
|
|
1898
|
+
|
|
1899
|
+
# Store rollback metadata in report
|
|
1900
|
+
report.meta["rollback_reason"] = rollback_reason
|
|
1901
|
+
report.meta["rollback_checkpoint"] = checkpoint_id
|
|
1902
|
+
report.meta["guard_recovered"] = True
|
|
1903
|
+
|
|
1904
|
+
else:
|
|
1905
|
+
# Match test expectation: log without additional data payload
|
|
1906
|
+
self._log_event("finalize", "rollback_unavailable", LogLevel.ERROR)
|
|
1907
|
+
|
|
1908
|
+
return status
|
|
1909
|
+
|
|
1910
|
+
def _handle_error(self, error: Exception, report: RunReport) -> None:
|
|
1911
|
+
"""Handle pipeline errors."""
|
|
1912
|
+
report.status = RunStatus.FAILED.value
|
|
1913
|
+
report.error = str(error)
|
|
1914
|
+
report.meta["end_time"] = time.time()
|
|
1915
|
+
|
|
1916
|
+
if "start_time" in report.meta:
|
|
1917
|
+
report.meta["duration"] = (
|
|
1918
|
+
report.meta["end_time"] - report.meta["start_time"]
|
|
1919
|
+
)
|
|
1920
|
+
|
|
1921
|
+
self._log_event("runner", "error", LogLevel.ERROR, {"error": str(error)})
|
|
1922
|
+
|
|
1923
|
+
# Attempt rollback on error
|
|
1924
|
+
if self.checkpoint_manager and "initial_checkpoint" in report.meta:
|
|
1925
|
+
try:
|
|
1926
|
+
checkpoint_id = report.meta["initial_checkpoint"]
|
|
1927
|
+
# Would need model and adapter here for actual rollback
|
|
1928
|
+
self._log_event(
|
|
1929
|
+
"runner",
|
|
1930
|
+
"emergency_rollback",
|
|
1931
|
+
LogLevel.WARNING,
|
|
1932
|
+
{"checkpoint": checkpoint_id},
|
|
1933
|
+
)
|
|
1934
|
+
except Exception as rollback_error:
|
|
1935
|
+
self._log_event(
|
|
1936
|
+
"runner",
|
|
1937
|
+
"rollback_failed",
|
|
1938
|
+
LogLevel.CRITICAL,
|
|
1939
|
+
{"error": str(rollback_error)},
|
|
1940
|
+
)
|
|
1941
|
+
|
|
1942
|
+
def _resolve_guard_policies(
|
|
1943
|
+
self, report: RunReport, auto_config: dict[str, Any] | None = None
|
|
1944
|
+
) -> dict[str, dict[str, Any]]:
|
|
1945
|
+
"""Resolve tier-based guard policies from configuration."""
|
|
1946
|
+
# Use passed auto_config if available, otherwise extract from report meta
|
|
1947
|
+
if auto_config is None:
|
|
1948
|
+
config_meta = report.meta.get("config", {})
|
|
1949
|
+
|
|
1950
|
+
# Try to get auto config from various possible locations
|
|
1951
|
+
if hasattr(report, "auto_config"):
|
|
1952
|
+
auto_config = report.auto_config
|
|
1953
|
+
elif "auto" in config_meta:
|
|
1954
|
+
auto_config = config_meta["auto"]
|
|
1955
|
+
else:
|
|
1956
|
+
# Fallback to default balanced tier
|
|
1957
|
+
auto_config = {"tier": "balanced", "enabled": True}
|
|
1958
|
+
|
|
1959
|
+
# Extract tier and edit name
|
|
1960
|
+
tier = auto_config.get("tier", "balanced")
|
|
1961
|
+
edit_name = None
|
|
1962
|
+
if hasattr(report, "edit") and report.edit:
|
|
1963
|
+
edit_name = report.edit.get("name")
|
|
1964
|
+
|
|
1965
|
+
# Also try to get edit name from stored edit result in meta
|
|
1966
|
+
if not edit_name and "edit_name" in report.meta:
|
|
1967
|
+
edit_name = report.meta["edit_name"]
|
|
1968
|
+
|
|
1969
|
+
# Get explicit guard overrides from config
|
|
1970
|
+
config_meta = report.meta.get("config", {})
|
|
1971
|
+
explicit_overrides = config_meta.get("guards", {})
|
|
1972
|
+
|
|
1973
|
+
try:
|
|
1974
|
+
# Resolve tier policies
|
|
1975
|
+
policies = resolve_tier_policies(tier, edit_name, explicit_overrides)
|
|
1976
|
+
|
|
1977
|
+
self._log_event(
|
|
1978
|
+
"auto_tuning",
|
|
1979
|
+
"tier_resolved",
|
|
1980
|
+
LogLevel.INFO,
|
|
1981
|
+
{"tier": tier, "edit": edit_name, "policies_count": len(policies)},
|
|
1982
|
+
)
|
|
1983
|
+
|
|
1984
|
+
return policies
|
|
1985
|
+
|
|
1986
|
+
except Exception as e:
|
|
1987
|
+
self._log_event(
|
|
1988
|
+
"auto_tuning",
|
|
1989
|
+
"tier_resolution_failed",
|
|
1990
|
+
LogLevel.ERROR,
|
|
1991
|
+
{"tier": tier, "error": str(e)},
|
|
1992
|
+
)
|
|
1993
|
+
# Return empty policies dict on failure
|
|
1994
|
+
return {}
|
|
1995
|
+
|
|
1996
|
+
def _apply_guard_policy(self, guard: Guard, policy: dict[str, Any]) -> None:
|
|
1997
|
+
"""Apply resolved policy parameters to a guard instance."""
|
|
1998
|
+
try:
|
|
1999
|
+
# Apply policy parameters to guard
|
|
2000
|
+
for param_name, param_value in policy.items():
|
|
2001
|
+
if hasattr(guard, param_name):
|
|
2002
|
+
setattr(guard, param_name, param_value)
|
|
2003
|
+
elif hasattr(guard, "config") and isinstance(guard.config, dict):
|
|
2004
|
+
# Try to set in guard's config dict
|
|
2005
|
+
guard.config[param_name] = param_value
|
|
2006
|
+
elif hasattr(guard, "policy") and isinstance(guard.policy, dict):
|
|
2007
|
+
# Try to set in guard's policy dict
|
|
2008
|
+
guard.policy[param_name] = param_value
|
|
2009
|
+
else:
|
|
2010
|
+
# Last resort: add to guard as attribute
|
|
2011
|
+
setattr(guard, param_name, param_value)
|
|
2012
|
+
|
|
2013
|
+
except Exception as e:
|
|
2014
|
+
self._log_event(
|
|
2015
|
+
"auto_tuning",
|
|
2016
|
+
"policy_application_failed",
|
|
2017
|
+
LogLevel.WARNING,
|
|
2018
|
+
{"guard": guard.name, "policy": policy, "error": str(e)},
|
|
2019
|
+
)
|
|
2020
|
+
|
|
2021
|
+
def _log_event(
|
|
2022
|
+
self,
|
|
2023
|
+
component: str,
|
|
2024
|
+
operation: str,
|
|
2025
|
+
level: LogLevel,
|
|
2026
|
+
data: dict[str, Any] | None = None,
|
|
2027
|
+
) -> None:
|
|
2028
|
+
"""Log an event if event logger is available."""
|
|
2029
|
+
if self.event_logger:
|
|
2030
|
+
self.event_logger.log(component, operation, level, data)
|
|
2031
|
+
|
|
2032
|
+
def _serialize_config(self, config: RunConfig) -> dict[str, Any]:
|
|
2033
|
+
"""Serialize RunConfig for storage in report."""
|
|
2034
|
+
return {
|
|
2035
|
+
"device": config.device,
|
|
2036
|
+
"max_pm_ratio": config.max_pm_ratio,
|
|
2037
|
+
"checkpoint_interval": config.checkpoint_interval,
|
|
2038
|
+
"dry_run": config.dry_run,
|
|
2039
|
+
"verbose": config.verbose,
|
|
2040
|
+
"guards": config.context.get("guards", {}) if config.context else {},
|
|
2041
|
+
}
|