invarlock 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- invarlock/__init__.py +33 -0
- invarlock/__main__.py +10 -0
- invarlock/_data/runtime/profiles/ci_cpu.yaml +15 -0
- invarlock/_data/runtime/profiles/release.yaml +23 -0
- invarlock/_data/runtime/tiers.yaml +76 -0
- invarlock/adapters/__init__.py +102 -0
- invarlock/adapters/_capabilities.py +45 -0
- invarlock/adapters/auto.py +99 -0
- invarlock/adapters/base.py +530 -0
- invarlock/adapters/base_types.py +85 -0
- invarlock/adapters/hf_bert.py +852 -0
- invarlock/adapters/hf_gpt2.py +403 -0
- invarlock/adapters/hf_llama.py +485 -0
- invarlock/adapters/hf_mixin.py +383 -0
- invarlock/adapters/hf_onnx.py +112 -0
- invarlock/adapters/hf_t5.py +137 -0
- invarlock/adapters/py.typed +1 -0
- invarlock/assurance/__init__.py +43 -0
- invarlock/cli/__init__.py +8 -0
- invarlock/cli/__main__.py +8 -0
- invarlock/cli/_evidence.py +25 -0
- invarlock/cli/_json.py +75 -0
- invarlock/cli/adapter_auto.py +162 -0
- invarlock/cli/app.py +287 -0
- invarlock/cli/commands/__init__.py +26 -0
- invarlock/cli/commands/certify.py +403 -0
- invarlock/cli/commands/doctor.py +1358 -0
- invarlock/cli/commands/explain_gates.py +151 -0
- invarlock/cli/commands/export_html.py +100 -0
- invarlock/cli/commands/plugins.py +1331 -0
- invarlock/cli/commands/report.py +354 -0
- invarlock/cli/commands/run.py +4146 -0
- invarlock/cli/commands/verify.py +1040 -0
- invarlock/cli/config.py +396 -0
- invarlock/cli/constants.py +68 -0
- invarlock/cli/device.py +92 -0
- invarlock/cli/doctor_helpers.py +74 -0
- invarlock/cli/errors.py +6 -0
- invarlock/cli/overhead_utils.py +60 -0
- invarlock/cli/provenance.py +66 -0
- invarlock/cli/utils.py +41 -0
- invarlock/config.py +56 -0
- invarlock/core/__init__.py +62 -0
- invarlock/core/abi.py +15 -0
- invarlock/core/api.py +274 -0
- invarlock/core/auto_tuning.py +317 -0
- invarlock/core/bootstrap.py +226 -0
- invarlock/core/checkpoint.py +221 -0
- invarlock/core/contracts.py +73 -0
- invarlock/core/error_utils.py +64 -0
- invarlock/core/events.py +298 -0
- invarlock/core/exceptions.py +95 -0
- invarlock/core/registry.py +481 -0
- invarlock/core/retry.py +146 -0
- invarlock/core/runner.py +2041 -0
- invarlock/core/types.py +154 -0
- invarlock/edits/__init__.py +12 -0
- invarlock/edits/_edit_utils.py +249 -0
- invarlock/edits/_external_utils.py +268 -0
- invarlock/edits/noop.py +47 -0
- invarlock/edits/py.typed +1 -0
- invarlock/edits/quant_rtn.py +801 -0
- invarlock/edits/registry.py +166 -0
- invarlock/eval/__init__.py +23 -0
- invarlock/eval/bench.py +1207 -0
- invarlock/eval/bootstrap.py +50 -0
- invarlock/eval/data.py +2052 -0
- invarlock/eval/metrics.py +2167 -0
- invarlock/eval/primary_metric.py +767 -0
- invarlock/eval/probes/__init__.py +24 -0
- invarlock/eval/probes/fft.py +139 -0
- invarlock/eval/probes/mi.py +213 -0
- invarlock/eval/probes/post_attention.py +323 -0
- invarlock/eval/providers/base.py +67 -0
- invarlock/eval/providers/seq2seq.py +111 -0
- invarlock/eval/providers/text_lm.py +113 -0
- invarlock/eval/providers/vision_text.py +93 -0
- invarlock/eval/py.typed +1 -0
- invarlock/guards/__init__.py +18 -0
- invarlock/guards/_contracts.py +9 -0
- invarlock/guards/invariants.py +640 -0
- invarlock/guards/policies.py +805 -0
- invarlock/guards/py.typed +1 -0
- invarlock/guards/rmt.py +2097 -0
- invarlock/guards/spectral.py +1419 -0
- invarlock/guards/tier_config.py +354 -0
- invarlock/guards/variance.py +3298 -0
- invarlock/guards_ref/__init__.py +15 -0
- invarlock/guards_ref/rmt_ref.py +40 -0
- invarlock/guards_ref/spectral_ref.py +135 -0
- invarlock/guards_ref/variance_ref.py +60 -0
- invarlock/model_profile.py +353 -0
- invarlock/model_utils.py +221 -0
- invarlock/observability/__init__.py +10 -0
- invarlock/observability/alerting.py +535 -0
- invarlock/observability/core.py +546 -0
- invarlock/observability/exporters.py +565 -0
- invarlock/observability/health.py +588 -0
- invarlock/observability/metrics.py +457 -0
- invarlock/observability/py.typed +1 -0
- invarlock/observability/utils.py +553 -0
- invarlock/plugins/__init__.py +12 -0
- invarlock/plugins/hello_guard.py +33 -0
- invarlock/plugins/hf_awq_adapter.py +82 -0
- invarlock/plugins/hf_bnb_adapter.py +79 -0
- invarlock/plugins/hf_gptq_adapter.py +78 -0
- invarlock/plugins/py.typed +1 -0
- invarlock/py.typed +1 -0
- invarlock/reporting/__init__.py +7 -0
- invarlock/reporting/certificate.py +3221 -0
- invarlock/reporting/certificate_schema.py +244 -0
- invarlock/reporting/dataset_hashing.py +215 -0
- invarlock/reporting/guards_analysis.py +948 -0
- invarlock/reporting/html.py +32 -0
- invarlock/reporting/normalizer.py +235 -0
- invarlock/reporting/policy_utils.py +517 -0
- invarlock/reporting/primary_metric_utils.py +265 -0
- invarlock/reporting/render.py +1442 -0
- invarlock/reporting/report.py +903 -0
- invarlock/reporting/report_types.py +278 -0
- invarlock/reporting/utils.py +175 -0
- invarlock/reporting/validate.py +631 -0
- invarlock/security.py +176 -0
- invarlock/sparsity_utils.py +323 -0
- invarlock/utils/__init__.py +150 -0
- invarlock/utils/digest.py +45 -0
- invarlock-0.2.0.dist-info/METADATA +586 -0
- invarlock-0.2.0.dist-info/RECORD +132 -0
- invarlock-0.2.0.dist-info/WHEEL +5 -0
- invarlock-0.2.0.dist-info/entry_points.txt +20 -0
- invarlock-0.2.0.dist-info/licenses/LICENSE +201 -0
- invarlock-0.2.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,4146 @@
|
|
|
1
|
+
"""
|
|
2
|
+
InvarLock CLI Run Command
|
|
3
|
+
=====================
|
|
4
|
+
|
|
5
|
+
Run a guarded pipeline from a YAML config. Intended for local smokes,
|
|
6
|
+
plugin demos, and development. Advanced: for pairwise certification,
|
|
7
|
+
prefer Compare & Certify via `invarlock certify --baseline ... --subject ...`.
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
import copy
|
|
11
|
+
import hashlib
|
|
12
|
+
import json
|
|
13
|
+
import math
|
|
14
|
+
import os
|
|
15
|
+
import random
|
|
16
|
+
import shutil
|
|
17
|
+
import sys as _sys
|
|
18
|
+
import types as _types
|
|
19
|
+
from array import array
|
|
20
|
+
from collections.abc import Iterable, Sequence
|
|
21
|
+
from datetime import datetime
|
|
22
|
+
from pathlib import Path
|
|
23
|
+
from types import SimpleNamespace
|
|
24
|
+
from typing import Any
|
|
25
|
+
|
|
26
|
+
import click
|
|
27
|
+
import numpy as np
|
|
28
|
+
import psutil
|
|
29
|
+
import typer
|
|
30
|
+
from rich.console import Console
|
|
31
|
+
|
|
32
|
+
try:
|
|
33
|
+
import torch
|
|
34
|
+
except ImportError:
|
|
35
|
+
torch = None # type: ignore[assignment]
|
|
36
|
+
|
|
37
|
+
from invarlock.cli.errors import InvarlockError
|
|
38
|
+
from invarlock.cli.utils import (
|
|
39
|
+
coerce_float as _coerce_float,
|
|
40
|
+
)
|
|
41
|
+
from invarlock.cli.utils import (
|
|
42
|
+
coerce_int as _coerce_int,
|
|
43
|
+
)
|
|
44
|
+
from invarlock.cli.utils import (
|
|
45
|
+
coerce_option as _coerce_option,
|
|
46
|
+
)
|
|
47
|
+
from invarlock.core.exceptions import (
|
|
48
|
+
ConfigError as _CfgErr,
|
|
49
|
+
)
|
|
50
|
+
from invarlock.core.exceptions import (
|
|
51
|
+
DataError as _DataErr,
|
|
52
|
+
)
|
|
53
|
+
from invarlock.core.exceptions import (
|
|
54
|
+
ValidationError as _ValErr,
|
|
55
|
+
)
|
|
56
|
+
from invarlock.model_profile import detect_model_profile, resolve_tokenizer
|
|
57
|
+
from invarlock.model_utils import set_seed
|
|
58
|
+
from invarlock.reporting.validate import validate_guard_overhead
|
|
59
|
+
|
|
60
|
+
from ..config import (
|
|
61
|
+
InvarLockConfig,
|
|
62
|
+
)
|
|
63
|
+
from ..overhead_utils import _extract_pm_snapshot_for_overhead
|
|
64
|
+
|
|
65
|
+
console = Console()
|
|
66
|
+
LIGHT_IMPORT = os.getenv("INVARLOCK_LIGHT_IMPORT", "").strip().lower() in {
|
|
67
|
+
"1",
|
|
68
|
+
"true",
|
|
69
|
+
"yes",
|
|
70
|
+
}
|
|
71
|
+
|
|
72
|
+
# Release profile window planning constants
|
|
73
|
+
RELEASE_BUFFER_FRACTION = 0.12
|
|
74
|
+
RELEASE_MIN_WINDOWS_PER_ARM = 200
|
|
75
|
+
RELEASE_CALIBRATION_MIN = 16
|
|
76
|
+
RELEASE_CALIBRATION_MAX = 24
|
|
77
|
+
GUARD_OVERHEAD_THRESHOLD = 0.01
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
# Common dataset split aliases we probe in order when not explicitly set
|
|
81
|
+
SPLIT_ALIASES: tuple[str, ...] = ("validation", "val", "dev", "eval", "test")
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
def _choose_dataset_split(
|
|
85
|
+
*, requested: str | None, available: list[str] | None
|
|
86
|
+
) -> tuple[str, bool]:
|
|
87
|
+
"""
|
|
88
|
+
Choose a dataset split deterministically.
|
|
89
|
+
|
|
90
|
+
Returns (split, used_fallback). If `requested` is provided, returns it verbatim.
|
|
91
|
+
Else tries SPLIT_ALIASES in order; if none present, falls back to the first
|
|
92
|
+
available split (sorted for determinism). If `available` is None/empty, returns
|
|
93
|
+
('validation', True) as a last resort so the run does not crash.
|
|
94
|
+
"""
|
|
95
|
+
try:
|
|
96
|
+
if isinstance(requested, str) and requested:
|
|
97
|
+
return requested, False
|
|
98
|
+
except Exception:
|
|
99
|
+
pass
|
|
100
|
+
avail = list(available) if isinstance(available, list) and available else []
|
|
101
|
+
if avail:
|
|
102
|
+
for cand in SPLIT_ALIASES:
|
|
103
|
+
if cand in avail:
|
|
104
|
+
return cand, True
|
|
105
|
+
return sorted(avail)[0], True
|
|
106
|
+
return "validation", True
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
def _persist_ref_masks(core_report: Any, run_dir: Path) -> Path | None:
|
|
110
|
+
"""Persist reference keep indices to artifact if present."""
|
|
111
|
+
|
|
112
|
+
edit_section = (
|
|
113
|
+
core_report.get("edit")
|
|
114
|
+
if isinstance(core_report, dict)
|
|
115
|
+
else getattr(core_report, "edit", None)
|
|
116
|
+
)
|
|
117
|
+
if not isinstance(edit_section, dict):
|
|
118
|
+
return None
|
|
119
|
+
|
|
120
|
+
artifacts_section = edit_section.get("artifacts")
|
|
121
|
+
if not isinstance(artifacts_section, dict):
|
|
122
|
+
return None
|
|
123
|
+
|
|
124
|
+
mask_payload = artifacts_section.get("mask_payload")
|
|
125
|
+
if not isinstance(mask_payload, dict) or not mask_payload:
|
|
126
|
+
return None
|
|
127
|
+
|
|
128
|
+
payload_copy = copy.deepcopy(mask_payload)
|
|
129
|
+
meta_section = payload_copy.setdefault("meta", {})
|
|
130
|
+
meta_section.setdefault("generated_at", datetime.now().isoformat())
|
|
131
|
+
|
|
132
|
+
target_dir = run_dir / "artifacts" / "edit_masks"
|
|
133
|
+
target_dir.mkdir(parents=True, exist_ok=True)
|
|
134
|
+
mask_path = target_dir / "masks.json"
|
|
135
|
+
with mask_path.open("w", encoding="utf-8") as handle:
|
|
136
|
+
json.dump(payload_copy, handle, indent=2, sort_keys=True)
|
|
137
|
+
handle.write("\n")
|
|
138
|
+
|
|
139
|
+
return mask_path
|
|
140
|
+
|
|
141
|
+
|
|
142
|
+
def _resolve_exit_code(exc: Exception, *, profile: str | None) -> int:
|
|
143
|
+
"""Resolve exit code based on exception type and profile.
|
|
144
|
+
|
|
145
|
+
- ValueError("Invalid RunReport...") → 2 (schema/shape issue)
|
|
146
|
+
- InvarlockError in CI/Release → 3 (hard abort)
|
|
147
|
+
- All other cases → 1 (generic failure)
|
|
148
|
+
"""
|
|
149
|
+
try:
|
|
150
|
+
prof = (profile or "").strip().lower()
|
|
151
|
+
except Exception:
|
|
152
|
+
prof = ""
|
|
153
|
+
# Schema/validation classes and known shapes → exit 2
|
|
154
|
+
if isinstance(exc, _CfgErr | _ValErr | _DataErr):
|
|
155
|
+
return 2
|
|
156
|
+
if isinstance(exc, ValueError) and "Invalid RunReport" in str(exc):
|
|
157
|
+
return 2
|
|
158
|
+
if isinstance(exc, InvarlockError) and prof in {"ci", "release"}:
|
|
159
|
+
return 3
|
|
160
|
+
return 1
|
|
161
|
+
|
|
162
|
+
|
|
163
|
+
## NOTE: Deprecated legacy helper `_check_pairability_or_abort` was removed.
|
|
164
|
+
## Provider parity and pairing guarantees are enforced via guard digests and
|
|
165
|
+
## invariant checks during run execution.
|
|
166
|
+
|
|
167
|
+
|
|
168
|
+
def _hash_sequences(seqs: Sequence[Sequence[int]] | Iterable[Sequence[int]]) -> str:
|
|
169
|
+
"""Compute a stable digest for a sequence of integer token sequences."""
|
|
170
|
+
hasher = hashlib.blake2s(digest_size=16)
|
|
171
|
+
for seq in seqs:
|
|
172
|
+
arr = array("I", (int(token) & 0xFFFFFFFF for token in seq))
|
|
173
|
+
hasher.update(arr.tobytes())
|
|
174
|
+
return hasher.hexdigest()
|
|
175
|
+
|
|
176
|
+
|
|
177
|
+
def _compute_mask_positions_digest(windows: dict[str, Any]) -> str | None:
|
|
178
|
+
"""Compute a rolled hash of MLM mask positions across windows.
|
|
179
|
+
|
|
180
|
+
Expects windows of the shape { 'preview': {...}, 'final': {...} } with
|
|
181
|
+
'labels' and optional 'window_ids' in each section. Positions where
|
|
182
|
+
labels != -100 are treated as masked.
|
|
183
|
+
"""
|
|
184
|
+
try:
|
|
185
|
+
# Simple, dependency-light digest of positions where labels != -100
|
|
186
|
+
hasher = hashlib.blake2s(digest_size=16)
|
|
187
|
+
any_masked = False
|
|
188
|
+
for arm in ("preview", "final"):
|
|
189
|
+
sec = windows.get(arm)
|
|
190
|
+
if not isinstance(sec, dict):
|
|
191
|
+
continue
|
|
192
|
+
labels = sec.get("labels")
|
|
193
|
+
if not isinstance(labels, list) or not labels:
|
|
194
|
+
continue
|
|
195
|
+
hasher.update(arm.encode("utf-8"))
|
|
196
|
+
for row in labels:
|
|
197
|
+
row_list = _tensor_or_list_to_ints(row)
|
|
198
|
+
if not row_list:
|
|
199
|
+
continue
|
|
200
|
+
found = False
|
|
201
|
+
for idx, v in enumerate(row_list):
|
|
202
|
+
if int(v) != -100:
|
|
203
|
+
hasher.update(b"1")
|
|
204
|
+
hasher.update(idx.to_bytes(4, "little", signed=False))
|
|
205
|
+
found = True
|
|
206
|
+
if found:
|
|
207
|
+
any_masked = True
|
|
208
|
+
hasher.update(b"|")
|
|
209
|
+
if not any_masked:
|
|
210
|
+
return None
|
|
211
|
+
digest = hasher.hexdigest()
|
|
212
|
+
return digest if digest else None
|
|
213
|
+
except Exception:
|
|
214
|
+
return None
|
|
215
|
+
|
|
216
|
+
|
|
217
|
+
def _to_int_list(values: Sequence[int] | Iterable[int]) -> list[int]:
|
|
218
|
+
return [int(v) for v in values]
|
|
219
|
+
|
|
220
|
+
|
|
221
|
+
def _tensor_or_list_to_ints(values: Any) -> list[int]:
|
|
222
|
+
"""Coerce possible tensor/list-like inputs to a list[int]."""
|
|
223
|
+
try:
|
|
224
|
+
# Torch tensors: `.tolist()` path
|
|
225
|
+
if torch is not None and hasattr(values, "tolist"):
|
|
226
|
+
raw = values.tolist()
|
|
227
|
+
if isinstance(raw, list):
|
|
228
|
+
return _to_int_list(raw)
|
|
229
|
+
try:
|
|
230
|
+
return _to_int_list(list(raw))
|
|
231
|
+
except Exception:
|
|
232
|
+
pass
|
|
233
|
+
# Numpy arrays: treat as list-like
|
|
234
|
+
if isinstance(values, np.ndarray | list | tuple):
|
|
235
|
+
return _to_int_list(list(values))
|
|
236
|
+
# Iterables of ints
|
|
237
|
+
if isinstance(values, Iterable):
|
|
238
|
+
return _to_int_list(values)
|
|
239
|
+
except Exception:
|
|
240
|
+
pass
|
|
241
|
+
return []
|
|
242
|
+
|
|
243
|
+
|
|
244
|
+
def _safe_int(value: Any, default: int = 0) -> int:
|
|
245
|
+
try:
|
|
246
|
+
return int(value)
|
|
247
|
+
except (TypeError, ValueError):
|
|
248
|
+
return default
|
|
249
|
+
|
|
250
|
+
|
|
251
|
+
def _derive_mlm_seed(base_seed: int, window_id: str | int, position: int) -> int:
|
|
252
|
+
payload = f"{base_seed}:{window_id}:{position}".encode()
|
|
253
|
+
digest = hashlib.blake2s(payload, digest_size=8).digest()
|
|
254
|
+
return int.from_bytes(digest, "little", signed=False)
|
|
255
|
+
|
|
256
|
+
|
|
257
|
+
def _apply_mlm_masks(
|
|
258
|
+
records: list[dict[str, Any]],
|
|
259
|
+
*,
|
|
260
|
+
tokenizer: Any,
|
|
261
|
+
mask_prob: float,
|
|
262
|
+
seed: int,
|
|
263
|
+
random_token_prob: float,
|
|
264
|
+
original_token_prob: float,
|
|
265
|
+
prefix: str,
|
|
266
|
+
) -> tuple[int, list[int]]:
|
|
267
|
+
"""Apply basic BERT-style MLM masking to tokenized records in-place."""
|
|
268
|
+
if mask_prob <= 0.0:
|
|
269
|
+
zeroed = []
|
|
270
|
+
for record in records:
|
|
271
|
+
length = len(record["input_ids"])
|
|
272
|
+
record["labels"] = [-100] * length
|
|
273
|
+
record["mlm_masked"] = 0
|
|
274
|
+
zeroed.append(0)
|
|
275
|
+
return 0, zeroed
|
|
276
|
+
|
|
277
|
+
vocab_size = _safe_int(getattr(tokenizer, "vocab_size", 0))
|
|
278
|
+
# Require an explicit mask token id for MLM
|
|
279
|
+
mask_token_id = getattr(tokenizer, "mask_token_id", None)
|
|
280
|
+
if mask_token_id is None:
|
|
281
|
+
raise RuntimeError(
|
|
282
|
+
"Tokenizer does not define mask_token_id; required for MLM evaluation."
|
|
283
|
+
)
|
|
284
|
+
try:
|
|
285
|
+
mask_token_id = int(mask_token_id)
|
|
286
|
+
except Exception:
|
|
287
|
+
mask_token_id = _safe_int(mask_token_id, 0)
|
|
288
|
+
|
|
289
|
+
# Build special token id set to avoid masking them
|
|
290
|
+
special_ids = set()
|
|
291
|
+
for attr in (
|
|
292
|
+
"cls_token_id",
|
|
293
|
+
"sep_token_id",
|
|
294
|
+
"bos_token_id",
|
|
295
|
+
"eos_token_id",
|
|
296
|
+
"pad_token_id",
|
|
297
|
+
):
|
|
298
|
+
val = getattr(tokenizer, attr, None)
|
|
299
|
+
if val is not None:
|
|
300
|
+
try:
|
|
301
|
+
special_ids.add(int(val))
|
|
302
|
+
except Exception:
|
|
303
|
+
pass
|
|
304
|
+
try:
|
|
305
|
+
special_ids.update(
|
|
306
|
+
int(t) for t in getattr(tokenizer, "all_special_ids", []) or []
|
|
307
|
+
)
|
|
308
|
+
except Exception:
|
|
309
|
+
pass
|
|
310
|
+
|
|
311
|
+
masked_total = 0
|
|
312
|
+
masked_counts: list[int] = []
|
|
313
|
+
for idx_record, record in enumerate(records):
|
|
314
|
+
window_id = record.get("window_id", f"{prefix}:{idx_record}")
|
|
315
|
+
input_ids = _tensor_or_list_to_ints(record.get("input_ids", []))
|
|
316
|
+
attention = _tensor_or_list_to_ints(record.get("attention_mask", []))
|
|
317
|
+
labels = [-100] * len(input_ids)
|
|
318
|
+
|
|
319
|
+
masked = 0
|
|
320
|
+
for pos, (tok, att) in enumerate(zip(input_ids, attention, strict=False)):
|
|
321
|
+
if not att:
|
|
322
|
+
continue
|
|
323
|
+
if int(tok) in special_ids:
|
|
324
|
+
continue
|
|
325
|
+
if random.random() < mask_prob:
|
|
326
|
+
rng = random.Random(_derive_mlm_seed(seed, window_id, pos))
|
|
327
|
+
labels[pos] = int(tok)
|
|
328
|
+
r = rng.random()
|
|
329
|
+
if r < 1.0 - (random_token_prob + original_token_prob):
|
|
330
|
+
input_ids[pos] = mask_token_id
|
|
331
|
+
elif r < 1.0 - original_token_prob and vocab_size > 0:
|
|
332
|
+
rng2 = random.Random(_derive_mlm_seed(seed + 17, window_id, pos))
|
|
333
|
+
input_ids[pos] = rng2.randint(0, max(1, vocab_size - 1))
|
|
334
|
+
masked += 1
|
|
335
|
+
|
|
336
|
+
# Ensure at least one masked token for stability
|
|
337
|
+
if masked == 0:
|
|
338
|
+
candidate_positions = [
|
|
339
|
+
p
|
|
340
|
+
for p, (tok, att) in enumerate(zip(input_ids, attention, strict=False))
|
|
341
|
+
if att and int(tok) not in special_ids
|
|
342
|
+
]
|
|
343
|
+
if candidate_positions:
|
|
344
|
+
pos = candidate_positions[len(candidate_positions) // 2]
|
|
345
|
+
rng = random.Random(_derive_mlm_seed(seed + 17, window_id, pos))
|
|
346
|
+
labels[pos] = int(input_ids[pos])
|
|
347
|
+
masked = 1
|
|
348
|
+
r = rng.random()
|
|
349
|
+
if r < 1.0 - (random_token_prob + original_token_prob):
|
|
350
|
+
input_ids[pos] = mask_token_id
|
|
351
|
+
elif r < 1.0 - original_token_prob and vocab_size > 0:
|
|
352
|
+
input_ids[pos] = rng.randrange(vocab_size)
|
|
353
|
+
|
|
354
|
+
record["input_ids"] = _to_int_list(input_ids)
|
|
355
|
+
record["attention_mask"] = _to_int_list(attention)
|
|
356
|
+
record["labels"] = _to_int_list(labels)
|
|
357
|
+
record["mlm_masked"] = masked
|
|
358
|
+
masked_total += masked
|
|
359
|
+
masked_counts.append(masked)
|
|
360
|
+
|
|
361
|
+
return masked_total, masked_counts
|
|
362
|
+
|
|
363
|
+
|
|
364
|
+
def _tokenizer_digest(tokenizer: Any) -> str:
|
|
365
|
+
"""Compute a stable digest for a tokenizer config.
|
|
366
|
+
|
|
367
|
+
Tries, in order: get_vocab().items(), `vocab` attribute if list-like, else
|
|
368
|
+
hashes a small set of informative attributes.
|
|
369
|
+
"""
|
|
370
|
+
try:
|
|
371
|
+
if hasattr(tokenizer, "get_vocab"):
|
|
372
|
+
try:
|
|
373
|
+
items = getattr(tokenizer.get_vocab(), "items", None)
|
|
374
|
+
if callable(items):
|
|
375
|
+
pairs = list(items())
|
|
376
|
+
# Filter non-string keys for stability
|
|
377
|
+
pairs = [
|
|
378
|
+
(str(k), int(v)) for k, v in pairs if isinstance(k, str | int)
|
|
379
|
+
]
|
|
380
|
+
payload = json.dumps(sorted(pairs), separators=(",", ":")).encode()
|
|
381
|
+
return hashlib.sha256(payload).hexdigest()
|
|
382
|
+
except Exception:
|
|
383
|
+
pass
|
|
384
|
+
# Fallback to `vocab` attribute (e.g., list of pairs)
|
|
385
|
+
vocab = getattr(tokenizer, "vocab", None)
|
|
386
|
+
if isinstance(vocab, list):
|
|
387
|
+
try:
|
|
388
|
+
payload = json.dumps(
|
|
389
|
+
[(str(k), int(v)) for k, v in vocab], separators=(",", ":")
|
|
390
|
+
).encode()
|
|
391
|
+
return hashlib.sha256(payload).hexdigest()
|
|
392
|
+
except Exception:
|
|
393
|
+
pass
|
|
394
|
+
# Last resort: small attribute set
|
|
395
|
+
attrs = {
|
|
396
|
+
"name": getattr(tokenizer, "name_or_path", None),
|
|
397
|
+
"eos": getattr(tokenizer, "eos_token", None),
|
|
398
|
+
"pad": getattr(tokenizer, "pad_token", None),
|
|
399
|
+
"size": _safe_int(getattr(tokenizer, "vocab_size", 0)),
|
|
400
|
+
}
|
|
401
|
+
return hashlib.sha256(json.dumps(attrs, sort_keys=True).encode()).hexdigest()
|
|
402
|
+
except Exception:
|
|
403
|
+
return "unknown-tokenizer"
|
|
404
|
+
|
|
405
|
+
|
|
406
|
+
def _extract_pairing_schedule(report: dict[str, Any] | None) -> dict[str, Any] | None:
|
|
407
|
+
"""Extract sanitized pairing schedule from a baseline-like report shape.
|
|
408
|
+
|
|
409
|
+
Returns a dict with 'preview' and 'final' sections when present. Each section
|
|
410
|
+
contains window_ids, input_ids, attention_masks, optional labels and counts.
|
|
411
|
+
"""
|
|
412
|
+
if not isinstance(report, dict):
|
|
413
|
+
return None
|
|
414
|
+
windows = report.get("evaluation_windows")
|
|
415
|
+
if not isinstance(windows, dict):
|
|
416
|
+
return None
|
|
417
|
+
|
|
418
|
+
def _sanitize(section_key: str) -> dict[str, Any] | None:
|
|
419
|
+
section = windows.get(section_key)
|
|
420
|
+
if not isinstance(section, dict):
|
|
421
|
+
return None
|
|
422
|
+
window_ids = list(section.get("window_ids", []))
|
|
423
|
+
input_ids_raw = section.get("input_ids", [])
|
|
424
|
+
if not isinstance(input_ids_raw, list):
|
|
425
|
+
return None
|
|
426
|
+
input_ids = [list(seq) for seq in input_ids_raw]
|
|
427
|
+
attention_raw = section.get("attention_masks")
|
|
428
|
+
if isinstance(attention_raw, list) and all(
|
|
429
|
+
isinstance(mask, list) for mask in attention_raw
|
|
430
|
+
):
|
|
431
|
+
attention_masks = [list(mask) for mask in attention_raw]
|
|
432
|
+
else:
|
|
433
|
+
attention_masks = [
|
|
434
|
+
[1 if token != 0 else 0 for token in seq] for seq in input_ids
|
|
435
|
+
]
|
|
436
|
+
|
|
437
|
+
labels_raw = section.get("labels")
|
|
438
|
+
labels: list[list[int]] | None = None
|
|
439
|
+
if isinstance(labels_raw, list) and labels_raw:
|
|
440
|
+
labels = []
|
|
441
|
+
for idx, raw_label in enumerate(labels_raw):
|
|
442
|
+
label_list = _tensor_or_list_to_ints(raw_label)
|
|
443
|
+
if idx < len(input_ids):
|
|
444
|
+
target_len = len(input_ids[idx])
|
|
445
|
+
if len(label_list) < target_len:
|
|
446
|
+
label_list = label_list + [-100] * (
|
|
447
|
+
target_len - len(label_list)
|
|
448
|
+
)
|
|
449
|
+
elif len(label_list) > target_len:
|
|
450
|
+
label_list = label_list[:target_len]
|
|
451
|
+
labels.append(label_list)
|
|
452
|
+
|
|
453
|
+
masked_counts = None
|
|
454
|
+
if isinstance(section.get("masked_token_counts"), list):
|
|
455
|
+
masked_counts = [int(v) for v in section["masked_token_counts"]]
|
|
456
|
+
actual_counts = None
|
|
457
|
+
if isinstance(section.get("actual_token_counts"), list):
|
|
458
|
+
actual_counts = [int(v) for v in section["actual_token_counts"]]
|
|
459
|
+
|
|
460
|
+
payload: dict[str, Any] = {
|
|
461
|
+
"window_ids": window_ids,
|
|
462
|
+
"input_ids": input_ids,
|
|
463
|
+
"attention_masks": attention_masks,
|
|
464
|
+
}
|
|
465
|
+
if labels is not None:
|
|
466
|
+
payload["labels"] = labels
|
|
467
|
+
if masked_counts is not None:
|
|
468
|
+
payload["masked_token_counts"] = masked_counts
|
|
469
|
+
if actual_counts is not None:
|
|
470
|
+
payload["actual_token_counts"] = actual_counts
|
|
471
|
+
return payload
|
|
472
|
+
|
|
473
|
+
preview = _sanitize("preview")
|
|
474
|
+
final = _sanitize("final")
|
|
475
|
+
if preview and final:
|
|
476
|
+
return {"preview": preview, "final": final}
|
|
477
|
+
return None
|
|
478
|
+
|
|
479
|
+
|
|
480
|
+
def _prepare_config_for_run(
|
|
481
|
+
*,
|
|
482
|
+
config_path: str,
|
|
483
|
+
profile: str | None,
|
|
484
|
+
edit: str | None,
|
|
485
|
+
tier: str | None,
|
|
486
|
+
probes: int | None,
|
|
487
|
+
console: Console,
|
|
488
|
+
) -> InvarLockConfig:
|
|
489
|
+
"""Load InvarLock config and apply CLI/profile overrides deterministically."""
|
|
490
|
+
# Local import to allow test monkeypatching of invarlock.cli.config functions
|
|
491
|
+
from ..config import (
|
|
492
|
+
apply_edit_override as _apply_edit_override,
|
|
493
|
+
)
|
|
494
|
+
from ..config import (
|
|
495
|
+
apply_profile as _apply_profile,
|
|
496
|
+
)
|
|
497
|
+
from ..config import (
|
|
498
|
+
load_config as _load_config,
|
|
499
|
+
)
|
|
500
|
+
from ..config import (
|
|
501
|
+
resolve_edit_kind as _resolve_edit_kind,
|
|
502
|
+
)
|
|
503
|
+
|
|
504
|
+
console.print(f"📋 Loading configuration: {config_path}")
|
|
505
|
+
cfg = _load_config(config_path)
|
|
506
|
+
|
|
507
|
+
# Apply profile if specified (dev is a no-op)
|
|
508
|
+
if profile and str(profile).lower() in {"ci", "release"}:
|
|
509
|
+
console.print(f"🎯 Applying profile: {profile}")
|
|
510
|
+
try:
|
|
511
|
+
cfg = _apply_profile(cfg, profile)
|
|
512
|
+
except Exception as exc:
|
|
513
|
+
console.print(f"[red]❌ {exc}[/red]")
|
|
514
|
+
raise typer.Exit(1) from exc
|
|
515
|
+
|
|
516
|
+
# Apply edit override
|
|
517
|
+
if edit:
|
|
518
|
+
try:
|
|
519
|
+
edit_name = _resolve_edit_kind(edit)
|
|
520
|
+
console.print(f"✂️ Edit override: {edit} → {edit_name}")
|
|
521
|
+
cfg = _apply_edit_override(cfg, edit)
|
|
522
|
+
except ValueError as e:
|
|
523
|
+
console.print(f"[red]❌ {e}[/red]")
|
|
524
|
+
raise typer.Exit(1) from e
|
|
525
|
+
|
|
526
|
+
# Apply CLI overrides for auto configuration
|
|
527
|
+
if tier or probes is not None:
|
|
528
|
+
if tier and tier not in ["conservative", "balanced", "aggressive", "none"]:
|
|
529
|
+
console.print(
|
|
530
|
+
f"[red]❌ Invalid tier '{tier}'. Valid options: conservative, balanced, aggressive, none[/red]"
|
|
531
|
+
)
|
|
532
|
+
raise typer.Exit(1)
|
|
533
|
+
if probes is not None and (probes < 0 or probes > 10):
|
|
534
|
+
console.print(
|
|
535
|
+
f"[red]❌ Invalid probes '{probes}'. Must be between 0 and 10[/red]"
|
|
536
|
+
)
|
|
537
|
+
raise typer.Exit(1)
|
|
538
|
+
|
|
539
|
+
cfg_dict = cfg.model_dump()
|
|
540
|
+
auto_section = (
|
|
541
|
+
cfg_dict.get("auto") if isinstance(cfg_dict.get("auto"), dict) else {}
|
|
542
|
+
)
|
|
543
|
+
cfg_dict["auto"] = auto_section
|
|
544
|
+
if tier:
|
|
545
|
+
auto_section["tier"] = tier
|
|
546
|
+
console.print(f"🎛️ Auto tier override: {tier}")
|
|
547
|
+
if probes is not None:
|
|
548
|
+
auto_section["probes"] = probes
|
|
549
|
+
console.print(f"🔬 Auto probes override: {probes}")
|
|
550
|
+
cfg = InvarLockConfig(cfg_dict)
|
|
551
|
+
|
|
552
|
+
# Resolve adapter:auto to a concrete built-in adapter if requested
|
|
553
|
+
try:
|
|
554
|
+
from ..adapter_auto import apply_auto_adapter_if_needed as _apply_auto
|
|
555
|
+
|
|
556
|
+
cfg = _apply_auto(cfg)
|
|
557
|
+
except Exception:
|
|
558
|
+
pass
|
|
559
|
+
|
|
560
|
+
return cfg
|
|
561
|
+
|
|
562
|
+
|
|
563
|
+
def _maybe_plan_release_windows(
|
|
564
|
+
capacity_meta: dict[str, Any],
|
|
565
|
+
*,
|
|
566
|
+
requested_preview: int,
|
|
567
|
+
requested_final: int,
|
|
568
|
+
max_calibration: int,
|
|
569
|
+
console: Console,
|
|
570
|
+
) -> dict[str, Any]:
|
|
571
|
+
"""Thin wrapper around _plan_release_windows to improve readability."""
|
|
572
|
+
return _plan_release_windows(
|
|
573
|
+
capacity_meta,
|
|
574
|
+
requested_preview=requested_preview,
|
|
575
|
+
requested_final=requested_final,
|
|
576
|
+
max_calibration=max_calibration,
|
|
577
|
+
console=console,
|
|
578
|
+
)
|
|
579
|
+
|
|
580
|
+
|
|
581
|
+
def _print_pipeline_start(console: Console) -> None:
|
|
582
|
+
console.print("🚀 Starting InvarLock pipeline...")
|
|
583
|
+
|
|
584
|
+
|
|
585
|
+
def _emit_run_artifacts(
|
|
586
|
+
*, report: Any, out_dir: Path, filename_prefix: str, console: Console
|
|
587
|
+
) -> dict[str, str]:
|
|
588
|
+
"""Save run report and return emitted artifact paths."""
|
|
589
|
+
from invarlock.reporting.report import save_report as _save_report
|
|
590
|
+
|
|
591
|
+
console.print("💾 Saving run report...")
|
|
592
|
+
return _save_report(
|
|
593
|
+
report, out_dir, formats=["json"], filename_prefix=filename_prefix
|
|
594
|
+
)
|
|
595
|
+
|
|
596
|
+
|
|
597
|
+
def _resolve_device_and_output(
|
|
598
|
+
cfg: Any, *, device: str | None, out: str | None, console: Console
|
|
599
|
+
) -> tuple[str, Path]:
|
|
600
|
+
"""Resolve device and output directory with validation and logging."""
|
|
601
|
+
from ..device import (
|
|
602
|
+
resolve_device as _resolve_device,
|
|
603
|
+
)
|
|
604
|
+
from ..device import (
|
|
605
|
+
validate_device_for_config as _validate,
|
|
606
|
+
)
|
|
607
|
+
|
|
608
|
+
try:
|
|
609
|
+
cfg_device = getattr(cfg.model, "device", None)
|
|
610
|
+
except Exception:
|
|
611
|
+
cfg_device = None
|
|
612
|
+
target_device = device or cfg_device or "auto"
|
|
613
|
+
resolved_device = _resolve_device(target_device)
|
|
614
|
+
console.print(
|
|
615
|
+
f"Device: {resolved_device} (requested={target_device}, resolved={resolved_device})"
|
|
616
|
+
)
|
|
617
|
+
is_valid, error_msg = _validate(resolved_device)
|
|
618
|
+
if not is_valid:
|
|
619
|
+
console.print(f"[red]❌ Device validation failed: {error_msg}[/red]")
|
|
620
|
+
raise typer.Exit(1)
|
|
621
|
+
|
|
622
|
+
# Determine output directory (support both 'output.dir' and legacy 'out.dir')
|
|
623
|
+
if out:
|
|
624
|
+
output_dir = Path(out)
|
|
625
|
+
else:
|
|
626
|
+
try:
|
|
627
|
+
output_dir = Path(cfg.output.dir)
|
|
628
|
+
except Exception:
|
|
629
|
+
try:
|
|
630
|
+
output_dir = Path(cfg.out.dir) # type: ignore[attr-defined]
|
|
631
|
+
except Exception:
|
|
632
|
+
output_dir = Path("runs")
|
|
633
|
+
output_dir.mkdir(parents=True, exist_ok=True)
|
|
634
|
+
return str(resolved_device), output_dir
|
|
635
|
+
|
|
636
|
+
|
|
637
|
+
def _resolve_provider_and_split(
|
|
638
|
+
cfg: Any,
|
|
639
|
+
model_profile: Any,
|
|
640
|
+
*,
|
|
641
|
+
get_provider_fn: Any,
|
|
642
|
+
provider_kwargs: dict[str, Any] | None = None,
|
|
643
|
+
console: Console,
|
|
644
|
+
resolved_device: str | None = None,
|
|
645
|
+
) -> tuple[Any, str, bool]:
|
|
646
|
+
"""Resolve dataset provider and split, returning (provider, split, used_fallback)."""
|
|
647
|
+
provider_name = None
|
|
648
|
+
provider_kwargs = dict(provider_kwargs or {})
|
|
649
|
+
try:
|
|
650
|
+
provider_val = cfg.dataset.provider
|
|
651
|
+
except Exception:
|
|
652
|
+
provider_val = None
|
|
653
|
+
if isinstance(provider_val, str) and provider_val:
|
|
654
|
+
provider_name = provider_val
|
|
655
|
+
else:
|
|
656
|
+
try:
|
|
657
|
+
provider_name = provider_val.kind # type: ignore[attr-defined]
|
|
658
|
+
try:
|
|
659
|
+
for k, v in provider_val.items(): # type: ignore[attr-defined]
|
|
660
|
+
if k != "kind" and v is not None and v != "":
|
|
661
|
+
provider_kwargs[k] = v
|
|
662
|
+
except Exception:
|
|
663
|
+
pass
|
|
664
|
+
except Exception:
|
|
665
|
+
provider_name = None
|
|
666
|
+
if not provider_name:
|
|
667
|
+
provider_name = getattr(model_profile, "default_provider", None) or "wikitext2"
|
|
668
|
+
# Pass device hint only to providers that understand it (currently WikiText-2)
|
|
669
|
+
if resolved_device and provider_name == "wikitext2":
|
|
670
|
+
provider_kwargs.setdefault("device_hint", resolved_device)
|
|
671
|
+
data_provider = get_provider_fn(provider_name, **provider_kwargs)
|
|
672
|
+
|
|
673
|
+
requested_split = None
|
|
674
|
+
try:
|
|
675
|
+
requested_split = getattr(cfg.dataset, "split", None)
|
|
676
|
+
except Exception:
|
|
677
|
+
requested_split = None
|
|
678
|
+
available_splits = None
|
|
679
|
+
if hasattr(data_provider, "available_splits"):
|
|
680
|
+
try:
|
|
681
|
+
available_splits = list(data_provider.available_splits()) # type: ignore[attr-defined]
|
|
682
|
+
except Exception:
|
|
683
|
+
available_splits = None
|
|
684
|
+
resolved_split, used_fallback_split = _choose_dataset_split(
|
|
685
|
+
requested=requested_split, available=available_splits
|
|
686
|
+
)
|
|
687
|
+
return data_provider, resolved_split, used_fallback_split
|
|
688
|
+
|
|
689
|
+
|
|
690
|
+
def _run_bare_control(
|
|
691
|
+
*,
|
|
692
|
+
adapter: Any,
|
|
693
|
+
edit_op: Any,
|
|
694
|
+
cfg: Any,
|
|
695
|
+
model: Any,
|
|
696
|
+
run_config: Any,
|
|
697
|
+
calibration_data: list[Any],
|
|
698
|
+
auto_config: Any,
|
|
699
|
+
edit_config: Any,
|
|
700
|
+
preview_count: int,
|
|
701
|
+
final_count: int,
|
|
702
|
+
seed_bundle: dict[str, int | None],
|
|
703
|
+
resolved_device: str,
|
|
704
|
+
restore_fn: Any | None,
|
|
705
|
+
console: Console,
|
|
706
|
+
resolved_loss_type: str,
|
|
707
|
+
profile_normalized: str | None,
|
|
708
|
+
skip_model_load: bool = False,
|
|
709
|
+
) -> dict[str, Any] | None:
|
|
710
|
+
"""Execute the bare-control run for overhead estimation and return payload."""
|
|
711
|
+
from invarlock.core.runner import CoreRunner as _CoreRunner
|
|
712
|
+
|
|
713
|
+
console.print("🧪 Running bare control (guards disabled) for overhead check")
|
|
714
|
+
set_seed(seed_bundle["python"]) # type: ignore[arg-type]
|
|
715
|
+
|
|
716
|
+
bare_runner = _CoreRunner()
|
|
717
|
+
bare_config = copy.deepcopy(run_config)
|
|
718
|
+
bare_config.event_path = None
|
|
719
|
+
bare_context = copy.deepcopy(run_config.context)
|
|
720
|
+
bare_context.setdefault("validation", {})["guard_overhead_mode"] = "bare"
|
|
721
|
+
bare_config.context = bare_context
|
|
722
|
+
|
|
723
|
+
if restore_fn and model is not None:
|
|
724
|
+
restore_fn()
|
|
725
|
+
bare_target_model = model
|
|
726
|
+
elif skip_model_load:
|
|
727
|
+
bare_target_model = model or SimpleNamespace(name="bare_stub_model")
|
|
728
|
+
else:
|
|
729
|
+
bare_target_model = adapter.load_model(cfg.model.id, device=resolved_device)
|
|
730
|
+
|
|
731
|
+
bare_report = bare_runner.execute(
|
|
732
|
+
model=bare_target_model,
|
|
733
|
+
adapter=adapter,
|
|
734
|
+
edit=edit_op,
|
|
735
|
+
guards=[],
|
|
736
|
+
config=bare_config,
|
|
737
|
+
calibration_data=calibration_data,
|
|
738
|
+
auto_config=auto_config,
|
|
739
|
+
edit_config=edit_config,
|
|
740
|
+
preview_n=preview_count,
|
|
741
|
+
final_n=final_count,
|
|
742
|
+
)
|
|
743
|
+
|
|
744
|
+
bare_ppl_final = None
|
|
745
|
+
bare_ppl_preview = None
|
|
746
|
+
if hasattr(bare_report, "metrics") and bare_report.metrics:
|
|
747
|
+
bare_pm = bare_report.metrics.get("primary_metric", {})
|
|
748
|
+
bare_ppl_final = bare_pm.get("final") if isinstance(bare_pm, dict) else None
|
|
749
|
+
bare_ppl_preview = bare_pm.get("preview") if isinstance(bare_pm, dict) else None
|
|
750
|
+
|
|
751
|
+
if profile_normalized in {"ci", "release"}:
|
|
752
|
+
|
|
753
|
+
def _finite(x: Any) -> bool:
|
|
754
|
+
try:
|
|
755
|
+
return isinstance(x, (int | float)) and math.isfinite(float(x))
|
|
756
|
+
except Exception:
|
|
757
|
+
return False
|
|
758
|
+
|
|
759
|
+
if not (_finite(bare_ppl_preview) and _finite(bare_ppl_final)):
|
|
760
|
+
console.print(
|
|
761
|
+
"[yellow]⚠️ Primary metric non-finite during bare control; continuing with diagnostics.[/yellow]"
|
|
762
|
+
)
|
|
763
|
+
|
|
764
|
+
payload: dict[str, Any] = {
|
|
765
|
+
"overhead_threshold": GUARD_OVERHEAD_THRESHOLD,
|
|
766
|
+
"messages": [],
|
|
767
|
+
"warnings": [],
|
|
768
|
+
"errors": [],
|
|
769
|
+
"checks": {},
|
|
770
|
+
"source": f"{profile_normalized or 'ci'}_profile",
|
|
771
|
+
}
|
|
772
|
+
|
|
773
|
+
if getattr(bare_report, "status", "").lower() not in {"success", "completed", "ok"}:
|
|
774
|
+
payload["warnings"].append(
|
|
775
|
+
f"Bare run status: {getattr(bare_report, 'status', 'unknown')}"
|
|
776
|
+
)
|
|
777
|
+
|
|
778
|
+
try:
|
|
779
|
+
lk = str(resolved_loss_type or "causal").lower()
|
|
780
|
+
if lk == "mlm":
|
|
781
|
+
pm_kind_bare = "ppl_mlm"
|
|
782
|
+
elif lk in {"seq2seq", "s2s", "t5"}:
|
|
783
|
+
pm_kind_bare = "ppl_seq2seq"
|
|
784
|
+
else:
|
|
785
|
+
pm_kind_bare = "ppl_causal"
|
|
786
|
+
pm_bare = _extract_pm_snapshot_for_overhead(bare_report, kind=pm_kind_bare)
|
|
787
|
+
if isinstance(pm_bare, dict) and pm_bare:
|
|
788
|
+
payload["bare_report"] = {"metrics": {"primary_metric": pm_bare}}
|
|
789
|
+
except Exception:
|
|
790
|
+
pass
|
|
791
|
+
|
|
792
|
+
set_seed(seed_bundle["python"]) # type: ignore[arg-type]
|
|
793
|
+
return payload
|
|
794
|
+
|
|
795
|
+
|
|
796
|
+
def _execute_guarded_run(
|
|
797
|
+
*,
|
|
798
|
+
runner: Any,
|
|
799
|
+
adapter: Any,
|
|
800
|
+
model: Any,
|
|
801
|
+
cfg: Any,
|
|
802
|
+
edit_op: Any,
|
|
803
|
+
run_config: Any,
|
|
804
|
+
guards: list[Any],
|
|
805
|
+
calibration_data: list[Any],
|
|
806
|
+
auto_config: Any,
|
|
807
|
+
edit_config: Any,
|
|
808
|
+
preview_count: int,
|
|
809
|
+
final_count: int,
|
|
810
|
+
restore_fn: Any | None,
|
|
811
|
+
resolved_device: str,
|
|
812
|
+
console: Console,
|
|
813
|
+
skip_model_load: bool = False,
|
|
814
|
+
) -> tuple[Any, Any]:
|
|
815
|
+
"""Restore or load model and execute the guarded CoreRunner."""
|
|
816
|
+
if restore_fn and model is not None:
|
|
817
|
+
restore_fn()
|
|
818
|
+
elif skip_model_load:
|
|
819
|
+
model = model or SimpleNamespace(name="guarded_stub_model")
|
|
820
|
+
else:
|
|
821
|
+
console.print(f"🔧 Loading model: {cfg.model.id} (attempt 1)")
|
|
822
|
+
model = adapter.load_model(cfg.model.id, device=resolved_device)
|
|
823
|
+
|
|
824
|
+
core_report = runner.execute(
|
|
825
|
+
model=model,
|
|
826
|
+
adapter=adapter,
|
|
827
|
+
edit=edit_op,
|
|
828
|
+
guards=guards,
|
|
829
|
+
config=run_config,
|
|
830
|
+
calibration_data=calibration_data,
|
|
831
|
+
auto_config=auto_config,
|
|
832
|
+
edit_config=edit_config,
|
|
833
|
+
preview_n=preview_count,
|
|
834
|
+
final_n=final_count,
|
|
835
|
+
)
|
|
836
|
+
return core_report, model
|
|
837
|
+
|
|
838
|
+
|
|
839
|
+
def _postprocess_and_summarize(
|
|
840
|
+
*,
|
|
841
|
+
report: dict[str, Any],
|
|
842
|
+
run_dir: Path,
|
|
843
|
+
run_config: Any,
|
|
844
|
+
window_plan: dict[str, Any] | None,
|
|
845
|
+
dataset_meta: dict[str, Any],
|
|
846
|
+
match_fraction: float | None,
|
|
847
|
+
overlap_fraction: float | None,
|
|
848
|
+
console: Console,
|
|
849
|
+
) -> None:
|
|
850
|
+
"""Finalize report windows stats and print/save summary artifacts."""
|
|
851
|
+
try:
|
|
852
|
+
ds = report.setdefault("dataset", {}).setdefault("windows", {})
|
|
853
|
+
stats = ds.setdefault("stats", {})
|
|
854
|
+
if match_fraction is not None:
|
|
855
|
+
stats["window_match_fraction"] = float(match_fraction)
|
|
856
|
+
if overlap_fraction is not None:
|
|
857
|
+
stats["window_overlap_fraction"] = float(overlap_fraction)
|
|
858
|
+
try:
|
|
859
|
+
if isinstance(window_plan, dict) and "coverage_ok" in window_plan:
|
|
860
|
+
stats["coverage"] = bool(window_plan.get("coverage_ok"))
|
|
861
|
+
except Exception:
|
|
862
|
+
pass
|
|
863
|
+
except Exception:
|
|
864
|
+
pass
|
|
865
|
+
|
|
866
|
+
saved_files = _emit_run_artifacts(
|
|
867
|
+
report=report, out_dir=run_dir, filename_prefix="report", console=console
|
|
868
|
+
)
|
|
869
|
+
console.print("[green]✅ Run completed successfully![/green]")
|
|
870
|
+
console.print(f"📄 Report: {saved_files['json']}")
|
|
871
|
+
if run_config.event_path:
|
|
872
|
+
console.print(f"📝 Events: {run_config.event_path}")
|
|
873
|
+
|
|
874
|
+
|
|
875
|
+
def _compute_provider_digest(report: dict[str, Any]) -> dict[str, str] | None:
|
|
876
|
+
"""Compute provider digest (ids/tokenizer/masking) from report context.
|
|
877
|
+
|
|
878
|
+
Returns a dict with keys: ids_sha256, tokenizer_sha256, masking_sha256?
|
|
879
|
+
"""
|
|
880
|
+
# Prefer centralized digest helpers
|
|
881
|
+
from invarlock.utils.digest import hash_json as _hash_json
|
|
882
|
+
|
|
883
|
+
windows = report.get("evaluation_windows") if isinstance(report, dict) else None
|
|
884
|
+
if not isinstance(windows, dict) or not windows:
|
|
885
|
+
return None
|
|
886
|
+
# window_ids digest across preview+final (sorted for stability)
|
|
887
|
+
all_ids: list = []
|
|
888
|
+
for key in ("preview", "final"):
|
|
889
|
+
sec = windows.get(key)
|
|
890
|
+
if not isinstance(sec, dict):
|
|
891
|
+
continue
|
|
892
|
+
wids = sec.get("window_ids")
|
|
893
|
+
if isinstance(wids, list):
|
|
894
|
+
all_ids.extend(list(wids))
|
|
895
|
+
ids_sha = _hash_json(sorted(all_ids)) if all_ids else None
|
|
896
|
+
|
|
897
|
+
# tokenizer hash: prefer meta.tokenizer_hash then data.tokenizer_hash
|
|
898
|
+
tok_hash = None
|
|
899
|
+
meta = report.get("meta") if isinstance(report.get("meta"), dict) else None
|
|
900
|
+
if isinstance(meta, dict):
|
|
901
|
+
tok_hash = meta.get("tokenizer_hash")
|
|
902
|
+
if not tok_hash and isinstance(report.get("data"), dict):
|
|
903
|
+
tok_hash = report["data"].get("tokenizer_hash")
|
|
904
|
+
|
|
905
|
+
# masking hash from mask positions
|
|
906
|
+
masking = _compute_mask_positions_digest(windows)
|
|
907
|
+
|
|
908
|
+
digest: dict[str, str] = {}
|
|
909
|
+
if isinstance(ids_sha, str) and ids_sha:
|
|
910
|
+
digest["ids_sha256"] = ids_sha
|
|
911
|
+
if isinstance(tok_hash, str) and tok_hash:
|
|
912
|
+
digest["tokenizer_sha256"] = str(tok_hash)
|
|
913
|
+
if isinstance(masking, str) and masking:
|
|
914
|
+
digest["masking_sha256"] = masking
|
|
915
|
+
return digest or None
|
|
916
|
+
|
|
917
|
+
|
|
918
|
+
def _validate_and_harvest_baseline_schedule(
|
|
919
|
+
cfg: Any,
|
|
920
|
+
pairing_schedule: dict[str, Any],
|
|
921
|
+
baseline_report_data: dict[str, Any] | None,
|
|
922
|
+
*,
|
|
923
|
+
tokenizer_hash: str | None,
|
|
924
|
+
resolved_loss_type: str,
|
|
925
|
+
baseline_path_str: str | None = None,
|
|
926
|
+
console: Console | None = None,
|
|
927
|
+
) -> dict[str, Any]:
|
|
928
|
+
"""Validate baseline schedule compatibility and harvest dataset metadata.
|
|
929
|
+
|
|
930
|
+
Returns a mapping with keys: effective_preview, effective_final, preview_count,
|
|
931
|
+
final_count, dataset_meta, window_plan, calibration_data.
|
|
932
|
+
"""
|
|
933
|
+
|
|
934
|
+
# Helpers
|
|
935
|
+
def _print(msg: str) -> None:
|
|
936
|
+
if console is not None:
|
|
937
|
+
console.print(msg)
|
|
938
|
+
|
|
939
|
+
def _fail_schedule(reason: str) -> None:
|
|
940
|
+
path = baseline_path_str or "baseline"
|
|
941
|
+
_print(
|
|
942
|
+
f"[red]❌ Baseline pairing schedule '{path}' is incompatible: {reason}[/red]"
|
|
943
|
+
)
|
|
944
|
+
raise typer.Exit(1)
|
|
945
|
+
|
|
946
|
+
baseline_meta = (
|
|
947
|
+
baseline_report_data.get("data")
|
|
948
|
+
if isinstance(baseline_report_data, dict)
|
|
949
|
+
else {}
|
|
950
|
+
)
|
|
951
|
+
if not isinstance(baseline_meta, dict):
|
|
952
|
+
baseline_meta = {}
|
|
953
|
+
|
|
954
|
+
def _extract_meta(field: str, default: Any = None) -> Any:
|
|
955
|
+
value = baseline_meta.get(field)
|
|
956
|
+
return value if value is not None else default
|
|
957
|
+
|
|
958
|
+
# Adopt counts from the schedule, warning if they differ from cfg
|
|
959
|
+
baseline_preview = len(pairing_schedule["preview"].get("input_ids") or [])
|
|
960
|
+
baseline_final = len(pairing_schedule["final"].get("input_ids") or [])
|
|
961
|
+
cfg_preview = getattr(cfg.dataset, "preview_n", None)
|
|
962
|
+
cfg_final = getattr(cfg.dataset, "final_n", None)
|
|
963
|
+
if (
|
|
964
|
+
cfg_preview is not None
|
|
965
|
+
and baseline_preview is not None
|
|
966
|
+
and baseline_preview != cfg_preview
|
|
967
|
+
) or (
|
|
968
|
+
cfg_final is not None
|
|
969
|
+
and baseline_final is not None
|
|
970
|
+
and baseline_final != cfg_final
|
|
971
|
+
):
|
|
972
|
+
_print(
|
|
973
|
+
"[yellow]⚠️ Adjusting evaluation window counts to match baseline schedule "
|
|
974
|
+
f"({baseline_preview}/{baseline_final}).[/yellow]"
|
|
975
|
+
)
|
|
976
|
+
|
|
977
|
+
effective_preview = int(baseline_preview)
|
|
978
|
+
effective_final = int(baseline_final)
|
|
979
|
+
preview_count = effective_preview
|
|
980
|
+
final_count = effective_final
|
|
981
|
+
|
|
982
|
+
# Validate key dataset parameters
|
|
983
|
+
cfg_seq_len = getattr(cfg.dataset, "seq_len", None)
|
|
984
|
+
baseline_seq_len = _extract_meta("seq_len")
|
|
985
|
+
if (
|
|
986
|
+
cfg_seq_len is not None
|
|
987
|
+
and baseline_seq_len is not None
|
|
988
|
+
and baseline_seq_len != cfg_seq_len
|
|
989
|
+
):
|
|
990
|
+
_fail_schedule(
|
|
991
|
+
f"sequence length mismatch (baseline {baseline_seq_len} vs config {cfg_seq_len})"
|
|
992
|
+
)
|
|
993
|
+
|
|
994
|
+
cfg_stride = getattr(cfg.dataset, "stride", getattr(cfg.dataset, "seq_len", None))
|
|
995
|
+
baseline_stride = _extract_meta("stride")
|
|
996
|
+
if (
|
|
997
|
+
baseline_stride is not None
|
|
998
|
+
and cfg_stride is not None
|
|
999
|
+
and baseline_stride != cfg_stride
|
|
1000
|
+
):
|
|
1001
|
+
_fail_schedule(
|
|
1002
|
+
f"stride mismatch (baseline {baseline_stride} vs config {cfg_stride})"
|
|
1003
|
+
)
|
|
1004
|
+
|
|
1005
|
+
cfg_dataset = getattr(cfg.dataset, "provider", None)
|
|
1006
|
+
if cfg_dataset is None:
|
|
1007
|
+
cfg_dataset = getattr(cfg.dataset, "dataset", None)
|
|
1008
|
+
baseline_dataset = _extract_meta("dataset")
|
|
1009
|
+
if (
|
|
1010
|
+
baseline_dataset is not None
|
|
1011
|
+
and cfg_dataset is not None
|
|
1012
|
+
and baseline_dataset != cfg_dataset
|
|
1013
|
+
):
|
|
1014
|
+
_fail_schedule(
|
|
1015
|
+
f"dataset mismatch (baseline {baseline_dataset} vs config {cfg_dataset})"
|
|
1016
|
+
)
|
|
1017
|
+
|
|
1018
|
+
cfg_split = getattr(cfg.dataset, "split", "validation")
|
|
1019
|
+
baseline_split = _extract_meta("split")
|
|
1020
|
+
if (
|
|
1021
|
+
baseline_split is not None
|
|
1022
|
+
and cfg_split is not None
|
|
1023
|
+
and baseline_split != cfg_split
|
|
1024
|
+
):
|
|
1025
|
+
_fail_schedule(
|
|
1026
|
+
f"split mismatch (baseline {baseline_split} vs config {cfg_split})"
|
|
1027
|
+
)
|
|
1028
|
+
|
|
1029
|
+
baseline_tokenizer_hash = baseline_meta.get("tokenizer_hash")
|
|
1030
|
+
if (
|
|
1031
|
+
baseline_tokenizer_hash
|
|
1032
|
+
and tokenizer_hash
|
|
1033
|
+
and baseline_tokenizer_hash != tokenizer_hash
|
|
1034
|
+
):
|
|
1035
|
+
_fail_schedule(
|
|
1036
|
+
"tokenizer hash mismatch between baseline and current configuration"
|
|
1037
|
+
)
|
|
1038
|
+
|
|
1039
|
+
dataset_meta = {
|
|
1040
|
+
key: baseline_meta.get(key)
|
|
1041
|
+
for key in (
|
|
1042
|
+
"tokenizer_hash",
|
|
1043
|
+
"tokenizer_name",
|
|
1044
|
+
"vocab_size",
|
|
1045
|
+
"bos_token",
|
|
1046
|
+
"eos_token",
|
|
1047
|
+
"pad_token",
|
|
1048
|
+
"add_prefix_space",
|
|
1049
|
+
"dataset_hash",
|
|
1050
|
+
"preview_hash",
|
|
1051
|
+
"final_hash",
|
|
1052
|
+
"preview_total_tokens",
|
|
1053
|
+
"final_total_tokens",
|
|
1054
|
+
)
|
|
1055
|
+
if baseline_meta.get(key) is not None
|
|
1056
|
+
}
|
|
1057
|
+
dataset_meta["loss_type"] = resolved_loss_type
|
|
1058
|
+
window_plan = baseline_meta.get("window_plan")
|
|
1059
|
+
calibration_data: list[Any] | None = []
|
|
1060
|
+
|
|
1061
|
+
return {
|
|
1062
|
+
"effective_preview": effective_preview,
|
|
1063
|
+
"effective_final": effective_final,
|
|
1064
|
+
"preview_count": preview_count,
|
|
1065
|
+
"final_count": final_count,
|
|
1066
|
+
"dataset_meta": dataset_meta,
|
|
1067
|
+
"window_plan": window_plan,
|
|
1068
|
+
"calibration_data": calibration_data,
|
|
1069
|
+
}
|
|
1070
|
+
|
|
1071
|
+
|
|
1072
|
+
def _enforce_provider_parity(
|
|
1073
|
+
subject_digest: dict | None, baseline_digest: dict | None, *, profile: str | None
|
|
1074
|
+
) -> None:
|
|
1075
|
+
"""Enforce tokenizer/masking parity rules for CI/Release profiles.
|
|
1076
|
+
|
|
1077
|
+
- If tokenizers differ in CI/Release, abort.
|
|
1078
|
+
- If tokenizers match but masking digests differ (MLM), abort.
|
|
1079
|
+
No-ops outside CI/Release.
|
|
1080
|
+
"""
|
|
1081
|
+
prof = (profile or "").strip().lower()
|
|
1082
|
+
if prof not in {"ci", "release"}:
|
|
1083
|
+
return
|
|
1084
|
+
sd = subject_digest or {}
|
|
1085
|
+
bd = baseline_digest or {}
|
|
1086
|
+
subj_tok = sd.get("tokenizer_sha256")
|
|
1087
|
+
base_tok = bd.get("tokenizer_sha256")
|
|
1088
|
+
subj_mask = sd.get("masking_sha256")
|
|
1089
|
+
base_mask = bd.get("masking_sha256")
|
|
1090
|
+
# Missing digest information in CI/Release → abort
|
|
1091
|
+
if not (
|
|
1092
|
+
isinstance(subj_tok, str)
|
|
1093
|
+
and isinstance(base_tok, str)
|
|
1094
|
+
and subj_tok
|
|
1095
|
+
and base_tok
|
|
1096
|
+
):
|
|
1097
|
+
raise InvarlockError(
|
|
1098
|
+
code="E004",
|
|
1099
|
+
message="PROVIDER-DIGEST-MISSING: subject or baseline missing tokenizer digest",
|
|
1100
|
+
)
|
|
1101
|
+
# Tokenizer mismatch → abort with code
|
|
1102
|
+
if subj_tok != base_tok:
|
|
1103
|
+
raise InvarlockError(
|
|
1104
|
+
code="E002",
|
|
1105
|
+
message="TOKENIZER-DIGEST-MISMATCH: subject and baseline tokenizers differ",
|
|
1106
|
+
)
|
|
1107
|
+
# Masking mismatch under identical tokenizers → abort
|
|
1108
|
+
if (
|
|
1109
|
+
isinstance(subj_mask, str)
|
|
1110
|
+
and isinstance(base_mask, str)
|
|
1111
|
+
and subj_mask
|
|
1112
|
+
and base_mask
|
|
1113
|
+
and subj_mask != base_mask
|
|
1114
|
+
):
|
|
1115
|
+
raise InvarlockError(
|
|
1116
|
+
code="E003",
|
|
1117
|
+
message="MASK-PARITY-MISMATCH: mask positions differ under matched tokenizers",
|
|
1118
|
+
)
|
|
1119
|
+
|
|
1120
|
+
|
|
1121
|
+
def _resolve_metric_and_provider(
|
|
1122
|
+
cfg: Any,
|
|
1123
|
+
model_profile: Any,
|
|
1124
|
+
*,
|
|
1125
|
+
resolved_loss_type: str | None = None,
|
|
1126
|
+
) -> tuple[str, str, dict[str, float]]:
|
|
1127
|
+
"""Resolve metric kind, provider kind, and metric options from config with precedence.
|
|
1128
|
+
|
|
1129
|
+
Precedence: CLI args (not handled here) → config → ModelProfile defaults → legacy fallback.
|
|
1130
|
+
Primary metric (metric‑v1) is canonical in dev‑phase; no env flag toggles.
|
|
1131
|
+
"""
|
|
1132
|
+
# Provider kind
|
|
1133
|
+
provider_val = None
|
|
1134
|
+
try:
|
|
1135
|
+
provider_val = cfg.dataset.provider
|
|
1136
|
+
except Exception:
|
|
1137
|
+
provider_val = None
|
|
1138
|
+
provider_kind = None
|
|
1139
|
+
if isinstance(provider_val, str) and provider_val:
|
|
1140
|
+
provider_kind = provider_val
|
|
1141
|
+
else:
|
|
1142
|
+
# Support object-like config sections (e.g., InvarLockConfig _Obj)
|
|
1143
|
+
try:
|
|
1144
|
+
provider_kind = provider_val.kind
|
|
1145
|
+
except Exception:
|
|
1146
|
+
try:
|
|
1147
|
+
provider_kind = provider_val.get("kind") # type: ignore[attr-defined]
|
|
1148
|
+
except Exception:
|
|
1149
|
+
provider_kind = None
|
|
1150
|
+
if not provider_kind and hasattr(model_profile, "default_provider"):
|
|
1151
|
+
provider_kind = model_profile.default_provider
|
|
1152
|
+
# Fallback to a known provider name supported by get_provider()
|
|
1153
|
+
if not provider_kind:
|
|
1154
|
+
provider_kind = "wikitext2"
|
|
1155
|
+
|
|
1156
|
+
# Metric config
|
|
1157
|
+
metric_cfg = None
|
|
1158
|
+
try:
|
|
1159
|
+
eval_section = cfg.eval
|
|
1160
|
+
metric_cfg = getattr(eval_section, "metric", None)
|
|
1161
|
+
except Exception:
|
|
1162
|
+
metric_cfg = None
|
|
1163
|
+
|
|
1164
|
+
metric_kind = None
|
|
1165
|
+
reps = None
|
|
1166
|
+
ci_level = None
|
|
1167
|
+
if metric_cfg is not None:
|
|
1168
|
+
try:
|
|
1169
|
+
metric_kind = (
|
|
1170
|
+
metric_cfg.get("kind")
|
|
1171
|
+
if isinstance(metric_cfg, dict)
|
|
1172
|
+
else metric_cfg.kind
|
|
1173
|
+
)
|
|
1174
|
+
except Exception:
|
|
1175
|
+
metric_kind = None
|
|
1176
|
+
try:
|
|
1177
|
+
reps = (
|
|
1178
|
+
metric_cfg.get("reps")
|
|
1179
|
+
if isinstance(metric_cfg, dict)
|
|
1180
|
+
else metric_cfg.reps
|
|
1181
|
+
)
|
|
1182
|
+
except Exception:
|
|
1183
|
+
reps = None
|
|
1184
|
+
try:
|
|
1185
|
+
ci_level = (
|
|
1186
|
+
metric_cfg.get("ci_level")
|
|
1187
|
+
if isinstance(metric_cfg, dict)
|
|
1188
|
+
else metric_cfg.ci_level
|
|
1189
|
+
)
|
|
1190
|
+
except Exception:
|
|
1191
|
+
ci_level = None
|
|
1192
|
+
|
|
1193
|
+
# Resolve metric kind from config
|
|
1194
|
+
if isinstance(metric_kind, str) and metric_kind:
|
|
1195
|
+
mk = metric_kind.strip().lower()
|
|
1196
|
+
if mk == "auto":
|
|
1197
|
+
metric_kind = None
|
|
1198
|
+
else:
|
|
1199
|
+
metric_kind = mk
|
|
1200
|
+
else:
|
|
1201
|
+
metric_kind = None
|
|
1202
|
+
|
|
1203
|
+
# Fallback to model profile default or legacy resolution by loss type
|
|
1204
|
+
if not metric_kind and hasattr(model_profile, "default_metric"):
|
|
1205
|
+
metric_kind = model_profile.default_metric
|
|
1206
|
+
if not metric_kind:
|
|
1207
|
+
# Legacy: map from loss kind
|
|
1208
|
+
lk = (resolved_loss_type or "causal").lower()
|
|
1209
|
+
if lk == "mlm":
|
|
1210
|
+
metric_kind = "ppl_mlm"
|
|
1211
|
+
elif lk in {"seq2seq", "s2s", "t5"}:
|
|
1212
|
+
metric_kind = "ppl_seq2seq"
|
|
1213
|
+
else:
|
|
1214
|
+
metric_kind = "ppl_causal"
|
|
1215
|
+
|
|
1216
|
+
# Metric options dict if present
|
|
1217
|
+
opts: dict[str, float] = {}
|
|
1218
|
+
if reps is not None:
|
|
1219
|
+
try:
|
|
1220
|
+
opts["reps"] = float(int(reps))
|
|
1221
|
+
except Exception:
|
|
1222
|
+
pass
|
|
1223
|
+
if ci_level is not None:
|
|
1224
|
+
try:
|
|
1225
|
+
opts["ci_level"] = float(ci_level)
|
|
1226
|
+
except Exception:
|
|
1227
|
+
pass
|
|
1228
|
+
|
|
1229
|
+
return str(metric_kind), str(provider_kind), opts
|
|
1230
|
+
|
|
1231
|
+
|
|
1232
|
+
def _plan_release_windows(
|
|
1233
|
+
capacity: dict[str, Any],
|
|
1234
|
+
*,
|
|
1235
|
+
requested_preview: int,
|
|
1236
|
+
requested_final: int,
|
|
1237
|
+
max_calibration: int,
|
|
1238
|
+
console: Console | None = None,
|
|
1239
|
+
) -> dict[str, Any]:
|
|
1240
|
+
"""Derive release-tier window plan based on dataset capacity."""
|
|
1241
|
+
available_unique = int(capacity.get("available_unique", 0))
|
|
1242
|
+
available_nonoverlap = int(capacity.get("available_nonoverlap", 0))
|
|
1243
|
+
total_tokens = int(capacity.get("total_tokens", 0))
|
|
1244
|
+
dedupe_rate = float(capacity.get("dedupe_rate", 0.0))
|
|
1245
|
+
candidate_unique = capacity.get("candidate_unique")
|
|
1246
|
+
if candidate_unique is not None and int(candidate_unique) > 0:
|
|
1247
|
+
effective_unique = int(candidate_unique)
|
|
1248
|
+
else:
|
|
1249
|
+
effective_unique = available_unique
|
|
1250
|
+
candidate_limit = capacity.get("candidate_limit")
|
|
1251
|
+
|
|
1252
|
+
target_per_arm = int(min(requested_preview, requested_final))
|
|
1253
|
+
if target_per_arm <= 0:
|
|
1254
|
+
target_per_arm = requested_preview or requested_final or 1
|
|
1255
|
+
|
|
1256
|
+
max_calibration = max(0, int(max_calibration or 0))
|
|
1257
|
+
if max_calibration > 0:
|
|
1258
|
+
calibration_windows = max(
|
|
1259
|
+
RELEASE_CALIBRATION_MIN,
|
|
1260
|
+
min(RELEASE_CALIBRATION_MAX, max_calibration // 10),
|
|
1261
|
+
)
|
|
1262
|
+
else:
|
|
1263
|
+
calibration_windows = RELEASE_CALIBRATION_MIN
|
|
1264
|
+
calibration_windows = min(calibration_windows, available_unique)
|
|
1265
|
+
|
|
1266
|
+
buffer_windows = math.ceil(RELEASE_BUFFER_FRACTION * effective_unique)
|
|
1267
|
+
reserve_windows = min(effective_unique, calibration_windows + buffer_windows)
|
|
1268
|
+
available_for_eval = max(0, effective_unique - reserve_windows)
|
|
1269
|
+
actual_per_arm_raw = available_for_eval // 2
|
|
1270
|
+
|
|
1271
|
+
coverage_ok = actual_per_arm_raw >= RELEASE_MIN_WINDOWS_PER_ARM
|
|
1272
|
+
if not coverage_ok:
|
|
1273
|
+
raise RuntimeError(
|
|
1274
|
+
"Release profile capacity insufficient: "
|
|
1275
|
+
f"available_unique={available_unique}, reserve={reserve_windows} "
|
|
1276
|
+
f"(calibration={calibration_windows}, buffer={buffer_windows}), "
|
|
1277
|
+
f"usable_per_arm={actual_per_arm_raw}, "
|
|
1278
|
+
f"requires ≥{RELEASE_MIN_WINDOWS_PER_ARM} per arm."
|
|
1279
|
+
)
|
|
1280
|
+
|
|
1281
|
+
actual_per_arm = min(target_per_arm, actual_per_arm_raw)
|
|
1282
|
+
|
|
1283
|
+
if console:
|
|
1284
|
+
candidate_msg = ""
|
|
1285
|
+
if candidate_unique is not None:
|
|
1286
|
+
candidate_msg = f", candidate_unique={int(candidate_unique)}" + (
|
|
1287
|
+
f"/{int(candidate_limit)}" if candidate_limit is not None else ""
|
|
1288
|
+
)
|
|
1289
|
+
console.print(
|
|
1290
|
+
"📏 Release window capacity:"
|
|
1291
|
+
f" unique={available_unique}, reserve={reserve_windows} "
|
|
1292
|
+
f"(calib {calibration_windows}, buffer {buffer_windows}), "
|
|
1293
|
+
f"usable={available_for_eval}, "
|
|
1294
|
+
f"per-arm raw={actual_per_arm_raw} → selected {actual_per_arm} "
|
|
1295
|
+
f"(target {target_per_arm}{candidate_msg})"
|
|
1296
|
+
)
|
|
1297
|
+
if actual_per_arm < target_per_arm:
|
|
1298
|
+
console.print(
|
|
1299
|
+
"[yellow]⚠️ Adjusted per-arm windows down from "
|
|
1300
|
+
f"{target_per_arm} to {actual_per_arm} based on capacity.[/yellow]"
|
|
1301
|
+
)
|
|
1302
|
+
|
|
1303
|
+
plan = {
|
|
1304
|
+
"profile": "release",
|
|
1305
|
+
"requested_preview": int(requested_preview),
|
|
1306
|
+
"requested_final": int(requested_final),
|
|
1307
|
+
"target_per_arm": target_per_arm,
|
|
1308
|
+
"min_per_arm": RELEASE_MIN_WINDOWS_PER_ARM,
|
|
1309
|
+
"actual_preview": int(actual_per_arm),
|
|
1310
|
+
"actual_final": int(actual_per_arm),
|
|
1311
|
+
"actual_per_arm_raw": int(actual_per_arm_raw),
|
|
1312
|
+
"coverage_ok": coverage_ok,
|
|
1313
|
+
"capacity": {
|
|
1314
|
+
"total_tokens": total_tokens,
|
|
1315
|
+
"available_nonoverlap": available_nonoverlap,
|
|
1316
|
+
"available_unique": available_unique,
|
|
1317
|
+
"effective_unique": effective_unique,
|
|
1318
|
+
"dedupe_rate": dedupe_rate,
|
|
1319
|
+
"calibration": calibration_windows,
|
|
1320
|
+
"buffer_fraction": RELEASE_BUFFER_FRACTION,
|
|
1321
|
+
"buffer_windows": buffer_windows,
|
|
1322
|
+
"reserve_windows": reserve_windows,
|
|
1323
|
+
"usable_after_reserve": available_for_eval,
|
|
1324
|
+
},
|
|
1325
|
+
}
|
|
1326
|
+
if candidate_unique is not None:
|
|
1327
|
+
plan["capacity"]["candidate_unique"] = int(candidate_unique)
|
|
1328
|
+
if candidate_limit is not None:
|
|
1329
|
+
plan["capacity"]["candidate_limit"] = int(candidate_limit)
|
|
1330
|
+
return plan
|
|
1331
|
+
|
|
1332
|
+
|
|
1333
|
+
# Check if core components are available
|
|
1334
|
+
try:
|
|
1335
|
+
from invarlock.core.api import RunConfig # noqa: F401
|
|
1336
|
+
from invarlock.core.registry import get_registry # noqa: F401
|
|
1337
|
+
|
|
1338
|
+
HAS_CORE_COMPONENTS = True
|
|
1339
|
+
except ImportError:
|
|
1340
|
+
HAS_CORE_COMPONENTS = False
|
|
1341
|
+
|
|
1342
|
+
|
|
1343
|
+
def run_command(
|
|
1344
|
+
config: str = typer.Option(
|
|
1345
|
+
..., "--config", "-c", help="Path to YAML configuration file"
|
|
1346
|
+
),
|
|
1347
|
+
device: str | None = typer.Option(
|
|
1348
|
+
None, "--device", help="Device override (auto|cuda|mps|cpu)"
|
|
1349
|
+
),
|
|
1350
|
+
profile: str | None = typer.Option(
|
|
1351
|
+
None, "--profile", help="Profile to apply (ci|release)"
|
|
1352
|
+
),
|
|
1353
|
+
out: str | None = typer.Option(None, "--out", help="Output directory override"),
|
|
1354
|
+
edit: str | None = typer.Option(None, "--edit", help="Edit kind (quant|mixed)"),
|
|
1355
|
+
tier: str | None = typer.Option(
|
|
1356
|
+
None,
|
|
1357
|
+
"--tier",
|
|
1358
|
+
help="Auto-tuning tier override (conservative|balanced|aggressive)",
|
|
1359
|
+
),
|
|
1360
|
+
probes: int | None = typer.Option(
|
|
1361
|
+
None, "--probes", help="Number of micro-probes (0=deterministic, >0=adaptive)"
|
|
1362
|
+
),
|
|
1363
|
+
until_pass: bool = typer.Option(
|
|
1364
|
+
False, "--until-pass", help="Retry until certificate passes (max 3 attempts)"
|
|
1365
|
+
),
|
|
1366
|
+
max_attempts: int = typer.Option(
|
|
1367
|
+
3, "--max-attempts", help="Maximum retry attempts for --until-pass mode"
|
|
1368
|
+
),
|
|
1369
|
+
timeout: int | None = typer.Option(
|
|
1370
|
+
None, "--timeout", help="Timeout in seconds for --until-pass mode"
|
|
1371
|
+
),
|
|
1372
|
+
baseline: str | None = typer.Option(
|
|
1373
|
+
None,
|
|
1374
|
+
"--baseline",
|
|
1375
|
+
help="Path to baseline report.json for certificate validation",
|
|
1376
|
+
),
|
|
1377
|
+
no_cleanup: bool = typer.Option(
|
|
1378
|
+
False, "--no-cleanup", help="Skip cleanup of temporary artifacts"
|
|
1379
|
+
),
|
|
1380
|
+
):
|
|
1381
|
+
"""
|
|
1382
|
+
Run InvarLock pipeline with the given configuration.
|
|
1383
|
+
|
|
1384
|
+
The command assembles non-overlapping preview/final windows, executes the
|
|
1385
|
+
GuardChain (invariants → spectral → RMT → variance), checks pairing/overlap
|
|
1386
|
+
invariants, enforces guard-overhead ≤1 %, and emits a run report plus JSONL
|
|
1387
|
+
events suitable for certificate generation.
|
|
1388
|
+
"""
|
|
1389
|
+
|
|
1390
|
+
try:
|
|
1391
|
+
from typer.models import OptionInfo as _TyperOptionInfo # noqa: F401
|
|
1392
|
+
except Exception: # pragma: no cover - typer internals may change
|
|
1393
|
+
_TyperOptionInfo = () # type: ignore[assignment]
|
|
1394
|
+
|
|
1395
|
+
config = _coerce_option(config)
|
|
1396
|
+
device = _coerce_option(device)
|
|
1397
|
+
profile = _coerce_option(profile)
|
|
1398
|
+
out = _coerce_option(out)
|
|
1399
|
+
edit = _coerce_option(edit)
|
|
1400
|
+
tier = _coerce_option(tier)
|
|
1401
|
+
probes = _coerce_option(probes)
|
|
1402
|
+
until_pass = bool(_coerce_option(until_pass, False))
|
|
1403
|
+
max_attempts = int(_coerce_option(max_attempts, 3))
|
|
1404
|
+
timeout = _coerce_option(timeout)
|
|
1405
|
+
baseline = _coerce_option(baseline)
|
|
1406
|
+
no_cleanup = bool(_coerce_option(no_cleanup, False))
|
|
1407
|
+
|
|
1408
|
+
# Use shared CLI coercers from invarlock.cli.utils
|
|
1409
|
+
|
|
1410
|
+
def _fail_run(message: str) -> None:
|
|
1411
|
+
console.print(f"[red]❌ {message}[/red]")
|
|
1412
|
+
# Generic failure path → exit 1 (InvarlockError paths handle code 3 separately)
|
|
1413
|
+
raise typer.Exit(1)
|
|
1414
|
+
|
|
1415
|
+
# Fail fast when torch is missing so users see a clear extras hint instead of
|
|
1416
|
+
# a raw ModuleNotFoundError from deeper imports.
|
|
1417
|
+
try:
|
|
1418
|
+
import torch as _torch # type: ignore[import]
|
|
1419
|
+
|
|
1420
|
+
_ = _torch # pragma: no cover
|
|
1421
|
+
except (ImportError, ModuleNotFoundError) as e:
|
|
1422
|
+
console.print(
|
|
1423
|
+
"❌ Torch is required for this command. "
|
|
1424
|
+
'Install extras with: pip install "invarlock[hf]" '
|
|
1425
|
+
'or "invarlock[adapters]".',
|
|
1426
|
+
style="red",
|
|
1427
|
+
markup=False,
|
|
1428
|
+
)
|
|
1429
|
+
raise typer.Exit(1) from e
|
|
1430
|
+
|
|
1431
|
+
# use module-level _extract_pairing_schedule
|
|
1432
|
+
|
|
1433
|
+
# use module-level _to_int_list, _tensor_or_list_to_ints, _safe_int
|
|
1434
|
+
|
|
1435
|
+
# Use the module-level _hash_sequences to avoid duplication
|
|
1436
|
+
|
|
1437
|
+
# use module-level _derive_mlm_seed
|
|
1438
|
+
|
|
1439
|
+
# use module-level _apply_mlm_masks
|
|
1440
|
+
|
|
1441
|
+
# use module-level _tokenizer_digest
|
|
1442
|
+
|
|
1443
|
+
try:
|
|
1444
|
+
# Import InvarLock components
|
|
1445
|
+
from invarlock.core.api import RunConfig
|
|
1446
|
+
from invarlock.core.registry import get_registry
|
|
1447
|
+
from invarlock.core.runner import CoreRunner
|
|
1448
|
+
from invarlock.eval.data import EvaluationWindow, get_provider
|
|
1449
|
+
from invarlock.reporting.report_types import create_empty_report
|
|
1450
|
+
|
|
1451
|
+
# Load and validate configuration via helper (preserves console prints)
|
|
1452
|
+
cfg = _prepare_config_for_run(
|
|
1453
|
+
config_path=config,
|
|
1454
|
+
profile=profile,
|
|
1455
|
+
edit=edit,
|
|
1456
|
+
tier=tier,
|
|
1457
|
+
probes=probes,
|
|
1458
|
+
console=console,
|
|
1459
|
+
)
|
|
1460
|
+
|
|
1461
|
+
# cfg prepared by helper above
|
|
1462
|
+
|
|
1463
|
+
adapter_name = str(getattr(cfg.model, "adapter", "")).lower()
|
|
1464
|
+
model_id_raw = str(getattr(cfg.model, "id", ""))
|
|
1465
|
+
model_profile = detect_model_profile(
|
|
1466
|
+
model_id=model_id_raw, adapter=adapter_name
|
|
1467
|
+
)
|
|
1468
|
+
tokenizer_hash: str | None = None
|
|
1469
|
+
tokenizer: Any | None = None
|
|
1470
|
+
|
|
1471
|
+
loss_cfg = getattr(cfg.eval, "loss", None)
|
|
1472
|
+
resolved_loss_type = (
|
|
1473
|
+
str(getattr(loss_cfg, "type", "auto")).lower() if loss_cfg else "auto"
|
|
1474
|
+
)
|
|
1475
|
+
if resolved_loss_type == "auto":
|
|
1476
|
+
resolved_loss_type = model_profile.default_loss
|
|
1477
|
+
use_mlm = resolved_loss_type == "mlm"
|
|
1478
|
+
mask_prob = _coerce_float(getattr(loss_cfg, "mask_prob", None), 0.15)
|
|
1479
|
+
mask_seed = _coerce_int(getattr(loss_cfg, "seed", None), 42)
|
|
1480
|
+
random_token_prob = _coerce_float(
|
|
1481
|
+
getattr(loss_cfg, "random_token_prob", None), 0.1
|
|
1482
|
+
)
|
|
1483
|
+
original_token_prob = _coerce_float(
|
|
1484
|
+
getattr(loss_cfg, "original_token_prob", None), 0.1
|
|
1485
|
+
)
|
|
1486
|
+
if loss_cfg is not None and getattr(loss_cfg, "type", None) == "auto":
|
|
1487
|
+
try:
|
|
1488
|
+
loss_cfg.type = resolved_loss_type # type: ignore[assignment]
|
|
1489
|
+
except Exception:
|
|
1490
|
+
pass
|
|
1491
|
+
|
|
1492
|
+
# Set deterministic seeds for Python/NumPy/Torch and record provenance
|
|
1493
|
+
raw_seed_value = 42
|
|
1494
|
+
if hasattr(cfg, "dataset"):
|
|
1495
|
+
try:
|
|
1496
|
+
raw_seed_value = getattr(cfg.dataset, "seed", 42)
|
|
1497
|
+
except Exception:
|
|
1498
|
+
raw_seed_value = 42
|
|
1499
|
+
try:
|
|
1500
|
+
seed_value = int(raw_seed_value)
|
|
1501
|
+
except (TypeError, ValueError, OverflowError):
|
|
1502
|
+
seed_value = 42
|
|
1503
|
+
set_seed(seed_value)
|
|
1504
|
+
# Enforce deterministic algorithms in CI/Release profiles when torch is available
|
|
1505
|
+
profile_label = (str(profile or "").lower()) if profile else None
|
|
1506
|
+
if torch is not None and profile_label in {"ci", "release"}:
|
|
1507
|
+
try: # pragma: no cover - behavior depends on torch availability
|
|
1508
|
+
if hasattr(torch, "use_deterministic_algorithms"):
|
|
1509
|
+
torch.use_deterministic_algorithms(True, warn_only=False)
|
|
1510
|
+
if hasattr(torch.backends, "cudnn"):
|
|
1511
|
+
torch.backends.cudnn.benchmark = False
|
|
1512
|
+
try:
|
|
1513
|
+
torch.backends.cudnn.deterministic = True # type: ignore[attr-defined]
|
|
1514
|
+
except Exception:
|
|
1515
|
+
pass
|
|
1516
|
+
except Exception:
|
|
1517
|
+
# If we cannot enforce determinism here, we will rely on core checks
|
|
1518
|
+
pass
|
|
1519
|
+
try:
|
|
1520
|
+
numpy_seed = int(np.random.get_state()[1][0])
|
|
1521
|
+
except Exception:
|
|
1522
|
+
numpy_seed = seed_value
|
|
1523
|
+
torch_seed = None
|
|
1524
|
+
if torch is not None:
|
|
1525
|
+
try:
|
|
1526
|
+
torch_seed = int(torch.initial_seed())
|
|
1527
|
+
except Exception:
|
|
1528
|
+
torch_seed = seed_value
|
|
1529
|
+
seed_bundle = {
|
|
1530
|
+
"python": int(seed_value),
|
|
1531
|
+
"numpy": int(numpy_seed),
|
|
1532
|
+
"torch": int(torch_seed) if torch_seed is not None else None,
|
|
1533
|
+
}
|
|
1534
|
+
console.print(
|
|
1535
|
+
"🎲 Deterministic seeds → "
|
|
1536
|
+
f"python={seed_bundle['python']}, numpy={seed_bundle['numpy']}, "
|
|
1537
|
+
f"torch={seed_bundle['torch'] if seed_bundle['torch'] is not None else 'N/A'}"
|
|
1538
|
+
)
|
|
1539
|
+
|
|
1540
|
+
# Resolve device and output directory
|
|
1541
|
+
resolved_device, output_dir = _resolve_device_and_output(
|
|
1542
|
+
cfg, device=device, out=out, console=console
|
|
1543
|
+
)
|
|
1544
|
+
|
|
1545
|
+
# Create run directory with timestamp
|
|
1546
|
+
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
|
1547
|
+
run_dir = output_dir / timestamp
|
|
1548
|
+
run_dir.mkdir(parents=True, exist_ok=True)
|
|
1549
|
+
|
|
1550
|
+
run_id = f"{output_dir.name}-{timestamp}" if output_dir.name else timestamp
|
|
1551
|
+
|
|
1552
|
+
console.print(f"📁 Output directory: {run_dir}")
|
|
1553
|
+
console.print(f"🆔 Run ID: {run_id}")
|
|
1554
|
+
|
|
1555
|
+
# Initialize retry controller if --until-pass mode enabled
|
|
1556
|
+
retry_controller = _init_retry_controller(
|
|
1557
|
+
until_pass=until_pass,
|
|
1558
|
+
max_attempts=max_attempts,
|
|
1559
|
+
timeout=timeout,
|
|
1560
|
+
baseline=baseline,
|
|
1561
|
+
console=console,
|
|
1562
|
+
)
|
|
1563
|
+
|
|
1564
|
+
baseline_report_data: dict[str, Any] | None = None
|
|
1565
|
+
pairing_schedule: dict[str, Any] | None = None
|
|
1566
|
+
if baseline:
|
|
1567
|
+
baseline_path = Path(baseline)
|
|
1568
|
+
if baseline_path.exists():
|
|
1569
|
+
try:
|
|
1570
|
+
with baseline_path.open(encoding="utf-8") as f:
|
|
1571
|
+
baseline_report_data = json.load(f)
|
|
1572
|
+
pairing_schedule = _extract_pairing_schedule(baseline_report_data)
|
|
1573
|
+
if pairing_schedule:
|
|
1574
|
+
console.print(
|
|
1575
|
+
"🧬 Loaded baseline evaluation schedule for pairing"
|
|
1576
|
+
)
|
|
1577
|
+
elif (profile or "").lower() == "release":
|
|
1578
|
+
console.print(
|
|
1579
|
+
f"[red]❌ Baseline report '{baseline}' does not contain evaluation_windows required for pairing.[/red]"
|
|
1580
|
+
)
|
|
1581
|
+
raise typer.Exit(1)
|
|
1582
|
+
else:
|
|
1583
|
+
console.print(
|
|
1584
|
+
f"[yellow]⚠️ Baseline report '{baseline}' lacks evaluation_windows; falling back to dataset schedule.[/yellow]"
|
|
1585
|
+
)
|
|
1586
|
+
baseline_report_data = None
|
|
1587
|
+
pairing_schedule = None
|
|
1588
|
+
except typer.Exit:
|
|
1589
|
+
raise
|
|
1590
|
+
except Exception as exc: # noqa: BLE001
|
|
1591
|
+
console.print(
|
|
1592
|
+
f"[yellow]⚠️ Failed to load baseline report '{baseline}': {exc}. Falling back to dataset schedule.[/yellow]"
|
|
1593
|
+
)
|
|
1594
|
+
baseline_report_data = None
|
|
1595
|
+
pairing_schedule = None
|
|
1596
|
+
else:
|
|
1597
|
+
console.print(
|
|
1598
|
+
f"[yellow]⚠️ Baseline report '{baseline}' not found. Falling back to dataset schedule.[/yellow]"
|
|
1599
|
+
)
|
|
1600
|
+
|
|
1601
|
+
requested_preview = int(getattr(cfg.dataset, "preview_n", 0))
|
|
1602
|
+
requested_final = int(getattr(cfg.dataset, "final_n", 0))
|
|
1603
|
+
effective_preview = requested_preview
|
|
1604
|
+
effective_final = requested_final
|
|
1605
|
+
preview_count = effective_preview
|
|
1606
|
+
final_count = effective_final
|
|
1607
|
+
# Default split prior to provider resolution; updated if provider exposes splits
|
|
1608
|
+
try:
|
|
1609
|
+
resolved_split = getattr(cfg.dataset, "split", None) or "validation"
|
|
1610
|
+
except Exception:
|
|
1611
|
+
resolved_split = "validation"
|
|
1612
|
+
used_fallback_split: bool = False
|
|
1613
|
+
|
|
1614
|
+
# Execute the pipeline using CoreRunner
|
|
1615
|
+
_print_pipeline_start(console)
|
|
1616
|
+
|
|
1617
|
+
# Get registry and create components
|
|
1618
|
+
registry = get_registry()
|
|
1619
|
+
adapter = registry.get_adapter(cfg.model.adapter)
|
|
1620
|
+
edit_name = getattr(getattr(cfg, "edit", None), "name", None)
|
|
1621
|
+
if not isinstance(edit_name, str) or not edit_name.strip():
|
|
1622
|
+
console.print(
|
|
1623
|
+
"[red]❌ Edit configuration must specify a non-empty `edit.name`.[/red]"
|
|
1624
|
+
)
|
|
1625
|
+
raise typer.Exit(1)
|
|
1626
|
+
try:
|
|
1627
|
+
edit_op = registry.get_edit(edit_name.strip())
|
|
1628
|
+
except Exception:
|
|
1629
|
+
console.print(
|
|
1630
|
+
f"[yellow]⚠️ Unknown edit '{edit_name.strip()}'. Using pass-through shim.[/yellow]"
|
|
1631
|
+
)
|
|
1632
|
+
edit_op = SimpleNamespace(name=edit_name.strip())
|
|
1633
|
+
|
|
1634
|
+
adapter_meta = registry.get_plugin_metadata(cfg.model.adapter, "adapters")
|
|
1635
|
+
try:
|
|
1636
|
+
from invarlock.cli.provenance import (
|
|
1637
|
+
extract_adapter_provenance,
|
|
1638
|
+
) # local import to avoid CLI import cycles
|
|
1639
|
+
|
|
1640
|
+
prov = extract_adapter_provenance(cfg.model.adapter)
|
|
1641
|
+
# Attach a small, stable provenance dict under adapter plugin metadata
|
|
1642
|
+
adapter_meta["provenance"] = prov.to_dict()
|
|
1643
|
+
except Exception:
|
|
1644
|
+
# Best-effort only; absence should not break runs
|
|
1645
|
+
pass
|
|
1646
|
+
try:
|
|
1647
|
+
edit_meta = registry.get_plugin_metadata(edit_name.strip(), "edits")
|
|
1648
|
+
except Exception:
|
|
1649
|
+
edit_meta = {
|
|
1650
|
+
"name": edit_name.strip(),
|
|
1651
|
+
"module": "edits.unknown",
|
|
1652
|
+
"version": "unknown",
|
|
1653
|
+
}
|
|
1654
|
+
|
|
1655
|
+
guards = []
|
|
1656
|
+
guard_metadata: list[dict[str, Any]] = []
|
|
1657
|
+
for guard_name in cfg.guards.order:
|
|
1658
|
+
if guard_name != "noop":
|
|
1659
|
+
try:
|
|
1660
|
+
guard = registry.get_guard(guard_name)
|
|
1661
|
+
guards.append(guard)
|
|
1662
|
+
guard_metadata.append(
|
|
1663
|
+
registry.get_plugin_metadata(guard_name, "guards")
|
|
1664
|
+
)
|
|
1665
|
+
except KeyError:
|
|
1666
|
+
console.print(
|
|
1667
|
+
f"[yellow]⚠️ Guard '{guard_name}' not found, skipping[/yellow]"
|
|
1668
|
+
)
|
|
1669
|
+
plugin_provenance = {
|
|
1670
|
+
"adapter": adapter_meta,
|
|
1671
|
+
"edit": edit_meta,
|
|
1672
|
+
"guards": guard_metadata,
|
|
1673
|
+
}
|
|
1674
|
+
|
|
1675
|
+
console.print(f"🔌 Adapter: {adapter.name}")
|
|
1676
|
+
|
|
1677
|
+
# Create run configuration
|
|
1678
|
+
def _to_serialisable_dict(section: object) -> dict[str, Any]:
|
|
1679
|
+
"""Coerce config fragments to plain dicts.
|
|
1680
|
+
|
|
1681
|
+
Handles InvarLockConfig sections (which wrap dicts in a private `_Obj` with
|
|
1682
|
+
`_data`) so downstream components (core.runner) see canonical mappings,
|
|
1683
|
+
e.g. `eval.bootstrap.replicates`.
|
|
1684
|
+
"""
|
|
1685
|
+
# Prefer native dump methods
|
|
1686
|
+
if hasattr(section, "model_dump"):
|
|
1687
|
+
return section.model_dump() # type: ignore[return-value]
|
|
1688
|
+
if hasattr(section, "dict"):
|
|
1689
|
+
try:
|
|
1690
|
+
return section.dict() # type: ignore[return-value]
|
|
1691
|
+
except Exception:
|
|
1692
|
+
pass
|
|
1693
|
+
# Unwrap CLI _Obj wrapper used by InvarLockConfig for attribute access
|
|
1694
|
+
try:
|
|
1695
|
+
raw = getattr(section, "_data", None)
|
|
1696
|
+
if isinstance(raw, dict):
|
|
1697
|
+
return raw
|
|
1698
|
+
except Exception:
|
|
1699
|
+
pass
|
|
1700
|
+
# Already a mapping
|
|
1701
|
+
if isinstance(section, dict):
|
|
1702
|
+
return section
|
|
1703
|
+
# Best-effort attribute dump
|
|
1704
|
+
try:
|
|
1705
|
+
data = vars(section)
|
|
1706
|
+
# Common case: {'_data': {...}}
|
|
1707
|
+
if isinstance(data, dict) and isinstance(data.get("_data"), dict):
|
|
1708
|
+
return data["_data"]
|
|
1709
|
+
return data # type: ignore[return-value]
|
|
1710
|
+
except TypeError:
|
|
1711
|
+
return {}
|
|
1712
|
+
|
|
1713
|
+
def _dump_guard(section: object) -> dict[str, Any]:
|
|
1714
|
+
data = _to_serialisable_dict(section)
|
|
1715
|
+
return data if isinstance(data, dict) else {}
|
|
1716
|
+
|
|
1717
|
+
guard_overrides = {
|
|
1718
|
+
"spectral": _dump_guard(getattr(cfg.guards, "spectral", {})),
|
|
1719
|
+
"rmt": _dump_guard(getattr(cfg.guards, "rmt", {})),
|
|
1720
|
+
"variance": _dump_guard(getattr(cfg.guards, "variance", {})),
|
|
1721
|
+
"invariants": _dump_guard(getattr(cfg.guards, "invariants", {})),
|
|
1722
|
+
}
|
|
1723
|
+
|
|
1724
|
+
if model_profile.invariants:
|
|
1725
|
+
invariants_policy = guard_overrides.setdefault("invariants", {})
|
|
1726
|
+
existing_checks = invariants_policy.get("profile_checks", [])
|
|
1727
|
+
if isinstance(existing_checks, list | tuple | set):
|
|
1728
|
+
checks_list = [str(item) for item in existing_checks]
|
|
1729
|
+
elif existing_checks:
|
|
1730
|
+
checks_list = [str(existing_checks)]
|
|
1731
|
+
else:
|
|
1732
|
+
checks_list = []
|
|
1733
|
+
for invariant in model_profile.invariants:
|
|
1734
|
+
invariant_name = str(invariant)
|
|
1735
|
+
if invariant_name not in checks_list:
|
|
1736
|
+
checks_list.append(invariant_name)
|
|
1737
|
+
invariants_policy["profile_checks"] = checks_list
|
|
1738
|
+
|
|
1739
|
+
run_context = {
|
|
1740
|
+
"eval": _to_serialisable_dict(cfg.eval),
|
|
1741
|
+
"dataset": _to_serialisable_dict(cfg.dataset),
|
|
1742
|
+
"guards": guard_overrides,
|
|
1743
|
+
"profile": profile if profile else "",
|
|
1744
|
+
"pairing_baseline": pairing_schedule,
|
|
1745
|
+
"seeds": seed_bundle,
|
|
1746
|
+
"plugins": plugin_provenance,
|
|
1747
|
+
"run_id": run_id,
|
|
1748
|
+
}
|
|
1749
|
+
run_context["model_profile"] = {
|
|
1750
|
+
"family": model_profile.family,
|
|
1751
|
+
"default_loss": model_profile.default_loss,
|
|
1752
|
+
"module_selectors": model_profile.module_selectors,
|
|
1753
|
+
"invariants": model_profile.invariants,
|
|
1754
|
+
"cert_lints": model_profile.cert_lints,
|
|
1755
|
+
}
|
|
1756
|
+
extra_context = _to_serialisable_dict(getattr(cfg, "context", {}))
|
|
1757
|
+
if isinstance(extra_context, dict):
|
|
1758
|
+
run_context.update(extra_context)
|
|
1759
|
+
try:
|
|
1760
|
+
run_context.setdefault("eval", {}).setdefault("loss", {})[
|
|
1761
|
+
"resolved_type"
|
|
1762
|
+
] = resolved_loss_type
|
|
1763
|
+
except Exception:
|
|
1764
|
+
pass
|
|
1765
|
+
run_config = RunConfig(
|
|
1766
|
+
device=resolved_device,
|
|
1767
|
+
max_pm_ratio=getattr(cfg.eval, "max_pm_ratio", 1.5),
|
|
1768
|
+
event_path=run_dir / "events.jsonl",
|
|
1769
|
+
context=run_context,
|
|
1770
|
+
)
|
|
1771
|
+
skip_model_load = False
|
|
1772
|
+
|
|
1773
|
+
# Load model using adapter
|
|
1774
|
+
# Load calibration data if dataset is configured
|
|
1775
|
+
calibration_data = None
|
|
1776
|
+
dataset_meta: dict[str, Any] = {}
|
|
1777
|
+
baseline_meta: dict[str, Any] = {}
|
|
1778
|
+
window_plan: dict[str, Any] | None = None
|
|
1779
|
+
if pairing_schedule:
|
|
1780
|
+
harvested = _validate_and_harvest_baseline_schedule(
|
|
1781
|
+
cfg,
|
|
1782
|
+
pairing_schedule,
|
|
1783
|
+
baseline_report_data,
|
|
1784
|
+
tokenizer_hash=tokenizer_hash,
|
|
1785
|
+
resolved_loss_type=resolved_loss_type,
|
|
1786
|
+
baseline_path_str=str(baseline) if baseline else None,
|
|
1787
|
+
console=console,
|
|
1788
|
+
)
|
|
1789
|
+
effective_preview = harvested["effective_preview"]
|
|
1790
|
+
effective_final = harvested["effective_final"]
|
|
1791
|
+
preview_count = harvested["preview_count"]
|
|
1792
|
+
final_count = harvested["final_count"]
|
|
1793
|
+
dataset_meta = harvested["dataset_meta"]
|
|
1794
|
+
window_plan = harvested["window_plan"]
|
|
1795
|
+
calibration_data = harvested["calibration_data"]
|
|
1796
|
+
if use_mlm and tokenizer is None:
|
|
1797
|
+
try:
|
|
1798
|
+
tokenizer, tokenizer_hash = resolve_tokenizer(model_profile)
|
|
1799
|
+
except Exception as exc:
|
|
1800
|
+
console.print(f"[red]❌ {exc}[/red]")
|
|
1801
|
+
raise typer.Exit(1) from exc
|
|
1802
|
+
preview_window_ids = pairing_schedule["preview"].get("window_ids")
|
|
1803
|
+
preview_labels = pairing_schedule["preview"].get("labels")
|
|
1804
|
+
for idx, (input_ids, attention_mask) in enumerate(
|
|
1805
|
+
zip(
|
|
1806
|
+
pairing_schedule["preview"]["input_ids"],
|
|
1807
|
+
pairing_schedule["preview"]["attention_masks"],
|
|
1808
|
+
strict=False,
|
|
1809
|
+
)
|
|
1810
|
+
):
|
|
1811
|
+
window_id = (
|
|
1812
|
+
preview_window_ids[idx]
|
|
1813
|
+
if preview_window_ids and idx < len(preview_window_ids)
|
|
1814
|
+
else idx
|
|
1815
|
+
)
|
|
1816
|
+
entry = {
|
|
1817
|
+
"input_ids": input_ids,
|
|
1818
|
+
"attention_mask": attention_mask,
|
|
1819
|
+
"window_id": f"preview::{window_id}",
|
|
1820
|
+
}
|
|
1821
|
+
if use_mlm:
|
|
1822
|
+
labels_list: list[int] = []
|
|
1823
|
+
if isinstance(preview_labels, list) and idx < len(preview_labels):
|
|
1824
|
+
labels_list = _tensor_or_list_to_ints(preview_labels[idx])
|
|
1825
|
+
if labels_list and any(token != -100 for token in labels_list):
|
|
1826
|
+
entry["labels"] = labels_list
|
|
1827
|
+
entry["mlm_masked"] = sum(
|
|
1828
|
+
1 for token in labels_list if token != -100
|
|
1829
|
+
)
|
|
1830
|
+
else:
|
|
1831
|
+
entry["labels"] = []
|
|
1832
|
+
entry["mlm_masked"] = 0
|
|
1833
|
+
# Prefer masked_token_counts if present in schedule
|
|
1834
|
+
mtc = pairing_schedule["preview"].get("masked_token_counts")
|
|
1835
|
+
if isinstance(mtc, list) and idx < len(mtc):
|
|
1836
|
+
try:
|
|
1837
|
+
entry["mlm_masked"] = int(mtc[idx])
|
|
1838
|
+
except Exception:
|
|
1839
|
+
pass
|
|
1840
|
+
calibration_data.append(entry)
|
|
1841
|
+
final_window_ids = pairing_schedule["final"].get("window_ids")
|
|
1842
|
+
final_labels = pairing_schedule["final"].get("labels")
|
|
1843
|
+
for idx, (input_ids, attention_mask) in enumerate(
|
|
1844
|
+
zip(
|
|
1845
|
+
pairing_schedule["final"]["input_ids"],
|
|
1846
|
+
pairing_schedule["final"]["attention_masks"],
|
|
1847
|
+
strict=False,
|
|
1848
|
+
)
|
|
1849
|
+
):
|
|
1850
|
+
window_id = (
|
|
1851
|
+
final_window_ids[idx]
|
|
1852
|
+
if final_window_ids and idx < len(final_window_ids)
|
|
1853
|
+
else idx
|
|
1854
|
+
)
|
|
1855
|
+
entry = {
|
|
1856
|
+
"input_ids": input_ids,
|
|
1857
|
+
"attention_mask": attention_mask,
|
|
1858
|
+
"window_id": f"final::{window_id}",
|
|
1859
|
+
}
|
|
1860
|
+
if use_mlm:
|
|
1861
|
+
labels_list: list[int] = []
|
|
1862
|
+
if isinstance(final_labels, list) and idx < len(final_labels):
|
|
1863
|
+
labels_list = _tensor_or_list_to_ints(final_labels[idx])
|
|
1864
|
+
if labels_list and any(token != -100 for token in labels_list):
|
|
1865
|
+
entry["labels"] = labels_list
|
|
1866
|
+
entry["mlm_masked"] = sum(
|
|
1867
|
+
1 for token in labels_list if token != -100
|
|
1868
|
+
)
|
|
1869
|
+
else:
|
|
1870
|
+
entry["labels"] = []
|
|
1871
|
+
entry["mlm_masked"] = 0
|
|
1872
|
+
# Prefer masked_token_counts if present in schedule
|
|
1873
|
+
mtc = pairing_schedule["final"].get("masked_token_counts")
|
|
1874
|
+
if isinstance(mtc, list) and idx < len(mtc):
|
|
1875
|
+
try:
|
|
1876
|
+
entry["mlm_masked"] = int(mtc[idx])
|
|
1877
|
+
except Exception:
|
|
1878
|
+
pass
|
|
1879
|
+
calibration_data.append(entry)
|
|
1880
|
+
preview_count = len(pairing_schedule["preview"]["input_ids"])
|
|
1881
|
+
final_count = len(pairing_schedule["final"]["input_ids"])
|
|
1882
|
+
effective_preview = int(preview_count)
|
|
1883
|
+
effective_final = int(final_count)
|
|
1884
|
+
preview_mask_total = 0
|
|
1885
|
+
final_mask_total = 0
|
|
1886
|
+
preview_mask_counts: list[int] = []
|
|
1887
|
+
final_mask_counts: list[int] = []
|
|
1888
|
+
if use_mlm:
|
|
1889
|
+
preview_entries = calibration_data[:preview_count]
|
|
1890
|
+
final_entries = calibration_data[preview_count:]
|
|
1891
|
+
|
|
1892
|
+
def _needs_masks(entries):
|
|
1893
|
+
missing_any = False
|
|
1894
|
+
counts = []
|
|
1895
|
+
for entry in entries:
|
|
1896
|
+
labels_val = entry.get("labels")
|
|
1897
|
+
has_label_masks = bool(
|
|
1898
|
+
isinstance(labels_val, list)
|
|
1899
|
+
and any(token != -100 for token in labels_val)
|
|
1900
|
+
)
|
|
1901
|
+
existing_count = int(entry.get("mlm_masked", 0))
|
|
1902
|
+
if not has_label_masks and existing_count <= 0:
|
|
1903
|
+
missing_any = True
|
|
1904
|
+
counts.append(int(entry.get("mlm_masked", 0)))
|
|
1905
|
+
return missing_any, counts
|
|
1906
|
+
|
|
1907
|
+
preview_missing, preview_counts_existing = _needs_masks(preview_entries)
|
|
1908
|
+
final_missing, final_counts_existing = _needs_masks(final_entries)
|
|
1909
|
+
|
|
1910
|
+
if preview_missing:
|
|
1911
|
+
preview_mask_total, preview_mask_counts = _apply_mlm_masks(
|
|
1912
|
+
preview_entries,
|
|
1913
|
+
tokenizer=tokenizer,
|
|
1914
|
+
mask_prob=mask_prob,
|
|
1915
|
+
seed=mask_seed,
|
|
1916
|
+
random_token_prob=random_token_prob,
|
|
1917
|
+
original_token_prob=original_token_prob,
|
|
1918
|
+
prefix="preview",
|
|
1919
|
+
)
|
|
1920
|
+
else:
|
|
1921
|
+
preview_mask_counts = preview_counts_existing
|
|
1922
|
+
preview_mask_total = sum(preview_mask_counts)
|
|
1923
|
+
|
|
1924
|
+
if final_missing:
|
|
1925
|
+
final_mask_total, final_mask_counts = _apply_mlm_masks(
|
|
1926
|
+
final_entries,
|
|
1927
|
+
tokenizer=tokenizer,
|
|
1928
|
+
mask_prob=mask_prob,
|
|
1929
|
+
seed=mask_seed,
|
|
1930
|
+
random_token_prob=random_token_prob,
|
|
1931
|
+
original_token_prob=original_token_prob,
|
|
1932
|
+
prefix="final",
|
|
1933
|
+
)
|
|
1934
|
+
else:
|
|
1935
|
+
final_mask_counts = final_counts_existing
|
|
1936
|
+
final_mask_total = sum(final_mask_counts)
|
|
1937
|
+
|
|
1938
|
+
# Ensure counts and labels set on entries
|
|
1939
|
+
if preview_mask_counts:
|
|
1940
|
+
for entry, count in zip(
|
|
1941
|
+
preview_entries, preview_mask_counts, strict=False
|
|
1942
|
+
):
|
|
1943
|
+
entry["mlm_masked"] = int(count)
|
|
1944
|
+
if final_mask_counts:
|
|
1945
|
+
for entry, count in zip(
|
|
1946
|
+
final_entries, final_mask_counts, strict=False
|
|
1947
|
+
):
|
|
1948
|
+
entry["mlm_masked"] = int(count)
|
|
1949
|
+
|
|
1950
|
+
if preview_count > 0 and preview_mask_total <= 0:
|
|
1951
|
+
_fail_run(
|
|
1952
|
+
"Baseline pairing schedule provided no masked tokens for preview windows; "
|
|
1953
|
+
"ensure MLM labels are present in the baseline report."
|
|
1954
|
+
)
|
|
1955
|
+
if final_count > 0 and final_mask_total <= 0:
|
|
1956
|
+
_fail_run(
|
|
1957
|
+
"Baseline pairing schedule provided no masked tokens for final windows; "
|
|
1958
|
+
"ensure MLM labels are present in the baseline report."
|
|
1959
|
+
)
|
|
1960
|
+
|
|
1961
|
+
dataset_meta["masked_tokens_preview"] = int(preview_mask_total)
|
|
1962
|
+
dataset_meta["masked_tokens_final"] = int(final_mask_total)
|
|
1963
|
+
dataset_meta["masked_tokens_total"] = int(
|
|
1964
|
+
preview_mask_total + final_mask_total
|
|
1965
|
+
)
|
|
1966
|
+
if os.environ.get("INVARLOCK_DEBUG_TRACE"):
|
|
1967
|
+
console.print(
|
|
1968
|
+
f"[debug] MLM pairing masks → preview={preview_mask_total}, final={final_mask_total}"
|
|
1969
|
+
)
|
|
1970
|
+
if "preview_total_tokens" not in dataset_meta:
|
|
1971
|
+
dataset_meta["preview_total_tokens"] = sum(
|
|
1972
|
+
len(_tensor_or_list_to_ints(seq))
|
|
1973
|
+
for seq in pairing_schedule["preview"]["input_ids"]
|
|
1974
|
+
)
|
|
1975
|
+
if "final_total_tokens" not in dataset_meta:
|
|
1976
|
+
dataset_meta["final_total_tokens"] = sum(
|
|
1977
|
+
len(_tensor_or_list_to_ints(seq))
|
|
1978
|
+
for seq in pairing_schedule["final"]["input_ids"]
|
|
1979
|
+
)
|
|
1980
|
+
if "preview_hash" not in dataset_meta:
|
|
1981
|
+
preview_hash = _hash_sequences(
|
|
1982
|
+
_tensor_or_list_to_ints(seq)
|
|
1983
|
+
for seq in pairing_schedule["preview"]["input_ids"]
|
|
1984
|
+
)
|
|
1985
|
+
dataset_meta["preview_hash"] = preview_hash
|
|
1986
|
+
else:
|
|
1987
|
+
preview_hash = dataset_meta["preview_hash"]
|
|
1988
|
+
if "final_hash" not in dataset_meta:
|
|
1989
|
+
final_hash = _hash_sequences(
|
|
1990
|
+
_tensor_or_list_to_ints(seq)
|
|
1991
|
+
for seq in pairing_schedule["final"]["input_ids"]
|
|
1992
|
+
)
|
|
1993
|
+
dataset_meta["final_hash"] = final_hash
|
|
1994
|
+
else:
|
|
1995
|
+
final_hash = dataset_meta["final_hash"]
|
|
1996
|
+
if "dataset_hash" not in dataset_meta:
|
|
1997
|
+
dataset_meta["dataset_hash"] = hashlib.blake2s(
|
|
1998
|
+
(str(preview_hash) + str(final_hash)).encode("utf-8"),
|
|
1999
|
+
digest_size=16,
|
|
2000
|
+
).hexdigest()
|
|
2001
|
+
if not window_plan:
|
|
2002
|
+
window_capacity = (
|
|
2003
|
+
baseline_meta.get("window_capacity")
|
|
2004
|
+
if isinstance(baseline_meta, dict)
|
|
2005
|
+
else {}
|
|
2006
|
+
)
|
|
2007
|
+
window_plan = {
|
|
2008
|
+
"profile": (profile or "").lower() or "baseline",
|
|
2009
|
+
"requested_preview": int(preview_count),
|
|
2010
|
+
"requested_final": int(final_count),
|
|
2011
|
+
"actual_preview": int(preview_count),
|
|
2012
|
+
"actual_final": int(final_count),
|
|
2013
|
+
"coverage_ok": True,
|
|
2014
|
+
"capacity": window_capacity or {},
|
|
2015
|
+
}
|
|
2016
|
+
if isinstance(window_plan, dict):
|
|
2017
|
+
dataset_meta.setdefault("window_plan", window_plan)
|
|
2018
|
+
capacity_meta = window_plan.get("capacity")
|
|
2019
|
+
if capacity_meta and "window_capacity" not in dataset_meta:
|
|
2020
|
+
dataset_meta["window_capacity"] = capacity_meta
|
|
2021
|
+
elif cfg.dataset.provider:
|
|
2022
|
+
console.print(f"📊 Loading dataset: {cfg.dataset.provider}")
|
|
2023
|
+
# Pass through provider-specific kwargs when available
|
|
2024
|
+
provider_kwargs = {}
|
|
2025
|
+
for key in (
|
|
2026
|
+
"dataset_name",
|
|
2027
|
+
"config_name",
|
|
2028
|
+
"text_field",
|
|
2029
|
+
"src_field",
|
|
2030
|
+
"tgt_field",
|
|
2031
|
+
"cache_dir",
|
|
2032
|
+
"max_samples",
|
|
2033
|
+
# Local providers (e.g., local_jsonl)
|
|
2034
|
+
"file",
|
|
2035
|
+
"path",
|
|
2036
|
+
"data_files",
|
|
2037
|
+
):
|
|
2038
|
+
try:
|
|
2039
|
+
val = getattr(cfg.dataset, key)
|
|
2040
|
+
except Exception:
|
|
2041
|
+
val = None
|
|
2042
|
+
if val is not None and val != "":
|
|
2043
|
+
provider_kwargs[key] = val
|
|
2044
|
+
# Resolve provider kind from config (supports string or mapping with kind)
|
|
2045
|
+
provider_val = getattr(cfg.dataset, "provider", None)
|
|
2046
|
+
provider_name = None
|
|
2047
|
+
if isinstance(provider_val, dict):
|
|
2048
|
+
provider_name = provider_val.get("kind")
|
|
2049
|
+
# Include nested provider-specific kwargs
|
|
2050
|
+
for k, v in provider_val.items():
|
|
2051
|
+
if k != "kind" and v is not None and v != "":
|
|
2052
|
+
provider_kwargs[k] = v
|
|
2053
|
+
elif isinstance(provider_val, str):
|
|
2054
|
+
provider_name = provider_val # noqa: F841
|
|
2055
|
+
else:
|
|
2056
|
+
# Support mapping-like provider configs (e.g., _Obj with .get)
|
|
2057
|
+
try:
|
|
2058
|
+
_ = provider_val.get("kind") # type: ignore[attr-defined]
|
|
2059
|
+
# Try to expose nested entries
|
|
2060
|
+
try:
|
|
2061
|
+
for k, v in provider_val._data.items(): # type: ignore[attr-defined]
|
|
2062
|
+
if k != "kind" and v is not None and v != "":
|
|
2063
|
+
provider_kwargs[k] = v
|
|
2064
|
+
except Exception:
|
|
2065
|
+
# Fallback: if items() exists
|
|
2066
|
+
try:
|
|
2067
|
+
for k, v in provider_val.items(): # type: ignore[attr-defined]
|
|
2068
|
+
if k != "kind" and v is not None and v != "":
|
|
2069
|
+
provider_kwargs[k] = v
|
|
2070
|
+
except Exception:
|
|
2071
|
+
pass
|
|
2072
|
+
except Exception:
|
|
2073
|
+
_ = None
|
|
2074
|
+
data_provider, resolved_split, used_fallback_split = (
|
|
2075
|
+
_resolve_provider_and_split(
|
|
2076
|
+
cfg,
|
|
2077
|
+
model_profile,
|
|
2078
|
+
get_provider_fn=get_provider,
|
|
2079
|
+
provider_kwargs=provider_kwargs,
|
|
2080
|
+
console=console,
|
|
2081
|
+
resolved_device=resolved_device,
|
|
2082
|
+
)
|
|
2083
|
+
)
|
|
2084
|
+
|
|
2085
|
+
# Load tokenizer for dataset processing
|
|
2086
|
+
try:
|
|
2087
|
+
tokenizer, tokenizer_hash = resolve_tokenizer(model_profile)
|
|
2088
|
+
except Exception as exc:
|
|
2089
|
+
console.print(f"[red]❌ {exc}[/red]")
|
|
2090
|
+
raise typer.Exit(1) from exc
|
|
2091
|
+
|
|
2092
|
+
dataset_stride = getattr(
|
|
2093
|
+
cfg.dataset, "stride", getattr(cfg.dataset, "seq_len", 0) // 2
|
|
2094
|
+
)
|
|
2095
|
+
release_profile = (profile or "").lower() == "release"
|
|
2096
|
+
if release_profile and not pairing_schedule:
|
|
2097
|
+
estimate_fn = getattr(data_provider, "estimate_capacity", None)
|
|
2098
|
+
if callable(estimate_fn):
|
|
2099
|
+
capacity_fast = bool(getattr(cfg.eval, "capacity_fast", False))
|
|
2100
|
+
capacity_meta = estimate_fn(
|
|
2101
|
+
tokenizer=tokenizer,
|
|
2102
|
+
seq_len=cfg.dataset.seq_len,
|
|
2103
|
+
stride=dataset_stride,
|
|
2104
|
+
split=resolved_split,
|
|
2105
|
+
target_total=requested_preview + requested_final,
|
|
2106
|
+
fast_mode=capacity_fast,
|
|
2107
|
+
)
|
|
2108
|
+
variance_policy = getattr(cfg.guards, "variance", None)
|
|
2109
|
+
max_calibration = (
|
|
2110
|
+
getattr(variance_policy, "max_calib", 0)
|
|
2111
|
+
if variance_policy is not None
|
|
2112
|
+
else 0
|
|
2113
|
+
)
|
|
2114
|
+
try:
|
|
2115
|
+
window_plan = _maybe_plan_release_windows(
|
|
2116
|
+
capacity_meta,
|
|
2117
|
+
requested_preview=requested_preview,
|
|
2118
|
+
requested_final=requested_final,
|
|
2119
|
+
max_calibration=max_calibration,
|
|
2120
|
+
console=console,
|
|
2121
|
+
)
|
|
2122
|
+
except RuntimeError as err:
|
|
2123
|
+
console.print(f"[red]❌ {err}[/red]")
|
|
2124
|
+
raise typer.Exit(1) from err
|
|
2125
|
+
|
|
2126
|
+
actual_per_arm = int(window_plan["actual_preview"])
|
|
2127
|
+
effective_preview = actual_per_arm
|
|
2128
|
+
effective_final = actual_per_arm
|
|
2129
|
+
preview_count = effective_preview
|
|
2130
|
+
final_count = effective_final
|
|
2131
|
+
dataset_stride = getattr(
|
|
2132
|
+
cfg.dataset, "stride", getattr(cfg.dataset, "seq_len", 0)
|
|
2133
|
+
)
|
|
2134
|
+
else:
|
|
2135
|
+
console.print(
|
|
2136
|
+
"[yellow]⚠️ Release profile requested but dataset provider "
|
|
2137
|
+
"does not expose capacity estimation; using configured window counts.[/yellow]"
|
|
2138
|
+
)
|
|
2139
|
+
|
|
2140
|
+
preview_records: list[tuple[list[int], list[int]]] = []
|
|
2141
|
+
final_records: list[tuple[list[int], list[int]]] = []
|
|
2142
|
+
|
|
2143
|
+
while True:
|
|
2144
|
+
preview_window, final_window = data_provider.windows(
|
|
2145
|
+
tokenizer=tokenizer,
|
|
2146
|
+
seq_len=cfg.dataset.seq_len,
|
|
2147
|
+
stride=getattr(cfg.dataset, "stride", cfg.dataset.seq_len // 2),
|
|
2148
|
+
preview_n=effective_preview,
|
|
2149
|
+
final_n=effective_final,
|
|
2150
|
+
seed=getattr(cfg.dataset, "seed", 42),
|
|
2151
|
+
split=resolved_split,
|
|
2152
|
+
)
|
|
2153
|
+
|
|
2154
|
+
preview_count = len(getattr(preview_window, "input_ids", []))
|
|
2155
|
+
final_count = len(getattr(final_window, "input_ids", []))
|
|
2156
|
+
is_eval_window = isinstance(
|
|
2157
|
+
preview_window, EvaluationWindow
|
|
2158
|
+
) and isinstance(final_window, EvaluationWindow)
|
|
2159
|
+
if is_eval_window:
|
|
2160
|
+
if (
|
|
2161
|
+
preview_count != effective_preview
|
|
2162
|
+
or final_count != effective_final
|
|
2163
|
+
):
|
|
2164
|
+
_fail_run(
|
|
2165
|
+
"Dataset provider returned mismatched preview/final counts "
|
|
2166
|
+
f"({preview_count}/{final_count}) "
|
|
2167
|
+
f"expected ({effective_preview}/{effective_final}). "
|
|
2168
|
+
"CI/Release profiles require exact parity."
|
|
2169
|
+
)
|
|
2170
|
+
else:
|
|
2171
|
+
preview_count = effective_preview
|
|
2172
|
+
final_count = effective_final
|
|
2173
|
+
|
|
2174
|
+
# Optional: provider-supplied labels for seq2seq
|
|
2175
|
+
provider_labels_prev = None
|
|
2176
|
+
provider_labels_fin = None
|
|
2177
|
+
try:
|
|
2178
|
+
provider_labels_prev = getattr(
|
|
2179
|
+
data_provider, "last_preview_labels", None
|
|
2180
|
+
)
|
|
2181
|
+
provider_labels_fin = getattr(
|
|
2182
|
+
data_provider, "last_final_labels", None
|
|
2183
|
+
)
|
|
2184
|
+
except Exception:
|
|
2185
|
+
provider_labels_prev = None
|
|
2186
|
+
provider_labels_fin = None
|
|
2187
|
+
|
|
2188
|
+
preview_records = []
|
|
2189
|
+
preview_indices_raw = getattr(preview_window, "indices", [])
|
|
2190
|
+
if isinstance(preview_indices_raw, list):
|
|
2191
|
+
preview_indices = preview_indices_raw
|
|
2192
|
+
else:
|
|
2193
|
+
try:
|
|
2194
|
+
preview_indices = list(preview_indices_raw)
|
|
2195
|
+
except TypeError:
|
|
2196
|
+
preview_indices = []
|
|
2197
|
+
for idx_local, (input_ids, attention_mask) in enumerate(
|
|
2198
|
+
zip(
|
|
2199
|
+
preview_window.input_ids,
|
|
2200
|
+
preview_window.attention_masks,
|
|
2201
|
+
strict=False,
|
|
2202
|
+
)
|
|
2203
|
+
):
|
|
2204
|
+
input_ids_list = _tensor_or_list_to_ints(input_ids)
|
|
2205
|
+
attention_mask_list = (
|
|
2206
|
+
_tensor_or_list_to_ints(attention_mask)
|
|
2207
|
+
if attention_mask is not None
|
|
2208
|
+
else [1] * len(input_ids_list)
|
|
2209
|
+
)
|
|
2210
|
+
dataset_index = (
|
|
2211
|
+
_safe_int(preview_indices[idx_local])
|
|
2212
|
+
if idx_local < len(preview_indices)
|
|
2213
|
+
else idx_local
|
|
2214
|
+
)
|
|
2215
|
+
rec = {
|
|
2216
|
+
"input_ids": input_ids_list,
|
|
2217
|
+
"attention_mask": attention_mask_list,
|
|
2218
|
+
"dataset_index": dataset_index,
|
|
2219
|
+
}
|
|
2220
|
+
# Attach provider labels for seq2seq if available
|
|
2221
|
+
if provider_labels_prev is not None and idx_local < len(
|
|
2222
|
+
provider_labels_prev
|
|
2223
|
+
):
|
|
2224
|
+
rec["labels"] = _tensor_or_list_to_ints(
|
|
2225
|
+
provider_labels_prev[idx_local]
|
|
2226
|
+
)
|
|
2227
|
+
preview_records.append(rec)
|
|
2228
|
+
|
|
2229
|
+
final_records = []
|
|
2230
|
+
final_indices_raw = getattr(final_window, "indices", [])
|
|
2231
|
+
if isinstance(final_indices_raw, list):
|
|
2232
|
+
final_indices = final_indices_raw
|
|
2233
|
+
else:
|
|
2234
|
+
try:
|
|
2235
|
+
final_indices = list(final_indices_raw)
|
|
2236
|
+
except TypeError:
|
|
2237
|
+
final_indices = []
|
|
2238
|
+
for idx_local, (input_ids, attention_mask) in enumerate(
|
|
2239
|
+
zip(
|
|
2240
|
+
final_window.input_ids,
|
|
2241
|
+
final_window.attention_masks,
|
|
2242
|
+
strict=False,
|
|
2243
|
+
)
|
|
2244
|
+
):
|
|
2245
|
+
input_ids_list = _tensor_or_list_to_ints(input_ids)
|
|
2246
|
+
attention_mask_list = (
|
|
2247
|
+
_tensor_or_list_to_ints(attention_mask)
|
|
2248
|
+
if attention_mask is not None
|
|
2249
|
+
else [1] * len(input_ids_list)
|
|
2250
|
+
)
|
|
2251
|
+
dataset_index = (
|
|
2252
|
+
_safe_int(final_indices[idx_local])
|
|
2253
|
+
if idx_local < len(final_indices)
|
|
2254
|
+
else idx_local
|
|
2255
|
+
)
|
|
2256
|
+
final_records.append(
|
|
2257
|
+
{
|
|
2258
|
+
"input_ids": input_ids_list,
|
|
2259
|
+
"attention_mask": attention_mask_list,
|
|
2260
|
+
"dataset_index": dataset_index,
|
|
2261
|
+
}
|
|
2262
|
+
)
|
|
2263
|
+
|
|
2264
|
+
if use_mlm:
|
|
2265
|
+
temp_preview_records = [
|
|
2266
|
+
{
|
|
2267
|
+
"input_ids": list(rec["input_ids"]),
|
|
2268
|
+
"attention_mask": list(rec["attention_mask"]),
|
|
2269
|
+
"dataset_index": rec.get("dataset_index"),
|
|
2270
|
+
"window_id": rec.get("window_id"),
|
|
2271
|
+
}
|
|
2272
|
+
for rec in preview_records
|
|
2273
|
+
]
|
|
2274
|
+
temp_final_records = [
|
|
2275
|
+
{
|
|
2276
|
+
"input_ids": list(rec["input_ids"]),
|
|
2277
|
+
"attention_mask": list(rec["attention_mask"]),
|
|
2278
|
+
"dataset_index": rec.get("dataset_index"),
|
|
2279
|
+
"window_id": rec.get("window_id"),
|
|
2280
|
+
}
|
|
2281
|
+
for rec in final_records
|
|
2282
|
+
]
|
|
2283
|
+
_apply_mlm_masks(
|
|
2284
|
+
temp_preview_records,
|
|
2285
|
+
tokenizer=tokenizer,
|
|
2286
|
+
mask_prob=mask_prob,
|
|
2287
|
+
seed=mask_seed,
|
|
2288
|
+
random_token_prob=random_token_prob,
|
|
2289
|
+
original_token_prob=original_token_prob,
|
|
2290
|
+
prefix="preview",
|
|
2291
|
+
)
|
|
2292
|
+
_apply_mlm_masks(
|
|
2293
|
+
temp_final_records,
|
|
2294
|
+
tokenizer=tokenizer,
|
|
2295
|
+
mask_prob=mask_prob,
|
|
2296
|
+
seed=mask_seed,
|
|
2297
|
+
random_token_prob=random_token_prob,
|
|
2298
|
+
original_token_prob=original_token_prob,
|
|
2299
|
+
prefix="final",
|
|
2300
|
+
)
|
|
2301
|
+
records_for_signatures = temp_preview_records + temp_final_records
|
|
2302
|
+
else:
|
|
2303
|
+
records_for_signatures = preview_records + final_records
|
|
2304
|
+
|
|
2305
|
+
signatures = []
|
|
2306
|
+
for record in records_for_signatures:
|
|
2307
|
+
tokens = record["input_ids"]
|
|
2308
|
+
masks = record["attention_mask"]
|
|
2309
|
+
signatures.append(
|
|
2310
|
+
tuple(
|
|
2311
|
+
tok
|
|
2312
|
+
for tok, mask in zip(tokens, masks, strict=False)
|
|
2313
|
+
if mask
|
|
2314
|
+
)
|
|
2315
|
+
)
|
|
2316
|
+
|
|
2317
|
+
unique_sequences = len(set(signatures))
|
|
2318
|
+
combined_total = len(signatures)
|
|
2319
|
+
if unique_sequences == combined_total:
|
|
2320
|
+
break
|
|
2321
|
+
|
|
2322
|
+
deficit = combined_total - unique_sequences
|
|
2323
|
+
reduction = max(5, int(deficit) if deficit > 0 else 1)
|
|
2324
|
+
proposed_per_arm = preview_count - reduction
|
|
2325
|
+
if proposed_per_arm >= preview_count:
|
|
2326
|
+
proposed_per_arm = preview_count - 1
|
|
2327
|
+
min_per_arm_floor = RELEASE_MIN_WINDOWS_PER_ARM
|
|
2328
|
+
if window_plan is None or window_plan.get("profile") != "release":
|
|
2329
|
+
min_per_arm_floor = max(
|
|
2330
|
+
10,
|
|
2331
|
+
min(
|
|
2332
|
+
int(requested_preview or 0) or RELEASE_MIN_WINDOWS_PER_ARM,
|
|
2333
|
+
int(requested_final or 0) or RELEASE_MIN_WINDOWS_PER_ARM,
|
|
2334
|
+
)
|
|
2335
|
+
// 2,
|
|
2336
|
+
)
|
|
2337
|
+
if proposed_per_arm < min_per_arm_floor:
|
|
2338
|
+
raise RuntimeError(
|
|
2339
|
+
"Unable to construct non-overlapping windows within minimum window floor."
|
|
2340
|
+
)
|
|
2341
|
+
console.print(
|
|
2342
|
+
f"[yellow]⚠️ Detected {deficit} duplicate windows; reducing per-arm windows to {proposed_per_arm} and retrying stratification.[/yellow]"
|
|
2343
|
+
)
|
|
2344
|
+
|
|
2345
|
+
effective_preview = proposed_per_arm
|
|
2346
|
+
effective_final = proposed_per_arm
|
|
2347
|
+
preview_count = effective_preview
|
|
2348
|
+
final_count = effective_final
|
|
2349
|
+
if window_plan is not None:
|
|
2350
|
+
window_plan.setdefault("dedupe_adjustments", []).append(
|
|
2351
|
+
{
|
|
2352
|
+
"deficit": int(deficit),
|
|
2353
|
+
"proposed_per_arm": int(proposed_per_arm),
|
|
2354
|
+
}
|
|
2355
|
+
)
|
|
2356
|
+
window_plan["actual_preview"] = proposed_per_arm
|
|
2357
|
+
window_plan["actual_final"] = proposed_per_arm
|
|
2358
|
+
continue
|
|
2359
|
+
|
|
2360
|
+
if window_plan is None:
|
|
2361
|
+
window_plan = {
|
|
2362
|
+
"profile": (profile or "").lower() or "default",
|
|
2363
|
+
"requested_preview": int(requested_preview),
|
|
2364
|
+
"requested_final": int(requested_final),
|
|
2365
|
+
"actual_preview": int(preview_count),
|
|
2366
|
+
"actual_final": int(final_count),
|
|
2367
|
+
"coverage_ok": preview_count == final_count,
|
|
2368
|
+
"capacity": {},
|
|
2369
|
+
}
|
|
2370
|
+
else:
|
|
2371
|
+
window_plan["actual_preview"] = int(preview_count)
|
|
2372
|
+
window_plan["actual_final"] = int(final_count)
|
|
2373
|
+
window_plan["coverage_ok"] = (
|
|
2374
|
+
window_plan.get("coverage_ok", True)
|
|
2375
|
+
and preview_count == final_count
|
|
2376
|
+
)
|
|
2377
|
+
|
|
2378
|
+
calibration_data: list[dict[str, Any]] = []
|
|
2379
|
+
preview_mask_total = 0
|
|
2380
|
+
final_mask_total = 0
|
|
2381
|
+
preview_mask_counts: list[int] = []
|
|
2382
|
+
final_mask_counts: list[int] = []
|
|
2383
|
+
if use_mlm:
|
|
2384
|
+
preview_mask_total, preview_mask_counts = _apply_mlm_masks(
|
|
2385
|
+
preview_records,
|
|
2386
|
+
tokenizer=tokenizer,
|
|
2387
|
+
mask_prob=mask_prob,
|
|
2388
|
+
seed=mask_seed,
|
|
2389
|
+
random_token_prob=random_token_prob,
|
|
2390
|
+
original_token_prob=original_token_prob,
|
|
2391
|
+
prefix="preview",
|
|
2392
|
+
)
|
|
2393
|
+
final_mask_total, final_mask_counts = _apply_mlm_masks(
|
|
2394
|
+
final_records,
|
|
2395
|
+
tokenizer=tokenizer,
|
|
2396
|
+
mask_prob=mask_prob,
|
|
2397
|
+
seed=mask_seed,
|
|
2398
|
+
random_token_prob=random_token_prob,
|
|
2399
|
+
original_token_prob=original_token_prob,
|
|
2400
|
+
prefix="final",
|
|
2401
|
+
)
|
|
2402
|
+
else:
|
|
2403
|
+
preview_mask_counts = [0] * len(preview_records)
|
|
2404
|
+
final_mask_counts = [0] * len(final_records)
|
|
2405
|
+
|
|
2406
|
+
preview_sequences = [record["input_ids"] for record in preview_records]
|
|
2407
|
+
for idx, record in enumerate(preview_records):
|
|
2408
|
+
entry = {
|
|
2409
|
+
"input_ids": record["input_ids"],
|
|
2410
|
+
"attention_mask": record["attention_mask"],
|
|
2411
|
+
"window_id": f"preview::{idx}",
|
|
2412
|
+
"dataset_index": record.get("dataset_index"),
|
|
2413
|
+
"mlm_masked": record.get("mlm_masked", 0),
|
|
2414
|
+
}
|
|
2415
|
+
if use_mlm:
|
|
2416
|
+
entry["labels"] = record.get(
|
|
2417
|
+
"labels", [-100] * len(record["input_ids"])
|
|
2418
|
+
)
|
|
2419
|
+
calibration_data.append(entry)
|
|
2420
|
+
|
|
2421
|
+
final_sequences = [record["input_ids"] for record in final_records]
|
|
2422
|
+
for idx, record in enumerate(final_records):
|
|
2423
|
+
entry = {
|
|
2424
|
+
"input_ids": record["input_ids"],
|
|
2425
|
+
"attention_mask": record["attention_mask"],
|
|
2426
|
+
"window_id": f"final::{idx}",
|
|
2427
|
+
"dataset_index": record.get("dataset_index"),
|
|
2428
|
+
"mlm_masked": record.get("mlm_masked", 0),
|
|
2429
|
+
}
|
|
2430
|
+
if use_mlm:
|
|
2431
|
+
entry["labels"] = record.get(
|
|
2432
|
+
"labels", [-100] * len(record["input_ids"])
|
|
2433
|
+
)
|
|
2434
|
+
elif provider_labels_fin is not None and idx < len(provider_labels_fin):
|
|
2435
|
+
entry["labels"] = _tensor_or_list_to_ints(provider_labels_fin[idx])
|
|
2436
|
+
calibration_data.append(entry)
|
|
2437
|
+
|
|
2438
|
+
masked_tokens_total = preview_mask_total + final_mask_total
|
|
2439
|
+
preview_hash = _hash_sequences(preview_sequences)
|
|
2440
|
+
final_hash = _hash_sequences(final_sequences)
|
|
2441
|
+
dataset_meta = {
|
|
2442
|
+
"tokenizer_name": getattr(tokenizer, "name_or_path", "unknown"),
|
|
2443
|
+
"tokenizer_hash": tokenizer_hash
|
|
2444
|
+
if tokenizer_hash is not None
|
|
2445
|
+
else _tokenizer_digest(tokenizer),
|
|
2446
|
+
"vocab_size": _safe_int(getattr(tokenizer, "vocab_size", 0)),
|
|
2447
|
+
"bos_token": getattr(tokenizer, "bos_token", None),
|
|
2448
|
+
"eos_token": getattr(tokenizer, "eos_token", None),
|
|
2449
|
+
"pad_token": getattr(tokenizer, "pad_token", None),
|
|
2450
|
+
"add_prefix_space": getattr(tokenizer, "add_prefix_space", None),
|
|
2451
|
+
"dataset_hash": hashlib.blake2s(
|
|
2452
|
+
(preview_hash + final_hash).encode("utf-8"), digest_size=16
|
|
2453
|
+
).hexdigest(),
|
|
2454
|
+
"preview_hash": preview_hash,
|
|
2455
|
+
"final_hash": final_hash,
|
|
2456
|
+
"preview_total_tokens": sum(len(seq) for seq in preview_sequences),
|
|
2457
|
+
"final_total_tokens": sum(len(seq) for seq in final_sequences),
|
|
2458
|
+
}
|
|
2459
|
+
dataset_meta["loss_type"] = resolved_loss_type
|
|
2460
|
+
if use_mlm:
|
|
2461
|
+
dataset_meta["masked_tokens_preview"] = int(preview_mask_total)
|
|
2462
|
+
dataset_meta["masked_tokens_final"] = int(final_mask_total)
|
|
2463
|
+
dataset_meta["masked_tokens_total"] = int(masked_tokens_total)
|
|
2464
|
+
if window_plan:
|
|
2465
|
+
dataset_meta["window_plan"] = window_plan
|
|
2466
|
+
capacity_meta = window_plan.get("capacity")
|
|
2467
|
+
if capacity_meta:
|
|
2468
|
+
dataset_meta["window_capacity"] = capacity_meta
|
|
2469
|
+
strat_stats = getattr(data_provider, "stratification_stats", None)
|
|
2470
|
+
if strat_stats:
|
|
2471
|
+
dataset_meta["stratification"] = strat_stats
|
|
2472
|
+
scorer_profile = getattr(data_provider, "scorer_profile", None)
|
|
2473
|
+
if scorer_profile:
|
|
2474
|
+
dataset_meta["scorer_profile"] = scorer_profile
|
|
2475
|
+
|
|
2476
|
+
try:
|
|
2477
|
+
run_context["dataset"]["preview_n"] = preview_count
|
|
2478
|
+
run_context["dataset"]["final_n"] = final_count
|
|
2479
|
+
except Exception:
|
|
2480
|
+
pass
|
|
2481
|
+
run_context["dataset_meta"] = dataset_meta
|
|
2482
|
+
if window_plan:
|
|
2483
|
+
run_context["window_plan"] = window_plan
|
|
2484
|
+
|
|
2485
|
+
if os.environ.get("INVARLOCK_DEBUG_TRACE"):
|
|
2486
|
+
console.print(
|
|
2487
|
+
"[debug] calibration batch size => preview="
|
|
2488
|
+
f"{preview_count} final={final_count} total={len(calibration_data)}"
|
|
2489
|
+
)
|
|
2490
|
+
if use_mlm and calibration_data:
|
|
2491
|
+
masked_preview = sum(
|
|
2492
|
+
entry.get("mlm_masked", 0)
|
|
2493
|
+
for entry in calibration_data[:preview_count]
|
|
2494
|
+
)
|
|
2495
|
+
masked_final = sum(
|
|
2496
|
+
entry.get("mlm_masked", 0)
|
|
2497
|
+
for entry in calibration_data[preview_count:]
|
|
2498
|
+
)
|
|
2499
|
+
console.print(
|
|
2500
|
+
f"[debug] masked tokens (preview/final) = {masked_preview}/{masked_final}"
|
|
2501
|
+
)
|
|
2502
|
+
console.print(
|
|
2503
|
+
f"[debug] sample labels first preview entry (first 10) = {calibration_data[0]['labels'][:10]}"
|
|
2504
|
+
)
|
|
2505
|
+
|
|
2506
|
+
# Execute the real pipeline using CoreRunner
|
|
2507
|
+
console.print(f"⚙️ Executing pipeline with {len(guards)} guards...")
|
|
2508
|
+
runner = CoreRunner()
|
|
2509
|
+
|
|
2510
|
+
# Prepare auto configuration for tier resolution
|
|
2511
|
+
# Build auto configuration with safe fallbacks when section/keys are absent
|
|
2512
|
+
try:
|
|
2513
|
+
auto_enabled = bool(cfg.auto.enabled)
|
|
2514
|
+
except Exception:
|
|
2515
|
+
auto_enabled = False
|
|
2516
|
+
try:
|
|
2517
|
+
auto_tier = cfg.auto.tier
|
|
2518
|
+
except Exception:
|
|
2519
|
+
auto_tier = "balanced"
|
|
2520
|
+
try:
|
|
2521
|
+
auto_probes = int(cfg.auto.probes)
|
|
2522
|
+
except Exception:
|
|
2523
|
+
auto_probes = 0
|
|
2524
|
+
try:
|
|
2525
|
+
auto_target_ratio = float(cfg.auto.target_pm_ratio)
|
|
2526
|
+
except Exception:
|
|
2527
|
+
auto_target_ratio = 2.0
|
|
2528
|
+
|
|
2529
|
+
auto_config = {
|
|
2530
|
+
"enabled": auto_enabled,
|
|
2531
|
+
"tier": auto_tier,
|
|
2532
|
+
"probes": auto_probes,
|
|
2533
|
+
"target_pm_ratio": auto_target_ratio,
|
|
2534
|
+
}
|
|
2535
|
+
|
|
2536
|
+
# Extract edit configuration parameters
|
|
2537
|
+
edit_config = {}
|
|
2538
|
+
if hasattr(cfg.edit, "plan") and cfg.edit.plan:
|
|
2539
|
+
try:
|
|
2540
|
+
# Accept plain dicts, dict-like wrappers, or nested objects
|
|
2541
|
+
plan_obj = getattr(cfg.edit, "plan", {})
|
|
2542
|
+
if isinstance(plan_obj, dict):
|
|
2543
|
+
edit_config = dict(plan_obj)
|
|
2544
|
+
else:
|
|
2545
|
+
# Best-effort unwrap for InvarLockConfig _Obj wrapper
|
|
2546
|
+
plan_data = getattr(plan_obj, "_data", None)
|
|
2547
|
+
if isinstance(plan_data, dict):
|
|
2548
|
+
edit_config = dict(plan_data)
|
|
2549
|
+
elif hasattr(plan_obj, "items"):
|
|
2550
|
+
edit_config = dict(plan_obj) # type: ignore[arg-type]
|
|
2551
|
+
except (TypeError, AttributeError):
|
|
2552
|
+
pass
|
|
2553
|
+
elif hasattr(cfg.edit, "parameters") and cfg.edit.parameters:
|
|
2554
|
+
try:
|
|
2555
|
+
if hasattr(cfg.edit.parameters, "items"):
|
|
2556
|
+
edit_config = dict(cfg.edit.parameters)
|
|
2557
|
+
elif isinstance(cfg.edit.parameters, dict):
|
|
2558
|
+
edit_config = cfg.edit.parameters
|
|
2559
|
+
except (TypeError, AttributeError):
|
|
2560
|
+
pass
|
|
2561
|
+
|
|
2562
|
+
if (
|
|
2563
|
+
model_profile.module_selectors
|
|
2564
|
+
and "module_selectors" not in edit_config
|
|
2565
|
+
and isinstance(model_profile.module_selectors, dict)
|
|
2566
|
+
):
|
|
2567
|
+
edit_config["module_selectors"] = {
|
|
2568
|
+
key: list(values)
|
|
2569
|
+
for key, values in model_profile.module_selectors.items()
|
|
2570
|
+
}
|
|
2571
|
+
|
|
2572
|
+
console.print(f"✂️ Edit: {edit_op.name}")
|
|
2573
|
+
console.print(f"🛡️ Guards: {[g.name for g in guards]}")
|
|
2574
|
+
|
|
2575
|
+
# Model load/snapshot strategy
|
|
2576
|
+
model = None
|
|
2577
|
+
restore_fn = None
|
|
2578
|
+
snapshot_tmpdir: str | None = None
|
|
2579
|
+
|
|
2580
|
+
# Try single-load with snapshot/restore if adapter supports it; fallback to reload per attempt
|
|
2581
|
+
try:
|
|
2582
|
+
# Load once
|
|
2583
|
+
console.print(f"🔧 Loading model once: {cfg.model.id}")
|
|
2584
|
+
model = adapter.load_model(cfg.model.id, device=resolved_device)
|
|
2585
|
+
|
|
2586
|
+
# No edit-specific bootstrap logic
|
|
2587
|
+
|
|
2588
|
+
def _estimate_model_bytes(m: Any) -> int:
|
|
2589
|
+
total = 0
|
|
2590
|
+
try:
|
|
2591
|
+
for _, p in getattr(m, "named_parameters", lambda: [])():
|
|
2592
|
+
try:
|
|
2593
|
+
total += int(p.element_size() * p.nelement())
|
|
2594
|
+
except Exception:
|
|
2595
|
+
pass
|
|
2596
|
+
for _, b in getattr(m, "named_buffers", lambda: [])():
|
|
2597
|
+
try:
|
|
2598
|
+
total += int(b.element_size() * b.nelement())
|
|
2599
|
+
except Exception:
|
|
2600
|
+
pass
|
|
2601
|
+
except Exception:
|
|
2602
|
+
return 0
|
|
2603
|
+
return total
|
|
2604
|
+
|
|
2605
|
+
# Load snapshot config from config.context.snapshot (highest precedence)
|
|
2606
|
+
cfg_snapshot = {}
|
|
2607
|
+
try:
|
|
2608
|
+
cfg_context = _to_serialisable_dict(getattr(cfg, "context", {}))
|
|
2609
|
+
if isinstance(cfg_context, dict):
|
|
2610
|
+
cfg_snapshot = _to_serialisable_dict(
|
|
2611
|
+
cfg_context.get("snapshot", {})
|
|
2612
|
+
)
|
|
2613
|
+
if not isinstance(cfg_snapshot, dict):
|
|
2614
|
+
cfg_snapshot = {}
|
|
2615
|
+
except Exception:
|
|
2616
|
+
cfg_snapshot = {}
|
|
2617
|
+
|
|
2618
|
+
def _choose_snapshot_mode() -> str:
|
|
2619
|
+
# Precedence: config > env > auto
|
|
2620
|
+
cfg_mode = (
|
|
2621
|
+
str(cfg_snapshot.get("mode", "")).lower()
|
|
2622
|
+
if isinstance(cfg_snapshot, dict)
|
|
2623
|
+
else ""
|
|
2624
|
+
)
|
|
2625
|
+
mode_env = str(
|
|
2626
|
+
os.environ.get("INVARLOCK_SNAPSHOT_MODE", "auto")
|
|
2627
|
+
).lower()
|
|
2628
|
+
supports_chunked = hasattr(adapter, "snapshot_chunked") and hasattr(
|
|
2629
|
+
adapter, "restore_chunked"
|
|
2630
|
+
)
|
|
2631
|
+
supports_bytes = hasattr(adapter, "snapshot") and hasattr(
|
|
2632
|
+
adapter, "restore"
|
|
2633
|
+
)
|
|
2634
|
+
if cfg_mode in {"bytes", "chunked"}:
|
|
2635
|
+
if cfg_mode == "bytes" and supports_bytes:
|
|
2636
|
+
return "bytes"
|
|
2637
|
+
if cfg_mode == "chunked" and supports_chunked:
|
|
2638
|
+
return "chunked"
|
|
2639
|
+
# fallback preference
|
|
2640
|
+
if supports_bytes:
|
|
2641
|
+
return "bytes"
|
|
2642
|
+
if supports_chunked:
|
|
2643
|
+
return "chunked"
|
|
2644
|
+
return "reload"
|
|
2645
|
+
if mode_env in {"bytes", "chunked"}:
|
|
2646
|
+
if mode_env == "bytes" and supports_bytes:
|
|
2647
|
+
return "bytes"
|
|
2648
|
+
if mode_env == "chunked" and supports_chunked:
|
|
2649
|
+
return "chunked"
|
|
2650
|
+
# fallback preference
|
|
2651
|
+
if supports_bytes:
|
|
2652
|
+
return "bytes"
|
|
2653
|
+
if supports_chunked:
|
|
2654
|
+
return "chunked"
|
|
2655
|
+
return "reload"
|
|
2656
|
+
# auto
|
|
2657
|
+
est_mb = _estimate_model_bytes(model) / (1024.0 * 1024.0)
|
|
2658
|
+
# RAM-based heuristic
|
|
2659
|
+
try:
|
|
2660
|
+
ram = psutil.virtual_memory()
|
|
2661
|
+
avail_mb = float(getattr(ram, "available", 0)) / (1024.0 * 1024.0)
|
|
2662
|
+
except Exception:
|
|
2663
|
+
avail_mb = 0.0
|
|
2664
|
+
# fraction: config override > env > default 0.4
|
|
2665
|
+
frac = 0.4
|
|
2666
|
+
try:
|
|
2667
|
+
if (
|
|
2668
|
+
isinstance(cfg_snapshot, dict)
|
|
2669
|
+
and cfg_snapshot.get("ram_fraction") is not None
|
|
2670
|
+
):
|
|
2671
|
+
frac = float(cfg_snapshot.get("ram_fraction"))
|
|
2672
|
+
else:
|
|
2673
|
+
frac = float(
|
|
2674
|
+
os.environ.get("INVARLOCK_SNAPSHOT_AUTO_RAM_FRACTION", frac)
|
|
2675
|
+
)
|
|
2676
|
+
except Exception:
|
|
2677
|
+
pass
|
|
2678
|
+
# threshold mb: if no RAM info, use config threshold_mb or env fallback; else derive from avail*frac
|
|
2679
|
+
if avail_mb > 0:
|
|
2680
|
+
threshold_mb = avail_mb * max(0.0, min(frac, 1.0))
|
|
2681
|
+
else:
|
|
2682
|
+
try:
|
|
2683
|
+
if (
|
|
2684
|
+
isinstance(cfg_snapshot, dict)
|
|
2685
|
+
and cfg_snapshot.get("threshold_mb") is not None
|
|
2686
|
+
):
|
|
2687
|
+
threshold_mb = float(cfg_snapshot.get("threshold_mb"))
|
|
2688
|
+
else:
|
|
2689
|
+
threshold_mb = float(
|
|
2690
|
+
os.environ.get("INVARLOCK_SNAPSHOT_THRESHOLD_MB", "768")
|
|
2691
|
+
)
|
|
2692
|
+
except Exception:
|
|
2693
|
+
threshold_mb = 768.0
|
|
2694
|
+
# Disk availability for chunked
|
|
2695
|
+
try:
|
|
2696
|
+
tmpdir = None
|
|
2697
|
+
if isinstance(cfg_snapshot, dict):
|
|
2698
|
+
tmpdir = cfg_snapshot.get("temp_dir") or None
|
|
2699
|
+
if not tmpdir:
|
|
2700
|
+
tmpdir = (
|
|
2701
|
+
os.environ.get("TMPDIR") or os.environ.get("TMP") or "/tmp"
|
|
2702
|
+
)
|
|
2703
|
+
du = shutil.disk_usage(tmpdir)
|
|
2704
|
+
free_mb = float(du.free) / (1024.0 * 1024.0)
|
|
2705
|
+
except Exception:
|
|
2706
|
+
free_mb = 0.0
|
|
2707
|
+
# Disk margin ratio: config > default 1.2
|
|
2708
|
+
margin = 1.2
|
|
2709
|
+
try:
|
|
2710
|
+
if (
|
|
2711
|
+
isinstance(cfg_snapshot, dict)
|
|
2712
|
+
and cfg_snapshot.get("disk_free_margin_ratio") is not None
|
|
2713
|
+
):
|
|
2714
|
+
margin = float(cfg_snapshot.get("disk_free_margin_ratio"))
|
|
2715
|
+
except Exception:
|
|
2716
|
+
pass
|
|
2717
|
+
# Choose chunked if model snapshot is a large fraction of available RAM and disk has room
|
|
2718
|
+
if (
|
|
2719
|
+
supports_chunked
|
|
2720
|
+
and est_mb >= threshold_mb
|
|
2721
|
+
and (free_mb <= 0.0 or est_mb * margin <= free_mb)
|
|
2722
|
+
):
|
|
2723
|
+
return "chunked"
|
|
2724
|
+
# Otherwise prefer bytes when supported
|
|
2725
|
+
if supports_bytes:
|
|
2726
|
+
# If RAM is extremely low and even bytes snapshot likely risky, fallback to chunked when possible
|
|
2727
|
+
if (
|
|
2728
|
+
supports_chunked
|
|
2729
|
+
and avail_mb > 0
|
|
2730
|
+
and est_mb >= max(64.0, avail_mb * 0.8)
|
|
2731
|
+
and (free_mb <= 0.0 or est_mb * margin <= free_mb)
|
|
2732
|
+
):
|
|
2733
|
+
return "chunked"
|
|
2734
|
+
return "bytes"
|
|
2735
|
+
if supports_chunked:
|
|
2736
|
+
return "chunked"
|
|
2737
|
+
return "reload"
|
|
2738
|
+
|
|
2739
|
+
mode = _choose_snapshot_mode()
|
|
2740
|
+
# Emit deterministic snapshot mode status line
|
|
2741
|
+
console.print(
|
|
2742
|
+
f"snapshot_mode: {'enabled' if mode in {'bytes', 'chunked'} else 'disabled'}"
|
|
2743
|
+
)
|
|
2744
|
+
if mode == "chunked":
|
|
2745
|
+
snapshot_tmpdir = adapter.snapshot_chunked(model) # type: ignore[attr-defined]
|
|
2746
|
+
|
|
2747
|
+
def _restore():
|
|
2748
|
+
adapter.restore_chunked(model, snapshot_tmpdir) # type: ignore[attr-defined]
|
|
2749
|
+
|
|
2750
|
+
restore_fn = _restore
|
|
2751
|
+
elif mode == "bytes":
|
|
2752
|
+
base_blob = adapter.snapshot(model) # type: ignore[attr-defined]
|
|
2753
|
+
|
|
2754
|
+
def _restore2():
|
|
2755
|
+
adapter.restore(model, base_blob) # type: ignore[attr-defined]
|
|
2756
|
+
|
|
2757
|
+
restore_fn = _restore2
|
|
2758
|
+
else:
|
|
2759
|
+
# reload path
|
|
2760
|
+
model = None
|
|
2761
|
+
restore_fn = None
|
|
2762
|
+
except Exception:
|
|
2763
|
+
# On any failure, fall back to reload-per-attempt path
|
|
2764
|
+
model = None
|
|
2765
|
+
restore_fn = None
|
|
2766
|
+
|
|
2767
|
+
# RETRY LOOP - All report processing inside loop
|
|
2768
|
+
attempt = 1
|
|
2769
|
+
profile_normalized = (profile or "").lower()
|
|
2770
|
+
measure_guard_overhead = profile_normalized in {"ci", "release"}
|
|
2771
|
+
|
|
2772
|
+
while True:
|
|
2773
|
+
# Reset RNG streams each attempt to guarantee determinism across retries
|
|
2774
|
+
set_seed(seed_bundle["python"])
|
|
2775
|
+
|
|
2776
|
+
if retry_controller:
|
|
2777
|
+
console.print(f"\n🚀 Attempt {attempt}/{max_attempts}")
|
|
2778
|
+
if attempt > 1:
|
|
2779
|
+
console.print(f"🔄 Retry attempt {attempt}/{max_attempts}")
|
|
2780
|
+
else:
|
|
2781
|
+
if attempt > 1:
|
|
2782
|
+
console.print(f"\n🚀 Attempt {attempt}")
|
|
2783
|
+
|
|
2784
|
+
# Adjust parameters for retry attempts
|
|
2785
|
+
if retry_controller and attempt > 1:
|
|
2786
|
+
from invarlock.core.retry import adjust_edit_params
|
|
2787
|
+
|
|
2788
|
+
edit_config = adjust_edit_params(
|
|
2789
|
+
edit_op.name, edit_config, attempt, None
|
|
2790
|
+
)
|
|
2791
|
+
|
|
2792
|
+
guard_overhead_payload: dict[str, Any] | None = None
|
|
2793
|
+
if measure_guard_overhead:
|
|
2794
|
+
guard_overhead_payload = _run_bare_control(
|
|
2795
|
+
adapter=adapter,
|
|
2796
|
+
edit_op=edit_op,
|
|
2797
|
+
cfg=cfg,
|
|
2798
|
+
model=model,
|
|
2799
|
+
run_config=run_config,
|
|
2800
|
+
calibration_data=calibration_data,
|
|
2801
|
+
auto_config=auto_config,
|
|
2802
|
+
edit_config=edit_config,
|
|
2803
|
+
preview_count=preview_count,
|
|
2804
|
+
final_count=final_count,
|
|
2805
|
+
seed_bundle=seed_bundle,
|
|
2806
|
+
resolved_device=resolved_device,
|
|
2807
|
+
restore_fn=restore_fn,
|
|
2808
|
+
console=console,
|
|
2809
|
+
resolved_loss_type=resolved_loss_type,
|
|
2810
|
+
profile_normalized=profile_normalized,
|
|
2811
|
+
skip_model_load=skip_model_load,
|
|
2812
|
+
)
|
|
2813
|
+
|
|
2814
|
+
# Ensure clean state for guarded run
|
|
2815
|
+
core_report, model = _execute_guarded_run(
|
|
2816
|
+
runner=runner,
|
|
2817
|
+
adapter=adapter,
|
|
2818
|
+
model=model,
|
|
2819
|
+
cfg=cfg,
|
|
2820
|
+
edit_op=edit_op,
|
|
2821
|
+
run_config=run_config,
|
|
2822
|
+
guards=guards,
|
|
2823
|
+
calibration_data=calibration_data,
|
|
2824
|
+
auto_config=auto_config,
|
|
2825
|
+
edit_config=edit_config,
|
|
2826
|
+
preview_count=preview_count,
|
|
2827
|
+
final_count=final_count,
|
|
2828
|
+
restore_fn=restore_fn,
|
|
2829
|
+
resolved_device=resolved_device,
|
|
2830
|
+
console=console,
|
|
2831
|
+
skip_model_load=skip_model_load,
|
|
2832
|
+
)
|
|
2833
|
+
|
|
2834
|
+
if not hasattr(core_report, "context") or core_report.context is None:
|
|
2835
|
+
core_report.context = {}
|
|
2836
|
+
|
|
2837
|
+
# Convert CoreRunner report to evaluation report
|
|
2838
|
+
report = create_empty_report()
|
|
2839
|
+
|
|
2840
|
+
# Code provenance: commit hash and InvarLock version
|
|
2841
|
+
commit_value = (
|
|
2842
|
+
getattr(cfg.meta, "commit", "") if hasattr(cfg, "meta") else ""
|
|
2843
|
+
)
|
|
2844
|
+
if not commit_value:
|
|
2845
|
+
try:
|
|
2846
|
+
import subprocess
|
|
2847
|
+
|
|
2848
|
+
commit_value = (
|
|
2849
|
+
subprocess.check_output(
|
|
2850
|
+
["git", "rev-parse", "HEAD"], stderr=subprocess.DEVNULL
|
|
2851
|
+
)
|
|
2852
|
+
.decode("utf-8", "ignore")
|
|
2853
|
+
.strip()
|
|
2854
|
+
)
|
|
2855
|
+
except Exception:
|
|
2856
|
+
commit_value = ""
|
|
2857
|
+
invarlock_version = None
|
|
2858
|
+
try:
|
|
2859
|
+
from invarlock import __version__ as _invarlock_version
|
|
2860
|
+
|
|
2861
|
+
invarlock_version = _invarlock_version
|
|
2862
|
+
except Exception:
|
|
2863
|
+
invarlock_version = None
|
|
2864
|
+
|
|
2865
|
+
# Collect determinism/env flags
|
|
2866
|
+
env_flags: dict[str, object] = {}
|
|
2867
|
+
try:
|
|
2868
|
+
import os as _os
|
|
2869
|
+
|
|
2870
|
+
if torch is not None:
|
|
2871
|
+
try:
|
|
2872
|
+
det_enabled = getattr(
|
|
2873
|
+
torch, "are_deterministic_algorithms_enabled", None
|
|
2874
|
+
)
|
|
2875
|
+
if callable(det_enabled):
|
|
2876
|
+
env_flags["torch_deterministic_algorithms"] = bool(
|
|
2877
|
+
det_enabled()
|
|
2878
|
+
)
|
|
2879
|
+
except Exception:
|
|
2880
|
+
pass
|
|
2881
|
+
try:
|
|
2882
|
+
tf32_matmul = getattr(
|
|
2883
|
+
getattr(torch.backends, "cuda", object()), "matmul", None
|
|
2884
|
+
)
|
|
2885
|
+
if tf32_matmul is not None and hasattr(
|
|
2886
|
+
tf32_matmul, "allow_tf32"
|
|
2887
|
+
):
|
|
2888
|
+
env_flags["cuda_matmul_allow_tf32"] = bool(
|
|
2889
|
+
tf32_matmul.allow_tf32
|
|
2890
|
+
)
|
|
2891
|
+
except Exception:
|
|
2892
|
+
pass
|
|
2893
|
+
try:
|
|
2894
|
+
cudnn_mod = getattr(torch.backends, "cudnn", None)
|
|
2895
|
+
if cudnn_mod is not None:
|
|
2896
|
+
env_flags["cudnn_allow_tf32"] = bool(
|
|
2897
|
+
getattr(cudnn_mod, "allow_tf32", None)
|
|
2898
|
+
)
|
|
2899
|
+
env_flags["cudnn_deterministic"] = bool(
|
|
2900
|
+
getattr(cudnn_mod, "deterministic", None)
|
|
2901
|
+
)
|
|
2902
|
+
env_flags["cudnn_benchmark"] = bool(
|
|
2903
|
+
getattr(cudnn_mod, "benchmark", None)
|
|
2904
|
+
)
|
|
2905
|
+
except Exception:
|
|
2906
|
+
pass
|
|
2907
|
+
try:
|
|
2908
|
+
env_flags["mps_available"] = bool(
|
|
2909
|
+
getattr(torch.backends, "mps", None)
|
|
2910
|
+
and torch.backends.mps.is_available()
|
|
2911
|
+
)
|
|
2912
|
+
except Exception:
|
|
2913
|
+
pass
|
|
2914
|
+
# Common environment variables for determinism
|
|
2915
|
+
env_flags["CUBLAS_WORKSPACE_CONFIG"] = _os.environ.get(
|
|
2916
|
+
"CUBLAS_WORKSPACE_CONFIG"
|
|
2917
|
+
)
|
|
2918
|
+
except Exception:
|
|
2919
|
+
env_flags = {}
|
|
2920
|
+
|
|
2921
|
+
meta_payload = {
|
|
2922
|
+
"model_id": cfg.model.id,
|
|
2923
|
+
"adapter": cfg.model.adapter,
|
|
2924
|
+
"device": str(resolved_device),
|
|
2925
|
+
"commit": commit_value,
|
|
2926
|
+
"seed": seed_bundle["python"],
|
|
2927
|
+
"seeds": seed_bundle,
|
|
2928
|
+
"ts": datetime.now().isoformat(),
|
|
2929
|
+
"auto": auto_config,
|
|
2930
|
+
}
|
|
2931
|
+
if invarlock_version:
|
|
2932
|
+
meta_payload["invarlock_version"] = invarlock_version
|
|
2933
|
+
if env_flags:
|
|
2934
|
+
meta_payload["env_flags"] = env_flags
|
|
2935
|
+
report["meta"].update(meta_payload)
|
|
2936
|
+
report["meta"]["model_profile"] = {
|
|
2937
|
+
"family": model_profile.family,
|
|
2938
|
+
"default_loss": model_profile.default_loss,
|
|
2939
|
+
"module_selectors": model_profile.module_selectors,
|
|
2940
|
+
"invariants": list(model_profile.invariants),
|
|
2941
|
+
"cert_lints": [dict(lint) for lint in model_profile.cert_lints],
|
|
2942
|
+
}
|
|
2943
|
+
|
|
2944
|
+
report["data"].update(
|
|
2945
|
+
{
|
|
2946
|
+
"dataset": cfg.dataset.provider,
|
|
2947
|
+
# Resolved split (explicit or inferred)
|
|
2948
|
+
"split": resolved_split,
|
|
2949
|
+
"seq_len": cfg.dataset.seq_len,
|
|
2950
|
+
"stride": getattr(cfg.dataset, "stride", cfg.dataset.seq_len // 2),
|
|
2951
|
+
"preview_n": _safe_int(preview_count),
|
|
2952
|
+
"final_n": _safe_int(final_count),
|
|
2953
|
+
}
|
|
2954
|
+
)
|
|
2955
|
+
dataset_meta_context = core_report.context.get("dataset_meta", {})
|
|
2956
|
+
if isinstance(dataset_meta_context, dict):
|
|
2957
|
+
report["data"].update(dataset_meta_context)
|
|
2958
|
+
dataset_tokenizer_hash = dataset_meta_context.get("tokenizer_hash")
|
|
2959
|
+
if (
|
|
2960
|
+
not tokenizer_hash
|
|
2961
|
+
and isinstance(dataset_tokenizer_hash, str)
|
|
2962
|
+
and dataset_tokenizer_hash
|
|
2963
|
+
):
|
|
2964
|
+
tokenizer_hash = dataset_tokenizer_hash
|
|
2965
|
+
|
|
2966
|
+
if tokenizer_hash:
|
|
2967
|
+
report["meta"]["tokenizer_hash"] = tokenizer_hash
|
|
2968
|
+
|
|
2969
|
+
# Transfer edit information
|
|
2970
|
+
if hasattr(core_report, "edit") and core_report.edit:
|
|
2971
|
+
edit_deltas = core_report.edit.get("deltas", {})
|
|
2972
|
+
report["edit"].update(
|
|
2973
|
+
{
|
|
2974
|
+
"name": edit_op.name,
|
|
2975
|
+
"plan_digest": core_report.edit.get(
|
|
2976
|
+
"plan_digest", str(hash(str(core_report.edit)))
|
|
2977
|
+
),
|
|
2978
|
+
"deltas": {
|
|
2979
|
+
"params_changed": edit_deltas.get("params_changed", 0),
|
|
2980
|
+
"sparsity": edit_deltas.get("sparsity", None),
|
|
2981
|
+
"bitwidth_map": edit_deltas.get("bitwidth_map", None),
|
|
2982
|
+
"layers_modified": edit_deltas.get("layers_modified", 0),
|
|
2983
|
+
},
|
|
2984
|
+
}
|
|
2985
|
+
)
|
|
2986
|
+
for key in (
|
|
2987
|
+
"algorithm",
|
|
2988
|
+
"algorithm_version",
|
|
2989
|
+
"implementation",
|
|
2990
|
+
"scope",
|
|
2991
|
+
"ranking",
|
|
2992
|
+
"grouping",
|
|
2993
|
+
"budgets",
|
|
2994
|
+
"seed",
|
|
2995
|
+
"mask_digest",
|
|
2996
|
+
):
|
|
2997
|
+
if key in core_report.edit:
|
|
2998
|
+
report["edit"][key] = copy.deepcopy(core_report.edit[key])
|
|
2999
|
+
if isinstance(core_report.context, dict):
|
|
3000
|
+
core_report.context.setdefault("edit", {})
|
|
3001
|
+
core_report.context["edit"].update(
|
|
3002
|
+
{
|
|
3003
|
+
"name": edit_op.name,
|
|
3004
|
+
"params_changed": edit_deltas.get("params_changed", 0),
|
|
3005
|
+
"layers_modified": edit_deltas.get("layers_modified", 0),
|
|
3006
|
+
}
|
|
3007
|
+
)
|
|
3008
|
+
|
|
3009
|
+
mask_artifact_path = _persist_ref_masks(core_report, run_dir)
|
|
3010
|
+
if mask_artifact_path:
|
|
3011
|
+
report.setdefault("artifacts", {})
|
|
3012
|
+
report["artifacts"]["masks_path"] = str(mask_artifact_path)
|
|
3013
|
+
|
|
3014
|
+
# Transfer metrics (PM-only: do not write ppl_* fields)
|
|
3015
|
+
if hasattr(core_report, "metrics") and core_report.metrics:
|
|
3016
|
+
metrics_payload = {
|
|
3017
|
+
"latency_ms_per_tok": core_report.metrics.get(
|
|
3018
|
+
"latency_ms_per_tok", 0.0
|
|
3019
|
+
),
|
|
3020
|
+
"memory_mb_peak": core_report.metrics.get("memory_mb_peak", 0.0),
|
|
3021
|
+
"spectral": {},
|
|
3022
|
+
"rmt": {},
|
|
3023
|
+
"invariants": {},
|
|
3024
|
+
}
|
|
3025
|
+
window_plan_ctx = core_report.context.get("window_plan")
|
|
3026
|
+
if isinstance(window_plan_ctx, dict):
|
|
3027
|
+
metrics_payload["window_plan"] = window_plan_ctx
|
|
3028
|
+
capacity_meta = window_plan_ctx.get("capacity")
|
|
3029
|
+
if isinstance(capacity_meta, dict):
|
|
3030
|
+
metrics_payload["window_capacity"] = capacity_meta
|
|
3031
|
+
stats_section = metrics_payload.setdefault("stats", {})
|
|
3032
|
+
if isinstance(stats_section, dict):
|
|
3033
|
+
stats_section.update(
|
|
3034
|
+
{
|
|
3035
|
+
"requested_preview": window_plan_ctx.get(
|
|
3036
|
+
"requested_preview"
|
|
3037
|
+
),
|
|
3038
|
+
"requested_final": window_plan_ctx.get(
|
|
3039
|
+
"requested_final"
|
|
3040
|
+
),
|
|
3041
|
+
"actual_preview": window_plan_ctx.get("actual_preview"),
|
|
3042
|
+
"actual_final": window_plan_ctx.get("actual_final"),
|
|
3043
|
+
"coverage_ok": window_plan_ctx.get("coverage_ok"),
|
|
3044
|
+
}
|
|
3045
|
+
)
|
|
3046
|
+
optional_keys = [
|
|
3047
|
+
"logloss_preview",
|
|
3048
|
+
"logloss_final",
|
|
3049
|
+
"logloss_delta",
|
|
3050
|
+
"logloss_preview_ci",
|
|
3051
|
+
"logloss_final_ci",
|
|
3052
|
+
"logloss_delta_ci",
|
|
3053
|
+
"bootstrap",
|
|
3054
|
+
"window_overlap_fraction",
|
|
3055
|
+
"window_match_fraction",
|
|
3056
|
+
"window_pairing_reason",
|
|
3057
|
+
"window_pairing_preview",
|
|
3058
|
+
"window_pairing_final",
|
|
3059
|
+
"paired_windows",
|
|
3060
|
+
"paired_delta_summary",
|
|
3061
|
+
"preview_total_tokens",
|
|
3062
|
+
"final_total_tokens",
|
|
3063
|
+
"masked_tokens_total",
|
|
3064
|
+
"masked_tokens_preview",
|
|
3065
|
+
"masked_tokens_final",
|
|
3066
|
+
"reduction",
|
|
3067
|
+
]
|
|
3068
|
+
for key in optional_keys:
|
|
3069
|
+
if key in core_report.metrics:
|
|
3070
|
+
metrics_payload[key] = core_report.metrics[key]
|
|
3071
|
+
metrics_payload["loss_type"] = resolved_loss_type
|
|
3072
|
+
if metrics_payload.get("loss_type") is None and isinstance(
|
|
3073
|
+
dataset_meta_context, dict
|
|
3074
|
+
):
|
|
3075
|
+
metrics_payload["loss_type"] = dataset_meta_context.get(
|
|
3076
|
+
"loss_type", resolved_loss_type
|
|
3077
|
+
)
|
|
3078
|
+
if isinstance(dataset_meta_context, dict):
|
|
3079
|
+
for meta_key in (
|
|
3080
|
+
"masked_tokens_total",
|
|
3081
|
+
"masked_tokens_preview",
|
|
3082
|
+
"masked_tokens_final",
|
|
3083
|
+
):
|
|
3084
|
+
if (
|
|
3085
|
+
meta_key not in metrics_payload
|
|
3086
|
+
and dataset_meta_context.get(meta_key) is not None
|
|
3087
|
+
):
|
|
3088
|
+
metrics_payload[meta_key] = dataset_meta_context[meta_key]
|
|
3089
|
+
report["metrics"].update(metrics_payload)
|
|
3090
|
+
|
|
3091
|
+
if guard_overhead_payload is not None:
|
|
3092
|
+
# Compute guarded primary-metric snapshot; pass structured reports into validator
|
|
3093
|
+
try:
|
|
3094
|
+
# Map loss type to ppl family kind
|
|
3095
|
+
lk = str(resolved_loss_type or "causal").lower()
|
|
3096
|
+
if lk == "mlm":
|
|
3097
|
+
pm_kind_for_overhead = "ppl_mlm"
|
|
3098
|
+
elif lk in {"seq2seq", "s2s", "t5"}:
|
|
3099
|
+
pm_kind_for_overhead = "ppl_seq2seq"
|
|
3100
|
+
else:
|
|
3101
|
+
pm_kind_for_overhead = "ppl_causal"
|
|
3102
|
+
|
|
3103
|
+
# Prefer computing from the in-memory core_report windows to avoid ordering issues
|
|
3104
|
+
pm_guarded = _extract_pm_snapshot_for_overhead(
|
|
3105
|
+
core_report, kind=pm_kind_for_overhead
|
|
3106
|
+
)
|
|
3107
|
+
if not isinstance(pm_guarded, dict) or not pm_guarded:
|
|
3108
|
+
pm_guarded = _extract_pm_snapshot_for_overhead(
|
|
3109
|
+
report, kind=pm_kind_for_overhead
|
|
3110
|
+
)
|
|
3111
|
+
|
|
3112
|
+
guard_overhead_payload["guarded_report"] = (
|
|
3113
|
+
{"metrics": {"primary_metric": pm_guarded}}
|
|
3114
|
+
if isinstance(pm_guarded, dict) and pm_guarded
|
|
3115
|
+
else None
|
|
3116
|
+
)
|
|
3117
|
+
except Exception:
|
|
3118
|
+
guard_overhead_payload["guarded_report"] = None
|
|
3119
|
+
bare_struct = guard_overhead_payload.get("bare_report") or {}
|
|
3120
|
+
guarded_struct = guard_overhead_payload.get("guarded_report") or {}
|
|
3121
|
+
# Be robust to mocks or minimal objects returned by validators
|
|
3122
|
+
result = validate_guard_overhead(
|
|
3123
|
+
bare_struct,
|
|
3124
|
+
guarded_struct,
|
|
3125
|
+
overhead_threshold=guard_overhead_payload.get(
|
|
3126
|
+
"overhead_threshold", GUARD_OVERHEAD_THRESHOLD
|
|
3127
|
+
),
|
|
3128
|
+
)
|
|
3129
|
+
try:
|
|
3130
|
+
messages = list(getattr(result, "messages", []))
|
|
3131
|
+
except Exception: # pragma: no cover - defensive
|
|
3132
|
+
messages = []
|
|
3133
|
+
try:
|
|
3134
|
+
warnings = list(getattr(result, "warnings", []))
|
|
3135
|
+
except Exception: # pragma: no cover - defensive
|
|
3136
|
+
warnings = []
|
|
3137
|
+
try:
|
|
3138
|
+
errors = list(getattr(result, "errors", []))
|
|
3139
|
+
except Exception: # pragma: no cover - defensive
|
|
3140
|
+
errors = []
|
|
3141
|
+
try:
|
|
3142
|
+
checks = dict(getattr(result, "checks", {}))
|
|
3143
|
+
except Exception: # pragma: no cover - defensive
|
|
3144
|
+
checks = {}
|
|
3145
|
+
metrics_obj = getattr(result, "metrics", {})
|
|
3146
|
+
if not isinstance(metrics_obj, dict):
|
|
3147
|
+
metrics_obj = {}
|
|
3148
|
+
overhead_ratio = metrics_obj.get("overhead_ratio")
|
|
3149
|
+
if overhead_ratio is None:
|
|
3150
|
+
overhead_ratio = getattr(result, "overhead_ratio", None)
|
|
3151
|
+
overhead_percent = metrics_obj.get("overhead_percent")
|
|
3152
|
+
if overhead_percent is None:
|
|
3153
|
+
overhead_percent = getattr(result, "overhead_percent", None)
|
|
3154
|
+
passed_flag = bool(getattr(result, "passed", False))
|
|
3155
|
+
|
|
3156
|
+
guard_overhead_payload.update(
|
|
3157
|
+
{
|
|
3158
|
+
"messages": messages,
|
|
3159
|
+
"warnings": warnings,
|
|
3160
|
+
"errors": errors,
|
|
3161
|
+
"checks": checks,
|
|
3162
|
+
"overhead_ratio": overhead_ratio,
|
|
3163
|
+
"overhead_percent": overhead_percent,
|
|
3164
|
+
"passed": passed_flag,
|
|
3165
|
+
"evaluated": True,
|
|
3166
|
+
}
|
|
3167
|
+
)
|
|
3168
|
+
# Normalize for non-finite/degenerate cases
|
|
3169
|
+
guard_overhead_payload = _normalize_overhead_result(
|
|
3170
|
+
guard_overhead_payload, profile=profile_normalized
|
|
3171
|
+
)
|
|
3172
|
+
report["guard_overhead"] = guard_overhead_payload
|
|
3173
|
+
|
|
3174
|
+
had_baseline = bool(baseline and Path(baseline).exists())
|
|
3175
|
+
if (
|
|
3176
|
+
hasattr(core_report, "evaluation_windows")
|
|
3177
|
+
and core_report.evaluation_windows
|
|
3178
|
+
):
|
|
3179
|
+
preview_windows = core_report.evaluation_windows.get("preview", {})
|
|
3180
|
+
final_windows = core_report.evaluation_windows.get("final", {})
|
|
3181
|
+
report["evaluation_windows"] = {
|
|
3182
|
+
"preview": {
|
|
3183
|
+
"window_ids": list(preview_windows.get("window_ids", [])),
|
|
3184
|
+
"logloss": list(preview_windows.get("logloss", [])),
|
|
3185
|
+
"input_ids": [
|
|
3186
|
+
list(seq) for seq in preview_windows.get("input_ids", [])
|
|
3187
|
+
],
|
|
3188
|
+
"attention_masks": [
|
|
3189
|
+
list(mask)
|
|
3190
|
+
for mask in preview_windows.get("attention_masks", [])
|
|
3191
|
+
],
|
|
3192
|
+
"token_counts": list(preview_windows.get("token_counts", [])),
|
|
3193
|
+
"masked_token_counts": list(
|
|
3194
|
+
preview_windows.get("masked_token_counts", [])
|
|
3195
|
+
),
|
|
3196
|
+
"actual_token_counts": list(
|
|
3197
|
+
preview_windows.get("actual_token_counts", [])
|
|
3198
|
+
),
|
|
3199
|
+
"labels": [
|
|
3200
|
+
list(seq) for seq in preview_windows.get("labels", [])
|
|
3201
|
+
],
|
|
3202
|
+
},
|
|
3203
|
+
"final": {
|
|
3204
|
+
"window_ids": list(final_windows.get("window_ids", [])),
|
|
3205
|
+
"logloss": list(final_windows.get("logloss", [])),
|
|
3206
|
+
"input_ids": [
|
|
3207
|
+
list(seq) for seq in final_windows.get("input_ids", [])
|
|
3208
|
+
],
|
|
3209
|
+
"attention_masks": [
|
|
3210
|
+
list(mask)
|
|
3211
|
+
for mask in final_windows.get("attention_masks", [])
|
|
3212
|
+
],
|
|
3213
|
+
"token_counts": list(final_windows.get("token_counts", [])),
|
|
3214
|
+
"masked_token_counts": list(
|
|
3215
|
+
final_windows.get("masked_token_counts", [])
|
|
3216
|
+
),
|
|
3217
|
+
"actual_token_counts": list(
|
|
3218
|
+
final_windows.get("actual_token_counts", [])
|
|
3219
|
+
),
|
|
3220
|
+
"labels": [
|
|
3221
|
+
list(seq) for seq in final_windows.get("labels", [])
|
|
3222
|
+
],
|
|
3223
|
+
},
|
|
3224
|
+
}
|
|
3225
|
+
elif had_baseline and (profile or "").lower() == "release":
|
|
3226
|
+
console.print(
|
|
3227
|
+
"[red]❌ [INVARLOCK:E001] PAIRING-SCHEDULE-MISMATCH: baseline pairing requested but evaluation windows were not produced. Check capacity/pairing config.[/red]"
|
|
3228
|
+
)
|
|
3229
|
+
raise typer.Exit(3)
|
|
3230
|
+
else:
|
|
3231
|
+
# Populate evaluation_windows directly from assembled records when the
|
|
3232
|
+
# runner did not provide a structured window payload. This ensures
|
|
3233
|
+
# provenance (provider_digest) can be computed even in lightweight/dev
|
|
3234
|
+
# runs and unit tests that stub the runner.
|
|
3235
|
+
try:
|
|
3236
|
+
|
|
3237
|
+
def _tokens(rec: dict[str, Any]) -> int:
|
|
3238
|
+
try:
|
|
3239
|
+
return int(len(rec.get("input_ids", []) or []))
|
|
3240
|
+
except Exception:
|
|
3241
|
+
return 0
|
|
3242
|
+
|
|
3243
|
+
report["evaluation_windows"] = {
|
|
3244
|
+
"preview": {
|
|
3245
|
+
"window_ids": [
|
|
3246
|
+
f"preview::{i}" for i in range(len(preview_records))
|
|
3247
|
+
],
|
|
3248
|
+
"input_ids": [
|
|
3249
|
+
list(r["input_ids"]) for r in preview_records
|
|
3250
|
+
],
|
|
3251
|
+
"attention_masks": [
|
|
3252
|
+
list(r["attention_mask"]) for r in preview_records
|
|
3253
|
+
],
|
|
3254
|
+
"token_counts": [_tokens(r) for r in preview_records],
|
|
3255
|
+
**(
|
|
3256
|
+
{
|
|
3257
|
+
"masked_token_counts": list(preview_mask_counts),
|
|
3258
|
+
"labels": [
|
|
3259
|
+
r.get("labels", [-100] * len(r["input_ids"]))
|
|
3260
|
+
for r in preview_records
|
|
3261
|
+
],
|
|
3262
|
+
}
|
|
3263
|
+
if use_mlm
|
|
3264
|
+
else {}
|
|
3265
|
+
),
|
|
3266
|
+
},
|
|
3267
|
+
"final": {
|
|
3268
|
+
"window_ids": [
|
|
3269
|
+
f"final::{i}" for i in range(len(final_records))
|
|
3270
|
+
],
|
|
3271
|
+
"input_ids": [list(r["input_ids"]) for r in final_records],
|
|
3272
|
+
"attention_masks": [
|
|
3273
|
+
list(r["attention_mask"]) for r in final_records
|
|
3274
|
+
],
|
|
3275
|
+
"token_counts": [_tokens(r) for r in final_records],
|
|
3276
|
+
**(
|
|
3277
|
+
{
|
|
3278
|
+
"masked_token_counts": list(final_mask_counts),
|
|
3279
|
+
"labels": [
|
|
3280
|
+
r.get("labels", [-100] * len(r["input_ids"]))
|
|
3281
|
+
for r in final_records
|
|
3282
|
+
],
|
|
3283
|
+
}
|
|
3284
|
+
if use_mlm
|
|
3285
|
+
else {}
|
|
3286
|
+
),
|
|
3287
|
+
},
|
|
3288
|
+
}
|
|
3289
|
+
except Exception:
|
|
3290
|
+
# Best-effort: provenance digest will be skipped if windows cannot be built
|
|
3291
|
+
pass
|
|
3292
|
+
|
|
3293
|
+
# Attach provider digest and dataset split provenance when available
|
|
3294
|
+
try:
|
|
3295
|
+
prov = report.setdefault("provenance", {})
|
|
3296
|
+
# Always record dataset split provenance for visibility
|
|
3297
|
+
try:
|
|
3298
|
+
prov["dataset_split"] = str(resolved_split)
|
|
3299
|
+
prov["split_fallback"] = bool(used_fallback_split)
|
|
3300
|
+
except Exception:
|
|
3301
|
+
pass
|
|
3302
|
+
provider_digest = _compute_provider_digest(report)
|
|
3303
|
+
if provider_digest:
|
|
3304
|
+
prov["provider_digest"] = provider_digest
|
|
3305
|
+
# Attach digest version for future evolution
|
|
3306
|
+
prov["digest_version"] = 1
|
|
3307
|
+
# Strict parity checks in CI/Release when baseline present
|
|
3308
|
+
try:
|
|
3309
|
+
if isinstance(baseline_report_data, dict):
|
|
3310
|
+
base_prov = (
|
|
3311
|
+
baseline_report_data.get("provenance", {})
|
|
3312
|
+
if isinstance(
|
|
3313
|
+
baseline_report_data.get("provenance"), dict
|
|
3314
|
+
)
|
|
3315
|
+
else {}
|
|
3316
|
+
)
|
|
3317
|
+
base_digest = (
|
|
3318
|
+
base_prov.get("provider_digest")
|
|
3319
|
+
if isinstance(base_prov, dict)
|
|
3320
|
+
else None
|
|
3321
|
+
)
|
|
3322
|
+
_enforce_provider_parity(
|
|
3323
|
+
provider_digest,
|
|
3324
|
+
base_digest,
|
|
3325
|
+
profile=(str(profile).lower() if profile else None),
|
|
3326
|
+
)
|
|
3327
|
+
except InvarlockError as ce:
|
|
3328
|
+
console.print(str(ce))
|
|
3329
|
+
# Map to profile-aware exit code: dev→1, ci/release→3
|
|
3330
|
+
raise typer.Exit(
|
|
3331
|
+
_resolve_exit_code(ce, profile=profile)
|
|
3332
|
+
) from None
|
|
3333
|
+
except RuntimeError as _e:
|
|
3334
|
+
_fail_run(str(_e))
|
|
3335
|
+
except Exception:
|
|
3336
|
+
pass
|
|
3337
|
+
except Exception:
|
|
3338
|
+
pass
|
|
3339
|
+
|
|
3340
|
+
# Transfer guard results
|
|
3341
|
+
if hasattr(core_report, "guards") and core_report.guards:
|
|
3342
|
+
for guard_name, guard_result in core_report.guards.items():
|
|
3343
|
+
guard_entry = {
|
|
3344
|
+
"name": guard_name,
|
|
3345
|
+
"passed": guard_result.get("passed"),
|
|
3346
|
+
"action": guard_result.get("action"),
|
|
3347
|
+
"policy": guard_result.get("policy", {}),
|
|
3348
|
+
"metrics": guard_result.get("metrics", {}),
|
|
3349
|
+
"actions": guard_result.get("actions", []),
|
|
3350
|
+
"violations": guard_result.get("violations", []),
|
|
3351
|
+
"warnings": guard_result.get("warnings", []),
|
|
3352
|
+
"errors": guard_result.get("errors", []),
|
|
3353
|
+
"details": guard_result.get("details", {}),
|
|
3354
|
+
}
|
|
3355
|
+
for extra_key in ("final_z_scores", "module_family_map"):
|
|
3356
|
+
if extra_key in guard_result:
|
|
3357
|
+
guard_entry[extra_key] = guard_result[extra_key]
|
|
3358
|
+
report["guards"].append(guard_entry)
|
|
3359
|
+
|
|
3360
|
+
# Set artifacts
|
|
3361
|
+
report["artifacts"].update(
|
|
3362
|
+
{
|
|
3363
|
+
"events_path": str(run_config.event_path)
|
|
3364
|
+
if run_config.event_path
|
|
3365
|
+
else "",
|
|
3366
|
+
"logs_path": "",
|
|
3367
|
+
"checkpoint_path": None,
|
|
3368
|
+
}
|
|
3369
|
+
)
|
|
3370
|
+
|
|
3371
|
+
# Optional: export HF-loadable model snapshot when requested
|
|
3372
|
+
export_env = str(
|
|
3373
|
+
os.environ.get("INVARLOCK_EXPORT_MODEL", "")
|
|
3374
|
+
).strip().lower() in {
|
|
3375
|
+
"1",
|
|
3376
|
+
"true",
|
|
3377
|
+
"yes",
|
|
3378
|
+
"on",
|
|
3379
|
+
}
|
|
3380
|
+
save_model_cfg = False
|
|
3381
|
+
try:
|
|
3382
|
+
save_model_cfg = bool(
|
|
3383
|
+
getattr(getattr(cfg, "output", {}), "save_model", False)
|
|
3384
|
+
)
|
|
3385
|
+
except Exception:
|
|
3386
|
+
save_model_cfg = False
|
|
3387
|
+
if export_env or save_model_cfg:
|
|
3388
|
+
try:
|
|
3389
|
+
# Resolve destination with precedence:
|
|
3390
|
+
# 1) cfg.output.model_dir (absolute or relative to run_dir)
|
|
3391
|
+
# 2) env INVARLOCK_EXPORT_DIR (absolute or relative)
|
|
3392
|
+
# 3) cfg.output.model_subdir (under run_dir)
|
|
3393
|
+
# 4) default: run_dir / "model"
|
|
3394
|
+
export_dir: Path | None = None
|
|
3395
|
+
# (1) explicit model_dir in config
|
|
3396
|
+
try:
|
|
3397
|
+
out_cfg = getattr(cfg, "output", None)
|
|
3398
|
+
model_dir_cfg = None
|
|
3399
|
+
if out_cfg is not None:
|
|
3400
|
+
model_dir_cfg = getattr(
|
|
3401
|
+
out_cfg, "model_dir", None
|
|
3402
|
+
) or getattr(out_cfg, "model_path", None)
|
|
3403
|
+
if model_dir_cfg:
|
|
3404
|
+
p = Path(str(model_dir_cfg))
|
|
3405
|
+
export_dir = p if p.is_absolute() else (run_dir / p)
|
|
3406
|
+
except Exception:
|
|
3407
|
+
export_dir = None
|
|
3408
|
+
# (2) env override
|
|
3409
|
+
if export_dir is None:
|
|
3410
|
+
env_dir_raw = os.environ.get("INVARLOCK_EXPORT_DIR", "")
|
|
3411
|
+
if isinstance(env_dir_raw, str) and env_dir_raw.strip():
|
|
3412
|
+
p = Path(env_dir_raw.strip())
|
|
3413
|
+
export_dir = p if p.is_absolute() else (run_dir / p)
|
|
3414
|
+
# (3) config subdir
|
|
3415
|
+
if export_dir is None:
|
|
3416
|
+
export_subdir = "model"
|
|
3417
|
+
try:
|
|
3418
|
+
export_subdir = str(
|
|
3419
|
+
getattr(
|
|
3420
|
+
getattr(cfg, "output", {}), "model_subdir", "model"
|
|
3421
|
+
)
|
|
3422
|
+
)
|
|
3423
|
+
except Exception:
|
|
3424
|
+
export_subdir = "model"
|
|
3425
|
+
export_dir = run_dir / export_subdir
|
|
3426
|
+
|
|
3427
|
+
# Ensure directory exists
|
|
3428
|
+
ok = False
|
|
3429
|
+
if hasattr(adapter, "save_pretrained") and model is not None:
|
|
3430
|
+
ok = bool(adapter.save_pretrained(model, export_dir)) # type: ignore[attr-defined]
|
|
3431
|
+
if ok:
|
|
3432
|
+
report["artifacts"]["checkpoint_path"] = str(export_dir)
|
|
3433
|
+
else:
|
|
3434
|
+
console.print(
|
|
3435
|
+
"[yellow]⚠️ Model export requested but adapter did not save a HF directory.[/yellow]"
|
|
3436
|
+
)
|
|
3437
|
+
except Exception:
|
|
3438
|
+
console.print(
|
|
3439
|
+
"[yellow]⚠️ Model export requested but failed due to an unexpected error.[/yellow]"
|
|
3440
|
+
)
|
|
3441
|
+
|
|
3442
|
+
# Set flags
|
|
3443
|
+
report["flags"].update(
|
|
3444
|
+
{
|
|
3445
|
+
"guard_recovered": any(
|
|
3446
|
+
not g.get("passed", True)
|
|
3447
|
+
for g in core_report.guards.values()
|
|
3448
|
+
if hasattr(core_report, "guards") and core_report.guards
|
|
3449
|
+
),
|
|
3450
|
+
"rollback_reason": None,
|
|
3451
|
+
}
|
|
3452
|
+
)
|
|
3453
|
+
|
|
3454
|
+
metrics_section = report.get("metrics", {}) or {}
|
|
3455
|
+
data_section = report.get("data", {}) or {}
|
|
3456
|
+
preview_count_report = data_section.get("preview_n")
|
|
3457
|
+
final_count_report = data_section.get("final_n")
|
|
3458
|
+
|
|
3459
|
+
# Classification metric (accuracy) — deterministic smoke path
|
|
3460
|
+
# If loss type is explicitly 'classification', derive accuracy
|
|
3461
|
+
# counts from evaluation windows using a deterministic label rule.
|
|
3462
|
+
try:
|
|
3463
|
+
loss_type_ctx = (
|
|
3464
|
+
run_config.context.get("eval", {})
|
|
3465
|
+
.get("loss", {})
|
|
3466
|
+
.get("resolved_type")
|
|
3467
|
+
)
|
|
3468
|
+
except Exception:
|
|
3469
|
+
loss_type_ctx = None
|
|
3470
|
+
if str(loss_type_ctx).lower() == "classification":
|
|
3471
|
+
try:
|
|
3472
|
+
from invarlock.eval.primary_metric import compute_accuracy_counts
|
|
3473
|
+
|
|
3474
|
+
# Prefer in-memory core_report.evaluation_windows (includes input_ids)
|
|
3475
|
+
ew = {}
|
|
3476
|
+
try:
|
|
3477
|
+
if hasattr(core_report, "evaluation_windows") and isinstance(
|
|
3478
|
+
core_report.evaluation_windows, dict
|
|
3479
|
+
):
|
|
3480
|
+
ew = core_report.evaluation_windows # type: ignore[assignment]
|
|
3481
|
+
except Exception:
|
|
3482
|
+
ew = {}
|
|
3483
|
+
if not ew:
|
|
3484
|
+
# Fallback to the soon-to-be persisted report windows (may lack input_ids)
|
|
3485
|
+
ew = (
|
|
3486
|
+
report.get("evaluation_windows", {})
|
|
3487
|
+
if isinstance(report.get("evaluation_windows"), dict)
|
|
3488
|
+
else {}
|
|
3489
|
+
)
|
|
3490
|
+
prev_rec = []
|
|
3491
|
+
fin_rec = []
|
|
3492
|
+
if isinstance(ew, dict):
|
|
3493
|
+
prev = ew.get("preview", {})
|
|
3494
|
+
fin = ew.get("final", {})
|
|
3495
|
+
if isinstance(prev, dict):
|
|
3496
|
+
prev_rec = [
|
|
3497
|
+
{"input_ids": seq}
|
|
3498
|
+
for seq in prev.get("input_ids", []) or []
|
|
3499
|
+
if isinstance(seq, list)
|
|
3500
|
+
]
|
|
3501
|
+
if isinstance(fin, dict):
|
|
3502
|
+
fin_rec = [
|
|
3503
|
+
{"input_ids": seq}
|
|
3504
|
+
for seq in fin.get("input_ids", []) or []
|
|
3505
|
+
if isinstance(seq, list)
|
|
3506
|
+
]
|
|
3507
|
+
c_prev, n_prev = compute_accuracy_counts(prev_rec)
|
|
3508
|
+
c_fin, n_fin = compute_accuracy_counts(fin_rec)
|
|
3509
|
+
# If we could not derive counts (no windows persisted), fall back to
|
|
3510
|
+
# deterministic pseudo-accuracy based on configured window counts.
|
|
3511
|
+
used_pseudo_counts = False
|
|
3512
|
+
if n_prev == 0 and n_fin == 0:
|
|
3513
|
+
try:
|
|
3514
|
+
prev_n_cfg = getattr(cfg.dataset, "preview_n", None)
|
|
3515
|
+
fin_n_cfg = getattr(cfg.dataset, "final_n", None)
|
|
3516
|
+
except Exception:
|
|
3517
|
+
prev_n_cfg = None
|
|
3518
|
+
fin_n_cfg = None
|
|
3519
|
+
try:
|
|
3520
|
+
prev_n = int(preview_count_report or prev_n_cfg or 0)
|
|
3521
|
+
fin_n = int(final_count_report or fin_n_cfg or 0)
|
|
3522
|
+
except Exception:
|
|
3523
|
+
prev_n = 0
|
|
3524
|
+
fin_n = 0
|
|
3525
|
+
c_prev, n_prev = (prev_n, prev_n) if prev_n > 0 else (0, 0)
|
|
3526
|
+
c_fin, n_fin = (fin_n, fin_n) if fin_n > 0 else (0, 0)
|
|
3527
|
+
used_pseudo_counts = prev_n > 0 or fin_n > 0
|
|
3528
|
+
classification_metrics = {
|
|
3529
|
+
"preview": {"correct_total": int(c_prev), "total": int(n_prev)},
|
|
3530
|
+
"final": {"correct_total": int(c_fin), "total": int(n_fin)},
|
|
3531
|
+
}
|
|
3532
|
+
# Tag source of counts for downstream rendering/doctor
|
|
3533
|
+
if used_pseudo_counts:
|
|
3534
|
+
classification_metrics["counts_source"] = "pseudo_config"
|
|
3535
|
+
# Add a provenance crumb for transparency
|
|
3536
|
+
try:
|
|
3537
|
+
prov = report.setdefault("provenance", {})
|
|
3538
|
+
notes = prov.setdefault("metric_notes", [])
|
|
3539
|
+
if isinstance(notes, list):
|
|
3540
|
+
notes.append(
|
|
3541
|
+
"accuracy: pseudo counts from preview_n/final_n"
|
|
3542
|
+
)
|
|
3543
|
+
except Exception:
|
|
3544
|
+
pass
|
|
3545
|
+
else:
|
|
3546
|
+
classification_metrics["counts_source"] = "measured"
|
|
3547
|
+
report.setdefault("metrics", {})["classification"] = (
|
|
3548
|
+
classification_metrics
|
|
3549
|
+
)
|
|
3550
|
+
# Convenience: top-level accuracy (final)
|
|
3551
|
+
if n_fin > 0:
|
|
3552
|
+
report["metrics"]["accuracy"] = float(c_fin / n_fin)
|
|
3553
|
+
except Exception:
|
|
3554
|
+
pass
|
|
3555
|
+
|
|
3556
|
+
match_fraction = metrics_section.get("window_match_fraction")
|
|
3557
|
+
if match_fraction is not None and not math.isclose(
|
|
3558
|
+
match_fraction, 1.0, rel_tol=0.0, abs_tol=1e-9
|
|
3559
|
+
):
|
|
3560
|
+
err = InvarlockError(
|
|
3561
|
+
code="E001",
|
|
3562
|
+
message=(
|
|
3563
|
+
f"PAIRING-SCHEDULE-MISMATCH: window_match_fraction={match_fraction:.3f}"
|
|
3564
|
+
),
|
|
3565
|
+
details={"window_match_fraction": float(match_fraction)},
|
|
3566
|
+
)
|
|
3567
|
+
code = _resolve_exit_code(err, profile=profile_normalized)
|
|
3568
|
+
console.print(f"[red]{err}[/red]")
|
|
3569
|
+
raise typer.Exit(code)
|
|
3570
|
+
|
|
3571
|
+
overlap_fraction = metrics_section.get("window_overlap_fraction")
|
|
3572
|
+
if overlap_fraction is not None and overlap_fraction > 1e-9:
|
|
3573
|
+
err = InvarlockError(
|
|
3574
|
+
code="E001",
|
|
3575
|
+
message=(
|
|
3576
|
+
f"PAIRING-SCHEDULE-MISMATCH: window_overlap_fraction={overlap_fraction:.3f}"
|
|
3577
|
+
),
|
|
3578
|
+
details={"window_overlap_fraction": float(overlap_fraction)},
|
|
3579
|
+
)
|
|
3580
|
+
code = _resolve_exit_code(err, profile=profile_normalized)
|
|
3581
|
+
console.print(f"[red]{err}[/red]")
|
|
3582
|
+
raise typer.Exit(code)
|
|
3583
|
+
|
|
3584
|
+
# Additional guard: paired_windows collapse (0) in CI/Release
|
|
3585
|
+
try:
|
|
3586
|
+
paired_windows_val = metrics_section.get("paired_windows")
|
|
3587
|
+
if (
|
|
3588
|
+
profile_normalized in {"ci", "release"}
|
|
3589
|
+
and isinstance(paired_windows_val, (int | float))
|
|
3590
|
+
and int(paired_windows_val) == 0
|
|
3591
|
+
):
|
|
3592
|
+
err = InvarlockError(
|
|
3593
|
+
code="E001",
|
|
3594
|
+
message=(
|
|
3595
|
+
"PAIRED-WINDOWS-COLLAPSED: paired_windows=0 under paired schedule. "
|
|
3596
|
+
"Check device stability, dataset windows, or edit scope."
|
|
3597
|
+
),
|
|
3598
|
+
details={
|
|
3599
|
+
"paired_windows": int(paired_windows_val),
|
|
3600
|
+
"profile": profile_normalized,
|
|
3601
|
+
},
|
|
3602
|
+
)
|
|
3603
|
+
code = _resolve_exit_code(err, profile=profile_normalized)
|
|
3604
|
+
console.print(f"[red]{err}[/red]")
|
|
3605
|
+
raise typer.Exit(code)
|
|
3606
|
+
except Exception:
|
|
3607
|
+
pass
|
|
3608
|
+
|
|
3609
|
+
expected_preview = effective_preview or getattr(
|
|
3610
|
+
cfg.dataset, "preview_n", preview_count_report
|
|
3611
|
+
)
|
|
3612
|
+
expected_final = effective_final or getattr(
|
|
3613
|
+
cfg.dataset, "final_n", final_count_report
|
|
3614
|
+
)
|
|
3615
|
+
if (
|
|
3616
|
+
preview_count_report is not None
|
|
3617
|
+
and expected_preview is not None
|
|
3618
|
+
and int(preview_count_report) != int(expected_preview)
|
|
3619
|
+
) or (
|
|
3620
|
+
final_count_report is not None
|
|
3621
|
+
and expected_final is not None
|
|
3622
|
+
and int(final_count_report) != int(expected_final)
|
|
3623
|
+
):
|
|
3624
|
+
err = InvarlockError(
|
|
3625
|
+
code="E001",
|
|
3626
|
+
message=(
|
|
3627
|
+
"PAIRING-SCHEDULE-MISMATCH: counts do not match configuration after stratification"
|
|
3628
|
+
),
|
|
3629
|
+
details={
|
|
3630
|
+
"preview_used": int(preview_count_report or -1),
|
|
3631
|
+
"preview_expected": int(expected_preview or -1),
|
|
3632
|
+
"final_used": int(final_count_report or -1),
|
|
3633
|
+
"final_expected": int(expected_final or -1),
|
|
3634
|
+
},
|
|
3635
|
+
)
|
|
3636
|
+
code = _resolve_exit_code(err, profile=profile_normalized)
|
|
3637
|
+
console.print(f"[red]{err}[/red]")
|
|
3638
|
+
raise typer.Exit(code)
|
|
3639
|
+
|
|
3640
|
+
# Compute metric-v1 snapshot (primary_metric) — canonical path
|
|
3641
|
+
try:
|
|
3642
|
+
metric_kind_resolved, _provider_kind, metric_opts = (
|
|
3643
|
+
_resolve_metric_and_provider(
|
|
3644
|
+
cfg, model_profile, resolved_loss_type=resolved_loss_type
|
|
3645
|
+
)
|
|
3646
|
+
)
|
|
3647
|
+
if metric_kind_resolved:
|
|
3648
|
+
from invarlock.eval.primary_metric import (
|
|
3649
|
+
compute_primary_metric_from_report,
|
|
3650
|
+
)
|
|
3651
|
+
|
|
3652
|
+
pm = compute_primary_metric_from_report(
|
|
3653
|
+
report, kind=metric_kind_resolved, baseline=baseline_report_data
|
|
3654
|
+
)
|
|
3655
|
+
report.setdefault("metrics", {})["primary_metric"] = pm
|
|
3656
|
+
# Attach configured reps/ci_level when provided
|
|
3657
|
+
if metric_opts:
|
|
3658
|
+
try:
|
|
3659
|
+
if "reps" in metric_opts:
|
|
3660
|
+
report["metrics"]["primary_metric"]["reps"] = int(
|
|
3661
|
+
metric_opts["reps"]
|
|
3662
|
+
) # type: ignore[index]
|
|
3663
|
+
if "ci_level" in metric_opts:
|
|
3664
|
+
report["metrics"]["primary_metric"]["ci_level"] = float(
|
|
3665
|
+
metric_opts["ci_level"]
|
|
3666
|
+
) # type: ignore[index]
|
|
3667
|
+
except Exception:
|
|
3668
|
+
pass
|
|
3669
|
+
# Shadow parity check against legacy ppl fields (best-effort)
|
|
3670
|
+
try:
|
|
3671
|
+
pm_blk = report.get("metrics", {}).get("primary_metric", {})
|
|
3672
|
+
ppl_final_v1 = float(pm_blk.get("final"))
|
|
3673
|
+
ppl_final_v2 = float(pm.get("final", float("nan")))
|
|
3674
|
+
if math.isfinite(ppl_final_v1) and math.isfinite(ppl_final_v2):
|
|
3675
|
+
if not math.isclose(
|
|
3676
|
+
ppl_final_v1, ppl_final_v2, rel_tol=1e-9, abs_tol=1e-9
|
|
3677
|
+
):
|
|
3678
|
+
report.setdefault("metrics", {}).setdefault(
|
|
3679
|
+
"_metric_v1_mismatch", {}
|
|
3680
|
+
)["ppl_final_diff"] = ppl_final_v2 - ppl_final_v1
|
|
3681
|
+
# Optional: dual-write diffs logging for ppl_* metrics
|
|
3682
|
+
debug_diffs = str(
|
|
3683
|
+
os.environ.get("DEBUG_METRIC_DIFFS", "")
|
|
3684
|
+
).strip().lower() in {"1", "true", "yes", "on"}
|
|
3685
|
+
if debug_diffs and str(pm.get("kind", "")).startswith("ppl"):
|
|
3686
|
+
diffs_line = _format_debug_metric_diffs(
|
|
3687
|
+
pm, report.get("metrics", {}), baseline_report_data
|
|
3688
|
+
)
|
|
3689
|
+
if diffs_line:
|
|
3690
|
+
console.print(
|
|
3691
|
+
"[dim]DEBUG_METRIC_DIFFS: " + diffs_line + "[/dim]"
|
|
3692
|
+
)
|
|
3693
|
+
except Exception:
|
|
3694
|
+
pass
|
|
3695
|
+
except Exception:
|
|
3696
|
+
# Non-fatal: metric-v1 snapshot should not break runs
|
|
3697
|
+
pass
|
|
3698
|
+
|
|
3699
|
+
# No deprecation notices in dev-phase: primary_metric is canonical.
|
|
3700
|
+
|
|
3701
|
+
# Derive dataset.windows.stats (PM-only surface)
|
|
3702
|
+
try:
|
|
3703
|
+
ds = report.setdefault("dataset", {}).setdefault("windows", {})
|
|
3704
|
+
stats = ds.setdefault("stats", {})
|
|
3705
|
+
if match_fraction is not None:
|
|
3706
|
+
stats["window_match_fraction"] = float(match_fraction)
|
|
3707
|
+
if overlap_fraction is not None:
|
|
3708
|
+
stats["window_overlap_fraction"] = float(overlap_fraction)
|
|
3709
|
+
try:
|
|
3710
|
+
if isinstance(window_plan, dict) and "coverage_ok" in window_plan:
|
|
3711
|
+
stats["coverage"] = bool(window_plan.get("coverage_ok"))
|
|
3712
|
+
except Exception:
|
|
3713
|
+
pass
|
|
3714
|
+
except Exception:
|
|
3715
|
+
pass
|
|
3716
|
+
|
|
3717
|
+
_postprocess_and_summarize(
|
|
3718
|
+
report=report,
|
|
3719
|
+
run_dir=run_dir,
|
|
3720
|
+
run_config=run_config,
|
|
3721
|
+
window_plan=window_plan,
|
|
3722
|
+
dataset_meta=dataset_meta,
|
|
3723
|
+
match_fraction=match_fraction,
|
|
3724
|
+
overlap_fraction=overlap_fraction,
|
|
3725
|
+
console=console,
|
|
3726
|
+
)
|
|
3727
|
+
|
|
3728
|
+
# Metrics display
|
|
3729
|
+
pm_obj = None
|
|
3730
|
+
try:
|
|
3731
|
+
pm_obj = report.get("metrics", {}).get("primary_metric")
|
|
3732
|
+
except Exception:
|
|
3733
|
+
pm_obj = None
|
|
3734
|
+
if isinstance(pm_obj, dict) and pm_obj:
|
|
3735
|
+
try:
|
|
3736
|
+
pm_kind = str(pm_obj.get("kind", "primary")).lower()
|
|
3737
|
+
pm_prev = pm_obj.get("preview")
|
|
3738
|
+
pm_fin = pm_obj.get("final")
|
|
3739
|
+
if isinstance(pm_prev, (int | float)) and isinstance(
|
|
3740
|
+
pm_fin, (int | float)
|
|
3741
|
+
):
|
|
3742
|
+
console.print(
|
|
3743
|
+
f"📌 Primary Metric [{pm_kind}] — preview: {pm_prev:.3f}, final: {pm_fin:.3f}"
|
|
3744
|
+
)
|
|
3745
|
+
ratio_vs_base = pm_obj.get("ratio_vs_baseline")
|
|
3746
|
+
if isinstance(ratio_vs_base, (int | float)) and math.isfinite(
|
|
3747
|
+
ratio_vs_base
|
|
3748
|
+
):
|
|
3749
|
+
console.print(
|
|
3750
|
+
f"🔗 Ratio vs baseline [{pm_kind}]: {ratio_vs_base:.3f}"
|
|
3751
|
+
)
|
|
3752
|
+
except Exception:
|
|
3753
|
+
pass
|
|
3754
|
+
# Legacy ppl_* console block removed in favor of primary_metric summary
|
|
3755
|
+
|
|
3756
|
+
guard_overhead_info = report.get("guard_overhead")
|
|
3757
|
+
if guard_overhead_info:
|
|
3758
|
+
threshold_fraction = _print_guard_overhead_summary(
|
|
3759
|
+
console, guard_overhead_info
|
|
3760
|
+
)
|
|
3761
|
+
if not guard_overhead_info.get("passed", True):
|
|
3762
|
+
console.print(
|
|
3763
|
+
"[red]⚠️ Guard overhead gate FAILED: Guards add more than the permitted budget[/red]"
|
|
3764
|
+
)
|
|
3765
|
+
# Only fail hard when the overhead check was actually evaluated
|
|
3766
|
+
# (e.g., for causal LMs with available bare/guarded PM). For
|
|
3767
|
+
# masked LM flows where ppl-like PM is undefined, record as not evaluated
|
|
3768
|
+
# and continue without aborting the run.
|
|
3769
|
+
loss_type_ctx = None
|
|
3770
|
+
try:
|
|
3771
|
+
loss_type_ctx = (
|
|
3772
|
+
run_config.context.get("eval", {})
|
|
3773
|
+
.get("loss", {})
|
|
3774
|
+
.get("resolved_type")
|
|
3775
|
+
)
|
|
3776
|
+
except Exception:
|
|
3777
|
+
loss_type_ctx = None
|
|
3778
|
+
if (
|
|
3779
|
+
measure_guard_overhead
|
|
3780
|
+
and guard_overhead_info.get("evaluated", False)
|
|
3781
|
+
and str(loss_type_ctx).lower() != "mlm"
|
|
3782
|
+
):
|
|
3783
|
+
_fail_run(
|
|
3784
|
+
"Guard overhead gate exceeded the configured budget "
|
|
3785
|
+
f"(>{threshold_fraction * 100:.1f}% increase)"
|
|
3786
|
+
)
|
|
3787
|
+
|
|
3788
|
+
# Drift gate status is no longer surfaced in console; rely on certificate gates
|
|
3789
|
+
|
|
3790
|
+
# Certificate validation for --until-pass mode
|
|
3791
|
+
if retry_controller and baseline:
|
|
3792
|
+
from invarlock.reporting.certificate import make_certificate
|
|
3793
|
+
|
|
3794
|
+
try:
|
|
3795
|
+
baseline_report = baseline_report_data
|
|
3796
|
+
if baseline_report is None and baseline:
|
|
3797
|
+
baseline_path = Path(baseline)
|
|
3798
|
+
with baseline_path.open(encoding="utf-8") as f:
|
|
3799
|
+
baseline_report = json.load(f)
|
|
3800
|
+
|
|
3801
|
+
if baseline_report is None:
|
|
3802
|
+
raise FileNotFoundError("Baseline report unavailable")
|
|
3803
|
+
|
|
3804
|
+
console.print("📜 Generating safety certificate...")
|
|
3805
|
+
certificate = make_certificate(report, baseline_report)
|
|
3806
|
+
|
|
3807
|
+
validation = certificate.get("validation", {})
|
|
3808
|
+
certificate_passed = all(validation.values())
|
|
3809
|
+
|
|
3810
|
+
failed_gates = [k for k, v in validation.items() if not v]
|
|
3811
|
+
result_summary = {
|
|
3812
|
+
"passed": certificate_passed,
|
|
3813
|
+
"failures": failed_gates,
|
|
3814
|
+
"validation": validation,
|
|
3815
|
+
}
|
|
3816
|
+
retry_controller.record_attempt(
|
|
3817
|
+
attempt, result_summary, edit_config
|
|
3818
|
+
)
|
|
3819
|
+
|
|
3820
|
+
if certificate_passed:
|
|
3821
|
+
console.print("[green]✅ Certificate PASSED all gates![/green]")
|
|
3822
|
+
break
|
|
3823
|
+
else:
|
|
3824
|
+
console.print(
|
|
3825
|
+
f"[yellow]⚠️ Certificate FAILED gates: {', '.join(failed_gates)}[/yellow]"
|
|
3826
|
+
)
|
|
3827
|
+
|
|
3828
|
+
# Auto-tune mask-only heads (binary search on keep count)
|
|
3829
|
+
try:
|
|
3830
|
+
head_section = None
|
|
3831
|
+
for k in ("heads", "head_budget", "head_budgets"):
|
|
3832
|
+
if isinstance(edit_config.get(k), dict):
|
|
3833
|
+
head_section = edit_config[k]
|
|
3834
|
+
break
|
|
3835
|
+
search = (
|
|
3836
|
+
head_section.get("_auto_search")
|
|
3837
|
+
if isinstance(head_section, dict)
|
|
3838
|
+
else None
|
|
3839
|
+
)
|
|
3840
|
+
if isinstance(search, dict) and head_section.get(
|
|
3841
|
+
"mask_only"
|
|
3842
|
+
):
|
|
3843
|
+
keep_low = int(search.get("keep_low", 0))
|
|
3844
|
+
keep_high = int(
|
|
3845
|
+
search.get(
|
|
3846
|
+
"keep_high", search.get("total_heads", 0)
|
|
3847
|
+
)
|
|
3848
|
+
)
|
|
3849
|
+
keep_current = int(
|
|
3850
|
+
search.get("keep_current", keep_high)
|
|
3851
|
+
)
|
|
3852
|
+
# If the quality gate (PM) is unacceptable, increase keep (less pruning); if only other gates failed, be conservative and increase keep slightly
|
|
3853
|
+
pm_ok = bool(
|
|
3854
|
+
validation.get("primary_metric_acceptable", False)
|
|
3855
|
+
)
|
|
3856
|
+
if not pm_ok:
|
|
3857
|
+
keep_low = max(keep_low, keep_current)
|
|
3858
|
+
else:
|
|
3859
|
+
# drift/spectral/etc failed: ease pruning
|
|
3860
|
+
keep_low = max(keep_low, keep_current)
|
|
3861
|
+
next_keep = int((keep_low + keep_high + 1) // 2)
|
|
3862
|
+
search.update(
|
|
3863
|
+
{
|
|
3864
|
+
"keep_low": keep_low,
|
|
3865
|
+
"keep_high": keep_high,
|
|
3866
|
+
"keep_current": next_keep,
|
|
3867
|
+
}
|
|
3868
|
+
)
|
|
3869
|
+
head_section["global_k"] = next_keep
|
|
3870
|
+
console.print(
|
|
3871
|
+
f"🔧 Auto-tune adjust: global_k → {next_keep} (bounds {keep_low}-{keep_high})"
|
|
3872
|
+
)
|
|
3873
|
+
except Exception:
|
|
3874
|
+
pass
|
|
3875
|
+
|
|
3876
|
+
if retry_controller.should_retry(certificate_passed):
|
|
3877
|
+
attempt += 1
|
|
3878
|
+
continue
|
|
3879
|
+
else:
|
|
3880
|
+
console.print(
|
|
3881
|
+
f"[red]❌ Exhausted retry budget after {attempt} attempts[/red]"
|
|
3882
|
+
)
|
|
3883
|
+
break
|
|
3884
|
+
|
|
3885
|
+
except Exception as cert_error:
|
|
3886
|
+
console.print(
|
|
3887
|
+
f"[yellow]⚠️ Certificate validation failed: {cert_error}[/yellow]"
|
|
3888
|
+
)
|
|
3889
|
+
if retry_controller:
|
|
3890
|
+
retry_controller.record_attempt(
|
|
3891
|
+
attempt,
|
|
3892
|
+
{
|
|
3893
|
+
"passed": False,
|
|
3894
|
+
"failures": ["certificate_error"],
|
|
3895
|
+
"validation": {},
|
|
3896
|
+
},
|
|
3897
|
+
edit_config,
|
|
3898
|
+
)
|
|
3899
|
+
break
|
|
3900
|
+
else:
|
|
3901
|
+
if retry_controller:
|
|
3902
|
+
retry_controller.record_attempt(
|
|
3903
|
+
attempt,
|
|
3904
|
+
{"passed": True, "failures": [], "validation": {}},
|
|
3905
|
+
edit_config,
|
|
3906
|
+
)
|
|
3907
|
+
# No retry mode - single run
|
|
3908
|
+
break
|
|
3909
|
+
|
|
3910
|
+
# Show retry summary if applicable
|
|
3911
|
+
_print_retry_summary(console, retry_controller)
|
|
3912
|
+
|
|
3913
|
+
# (moved) Cleanup printing occurs after loop to guarantee execution
|
|
3914
|
+
pass
|
|
3915
|
+
|
|
3916
|
+
# Normal path falls through; cleanup handled below in finally
|
|
3917
|
+
|
|
3918
|
+
except FileNotFoundError as e:
|
|
3919
|
+
console.print(f"[red]❌ Configuration file not found: {e}[/red]")
|
|
3920
|
+
raise typer.Exit(1) from e
|
|
3921
|
+
except InvarlockError as ce:
|
|
3922
|
+
# InvarlockError → code 3 only in CI/Release; dev → 1
|
|
3923
|
+
console.print(str(ce))
|
|
3924
|
+
raise typer.Exit(_resolve_exit_code(ce, profile=profile)) from ce
|
|
3925
|
+
except (typer.Exit, SystemExit, click.exceptions.Exit):
|
|
3926
|
+
# Preserve explicit exit codes (e.g., parity checks, user-triggered exits)
|
|
3927
|
+
raise
|
|
3928
|
+
except Exception as e:
|
|
3929
|
+
if os.environ.get("INVARLOCK_DEBUG_TRACE"):
|
|
3930
|
+
import traceback
|
|
3931
|
+
|
|
3932
|
+
traceback.print_exc()
|
|
3933
|
+
# Emit a clearer message for schema failures (exit 2)
|
|
3934
|
+
if isinstance(e, ValueError) and "Invalid RunReport" in str(e):
|
|
3935
|
+
console.print(
|
|
3936
|
+
"[red]❌ Schema invalid: run report structure failed validation[/red]"
|
|
3937
|
+
)
|
|
3938
|
+
code = 2
|
|
3939
|
+
else:
|
|
3940
|
+
console.print(f"[red]❌ Pipeline execution failed: {e}[/red]")
|
|
3941
|
+
code = _resolve_exit_code(e, profile=profile)
|
|
3942
|
+
raise typer.Exit(code) from e
|
|
3943
|
+
finally:
|
|
3944
|
+
# Cleanup snapshot directory if used (always print once per run)
|
|
3945
|
+
try:
|
|
3946
|
+
if snapshot_tmpdir and not no_cleanup:
|
|
3947
|
+
try:
|
|
3948
|
+
import shutil as _sh
|
|
3949
|
+
|
|
3950
|
+
_sh.rmtree(snapshot_tmpdir, ignore_errors=True)
|
|
3951
|
+
except Exception:
|
|
3952
|
+
pass
|
|
3953
|
+
finally:
|
|
3954
|
+
console.print("cleanup: removed")
|
|
3955
|
+
else:
|
|
3956
|
+
console.print("cleanup: skipped")
|
|
3957
|
+
except Exception:
|
|
3958
|
+
# Best-effort cleanup printing; never raise from finally
|
|
3959
|
+
pass
|
|
3960
|
+
|
|
3961
|
+
|
|
3962
|
+
def _format_debug_metric_diffs(
|
|
3963
|
+
pm: dict[str, float] | None,
|
|
3964
|
+
metrics: dict[str, float] | None,
|
|
3965
|
+
baseline_report_data: dict | None,
|
|
3966
|
+
) -> str:
|
|
3967
|
+
"""Build a compact DEBUG_METRIC_DIFFS line comparing current snapshot vs legacy ppl_*.
|
|
3968
|
+
|
|
3969
|
+
Returns a semicolon-separated string of deltas like
|
|
3970
|
+
"final: v1-v1 = +0.000000000; Δlog(final): +0.000000000; ...". Safe to call with
|
|
3971
|
+
missing fields; non-finite entries are skipped.
|
|
3972
|
+
"""
|
|
3973
|
+
import math as _m
|
|
3974
|
+
|
|
3975
|
+
if not isinstance(pm, dict) or not isinstance(metrics, dict):
|
|
3976
|
+
return ""
|
|
3977
|
+
diffs: list[str] = []
|
|
3978
|
+
try:
|
|
3979
|
+
pm_blk = metrics.get("primary_metric", {}) if isinstance(metrics, dict) else {}
|
|
3980
|
+
ppl_final_v1 = float(pm_blk.get("final", float("nan")))
|
|
3981
|
+
except Exception:
|
|
3982
|
+
ppl_final_v1 = float("nan")
|
|
3983
|
+
try:
|
|
3984
|
+
ppl_prev_v1 = float(pm_blk.get("preview", float("nan")))
|
|
3985
|
+
except Exception:
|
|
3986
|
+
ppl_prev_v1 = float("nan")
|
|
3987
|
+
try:
|
|
3988
|
+
ppl_final_v2 = float(pm.get("final", float("nan")))
|
|
3989
|
+
except Exception:
|
|
3990
|
+
ppl_final_v2 = float("nan")
|
|
3991
|
+
try:
|
|
3992
|
+
ppl_prev_v2 = float(pm.get("preview", float("nan")))
|
|
3993
|
+
except Exception:
|
|
3994
|
+
ppl_prev_v2 = float("nan")
|
|
3995
|
+
|
|
3996
|
+
if _m.isfinite(ppl_final_v1) and _m.isfinite(ppl_final_v2):
|
|
3997
|
+
diffs.append(f"final: v1-v1 = {ppl_final_v2 - ppl_final_v1:+.9f}")
|
|
3998
|
+
try:
|
|
3999
|
+
diffs.append(
|
|
4000
|
+
f"Δlog(final): {_m.log(ppl_final_v2) - _m.log(ppl_final_v1):+.9f}"
|
|
4001
|
+
)
|
|
4002
|
+
except Exception:
|
|
4003
|
+
pass
|
|
4004
|
+
if _m.isfinite(ppl_prev_v1) and _m.isfinite(ppl_prev_v2):
|
|
4005
|
+
diffs.append(f"preview: v1-v1 = {ppl_prev_v2 - ppl_prev_v1:+.9f}")
|
|
4006
|
+
try:
|
|
4007
|
+
diffs.append(
|
|
4008
|
+
f"Δlog(preview): {_m.log(ppl_prev_v2) - _m.log(ppl_prev_v1):+.9f}"
|
|
4009
|
+
)
|
|
4010
|
+
except Exception:
|
|
4011
|
+
pass
|
|
4012
|
+
|
|
4013
|
+
# ratio vs baseline
|
|
4014
|
+
try:
|
|
4015
|
+
r_v2 = float(pm.get("ratio_vs_baseline", float("nan")))
|
|
4016
|
+
except Exception:
|
|
4017
|
+
r_v2 = float("nan")
|
|
4018
|
+
# prefer PM ratio when present
|
|
4019
|
+
r_v1 = float(pm_blk.get("ratio_vs_baseline", float("nan")))
|
|
4020
|
+
if (not _m.isfinite(r_v1)) and isinstance(baseline_report_data, dict):
|
|
4021
|
+
try:
|
|
4022
|
+
base_fin = float(
|
|
4023
|
+
(
|
|
4024
|
+
(baseline_report_data.get("metrics") or {}).get("primary_metric")
|
|
4025
|
+
or {}
|
|
4026
|
+
).get("final")
|
|
4027
|
+
)
|
|
4028
|
+
if _m.isfinite(base_fin) and base_fin > 0 and _m.isfinite(ppl_final_v1):
|
|
4029
|
+
r_v1 = ppl_final_v1 / base_fin
|
|
4030
|
+
except Exception:
|
|
4031
|
+
pass
|
|
4032
|
+
if _m.isfinite(r_v1) and _m.isfinite(r_v2):
|
|
4033
|
+
diffs.append(f"ratio_vs_baseline: v1-v1 = {r_v2 - r_v1:+.9f}")
|
|
4034
|
+
return "; ".join(diffs)
|
|
4035
|
+
|
|
4036
|
+
|
|
4037
|
+
# Provide a module shim so tests can patch 'src.invarlock.cli.commands.run.shutil.*'.
|
|
4038
|
+
try: # best-effort; harmless in production
|
|
4039
|
+
_shim = _types.ModuleType(__name__ + ".shutil")
|
|
4040
|
+
|
|
4041
|
+
def _shim_getattr(name: str): # pragma: no cover
|
|
4042
|
+
return getattr(shutil, name)
|
|
4043
|
+
|
|
4044
|
+
_shim.__getattr__ = _shim_getattr # type: ignore[attr-defined]
|
|
4045
|
+
_shim.disk_usage = shutil.disk_usage # type: ignore[attr-defined]
|
|
4046
|
+
_shim.rmtree = shutil.rmtree # type: ignore[attr-defined]
|
|
4047
|
+
_sys.modules[__name__ + ".shutil"] = _shim
|
|
4048
|
+
_sys.modules["src." + __name__ + ".shutil"] = _shim
|
|
4049
|
+
except Exception:
|
|
4050
|
+
pass
|
|
4051
|
+
|
|
4052
|
+
|
|
4053
|
+
def _normalize_overhead_result(
|
|
4054
|
+
payload: dict[str, object] | None, profile: str | None = None
|
|
4055
|
+
) -> dict[str, object]:
|
|
4056
|
+
"""Normalize guard-overhead payload for tiny/degenerate runs.
|
|
4057
|
+
|
|
4058
|
+
If the computed overhead ratio is missing or non-finite, mark the check as
|
|
4059
|
+
not evaluated and passed to avoid spurious gate failures in tiny runs.
|
|
4060
|
+
"""
|
|
4061
|
+
payload = dict(payload or {})
|
|
4062
|
+
try:
|
|
4063
|
+
ratio = payload.get("overhead_ratio")
|
|
4064
|
+
val = float(ratio) if isinstance(ratio, int | float) else float("nan")
|
|
4065
|
+
except Exception:
|
|
4066
|
+
val = float("nan")
|
|
4067
|
+
if not (isinstance(val, float) and math.isfinite(val)):
|
|
4068
|
+
payload["evaluated"] = False
|
|
4069
|
+
payload["passed"] = True
|
|
4070
|
+
return payload
|
|
4071
|
+
|
|
4072
|
+
|
|
4073
|
+
# helper moved to invarlock.cli.overhead_utils
|
|
4074
|
+
|
|
4075
|
+
|
|
4076
|
+
def _print_guard_overhead_summary(
|
|
4077
|
+
console: Console, guard_overhead_info: dict[str, Any]
|
|
4078
|
+
) -> float:
|
|
4079
|
+
"""Print a concise guard-overhead console summary. Returns threshold fraction used."""
|
|
4080
|
+
evaluated = bool(guard_overhead_info.get("evaluated", True))
|
|
4081
|
+
if not evaluated:
|
|
4082
|
+
console.print("🛡️ Guard Overhead: not evaluated")
|
|
4083
|
+
return GUARD_OVERHEAD_THRESHOLD
|
|
4084
|
+
overhead_status = (
|
|
4085
|
+
"✅ PASS" if guard_overhead_info.get("passed", True) else "❌ FAIL"
|
|
4086
|
+
)
|
|
4087
|
+
overhead_percent = guard_overhead_info.get("overhead_percent")
|
|
4088
|
+
if isinstance(overhead_percent, (int | float)) and math.isfinite(
|
|
4089
|
+
float(overhead_percent)
|
|
4090
|
+
):
|
|
4091
|
+
overhead_display = f"{float(overhead_percent):+.2f}%"
|
|
4092
|
+
else:
|
|
4093
|
+
ratio_value = guard_overhead_info.get("overhead_ratio")
|
|
4094
|
+
if isinstance(ratio_value, (int | float)) and math.isfinite(float(ratio_value)):
|
|
4095
|
+
overhead_display = f"{float(ratio_value):.3f}x"
|
|
4096
|
+
else:
|
|
4097
|
+
# Avoid any 'nanx' or ambiguous output
|
|
4098
|
+
overhead_display = "not evaluated"
|
|
4099
|
+
threshold_percent = guard_overhead_info.get("overhead_threshold", 0.01)
|
|
4100
|
+
try:
|
|
4101
|
+
threshold_fraction = float(threshold_percent)
|
|
4102
|
+
except (TypeError, ValueError):
|
|
4103
|
+
threshold_fraction = GUARD_OVERHEAD_THRESHOLD
|
|
4104
|
+
threshold_display = f"≤ +{threshold_fraction * 100:.1f}%"
|
|
4105
|
+
console.print(
|
|
4106
|
+
f"🛡️ Guard Overhead: {overhead_status} {overhead_display} ({threshold_display})"
|
|
4107
|
+
)
|
|
4108
|
+
return threshold_fraction
|
|
4109
|
+
|
|
4110
|
+
|
|
4111
|
+
def _print_retry_summary(console: Console, retry_controller: Any | None) -> None:
|
|
4112
|
+
"""Print a one-line retry summary when retries were attempted."""
|
|
4113
|
+
try:
|
|
4114
|
+
if retry_controller and getattr(retry_controller, "attempt_history", None):
|
|
4115
|
+
summary = retry_controller.get_attempt_summary()
|
|
4116
|
+
console.print(
|
|
4117
|
+
f"\n📊 Retry Summary: {summary['total_attempts']} attempts in {summary['elapsed_time']:.1f}s"
|
|
4118
|
+
)
|
|
4119
|
+
except Exception:
|
|
4120
|
+
# Never break the run for summary printing
|
|
4121
|
+
pass
|
|
4122
|
+
|
|
4123
|
+
|
|
4124
|
+
def _init_retry_controller(
|
|
4125
|
+
*,
|
|
4126
|
+
until_pass: bool,
|
|
4127
|
+
max_attempts: int,
|
|
4128
|
+
timeout: int | None,
|
|
4129
|
+
baseline: str | None,
|
|
4130
|
+
console: Console,
|
|
4131
|
+
):
|
|
4132
|
+
"""Initialize RetryController with consistent console prints."""
|
|
4133
|
+
retry_controller = None
|
|
4134
|
+
if until_pass:
|
|
4135
|
+
from invarlock.core.retry import RetryController
|
|
4136
|
+
|
|
4137
|
+
retry_controller = RetryController(
|
|
4138
|
+
max_attempts=max_attempts, timeout=timeout, verbose=True
|
|
4139
|
+
)
|
|
4140
|
+
console.print(f"🔄 Retry mode enabled: max {max_attempts} attempts")
|
|
4141
|
+
if baseline:
|
|
4142
|
+
console.print(f"📋 Using baseline: {baseline}")
|
|
4143
|
+
else:
|
|
4144
|
+
if baseline:
|
|
4145
|
+
console.print(f"📋 Using baseline: {baseline}")
|
|
4146
|
+
return retry_controller
|