invarlock 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (132) hide show
  1. invarlock/__init__.py +33 -0
  2. invarlock/__main__.py +10 -0
  3. invarlock/_data/runtime/profiles/ci_cpu.yaml +15 -0
  4. invarlock/_data/runtime/profiles/release.yaml +23 -0
  5. invarlock/_data/runtime/tiers.yaml +76 -0
  6. invarlock/adapters/__init__.py +102 -0
  7. invarlock/adapters/_capabilities.py +45 -0
  8. invarlock/adapters/auto.py +99 -0
  9. invarlock/adapters/base.py +530 -0
  10. invarlock/adapters/base_types.py +85 -0
  11. invarlock/adapters/hf_bert.py +852 -0
  12. invarlock/adapters/hf_gpt2.py +403 -0
  13. invarlock/adapters/hf_llama.py +485 -0
  14. invarlock/adapters/hf_mixin.py +383 -0
  15. invarlock/adapters/hf_onnx.py +112 -0
  16. invarlock/adapters/hf_t5.py +137 -0
  17. invarlock/adapters/py.typed +1 -0
  18. invarlock/assurance/__init__.py +43 -0
  19. invarlock/cli/__init__.py +8 -0
  20. invarlock/cli/__main__.py +8 -0
  21. invarlock/cli/_evidence.py +25 -0
  22. invarlock/cli/_json.py +75 -0
  23. invarlock/cli/adapter_auto.py +162 -0
  24. invarlock/cli/app.py +287 -0
  25. invarlock/cli/commands/__init__.py +26 -0
  26. invarlock/cli/commands/certify.py +403 -0
  27. invarlock/cli/commands/doctor.py +1358 -0
  28. invarlock/cli/commands/explain_gates.py +151 -0
  29. invarlock/cli/commands/export_html.py +100 -0
  30. invarlock/cli/commands/plugins.py +1331 -0
  31. invarlock/cli/commands/report.py +354 -0
  32. invarlock/cli/commands/run.py +4146 -0
  33. invarlock/cli/commands/verify.py +1040 -0
  34. invarlock/cli/config.py +396 -0
  35. invarlock/cli/constants.py +68 -0
  36. invarlock/cli/device.py +92 -0
  37. invarlock/cli/doctor_helpers.py +74 -0
  38. invarlock/cli/errors.py +6 -0
  39. invarlock/cli/overhead_utils.py +60 -0
  40. invarlock/cli/provenance.py +66 -0
  41. invarlock/cli/utils.py +41 -0
  42. invarlock/config.py +56 -0
  43. invarlock/core/__init__.py +62 -0
  44. invarlock/core/abi.py +15 -0
  45. invarlock/core/api.py +274 -0
  46. invarlock/core/auto_tuning.py +317 -0
  47. invarlock/core/bootstrap.py +226 -0
  48. invarlock/core/checkpoint.py +221 -0
  49. invarlock/core/contracts.py +73 -0
  50. invarlock/core/error_utils.py +64 -0
  51. invarlock/core/events.py +298 -0
  52. invarlock/core/exceptions.py +95 -0
  53. invarlock/core/registry.py +481 -0
  54. invarlock/core/retry.py +146 -0
  55. invarlock/core/runner.py +2041 -0
  56. invarlock/core/types.py +154 -0
  57. invarlock/edits/__init__.py +12 -0
  58. invarlock/edits/_edit_utils.py +249 -0
  59. invarlock/edits/_external_utils.py +268 -0
  60. invarlock/edits/noop.py +47 -0
  61. invarlock/edits/py.typed +1 -0
  62. invarlock/edits/quant_rtn.py +801 -0
  63. invarlock/edits/registry.py +166 -0
  64. invarlock/eval/__init__.py +23 -0
  65. invarlock/eval/bench.py +1207 -0
  66. invarlock/eval/bootstrap.py +50 -0
  67. invarlock/eval/data.py +2052 -0
  68. invarlock/eval/metrics.py +2167 -0
  69. invarlock/eval/primary_metric.py +767 -0
  70. invarlock/eval/probes/__init__.py +24 -0
  71. invarlock/eval/probes/fft.py +139 -0
  72. invarlock/eval/probes/mi.py +213 -0
  73. invarlock/eval/probes/post_attention.py +323 -0
  74. invarlock/eval/providers/base.py +67 -0
  75. invarlock/eval/providers/seq2seq.py +111 -0
  76. invarlock/eval/providers/text_lm.py +113 -0
  77. invarlock/eval/providers/vision_text.py +93 -0
  78. invarlock/eval/py.typed +1 -0
  79. invarlock/guards/__init__.py +18 -0
  80. invarlock/guards/_contracts.py +9 -0
  81. invarlock/guards/invariants.py +640 -0
  82. invarlock/guards/policies.py +805 -0
  83. invarlock/guards/py.typed +1 -0
  84. invarlock/guards/rmt.py +2097 -0
  85. invarlock/guards/spectral.py +1419 -0
  86. invarlock/guards/tier_config.py +354 -0
  87. invarlock/guards/variance.py +3298 -0
  88. invarlock/guards_ref/__init__.py +15 -0
  89. invarlock/guards_ref/rmt_ref.py +40 -0
  90. invarlock/guards_ref/spectral_ref.py +135 -0
  91. invarlock/guards_ref/variance_ref.py +60 -0
  92. invarlock/model_profile.py +353 -0
  93. invarlock/model_utils.py +221 -0
  94. invarlock/observability/__init__.py +10 -0
  95. invarlock/observability/alerting.py +535 -0
  96. invarlock/observability/core.py +546 -0
  97. invarlock/observability/exporters.py +565 -0
  98. invarlock/observability/health.py +588 -0
  99. invarlock/observability/metrics.py +457 -0
  100. invarlock/observability/py.typed +1 -0
  101. invarlock/observability/utils.py +553 -0
  102. invarlock/plugins/__init__.py +12 -0
  103. invarlock/plugins/hello_guard.py +33 -0
  104. invarlock/plugins/hf_awq_adapter.py +82 -0
  105. invarlock/plugins/hf_bnb_adapter.py +79 -0
  106. invarlock/plugins/hf_gptq_adapter.py +78 -0
  107. invarlock/plugins/py.typed +1 -0
  108. invarlock/py.typed +1 -0
  109. invarlock/reporting/__init__.py +7 -0
  110. invarlock/reporting/certificate.py +3221 -0
  111. invarlock/reporting/certificate_schema.py +244 -0
  112. invarlock/reporting/dataset_hashing.py +215 -0
  113. invarlock/reporting/guards_analysis.py +948 -0
  114. invarlock/reporting/html.py +32 -0
  115. invarlock/reporting/normalizer.py +235 -0
  116. invarlock/reporting/policy_utils.py +517 -0
  117. invarlock/reporting/primary_metric_utils.py +265 -0
  118. invarlock/reporting/render.py +1442 -0
  119. invarlock/reporting/report.py +903 -0
  120. invarlock/reporting/report_types.py +278 -0
  121. invarlock/reporting/utils.py +175 -0
  122. invarlock/reporting/validate.py +631 -0
  123. invarlock/security.py +176 -0
  124. invarlock/sparsity_utils.py +323 -0
  125. invarlock/utils/__init__.py +150 -0
  126. invarlock/utils/digest.py +45 -0
  127. invarlock-0.2.0.dist-info/METADATA +586 -0
  128. invarlock-0.2.0.dist-info/RECORD +132 -0
  129. invarlock-0.2.0.dist-info/WHEEL +5 -0
  130. invarlock-0.2.0.dist-info/entry_points.txt +20 -0
  131. invarlock-0.2.0.dist-info/licenses/LICENSE +201 -0
  132. invarlock-0.2.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,4146 @@
