invarlock 0.3.2__py3-none-any.whl → 0.3.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -200,6 +200,10 @@ def _free_model_memory(model: object | None) -> None:
200
200
  pass
201
201
 
202
202
 
203
+ class _SnapshotRestoreFailed(RuntimeError):
204
+ """Internal signal for snapshot restore failures during retries."""
205
+
206
+
203
207
  def _should_measure_overhead(profile_normalized: str) -> tuple[bool, bool]:
204
208
  """Return (measure_guard_overhead, skip_overhead) derived from env/profile."""
205
209
 
@@ -301,6 +305,12 @@ def _hash_sequences(seqs: Sequence[Sequence[int]] | Iterable[Sequence[int]]) ->
301
305
  """Compute a stable digest for a sequence of integer token sequences."""
302
306
  hasher = hashlib.blake2s(digest_size=16)
303
307
  for seq in seqs:
308
+ try:
309
+ seq_len = len(seq)
310
+ except TypeError:
311
+ seq = list(seq)
312
+ seq_len = len(seq)
313
+ hasher.update(seq_len.to_bytes(4, "little", signed=False))
304
314
  arr = array("I", (int(token) & 0xFFFFFFFF for token in seq))
305
315
  hasher.update(arr.tobytes())
306
316
  return hasher.hexdigest()
@@ -360,6 +370,8 @@ def _tensor_or_list_to_ints(values: Any) -> list[int]:
360
370
  return _to_int_list(raw)
361
371
  try:
362
372
  return _to_int_list(list(raw))
373
+ except (typer.Exit, SystemExit, click.exceptions.Exit):
374
+ raise
363
375
  except Exception:
364
376
  pass
365
377
  # Numpy arrays: treat as list-like
@@ -547,30 +559,69 @@ def _extract_pairing_schedule(report: dict[str, Any] | None) -> dict[str, Any] |
547
559
  if not isinstance(windows, dict):
548
560
  return None
549
561
 
550
- def _sanitize(section_key: str) -> dict[str, Any] | None:
562
+ def _wrap_single_row(raw: Any, *, expected_rows: int) -> list | None:
563
+ if not isinstance(raw, list):
564
+ return None
565
+ if expected_rows == 1 and raw and not isinstance(raw[0], list):
566
+ return [raw]
567
+ return raw
568
+
569
+ def _sanitize(section_key: str, *, start_id: int) -> dict[str, Any] | None:
551
570
  section = windows.get(section_key)
552
571
  if not isinstance(section, dict):
553
572
  return None
554
- window_ids = list(section.get("window_ids", []))
555
- input_ids_raw = section.get("input_ids", [])
573
+ input_ids_raw = section.get("input_ids")
556
574
  if not isinstance(input_ids_raw, list):
557
575
  return None
558
- input_ids = [list(seq) for seq in input_ids_raw]
576
+ input_ids = [_tensor_or_list_to_ints(seq) for seq in input_ids_raw]
577
+ if not input_ids:
578
+ return None
579
+
580
+ window_ids_raw = section.get("window_ids")
581
+ window_ids: list[int] = []
582
+ if isinstance(window_ids_raw, list):
583
+ if len(window_ids_raw) != len(input_ids):
584
+ return None
585
+ for wid in window_ids_raw:
586
+ try:
587
+ window_ids.append(int(wid))
588
+ except Exception:
589
+ return None
590
+ else:
591
+ window_ids = list(range(int(start_id), int(start_id) + len(input_ids)))
592
+
559
593
  attention_raw = section.get("attention_masks")
560
- if isinstance(attention_raw, list) and all(
561
- isinstance(mask, list) for mask in attention_raw
562
- ):
563
- attention_masks = [list(mask) for mask in attention_raw]
594
+ attention_masks: list[list[int]]
595
+ if isinstance(attention_raw, list):
596
+ maybe = _wrap_single_row(attention_raw, expected_rows=len(input_ids))
597
+ if isinstance(maybe, list) and all(
598
+ isinstance(mask, list) for mask in maybe
599
+ ):
600
+ attention_masks = [_tensor_or_list_to_ints(mask) for mask in maybe]
601
+ else:
602
+ attention_masks = [
603
+ [1 if int(token) != 0 else 0 for token in seq] for seq in input_ids
604
+ ]
564
605
  else:
565
606
  attention_masks = [
566
- [1 if token != 0 else 0 for token in seq] for seq in input_ids
607
+ [1 if int(token) != 0 else 0 for token in seq] for seq in input_ids
567
608
  ]
609
+ if len(attention_masks) != len(input_ids):
610
+ return None
611
+ for seq, mask in zip(input_ids, attention_masks, strict=False):
612
+ if len(mask) != len(seq):
613
+ return None
568
614
 
569
615
  labels_raw = section.get("labels")
570
616
  labels: list[list[int]] | None = None
571
617
  if isinstance(labels_raw, list) and labels_raw:
618
+ maybe_labels = _wrap_single_row(labels_raw, expected_rows=len(input_ids))
619
+ if not isinstance(maybe_labels, list) or len(maybe_labels) != len(
620
+ input_ids
621
+ ):
622
+ return None
572
623
  labels = []
573
- for idx, raw_label in enumerate(labels_raw):
624
+ for idx, raw_label in enumerate(maybe_labels):
574
625
  label_list = _tensor_or_list_to_ints(raw_label)
575
626
  if idx < len(input_ids):
576
627
  target_len = len(input_ids[idx])
@@ -582,12 +633,22 @@ def _extract_pairing_schedule(report: dict[str, Any] | None) -> dict[str, Any] |
582
633
  label_list = label_list[:target_len]
583
634
  labels.append(label_list)
584
635
 