1
+ """
2
+ InvarLock CLI Run Command
3
+ =====================
4
+
5
+ Run a guarded pipeline from a YAML config. Intended for local smokes,
6
+ plugin demos, and development. Advanced: for pairwise certification,
7
+ prefer Compare & Certify via `invarlock certify --baseline ... --subject ...`.
8
+ """
9
+
10
+ import copy
11
+ import hashlib
12
+ import json
13
+ import math
14
+ import os
15
+ import random
16
+ import shutil
17
+ import sys as _sys
18
+ import types as _types
19
+ from array import array
20
+ from collections.abc import Iterable, Sequence
21
+ from datetime import datetime
22
+ from pathlib import Path
23
+ from types import SimpleNamespace
24
+ from typing import Any
25
+
26
+ import click
27
+ import numpy as np
28
+ import psutil
29
+ import typer
30
+ from rich.console import Console
31
+
32
+ try:
33
+ import torch
34
+ except ImportError:
35
+ torch = None # type: ignore[assignment]
36
+
37
+ from invarlock.cli.errors import InvarlockError
38
+ from invarlock.cli.utils import (
39
+ coerce_float as _coerce_float,
40
+ )
41
+ from invarlock.cli.utils import (
42
+ coerce_int as _coerce_int,
43
+ )
44
+ from invarlock.cli.utils import (
45
+ coerce_option as _coerce_option,
46
+ )
47
+ from invarlock.core.exceptions import (
48
+ ConfigError as _CfgErr,
49
+ )
50
+ from invarlock.core.exceptions import (
51
+ DataError as _DataErr,
52
+ )
53
+ from invarlock.core.exceptions import (
54
+ ValidationError as _ValErr,
55
+ )
56
+ from invarlock.model_profile import detect_model_profile, resolve_tokenizer
57
+ from invarlock.model_utils import set_seed
58
+ from invarlock.reporting.validate import validate_guard_overhead
59
+
60
+ from ..config import (
61
+ InvarLockConfig,
62
+ )
63
+ from ..overhead_utils import _extract_pm_snapshot_for_overhead
64
+
65
+ console = Console()
66
+ LIGHT_IMPORT = os.getenv("INVARLOCK_LIGHT_IMPORT", "").strip().lower() in {
67
+ "1",
68
+ "true",
69
+ "yes",
70
+ }
71
+
72
+ # Release profile window planning constants
73
+ RELEASE_BUFFER_FRACTION = 0.12
74
+ RELEASE_MIN_WINDOWS_PER_ARM = 200
75
+ RELEASE_CALIBRATION_MIN = 16
76
+ RELEASE_CALIBRATION_MAX = 24
77
+ GUARD_OVERHEAD_THRESHOLD = 0.01
78
+
79
+
80
+ # Common dataset split aliases we probe in order when not explicitly set
81
+ SPLIT_ALIASES: tuple[str, ...] = ("validation", "val", "dev", "eval", "test")
82
+
83
+
84
+ def _choose_dataset_split(
85
+ *, requested: str | None, available: list[str] | None
86
+ ) -> tuple[str, bool]:
87
+ """
88
+ Choose a dataset split deterministically.
89
+
90
+ Returns (split, used_fallback). If `requested` is provided, returns it verbatim.
91
+ Else tries SPLIT_ALIASES in order; if none present, falls back to the first
92
+ available split (sorted for determinism). If `available` is None/empty, returns
93
+ ('validation', True) as a last resort so the run does not crash.
94
+ """
95
+ try:
96
+ if isinstance(requested, str) and requested:
97
+ return requested, False
98
+ except Exception:
99
+ pass
100
+ avail = list(available) if isinstance(available, list) and available else []
101
+ if avail:
102
+ for cand in SPLIT_ALIASES:
103
+ if cand in avail:
104
+ return cand, True
105
+ return sorted(avail)[0], True
106
+ return "validation", True
107
+
108
+
109
+ def _persist_ref_masks(core_report: Any, run_dir: Path) -> Path | None:
110
+ """Persist reference keep indices to artifact if present."""
111
+
112
+ edit_section = (
113
+ core_report.get("edit")
114
+ if isinstance(core_report, dict)
115
+ else getattr(core_report, "edit", None)
116
+ )
117
+ if not isinstance(edit_section, dict):
118
+ return None
119
+
120
+ artifacts_section = edit_section.get("artifacts")
121
+ if not isinstance(artifacts_section, dict):
122
+ return None
123
+
124
+ mask_payload = artifacts_section.get("mask_payload")
125
+ if not isinstance(mask_payload, dict) or not mask_payload:
126
+ return None
127
+
128
+ payload_copy = copy.deepcopy(mask_payload)
129
+ meta_section = payload_copy.setdefault("meta", {})
130
+ meta_section.setdefault("generated_at", datetime.now().isoformat())
131
+
132
+ target_dir = run_dir / "artifacts" / "edit_masks"
133
+ target_dir.mkdir(parents=True, exist_ok=True)
134
+ mask_path = target_dir / "masks.json"
135
+ with mask_path.open("w", encoding="utf-8") as handle:
136
+ json.dump(payload_copy, handle, indent=2, sort_keys=True)
137
+ handle.write("\n")
138
+
139
+ return mask_path
140
+
141
+
142
+ def _resolve_exit_code(exc: Exception, *, profile: str | None) -> int:
143
+ """Resolve exit code based on exception type and profile.
144
+
145
+ - ValueError("Invalid RunReport...") → 2 (schema/shape issue)
146
+ - InvarlockError in CI/Release → 3 (hard abort)
147
+ - All other cases → 1 (generic failure)
148
+ """
149
+ try:
150
+ prof = (profile or "").strip().lower()
151
+ except Exception:
152
+ prof = ""
153
+ # Schema/validation classes and known shapes → exit 2
154
+ if isinstance(exc, _CfgErr | _ValErr | _DataErr):
155
+ return 2
156
+ if isinstance(exc, ValueError) and "Invalid RunReport" in str(exc):
157
+ return 2
158
+ if isinstance(exc, InvarlockError) and prof in {"ci", "release"}:
159
+ return 3
160
+ return 1
161
+
162
+
163
+ ## NOTE: Deprecated legacy helper `_check_pairability_or_abort` was removed.
164
+ ## Provider parity and pairing guarantees are enforced via guard digests and
165
+ ## invariant checks during run execution.
166
+
167
+
168
+ def _hash_sequences(seqs: Sequence[Sequence[int]] | Iterable[Sequence[int]]) -> str:
169
+ """Compute a stable digest for a sequence of integer token sequences."""
170
+ hasher = hashlib.blake2s(digest_size=16)
171
+ for seq in seqs:
172
+ arr = array("I", (int(token) & 0xFFFFFFFF for token in seq))
173
+ hasher.update(arr.tobytes())
174
+ return hasher.hexdigest()
175
+
176
+
177
+ def _compute_mask_positions_digest(windows: dict[str, Any]) -> str | None:
178
+ """Compute a rolled hash of MLM mask positions across windows.
179
+
180
+ Expects windows of the shape { 'preview': {...}, 'final': {...} } with
181
+ 'labels' and optional 'window_ids' in each section. Positions where
182
+ labels != -100 are treated as masked.
183
+ """
184
+ try:
185
+ # Simple, dependency-light digest of positions where labels != -100
186
+ hasher = hashlib.blake2s(digest_size=16)
187
+ any_masked = False
188
+ for arm in ("preview", "final"):
189
+ sec = windows.get(arm)
190
+ if not isinstance(sec, dict):
191
+ continue
192
+ labels = sec.get("labels")
193
+ if not isinstance(labels, list) or not labels:
194
+ continue
195
+ hasher.update(arm.encode("utf-8"))
196
+ for row in labels:
197
+ row_list = _tensor_or_list_to_ints(row)
198
+ if not row_list:
199
+ continue
200
+ found = False
201
+ for idx, v in enumerate(row_list):
202
+ if int(v) != -100:
203
+ hasher.update(b"1")
204
+ hasher.update(idx.to_bytes(4, "little", signed=False))
205
+ found = True
206
+ if found:
207
+ any_masked = True
208
+ hasher.update(b"|")
209
+ if not any_masked:
210
+ return None
211
+ digest = hasher.hexdigest()
212
+ return digest if digest else None
213
+ except Exception:
214
+ return None
215
+
216
+
217
+ def _to_int_list(values: Sequence[int] | Iterable[int]) -> list[int]:
218
+ return [int(v) for v in values]
219
+
220
+
221
+ def _tensor_or_list_to_ints(values: Any) -> list[int]:
222
+ """Coerce possible tensor/list-like inputs to a list[int]."""
223
+ try:
224
+ # Torch tensors: `.tolist()` path
225
+ if torch is not None and hasattr(values, "tolist"):
226
+ raw = values.tolist()
227
+ if isinstance(raw, list):
228
+ return _to_int_list(raw)
229
+ try:
230
+ return _to_int_list(list(raw))
231
+ except Exception:
232
+ pass
233
+ # Numpy arrays: treat as list-like
234
+ if isinstance(values, np.ndarray | list | tuple):
235
+ return _to_int_list(list(values))
236
+ # Iterables of ints
237
+ if isinstance(values, Iterable):
238
+ return _to_int_list(values)
239
+ except Exception:
240
+ pass
241
+ return []
242
+
243
+
244
+ def _safe_int(value: Any, default: int = 0) -> int:
245
+ try:
246
+ return int(value)
247
+ except (TypeError, ValueError):
248
+ return default
249
+
250
+
251
+ def _derive_mlm_seed(base_seed: int, window_id: str | int, position: int) -> int:
252
+ payload = f"{base_seed}:{window_id}:{position}".encode()
253
+ digest = hashlib.blake2s(payload, digest_size=8).digest()
254
+ return int.from_bytes(digest, "little", signed=False)
255
+
256
+
257
+ def _apply_mlm_masks(
258
+ records: list[dict[str, Any]],
259
+ *,
260
+ tokenizer: Any,
261
+ mask_prob: float,
262
+ seed: int,
263
+ random_token_prob: float,
264
+ original_token_prob: float,
265
+ prefix: str,
266
+ ) -> tuple[int, list[int]]:
267
+ """Apply basic BERT-style MLM masking to tokenized records in-place."""
268
+ if mask_prob <= 0.0:
269
+ zeroed = []
270
+ for record in records:
271
+ length = len(record["input_ids"])
272
+ record["labels"] = [-100] * length
273
+ record["mlm_masked"] = 0
274
+ zeroed.append(0)
275
+ return 0, zeroed
276
+
277
+ vocab_size = _safe_int(getattr(tokenizer, "vocab_size", 0))
278
+ # Require an explicit mask token id for MLM
279
+ mask_token_id = getattr(tokenizer, "mask_token_id", None)
280
+ if mask_token_id is None:
281
+ raise RuntimeError(
282
+ "Tokenizer does not define mask_token_id; required for MLM evaluation."
283
+ )
284
+ try:
285
+ mask_token_id = int(mask_token_id)
286
+ except Exception:
287
+ mask_token_id = _safe_int(mask_token_id, 0)
288
+
289
+ # Build special token id set to avoid masking them
290
+ special_ids = set()
291
+ for attr in (
292
+ "cls_token_id",
293
+ "sep_token_id",
294
+ "bos_token_id",
295
+ "eos_token_id",
296
+ "pad_token_id",
297
+ ):
298
+ val = getattr(tokenizer, attr, None)
299
+ if val is not None:
300
+ try:
301
+ special_ids.add(int(val))
302
+ except Exception:
303
+ pass
304
+ try:
305
+ special_ids.update(
306
+ int(t) for t in getattr(tokenizer, "all_special_ids", []) or []
307
+ )
308
+ except Exception:
309
+ pass
310
+
311
+ masked_total = 0
312
+ masked_counts: list[int] = []
313
+ for idx_record, record in enumerate(records):
314
+ window_id = record.get("window_id", f"{prefix}:{idx_record}")
315
+ input_ids = _tensor_or_list_to_ints(record.get("input_ids", []))
316
+ attention = _tensor_or_list_to_ints(record.get("attention_mask", []))
317
+ labels = [-100] * len(input_ids)
318
+
319
+ masked = 0
320
+ for pos, (tok, att) in enumerate(zip(input_ids, attention, strict=False)):
321
+ if not att:
322
+ continue
323
+ if int(tok) in special_ids:
324
+ continue
325
+ if random.random() < mask_prob:
326
+ rng = random.Random(_derive_mlm_seed(seed, window_id, pos))
327
+ labels[pos] = int(tok)
328
+ r = rng.random()
329
+ if r < 1.0 - (random_token_prob + original_token_prob):
330
+ input_ids[pos] = mask_token_id
331
+ elif r < 1.0 - original_token_prob and vocab_size > 0:
332
+ rng2 = random.Random(_derive_mlm_seed(seed + 17, window_id, pos))
333
+ input_ids[pos] = rng2.randint(0, max(1, vocab_size - 1))
334
+ masked += 1
335
+
336
+ # Ensure at least one masked token for stability
337
+ if masked == 0:
338
+ candidate_positions = [
339
+ p
340
+ for p, (tok, att) in enumerate(zip(input_ids, attention, strict=False))
341
+ if att and int(tok) not in special_ids
342
+ ]
343
+ if candidate_positions:
344
+ pos = candidate_positions[len(candidate_positions) // 2]
345
+ rng = random.Random(_derive_mlm_seed(seed + 17, window_id, pos))
346
+ labels[pos] = int(input_ids[pos])
347
+ masked = 1
348
+ r = rng.random()
349
+ if r < 1.0 - (random_token_prob + original_token_prob):
350
+ input_ids[pos] = mask_token_id
351
+ elif r < 1.0 - original_token_prob and vocab_size > 0:
352
+ input_ids[pos] = rng.randrange(vocab_size)
353
+
354
+ record["input_ids"] = _to_int_list(input_ids)
355
+ record["attention_mask"] = _to_int_list(attention)
356
+ record["labels"] = _to_int_list(labels)
357
+ record["mlm_masked"] = masked
358
+ masked_total += masked
359
+ masked_counts.append(masked)
360
+
361
+ return masked_total, masked_counts
362
+
363
+
364
+ def _tokenizer_digest(tokenizer: Any) -> str:
365
+ """Compute a stable digest for a tokenizer config.
366
+
367
+ Tries, in order: get_vocab().items(), `vocab` attribute if list-like, else
368
+ hashes a small set of informative attributes.
369
+ """
370
+ try:
371
+ if hasattr(tokenizer, "get_vocab"):
372
+ try:
373
+ items = getattr(tokenizer.get_vocab(), "items", None)
374
+ if callable(items):
375
+ pairs = list(items())
376
+ # Filter non-string keys for stability
377
+ pairs = [
378
+ (str(k), int(v)) for k, v in pairs if isinstance(k, str | int)
379
+ ]
380
+ payload = json.dumps(sorted(pairs), separators=(",", ":")).encode()
381
+ return hashlib.sha256(payload).hexdigest()
382
+ except Exception:
383
+ pass
384
+ # Fallback to `vocab` attribute (e.g., list of pairs)
385
+ vocab = getattr(tokenizer, "vocab", None)
386
+ if isinstance(vocab, list):
387
+ try:
388
+ payload = json.dumps(
389
+ [(str(k), int(v)) for k, v in vocab], separators=(",", ":")
390
+ ).encode()
391
+ return hashlib.sha256(payload).hexdigest()
392
+ except Exception:
393
+ pass
394
+ # Last resort: small attribute set
395
+ attrs = {
396
+ "name": getattr(tokenizer, "name_or_path", None),
397
+ "eos": getattr(tokenizer, "eos_token", None),
398
+ "pad": getattr(tokenizer, "pad_token", None),
399
+ "size": _safe_int(getattr(tokenizer, "vocab_size", 0)),
400
+ }
401
+ return hashlib.sha256(json.dumps(attrs, sort_keys=True).encode()).hexdigest()
402
+ except Exception:
403
+ return "unknown-tokenizer"
404
+
405
+
406
+ def _extract_pairing_schedule(report: dict[str, Any] | None) -> dict[str, Any] | None:
407
+ """Extract sanitized pairing schedule from a baseline-like report shape.
408
+
409
+ Returns a dict with 'preview' and 'final' sections when present. Each section
410
+ contains window_ids, input_ids, attention_masks, optional labels and counts.
411
+ """
412
+ if not isinstance(report, dict):
413
+ return None
414
+ windows = report.get("evaluation_windows")
415
+ if not isinstance(windows, dict):
416
+ return None
417
+
418
+ def _sanitize(section_key: str) -> dict[str, Any] | None:
419
+ section = windows.get(section_key)
420
+ if not isinstance(section, dict):
421
+ return None
422
+ window_ids = list(section.get("window_ids", []))
423
+ input_ids_raw = section.get("input_ids", [])
424
+ if not isinstance(input_ids_raw, list):
425
+ return None
426
+ input_ids = [list(seq) for seq in input_ids_raw]
427
+ attention_raw = section.get("attention_masks")
428
+ if isinstance(attention_raw, list) and all(
429
+ isinstance(mask, list) for mask in attention_raw
430
+ ):
431
+ attention_masks = [list(mask) for mask in attention_raw]
432
+ else:
433
+ attention_masks = [
434
+ [1 if token != 0 else 0 for token in seq] for seq in input_ids
435
+ ]
436
+
437
+ labels_raw = section.get("labels")
438
+ labels: list[list[int]] | None = None
439
+ if isinstance(labels_raw, list) and labels_raw:
440
+ labels = []
441
+ for idx, raw_label in enumerate(labels_raw):
442
+ label_list = _tensor_or_list_to_ints(raw_label)
443
+ if idx < len(input_ids):
444
+ target_len = len(input_ids[idx])
445
+ if len(label_list) < target_len:
446
+ label_list = label_list + [-100] * (
447
+ target_len - len(label_list)
448
+ )
449
+ elif len(label_list) > target_len:
450
+ label_list = label_list[:target_len]
451
+ labels.append(label_list)
452
+
453
+ masked_counts = None
454
+ if isinstance(section.get("masked_token_counts"), list):
455
+ masked_counts = [int(v) for v in section["masked_token_counts"]]
456
+ actual_counts = None
457
+ if isinstance(section.get("actual_token_counts"), list):
458
+ actual_counts = [int(v) for v in section["actual_token_counts"]]
459
+
460
+ payload: dict[str, Any] = {
461
+ "window_ids": window_ids,
462
+ "input_ids": input_ids,
463
+ "attention_masks": attention_masks,
464
+ }
465
+ if labels is not None:
466
+ payload["labels"] = labels
467
+ if masked_counts is not None:
468
+ payload["masked_token_counts"] = masked_counts
469
+ if actual_counts is not None:
470
+ payload["actual_token_counts"] = actual_counts
471
+ return payload
472
+
473
+ preview = _sanitize("preview")
474
+ final = _sanitize("final")
475
+ if preview and final:
476
+ return {"preview": preview, "final": final}
477
+ return None
478
+
479
+
480
+ def _prepare_config_for_run(
481
+ *,
482
+ config_path: str,
483
+ profile: str | None,
484
+ edit: str | None,
485
+ tier: str | None,
486
+ probes: int | None,
487
+ console: Console,
488
+ ) -> InvarLockConfig:
489
+ """Load InvarLock config and apply CLI/profile overrides deterministically."""
490
+ # Local import to allow test monkeypatching of invarlock.cli.config functions
491
+ from ..config import (
492
+ apply_edit_override as _apply_edit_override,
493
+ )
494
+ from ..config import (
495
+ apply_profile as _apply_profile,
496
+ )
497
+ from ..config import (
498
+ load_config as _load_config,
499
+ )
500
+ from ..config import (
501
+ resolve_edit_kind as _resolve_edit_kind,
502
+ )
503
+
504
+ console.print(f"📋 Loading configuration: {config_path}")
505
+ cfg = _load_config(config_path)
506
+
507
+ # Apply profile if specified (dev is a no-op)
508
+ if profile and str(profile).lower() in {"ci", "release"}:
509
+ console.print(f"🎯 Applying profile: {profile}")
510
+ try:
511
+ cfg = _apply_profile(cfg, profile)
512
+ except Exception as exc:
513
+ console.print(f"[red]❌ {exc}[/red]")
514
+ raise typer.Exit(1) from exc
515
+
516
+ # Apply edit override
517
+ if edit:
518
+ try:
519
+ edit_name = _resolve_edit_kind(edit)
520
+ console.print(f"✂️ Edit override: {edit} → {edit_name}")
521
+ cfg = _apply_edit_override(cfg, edit)
522
+ except ValueError as e:
523
+ console.print(f"[red]❌ {e}[/red]")
524
+ raise typer.Exit(1) from e
525
+
526
+ # Apply CLI overrides for auto configuration
527
+ if tier or probes is not None:
528
+ if tier and tier not in ["conservative", "balanced", "aggressive", "none"]:
529
+ console.print(
530
+ f"[red]❌ Invalid tier '{tier}'. Valid options: conservative, balanced, aggressive, none[/red]"
531
+ )
532
+ raise typer.Exit(1)
533
+ if probes is not None and (probes < 0 or probes > 10):
534
+ console.print(
535
+ f"[red]❌ Invalid probes '{probes}'. Must be between 0 and 10[/red]"
536
+ )
537
+ raise typer.Exit(1)
538
+
539
+ cfg_dict = cfg.model_dump()
540
+ auto_section = (
541
+ cfg_dict.get("auto") if isinstance(cfg_dict.get("auto"), dict) else {}
542
+ )
543
+ cfg_dict["auto"] = auto_section
544
+ if tier:
545
+ auto_section["tier"] = tier
546
+ console.print(f"🎛️ Auto tier override: {tier}")
547
+ if probes is not None:
548
+ auto_section["probes"] = probes
549
+ console.print(f"🔬 Auto probes override: {probes}")
550
+ cfg = InvarLockConfig(cfg_dict)
551
+
552
+ # Resolve adapter:auto to a concrete built-in adapter if requested
553
+ try:
554
+ from ..adapter_auto import apply_auto_adapter_if_needed as _apply_auto
555
+
556
+ cfg = _apply_auto(cfg)
557
+ except Exception:
558
+ pass
559
+
560
+ return cfg
561
+
562
+
563
+ def _maybe_plan_release_windows(
564
+ capacity_meta: dict[str, Any],
565
+ *,
566
+ requested_preview: int,
567
+ requested_final: int,
568
+ max_calibration: int,
569
+ console: Console,
570
+ ) -> dict[str, Any]:
571
+ """Thin wrapper around _plan_release_windows to improve readability."""
572
+ return _plan_release_windows(
573
+ capacity_meta,
574
+ requested_preview=requested_preview,
575
+ requested_final=requested_final,
576
+ max_calibration=max_calibration,
577
+ console=console,
578
+ )
579
+
580
+
581
+ def _print_pipeline_start(console: Console) -> None:
582
+ console.print("🚀 Starting InvarLock pipeline...")
583
+
584
+
585
+ def _emit_run_artifacts(
586
+ *, report: Any, out_dir: Path, filename_prefix: str, console: Console
587
+ ) -> dict[str, str]:
588
+ """Save run report and return emitted artifact paths."""
589
+ from invarlock.reporting.report import save_report as _save_report
590
+
591
+ console.print("💾 Saving run report...")
592
+ return _save_report(
593
+ report, out_dir, formats=["json"], filename_prefix=filename_prefix
594
+ )
595
+
596
+
597
+ def _resolve_device_and_output(
598
+ cfg: Any, *, device: str | None, out: str | None, console: Console
599
+ ) -> tuple[str, Path]:
600
+ """Resolve device and output directory with validation and logging."""
601
+ from ..device import (
602
+ resolve_device as _resolve_device,
603
+ )
604
+ from ..device import (
605
+ validate_device_for_config as _validate,
606
+ )
607
+
608
+ try:
609
+ cfg_device = getattr(cfg.model, "device", None)
610
+ except Exception:
611
+ cfg_device = None
612
+ target_device = device or cfg_device or "auto"
613
+ resolved_device = _resolve_device(target_device)
614
+ console.print(
615
+ f"Device: {resolved_device} (requested={target_device}, resolved={resolved_device})"
616
+ )
617
+ is_valid, error_msg = _validate(resolved_device)
618
+ if not is_valid:
619
+ console.print(f"[red]❌ Device validation failed: {error_msg}[/red]")
620
+ raise typer.Exit(1)
621
+
622
+ # Determine output directory (support both 'output.dir' and legacy 'out.dir')
623
+ if out:
624
+ output_dir = Path(out)
625
+ else:
626
+ try:
627
+ output_dir = Path(cfg.output.dir)
628
+ except Exception:
629
+ try:
630
+ output_dir = Path(cfg.out.dir) # type: ignore[attr-defined]
631
+ except Exception:
632
+ output_dir = Path("runs")
633
+ output_dir.mkdir(parents=True, exist_ok=True)
634
+ return str(resolved_device), output_dir
635
+
636
+
637
+ def _resolve_provider_and_split(
638
+ cfg: Any,
639
+ model_profile: Any,
640
+ *,
641
+ get_provider_fn: Any,
642
+ provider_kwargs: dict[str, Any] | None = None,
643
+ console: Console,
644
+ resolved_device: str | None = None,
645
+ ) -> tuple[Any, str, bool]:
646
+ """Resolve dataset provider and split, returning (provider, split, used_fallback)."""
647
+ provider_name = None
648
+ provider_kwargs = dict(provider_kwargs or {})
649
+ try:
650
+ provider_val = cfg.dataset.provider
651
+ except Exception:
652
+ provider_val = None
653
+ if isinstance(provider_val, str) and provider_val:
654
+ provider_name = provider_val
655
+ else:
656
+ try:
657
+ provider_name = provider_val.kind # type: ignore[attr-defined]
658
+ try:
659
+ for k, v in provider_val.items(): # type: ignore[attr-defined]
660
+ if k != "kind" and v is not None and v != "":
661
+ provider_kwargs[k] = v
662
+ except Exception:
663
+ pass
664
+ except Exception:
665
+ provider_name = None
666
+ if not provider_name:
667
+ provider_name = getattr(model_profile, "default_provider", None) or "wikitext2"
668
+ # Pass device hint only to providers that understand it (currently WikiText-2)
669
+ if resolved_device and provider_name == "wikitext2":
670
+ provider_kwargs.setdefault("device_hint", resolved_device)
671
+ data_provider = get_provider_fn(provider_name, **provider_kwargs)
672
+
673
+ requested_split = None
674
+ try:
675
+ requested_split = getattr(cfg.dataset, "split", None)
676
+ except Exception:
677
+ requested_split = None
678
+ available_splits = None
679
+ if hasattr(data_provider, "available_splits"):
680
+ try:
681
+ available_splits = list(data_provider.available_splits()) # type: ignore[attr-defined]
682
+ except Exception:
683
+ available_splits = None
684
+ resolved_split, used_fallback_split = _choose_dataset_split(
685
+ requested=requested_split, available=available_splits
686
+ )
687
+ return data_provider, resolved_split, used_fallback_split
688
+
689
+
690
+ def _run_bare_control(
691
+ *,
692
+ adapter: Any,
693
+ edit_op: Any,
694
+ cfg: Any,
695
+ model: Any,
696
+ run_config: Any,
697
+ calibration_data: list[Any],
698
+ auto_config: Any,
699
+ edit_config: Any,
700
+ preview_count: int,
701
+ final_count: int,
702
+ seed_bundle: dict[str, int | None],
703
+ resolved_device: str,
704
+ restore_fn: Any | None,
705
+ console: Console,
706
+ resolved_loss_type: str,
707
+ profile_normalized: str | None,
708
+ skip_model_load: bool = False,
709
+ ) -> dict[str, Any] | None:
710
+ """Execute the bare-control run for overhead estimation and return payload."""
711
+ from invarlock.core.runner import CoreRunner as _CoreRunner
712
+
713
+ console.print("🧪 Running bare control (guards disabled) for overhead check")
714
+ set_seed(seed_bundle["python"]) # type: ignore[arg-type]
715
+
716
+ bare_runner = _CoreRunner()
717
+ bare_config = copy.deepcopy(run_config)
718
+ bare_config.event_path = None
719
+ bare_context = copy.deepcopy(run_config.context)
720
+ bare_context.setdefault("validation", {})["guard_overhead_mode"] = "bare"
721
+ bare_config.context = bare_context
722
+
723
+ if restore_fn and model is not None:
724
+ restore_fn()
725
+ bare_target_model = model
726
+ elif skip_model_load:
727
+ bare_target_model = model or SimpleNamespace(name="bare_stub_model")
728
+ else:
729
+ bare_target_model = adapter.load_model(cfg.model.id, device=resolved_device)
730
+
731
+ bare_report = bare_runner.execute(
732
+ model=bare_target_model,
733
+ adapter=adapter,
734
+ edit=edit_op,
735
+ guards=[],
736
+ config=bare_config,
737
+ calibration_data=calibration_data,
738
+ auto_config=auto_config,
739
+ edit_config=edit_config,
740
+ preview_n=preview_count,
741
+ final_n=final_count,
742
+ )
743
+
744
+ bare_ppl_final = None
745
+ bare_ppl_preview = None
746
+ if hasattr(bare_report, "metrics") and bare_report.metrics:
747
+ bare_pm = bare_report.metrics.get("primary_metric", {})
748
+ bare_ppl_final = bare_pm.get("final") if isinstance(bare_pm, dict) else None
749
+ bare_ppl_preview = bare_pm.get("preview") if isinstance(bare_pm, dict) else None
750
+
751
+ if profile_normalized in {"ci", "release"}:
752
+
753
+ def _finite(x: Any) -> bool:
754
+ try:
755
+ return isinstance(x, (int | float)) and math.isfinite(float(x))
756
+ except Exception:
757
+ return False
758
+
759
+ if not (_finite(bare_ppl_preview) and _finite(bare_ppl_final)):
760
+ console.print(
761
+ "[yellow]⚠️ Primary metric non-finite during bare control; continuing with diagnostics.[/yellow]"
762
+ )
763
+
764
+ payload: dict[str, Any] = {
765
+ "overhead_threshold": GUARD_OVERHEAD_THRESHOLD,
766
+ "messages": [],
767
+ "warnings": [],
768
+ "errors": [],
769
+ "checks": {},
770
+ "source": f"{profile_normalized or 'ci'}_profile",
771
+ }
772
+
773
+ if getattr(bare_report, "status", "").lower() not in {"success", "completed", "ok"}:
774
+ payload["warnings"].append(
775
+ f"Bare run status: {getattr(bare_report, 'status', 'unknown')}"
776
+ )
777
+
778
+ try:
779
+ lk = str(resolved_loss_type or "causal").lower()
780
+ if lk == "mlm":
781
+ pm_kind_bare = "ppl_mlm"
782
+ elif lk in {"seq2seq", "s2s", "t5"}:
783
+ pm_kind_bare = "ppl_seq2seq"
784
+ else:
785
+ pm_kind_bare = "ppl_causal"
786
+ pm_bare = _extract_pm_snapshot_for_overhead(bare_report, kind=pm_kind_bare)
787
+ if isinstance(pm_bare, dict) and pm_bare:
788
+ payload["bare_report"] = {"metrics": {"primary_metric": pm_bare}}
789
+ except Exception:
790
+ pass
791
+
792
+ set_seed(seed_bundle["python"]) # type: ignore[arg-type]
793
+ return payload
794
+
795
+
796
+ def _execute_guarded_run(
797
+ *,
798
+ runner: Any,
799
+ adapter: Any,
800
+ model: Any,
801
+ cfg: Any,
802
+ edit_op: Any,
803
+ run_config: Any,
804
+ guards: list[Any],
805
+ calibration_data: list[Any],
806
+ auto_config: Any,
807
+ edit_config: Any,
808
+ preview_count: int,
809
+ final_count: int,
810
+ restore_fn: Any | None,
811
+ resolved_device: str,
812
+ console: Console,
813
+ skip_model_load: bool = False,
814
+ ) -> tuple[Any, Any]:
815
+ """Restore or load model and execute the guarded CoreRunner."""
816
+ if restore_fn and model is not None:
817
+ restore_fn()
818
+ elif skip_model_load:
819
+ model = model or SimpleNamespace(name="guarded_stub_model")
820
+ else:
821
+ console.print(f"🔧 Loading model: {cfg.model.id} (attempt 1)")
822
+ model = adapter.load_model(cfg.model.id, device=resolved_device)
823
+
824
+ core_report = runner.execute(
825
+ model=model,
826
+ adapter=adapter,
827
+ edit=edit_op,
828
+ guards=guards,
829
+ config=run_config,
830
+ calibration_data=calibration_data,
831
+ auto_config=auto_config,
832
+ edit_config=edit_config,
833
+ preview_n=preview_count,
834
+ final_n=final_count,
835
+ )
836
+ return core_report, model
837
+
838
+
839
+ def _postprocess_and_summarize(
840
+ *,
841
+ report: dict[str, Any],
842
+ run_dir: Path,
843
+ run_config: Any,
844
+ window_plan: dict[str, Any] | None,
845
+ dataset_meta: dict[str, Any],
846
+ match_fraction: float | None,
847
+ overlap_fraction: float | None,
848
+ console: Console,
849
+ ) -> None:
850
+ """Finalize report windows stats and print/save summary artifacts."""
851
+ try:
852
+ ds = report.setdefault("dataset", {}).setdefault("windows", {})
853
+ stats = ds.setdefault("stats", {})
854
+ if match_fraction is not None:
855
+ stats["window_match_fraction"] = float(match_fraction)
856
+ if overlap_fraction is not None:
857
+ stats["window_overlap_fraction"] = float(overlap_fraction)
858
+ try:
859
+ if isinstance(window_plan, dict) and "coverage_ok" in window_plan:
860
+ stats["coverage"] = bool(window_plan.get("coverage_ok"))
861
+ except Exception:
862
+ pass
863
+ except Exception:
864
+ pass
865
+
866
+ saved_files = _emit_run_artifacts(
867
+ report=report, out_dir=run_dir, filename_prefix="report", console=console
868
+ )
869
+ console.print("[green]✅ Run completed successfully![/green]")
870
+ console.print(f"📄 Report: {saved_files['json']}")
871
+ if run_config.event_path:
872
+ console.print(f"📝 Events: {run_config.event_path}")
873
+
874
+
875
+ def _compute_provider_digest(report: dict[str, Any]) -> dict[str, str] | None:
876
+ """Compute provider digest (ids/tokenizer/masking) from report context.
877
+
878
+ Returns a dict with keys: ids_sha256, tokenizer_sha256, masking_sha256?
879
+ """
880
+ # Prefer centralized digest helpers
881
+ from invarlock.utils.digest import hash_json as _hash_json
882
+
883
+ windows = report.get("evaluation_windows") if isinstance(report, dict) else None
884
+ if not isinstance(windows, dict) or not windows:
885
+ return None
886
+ # window_ids digest across preview+final (sorted for stability)
887
+ all_ids: list = []
888
+ for key in ("preview", "final"):
889
+ sec = windows.get(key)
890
+ if not isinstance(sec, dict):
891
+ continue
892
+ wids = sec.get("window_ids")
893
+ if isinstance(wids, list):
894
+ all_ids.extend(list(wids))
895
+ ids_sha = _hash_json(sorted(all_ids)) if all_ids else None
896
+
897
+ # tokenizer hash: prefer meta.tokenizer_hash then data.tokenizer_hash
898
+ tok_hash = None
899
+ meta = report.get("meta") if isinstance(report.get("meta"), dict) else None
900
+ if isinstance(meta, dict):
901
+ tok_hash = meta.get("tokenizer_hash")
902
+ if not tok_hash and isinstance(report.get("data"), dict):
903
+ tok_hash = report["data"].get("tokenizer_hash")
904
+
905
+ # masking hash from mask positions
906
+ masking = _compute_mask_positions_digest(windows)
907
+
908
+ digest: dict[str, str] = {}
909
+ if isinstance(ids_sha, str) and ids_sha:
910
+ digest["ids_sha256"] = ids_sha
911
+ if isinstance(tok_hash, str) and tok_hash:
912
+ digest["tokenizer_sha256"] = str(tok_hash)
913
+ if isinstance(masking, str) and masking:
914
+ digest["masking_sha256"] = masking
915
+ return digest or None
916
+
917
+
918
+ def _validate_and_harvest_baseline_schedule(
919
+ cfg: Any,
920
+ pairing_schedule: dict[str, Any],
921
+ baseline_report_data: dict[str, Any] | None,
922
+ *,
923
+ tokenizer_hash: str | None,
924
+ resolved_loss_type: str,
925
+ baseline_path_str: str | None = None,
926
+ console: Console | None = None,
927
+ ) -> dict[str, Any]:
928
+ """Validate baseline schedule compatibility and harvest dataset metadata.
929
+
930
+ Returns a mapping with keys: effective_preview, effective_final, preview_count,
931
+ final_count, dataset_meta, window_plan, calibration_data.
932
+ """
933
+
934
+ # Helpers
935
+ def _print(msg: str) -> None:
936
+ if console is not None:
937
+ console.print(msg)
938
+
939
+ def _fail_schedule(reason: str) -> None:
940
+ path = baseline_path_str or "baseline"
941
+ _print(
942
+ f"[red]❌ Baseline pairing schedule '{path}' is incompatible: {reason}[/red]"
943
+ )
944
+ raise typer.Exit(1)
945
+
946
+ baseline_meta = (
947
+ baseline_report_data.get("data")
948
+ if isinstance(baseline_report_data, dict)
949
+ else {}
950
+ )
951
+ if not isinstance(baseline_meta, dict):
952
+ baseline_meta = {}
953
+
954
+ def _extract_meta(field: str, default: Any = None) -> Any:
955
+ value = baseline_meta.get(field)
956
+ return value if value is not None else default
957
+
958
+ # Adopt counts from the schedule, warning if they differ from cfg
959
+ baseline_preview = len(pairing_schedule["preview"].get("input_ids") or [])
960
+ baseline_final = len(pairing_schedule["final"].get("input_ids") or [])
961
+ cfg_preview = getattr(cfg.dataset, "preview_n", None)
962
+ cfg_final = getattr(cfg.dataset, "final_n", None)
963
+ if (
964
+ cfg_preview is not None
965
+ and baseline_preview is not None
966
+ and baseline_preview != cfg_preview
967
+ ) or (
968
+ cfg_final is not None
969
+ and baseline_final is not None
970
+ and baseline_final != cfg_final
971
+ ):
972
+ _print(
973
+ "[yellow]⚠️ Adjusting evaluation window counts to match baseline schedule "
974
+ f"({baseline_preview}/{baseline_final}).[/yellow]"
975
+ )
976
+
977
+ effective_preview = int(baseline_preview)
978
+ effective_final = int(baseline_final)
979
+ preview_count = effective_preview
980
+ final_count = effective_final
981
+
982
+ # Validate key dataset parameters
983
+ cfg_seq_len = getattr(cfg.dataset, "seq_len", None)
984
+ baseline_seq_len = _extract_meta("seq_len")
985
+ if (
986
+ cfg_seq_len is not None
987
+ and baseline_seq_len is not None
988
+ and baseline_seq_len != cfg_seq_len
989
+ ):
990
+ _fail_schedule(
991
+ f"sequence length mismatch (baseline {baseline_seq_len} vs config {cfg_seq_len})"
992
+ )
993
+
994
+ cfg_stride = getattr(cfg.dataset, "stride", getattr(cfg.dataset, "seq_len", None))
995
+ baseline_stride = _extract_meta("stride")
996
+ if (
997
+ baseline_stride is not None
998
+ and cfg_stride is not None
999
+ and baseline_stride != cfg_stride
1000
+ ):
1001
+ _fail_schedule(
1002
+ f"stride mismatch (baseline {baseline_stride} vs config {cfg_stride})"
1003
+ )
1004
+
1005
+ cfg_dataset = getattr(cfg.dataset, "provider", None)
1006
+ if cfg_dataset is None:
1007
+ cfg_dataset = getattr(cfg.dataset, "dataset", None)
1008
+ baseline_dataset = _extract_meta("dataset")
1009
+ if (
1010
+ baseline_dataset is not None
1011
+ and cfg_dataset is not None
1012
+ and baseline_dataset != cfg_dataset
1013
+ ):
1014
+ _fail_schedule(
1015
+ f"dataset mismatch (baseline {baseline_dataset} vs config {cfg_dataset})"
1016
+ )
1017
+
1018
+ cfg_split = getattr(cfg.dataset, "split", "validation")
1019
+ baseline_split = _extract_meta("split")
1020
+ if (
1021
+ baseline_split is not None
1022
+ and cfg_split is not None
1023
+ and baseline_split != cfg_split
1024
+ ):
1025
+ _fail_schedule(
1026
+ f"split mismatch (baseline {baseline_split} vs config {cfg_split})"
1027
+ )
1028
+
1029
+ baseline_tokenizer_hash = baseline_meta.get("tokenizer_hash")
1030
+ if (
1031
+ baseline_tokenizer_hash
1032
+ and tokenizer_hash
1033
+ and baseline_tokenizer_hash != tokenizer_hash
1034
+ ):
1035
+ _fail_schedule(
1036
+ "tokenizer hash mismatch between baseline and current configuration"
1037
+ )
1038
+
1039
+ dataset_meta = {
1040
+ key: baseline_meta.get(key)
1041
+ for key in (
1042
+ "tokenizer_hash",
1043
+ "tokenizer_name",
1044
+ "vocab_size",
1045
+ "bos_token",
1046
+ "eos_token",
1047
+ "pad_token",
1048
+ "add_prefix_space",
1049
+ "dataset_hash",
1050
+ "preview_hash",
1051
+ "final_hash",
1052
+ "preview_total_tokens",
1053
+ "final_total_tokens",
1054
+ )
1055
+ if baseline_meta.get(key) is not None
1056
+ }
1057
+ dataset_meta["loss_type"] = resolved_loss_type
1058
+ window_plan = baseline_meta.get("window_plan")
1059
+ calibration_data: list[Any] | None = []
1060
+
1061
+ return {
1062
+ "effective_preview": effective_preview,
1063
+ "effective_final": effective_final,
1064
+ "preview_count": preview_count,
1065
+ "final_count": final_count,
1066
+ "dataset_meta": dataset_meta,
1067
+ "window_plan": window_plan,
1068
+ "calibration_data": calibration_data,
1069
+ }
1070
+
1071
+
1072
+ def _enforce_provider_parity(
1073
+ subject_digest: dict | None, baseline_digest: dict | None, *, profile: str | None
1074
+ ) -> None:
1075
+ """Enforce tokenizer/masking parity rules for CI/Release profiles.
1076
+
1077
+ - If tokenizers differ in CI/Release, abort.
1078
+ - If tokenizers match but masking digests differ (MLM), abort.
1079
+ No-ops outside CI/Release.
1080
+ """
1081
+ prof = (profile or "").strip().lower()
1082
+ if prof not in {"ci", "release"}:
1083
+ return
1084
+ sd = subject_digest or {}
1085
+ bd = baseline_digest or {}
1086
+ subj_tok = sd.get("tokenizer_sha256")
1087
+ base_tok = bd.get("tokenizer_sha256")
1088
+ subj_mask = sd.get("masking_sha256")
1089
+ base_mask = bd.get("masking_sha256")
1090
+ # Missing digest information in CI/Release → abort
1091
+ if not (
1092
+ isinstance(subj_tok, str)
1093
+ and isinstance(base_tok, str)
1094
+ and subj_tok
1095
+ and base_tok
1096
+ ):
1097
+ raise InvarlockError(
1098
+ code="E004",
1099
+ message="PROVIDER-DIGEST-MISSING: subject or baseline missing tokenizer digest",
1100
+ )
1101
+ # Tokenizer mismatch → abort with code
1102
+ if subj_tok != base_tok:
1103
+ raise InvarlockError(
1104
+ code="E002",
1105
+ message="TOKENIZER-DIGEST-MISMATCH: subject and baseline tokenizers differ",
1106
+ )
1107
+ # Masking mismatch under identical tokenizers → abort
1108
+ if (
1109
+ isinstance(subj_mask, str)
1110
+ and isinstance(base_mask, str)
1111
+ and subj_mask
1112
+ and base_mask
1113
+ and subj_mask != base_mask
1114
+ ):
1115
+ raise InvarlockError(
1116
+ code="E003",
1117
+ message="MASK-PARITY-MISMATCH: mask positions differ under matched tokenizers",
1118
+ )
1119
+
1120
+
1121
+ def _resolve_metric_and_provider(
1122
+ cfg: Any,
1123
+ model_profile: Any,
1124
+ *,
1125
+ resolved_loss_type: str | None = None,
1126
+ ) -> tuple[str, str, dict[str, float]]:
1127
+ """Resolve metric kind, provider kind, and metric options from config with precedence.
1128
+
1129
+ Precedence: CLI args (not handled here) → config → ModelProfile defaults → legacy fallback.
1130
+ Primary metric (metric‑v1) is canonical in dev‑phase; no env flag toggles.
1131
+ """
1132
+ # Provider kind
1133
+ provider_val = None
1134
+ try:
1135
+ provider_val = cfg.dataset.provider
1136
+ except Exception:
1137
+ provider_val = None
1138
+ provider_kind = None
1139
+ if isinstance(provider_val, str) and provider_val:
1140
+ provider_kind = provider_val
1141
+ else:
1142
+ # Support object-like config sections (e.g., InvarLockConfig _Obj)
1143
+ try:
1144
+ provider_kind = provider_val.kind
1145
+ except Exception:
1146
+ try:
1147
+ provider_kind = provider_val.get("kind") # type: ignore[attr-defined]
1148
+ except Exception:
1149
+ provider_kind = None
1150
+ if not provider_kind and hasattr(model_profile, "default_provider"):
1151
+ provider_kind = model_profile.default_provider
1152
+ # Fallback to a known provider name supported by get_provider()
1153
+ if not provider_kind:
1154
+ provider_kind = "wikitext2"
1155
+
1156
+ # Metric config
1157
+ metric_cfg = None
1158
+ try:
1159
+ eval_section = cfg.eval
1160
+ metric_cfg = getattr(eval_section, "metric", None)
1161
+ except Exception:
1162
+ metric_cfg = None
1163
+
1164
+ metric_kind = None
1165
+ reps = None
1166
+ ci_level = None
1167
+ if metric_cfg is not None:
1168
+ try:
1169
+ metric_kind = (
1170
+ metric_cfg.get("kind")
1171
+ if isinstance(metric_cfg, dict)
1172
+ else metric_cfg.kind
1173
+ )
1174
+ except Exception:
1175
+ metric_kind = None
1176
+ try:
1177
+ reps = (
1178
+ metric_cfg.get("reps")
1179
+ if isinstance(metric_cfg, dict)
1180
+ else metric_cfg.reps
1181
+ )
1182
+ except Exception:
1183
+ reps = None
1184
+ try:
1185
+ ci_level = (
1186
+ metric_cfg.get("ci_level")
1187
+ if isinstance(metric_cfg, dict)
1188
+ else metric_cfg.ci_level
1189
+ )
1190
+ except Exception:
1191
+ ci_level = None
1192
+
1193
+ # Resolve metric kind from config
1194
+ if isinstance(metric_kind, str) and metric_kind:
1195
+ mk = metric_kind.strip().lower()
1196
+ if mk == "auto":
1197
+ metric_kind = None
1198
+ else:
1199
+ metric_kind = mk
1200
+ else:
1201
+ metric_kind = None
1202
+
1203
+ # Fallback to model profile default or legacy resolution by loss type
1204
+ if not metric_kind and hasattr(model_profile, "default_metric"):
1205
+ metric_kind = model_profile.default_metric
1206
+ if not metric_kind:
1207
+ # Legacy: map from loss kind
1208
+ lk = (resolved_loss_type or "causal").lower()
1209
+ if lk == "mlm":
1210
+ metric_kind = "ppl_mlm"
1211
+ elif lk in {"seq2seq", "s2s", "t5"}:
1212
+ metric_kind = "ppl_seq2seq"
1213
+ else:
1214
+ metric_kind = "ppl_causal"
1215
+
1216
+ # Metric options dict if present
1217
+ opts: dict[str, float] = {}
1218
+ if reps is not None:
1219
+ try:
1220
+ opts["reps"] = float(int(reps))
1221
+ except Exception:
1222
+ pass
1223
+ if ci_level is not None:
1224
+ try:
1225
+ opts["ci_level"] = float(ci_level)
1226
+ except Exception:
1227
+ pass
1228
+
1229
+ return str(metric_kind), str(provider_kind), opts
1230
+
1231
+
1232
+ def _plan_release_windows(
1233
+ capacity: dict[str, Any],
1234
+ *,
1235
+ requested_preview: int,
1236
+ requested_final: int,
1237
+ max_calibration: int,
1238
+ console: Console | None = None,
1239
+ ) -> dict[str, Any]:
1240
+ """Derive release-tier window plan based on dataset capacity."""
1241
+ available_unique = int(capacity.get("available_unique", 0))
1242
+ available_nonoverlap = int(capacity.get("available_nonoverlap", 0))
1243
+ total_tokens = int(capacity.get("total_tokens", 0))
1244
+ dedupe_rate = float(capacity.get("dedupe_rate", 0.0))
1245
+ candidate_unique = capacity.get("candidate_unique")
1246
+ if candidate_unique is not None and int(candidate_unique) > 0:
1247
+ effective_unique = int(candidate_unique)
1248
+ else:
1249
+ effective_unique = available_unique
1250
+ candidate_limit = capacity.get("candidate_limit")
1251
+
1252
+ target_per_arm = int(min(requested_preview, requested_final))
1253
+ if target_per_arm <= 0:
1254
+ target_per_arm = requested_preview or requested_final or 1
1255
+
1256
+ max_calibration = max(0, int(max_calibration or 0))
1257
+ if max_calibration > 0:
1258
+ calibration_windows = max(
1259
+ RELEASE_CALIBRATION_MIN,
1260
+ min(RELEASE_CALIBRATION_MAX, max_calibration // 10),
1261
+ )
1262
+ else:
1263
+ calibration_windows = RELEASE_CALIBRATION_MIN
1264
+ calibration_windows = min(calibration_windows, available_unique)
1265
+
1266
+ buffer_windows = math.ceil(RELEASE_BUFFER_FRACTION * effective_unique)
1267
+ reserve_windows = min(effective_unique, calibration_windows + buffer_windows)
1268
+ available_for_eval = max(0, effective_unique - reserve_windows)
1269
+ actual_per_arm_raw = available_for_eval // 2
1270
+
1271
+ coverage_ok = actual_per_arm_raw >= RELEASE_MIN_WINDOWS_PER_ARM
1272
+ if not coverage_ok:
1273
+ raise RuntimeError(
1274
+ "Release profile capacity insufficient: "
1275
+ f"available_unique={available_unique}, reserve={reserve_windows} "
1276
+ f"(calibration={calibration_windows}, buffer={buffer_windows}), "
1277
+ f"usable_per_arm={actual_per_arm_raw}, "
1278
+ f"requires ≥{RELEASE_MIN_WINDOWS_PER_ARM} per arm."
1279
+ )
1280
+
1281
+ actual_per_arm = min(target_per_arm, actual_per_arm_raw)
1282
+
1283
+ if console:
1284
+ candidate_msg = ""
1285
+ if candidate_unique is not None:
1286
+ candidate_msg = f", candidate_unique={int(candidate_unique)}" + (
1287
+ f"/{int(candidate_limit)}" if candidate_limit is not None else ""
1288
+ )
1289
+ console.print(
1290
+ "📏 Release window capacity:"
1291
+ f" unique={available_unique}, reserve={reserve_windows} "
1292
+ f"(calib {calibration_windows}, buffer {buffer_windows}), "
1293
+ f"usable={available_for_eval}, "
1294
+ f"per-arm raw={actual_per_arm_raw} → selected {actual_per_arm} "
1295
+ f"(target {target_per_arm}{candidate_msg})"
1296
+ )
1297
+ if actual_per_arm < target_per_arm:
1298
+ console.print(
1299
+ "[yellow]⚠️ Adjusted per-arm windows down from "
1300
+ f"{target_per_arm} to {actual_per_arm} based on capacity.[/yellow]"
1301
+ )
1302
+
1303
+ plan = {
1304
+ "profile": "release",
1305
+ "requested_preview": int(requested_preview),
1306
+ "requested_final": int(requested_final),
1307
+ "target_per_arm": target_per_arm,
1308
+ "min_per_arm": RELEASE_MIN_WINDOWS_PER_ARM,
1309
+ "actual_preview": int(actual_per_arm),
1310
+ "actual_final": int(actual_per_arm),
1311
+ "actual_per_arm_raw": int(actual_per_arm_raw),
1312
+ "coverage_ok": coverage_ok,
1313
+ "capacity": {
1314
+ "total_tokens": total_tokens,
1315
+ "available_nonoverlap": available_nonoverlap,
1316
+ "available_unique": available_unique,
1317
+ "effective_unique": effective_unique,
1318
+ "dedupe_rate": dedupe_rate,
1319
+ "calibration": calibration_windows,
1320
+ "buffer_fraction": RELEASE_BUFFER_FRACTION,
1321
+ "buffer_windows": buffer_windows,
1322
+ "reserve_windows": reserve_windows,
1323
+ "usable_after_reserve": available_for_eval,
1324
+ },
1325
+ }
1326
+ if candidate_unique is not None:
1327
+ plan["capacity"]["candidate_unique"] = int(candidate_unique)
1328
+ if candidate_limit is not None:
1329
+ plan["capacity"]["candidate_limit"] = int(candidate_limit)
1330
+ return plan
1331
+
1332
+
1333
+ # Check if core components are available
1334
+ try:
1335
+ from invarlock.core.api import RunConfig # noqa: F401
1336
+ from invarlock.core.registry import get_registry # noqa: F401
1337
+
1338
+ HAS_CORE_COMPONENTS = True
1339
+ except ImportError:
1340
+ HAS_CORE_COMPONENTS = False
1341
+
1342
+
1343
+ def run_command(
1344
+ config: str = typer.Option(
1345
+ ..., "--config", "-c", help="Path to YAML configuration file"
1346
+ ),
1347
+ device: str | None = typer.Option(
1348
+ None, "--device", help="Device override (auto|cuda|mps|cpu)"
1349
+ ),
1350
+ profile: str | None = typer.Option(
1351
+ None, "--profile", help="Profile to apply (ci|release)"
1352
+ ),
1353
+ out: str | None = typer.Option(None, "--out", help="Output directory override"),
1354
+ edit: str | None = typer.Option(None, "--edit", help="Edit kind (quant|mixed)"),
1355
+ tier: str | None = typer.Option(
1356
+ None,
1357
+ "--tier",
1358
+ help="Auto-tuning tier override (conservative|balanced|aggressive)",
1359
+ ),
1360
+ probes: int | None = typer.Option(
1361
+ None, "--probes", help="Number of micro-probes (0=deterministic, >0=adaptive)"
1362
+ ),
1363
+ until_pass: bool = typer.Option(
1364
+ False, "--until-pass", help="Retry until certificate passes (max 3 attempts)"
1365
+ ),
1366
+ max_attempts: int = typer.Option(
1367
+ 3, "--max-attempts", help="Maximum retry attempts for --until-pass mode"
1368
+ ),
1369
+ timeout: int | None = typer.Option(
1370
+ None, "--timeout", help="Timeout in seconds for --until-pass mode"
1371
+ ),
1372
+ baseline: str | None = typer.Option(
1373
+ None,
1374
+ "--baseline",
1375
+ help="Path to baseline report.json for certificate validation",
1376
+ ),
1377
+ no_cleanup: bool = typer.Option(
1378
+ False, "--no-cleanup", help="Skip cleanup of temporary artifacts"
1379
+ ),
1380
+ ):
1381
+ """
1382
+ Run InvarLock pipeline with the given configuration.
1383
+
1384
+ The command assembles non-overlapping preview/final windows, executes the
1385
+ GuardChain (invariants → spectral → RMT → variance), checks pairing/overlap
1386
+ invariants, enforces guard-overhead ≤1 %, and emits a run report plus JSONL
1387
+ events suitable for certificate generation.
1388
+ """
1389
+
1390
+ try:
1391
+ from typer.models import OptionInfo as _TyperOptionInfo # noqa: F401
1392
+ except Exception: # pragma: no cover - typer internals may change
1393
+ _TyperOptionInfo = () # type: ignore[assignment]
1394
+
1395
+ config = _coerce_option(config)
1396
+ device = _coerce_option(device)
1397
+ profile = _coerce_option(profile)
1398
+ out = _coerce_option(out)
1399
+ edit = _coerce_option(edit)
1400
+ tier = _coerce_option(tier)
1401
+ probes = _coerce_option(probes)
1402
+ until_pass = bool(_coerce_option(until_pass, False))
1403
+ max_attempts = int(_coerce_option(max_attempts, 3))
1404
+ timeout = _coerce_option(timeout)
1405
+ baseline = _coerce_option(baseline)
1406
+ no_cleanup = bool(_coerce_option(no_cleanup, False))
1407
+
1408
+ # Use shared CLI coercers from invarlock.cli.utils
1409
+
1410
+ def _fail_run(message: str) -> None:
1411
+ console.print(f"[red]❌ {message}[/red]")
1412
+ # Generic failure path → exit 1 (InvarlockError paths handle code 3 separately)
1413
+ raise typer.Exit(1)
1414
+
1415
+ # Fail fast when torch is missing so users see a clear extras hint instead of
1416
+ # a raw ModuleNotFoundError from deeper imports.
1417
+ try:
1418
+ import torch as _torch # type: ignore[import]
1419
+
1420
+ _ = _torch # pragma: no cover
1421
+ except (ImportError, ModuleNotFoundError) as e:
1422
+ console.print(
1423
+ "❌ Torch is required for this command. "
1424
+ 'Install extras with: pip install "invarlock[hf]" '
1425
+ 'or "invarlock[adapters]".',
1426
+ style="red",
1427
+ markup=False,
1428
+ )
1429
+ raise typer.Exit(1) from e
1430
+
1431
+ # use module-level _extract_pairing_schedule
1432
+
1433
+ # use module-level _to_int_list, _tensor_or_list_to_ints, _safe_int
1434
+
1435
+ # Use the module-level _hash_sequences to avoid duplication
1436
+
1437
+ # use module-level _derive_mlm_seed
1438
+
1439
+ # use module-level _apply_mlm_masks
1440
+
1441
+ # use module-level _tokenizer_digest
1442
+
1443
+ try:
1444
+ # Import InvarLock components
1445
+ from invarlock.core.api import RunConfig
1446
+ from invarlock.core.registry import get_registry
1447
+ from invarlock.core.runner import CoreRunner
1448
+ from invarlock.eval.data import EvaluationWindow, get_provider
1449
+ from invarlock.reporting.report_types import create_empty_report
1450
+
1451
+ # Load and validate configuration via helper (preserves console prints)
1452
+ cfg = _prepare_config_for_run(
1453
+ config_path=config,
1454
+ profile=profile,
1455
+ edit=edit,
1456
+ tier=tier,
1457
+ probes=probes,
1458
+ console=console,
1459
+ )
1460
+
1461
+ # cfg prepared by helper above
1462
+
1463
+ adapter_name = str(getattr(cfg.model, "adapter", "")).lower()
1464
+ model_id_raw = str(getattr(cfg.model, "id", ""))
1465
+ model_profile = detect_model_profile(
1466
+ model_id=model_id_raw, adapter=adapter_name
1467
+ )
1468
+ tokenizer_hash: str | None = None
1469
+ tokenizer: Any | None = None
1470
+
1471
+ loss_cfg = getattr(cfg.eval, "loss", None)
1472
+ resolved_loss_type = (
1473
+ str(getattr(loss_cfg, "type", "auto")).lower() if loss_cfg else "auto"
1474
+ )
1475
+ if resolved_loss_type == "auto":
1476
+ resolved_loss_type = model_profile.default_loss
1477
+ use_mlm = resolved_loss_type == "mlm"
1478
+ mask_prob = _coerce_float(getattr(loss_cfg, "mask_prob", None), 0.15)
1479
+ mask_seed = _coerce_int(getattr(loss_cfg, "seed", None), 42)
1480
+ random_token_prob = _coerce_float(
1481
+ getattr(loss_cfg, "random_token_prob", None), 0.1
1482
+ )
1483
+ original_token_prob = _coerce_float(
1484
+ getattr(loss_cfg, "original_token_prob", None), 0.1
1485
+ )
1486
+ if loss_cfg is not None and getattr(loss_cfg, "type", None) == "auto":
1487
+ try:
1488
+ loss_cfg.type = resolved_loss_type # type: ignore[assignment]
1489
+ except Exception:
1490
+ pass
1491
+
1492
+ # Set deterministic seeds for Python/NumPy/Torch and record provenance
1493
+ raw_seed_value = 42
1494
+ if hasattr(cfg, "dataset"):
1495
+ try:
1496
+ raw_seed_value = getattr(cfg.dataset, "seed", 42)
1497
+ except Exception:
1498
+ raw_seed_value = 42
1499
+ try:
1500
+ seed_value = int(raw_seed_value)
1501
+ except (TypeError, ValueError, OverflowError):
1502
+ seed_value = 42
1503
+ set_seed(seed_value)
1504
+ # Enforce deterministic algorithms in CI/Release profiles when torch is available
1505
+ profile_label = (str(profile or "").lower()) if profile else None
1506
+ if torch is not None and profile_label in {"ci", "release"}:
1507
+ try: # pragma: no cover - behavior depends on torch availability
1508
+ if hasattr(torch, "use_deterministic_algorithms"):
1509
+ torch.use_deterministic_algorithms(True, warn_only=False)
1510
+ if hasattr(torch.backends, "cudnn"):
1511
+ torch.backends.cudnn.benchmark = False
1512
+ try:
1513
+ torch.backends.cudnn.deterministic = True # type: ignore[attr-defined]
1514
+ except Exception:
1515
+ pass
1516
+ except Exception:
1517
+ # If we cannot enforce determinism here, we will rely on core checks
1518
+ pass
1519
+ try:
1520
+ numpy_seed = int(np.random.get_state()[1][0])
1521
+ except Exception:
1522
+ numpy_seed = seed_value
1523
+ torch_seed = None
1524
+ if torch is not None:
1525
+ try:
1526
+ torch_seed = int(torch.initial_seed())
1527
+ except Exception:
1528
+ torch_seed = seed_value
1529
+ seed_bundle = {
1530
+ "python": int(seed_value),
1531
+ "numpy": int(numpy_seed),
1532
+ "torch": int(torch_seed) if torch_seed is not None else None,
1533
+ }
1534
+ console.print(
1535
+ "🎲 Deterministic seeds → "
1536
+ f"python={seed_bundle['python']}, numpy={seed_bundle['numpy']}, "
1537
+ f"torch={seed_bundle['torch'] if seed_bundle['torch'] is not None else 'N/A'}"
1538
+ )
1539
+
1540
+ # Resolve device and output directory
1541
+ resolved_device, output_dir = _resolve_device_and_output(
1542
+ cfg, device=device, out=out, console=console
1543
+ )
1544
+
1545
+ # Create run directory with timestamp
1546
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
1547
+ run_dir = output_dir / timestamp
1548
+ run_dir.mkdir(parents=True, exist_ok=True)
1549
+
1550
+ run_id = f"{output_dir.name}-{timestamp}" if output_dir.name else timestamp
1551
+
1552
+ console.print(f"📁 Output directory: {run_dir}")
1553
+ console.print(f"🆔 Run ID: {run_id}")
1554
+
1555
+ # Initialize retry controller if --until-pass mode enabled
1556
+ retry_controller = _init_retry_controller(
1557
+ until_pass=until_pass,
1558
+ max_attempts=max_attempts,
1559
+ timeout=timeout,
1560
+ baseline=baseline,
1561
+ console=console,
1562
+ )
1563
+
1564
+ baseline_report_data: dict[str, Any] | None = None
1565
+ pairing_schedule: dict[str, Any] | None = None
1566
+ if baseline:
1567
+ baseline_path = Path(baseline)
1568
+ if baseline_path.exists():
1569
+ try:
1570
+ with baseline_path.open(encoding="utf-8") as f:
1571
+ baseline_report_data = json.load(f)
1572
+ pairing_schedule = _extract_pairing_schedule(baseline_report_data)
1573
+ if pairing_schedule:
1574
+ console.print(
1575
+ "🧬 Loaded baseline evaluation schedule for pairing"
1576
+ )
1577
+ elif (profile or "").lower() == "release":
1578
+ console.print(
1579
+ f"[red]❌ Baseline report '{baseline}' does not contain evaluation_windows required for pairing.[/red]"
1580
+ )
1581
+ raise typer.Exit(1)
1582
+ else:
1583
+ console.print(
1584
+ f"[yellow]⚠️ Baseline report '{baseline}' lacks evaluation_windows; falling back to dataset schedule.[/yellow]"
1585
+ )
1586
+ baseline_report_data = None
1587
+ pairing_schedule = None
1588
+ except typer.Exit:
1589
+ raise
1590
+ except Exception as exc: # noqa: BLE001
1591
+ console.print(
1592
+ f"[yellow]⚠️ Failed to load baseline report '{baseline}': {exc}. Falling back to dataset schedule.[/yellow]"
1593
+ )
1594
+ baseline_report_data = None
1595
+ pairing_schedule = None
1596
+ else:
1597
+ console.print(
1598
+ f"[yellow]⚠️ Baseline report '{baseline}' not found. Falling back to dataset schedule.[/yellow]"
1599
+ )
1600
+
1601
+ requested_preview = int(getattr(cfg.dataset, "preview_n", 0))
1602
+ requested_final = int(getattr(cfg.dataset, "final_n", 0))
1603
+ effective_preview = requested_preview
1604
+ effective_final = requested_final
1605
+ preview_count = effective_preview
1606
+ final_count = effective_final
1607
+ # Default split prior to provider resolution; updated if provider exposes splits
1608
+ try:
1609
+ resolved_split = getattr(cfg.dataset, "split", None) or "validation"
1610
+ except Exception:
1611
+ resolved_split = "validation"
1612
+ used_fallback_split: bool = False
1613
+
1614
+ # Execute the pipeline using CoreRunner
1615
+ _print_pipeline_start(console)
1616
+
1617
+ # Get registry and create components
1618
+ registry = get_registry()
1619
+ adapter = registry.get_adapter(cfg.model.adapter)
1620
+ edit_name = getattr(getattr(cfg, "edit", None), "name", None)
1621
+ if not isinstance(edit_name, str) or not edit_name.strip():
1622
+ console.print(
1623
+ "[red]❌ Edit configuration must specify a non-empty `edit.name`.[/red]"
1624
+ )
1625
+ raise typer.Exit(1)
1626
+ try:
1627
+ edit_op = registry.get_edit(edit_name.strip())
1628
+ except Exception:
1629
+ console.print(
1630
+ f"[yellow]⚠️ Unknown edit '{edit_name.strip()}'. Using pass-through shim.[/yellow]"
1631
+ )
1632
+ edit_op = SimpleNamespace(name=edit_name.strip())
1633
+
1634
+ adapter_meta = registry.get_plugin_metadata(cfg.model.adapter, "adapters")
1635
+ try:
1636
+ from invarlock.cli.provenance import (
1637
+ extract_adapter_provenance,
1638
+ ) # local import to avoid CLI import cycles
1639
+
1640
+ prov = extract_adapter_provenance(cfg.model.adapter)
1641
+ # Attach a small, stable provenance dict under adapter plugin metadata
1642
+ adapter_meta["provenance"] = prov.to_dict()
1643
+ except Exception:
1644
+ # Best-effort only; absence should not break runs
1645
+ pass
1646
+ try:
1647
+ edit_meta = registry.get_plugin_metadata(edit_name.strip(), "edits")
1648
+ except Exception:
1649
+ edit_meta = {
1650
+ "name": edit_name.strip(),
1651
+ "module": "edits.unknown",
1652
+ "version": "unknown",
1653
+ }
1654
+
1655
+ guards = []
1656
+ guard_metadata: list[dict[str, Any]] = []
1657
+ for guard_name in cfg.guards.order:
1658
+ if guard_name != "noop":
1659
+ try:
1660
+ guard = registry.get_guard(guard_name)
1661
+ guards.append(guard)
1662
+ guard_metadata.append(
1663
+ registry.get_plugin_metadata(guard_name, "guards")
1664
+ )
1665
+ except KeyError:
1666
+ console.print(
1667
+ f"[yellow]⚠️ Guard '{guard_name}' not found, skipping[/yellow]"
1668
+ )
1669
+ plugin_provenance = {
1670
+ "adapter": adapter_meta,
1671
+ "edit": edit_meta,
1672
+ "guards": guard_metadata,
1673
+ }
1674
+
1675
+ console.print(f"🔌 Adapter: {adapter.name}")
1676
+
1677
+ # Create run configuration
1678
+ def _to_serialisable_dict(section: object) -> dict[str, Any]:
1679
+ """Coerce config fragments to plain dicts.
1680
+
1681
+ Handles InvarLockConfig sections (which wrap dicts in a private `_Obj` with
1682
+ `_data`) so downstream components (core.runner) see canonical mappings,
1683
+ e.g. `eval.bootstrap.replicates`.
1684
+ """
1685
+ # Prefer native dump methods
1686
+ if hasattr(section, "model_dump"):
1687
+ return section.model_dump() # type: ignore[return-value]
1688
+ if hasattr(section, "dict"):
1689
+ try:
1690
+ return section.dict() # type: ignore[return-value]
1691
+ except Exception:
1692
+ pass
1693
+ # Unwrap CLI _Obj wrapper used by InvarLockConfig for attribute access
1694
+ try:
1695
+ raw = getattr(section, "_data", None)
1696
+ if isinstance(raw, dict):
1697
+ return raw
1698
+ except Exception:
1699
+ pass
1700
+ # Already a mapping
1701
+ if isinstance(section, dict):
1702
+ return section
1703
+ # Best-effort attribute dump
1704
+ try:
1705
+ data = vars(section)
1706
+ # Common case: {'_data': {...}}
1707
+ if isinstance(data, dict) and isinstance(data.get("_data"), dict):
1708
+ return data["_data"]
1709
+ return data # type: ignore[return-value]
1710
+ except TypeError:
1711
+ return {}
1712
+
1713
+ def _dump_guard(section: object) -> dict[str, Any]:
1714
+ data = _to_serialisable_dict(section)
1715
+ return data if isinstance(data, dict) else {}
1716
+
1717
+ guard_overrides = {
1718
+ "spectral": _dump_guard(getattr(cfg.guards, "spectral", {})),
1719
+ "rmt": _dump_guard(getattr(cfg.guards, "rmt", {})),
1720
+ "variance": _dump_guard(getattr(cfg.guards, "variance", {})),
1721
+ "invariants": _dump_guard(getattr(cfg.guards, "invariants", {})),
1722
+ }
1723
+
1724
+ if model_profile.invariants:
1725
+ invariants_policy = guard_overrides.setdefault("invariants", {})
1726
+ existing_checks = invariants_policy.get("profile_checks", [])
1727
+ if isinstance(existing_checks, list | tuple | set):
1728
+ checks_list = [str(item) for item in existing_checks]
1729
+ elif existing_checks:
1730
+ checks_list = [str(existing_checks)]
1731
+ else:
1732
+ checks_list = []
1733
+ for invariant in model_profile.invariants:
1734
+ invariant_name = str(invariant)
1735
+ if invariant_name not in checks_list:
1736
+ checks_list.append(invariant_name)
1737
+ invariants_policy["profile_checks"] = checks_list
1738
+
1739
+ run_context = {
1740
+ "eval": _to_serialisable_dict(cfg.eval),
1741
+ "dataset": _to_serialisable_dict(cfg.dataset),
1742
+ "guards": guard_overrides,
1743
+ "profile": profile if profile else "",
1744
+ "pairing_baseline": pairing_schedule,
1745
+ "seeds": seed_bundle,
1746
+ "plugins": plugin_provenance,
1747
+ "run_id": run_id,
1748
+ }
1749
+ run_context["model_profile"] = {
1750
+ "family": model_profile.family,
1751
+ "default_loss": model_profile.default_loss,
1752
+ "module_selectors": model_profile.module_selectors,
1753
+ "invariants": model_profile.invariants,
1754
+ "cert_lints": model_profile.cert_lints,
1755
+ }
1756
+ extra_context = _to_serialisable_dict(getattr(cfg, "context", {}))
1757
+ if isinstance(extra_context, dict):
1758
+ run_context.update(extra_context)
1759
+ try:
1760
+ run_context.setdefault("eval", {}).setdefault("loss", {})[
1761
+ "resolved_type"
1762
+ ] = resolved_loss_type
1763
+ except Exception:
1764
+ pass
1765
+ run_config = RunConfig(
1766
+ device=resolved_device,
1767
+ max_pm_ratio=getattr(cfg.eval, "max_pm_ratio", 1.5),
1768
+ event_path=run_dir / "events.jsonl",
1769
+ context=run_context,
1770
+ )
1771
+ skip_model_load = False
1772
+
1773
+ # Load model using adapter
1774
+ # Load calibration data if dataset is configured
1775
+ calibration_data = None
1776
+ dataset_meta: dict[str, Any] = {}
1777
+ baseline_meta: dict[str, Any] = {}
1778
+ window_plan: dict[str, Any] | None = None
1779
+ if pairing_schedule:
1780
+ harvested = _validate_and_harvest_baseline_schedule(
1781
+ cfg,
1782
+ pairing_schedule,
1783
+ baseline_report_data,
1784
+ tokenizer_hash=tokenizer_hash,
1785
+ resolved_loss_type=resolved_loss_type,
1786
+ baseline_path_str=str(baseline) if baseline else None,
1787
+ console=console,
1788
+ )
1789
+ effective_preview = harvested["effective_preview"]
1790
+ effective_final = harvested["effective_final"]
1791
+ preview_count = harvested["preview_count"]
1792
+ final_count = harvested["final_count"]
1793
+ dataset_meta = harvested["dataset_meta"]
1794
+ window_plan = harvested["window_plan"]
1795
+ calibration_data = harvested["calibration_data"]
1796
+ if use_mlm and tokenizer is None:
1797
+ try:
1798
+ tokenizer, tokenizer_hash = resolve_tokenizer(model_profile)
1799
+ except Exception as exc:
1800
+ console.print(f"[red]❌ {exc}[/red]")
1801
+ raise typer.Exit(1) from exc
1802
+ preview_window_ids = pairing_schedule["preview"].get("window_ids")
1803
+ preview_labels = pairing_schedule["preview"].get("labels")
1804
+ for idx, (input_ids, attention_mask) in enumerate(
1805
+ zip(
1806
+ pairing_schedule["preview"]["input_ids"],
1807
+ pairing_schedule["preview"]["attention_masks"],
1808
+ strict=False,
1809
+ )
1810
+ ):
1811
+ window_id = (
1812
+ preview_window_ids[idx]
1813
+ if preview_window_ids and idx < len(preview_window_ids)
1814
+ else idx
1815
+ )
1816
+ entry = {
1817
+ "input_ids": input_ids,
1818
+ "attention_mask": attention_mask,
1819
+ "window_id": f"preview::{window_id}",
1820
+ }
1821
+ if use_mlm:
1822
+ labels_list: list[int] = []
1823
+ if isinstance(preview_labels, list) and idx < len(preview_labels):
1824
+ labels_list = _tensor_or_list_to_ints(preview_labels[idx])
1825
+ if labels_list and any(token != -100 for token in labels_list):
1826
+ entry["labels"] = labels_list
1827
+ entry["mlm_masked"] = sum(
1828
+ 1 for token in labels_list if token != -100
1829
+ )
1830
+ else:
1831
+ entry["labels"] = []
1832
+ entry["mlm_masked"] = 0
1833
+ # Prefer masked_token_counts if present in schedule
1834
+ mtc = pairing_schedule["preview"].get("masked_token_counts")
1835
+ if isinstance(mtc, list) and idx < len(mtc):
1836
+ try:
1837
+ entry["mlm_masked"] = int(mtc[idx])
1838
+ except Exception:
1839
+ pass
1840
+ calibration_data.append(entry)
1841
+ final_window_ids = pairing_schedule["final"].get("window_ids")
1842
+ final_labels = pairing_schedule["final"].get("labels")
1843
+ for idx, (input_ids, attention_mask) in enumerate(
1844
+ zip(
1845
+ pairing_schedule["final"]["input_ids"],
1846
+ pairing_schedule["final"]["attention_masks"],
1847
+ strict=False,
1848
+ )
1849
+ ):
1850
+ window_id = (
1851
+ final_window_ids[idx]
1852
+ if final_window_ids and idx < len(final_window_ids)
1853
+ else idx
1854
+ )
1855
+ entry = {
1856
+ "input_ids": input_ids,
1857
+ "attention_mask": attention_mask,
1858
+ "window_id": f"final::{window_id}",
1859
+ }
1860
+ if use_mlm:
1861
+ labels_list: list[int] = []
1862
+ if isinstance(final_labels, list) and idx < len(final_labels):
1863
+ labels_list = _tensor_or_list_to_ints(final_labels[idx])
1864
+ if labels_list and any(token != -100 for token in labels_list):
1865
+ entry["labels"] = labels_list
1866
+ entry["mlm_masked"] = sum(
1867
+ 1 for token in labels_list if token != -100
1868
+ )
1869
+ else:
1870
+ entry["labels"] = []
1871
+ entry["mlm_masked"] = 0
1872
+ # Prefer masked_token_counts if present in schedule
1873
+ mtc = pairing_schedule["final"].get("masked_token_counts")
1874
+ if isinstance(mtc, list) and idx < len(mtc):
1875
+ try:
1876
+ entry["mlm_masked"] = int(mtc[idx])
1877
+ except Exception:
1878
+ pass
1879
+ calibration_data.append(entry)
1880
+ preview_count = len(pairing_schedule["preview"]["input_ids"])
1881
+ final_count = len(pairing_schedule["final"]["input_ids"])
1882
+ effective_preview = int(preview_count)
1883
+ effective_final = int(final_count)
1884
+ preview_mask_total = 0
1885
+ final_mask_total = 0
1886
+ preview_mask_counts: list[int] = []
1887
+ final_mask_counts: list[int] = []
1888
+ if use_mlm:
1889
+ preview_entries = calibration_data[:preview_count]
1890
+ final_entries = calibration_data[preview_count:]
1891
+
1892
+ def _needs_masks(entries):
1893
+ missing_any = False
1894
+ counts = []
1895
+ for entry in entries:
1896
+ labels_val = entry.get("labels")
1897
+ has_label_masks = bool(
1898
+ isinstance(labels_val, list)
1899
+ and any(token != -100 for token in labels_val)
1900
+ )
1901
+ existing_count = int(entry.get("mlm_masked", 0))
1902
+ if not has_label_masks and existing_count <= 0:
1903
+ missing_any = True
1904
+ counts.append(int(entry.get("mlm_masked", 0)))
1905
+ return missing_any, counts
1906
+
1907
+ preview_missing, preview_counts_existing = _needs_masks(preview_entries)
1908
+ final_missing, final_counts_existing = _needs_masks(final_entries)
1909
+
1910
+ if preview_missing:
1911
+ preview_mask_total, preview_mask_counts = _apply_mlm_masks(
1912
+ preview_entries,
1913
+ tokenizer=tokenizer,
1914
+ mask_prob=mask_prob,
1915
+ seed=mask_seed,
1916
+ random_token_prob=random_token_prob,
1917
+ original_token_prob=original_token_prob,
1918
+ prefix="preview",
1919
+ )
1920
+ else:
1921
+ preview_mask_counts = preview_counts_existing
1922
+ preview_mask_total = sum(preview_mask_counts)
1923
+
1924
+ if final_missing:
1925
+ final_mask_total, final_mask_counts = _apply_mlm_masks(
1926
+ final_entries,
1927
+ tokenizer=tokenizer,
1928
+ mask_prob=mask_prob,
1929
+ seed=mask_seed,
1930
+ random_token_prob=random_token_prob,
1931
+ original_token_prob=original_token_prob,
1932
+ prefix="final",
1933
+ )
1934
+ else:
1935
+ final_mask_counts = final_counts_existing
1936
+ final_mask_total = sum(final_mask_counts)
1937
+
1938
+ # Ensure counts and labels set on entries
1939
+ if preview_mask_counts:
1940
+ for entry, count in zip(
1941
+ preview_entries, preview_mask_counts, strict=False
1942
+ ):
1943
+ entry["mlm_masked"] = int(count)
1944
+ if final_mask_counts:
1945
+ for entry, count in zip(
1946
+ final_entries, final_mask_counts, strict=False
1947
+ ):
1948
+ entry["mlm_masked"] = int(count)
1949
+
1950
+ if preview_count > 0 and preview_mask_total <= 0:
1951
+ _fail_run(
1952
+ "Baseline pairing schedule provided no masked tokens for preview windows; "
1953
+ "ensure MLM labels are present in the baseline report."
1954
+ )
1955
+ if final_count > 0 and final_mask_total <= 0:
1956
+ _fail_run(
1957
+ "Baseline pairing schedule provided no masked tokens for final windows; "
1958
+ "ensure MLM labels are present in the baseline report."
1959
+ )
1960
+
1961
+ dataset_meta["masked_tokens_preview"] = int(preview_mask_total)
1962
+ dataset_meta["masked_tokens_final"] = int(final_mask_total)
1963
+ dataset_meta["masked_tokens_total"] = int(
1964
+ preview_mask_total + final_mask_total
1965
+ )
1966
+ if os.environ.get("INVARLOCK_DEBUG_TRACE"):
1967
+ console.print(
1968
+ f"[debug] MLM pairing masks → preview={preview_mask_total}, final={final_mask_total}"
1969
+ )
1970
+ if "preview_total_tokens" not in dataset_meta:
1971
+ dataset_meta["preview_total_tokens"] = sum(
1972
+ len(_tensor_or_list_to_ints(seq))
1973
+ for seq in pairing_schedule["preview"]["input_ids"]
1974
+ )
1975
+ if "final_total_tokens" not in dataset_meta:
1976
+ dataset_meta["final_total_tokens"] = sum(
1977
+ len(_tensor_or_list_to_ints(seq))
1978
+ for seq in pairing_schedule["final"]["input_ids"]
1979
+ )
1980
+ if "preview_hash" not in dataset_meta:
1981
+ preview_hash = _hash_sequences(
1982
+ _tensor_or_list_to_ints(seq)
1983
+ for seq in pairing_schedule["preview"]["input_ids"]
1984
+ )
1985
+ dataset_meta["preview_hash"] = preview_hash
1986
+ else:
1987
+ preview_hash = dataset_meta["preview_hash"]
1988
+ if "final_hash" not in dataset_meta:
1989
+ final_hash = _hash_sequences(
1990
+ _tensor_or_list_to_ints(seq)
1991
+ for seq in pairing_schedule["final"]["input_ids"]
1992
+ )
1993
+ dataset_meta["final_hash"] = final_hash
1994
+ else:
1995
+ final_hash = dataset_meta["final_hash"]
1996
+ if "dataset_hash" not in dataset_meta:
1997
+ dataset_meta["dataset_hash"] = hashlib.blake2s(
1998
+ (str(preview_hash) + str(final_hash)).encode("utf-8"),
1999
+ digest_size=16,
2000
+ ).hexdigest()
2001
+ if not window_plan:
2002
+ window_capacity = (
2003
+ baseline_meta.get("window_capacity")
2004
+ if isinstance(baseline_meta, dict)
2005
+ else {}
2006
+ )
2007
+ window_plan = {
2008
+ "profile": (profile or "").lower() or "baseline",
2009
+ "requested_preview": int(preview_count),
2010
+ "requested_final": int(final_count),
2011
+ "actual_preview": int(preview_count),
2012
+ "actual_final": int(final_count),
2013
+ "coverage_ok": True,
2014
+ "capacity": window_capacity or {},
2015
+ }
2016
+ if isinstance(window_plan, dict):
2017
+ dataset_meta.setdefault("window_plan", window_plan)
2018
+ capacity_meta = window_plan.get("capacity")
2019
+ if capacity_meta and "window_capacity" not in dataset_meta:
2020
+ dataset_meta["window_capacity"] = capacity_meta
2021
+ elif cfg.dataset.provider:
2022
+ console.print(f"📊 Loading dataset: {cfg.dataset.provider}")
2023
+ # Pass through provider-specific kwargs when available
2024
+ provider_kwargs = {}
2025
+ for key in (
2026
+ "dataset_name",
2027
+ "config_name",
2028
+ "text_field",
2029
+ "src_field",
2030
+ "tgt_field",
2031
+ "cache_dir",
2032
+ "max_samples",
2033
+ # Local providers (e.g., local_jsonl)
2034
+ "file",
2035
+ "path",
2036
+ "data_files",
2037
+ ):
2038
+ try:
2039
+ val = getattr(cfg.dataset, key)
2040
+ except Exception:
2041
+ val = None
2042
+ if val is not None and val != "":
2043
+ provider_kwargs[key] = val
2044
+ # Resolve provider kind from config (supports string or mapping with kind)
2045
+ provider_val = getattr(cfg.dataset, "provider", None)
2046
+ provider_name = None
2047
+ if isinstance(provider_val, dict):
2048
+ provider_name = provider_val.get("kind")
2049
+ # Include nested provider-specific kwargs
2050
+ for k, v in provider_val.items():
2051
+ if k != "kind" and v is not None and v != "":
2052
+ provider_kwargs[k] = v
2053
+ elif isinstance(provider_val, str):
2054
+ provider_name = provider_val # noqa: F841
2055
+ else:
2056
+ # Support mapping-like provider configs (e.g., _Obj with .get)
2057
+ try:
2058
+ _ = provider_val.get("kind") # type: ignore[attr-defined]
2059
+ # Try to expose nested entries
2060
+ try:
2061
+ for k, v in provider_val._data.items(): # type: ignore[attr-defined]
2062
+ if k != "kind" and v is not None and v != "":
2063
+ provider_kwargs[k] = v
2064
+ except Exception:
2065
+ # Fallback: if items() exists
2066
+ try:
2067
+ for k, v in provider_val.items(): # type: ignore[attr-defined]
2068
+ if k != "kind" and v is not None and v != "":
2069
+ provider_kwargs[k] = v
2070
+ except Exception:
2071
+ pass
2072
+ except Exception:
2073
+ _ = None
2074
+ data_provider, resolved_split, used_fallback_split = (
2075
+ _resolve_provider_and_split(
2076
+ cfg,
2077
+ model_profile,
2078
+ get_provider_fn=get_provider,
2079
+ provider_kwargs=provider_kwargs,
2080
+ console=console,
2081
+ resolved_device=resolved_device,
2082
+ )
2083
+ )
2084
+
2085
+ # Load tokenizer for dataset processing
2086
+ try:
2087
+ tokenizer, tokenizer_hash = resolve_tokenizer(model_profile)
2088
+ except Exception as exc:
2089
+ console.print(f"[red]❌ {exc}[/red]")
2090
+ raise typer.Exit(1) from exc
2091
+
2092
+ dataset_stride = getattr(
2093
+ cfg.dataset, "stride", getattr(cfg.dataset, "seq_len", 0) // 2
2094
+ )
2095
+ release_profile = (profile or "").lower() == "release"
2096
+ if release_profile and not pairing_schedule:
2097
+ estimate_fn = getattr(data_provider, "estimate_capacity", None)
2098
+ if callable(estimate_fn):
2099
+ capacity_fast = bool(getattr(cfg.eval, "capacity_fast", False))
2100
+ capacity_meta = estimate_fn(
2101
+ tokenizer=tokenizer,
2102
+ seq_len=cfg.dataset.seq_len,
2103
+ stride=dataset_stride,
2104
+ split=resolved_split,
2105
+ target_total=requested_preview + requested_final,
2106
+ fast_mode=capacity_fast,
2107
+ )
2108
+ variance_policy = getattr(cfg.guards, "variance", None)
2109
+ max_calibration = (
2110
+ getattr(variance_policy, "max_calib", 0)
2111
+ if variance_policy is not None
2112
+ else 0
2113
+ )
2114
+ try:
2115
+ window_plan = _maybe_plan_release_windows(
2116
+ capacity_meta,
2117
+ requested_preview=requested_preview,
2118
+ requested_final=requested_final,
2119
+ max_calibration=max_calibration,
2120
+ console=console,
2121
+ )
2122
+ except RuntimeError as err:
2123
+ console.print(f"[red]❌ {err}[/red]")
2124
+ raise typer.Exit(1) from err
2125
+
2126
+ actual_per_arm = int(window_plan["actual_preview"])
2127
+ effective_preview = actual_per_arm
2128
+ effective_final = actual_per_arm
2129
+ preview_count = effective_preview
2130
+ final_count = effective_final
2131
+ dataset_stride = getattr(
2132
+ cfg.dataset, "stride", getattr(cfg.dataset, "seq_len", 0)
2133
+ )
2134
+ else:
2135
+ console.print(
2136
+ "[yellow]⚠️ Release profile requested but dataset provider "
2137
+ "does not expose capacity estimation; using configured window counts.[/yellow]"
2138
+ )
2139
+
2140
+ preview_records: list[tuple[list[int], list[int]]] = []
2141
+ final_records: list[tuple[list[int], list[int]]] = []
2142
+
2143
+ while True:
2144
+ preview_window, final_window = data_provider.windows(
2145
+ tokenizer=tokenizer,
2146
+ seq_len=cfg.dataset.seq_len,
2147
+ stride=getattr(cfg.dataset, "stride", cfg.dataset.seq_len // 2),
2148
+ preview_n=effective_preview,
2149
+ final_n=effective_final,
2150
+ seed=getattr(cfg.dataset, "seed", 42),
2151
+ split=resolved_split,
2152
+ )
2153
+
2154
+ preview_count = len(getattr(preview_window, "input_ids", []))
2155
+ final_count = len(getattr(final_window, "input_ids", []))
2156
+ is_eval_window = isinstance(
2157
+ preview_window, EvaluationWindow
2158
+ ) and isinstance(final_window, EvaluationWindow)
2159
+ if is_eval_window:
2160
+ if (
2161
+ preview_count != effective_preview
2162
+ or final_count != effective_final
2163
+ ):
2164
+ _fail_run(
2165
+ "Dataset provider returned mismatched preview/final counts "
2166
+ f"({preview_count}/{final_count}) "
2167
+ f"expected ({effective_preview}/{effective_final}). "
2168
+ "CI/Release profiles require exact parity."
2169
+ )
2170
+ else:
2171
+ preview_count = effective_preview
2172
+ final_count = effective_final
2173
+
2174
+ # Optional: provider-supplied labels for seq2seq
2175
+ provider_labels_prev = None
2176
+ provider_labels_fin = None
2177
+ try:
2178
+ provider_labels_prev = getattr(
2179
+ data_provider, "last_preview_labels", None
2180
+ )
2181
+ provider_labels_fin = getattr(
2182
+ data_provider, "last_final_labels", None
2183
+ )
2184
+ except Exception:
2185
+ provider_labels_prev = None
2186
+ provider_labels_fin = None
2187
+
2188
+ preview_records = []
2189
+ preview_indices_raw = getattr(preview_window, "indices", [])
2190
+ if isinstance(preview_indices_raw, list):
2191
+ preview_indices = preview_indices_raw
2192
+ else:
2193
+ try:
2194
+ preview_indices = list(preview_indices_raw)
2195
+ except TypeError:
2196
+ preview_indices = []
2197
+ for idx_local, (input_ids, attention_mask) in enumerate(
2198
+ zip(
2199
+ preview_window.input_ids,
2200
+ preview_window.attention_masks,
2201
+ strict=False,
2202
+ )
2203
+ ):
2204
+ input_ids_list = _tensor_or_list_to_ints(input_ids)
2205
+ attention_mask_list = (
2206
+ _tensor_or_list_to_ints(attention_mask)
2207
+ if attention_mask is not None
2208
+ else [1] * len(input_ids_list)
2209
+ )
2210
+ dataset_index = (
2211
+ _safe_int(preview_indices[idx_local])
2212
+ if idx_local < len(preview_indices)
2213
+ else idx_local
2214
+ )
2215
+ rec = {
2216
+ "input_ids": input_ids_list,
2217
+ "attention_mask": attention_mask_list,
2218
+ "dataset_index": dataset_index,
2219
+ }
2220
+ # Attach provider labels for seq2seq if available
2221
+ if provider_labels_prev is not None and idx_local < len(
2222
+ provider_labels_prev
2223
+ ):
2224
+ rec["labels"] = _tensor_or_list_to_ints(
2225
+ provider_labels_prev[idx_local]
2226
+ )
2227
+ preview_records.append(rec)
2228
+
2229
+ final_records = []
2230
+ final_indices_raw = getattr(final_window, "indices", [])
2231
+ if isinstance(final_indices_raw, list):
2232
+ final_indices = final_indices_raw
2233
+ else:
2234
+ try:
2235
+ final_indices = list(final_indices_raw)
2236
+ except TypeError:
2237
+ final_indices = []
2238
+ for idx_local, (input_ids, attention_mask) in enumerate(
2239
+ zip(
2240
+ final_window.input_ids,
2241
+ final_window.attention_masks,
2242
+ strict=False,
2243
+ )
2244
+ ):
2245
+ input_ids_list = _tensor_or_list_to_ints(input_ids)
2246
+ attention_mask_list = (
2247
+ _tensor_or_list_to_ints(attention_mask)
2248
+ if attention_mask is not None
2249
+ else [1] * len(input_ids_list)
2250
+ )
2251
+ dataset_index = (
2252
+ _safe_int(final_indices[idx_local])
2253
+ if idx_local < len(final_indices)
2254
+ else idx_local
2255
+ )
2256
+ final_records.append(
2257
+ {
2258
+ "input_ids": input_ids_list,
2259
+ "attention_mask": attention_mask_list,
2260
+ "dataset_index": dataset_index,
2261
+ }
2262
+ )
2263
+
2264
+ if use_mlm:
2265
+ temp_preview_records = [
2266
+ {
2267
+ "input_ids": list(rec["input_ids"]),
2268
+ "attention_mask": list(rec["attention_mask"]),
2269
+ "dataset_index": rec.get("dataset_index"),
2270
+ "window_id": rec.get("window_id"),
2271
+ }
2272
+ for rec in preview_records
2273
+ ]
2274
+ temp_final_records = [
2275
+ {
2276
+ "input_ids": list(rec["input_ids"]),
2277
+ "attention_mask": list(rec["attention_mask"]),
2278
+ "dataset_index": rec.get("dataset_index"),
2279
+ "window_id": rec.get("window_id"),
2280
+ }
2281
+ for rec in final_records
2282
+ ]
2283
+ _apply_mlm_masks(
2284
+ temp_preview_records,
2285
+ tokenizer=tokenizer,
2286
+ mask_prob=mask_prob,
2287
+ seed=mask_seed,
2288
+ random_token_prob=random_token_prob,
2289
+ original_token_prob=original_token_prob,
2290
+ prefix="preview",
2291
+ )
2292
+ _apply_mlm_masks(
2293
+ temp_final_records,
2294
+ tokenizer=tokenizer,
2295
+ mask_prob=mask_prob,
2296
+ seed=mask_seed,
2297
+ random_token_prob=random_token_prob,
2298
+ original_token_prob=original_token_prob,
2299
+ prefix="final",
2300
+ )
2301
+ records_for_signatures = temp_preview_records + temp_final_records
2302
+ else:
2303
+ records_for_signatures = preview_records + final_records
2304
+
2305
+ signatures = []
2306
+ for record in records_for_signatures:
2307
+ tokens = record["input_ids"]
2308
+ masks = record["attention_mask"]
2309
+ signatures.append(
2310
+ tuple(
2311
+ tok
2312
+ for tok, mask in zip(tokens, masks, strict=False)
2313
+ if mask
2314
+ )
2315
+ )
2316
+
2317
+ unique_sequences = len(set(signatures))
2318
+ combined_total = len(signatures)
2319
+ if unique_sequences == combined_total:
2320
+ break
2321
+
2322
+ deficit = combined_total - unique_sequences
2323
+ reduction = max(5, int(deficit) if deficit > 0 else 1)
2324
+ proposed_per_arm = preview_count - reduction
2325
+ if proposed_per_arm >= preview_count:
2326
+ proposed_per_arm = preview_count - 1
2327
+ min_per_arm_floor = RELEASE_MIN_WINDOWS_PER_ARM
2328
+ if window_plan is None or window_plan.get("profile") != "release":
2329
+ min_per_arm_floor = max(
2330
+ 10,
2331
+ min(
2332
+ int(requested_preview or 0) or RELEASE_MIN_WINDOWS_PER_ARM,
2333
+ int(requested_final or 0) or RELEASE_MIN_WINDOWS_PER_ARM,
2334
+ )
2335
+ // 2,
2336
+ )
2337
+ if proposed_per_arm < min_per_arm_floor:
2338
+ raise RuntimeError(
2339
+ "Unable to construct non-overlapping windows within minimum window floor."
2340
+ )
2341
+ console.print(
2342
+ f"[yellow]⚠️ Detected {deficit} duplicate windows; reducing per-arm windows to {proposed_per_arm} and retrying stratification.[/yellow]"
2343
+ )
2344
+
2345
+ effective_preview = proposed_per_arm
2346
+ effective_final = proposed_per_arm
2347
+ preview_count = effective_preview
2348
+ final_count = effective_final
2349
+ if window_plan is not None:
2350
+ window_plan.setdefault("dedupe_adjustments", []).append(
2351
+ {
2352
+ "deficit": int(deficit),
2353
+ "proposed_per_arm": int(proposed_per_arm),
2354
+ }
2355
+ )
2356
+ window_plan["actual_preview"] = proposed_per_arm
2357
+ window_plan["actual_final"] = proposed_per_arm
2358
+ continue
2359
+
2360
+ if window_plan is None:
2361
+ window_plan = {
2362
+ "profile": (profile or "").lower() or "default",
2363
+ "requested_preview": int(requested_preview),
2364
+ "requested_final": int(requested_final),
2365
+ "actual_preview": int(preview_count),
2366
+ "actual_final": int(final_count),
2367
+ "coverage_ok": preview_count == final_count,
2368
+ "capacity": {},
2369
+ }
2370
+ else:
2371
+ window_plan["actual_preview"] = int(preview_count)
2372
+ window_plan["actual_final"] = int(final_count)
2373
+ window_plan["coverage_ok"] = (
2374
+ window_plan.get("coverage_ok", True)
2375
+ and preview_count == final_count
2376
+ )
2377
+
2378
+ calibration_data: list[dict[str, Any]] = []
2379
+ preview_mask_total = 0
2380
+ final_mask_total = 0
2381
+ preview_mask_counts: list[int] = []
2382
+ final_mask_counts: list[int] = []
2383
+ if use_mlm:
2384
+ preview_mask_total, preview_mask_counts = _apply_mlm_masks(
2385
+ preview_records,
2386
+ tokenizer=tokenizer,
2387
+ mask_prob=mask_prob,
2388
+ seed=mask_seed,
2389
+ random_token_prob=random_token_prob,
2390
+ original_token_prob=original_token_prob,
2391
+ prefix="preview",
2392
+ )
2393
+ final_mask_total, final_mask_counts = _apply_mlm_masks(
2394
+ final_records,
2395
+ tokenizer=tokenizer,
2396
+ mask_prob=mask_prob,
2397
+ seed=mask_seed,
2398
+ random_token_prob=random_token_prob,
2399
+ original_token_prob=original_token_prob,
2400
+ prefix="final",
2401
+ )
2402
+ else:
2403
+ preview_mask_counts = [0] * len(preview_records)
2404
+ final_mask_counts = [0] * len(final_records)
2405
+
2406
+ preview_sequences = [record["input_ids"] for record in preview_records]
2407
+ for idx, record in enumerate(preview_records):
2408
+ entry = {
2409
+ "input_ids": record["input_ids"],
2410
+ "attention_mask": record["attention_mask"],
2411
+ "window_id": f"preview::{idx}",
2412
+ "dataset_index": record.get("dataset_index"),
2413
+ "mlm_masked": record.get("mlm_masked", 0),
2414
+ }
2415
+ if use_mlm:
2416
+ entry["labels"] = record.get(
2417
+ "labels", [-100] * len(record["input_ids"])
2418
+ )
2419
+ calibration_data.append(entry)
2420
+
2421
+ final_sequences = [record["input_ids"] for record in final_records]
2422
+ for idx, record in enumerate(final_records):
2423
+ entry = {
2424
+ "input_ids": record["input_ids"],
2425
+ "attention_mask": record["attention_mask"],
2426
+ "window_id": f"final::{idx}",
2427
+ "dataset_index": record.get("dataset_index"),
2428
+ "mlm_masked": record.get("mlm_masked", 0),
2429
+ }
2430
+ if use_mlm:
2431
+ entry["labels"] = record.get(
2432
+ "labels", [-100] * len(record["input_ids"])
2433
+ )
2434
+ elif provider_labels_fin is not None and idx < len(provider_labels_fin):
2435
+ entry["labels"] = _tensor_or_list_to_ints(provider_labels_fin[idx])
2436
+ calibration_data.append(entry)
2437
+
2438
+ masked_tokens_total = preview_mask_total + final_mask_total
2439
+ preview_hash = _hash_sequences(preview_sequences)
2440
+ final_hash = _hash_sequences(final_sequences)
2441
+ dataset_meta = {
2442
+ "tokenizer_name": getattr(tokenizer, "name_or_path", "unknown"),
2443
+ "tokenizer_hash": tokenizer_hash
2444
+ if tokenizer_hash is not None
2445
+ else _tokenizer_digest(tokenizer),
2446
+ "vocab_size": _safe_int(getattr(tokenizer, "vocab_size", 0)),
2447
+ "bos_token": getattr(tokenizer, "bos_token", None),
2448
+ "eos_token": getattr(tokenizer, "eos_token", None),
2449
+ "pad_token": getattr(tokenizer, "pad_token", None),
2450
+ "add_prefix_space": getattr(tokenizer, "add_prefix_space", None),
2451
+ "dataset_hash": hashlib.blake2s(
2452
+ (preview_hash + final_hash).encode("utf-8"), digest_size=16
2453
+ ).hexdigest(),
2454
+ "preview_hash": preview_hash,
2455
+ "final_hash": final_hash,
2456
+ "preview_total_tokens": sum(len(seq) for seq in preview_sequences),
2457
+ "final_total_tokens": sum(len(seq) for seq in final_sequences),
2458
+ }
2459
+ dataset_meta["loss_type"] = resolved_loss_type
2460
+ if use_mlm:
2461
+ dataset_meta["masked_tokens_preview"] = int(preview_mask_total)
2462
+ dataset_meta["masked_tokens_final"] = int(final_mask_total)
2463
+ dataset_meta["masked_tokens_total"] = int(masked_tokens_total)
2464
+ if window_plan:
2465
+ dataset_meta["window_plan"] = window_plan
2466
+ capacity_meta = window_plan.get("capacity")
2467
+ if capacity_meta:
2468
+ dataset_meta["window_capacity"] = capacity_meta
2469
+ strat_stats = getattr(data_provider, "stratification_stats", None)
2470
+ if strat_stats:
2471
+ dataset_meta["stratification"] = strat_stats
2472
+ scorer_profile = getattr(data_provider, "scorer_profile", None)
2473
+ if scorer_profile:
2474
+ dataset_meta["scorer_profile"] = scorer_profile
2475
+
2476
+ try:
2477
+ run_context["dataset"]["preview_n"] = preview_count
2478
+ run_context["dataset"]["final_n"] = final_count
2479
+ except Exception:
2480
+ pass
2481
+ run_context["dataset_meta"] = dataset_meta
2482
+ if window_plan:
2483
+ run_context["window_plan"] = window_plan
2484
+
2485
+ if os.environ.get("INVARLOCK_DEBUG_TRACE"):
2486
+ console.print(
2487
+ "[debug] calibration batch size => preview="
2488
+ f"{preview_count} final={final_count} total={len(calibration_data)}"
2489
+ )
2490
+ if use_mlm and calibration_data:
2491
+ masked_preview = sum(
2492
+ entry.get("mlm_masked", 0)
2493
+ for entry in calibration_data[:preview_count]
2494
+ )
2495
+ masked_final = sum(
2496
+ entry.get("mlm_masked", 0)
2497
+ for entry in calibration_data[preview_count:]
2498
+ )
2499
+ console.print(
2500
+ f"[debug] masked tokens (preview/final) = {masked_preview}/{masked_final}"
2501
+ )
2502
+ console.print(
2503
+ f"[debug] sample labels first preview entry (first 10) = {calibration_data[0]['labels'][:10]}"
2504
+ )
2505
+
2506
+ # Execute the real pipeline using CoreRunner
2507
+ console.print(f"⚙️ Executing pipeline with {len(guards)} guards...")
2508
+ runner = CoreRunner()
2509
+
2510
+ # Prepare auto configuration for tier resolution
2511
+ # Build auto configuration with safe fallbacks when section/keys are absent
2512
+ try:
2513
+ auto_enabled = bool(cfg.auto.enabled)
2514
+ except Exception:
2515
+ auto_enabled = False
2516
+ try:
2517
+ auto_tier = cfg.auto.tier
2518
+ except Exception:
2519
+ auto_tier = "balanced"
2520
+ try:
2521
+ auto_probes = int(cfg.auto.probes)
2522
+ except Exception:
2523
+ auto_probes = 0
2524
+ try:
2525
+ auto_target_ratio = float(cfg.auto.target_pm_ratio)
2526
+ except Exception:
2527
+ auto_target_ratio = 2.0
2528
+
2529
+ auto_config = {
2530
+ "enabled": auto_enabled,
2531
+ "tier": auto_tier,
2532
+ "probes": auto_probes,
2533
+ "target_pm_ratio": auto_target_ratio,
2534
+ }
2535
+
2536
+ # Extract edit configuration parameters
2537
+ edit_config = {}
2538
+ if hasattr(cfg.edit, "plan") and cfg.edit.plan:
2539
+ try:
2540
+ # Accept plain dicts, dict-like wrappers, or nested objects
2541
+ plan_obj = getattr(cfg.edit, "plan", {})
2542
+ if isinstance(plan_obj, dict):
2543
+ edit_config = dict(plan_obj)
2544
+ else:
2545
+ # Best-effort unwrap for InvarLockConfig _Obj wrapper
2546
+ plan_data = getattr(plan_obj, "_data", None)
2547
+ if isinstance(plan_data, dict):
2548
+ edit_config = dict(plan_data)
2549
+ elif hasattr(plan_obj, "items"):
2550
+ edit_config = dict(plan_obj) # type: ignore[arg-type]
2551
+ except (TypeError, AttributeError):
2552
+ pass
2553
+ elif hasattr(cfg.edit, "parameters") and cfg.edit.parameters:
2554
+ try:
2555
+ if hasattr(cfg.edit.parameters, "items"):
2556
+ edit_config = dict(cfg.edit.parameters)
2557
+ elif isinstance(cfg.edit.parameters, dict):
2558
+ edit_config = cfg.edit.parameters
2559
+ except (TypeError, AttributeError):
2560
+ pass
2561
+
2562
+ if (
2563
+ model_profile.module_selectors
2564
+ and "module_selectors" not in edit_config
2565
+ and isinstance(model_profile.module_selectors, dict)
2566
+ ):
2567
+ edit_config["module_selectors"] = {
2568
+ key: list(values)
2569
+ for key, values in model_profile.module_selectors.items()
2570
+ }
2571
+
2572
+ console.print(f"✂️ Edit: {edit_op.name}")
2573
+ console.print(f"🛡️ Guards: {[g.name for g in guards]}")
2574
+
2575
+ # Model load/snapshot strategy
2576
+ model = None
2577
+ restore_fn = None
2578
+ snapshot_tmpdir: str | None = None
2579
+
2580
+ # Try single-load with snapshot/restore if adapter supports it; fallback to reload per attempt
2581
+ try:
2582
+ # Load once
2583
+ console.print(f"🔧 Loading model once: {cfg.model.id}")
2584
+ model = adapter.load_model(cfg.model.id, device=resolved_device)
2585
+
2586
+ # No edit-specific bootstrap logic
2587
+
2588
+ def _estimate_model_bytes(m: Any) -> int:
2589
+ total = 0
2590
+ try:
2591
+ for _, p in getattr(m, "named_parameters", lambda: [])():
2592
+ try:
2593
+ total += int(p.element_size() * p.nelement())
2594
+ except Exception:
2595
+ pass
2596
+ for _, b in getattr(m, "named_buffers", lambda: [])():
2597
+ try:
2598
+ total += int(b.element_size() * b.nelement())
2599
+ except Exception:
2600
+ pass
2601
+ except Exception:
2602
+ return 0
2603
+ return total
2604
+
2605
+ # Load snapshot config from config.context.snapshot (highest precedence)
2606
+ cfg_snapshot = {}
2607
+ try:
2608
+ cfg_context = _to_serialisable_dict(getattr(cfg, "context", {}))
2609
+ if isinstance(cfg_context, dict):
2610
+ cfg_snapshot = _to_serialisable_dict(
2611
+ cfg_context.get("snapshot", {})
2612
+ )
2613
+ if not isinstance(cfg_snapshot, dict):
2614
+ cfg_snapshot = {}
2615
+ except Exception:
2616
+ cfg_snapshot = {}
2617
+
2618
+ def _choose_snapshot_mode() -> str:
2619
+ # Precedence: config > env > auto
2620
+ cfg_mode = (
2621
+ str(cfg_snapshot.get("mode", "")).lower()
2622
+ if isinstance(cfg_snapshot, dict)
2623
+ else ""
2624
+ )
2625
+ mode_env = str(
2626
+ os.environ.get("INVARLOCK_SNAPSHOT_MODE", "auto")
2627
+ ).lower()
2628
+ supports_chunked = hasattr(adapter, "snapshot_chunked") and hasattr(
2629
+ adapter, "restore_chunked"
2630
+ )
2631
+ supports_bytes = hasattr(adapter, "snapshot") and hasattr(
2632
+ adapter, "restore"
2633
+ )
2634
+ if cfg_mode in {"bytes", "chunked"}:
2635
+ if cfg_mode == "bytes" and supports_bytes:
2636
+ return "bytes"
2637
+ if cfg_mode == "chunked" and supports_chunked:
2638
+ return "chunked"
2639
+ # fallback preference
2640
+ if supports_bytes:
2641
+ return "bytes"
2642
+ if supports_chunked:
2643
+ return "chunked"
2644
+ return "reload"
2645
+ if mode_env in {"bytes", "chunked"}:
2646
+ if mode_env == "bytes" and supports_bytes:
2647
+ return "bytes"
2648
+ if mode_env == "chunked" and supports_chunked:
2649
+ return "chunked"
2650
+ # fallback preference
2651
+ if supports_bytes:
2652
+ return "bytes"
2653
+ if supports_chunked:
2654
+ return "chunked"
2655
+ return "reload"
2656
+ # auto
2657
+ est_mb = _estimate_model_bytes(model) / (1024.0 * 1024.0)
2658
+ # RAM-based heuristic
2659
+ try:
2660
+ ram = psutil.virtual_memory()
2661
+ avail_mb = float(getattr(ram, "available", 0)) / (1024.0 * 1024.0)
2662
+ except Exception:
2663
+ avail_mb = 0.0
2664
+ # fraction: config override > env > default 0.4
2665
+ frac = 0.4
2666
+ try:
2667
+ if (
2668
+ isinstance(cfg_snapshot, dict)
2669
+ and cfg_snapshot.get("ram_fraction") is not None
2670
+ ):
2671
+ frac = float(cfg_snapshot.get("ram_fraction"))
2672
+ else:
2673
+ frac = float(
2674
+ os.environ.get("INVARLOCK_SNAPSHOT_AUTO_RAM_FRACTION", frac)
2675
+ )
2676
+ except Exception:
2677
+ pass
2678
+ # threshold mb: if no RAM info, use config threshold_mb or env fallback; else derive from avail*frac
2679
+ if avail_mb > 0:
2680
+ threshold_mb = avail_mb * max(0.0, min(frac, 1.0))
2681
+ else:
2682
+ try:
2683
+ if (
2684
+ isinstance(cfg_snapshot, dict)
2685
+ and cfg_snapshot.get("threshold_mb") is not None
2686
+ ):
2687
+ threshold_mb = float(cfg_snapshot.get("threshold_mb"))
2688
+ else:
2689
+ threshold_mb = float(
2690
+ os.environ.get("INVARLOCK_SNAPSHOT_THRESHOLD_MB", "768")
2691
+ )
2692
+ except Exception:
2693
+ threshold_mb = 768.0
2694
+ # Disk availability for chunked
2695
+ try:
2696
+ tmpdir = None
2697
+ if isinstance(cfg_snapshot, dict):
2698
+ tmpdir = cfg_snapshot.get("temp_dir") or None
2699
+ if not tmpdir:
2700
+ tmpdir = (
2701
+ os.environ.get("TMPDIR") or os.environ.get("TMP") or "/tmp"
2702
+ )
2703
+ du = shutil.disk_usage(tmpdir)
2704
+ free_mb = float(du.free) / (1024.0 * 1024.0)
2705
+ except Exception:
2706
+ free_mb = 0.0
2707
+ # Disk margin ratio: config > default 1.2
2708
+ margin = 1.2
2709
+ try:
2710
+ if (
2711
+ isinstance(cfg_snapshot, dict)
2712
+ and cfg_snapshot.get("disk_free_margin_ratio") is not None
2713
+ ):
2714
+ margin = float(cfg_snapshot.get("disk_free_margin_ratio"))
2715
+ except Exception:
2716
+ pass
2717
+ # Choose chunked if model snapshot is a large fraction of available RAM and disk has room
2718
+ if (
2719
+ supports_chunked
2720
+ and est_mb >= threshold_mb
2721
+ and (free_mb <= 0.0 or est_mb * margin <= free_mb)
2722
+ ):
2723
+ return "chunked"
2724
+ # Otherwise prefer bytes when supported
2725
+ if supports_bytes:
2726
+ # If RAM is extremely low and even bytes snapshot likely risky, fallback to chunked when possible
2727
+ if (
2728
+ supports_chunked
2729
+ and avail_mb > 0
2730
+ and est_mb >= max(64.0, avail_mb * 0.8)
2731
+ and (free_mb <= 0.0 or est_mb * margin <= free_mb)
2732
+ ):
2733
+ return "chunked"
2734
+ return "bytes"
2735
+ if supports_chunked:
2736
+ return "chunked"
2737
+ return "reload"
2738
+
2739
+ mode = _choose_snapshot_mode()
2740
+ # Emit deterministic snapshot mode status line
2741
+ console.print(
2742
+ f"snapshot_mode: {'enabled' if mode in {'bytes', 'chunked'} else 'disabled'}"
2743
+ )
2744
+ if mode == "chunked":
2745
+ snapshot_tmpdir = adapter.snapshot_chunked(model) # type: ignore[attr-defined]
2746
+
2747
+ def _restore():
2748
+ adapter.restore_chunked(model, snapshot_tmpdir) # type: ignore[attr-defined]
2749
+
2750
+ restore_fn = _restore
2751
+ elif mode == "bytes":
2752
+ base_blob = adapter.snapshot(model) # type: ignore[attr-defined]
2753
+
2754
+ def _restore2():
2755
+ adapter.restore(model, base_blob) # type: ignore[attr-defined]
2756
+
2757
+ restore_fn = _restore2
2758
+ else:
2759
+ # reload path
2760
+ model = None
2761
+ restore_fn = None
2762
+ except Exception:
2763
+ # On any failure, fall back to reload-per-attempt path
2764
+ model = None
2765
+ restore_fn = None
2766
+
2767
+ # RETRY LOOP - All report processing inside loop
2768
+ attempt = 1
2769
+ profile_normalized = (profile or "").lower()
2770
+ measure_guard_overhead = profile_normalized in {"ci", "release"}
2771
+
2772
+ while True:
2773
+ # Reset RNG streams each attempt to guarantee determinism across retries
2774
+ set_seed(seed_bundle["python"])
2775
+
2776
+ if retry_controller:
2777
+ console.print(f"\n🚀 Attempt {attempt}/{max_attempts}")
2778
+ if attempt > 1:
2779
+ console.print(f"🔄 Retry attempt {attempt}/{max_attempts}")
2780
+ else:
2781
+ if attempt > 1:
2782
+ console.print(f"\n🚀 Attempt {attempt}")
2783
+
2784
+ # Adjust parameters for retry attempts
2785
+ if retry_controller and attempt > 1:
2786
+ from invarlock.core.retry import adjust_edit_params
2787
+
2788
+ edit_config = adjust_edit_params(
2789
+ edit_op.name, edit_config, attempt, None
2790
+ )
2791
+
2792
+ guard_overhead_payload: dict[str, Any] | None = None
2793
+ if measure_guard_overhead:
2794
+ guard_overhead_payload = _run_bare_control(
2795
+ adapter=adapter,
2796
+ edit_op=edit_op,
2797
+ cfg=cfg,
2798
+ model=model,
2799
+ run_config=run_config,
2800
+ calibration_data=calibration_data,
2801
+ auto_config=auto_config,
2802
+ edit_config=edit_config,
2803
+ preview_count=preview_count,
2804
+ final_count=final_count,
2805
+ seed_bundle=seed_bundle,
2806
+ resolved_device=resolved_device,
2807
+ restore_fn=restore_fn,
2808
+ console=console,
2809
+ resolved_loss_type=resolved_loss_type,
2810
+ profile_normalized=profile_normalized,
2811
+ skip_model_load=skip_model_load,
2812
+ )
2813
+
2814
+ # Ensure clean state for guarded run
2815
+ core_report, model = _execute_guarded_run(
2816
+ runner=runner,
2817
+ adapter=adapter,
2818
+ model=model,
2819
+ cfg=cfg,
2820
+ edit_op=edit_op,
2821
+ run_config=run_config,
2822
+ guards=guards,
2823
+ calibration_data=calibration_data,
2824
+ auto_config=auto_config,
2825
+ edit_config=edit_config,
2826
+ preview_count=preview_count,
2827
+ final_count=final_count,
2828
+ restore_fn=restore_fn,
2829
+ resolved_device=resolved_device,
2830
+ console=console,
2831
+ skip_model_load=skip_model_load,
2832
+ )
2833
+
2834
+ if not hasattr(core_report, "context") or core_report.context is None:
2835
+ core_report.context = {}
2836
+
2837
+ # Convert CoreRunner report to evaluation report
2838
+ report = create_empty_report()
2839
+
2840
+ # Code provenance: commit hash and InvarLock version
2841
+ commit_value = (
2842
+ getattr(cfg.meta, "commit", "") if hasattr(cfg, "meta") else ""
2843
+ )
2844
+ if not commit_value:
2845
+ try:
2846
+ import subprocess
2847
+
2848
+ commit_value = (
2849
+ subprocess.check_output(
2850
+ ["git", "rev-parse", "HEAD"], stderr=subprocess.DEVNULL
2851
+ )
2852
+ .decode("utf-8", "ignore")
2853
+ .strip()
2854
+ )
2855
+ except Exception:
2856
+ commit_value = ""
2857
+ invarlock_version = None
2858
+ try:
2859
+ from invarlock import __version__ as _invarlock_version
2860
+
2861
+ invarlock_version = _invarlock_version
2862
+ except Exception:
2863
+ invarlock_version = None
2864
+
2865
+ # Collect determinism/env flags
2866
+ env_flags: dict[str, object] = {}
2867
+ try:
2868
+ import os as _os
2869
+
2870
+ if torch is not None:
2871
+ try:
2872
+ det_enabled = getattr(
2873
+ torch, "are_deterministic_algorithms_enabled", None
2874
+ )
2875
+ if callable(det_enabled):
2876
+ env_flags["torch_deterministic_algorithms"] = bool(
2877
+ det_enabled()
2878
+ )
2879
+ except Exception:
2880
+ pass
2881
+ try:
2882
+ tf32_matmul = getattr(
2883
+ getattr(torch.backends, "cuda", object()), "matmul", None
2884
+ )
2885
+ if tf32_matmul is not None and hasattr(
2886
+ tf32_matmul, "allow_tf32"
2887
+ ):
2888
+ env_flags["cuda_matmul_allow_tf32"] = bool(
2889
+ tf32_matmul.allow_tf32
2890
+ )
2891
+ except Exception:
2892
+ pass
2893
+ try:
2894
+ cudnn_mod = getattr(torch.backends, "cudnn", None)
2895
+ if cudnn_mod is not None:
2896
+ env_flags["cudnn_allow_tf32"] = bool(
2897
+ getattr(cudnn_mod, "allow_tf32", None)
2898
+ )
2899
+ env_flags["cudnn_deterministic"] = bool(
2900
+ getattr(cudnn_mod, "deterministic", None)
2901
+ )
2902
+ env_flags["cudnn_benchmark"] = bool(
2903
+ getattr(cudnn_mod, "benchmark", None)
2904
+ )
2905
+ except Exception:
2906
+ pass
2907
+ try:
2908
+ env_flags["mps_available"] = bool(
2909
+ getattr(torch.backends, "mps", None)
2910
+ and torch.backends.mps.is_available()
2911
+ )
2912
+ except Exception:
2913
+ pass
2914
+ # Common environment variables for determinism
2915
+ env_flags["CUBLAS_WORKSPACE_CONFIG"] = _os.environ.get(
2916
+ "CUBLAS_WORKSPACE_CONFIG"
2917
+ )
2918
+ except Exception:
2919
+ env_flags = {}
2920
+
2921
+ meta_payload = {
2922
+ "model_id": cfg.model.id,
2923
+ "adapter": cfg.model.adapter,
2924
+ "device": str(resolved_device),
2925
+ "commit": commit_value,
2926
+ "seed": seed_bundle["python"],
2927
+ "seeds": seed_bundle,
2928
+ "ts": datetime.now().isoformat(),
2929
+ "auto": auto_config,
2930
+ }
2931
+ if invarlock_version:
2932
+ meta_payload["invarlock_version"] = invarlock_version
2933
+ if env_flags:
2934
+ meta_payload["env_flags"] = env_flags
2935
+ report["meta"].update(meta_payload)
2936
+ report["meta"]["model_profile"] = {
2937
+ "family": model_profile.family,
2938
+ "default_loss": model_profile.default_loss,
2939
+ "module_selectors": model_profile.module_selectors,
2940
+ "invariants": list(model_profile.invariants),
2941
+ "cert_lints": [dict(lint) for lint in model_profile.cert_lints],
2942
+ }
2943
+
2944
+ report["data"].update(
2945
+ {
2946
+ "dataset": cfg.dataset.provider,
2947
+ # Resolved split (explicit or inferred)
2948
+ "split": resolved_split,
2949
+ "seq_len": cfg.dataset.seq_len,
2950
+ "stride": getattr(cfg.dataset, "stride", cfg.dataset.seq_len // 2),
2951
+ "preview_n": _safe_int(preview_count),
2952
+ "final_n": _safe_int(final_count),
2953
+ }
2954
+ )
2955
+ dataset_meta_context = core_report.context.get("dataset_meta", {})
2956
+ if isinstance(dataset_meta_context, dict):
2957
+ report["data"].update(dataset_meta_context)
2958
+ dataset_tokenizer_hash = dataset_meta_context.get("tokenizer_hash")
2959
+ if (
2960
+ not tokenizer_hash
2961
+ and isinstance(dataset_tokenizer_hash, str)
2962
+ and dataset_tokenizer_hash
2963
+ ):
2964
+ tokenizer_hash = dataset_tokenizer_hash
2965
+
2966
+ if tokenizer_hash:
2967
+ report["meta"]["tokenizer_hash"] = tokenizer_hash
2968
+
2969
+ # Transfer edit information
2970
+ if hasattr(core_report, "edit") and core_report.edit:
2971
+ edit_deltas = core_report.edit.get("deltas", {})
2972
+ report["edit"].update(
2973
+ {
2974
+ "name": edit_op.name,
2975
+ "plan_digest": core_report.edit.get(
2976
+ "plan_digest", str(hash(str(core_report.edit)))
2977
+ ),
2978
+ "deltas": {
2979
+ "params_changed": edit_deltas.get("params_changed", 0),
2980
+ "sparsity": edit_deltas.get("sparsity", None),
2981
+ "bitwidth_map": edit_deltas.get("bitwidth_map", None),
2982
+ "layers_modified": edit_deltas.get("layers_modified", 0),
2983
+ },
2984
+ }
2985
+ )
2986
+ for key in (
2987
+ "algorithm",
2988
+ "algorithm_version",
2989
+ "implementation",
2990
+ "scope",
2991
+ "ranking",
2992
+ "grouping",
2993
+ "budgets",
2994
+ "seed",
2995
+ "mask_digest",
2996
+ ):
2997
+ if key in core_report.edit:
2998
+ report["edit"][key] = copy.deepcopy(core_report.edit[key])
2999
+ if isinstance(core_report.context, dict):
3000
+ core_report.context.setdefault("edit", {})
3001
+ core_report.context["edit"].update(
3002
+ {
3003
+ "name": edit_op.name,
3004
+ "params_changed": edit_deltas.get("params_changed", 0),
3005
+ "layers_modified": edit_deltas.get("layers_modified", 0),
3006
+ }
3007
+ )
3008
+
3009
+ mask_artifact_path = _persist_ref_masks(core_report, run_dir)
3010
+ if mask_artifact_path:
3011
+ report.setdefault("artifacts", {})
3012
+ report["artifacts"]["masks_path"] = str(mask_artifact_path)
3013
+
3014
+ # Transfer metrics (PM-only: do not write ppl_* fields)
3015
+ if hasattr(core_report, "metrics") and core_report.metrics:
3016
+ metrics_payload = {
3017
+ "latency_ms_per_tok": core_report.metrics.get(
3018
+ "latency_ms_per_tok", 0.0
3019
+ ),
3020
+ "memory_mb_peak": core_report.metrics.get("memory_mb_peak", 0.0),
3021
+ "spectral": {},
3022
+ "rmt": {},
3023
+ "invariants": {},
3024
+ }
3025
+ window_plan_ctx = core_report.context.get("window_plan")
3026
+ if isinstance(window_plan_ctx, dict):
3027
+ metrics_payload["window_plan"] = window_plan_ctx
3028
+ capacity_meta = window_plan_ctx.get("capacity")
3029
+ if isinstance(capacity_meta, dict):
3030
+ metrics_payload["window_capacity"] = capacity_meta
3031
+ stats_section = metrics_payload.setdefault("stats", {})
3032
+ if isinstance(stats_section, dict):
3033
+ stats_section.update(
3034
+ {
3035
+ "requested_preview": window_plan_ctx.get(
3036
+ "requested_preview"
3037
+ ),
3038
+ "requested_final": window_plan_ctx.get(
3039
+ "requested_final"
3040
+ ),
3041
+ "actual_preview": window_plan_ctx.get("actual_preview"),
3042
+ "actual_final": window_plan_ctx.get("actual_final"),
3043
+ "coverage_ok": window_plan_ctx.get("coverage_ok"),
3044
+ }
3045
+ )
3046
+ optional_keys = [
3047
+ "logloss_preview",
3048
+ "logloss_final",
3049
+ "logloss_delta",
3050
+ "logloss_preview_ci",
3051
+ "logloss_final_ci",
3052
+ "logloss_delta_ci",
3053
+ "bootstrap",
3054
+ "window_overlap_fraction",
3055
+ "window_match_fraction",
3056
+ "window_pairing_reason",
3057
+ "window_pairing_preview",
3058
+ "window_pairing_final",
3059
+ "paired_windows",
3060
+ "paired_delta_summary",
3061
+ "preview_total_tokens",
3062
+ "final_total_tokens",
3063
+ "masked_tokens_total",
3064
+ "masked_tokens_preview",
3065
+ "masked_tokens_final",
3066
+ "reduction",
3067
+ ]
3068
+ for key in optional_keys:
3069
+ if key in core_report.metrics:
3070
+ metrics_payload[key] = core_report.metrics[key]
3071
+ metrics_payload["loss_type"] = resolved_loss_type
3072
+ if metrics_payload.get("loss_type") is None and isinstance(
3073
+ dataset_meta_context, dict
3074
+ ):
3075
+ metrics_payload["loss_type"] = dataset_meta_context.get(
3076
+ "loss_type", resolved_loss_type
3077
+ )
3078
+ if isinstance(dataset_meta_context, dict):
3079
+ for meta_key in (
3080
+ "masked_tokens_total",
3081
+ "masked_tokens_preview",
3082
+ "masked_tokens_final",
3083
+ ):
3084
+ if (
3085
+ meta_key not in metrics_payload
3086
+ and dataset_meta_context.get(meta_key) is not None
3087
+ ):
3088
+ metrics_payload[meta_key] = dataset_meta_context[meta_key]
3089
+ report["metrics"].update(metrics_payload)
3090
+
3091
+ if guard_overhead_payload is not None:
3092
+ # Compute guarded primary-metric snapshot; pass structured reports into validator
3093
+ try:
3094
+ # Map loss type to ppl family kind
3095
+ lk = str(resolved_loss_type or "causal").lower()
3096
+ if lk == "mlm":
3097
+ pm_kind_for_overhead = "ppl_mlm"
3098
+ elif lk in {"seq2seq", "s2s", "t5"}:
3099
+ pm_kind_for_overhead = "ppl_seq2seq"
3100
+ else:
3101
+ pm_kind_for_overhead = "ppl_causal"
3102
+
3103
+ # Prefer computing from the in-memory core_report windows to avoid ordering issues
3104
+ pm_guarded = _extract_pm_snapshot_for_overhead(
3105
+ core_report, kind=pm_kind_for_overhead
3106
+ )
3107
+ if not isinstance(pm_guarded, dict) or not pm_guarded:
3108
+ pm_guarded = _extract_pm_snapshot_for_overhead(
3109
+ report, kind=pm_kind_for_overhead
3110
+ )
3111
+
3112
+ guard_overhead_payload["guarded_report"] = (
3113
+ {"metrics": {"primary_metric": pm_guarded}}
3114
+ if isinstance(pm_guarded, dict) and pm_guarded
3115
+ else None
3116
+ )
3117
+ except Exception:
3118
+ guard_overhead_payload["guarded_report"] = None
3119
+ bare_struct = guard_overhead_payload.get("bare_report") or {}
3120
+ guarded_struct = guard_overhead_payload.get("guarded_report") or {}
3121
+ # Be robust to mocks or minimal objects returned by validators
3122
+ result = validate_guard_overhead(
3123
+ bare_struct,
3124
+ guarded_struct,
3125
+ overhead_threshold=guard_overhead_payload.get(
3126
+ "overhead_threshold", GUARD_OVERHEAD_THRESHOLD
3127
+ ),
3128
+ )
3129
+ try:
3130
+ messages = list(getattr(result, "messages", []))
3131
+ except Exception: # pragma: no cover - defensive
3132
+ messages = []
3133
+ try:
3134
+ warnings = list(getattr(result, "warnings", []))
3135
+ except Exception: # pragma: no cover - defensive
3136
+ warnings = []
3137
+ try:
3138
+ errors = list(getattr(result, "errors", []))
3139
+ except Exception: # pragma: no cover - defensive
3140
+ errors = []
3141
+ try:
3142
+ checks = dict(getattr(result, "checks", {}))
3143
+ except Exception: # pragma: no cover - defensive
3144
+ checks = {}
3145
+ metrics_obj = getattr(result, "metrics", {})
3146
+ if not isinstance(metrics_obj, dict):
3147
+ metrics_obj = {}
3148
+ overhead_ratio = metrics_obj.get("overhead_ratio")
3149
+ if overhead_ratio is None:
3150
+ overhead_ratio = getattr(result, "overhead_ratio", None)
3151
+ overhead_percent = metrics_obj.get("overhead_percent")
3152
+ if overhead_percent is None:
3153
+ overhead_percent = getattr(result, "overhead_percent", None)
3154
+ passed_flag = bool(getattr(result, "passed", False))
3155
+
3156
+ guard_overhead_payload.update(
3157
+ {
3158
+ "messages": messages,
3159
+ "warnings": warnings,
3160
+ "errors": errors,
3161
+ "checks": checks,
3162
+ "overhead_ratio": overhead_ratio,
3163
+ "overhead_percent": overhead_percent,
3164
+ "passed": passed_flag,
3165
+ "evaluated": True,
3166
+ }
3167
+ )
3168
+ # Normalize for non-finite/degenerate cases
3169
+ guard_overhead_payload = _normalize_overhead_result(
3170
+ guard_overhead_payload, profile=profile_normalized
3171
+ )
3172
+ report["guard_overhead"] = guard_overhead_payload
3173
+
3174
+ had_baseline = bool(baseline and Path(baseline).exists())
3175
+ if (
3176
+ hasattr(core_report, "evaluation_windows")
3177
+ and core_report.evaluation_windows
3178
+ ):
3179
+ preview_windows = core_report.evaluation_windows.get("preview", {})
3180
+ final_windows = core_report.evaluation_windows.get("final", {})
3181
+ report["evaluation_windows"] = {
3182
+ "preview": {
3183
+ "window_ids": list(preview_windows.get("window_ids", [])),
3184
+ "logloss": list(preview_windows.get("logloss", [])),
3185
+ "input_ids": [
3186
+ list(seq) for seq in preview_windows.get("input_ids", [])
3187
+ ],
3188
+ "attention_masks": [
3189
+ list(mask)
3190
+ for mask in preview_windows.get("attention_masks", [])
3191
+ ],
3192
+ "token_counts": list(preview_windows.get("token_counts", [])),
3193
+ "masked_token_counts": list(
3194
+ preview_windows.get("masked_token_counts", [])
3195
+ ),
3196
+ "actual_token_counts": list(
3197
+ preview_windows.get("actual_token_counts", [])
3198
+ ),
3199
+ "labels": [
3200
+ list(seq) for seq in preview_windows.get("labels", [])
3201
+ ],
3202
+ },
3203
+ "final": {
3204
+ "window_ids": list(final_windows.get("window_ids", [])),
3205
+ "logloss": list(final_windows.get("logloss", [])),
3206
+ "input_ids": [
3207
+ list(seq) for seq in final_windows.get("input_ids", [])
3208
+ ],
3209
+ "attention_masks": [
3210
+ list(mask)
3211
+ for mask in final_windows.get("attention_masks", [])
3212
+ ],
3213
+ "token_counts": list(final_windows.get("token_counts", [])),
3214
+ "masked_token_counts": list(
3215
+ final_windows.get("masked_token_counts", [])
3216
+ ),
3217
+ "actual_token_counts": list(
3218
+ final_windows.get("actual_token_counts", [])
3219
+ ),
3220
+ "labels": [
3221
+ list(seq) for seq in final_windows.get("labels", [])
3222
+ ],
3223
+ },
3224
+ }
3225
+ elif had_baseline and (profile or "").lower() == "release":
3226
+ console.print(
3227
+ "[red]❌ [INVARLOCK:E001] PAIRING-SCHEDULE-MISMATCH: baseline pairing requested but evaluation windows were not produced. Check capacity/pairing config.[/red]"
3228
+ )
3229
+ raise typer.Exit(3)
3230
+ else:
3231
+ # Populate evaluation_windows directly from assembled records when the
3232
+ # runner did not provide a structured window payload. This ensures
3233
+ # provenance (provider_digest) can be computed even in lightweight/dev
3234
+ # runs and unit tests that stub the runner.
3235
+ try:
3236
+
3237
+ def _tokens(rec: dict[str, Any]) -> int:
3238
+ try:
3239
+ return int(len(rec.get("input_ids", []) or []))
3240
+ except Exception:
3241
+ return 0
3242
+
3243
+ report["evaluation_windows"] = {
3244
+ "preview": {
3245
+ "window_ids": [
3246
+ f"preview::{i}" for i in range(len(preview_records))
3247
+ ],
3248
+ "input_ids": [
3249
+ list(r["input_ids"]) for r in preview_records
3250
+ ],
3251
+ "attention_masks": [
3252
+ list(r["attention_mask"]) for r in preview_records
3253
+ ],
3254
+ "token_counts": [_tokens(r) for r in preview_records],
3255
+ **(
3256
+ {
3257
+ "masked_token_counts": list(preview_mask_counts),
3258
+ "labels": [
3259
+ r.get("labels", [-100] * len(r["input_ids"]))
3260
+ for r in preview_records
3261
+ ],
3262
+ }
3263
+ if use_mlm
3264
+ else {}
3265
+ ),
3266
+ },
3267
+ "final": {
3268
+ "window_ids": [
3269
+ f"final::{i}" for i in range(len(final_records))
3270
+ ],
3271
+ "input_ids": [list(r["input_ids"]) for r in final_records],
3272
+ "attention_masks": [
3273
+ list(r["attention_mask"]) for r in final_records
3274
+ ],
3275
+ "token_counts": [_tokens(r) for r in final_records],
3276
+ **(
3277
+ {
3278
+ "masked_token_counts": list(final_mask_counts),
3279
+ "labels": [
3280
+ r.get("labels", [-100] * len(r["input_ids"]))
3281
+ for r in final_records
3282
+ ],
3283
+ }
3284
+ if use_mlm
3285
+ else {}
3286
+ ),
3287
+ },
3288
+ }
3289
+ except Exception:
3290
+ # Best-effort: provenance digest will be skipped if windows cannot be built
3291
+ pass
3292
+
3293
+ # Attach provider digest and dataset split provenance when available
3294
+ try:
3295
+ prov = report.setdefault("provenance", {})
3296
+ # Always record dataset split provenance for visibility
3297
+ try:
3298
+ prov["dataset_split"] = str(resolved_split)
3299
+ prov["split_fallback"] = bool(used_fallback_split)
3300
+ except Exception:
3301
+ pass
3302
+ provider_digest = _compute_provider_digest(report)
3303
+ if provider_digest:
3304
+ prov["provider_digest"] = provider_digest
3305
+ # Attach digest version for future evolution
3306
+ prov["digest_version"] = 1
3307
+ # Strict parity checks in CI/Release when baseline present
3308
+ try:
3309
+ if isinstance(baseline_report_data, dict):
3310
+ base_prov = (
3311
+ baseline_report_data.get("provenance", {})
3312
+ if isinstance(
3313
+ baseline_report_data.get("provenance"), dict
3314
+ )
3315
+ else {}
3316
+ )
3317
+ base_digest = (
3318
+ base_prov.get("provider_digest")
3319
+ if isinstance(base_prov, dict)
3320
+ else None
3321
+ )
3322
+ _enforce_provider_parity(
3323
+ provider_digest,
3324
+ base_digest,
3325
+ profile=(str(profile).lower() if profile else None),
3326
+ )
3327
+ except InvarlockError as ce:
3328
+ console.print(str(ce))
3329
+ # Map to profile-aware exit code: dev→1, ci/release→3
3330
+ raise typer.Exit(
3331
+ _resolve_exit_code(ce, profile=profile)
3332
+ ) from None
3333
+ except RuntimeError as _e:
3334
+ _fail_run(str(_e))
3335
+ except Exception:
3336
+ pass
3337
+ except Exception:
3338
+ pass
3339
+
3340
+ # Transfer guard results
3341
+ if hasattr(core_report, "guards") and core_report.guards:
3342
+ for guard_name, guard_result in core_report.guards.items():
3343
+ guard_entry = {
3344
+ "name": guard_name,
3345
+ "passed": guard_result.get("passed"),
3346
+ "action": guard_result.get("action"),
3347
+ "policy": guard_result.get("policy", {}),
3348
+ "metrics": guard_result.get("metrics", {}),
3349
+ "actions": guard_result.get("actions", []),
3350
+ "violations": guard_result.get("violations", []),
3351
+ "warnings": guard_result.get("warnings", []),
3352
+ "errors": guard_result.get("errors", []),
3353
+ "details": guard_result.get("details", {}),
3354
+ }
3355
+ for extra_key in ("final_z_scores", "module_family_map"):
3356
+ if extra_key in guard_result:
3357
+ guard_entry[extra_key] = guard_result[extra_key]
3358
+ report["guards"].append(guard_entry)
3359
+
3360
+ # Set artifacts
3361
+ report["artifacts"].update(
3362
+ {
3363
+ "events_path": str(run_config.event_path)
3364
+ if run_config.event_path
3365
+ else "",
3366
+ "logs_path": "",
3367
+ "checkpoint_path": None,
3368
+ }
3369
+ )
3370
+
3371
+ # Optional: export HF-loadable model snapshot when requested
3372
+ export_env = str(
3373
+ os.environ.get("INVARLOCK_EXPORT_MODEL", "")
3374
+ ).strip().lower() in {
3375
+ "1",
3376
+ "true",
3377
+ "yes",
3378
+ "on",
3379
+ }
3380
+ save_model_cfg = False
3381
+ try:
3382
+ save_model_cfg = bool(
3383
+ getattr(getattr(cfg, "output", {}), "save_model", False)
3384
+ )
3385
+ except Exception:
3386
+ save_model_cfg = False
3387
+ if export_env or save_model_cfg:
3388
+ try:
3389
+ # Resolve destination with precedence:
3390
+ # 1) cfg.output.model_dir (absolute or relative to run_dir)
3391
+ # 2) env INVARLOCK_EXPORT_DIR (absolute or relative)
3392
+ # 3) cfg.output.model_subdir (under run_dir)
3393
+ # 4) default: run_dir / "model"
3394
+ export_dir: Path | None = None
3395
+ # (1) explicit model_dir in config
3396
+ try:
3397
+ out_cfg = getattr(cfg, "output", None)
3398
+ model_dir_cfg = None
3399
+ if out_cfg is not None:
3400
+ model_dir_cfg = getattr(
3401
+ out_cfg, "model_dir", None
3402
+ ) or getattr(out_cfg, "model_path", None)
3403
+ if model_dir_cfg:
3404
+ p = Path(str(model_dir_cfg))
3405
+ export_dir = p if p.is_absolute() else (run_dir / p)
3406
+ except Exception:
3407
+ export_dir = None
3408
+ # (2) env override
3409
+ if export_dir is None:
3410
+ env_dir_raw = os.environ.get("INVARLOCK_EXPORT_DIR", "")
3411
+ if isinstance(env_dir_raw, str) and env_dir_raw.strip():
3412
+ p = Path(env_dir_raw.strip())
3413
+ export_dir = p if p.is_absolute() else (run_dir / p)
3414
+ # (3) config subdir
3415
+ if export_dir is None:
3416
+ export_subdir = "model"
3417
+ try:
3418
+ export_subdir = str(
3419
+ getattr(
3420
+ getattr(cfg, "output", {}), "model_subdir", "model"
3421
+ )
3422
+ )
3423
+ except Exception:
3424
+ export_subdir = "model"
3425
+ export_dir = run_dir / export_subdir
3426
+
3427
+ # Ensure directory exists
3428
+ ok = False
3429
+ if hasattr(adapter, "save_pretrained") and model is not None:
3430
+ ok = bool(adapter.save_pretrained(model, export_dir)) # type: ignore[attr-defined]
3431
+ if ok:
3432
+ report["artifacts"]["checkpoint_path"] = str(export_dir)
3433
+ else:
3434
+ console.print(
3435
+ "[yellow]⚠️ Model export requested but adapter did not save a HF directory.[/yellow]"
3436
+ )
3437
+ except Exception:
3438
+ console.print(
3439
+ "[yellow]⚠️ Model export requested but failed due to an unexpected error.[/yellow]"
3440
+ )
3441
+
3442
+ # Set flags
3443
+ report["flags"].update(
3444
+ {
3445
+ "guard_recovered": any(
3446
+ not g.get("passed", True)
3447
+ for g in core_report.guards.values()
3448
+ if hasattr(core_report, "guards") and core_report.guards
3449
+ ),
3450
+ "rollback_reason": None,
3451
+ }
3452
+ )
3453
+
3454
+ metrics_section = report.get("metrics", {}) or {}
3455
+ data_section = report.get("data", {}) or {}
3456
+ preview_count_report = data_section.get("preview_n")
3457
+ final_count_report = data_section.get("final_n")
3458
+
3459
+ # Classification metric (accuracy) — deterministic smoke path
3460
+ # If loss type is explicitly 'classification', derive accuracy
3461
+ # counts from evaluation windows using a deterministic label rule.
3462
+ try:
3463
+ loss_type_ctx = (
3464
+ run_config.context.get("eval", {})
3465
+ .get("loss", {})
3466
+ .get("resolved_type")
3467
+ )
3468
+ except Exception:
3469
+ loss_type_ctx = None
3470
+ if str(loss_type_ctx).lower() == "classification":
3471
+ try:
3472
+ from invarlock.eval.primary_metric import compute_accuracy_counts
3473
+
3474
+ # Prefer in-memory core_report.evaluation_windows (includes input_ids)
3475
+ ew = {}
3476
+ try:
3477
+ if hasattr(core_report, "evaluation_windows") and isinstance(
3478
+ core_report.evaluation_windows, dict
3479
+ ):
3480
+ ew = core_report.evaluation_windows # type: ignore[assignment]
3481
+ except Exception:
3482
+ ew = {}
3483
+ if not ew:
3484
+ # Fallback to the soon-to-be persisted report windows (may lack input_ids)
3485
+ ew = (
3486
+ report.get("evaluation_windows", {})
3487
+ if isinstance(report.get("evaluation_windows"), dict)
3488
+ else {}
3489
+ )
3490
+ prev_rec = []
3491
+ fin_rec = []
3492
+ if isinstance(ew, dict):
3493
+ prev = ew.get("preview", {})
3494
+ fin = ew.get("final", {})
3495
+ if isinstance(prev, dict):
3496
+ prev_rec = [
3497
+ {"input_ids": seq}
3498
+ for seq in prev.get("input_ids", []) or []
3499
+ if isinstance(seq, list)
3500
+ ]
3501
+ if isinstance(fin, dict):
3502
+ fin_rec = [
3503
+ {"input_ids": seq}
3504
+ for seq in fin.get("input_ids", []) or []
3505
+ if isinstance(seq, list)
3506
+ ]
3507
+ c_prev, n_prev = compute_accuracy_counts(prev_rec)
3508
+ c_fin, n_fin = compute_accuracy_counts(fin_rec)
3509
+ # If we could not derive counts (no windows persisted), fall back to
3510
+ # deterministic pseudo-accuracy based on configured window counts.
3511
+ used_pseudo_counts = False
3512
+ if n_prev == 0 and n_fin == 0:
3513
+ try:
3514
+ prev_n_cfg = getattr(cfg.dataset, "preview_n", None)
3515
+ fin_n_cfg = getattr(cfg.dataset, "final_n", None)
3516
+ except Exception:
3517
+ prev_n_cfg = None
3518
+ fin_n_cfg = None
3519
+ try:
3520
+ prev_n = int(preview_count_report or prev_n_cfg or 0)
3521
+ fin_n = int(final_count_report or fin_n_cfg or 0)
3522
+ except Exception:
3523
+ prev_n = 0
3524
+ fin_n = 0
3525
+ c_prev, n_prev = (prev_n, prev_n) if prev_n > 0 else (0, 0)
3526
+ c_fin, n_fin = (fin_n, fin_n) if fin_n > 0 else (0, 0)
3527
+ used_pseudo_counts = prev_n > 0 or fin_n > 0
3528
+ classification_metrics = {
3529
+ "preview": {"correct_total": int(c_prev), "total": int(n_prev)},
3530
+ "final": {"correct_total": int(c_fin), "total": int(n_fin)},
3531
+ }
3532
+ # Tag source of counts for downstream rendering/doctor
3533
+ if used_pseudo_counts:
3534
+ classification_metrics["counts_source"] = "pseudo_config"
3535
+ # Add a provenance crumb for transparency
3536
+ try:
3537
+ prov = report.setdefault("provenance", {})
3538
+ notes = prov.setdefault("metric_notes", [])
3539
+ if isinstance(notes, list):
3540
+ notes.append(
3541
+ "accuracy: pseudo counts from preview_n/final_n"
3542
+ )
3543
+ except Exception:
3544
+ pass
3545
+ else:
3546
+ classification_metrics["counts_source"] = "measured"
3547
+ report.setdefault("metrics", {})["classification"] = (
3548
+ classification_metrics
3549
+ )
3550
+ # Convenience: top-level accuracy (final)
3551
+ if n_fin > 0:
3552
+ report["metrics"]["accuracy"] = float(c_fin / n_fin)
3553
+ except Exception:
3554
+ pass
3555
+
3556
+ match_fraction = metrics_section.get("window_match_fraction")
3557
+ if match_fraction is not None and not math.isclose(
3558
+ match_fraction, 1.0, rel_tol=0.0, abs_tol=1e-9
3559
+ ):
3560
+ err = InvarlockError(
3561
+ code="E001",
3562
+ message=(
3563
+ f"PAIRING-SCHEDULE-MISMATCH: window_match_fraction={match_fraction:.3f}"
3564
+ ),
3565
+ details={"window_match_fraction": float(match_fraction)},
3566
+ )
3567
+ code = _resolve_exit_code(err, profile=profile_normalized)
3568
+ console.print(f"[red]{err}[/red]")
3569
+ raise typer.Exit(code)
3570
+
3571
+ overlap_fraction = metrics_section.get("window_overlap_fraction")
3572
+ if overlap_fraction is not None and overlap_fraction > 1e-9:
3573
+ err = InvarlockError(
3574
+ code="E001",
3575
+ message=(
3576
+ f"PAIRING-SCHEDULE-MISMATCH: window_overlap_fraction={overlap_fraction:.3f}"
3577
+ ),
3578
+ details={"window_overlap_fraction": float(overlap_fraction)},
3579
+ )
3580
+ code = _resolve_exit_code(err, profile=profile_normalized)
3581
+ console.print(f"[red]{err}[/red]")
3582
+ raise typer.Exit(code)
3583
+
3584
+ # Additional guard: paired_windows collapse (0) in CI/Release
3585
+ try:
3586
+ paired_windows_val = metrics_section.get("paired_windows")
3587
+ if (
3588
+ profile_normalized in {"ci", "release"}
3589
+ and isinstance(paired_windows_val, (int | float))
3590
+ and int(paired_windows_val) == 0
3591
+ ):
3592
+ err = InvarlockError(
3593
+ code="E001",
3594
+ message=(
3595
+ "PAIRED-WINDOWS-COLLAPSED: paired_windows=0 under paired schedule. "
3596
+ "Check device stability, dataset windows, or edit scope."
3597
+ ),
3598
+ details={
3599
+ "paired_windows": int(paired_windows_val),
3600
+ "profile": profile_normalized,
3601
+ },
3602
+ )
3603
+ code = _resolve_exit_code(err, profile=profile_normalized)
3604
+ console.print(f"[red]{err}[/red]")
3605
+ raise typer.Exit(code)
3606
+ except Exception:
3607
+ pass
3608
+
3609
+ expected_preview = effective_preview or getattr(
3610
+ cfg.dataset, "preview_n", preview_count_report
3611
+ )
3612
+ expected_final = effective_final or getattr(
3613
+ cfg.dataset, "final_n", final_count_report
3614
+ )
3615
+ if (
3616
+ preview_count_report is not None
3617
+ and expected_preview is not None
3618
+ and int(preview_count_report) != int(expected_preview)
3619
+ ) or (
3620
+ final_count_report is not None
3621
+ and expected_final is not None
3622
+ and int(final_count_report) != int(expected_final)
3623
+ ):
3624
+ err = InvarlockError(
3625
+ code="E001",
3626
+ message=(
3627
+ "PAIRING-SCHEDULE-MISMATCH: counts do not match configuration after stratification"
3628
+ ),
3629
+ details={
3630
+ "preview_used": int(preview_count_report or -1),
3631
+ "preview_expected": int(expected_preview or -1),
3632
+ "final_used": int(final_count_report or -1),
3633
+ "final_expected": int(expected_final or -1),
3634
+ },
3635
+ )
3636
+ code = _resolve_exit_code(err, profile=profile_normalized)
3637
+ console.print(f"[red]{err}[/red]")
3638
+ raise typer.Exit(code)
3639
+
3640
+ # Compute metric-v1 snapshot (primary_metric) — canonical path
3641
+ try:
3642
+ metric_kind_resolved, _provider_kind, metric_opts = (
3643
+ _resolve_metric_and_provider(
3644
+ cfg, model_profile, resolved_loss_type=resolved_loss_type
3645
+ )
3646
+ )
3647
+ if metric_kind_resolved:
3648
+ from invarlock.eval.primary_metric import (
3649
+ compute_primary_metric_from_report,
3650
+ )
3651
+
3652
+ pm = compute_primary_metric_from_report(
3653
+ report, kind=metric_kind_resolved, baseline=baseline_report_data
3654
+ )
3655
+ report.setdefault("metrics", {})["primary_metric"] = pm
3656
+ # Attach configured reps/ci_level when provided
3657
+ if metric_opts:
3658
+ try:
3659
+ if "reps" in metric_opts:
3660
+ report["metrics"]["primary_metric"]["reps"] = int(
3661
+ metric_opts["reps"]
3662
+ ) # type: ignore[index]
3663
+ if "ci_level" in metric_opts:
3664
+ report["metrics"]["primary_metric"]["ci_level"] = float(
3665
+ metric_opts["ci_level"]
3666
+ ) # type: ignore[index]
3667
+ except Exception:
3668
+ pass
3669
+ # Shadow parity check against legacy ppl fields (best-effort)
3670
+ try:
3671
+ pm_blk = report.get("metrics", {}).get("primary_metric", {})
3672
+ ppl_final_v1 = float(pm_blk.get("final"))
3673
+ ppl_final_v2 = float(pm.get("final", float("nan")))
3674
+ if math.isfinite(ppl_final_v1) and math.isfinite(ppl_final_v2):
3675
+ if not math.isclose(
3676
+ ppl_final_v1, ppl_final_v2, rel_tol=1e-9, abs_tol=1e-9
3677
+ ):
3678
+ report.setdefault("metrics", {}).setdefault(
3679
+ "_metric_v1_mismatch", {}
3680
+ )["ppl_final_diff"] = ppl_final_v2 - ppl_final_v1
3681
+ # Optional: dual-write diffs logging for ppl_* metrics
3682
+ debug_diffs = str(
3683
+ os.environ.get("DEBUG_METRIC_DIFFS", "")
3684
+ ).strip().lower() in {"1", "true", "yes", "on"}
3685
+ if debug_diffs and str(pm.get("kind", "")).startswith("ppl"):
3686
+ diffs_line = _format_debug_metric_diffs(
3687
+ pm, report.get("metrics", {}), baseline_report_data
3688
+ )
3689
+ if diffs_line:
3690
+ console.print(
3691
+ "[dim]DEBUG_METRIC_DIFFS: " + diffs_line + "[/dim]"
3692
+ )
3693
+ except Exception:
3694
+ pass
3695
+ except Exception:
3696
+ # Non-fatal: metric-v1 snapshot should not break runs
3697
+ pass
3698
+
3699
+ # No deprecation notices in dev-phase: primary_metric is canonical.
3700
+
3701
+ # Derive dataset.windows.stats (PM-only surface)
3702
+ try:
3703
+ ds = report.setdefault("dataset", {}).setdefault("windows", {})
3704
+ stats = ds.setdefault("stats", {})
3705
+ if match_fraction is not None:
3706
+ stats["window_match_fraction"] = float(match_fraction)
3707
+ if overlap_fraction is not None:
3708
+ stats["window_overlap_fraction"] = float(overlap_fraction)
3709
+ try:
3710
+ if isinstance(window_plan, dict) and "coverage_ok" in window_plan:
3711
+ stats["coverage"] = bool(window_plan.get("coverage_ok"))
3712
+ except Exception:
3713
+ pass
3714
+ except Exception:
3715
+ pass
3716
+
3717
+ _postprocess_and_summarize(
3718
+ report=report,
3719
+ run_dir=run_dir,
3720
+ run_config=run_config,
3721
+ window_plan=window_plan,
3722
+ dataset_meta=dataset_meta,
3723
+ match_fraction=match_fraction,
3724
+ overlap_fraction=overlap_fraction,
3725
+ console=console,
3726
+ )
3727
+
3728
+ # Metrics display
3729
+ pm_obj = None
3730
+ try:
3731
+ pm_obj = report.get("metrics", {}).get("primary_metric")
3732
+ except Exception:
3733
+ pm_obj = None
3734
+ if isinstance(pm_obj, dict) and pm_obj:
3735
+ try:
3736
+ pm_kind = str(pm_obj.get("kind", "primary")).lower()
3737
+ pm_prev = pm_obj.get("preview")
3738
+ pm_fin = pm_obj.get("final")
3739
+ if isinstance(pm_prev, (int | float)) and isinstance(
3740
+ pm_fin, (int | float)
3741
+ ):
3742
+ console.print(
3743
+ f"📌 Primary Metric [{pm_kind}] — preview: {pm_prev:.3f}, final: {pm_fin:.3f}"
3744
+ )
3745
+ ratio_vs_base = pm_obj.get("ratio_vs_baseline")
3746
+ if isinstance(ratio_vs_base, (int | float)) and math.isfinite(
3747
+ ratio_vs_base
3748
+ ):
3749
+ console.print(
3750
+ f"🔗 Ratio vs baseline [{pm_kind}]: {ratio_vs_base:.3f}"
3751
+ )
3752
+ except Exception:
3753
+ pass
3754
+ # Legacy ppl_* console block removed in favor of primary_metric summary
3755
+
3756
+ guard_overhead_info = report.get("guard_overhead")
3757
+ if guard_overhead_info:
3758
+ threshold_fraction = _print_guard_overhead_summary(
3759
+ console, guard_overhead_info
3760
+ )
3761
+ if not guard_overhead_info.get("passed", True):
3762
+ console.print(
3763
+ "[red]⚠️ Guard overhead gate FAILED: Guards add more than the permitted budget[/red]"
3764
+ )
3765
+ # Only fail hard when the overhead check was actually evaluated
3766
+ # (e.g., for causal LMs with available bare/guarded PM). For
3767
+ # masked LM flows where ppl-like PM is undefined, record as not evaluated
3768
+ # and continue without aborting the run.
3769
+ loss_type_ctx = None
3770
+ try:
3771
+ loss_type_ctx = (
3772
+ run_config.context.get("eval", {})
3773
+ .get("loss", {})
3774
+ .get("resolved_type")
3775
+ )
3776
+ except Exception:
3777
+ loss_type_ctx = None
3778
+ if (
3779
+ measure_guard_overhead
3780
+ and guard_overhead_info.get("evaluated", False)
3781
+ and str(loss_type_ctx).lower() != "mlm"
3782
+ ):
3783
+ _fail_run(
3784
+ "Guard overhead gate exceeded the configured budget "
3785
+ f"(>{threshold_fraction * 100:.1f}% increase)"
3786
+ )
3787
+
3788
+ # Drift gate status is no longer surfaced in console; rely on certificate gates
3789
+
3790
+ # Certificate validation for --until-pass mode
3791
+ if retry_controller and baseline:
3792
+ from invarlock.reporting.certificate import make_certificate
3793
+
3794
+ try:
3795
+ baseline_report = baseline_report_data
3796
+ if baseline_report is None and baseline:
3797
+ baseline_path = Path(baseline)
3798
+ with baseline_path.open(encoding="utf-8") as f:
3799
+ baseline_report = json.load(f)
3800
+
3801
+ if baseline_report is None:
3802
+ raise FileNotFoundError("Baseline report unavailable")
3803
+
3804
+ console.print("📜 Generating safety certificate...")
3805
+ certificate = make_certificate(report, baseline_report)
3806
+
3807
+ validation = certificate.get("validation", {})
3808
+ certificate_passed = all(validation.values())
3809
+
3810
+ failed_gates = [k for k, v in validation.items() if not v]
3811
+ result_summary = {
3812
+ "passed": certificate_passed,
3813
+ "failures": failed_gates,
3814
+ "validation": validation,
3815
+ }
3816
+ retry_controller.record_attempt(
3817
+ attempt, result_summary, edit_config
3818
+ )
3819
+
3820
+ if certificate_passed:
3821
+ console.print("[green]✅ Certificate PASSED all gates![/green]")
3822
+ break
3823
+ else:
3824
+ console.print(
3825
+ f"[yellow]⚠️ Certificate FAILED gates: {', '.join(failed_gates)}[/yellow]"
3826
+ )
3827
+
3828
+ # Auto-tune mask-only heads (binary search on keep count)
3829
+ try:
3830
+ head_section = None
3831
+ for k in ("heads", "head_budget", "head_budgets"):
3832
+ if isinstance(edit_config.get(k), dict):
3833
+ head_section = edit_config[k]
3834
+ break
3835
+ search = (
3836
+ head_section.get("_auto_search")
3837
+ if isinstance(head_section, dict)
3838
+ else None
3839
+ )
3840
+ if isinstance(search, dict) and head_section.get(
3841
+ "mask_only"
3842
+ ):
3843
+ keep_low = int(search.get("keep_low", 0))
3844
+ keep_high = int(
3845
+ search.get(
3846
+ "keep_high", search.get("total_heads", 0)
3847
+ )
3848
+ )
3849
+ keep_current = int(
3850
+ search.get("keep_current", keep_high)
3851
+ )
3852
+ # If the quality gate (PM) is unacceptable, increase keep (less pruning); if only other gates failed, be conservative and increase keep slightly
3853
+ pm_ok = bool(
3854
+ validation.get("primary_metric_acceptable", False)
3855
+ )
3856
+ if not pm_ok:
3857
+ keep_low = max(keep_low, keep_current)
3858
+ else:
3859
+ # drift/spectral/etc failed: ease pruning
3860
+ keep_low = max(keep_low, keep_current)
3861
+ next_keep = int((keep_low + keep_high + 1) // 2)
3862
+ search.update(
3863
+ {
3864
+ "keep_low": keep_low,
3865
+ "keep_high": keep_high,
3866
+ "keep_current": next_keep,
3867
+ }
3868
+ )
3869
+ head_section["global_k"] = next_keep
3870
+ console.print(
3871
+ f"🔧 Auto-tune adjust: global_k → {next_keep} (bounds {keep_low}-{keep_high})"
3872
+ )
3873
+ except Exception:
3874
+ pass
3875
+
3876
+ if retry_controller.should_retry(certificate_passed):
3877
+ attempt += 1
3878
+ continue
3879
+ else:
3880
+ console.print(
3881
+ f"[red]❌ Exhausted retry budget after {attempt} attempts[/red]"
3882
+ )
3883
+ break
3884
+
3885
+ except Exception as cert_error:
3886
+ console.print(
3887
+ f"[yellow]⚠️ Certificate validation failed: {cert_error}[/yellow]"
3888
+ )
3889
+ if retry_controller:
3890
+ retry_controller.record_attempt(
3891
+ attempt,
3892
+ {
3893
+ "passed": False,
3894
+ "failures": ["certificate_error"],
3895
+ "validation": {},
3896
+ },
3897
+ edit_config,
3898
+ )
3899
+ break
3900
+ else:
3901
+ if retry_controller:
3902
+ retry_controller.record_attempt(
3903
+ attempt,
3904
+ {"passed": True, "failures": [], "validation": {}},
3905
+ edit_config,
3906
+ )
3907
+ # No retry mode - single run
3908
+ break
3909
+
3910
+ # Show retry summary if applicable
3911
+ _print_retry_summary(console, retry_controller)
3912
+
3913
+ # (moved) Cleanup printing occurs after loop to guarantee execution
3914
+ pass
3915
+
3916
+ # Normal path falls through; cleanup handled below in finally
3917
+
3918
+ except FileNotFoundError as e:
3919
+ console.print(f"[red]❌ Configuration file not found: {e}[/red]")
3920
+ raise typer.Exit(1) from e
3921
+ except InvarlockError as ce:
3922
+ # InvarlockError → code 3 only in CI/Release; dev → 1
3923
+ console.print(str(ce))
3924
+ raise typer.Exit(_resolve_exit_code(ce, profile=profile)) from ce
3925
+ except (typer.Exit, SystemExit, click.exceptions.Exit):
3926
+ # Preserve explicit exit codes (e.g., parity checks, user-triggered exits)
3927
+ raise
3928
+ except Exception as e:
3929
+ if os.environ.get("INVARLOCK_DEBUG_TRACE"):
3930
+ import traceback
3931
+
3932
+ traceback.print_exc()
3933
+ # Emit a clearer message for schema failures (exit 2)
3934
+ if isinstance(e, ValueError) and "Invalid RunReport" in str(e):
3935
+ console.print(
3936
+ "[red]❌ Schema invalid: run report structure failed validation[/red]"
3937
+ )
3938
+ code = 2
3939
+ else:
3940
+ console.print(f"[red]❌ Pipeline execution failed: {e}[/red]")
3941
+ code = _resolve_exit_code(e, profile=profile)
3942
+ raise typer.Exit(code) from e
3943
+ finally:
3944
+ # Cleanup snapshot directory if used (always print once per run)
3945
+ try:
3946
+ if snapshot_tmpdir and not no_cleanup:
3947
+ try:
3948
+ import shutil as _sh
3949
+
3950
+ _sh.rmtree(snapshot_tmpdir, ignore_errors=True)
3951
+ except Exception:
3952
+ pass
3953
+ finally:
3954
+ console.print("cleanup: removed")
3955
+ else:
3956
+ console.print("cleanup: skipped")
3957
+ except Exception:
3958
+ # Best-effort cleanup printing; never raise from finally
3959
+ pass
3960
+
3961
+
3962
+ def _format_debug_metric_diffs(
3963
+ pm: dict[str, float] | None,
3964
+ metrics: dict[str, float] | None,
3965
+ baseline_report_data: dict | None,
3966
+ ) -> str:
3967
+ """Build a compact DEBUG_METRIC_DIFFS line comparing current snapshot vs legacy ppl_*.
3968
+
3969
+ Returns a semicolon-separated string of deltas like
3970
+ "final: v1-v1 = +0.000000000; Δlog(final): +0.000000000; ...". Safe to call with
3971
+ missing fields; non-finite entries are skipped.
3972
+ """
3973
+ import math as _m
3974
+
3975
+ if not isinstance(pm, dict) or not isinstance(metrics, dict):
3976
+ return ""
3977
+ diffs: list[str] = []
3978
+ try:
3979
+ pm_blk = metrics.get("primary_metric", {}) if isinstance(metrics, dict) else {}
3980
+ ppl_final_v1 = float(pm_blk.get("final", float("nan")))
3981
+ except Exception:
3982
+ ppl_final_v1 = float("nan")
3983
+ try:
3984
+ ppl_prev_v1 = float(pm_blk.get("preview", float("nan")))
3985
+ except Exception:
3986
+ ppl_prev_v1 = float("nan")
3987
+ try:
3988
+ ppl_final_v2 = float(pm.get("final", float("nan")))
3989
+ except Exception:
3990
+ ppl_final_v2 = float("nan")
3991
+ try:
3992
+ ppl_prev_v2 = float(pm.get("preview", float("nan")))
3993
+ except Exception:
3994
+ ppl_prev_v2 = float("nan")
3995
+
3996
+ if _m.isfinite(ppl_final_v1) and _m.isfinite(ppl_final_v2):
3997
+ diffs.append(f"final: v1-v1 = {ppl_final_v2 - ppl_final_v1:+.9f}")
3998
+ try:
3999
+ diffs.append(
4000
+ f"Δlog(final): {_m.log(ppl_final_v2) - _m.log(ppl_final_v1):+.9f}"
4001
+ )
4002
+ except Exception:
4003
+ pass
4004
+ if _m.isfinite(ppl_prev_v1) and _m.isfinite(ppl_prev_v2):
4005
+ diffs.append(f"preview: v1-v1 = {ppl_prev_v2 - ppl_prev_v1:+.9f}")
4006
+ try:
4007
+ diffs.append(
4008
+ f"Δlog(preview): {_m.log(ppl_prev_v2) - _m.log(ppl_prev_v1):+.9f}"
4009
+ )
4010
+ except Exception:
4011
+ pass
4012
+
4013
+ # ratio vs baseline
4014
+ try:
4015
+ r_v2 = float(pm.get("ratio_vs_baseline", float("nan")))
4016
+ except Exception:
4017
+ r_v2 = float("nan")
4018
+ # prefer PM ratio when present
4019
+ r_v1 = float(pm_blk.get("ratio_vs_baseline", float("nan")))
4020
+ if (not _m.isfinite(r_v1)) and isinstance(baseline_report_data, dict):
4021
+ try:
4022
+ base_fin = float(
4023
+ (
4024
+ (baseline_report_data.get("metrics") or {}).get("primary_metric")
4025
+ or {}
4026
+ ).get("final")
4027
+ )
4028
+ if _m.isfinite(base_fin) and base_fin > 0 and _m.isfinite(ppl_final_v1):
4029
+ r_v1 = ppl_final_v1 / base_fin
4030
+ except Exception:
4031
+ pass
4032
+ if _m.isfinite(r_v1) and _m.isfinite(r_v2):
4033
+ diffs.append(f"ratio_vs_baseline: v1-v1 = {r_v2 - r_v1:+.9f}")
4034
+ return "; ".join(diffs)
4035
+
4036
+
4037
+ # Provide a module shim so tests can patch 'src.invarlock.cli.commands.run.shutil.*'.
4038
+ try: # best-effort; harmless in production
4039
+ _shim = _types.ModuleType(__name__ + ".shutil")
4040
+
4041
+ def _shim_getattr(name: str): # pragma: no cover
4042
+ return getattr(shutil, name)
4043
+
4044
+ _shim.__getattr__ = _shim_getattr # type: ignore[attr-defined]
4045
+ _shim.disk_usage = shutil.disk_usage # type: ignore[attr-defined]
4046
+ _shim.rmtree = shutil.rmtree # type: ignore[attr-defined]
4047
+ _sys.modules[__name__ + ".shutil"] = _shim
4048
+ _sys.modules["src." + __name__ + ".shutil"] = _shim
4049
+ except Exception:
4050
+ pass
4051
+
4052
+
4053
+ def _normalize_overhead_result(
4054
+ payload: dict[str, object] | None, profile: str | None = None
4055
+ ) -> dict[str, object]:
4056
+ """Normalize guard-overhead payload for tiny/degenerate runs.
4057
+
4058
+ If the computed overhead ratio is missing or non-finite, mark the check as
4059
+ not evaluated and passed to avoid spurious gate failures in tiny runs.
4060
+ """
4061
+ payload = dict(payload or {})
4062
+ try:
4063
+ ratio = payload.get("overhead_ratio")
4064
+ val = float(ratio) if isinstance(ratio, int | float) else float("nan")
4065
+ except Exception:
4066
+ val = float("nan")
4067
+ if not (isinstance(val, float) and math.isfinite(val)):
4068
+ payload["evaluated"] = False
4069
+ payload["passed"] = True
4070
+ return payload
4071
+
4072
+
4073
+ # helper moved to invarlock.cli.overhead_utils
4074
+
4075
+
4076
+ def _print_guard_overhead_summary(
4077
+ console: Console, guard_overhead_info: dict[str, Any]
4078
+ ) -> float:
4079
+ """Print a concise guard-overhead console summary. Returns threshold fraction used."""
4080
+ evaluated = bool(guard_overhead_info.get("evaluated", True))
4081
+ if not evaluated:
4082
+ console.print("🛡️ Guard Overhead: not evaluated")
4083
+ return GUARD_OVERHEAD_THRESHOLD
4084
+ overhead_status = (
4085
+ "✅ PASS" if guard_overhead_info.get("passed", True) else "❌ FAIL"
4086
+ )
4087
+ overhead_percent = guard_overhead_info.get("overhead_percent")
4088
+ if isinstance(overhead_percent, (int | float)) and math.isfinite(
4089
+ float(overhead_percent)
4090
+ ):
4091
+ overhead_display = f"{float(overhead_percent):+.2f}%"
4092
+ else:
4093
+ ratio_value = guard_overhead_info.get("overhead_ratio")
4094
+ if isinstance(ratio_value, (int | float)) and math.isfinite(float(ratio_value)):
4095
+ overhead_display = f"{float(ratio_value):.3f}x"
4096
+ else:
4097
+ # Avoid any 'nanx' or ambiguous output
4098
+ overhead_display = "not evaluated"
4099
+ threshold_percent = guard_overhead_info.get("overhead_threshold", 0.01)
4100
+ try:
4101
+ threshold_fraction = float(threshold_percent)
4102
+ except (TypeError, ValueError):
4103
+ threshold_fraction = GUARD_OVERHEAD_THRESHOLD
4104
+ threshold_display = f"≤ +{threshold_fraction * 100:.1f}%"
4105
+ console.print(
4106
+ f"🛡️ Guard Overhead: {overhead_status} {overhead_display} ({threshold_display})"
4107
+ )
4108
+ return threshold_fraction
4109
+
4110
+
4111
+ def _print_retry_summary(console: Console, retry_controller: Any | None) -> None:
4112
+ """Print a one-line retry summary when retries were attempted."""
4113
+ try:
4114
+ if retry_controller and getattr(retry_controller, "attempt_history", None):
4115
+ summary = retry_controller.get_attempt_summary()
4116
+ console.print(
4117
+ f"\n📊 Retry Summary: {summary['total_attempts']} attempts in {summary['elapsed_time']:.1f}s"
4118
+ )
4119
+ except Exception:
4120
+ # Never break the run for summary printing
4121
+ pass
4122
+
4123
+
4124
+ def _init_retry_controller(
4125
+ *,
4126
+ until_pass: bool,
4127
+ max_attempts: int,
4128
+ timeout: int | None,
4129
+ baseline: str | None,
4130
+ console: Console,
4131
+ ):
4132
+ """Initialize RetryController with consistent console prints."""
4133
+ retry_controller = None
4134
+ if until_pass:
4135
+ from invarlock.core.retry import RetryController
4136
+
4137
+ retry_controller = RetryController(
4138
+ max_attempts=max_attempts, timeout=timeout, verbose=True
4139
+ )
4140
+ console.print(f"🔄 Retry mode enabled: max {max_attempts} attempts")
4141
+ if baseline:
4142
+ console.print(f"📋 Using baseline: {baseline}")
4143
+ else:
4144
+ if baseline:
4145
+ console.print(f"📋 Using baseline: {baseline}")
4146
+ return retry_controller