585
- masked_counts = None
586
- if isinstance(section.get("masked_token_counts"), list):
587
- masked_counts = [int(v) for v in section["masked_token_counts"]]
588
- actual_counts = None
589
- if isinstance(section.get("actual_token_counts"), list):
590
- actual_counts = [int(v) for v in section["actual_token_counts"]]
636
+ masked_counts: list[int] | None = None
637
+ if section.get("masked_token_counts") is not None:
638
+ raw = section.get("masked_token_counts")
639
+ if isinstance(raw, int) and len(input_ids) == 1:
640
+ raw = [raw]
641
+ if not isinstance(raw, list) or len(raw) != len(input_ids):
642
+ return None
643
+ masked_counts = [int(v) for v in raw]
644
+ actual_counts: list[int] | None = None
645
+ if section.get("actual_token_counts") is not None:
646
+ raw = section.get("actual_token_counts")
647
+ if isinstance(raw, int) and len(input_ids) == 1:
648
+ raw = [raw]
649
+ if not isinstance(raw, list) or len(raw) != len(input_ids):
650
+ return None
651
+ actual_counts = [int(v) for v in raw]
591
652
 
592
653
  payload: dict[str, Any] = {
593
654
  "window_ids": window_ids,
@@ -602,8 +663,10 @@ def _extract_pairing_schedule(report: dict[str, Any] | None) -> dict[str, Any] |
602
663
  payload["actual_token_counts"] = actual_counts
603
664
  return payload
604
665
 
605
- preview = _sanitize("preview")
606
- final = _sanitize("final")
666
+ preview = _sanitize("preview", start_id=0)
667
+ if not preview:
668
+ return None
669
+ final = _sanitize("final", start_id=len(preview.get("input_ids") or []))
607
670
  if preview and final:
608
671
  return {"preview": preview, "final": final}
609
672
  return None
@@ -828,11 +891,30 @@ def _extract_model_load_kwargs(cfg: InvarLockConfig) -> dict[str, Any]:
828
891
  model = data.get("model") if isinstance(data, dict) else None
829
892
  if not isinstance(model, dict):
830
893
  return {}
831
- return {
894
+ extra = {
832
895
  key: value
833
896
  for key, value in model.items()
834
897
  if key not in {"id", "adapter", "device"} and value is not None
835
898
  }
899
+ # Backwards-compatible aliasing: config `dtype` → HF `torch_dtype`.
900
+ if "dtype" in extra and "torch_dtype" not in extra:
901
+ extra["torch_dtype"] = extra.pop("dtype")
902
+
903
+ # Normalize torch_dtype when present (keep as string for JSON-ability).
904
+ if "torch_dtype" in extra and isinstance(extra.get("torch_dtype"), str):
905
+ dtype_str = str(extra.get("torch_dtype") or "").strip().lower()
906
+ aliases = {
907
+ "fp16": "float16",
908
+ "half": "float16",
909
+ "bf16": "bfloat16",
910
+ "fp32": "float32",
911
+ }
912
+ if dtype_str in aliases:
913
+ extra["torch_dtype"] = aliases[dtype_str]
914
+ elif dtype_str:
915
+ extra["torch_dtype"] = dtype_str
916
+
917
+ return extra
836
918
 
837
919
 
838
920
  def _load_model_with_cfg(adapter: Any, cfg: InvarLockConfig, device: str) -> Any:
@@ -882,6 +964,7 @@ def _run_bare_control(
882
964
  console: Console,
883
965
  resolved_loss_type: str,
884
966
  profile_normalized: str | None,
967
+ snapshot_provenance: dict[str, bool] | None = None,
885
968
  skip_model_load: bool = False,
886
969
  ) -> dict[str, Any] | None:
887
970
  """Execute the bare-control run for overhead estimation and return payload."""
@@ -897,26 +980,38 @@ def _run_bare_control(
897
980
  bare_context.setdefault("validation", {})["guard_overhead_mode"] = "bare"
898
981
  bare_config.context = bare_context
899
982
 
900
- if restore_fn and model is not None:
901
- restore_fn()
902
- bare_target_model = model
903
- elif skip_model_load:
904
- bare_target_model = model or SimpleNamespace(name="bare_stub_model")
905
- else:
906
- bare_target_model = adapter.load_model(cfg.model.id, device=resolved_device)
907
-
908
- bare_report = bare_runner.execute(
909
- model=bare_target_model,
910
- adapter=adapter,
911
- edit=edit_op,
912
- guards=[],
913
- config=bare_config,
914
- calibration_data=calibration_data,
915
- auto_config=auto_config,
916
- edit_config=edit_config,
917
- preview_n=preview_count,
918
- final_n=final_count,
919
- )
983
+ private_model_loaded = False
984
+ bare_target_model = None
985
+ try:
986
+ if restore_fn and model is not None:
987
+ try:
988
+ restore_fn()
989
+ except Exception as exc:
990
+ raise _SnapshotRestoreFailed(str(exc)) from exc
991
+ bare_target_model = model
992
+ elif skip_model_load:
993
+ bare_target_model = model or SimpleNamespace(name="bare_stub_model")
994
+ else:
995
+ bare_target_model = _load_model_with_cfg(adapter, cfg, resolved_device)
996
+ private_model_loaded = True
997
+ if snapshot_provenance is not None:
998
+ snapshot_provenance["reload_path_used"] = True
999
+
1000
+ bare_report = bare_runner.execute(
1001
+ model=bare_target_model,
1002
+ adapter=adapter,
1003
+ edit=edit_op,
1004
+ guards=[],
1005
+ config=bare_config,
1006
+ calibration_data=calibration_data,
1007
+ auto_config=auto_config,
1008
+ edit_config=edit_config,
1009
+ preview_n=preview_count,
1010
+ final_n=final_count,
1011
+ )
1012
+ finally:
1013
+ if private_model_loaded:
1014
+ _free_model_memory(bare_target_model)
920
1015
 
921
1016
  bare_ppl_final = None
922
1017
  bare_ppl_preview = None
@@ -988,16 +1083,22 @@ def _execute_guarded_run(
988
1083
  restore_fn: Any | None,
989
1084
  resolved_device: str,
990
1085
  console: Console,
1086
+ snapshot_provenance: dict[str, bool] | None = None,
991
1087
  skip_model_load: bool = False,
992
1088
  ) -> tuple[Any, Any]:
993
1089
  """Restore or load model and execute the guarded CoreRunner."""
994
1090
  if restore_fn and model is not None:
995
- restore_fn()
1091
+ try:
1092
+ restore_fn()
1093
+ except Exception as exc:
1094
+ raise _SnapshotRestoreFailed(str(exc)) from exc
996
1095
  elif skip_model_load:
997
1096
  model = model or SimpleNamespace(name="guarded_stub_model")
998
1097
  else:
999
1098
  console.print(f"🔧 Loading model: {cfg.model.id} (attempt 1)")
1000
- model = adapter.load_model(cfg.model.id, device=resolved_device)
1099
+ model = _load_model_with_cfg(adapter, cfg, resolved_device)
1100
+ if snapshot_provenance is not None:
1101
+ snapshot_provenance["reload_path_used"] = True
1001
1102
 
1002
1103
  core_report = runner.execute(
1003
1104
  model=model,
@@ -1071,7 +1172,22 @@ def _compute_provider_digest(report: dict[str, Any]) -> dict[str, str] | None:
1071
1172
  wids = sec.get("window_ids")
1072
1173
  if isinstance(wids, list):
1073
1174
  all_ids.extend(list(wids))
1074
- ids_sha = _hash_json(sorted(all_ids)) if all_ids else None
1175
+ ids_sha = None
1176
+ if all_ids:
1177
+ # Prefer ints when possible; fall back to strings to avoid mixed-type sorting.
1178
+ ids_int: list[int] = []
1179
+ use_ints = True
1180
+ for raw in all_ids:
1181
+ try:
1182
+ ids_int.append(int(raw))
1183
+ except Exception:
1184
+ use_ints = False
1185
+ break
1186
+ if use_ints:
1187
+ ids_sha = _hash_json(sorted(ids_int))
1188
+ else:
1189
+ ids_str = [str(v) for v in all_ids]
1190
+ ids_sha = _hash_json(sorted(ids_str))
1075
1191
 
1076
1192
  # tokenizer hash: prefer meta.tokenizer_hash then data.tokenizer_hash
1077
1193
  tok_hash = None
@@ -1101,6 +1217,7 @@ def _validate_and_harvest_baseline_schedule(
1101
1217
  *,
1102
1218
  tokenizer_hash: str | None,
1103
1219
  resolved_loss_type: str,
1220
+ profile: str | None = None,
1104
1221
  baseline_path_str: str | None = None,
1105
1222
  console: Console | None = None,
1106
1223
  ) -> dict[str, Any]:
@@ -1117,6 +1234,10 @@ def _validate_and_harvest_baseline_schedule(
1117
1234
 
1118
1235
  def _fail_schedule(reason: str) -> None:
1119
1236
  path = baseline_path_str or "baseline"
1237
+ prof = (profile or "dev").strip().lower()
1238
+ message = f"PAIRING-EVIDENCE-MISSING: {path}: {reason}"
1239
+ if prof in {"ci", "release"}:
1240
+ raise InvarlockError(code="E001", message=message)
1120
1241
  _print(
1121
1242
  f"[red]❌ Baseline pairing schedule '{path}' is incompatible: {reason}[/red]"
1122
1243
  )
@@ -1134,6 +1255,178 @@ def _validate_and_harvest_baseline_schedule(
1134
1255
  value = baseline_meta.get(field)
1135
1256
  return value if value is not None else default
1136
1257
 
1258
+ # Structural integrity checks (fail closed in CI/Release)
1259
+ try:
1260
+ prev = (
1261
+ pairing_schedule.get("preview")
1262
+ if isinstance(pairing_schedule, dict)
1263
+ else None
1264
+ )
1265
+ fin = (
1266
+ pairing_schedule.get("final")
1267
+ if isinstance(pairing_schedule, dict)
1268
+ else None
1269
+ )
1270
+ if not isinstance(prev, dict) or not isinstance(fin, dict):
1271
+ _fail_schedule("missing preview/final evaluation_windows sections")
1272
+
1273
+ def _arm_check(
1274
+ label: str, section: dict[str, Any]
1275
+ ) -> tuple[list[int], list[list[int]]]:
1276
+ wids = section.get("window_ids")
1277
+ toks = section.get("input_ids")
1278
+ masks = section.get("attention_masks")
1279
+ if not isinstance(wids, list) or not isinstance(toks, list):
1280
+ _fail_schedule(f"invalid {label} section: missing window_ids/input_ids")
1281
+ if len(wids) != len(toks):
1282
+ _fail_schedule(
1283
+ f"{label} coherence error: len(window_ids)={len(wids)} len(input_ids)={len(toks)}"
1284
+ )
1285
+ ids_int: list[int] = []
1286
+ seqs: list[list[int]] = []
1287
+ for idx, (wid, seq) in enumerate(zip(wids, toks, strict=False)):
1288
+ try:
1289
+ wid_int = int(wid)
1290
+ except Exception:
1291
+ _fail_schedule(
1292
+ f"{label} window_ids contains non-int at index {idx}"
1293
+ )
1294
+ ids_int.append(wid_int)
1295
+ seq_ints = _tensor_or_list_to_ints(seq)
1296
+ if not seq_ints:
1297
+ _fail_schedule(f"{label} input_ids empty at index {idx}")
1298
+ seqs.append(seq_ints)
1299
+
1300
+ # attention_masks are required for pairing, but legacy baselines may omit them.
1301
+ # When absent, default to all-ones masks (cannot infer padding reliably).
1302
+ masks_rows: list[list[int]] = []
1303
+ masks_missing = masks is None or masks == []
1304
+ if (
1305
+ isinstance(masks, list)
1306
+ and masks
1307
+ and len(seqs) == 1
1308
+ and not isinstance(masks[0], list)
1309
+ ): # type: ignore[index]
1310
+ masks = [masks]
1311
+
1312
+ if isinstance(masks, list) and masks:
1313
+ if len(masks) != len(seqs):
1314
+ _fail_schedule(
1315
+ f"{label} coherence error: len(attention_masks)={len(masks)} len(input_ids)={len(seqs)}"
1316
+ )
1317
+ for j, (seq_ints, mask) in enumerate(zip(seqs, masks, strict=False)):
1318
+ if not isinstance(mask, list):
1319
+ _fail_schedule(
1320
+ f"{label} attention_masks row is not a list at index {j}"
1321
+ )
1322
+ mask_ints = _tensor_or_list_to_ints(mask)
1323
+ if len(mask_ints) != len(seq_ints):
1324
+ _fail_schedule(
1325
+ f"{label} attention_masks length mismatch at index {j}"
1326
+ )
1327
+ masks_rows.append(mask_ints)
1328
+ else:
1329
+ masks_missing = True
1330
+ masks_rows = [[1] * len(seq) for seq in seqs]
1331
+
1332
+ if masks_missing:
1333
+ try:
1334
+ section["attention_masks"] = masks_rows
1335
+ except Exception:
1336
+ pass
1337
+
1338
+ # Optional MLM fields must align when present.
1339
+ labels = section.get("labels")
1340
+ if isinstance(labels, list) and labels:
1341
+ if len(labels) != len(seqs):
1342
+ _fail_schedule(f"{label} labels length mismatch")
1343
+ for j, row in enumerate(labels):
1344
+ row_ints = _tensor_or_list_to_ints(row)
1345
+ if len(row_ints) != len(seqs[j]):
1346
+ _fail_schedule(f"{label} labels length mismatch at index {j}")
1347
+
1348
+ for key in ("masked_token_counts", "actual_token_counts"):
1349
+ if section.get(key) is not None:
1350
+ raw_counts = section.get(key)
1351
+ if not isinstance(raw_counts, list) or len(raw_counts) != len(seqs):
1352
+ _fail_schedule(f"{label} {key} length mismatch")
1353
+ return ids_int, seqs
1354
+
1355
+ prev_ids, prev_seqs = _arm_check("preview", prev)
1356
+ fin_ids, fin_seqs = _arm_check("final", fin)
1357
+
1358
+ if len(set(prev_ids)) != len(prev_ids):
1359
+ _fail_schedule("duplicate window_ids detected in preview arm")
1360
+ if len(set(fin_ids)) != len(fin_ids):
1361
+ _fail_schedule("duplicate window_ids detected in final arm")
1362
+ if set(prev_ids) & set(fin_ids):
1363
+ _fail_schedule("window_ids overlap between preview and final arms")
1364
+
1365
+ def _hash_tokens(tokens: list[int]) -> bytes:
1366
+ if not tokens:
1367
+ return b""
1368
+ token_array = array("I", (int(token) & 0xFFFFFFFF for token in tokens))
1369
+ return hashlib.blake2b(token_array.tobytes(), digest_size=16).digest()
1370
+
1371
+ prev_hashes = [_hash_tokens(seq) for seq in prev_seqs]
1372
+ fin_hashes = [_hash_tokens(seq) for seq in fin_seqs]
1373
+ if len(set(prev_hashes)) != len(prev_hashes):
1374
+ _fail_schedule("duplicate token sequences detected in preview arm")
1375
+ if len(set(fin_hashes)) != len(fin_hashes):
1376
+ _fail_schedule("duplicate token sequences detected in final arm")
1377
+ if set(prev_hashes) & set(fin_hashes):
1378
+ _fail_schedule("preview/final token sequence overlap detected")
1379
+
1380
+ # Optional: validate baseline hashes when present in baseline report data
1381
+ expected_preview_hash = _hash_sequences(prev_seqs)
1382
+ expected_final_hash = _hash_sequences(fin_seqs)
1383
+ expected_dataset_hash = hashlib.blake2s(
1384
+ (expected_preview_hash + expected_final_hash).encode("utf-8"),
1385
+ digest_size=16,
1386
+ ).hexdigest()
1387
+ baseline_preview_hash = baseline_meta.get("preview_hash")
1388
+ baseline_final_hash = baseline_meta.get("final_hash")
1389
+ baseline_dataset_hash = baseline_meta.get("dataset_hash")
1390
+ if (
1391
+ isinstance(baseline_preview_hash, str)
1392
+ and baseline_preview_hash
1393
+ and baseline_preview_hash != expected_preview_hash
1394
+ ):
1395
+ prof = (profile or "dev").strip().lower()
1396
+ if prof in {"ci", "release"}:
1397
+ _fail_schedule("preview_hash mismatch vs baseline report data")
1398
+ _print(
1399
+ "[yellow]⚠️ Baseline preview_hash mismatch; continuing in dev profile.[/yellow]"
1400
+ )
1401
+ if (
1402
+ isinstance(baseline_final_hash, str)
1403
+ and baseline_final_hash
1404
+ and baseline_final_hash != expected_final_hash
1405
+ ):
1406
+ prof = (profile or "dev").strip().lower()
1407
+ if prof in {"ci", "release"}:
1408
+ _fail_schedule("final_hash mismatch vs baseline report data")
1409
+ _print(
1410
+ "[yellow]⚠️ Baseline final_hash mismatch; continuing in dev profile.[/yellow]"
1411
+ )
1412
+ if (
1413
+ isinstance(baseline_dataset_hash, str)
1414
+ and baseline_dataset_hash
1415
+ and baseline_dataset_hash != expected_dataset_hash
1416
+ ):
1417
+ prof = (profile or "dev").strip().lower()
1418
+ if prof in {"ci", "release"}:
1419
+ _fail_schedule("dataset_hash mismatch vs baseline report data")
1420
+ _print(
1421
+ "[yellow]⚠️ Baseline dataset_hash mismatch; continuing in dev profile.[/yellow]"
1422
+ )
1423
+ except InvarlockError:
1424
+ raise
1425
+ except typer.Exit:
1426
+ raise
1427
+ except Exception as exc: # noqa: BLE001
1428
+ _fail_schedule(f"failed to validate baseline schedule integrity ({exc})")
1429
+
1137
1430
  # Adopt counts from the schedule, warning if they differ from cfg
1138
1431
  baseline_preview = len(pairing_schedule["preview"].get("input_ids") or [])
1139
1432
  baseline_final = len(pairing_schedule["final"].get("input_ids") or [])
@@ -1262,20 +1555,32 @@ def _enforce_provider_parity(
1262
1555
  return
1263
1556
  sd = subject_digest or {}
1264
1557
  bd = baseline_digest or {}
1558
+ subj_ids = sd.get("ids_sha256")
1559
+ base_ids = bd.get("ids_sha256")
1265
1560
  subj_tok = sd.get("tokenizer_sha256")
1266
1561
  base_tok = bd.get("tokenizer_sha256")
1267
1562
  subj_mask = sd.get("masking_sha256")
1268
1563
  base_mask = bd.get("masking_sha256")
1269
1564
  # Missing digest information in CI/Release → abort
1270
1565
  if not (
1271
- isinstance(subj_tok, str)
1566
+ isinstance(subj_ids, str)
1567
+ and isinstance(base_ids, str)
1568
+ and subj_ids
1569
+ and base_ids
1570
+ and isinstance(subj_tok, str)
1272
1571
  and isinstance(base_tok, str)
1273
1572
  and subj_tok
1274
1573
  and base_tok
1275
1574
  ):
1276
1575
  raise InvarlockError(
1277
1576
  code="E004",
1278
- message="PROVIDER-DIGEST-MISSING: subject or baseline missing tokenizer digest",
1577
+ message="PROVIDER-DIGEST-MISSING: subject or baseline missing ids/tokenizer digest",
1578
+ )
1579
+ # Window-ids mismatch → abort
1580
+ if subj_ids != base_ids:
1581
+ raise InvarlockError(
1582
+ code="E006",
1583
+ message="IDS-DIGEST-MISMATCH: subject and baseline window IDs differ",
1279
1584
  )
1280
1585
  # Tokenizer mismatch → abort with code
1281
1586
  if subj_tok != base_tok:
@@ -1765,38 +2070,83 @@ def run_command(
1765
2070
  pairing_schedule: dict[str, Any] | None = None
1766
2071
  if baseline:
1767
2072
  baseline_path = Path(baseline)
1768
- if baseline_path.exists():
2073
+ profile_normalized = (profile or "").strip().lower()
2074
+ strict_baseline = profile_normalized in {"ci", "release"}
2075
+ if not baseline_path.exists():
2076
+ msg = (
2077
+ "PAIRING-EVIDENCE-MISSING: baseline report path does not exist "
2078
+ f"({baseline})"
2079
+ )
2080
+ if strict_baseline:
2081
+ raise InvarlockError(code="E001", message=msg)
2082
+ console.print(
2083
+ f"[yellow]⚠️ {msg}. Falling back to dataset schedule.[/yellow]"
2084
+ )
2085
+ else:
1769
2086
  try:
1770
2087
  with baseline_path.open(encoding="utf-8") as f:
1771
2088
  baseline_report_data = json.load(f)
2089
+ except Exception as exc: # noqa: BLE001
2090
+ msg = f"PAIRING-EVIDENCE-MISSING: baseline report JSON parse failed ({exc})"
2091
+ if strict_baseline:
2092
+ raise InvarlockError(code="E001", message=msg) from exc
2093
+ console.print(
2094
+ f"[yellow]⚠️ {msg}. Falling back to dataset schedule.[/yellow]"
2095
+ )
2096
+ baseline_report_data = None
2097
+ if isinstance(baseline_report_data, dict):
1772
2098
  pairing_schedule = _extract_pairing_schedule(baseline_report_data)
1773
2099
  if pairing_schedule:
2100
+ # Normalize baseline report in-memory so downstream digest/parity
2101
+ # computations see a consistent window_id + mask shape even for
2102
+ # legacy baselines missing some fields.
2103
+ try:
2104
+ baseline_report_data["evaluation_windows"] = (
2105
+ pairing_schedule
2106
+ )
2107
+ except Exception:
2108
+ pass
2109
+ # Harvest tokenizer hash provenance from baseline when present.
2110
+ try:
2111
+ if not tokenizer_hash:
2112
+ tok = None
2113
+ meta = (
2114
+ baseline_report_data.get("meta")
2115
+ if isinstance(
2116
+ baseline_report_data.get("meta"), dict
2117
+ )
2118
+ else {}
2119
+ )
2120
+ data = (
2121
+ baseline_report_data.get("data")
2122
+ if isinstance(
2123
+ baseline_report_data.get("data"), dict
2124
+ )
2125
+ else {}
2126
+ )
2127
+ if isinstance(meta, dict):
2128
+ tok = meta.get("tokenizer_hash")
2129
+ if not tok and isinstance(data, dict):
2130
+ tok = data.get("tokenizer_hash")
2131
+ if isinstance(tok, str) and tok:
2132
+ tokenizer_hash = tok
2133
+ except Exception:
2134
+ pass
1774
2135
  console.print(
1775
2136
  "🧬 Loaded baseline evaluation schedule for pairing"
1776
2137
  )
1777
- elif (profile or "").lower() == "release":
1778
- console.print(
1779
- f"[red]❌ Baseline report '{baseline}' does not contain evaluation_windows required for pairing.[/red]"
1780
- )
1781
- raise typer.Exit(1)
1782
2138
  else:
2139
+ msg = (
2140
+ "PAIRING-EVIDENCE-MISSING: baseline report missing or invalid "
2141
+ f"evaluation_windows ({baseline})"
2142
+ )
2143
+ if strict_baseline:
2144
+ raise InvarlockError(code="E001", message=msg)
1783
2145
  console.print(
1784
- f"[yellow]⚠️ Baseline report '{baseline}' lacks evaluation_windows; falling back to dataset schedule.[/yellow]"
2146
+ f"[yellow]⚠️ {msg}. Falling back to dataset schedule.[/yellow]"
1785
2147
  )
1786
2148
  baseline_report_data = None
1787
2149
  pairing_schedule = None
1788
- except typer.Exit:
1789
- raise
1790
- except Exception as exc: # noqa: BLE001
1791
- console.print(
1792
- f"[yellow]⚠️ Failed to load baseline report '{baseline}': {exc}. Falling back to dataset schedule.[/yellow]"
1793
- )
1794
- baseline_report_data = None
1795
- pairing_schedule = None
1796
- else:
1797
- console.print(
1798
- f"[yellow]⚠️ Baseline report '{baseline}' not found. Falling back to dataset schedule.[/yellow]"
1799
- )
1800
2150
 
1801
2151
  requested_preview = int(getattr(cfg.dataset, "preview_n", 0))
1802
2152
  requested_final = int(getattr(cfg.dataset, "final_n", 0))
@@ -1988,6 +2338,7 @@ def run_command(
1988
2338
  baseline_report_data,
1989
2339
  tokenizer_hash=tokenizer_hash,
1990
2340
  resolved_loss_type=resolved_loss_type,
2341
+ profile=profile,
1991
2342
  baseline_path_str=str(baseline) if baseline else None,
1992
2343
  console=console,
1993
2344
  )
@@ -2781,12 +3132,16 @@ def run_command(
2781
3132
  model = None
2782
3133
  restore_fn = None
2783
3134
  snapshot_tmpdir: str | None = None
3135
+ snapshot_provenance: dict[str, bool] = {
3136
+ "restore_failed": False,
3137
+ "reload_path_used": False,
3138
+ }
2784
3139
 
2785
3140
  # Try single-load with snapshot/restore if adapter supports it; fallback to reload per attempt
2786
3141
  try:
2787
3142
  # Load once
2788
3143
  console.print(f"🔧 Loading model once: {cfg.model.id}")
2789
- model = adapter.load_model(cfg.model.id, device=resolved_device)
3144
+ model = _load_model_with_cfg(adapter, cfg, resolved_device)
2790
3145
 
2791
3146
  # No edit-specific bootstrap logic
2792
3147
 
@@ -2954,12 +3309,26 @@ def run_command(
2954
3309
 
2955
3310
  restore_fn = _restore
2956
3311
  elif mode == "bytes":
2957
- base_blob = adapter.snapshot(model) # type: ignore[attr-defined]
3312
+ supports_chunked = hasattr(adapter, "snapshot_chunked") and hasattr(
3313
+ adapter, "restore_chunked"
3314
+ )
3315
+ try:
3316
+ base_blob = adapter.snapshot(model) # type: ignore[attr-defined]
3317
+ except Exception:
3318
+ if not supports_chunked:
3319
+ raise
3320
+ snapshot_tmpdir = adapter.snapshot_chunked(model) # type: ignore[attr-defined]
3321
+
3322
+ def _restore_fallback_chunked():
3323
+ adapter.restore_chunked(model, snapshot_tmpdir) # type: ignore[attr-defined]
3324
+
3325
+ restore_fn = _restore_fallback_chunked
3326
+ else:
2958
3327
 
2959
- def _restore2():
2960
- adapter.restore(model, base_blob) # type: ignore[attr-defined]
3328
+ def _restore2():
3329
+ adapter.restore(model, base_blob) # type: ignore[attr-defined]
2961
3330
 
2962
- restore_fn = _restore2
3331
+ restore_fn = _restore2
2963
3332
  else:
2964
3333
  # reload path - properly free GPU memory before setting to None
2965
3334
  _free_model_memory(model)
@@ -3003,62 +3372,88 @@ def run_command(
3003
3372
  )
3004
3373
 
3005
3374
  guard_overhead_payload: dict[str, Any] | None = None
3006
- if skip_overhead and profile_normalized in {"ci", "release"}:
3007
- guard_overhead_payload = {
3008
- "overhead_threshold": GUARD_OVERHEAD_THRESHOLD,
3009
- "evaluated": False,
3010
- "passed": True,
3011
- "skipped": True,
3012
- "skip_reason": "INVARLOCK_SKIP_OVERHEAD_CHECK",
3013
- "mode": "skipped",
3014
- "source": "env:INVARLOCK_SKIP_OVERHEAD_CHECK",
3015
- "messages": [
3016
- "Overhead check skipped via INVARLOCK_SKIP_OVERHEAD_CHECK"
3017
- ],
3018
- "warnings": [],
3019
- "errors": [],
3020
- "checks": {},
3021
- }
3022
- elif measure_guard_overhead:
3023
- guard_overhead_payload = _run_bare_control(
3375
+ try:
3376
+ if skip_overhead and profile_normalized in {"ci", "release"}:
3377
+ guard_overhead_payload = {
3378
+ "overhead_threshold": GUARD_OVERHEAD_THRESHOLD,
3379
+ "evaluated": False,
3380
+ "passed": True,
3381
+ "skipped": True,
3382
+ "skip_reason": "INVARLOCK_SKIP_OVERHEAD_CHECK",
3383
+ "mode": "skipped",
3384
+ "source": "env:INVARLOCK_SKIP_OVERHEAD_CHECK",
3385
+ "messages": [
3386
+ "Overhead check skipped via INVARLOCK_SKIP_OVERHEAD_CHECK"
3387
+ ],
3388
+ "warnings": [],
3389
+ "errors": [],
3390
+ "checks": {},
3391
+ }
3392
+ elif measure_guard_overhead:
3393
+ guard_overhead_payload = _run_bare_control(
3394
+ adapter=adapter,
3395
+ edit_op=edit_op,
3396
+ cfg=cfg,
3397
+ model=model,
3398
+ run_config=run_config,
3399
+ calibration_data=calibration_data,
3400
+ auto_config=auto_config,
3401
+ edit_config=edit_config,
3402
+ preview_count=preview_count,
3403
+ final_count=final_count,
3404
+ seed_bundle=seed_bundle,
3405
+ resolved_device=resolved_device,
3406
+ restore_fn=restore_fn,
3407
+ console=console,
3408
+ resolved_loss_type=resolved_loss_type,
3409
+ profile_normalized=profile_normalized,
3410
+ snapshot_provenance=snapshot_provenance,
3411
+ skip_model_load=skip_model_load,
3412
+ )
3413
+
3414
+ # Ensure clean state for guarded run
3415
+ core_report, model = _execute_guarded_run(
3416
+ runner=runner,
3024
3417
  adapter=adapter,
3025
- edit_op=edit_op,
3026
- cfg=cfg,
3027
3418
  model=model,
3419
+ cfg=cfg,
3420
+ edit_op=edit_op,
3028
3421
  run_config=run_config,
3422
+ guards=guards,
3029
3423
  calibration_data=calibration_data,
3030
3424
  auto_config=auto_config,
3031
3425
  edit_config=edit_config,
3032
3426
  preview_count=preview_count,
3033
3427
  final_count=final_count,
3034
- seed_bundle=seed_bundle,
3035
- resolved_device=resolved_device,
3036
3428
  restore_fn=restore_fn,
3429
+ resolved_device=resolved_device,
3037
3430
  console=console,
3038
- resolved_loss_type=resolved_loss_type,
3039
- profile_normalized=profile_normalized,
3431
+ snapshot_provenance=snapshot_provenance,
3040
3432
  skip_model_load=skip_model_load,
3041
3433
  )
3042
-
3043
- # Ensure clean state for guarded run
3044
- core_report, model = _execute_guarded_run(
3045
- runner=runner,
3046
- adapter=adapter,
3047
- model=model,
3048
- cfg=cfg,
3049
- edit_op=edit_op,
3050
- run_config=run_config,
3051
- guards=guards,
3052
- calibration_data=calibration_data,
3053
- auto_config=auto_config,
3054
- edit_config=edit_config,
3055
- preview_count=preview_count,
3056
- final_count=final_count,
3057
- restore_fn=restore_fn,
3058
- resolved_device=resolved_device,
3059
- console=console,
3060
- skip_model_load=skip_model_load,
3061
- )
3434
+ except _SnapshotRestoreFailed as exc:
3435
+ snapshot_provenance["restore_failed"] = True
3436
+ _free_model_memory(model)
3437
+ model = None
3438
+ restore_fn = None
3439
+ console.print(
3440
+ "[yellow]⚠️ Snapshot restore failed; switching to reload-per-attempt.[/yellow]"
3441
+ )
3442
+ console.print(f"[yellow]↳ {exc}[/yellow]")
3443
+ if retry_controller:
3444
+ retry_controller.record_attempt(
3445
+ attempt,
3446
+ {
3447
+ "passed": False,
3448
+ "failures": ["restore_failed"],
3449
+ "validation": {},
3450
+ },
3451
+ edit_config,
3452
+ )
3453
+ if retry_controller.should_retry(False):
3454
+ attempt += 1
3455
+ continue
3456
+ raise typer.Exit(1) from exc
3062
3457
 
3063
3458
  if not hasattr(core_report, "context") or core_report.context is None:
3064
3459
  core_report.context = {}
@@ -3199,6 +3594,16 @@ def run_command(
3199
3594
  if tokenizer_hash:
3200
3595
  report["meta"]["tokenizer_hash"] = tokenizer_hash
3201
3596
 
3597
+ # Snapshot/restore provenance (survives retries).
3598
+ try:
3599
+ prov = report.setdefault("provenance", {})
3600
+ prov["restore_failed"] = bool(snapshot_provenance.get("restore_failed"))
3601
+ prov["reload_path_used"] = bool(
3602
+ snapshot_provenance.get("reload_path_used")
3603
+ )
3604
+ except Exception:
3605
+ pass
3606
+
3202
3607
  # Transfer edit information
3203
3608
  if hasattr(core_report, "edit") and core_report.edit:
3204
3609
  edit_deltas = core_report.edit.get("deltas", {})
@@ -3458,7 +3863,7 @@ def run_command(
3458
3863
  ],
3459
3864
  },
3460
3865
  }
3461
- elif had_baseline and (profile or "").lower() == "release":
3866
+ elif had_baseline and (profile or "").lower() in {"ci", "release"}:
3462
3867
  console.print(
3463
3868
  "[red]❌ [INVARLOCK:E001] PAIRING-SCHEDULE-MISMATCH: baseline pairing requested but evaluation windows were not produced. Check capacity/pairing config.[/red]"
3464
3869
  )
@@ -3476,11 +3881,12 @@ def run_command(
3476
3881
  except Exception:
3477
3882
  return 0
3478
3883
 
3884
+ preview_window_count = len(preview_records)
3885
+ final_window_count = len(final_records)
3886
+
3479
3887
  report["evaluation_windows"] = {
3480
3888
  "preview": {
3481
- "window_ids": [
3482
- f"preview::{i}" for i in range(len(preview_records))
3483
- ],
3889
+ "window_ids": list(range(preview_window_count)),
3484
3890
  "input_ids": [
3485
3891
  list(r["input_ids"]) for r in preview_records
3486
3892
  ],
@@ -3501,9 +3907,12 @@ def run_command(
3501
3907
  ),
3502
3908
  },
3503
3909
  "final": {
3504
- "window_ids": [
3505
- f"final::{i}" for i in range(len(final_records))
3506
- ],
3910
+ "window_ids": list(
3911
+ range(
3912
+ preview_window_count,
3913
+ preview_window_count + final_window_count,
3914
+ )
3915
+ ),
3507
3916
  "input_ids": [list(r["input_ids"]) for r in final_records],
3508
3917
  "attention_masks": [
3509
3918
  list(r["attention_mask"]) for r in final_records
@@ -3543,18 +3952,16 @@ def run_command(
3543
3952
  # Strict parity checks in CI/Release when baseline present
3544
3953
  try:
3545
3954
  if isinstance(baseline_report_data, dict):
3546
- base_prov = (
3547
- baseline_report_data.get("provenance", {})
3548
- if isinstance(
3549
- baseline_report_data.get("provenance"), dict
3955
+ base_digest = None
3956
+ base_prov = baseline_report_data.get("provenance")
3957
+ if isinstance(base_prov, dict):
3958
+ base_pd = base_prov.get("provider_digest")
3959
+ if isinstance(base_pd, dict):
3960
+ base_digest = base_pd
3961
+ if base_digest is None:
3962
+ base_digest = _compute_provider_digest(
3963
+ baseline_report_data
3550
3964
  )
3551
- else {}
3552
- )
3553
- base_digest = (
3554
- base_prov.get("provider_digest")
3555
- if isinstance(base_prov, dict)
3556
- else None
3557
- )
3558
3965
  _enforce_provider_parity(
3559
3966
  provider_digest,
3560
3967
  base_digest,
@@ -3570,6 +3977,8 @@ def run_command(
3570
3977
  _fail_run(str(_e))
3571
3978
  except Exception:
3572
3979
  pass
3980
+ except (typer.Exit, SystemExit, click.exceptions.Exit):
3981
+ raise
3573
3982
  except Exception:
3574
3983
  pass
3575
3984
 
@@ -3817,30 +4226,46 @@ def run_command(
3817
4226
  console.print(f"[red]{err}[/red]")
3818
4227
  raise typer.Exit(code)
3819
4228
 
3820
- # Additional guard: paired_windows collapse (0) in CI/Release
3821
- try:
4229
+ # Paired-run enforcement: baseline provided must be truly paired in CI/Release.
4230
+ if baseline and profile_normalized in {"ci", "release"}:
4231
+ pairing_reason = metrics_section.get("window_pairing_reason")
4232
+ if pairing_reason is not None:
4233
+ err = InvarlockError(
4234
+ code="E001",
4235
+ message=(
4236
+ "PAIRING-SCHEDULE-MISMATCH: baseline pairing requested but run was not paired "
4237
+ f"(window_pairing_reason={pairing_reason})"
4238
+ ),
4239
+ details={"window_pairing_reason": pairing_reason},
4240
+ )
4241
+ code = _resolve_exit_code(err, profile=profile_normalized)
4242
+ console.print(f"[red]{err}[/red]")
4243
+ raise typer.Exit(code)
4244
+
3822
4245
  paired_windows_val = metrics_section.get("paired_windows")
3823
- if (
3824
- profile_normalized in {"ci", "release"}
3825
- and isinstance(paired_windows_val, (int | float))
3826
- and int(paired_windows_val) == 0
3827
- ):
4246
+ paired_windows_int = None
4247
+ try:
4248
+ if paired_windows_val is not None and not isinstance(
4249
+ paired_windows_val, bool
4250
+ ):
4251
+ paired_windows_int = int(paired_windows_val)
4252
+ except Exception:
4253
+ paired_windows_int = None
4254
+ if paired_windows_int is None or paired_windows_int <= 0:
3828
4255
  err = InvarlockError(
3829
4256
  code="E001",
3830
4257
  message=(
3831
- "PAIRED-WINDOWS-COLLAPSED: paired_windows=0 under paired schedule. "
4258
+ "PAIRED-WINDOWS-COLLAPSED: paired_windows<=0 under paired baseline. "
3832
4259
  "Check device stability, dataset windows, or edit scope."
3833
4260
  ),
3834
4261
  details={
3835
- "paired_windows": int(paired_windows_val),
4262
+ "paired_windows": paired_windows_val,
3836
4263
  "profile": profile_normalized,
3837
4264
  },
3838
4265
  )
3839
4266
  code = _resolve_exit_code(err, profile=profile_normalized)
3840
4267
  console.print(f"[red]{err}[/red]")
3841
4268
  raise typer.Exit(code)
3842
- except Exception:
3843
- pass
3844
4269
 
3845
4270
  expected_preview = effective_preview or getattr(
3846
4271
  cfg.dataset, "preview_n", preview_count_report