invarlock 0.3.3__py3-none-any.whl → 0.3.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -200,6 +200,10 @@ def _free_model_memory(model: object | None) -> None:
200
200
  pass
201
201
 
202
202
 
203
+ class _SnapshotRestoreFailed(RuntimeError):
204
+ """Internal signal for snapshot restore failures during retries."""
205
+
206
+
203
207
  def _should_measure_overhead(profile_normalized: str) -> tuple[bool, bool]:
204
208
  """Return (measure_guard_overhead, skip_overhead) derived from env/profile."""
205
209
 
@@ -366,6 +370,8 @@ def _tensor_or_list_to_ints(values: Any) -> list[int]:
366
370
  return _to_int_list(raw)
367
371
  try:
368
372
  return _to_int_list(list(raw))
373
+ except (typer.Exit, SystemExit, click.exceptions.Exit):
374
+ raise
369
375
  except Exception:
370
376
  pass
371
377
  # Numpy arrays: treat as list-like
@@ -553,30 +559,69 @@ def _extract_pairing_schedule(report: dict[str, Any] | None) -> dict[str, Any] |
553
559
  if not isinstance(windows, dict):
554
560
  return None
555
561
 
556
- def _sanitize(section_key: str) -> dict[str, Any] | None:
562
+ def _wrap_single_row(raw: Any, *, expected_rows: int) -> list | None:
563
+ if not isinstance(raw, list):
564
+ return None
565
+ if expected_rows == 1 and raw and not isinstance(raw[0], list):
566
+ return [raw]
567
+ return raw
568
+
569
+ def _sanitize(section_key: str, *, start_id: int) -> dict[str, Any] | None:
557
570
  section = windows.get(section_key)
558
571
  if not isinstance(section, dict):
559
572
  return None
560
- window_ids = list(section.get("window_ids", []))
561
- input_ids_raw = section.get("input_ids", [])
573
+ input_ids_raw = section.get("input_ids")
562
574
  if not isinstance(input_ids_raw, list):
563
575
  return None
564
- input_ids = [list(seq) for seq in input_ids_raw]
576
+ input_ids = [_tensor_or_list_to_ints(seq) for seq in input_ids_raw]
577
+ if not input_ids:
578
+ return None
579
+
580
+ window_ids_raw = section.get("window_ids")
581
+ window_ids: list[int] = []
582
+ if isinstance(window_ids_raw, list):
583
+ if len(window_ids_raw) != len(input_ids):
584
+ return None
585
+ for wid in window_ids_raw:
586
+ try:
587
+ window_ids.append(int(wid))
588
+ except Exception:
589
+ return None
590
+ else:
591
+ window_ids = list(range(int(start_id), int(start_id) + len(input_ids)))
592
+
565
593
  attention_raw = section.get("attention_masks")
566
- if isinstance(attention_raw, list) and all(
567
- isinstance(mask, list) for mask in attention_raw
568
- ):
569
- attention_masks = [list(mask) for mask in attention_raw]
594
+ attention_masks: list[list[int]]
595
+ if isinstance(attention_raw, list):
596
+ maybe = _wrap_single_row(attention_raw, expected_rows=len(input_ids))
597
+ if isinstance(maybe, list) and all(
598
+ isinstance(mask, list) for mask in maybe
599
+ ):
600
+ attention_masks = [_tensor_or_list_to_ints(mask) for mask in maybe]
601
+ else:
602
+ attention_masks = [
603
+ [1 if int(token) != 0 else 0 for token in seq] for seq in input_ids
604
+ ]
570
605
  else:
571
606
  attention_masks = [
572
- [1 if token != 0 else 0 for token in seq] for seq in input_ids
607
+ [1 if int(token) != 0 else 0 for token in seq] for seq in input_ids
573
608
  ]
609
+ if len(attention_masks) != len(input_ids):
610
+ return None
611
+ for seq, mask in zip(input_ids, attention_masks, strict=False):
612
+ if len(mask) != len(seq):
613
+ return None
574
614
 
575
615
  labels_raw = section.get("labels")
576
616
  labels: list[list[int]] | None = None
577
617
  if isinstance(labels_raw, list) and labels_raw:
618
+ maybe_labels = _wrap_single_row(labels_raw, expected_rows=len(input_ids))
619
+ if not isinstance(maybe_labels, list) or len(maybe_labels) != len(
620
+ input_ids
621
+ ):
622
+ return None
578
623
  labels = []
579
- for idx, raw_label in enumerate(labels_raw):
624
+ for idx, raw_label in enumerate(maybe_labels):
580
625
  label_list = _tensor_or_list_to_ints(raw_label)
581
626
  if idx < len(input_ids):
582
627
  target_len = len(input_ids[idx])
@@ -588,12 +633,22 @@ def _extract_pairing_schedule(report: dict[str, Any] | None) -> dict[str, Any] |
588
633
  label_list = label_list[:target_len]
589
634
  labels.append(label_list)
590
635
 
591
- masked_counts = None
592
- if isinstance(section.get("masked_token_counts"), list):
593
- masked_counts = [int(v) for v in section["masked_token_counts"]]
594
- actual_counts = None
595
- if isinstance(section.get("actual_token_counts"), list):
596
- actual_counts = [int(v) for v in section["actual_token_counts"]]
636
+ masked_counts: list[int] | None = None
637
+ if section.get("masked_token_counts") is not None:
638
+ raw = section.get("masked_token_counts")
639
+ if isinstance(raw, int) and len(input_ids) == 1:
640
+ raw = [raw]
641
+ if not isinstance(raw, list) or len(raw) != len(input_ids):
642
+ return None
643
+ masked_counts = [int(v) for v in raw]
644
+ actual_counts: list[int] | None = None
645
+ if section.get("actual_token_counts") is not None:
646
+ raw = section.get("actual_token_counts")
647
+ if isinstance(raw, int) and len(input_ids) == 1:
648
+ raw = [raw]
649
+ if not isinstance(raw, list) or len(raw) != len(input_ids):
650
+ return None
651
+ actual_counts = [int(v) for v in raw]
597
652
 
598
653
  payload: dict[str, Any] = {
599
654
  "window_ids": window_ids,
@@ -608,8 +663,10 @@ def _extract_pairing_schedule(report: dict[str, Any] | None) -> dict[str, Any] |
608
663
  payload["actual_token_counts"] = actual_counts
609
664
  return payload
610
665
 
611
- preview = _sanitize("preview")
612
- final = _sanitize("final")
666
+ preview = _sanitize("preview", start_id=0)
667
+ if not preview:
668
+ return None
669
+ final = _sanitize("final", start_id=len(preview.get("input_ids") or []))
613
670
  if preview and final:
614
671
  return {"preview": preview, "final": final}
615
672
  return None
@@ -834,11 +891,30 @@ def _extract_model_load_kwargs(cfg: InvarLockConfig) -> dict[str, Any]:
834
891
  model = data.get("model") if isinstance(data, dict) else None
835
892
  if not isinstance(model, dict):
836
893
  return {}
837
- return {
894
+ extra = {
838
895
  key: value
839
896
  for key, value in model.items()
840
897
  if key not in {"id", "adapter", "device"} and value is not None
841
898
  }
899
+ # Backwards-compatible aliasing: config `dtype` → HF `torch_dtype`.
900
+ if "dtype" in extra and "torch_dtype" not in extra:
901
+ extra["torch_dtype"] = extra.pop("dtype")
902
+
903
+ # Normalize torch_dtype when present (keep as string for JSON-ability).
904
+ if "torch_dtype" in extra and isinstance(extra.get("torch_dtype"), str):
905
+ dtype_str = str(extra.get("torch_dtype") or "").strip().lower()
906
+ aliases = {
907
+ "fp16": "float16",
908
+ "half": "float16",
909
+ "bf16": "bfloat16",
910
+ "fp32": "float32",
911
+ }
912
+ if dtype_str in aliases:
913
+ extra["torch_dtype"] = aliases[dtype_str]
914
+ elif dtype_str:
915
+ extra["torch_dtype"] = dtype_str
916
+
917
+ return extra
842
918
 
843
919
 
844
920
  def _load_model_with_cfg(adapter: Any, cfg: InvarLockConfig, device: str) -> Any:
@@ -888,6 +964,7 @@ def _run_bare_control(
888
964
  console: Console,
889
965
  resolved_loss_type: str,
890
966
  profile_normalized: str | None,
967
+ snapshot_provenance: dict[str, bool] | None = None,
891
968
  skip_model_load: bool = False,
892
969
  ) -> dict[str, Any] | None:
893
970
  """Execute the bare-control run for overhead estimation and return payload."""
@@ -903,26 +980,38 @@ def _run_bare_control(
903
980
  bare_context.setdefault("validation", {})["guard_overhead_mode"] = "bare"
904
981
  bare_config.context = bare_context
905
982
 
906
- if restore_fn and model is not None:
907
- restore_fn()
908
- bare_target_model = model
909
- elif skip_model_load:
910
- bare_target_model = model or SimpleNamespace(name="bare_stub_model")
911
- else:
912
- bare_target_model = adapter.load_model(cfg.model.id, device=resolved_device)
913
-
914
- bare_report = bare_runner.execute(
915
- model=bare_target_model,
916
- adapter=adapter,
917
- edit=edit_op,
918
- guards=[],
919
- config=bare_config,
920
- calibration_data=calibration_data,
921
- auto_config=auto_config,
922
- edit_config=edit_config,
923
- preview_n=preview_count,
924
- final_n=final_count,
925
- )
983
+ private_model_loaded = False
984
+ bare_target_model = None
985
+ try:
986
+ if restore_fn and model is not None:
987
+ try:
988
+ restore_fn()
989
+ except Exception as exc:
990
+ raise _SnapshotRestoreFailed(str(exc)) from exc
991
+ bare_target_model = model
992
+ elif skip_model_load:
993
+ bare_target_model = model or SimpleNamespace(name="bare_stub_model")
994
+ else:
995
+ bare_target_model = _load_model_with_cfg(adapter, cfg, resolved_device)
996
+ private_model_loaded = True
997
+ if snapshot_provenance is not None:
998
+ snapshot_provenance["reload_path_used"] = True
999
+
1000
+ bare_report = bare_runner.execute(
1001
+ model=bare_target_model,
1002
+ adapter=adapter,
1003
+ edit=edit_op,
1004
+ guards=[],
1005
+ config=bare_config,
1006
+ calibration_data=calibration_data,
1007
+ auto_config=auto_config,
1008
+ edit_config=edit_config,
1009
+ preview_n=preview_count,
1010
+ final_n=final_count,
1011
+ )
1012
+ finally:
1013
+ if private_model_loaded:
1014
+ _free_model_memory(bare_target_model)
926
1015
 
927
1016
  bare_ppl_final = None
928
1017
  bare_ppl_preview = None
@@ -994,16 +1083,22 @@ def _execute_guarded_run(
994
1083
  restore_fn: Any | None,
995
1084
  resolved_device: str,
996
1085
  console: Console,
1086
+ snapshot_provenance: dict[str, bool] | None = None,
997
1087
  skip_model_load: bool = False,
998
1088
  ) -> tuple[Any, Any]:
999
1089
  """Restore or load model and execute the guarded CoreRunner."""
1000
1090
  if restore_fn and model is not None:
1001
- restore_fn()
1091
+ try:
1092
+ restore_fn()
1093
+ except Exception as exc:
1094
+ raise _SnapshotRestoreFailed(str(exc)) from exc
1002
1095
  elif skip_model_load:
1003
1096
  model = model or SimpleNamespace(name="guarded_stub_model")
1004
1097
  else:
1005
1098
  console.print(f"🔧 Loading model: {cfg.model.id} (attempt 1)")
1006
- model = adapter.load_model(cfg.model.id, device=resolved_device)
1099
+ model = _load_model_with_cfg(adapter, cfg, resolved_device)
1100
+ if snapshot_provenance is not None:
1101
+ snapshot_provenance["reload_path_used"] = True
1007
1102
 
1008
1103
  core_report = runner.execute(
1009
1104
  model=model,
@@ -1077,7 +1172,22 @@ def _compute_provider_digest(report: dict[str, Any]) -> dict[str, str] | None:
1077
1172
  wids = sec.get("window_ids")
1078
1173
  if isinstance(wids, list):
1079
1174
  all_ids.extend(list(wids))
1080
- ids_sha = _hash_json(sorted(all_ids)) if all_ids else None
1175
+ ids_sha = None
1176
+ if all_ids:
1177
+ # Prefer ints when possible; fall back to strings to avoid mixed-type sorting.
1178
+ ids_int: list[int] = []
1179
+ use_ints = True
1180
+ for raw in all_ids:
1181
+ try:
1182
+ ids_int.append(int(raw))
1183
+ except Exception:
1184
+ use_ints = False
1185
+ break
1186
+ if use_ints:
1187
+ ids_sha = _hash_json(sorted(ids_int))
1188
+ else:
1189
+ ids_str = [str(v) for v in all_ids]
1190
+ ids_sha = _hash_json(sorted(ids_str))
1081
1191
 
1082
1192
  # tokenizer hash: prefer meta.tokenizer_hash then data.tokenizer_hash
1083
1193
  tok_hash = None
@@ -1107,6 +1217,7 @@ def _validate_and_harvest_baseline_schedule(
1107
1217
  *,
1108
1218
  tokenizer_hash: str | None,
1109
1219
  resolved_loss_type: str,
1220
+ profile: str | None = None,
1110
1221
  baseline_path_str: str | None = None,
1111
1222
  console: Console | None = None,
1112
1223
  ) -> dict[str, Any]:
@@ -1123,6 +1234,10 @@ def _validate_and_harvest_baseline_schedule(
1123
1234
 
1124
1235
  def _fail_schedule(reason: str) -> None:
1125
1236
  path = baseline_path_str or "baseline"
1237
+ prof = (profile or "dev").strip().lower()
1238
+ message = f"PAIRING-EVIDENCE-MISSING: {path}: {reason}"
1239
+ if prof in {"ci", "release"}:
1240
+ raise InvarlockError(code="E001", message=message)
1126
1241
  _print(
1127
1242
  f"[red]❌ Baseline pairing schedule '{path}' is incompatible: {reason}[/red]"
1128
1243
  )
@@ -1140,6 +1255,178 @@ def _validate_and_harvest_baseline_schedule(
1140
1255
  value = baseline_meta.get(field)
1141
1256
  return value if value is not None else default
1142
1257
 
1258
+ # Structural integrity checks (fail closed in CI/Release)
1259
+ try:
1260
+ prev = (
1261
+ pairing_schedule.get("preview")
1262
+ if isinstance(pairing_schedule, dict)
1263
+ else None
1264
+ )
1265
+ fin = (
1266
+ pairing_schedule.get("final")
1267
+ if isinstance(pairing_schedule, dict)
1268
+ else None
1269
+ )
1270
+ if not isinstance(prev, dict) or not isinstance(fin, dict):
1271
+ _fail_schedule("missing preview/final evaluation_windows sections")
1272
+
1273
+ def _arm_check(
1274
+ label: str, section: dict[str, Any]
1275
+ ) -> tuple[list[int], list[list[int]]]:
1276
+ wids = section.get("window_ids")
1277
+ toks = section.get("input_ids")
1278
+ masks = section.get("attention_masks")
1279
+ if not isinstance(wids, list) or not isinstance(toks, list):
1280
+ _fail_schedule(f"invalid {label} section: missing window_ids/input_ids")
1281
+ if len(wids) != len(toks):
1282
+ _fail_schedule(
1283
+ f"{label} coherence error: len(window_ids)={len(wids)} len(input_ids)={len(toks)}"
1284
+ )
1285
+ ids_int: list[int] = []
1286
+ seqs: list[list[int]] = []
1287
+ for idx, (wid, seq) in enumerate(zip(wids, toks, strict=False)):
1288
+ try:
1289
+ wid_int = int(wid)
1290
+ except Exception:
1291
+ _fail_schedule(
1292
+ f"{label} window_ids contains non-int at index {idx}"
1293
+ )
1294
+ ids_int.append(wid_int)
1295
+ seq_ints = _tensor_or_list_to_ints(seq)
1296
+ if not seq_ints:
1297
+ _fail_schedule(f"{label} input_ids empty at index {idx}")
1298
+ seqs.append(seq_ints)
1299
+
1300
+ # attention_masks are required for pairing, but legacy baselines may omit them.
1301
+ # When absent, default to all-ones masks (cannot infer padding reliably).
1302
+ masks_rows: list[list[int]] = []
1303
+ masks_missing = masks is None or masks == []
1304
+ if (
1305
+ isinstance(masks, list)
1306
+ and masks
1307
+ and len(seqs) == 1
1308
+ and not isinstance(masks[0], list)
1309
+ ): # type: ignore[index]
1310
+ masks = [masks]
1311
+
1312
+ if isinstance(masks, list) and masks:
1313
+ if len(masks) != len(seqs):
1314
+ _fail_schedule(
1315
+ f"{label} coherence error: len(attention_masks)={len(masks)} len(input_ids)={len(seqs)}"
1316
+ )
1317
+ for j, (seq_ints, mask) in enumerate(zip(seqs, masks, strict=False)):
1318
+ if not isinstance(mask, list):
1319
+ _fail_schedule(
1320
+ f"{label} attention_masks row is not a list at index {j}"
1321
+ )
1322
+ mask_ints = _tensor_or_list_to_ints(mask)
1323
+ if len(mask_ints) != len(seq_ints):
1324
+ _fail_schedule(
1325
+ f"{label} attention_masks length mismatch at index {j}"
1326
+ )
1327
+ masks_rows.append(mask_ints)
1328
+ else:
1329
+ masks_missing = True
1330
+ masks_rows = [[1] * len(seq) for seq in seqs]
1331
+
1332
+ if masks_missing:
1333
+ try:
1334
+ section["attention_masks"] = masks_rows
1335
+ except Exception:
1336
+ pass
1337
+
1338
+ # Optional MLM fields must align when present.
1339
+ labels = section.get("labels")
1340
+ if isinstance(labels, list) and labels:
1341
+ if len(labels) != len(seqs):
1342
+ _fail_schedule(f"{label} labels length mismatch")
1343
+ for j, row in enumerate(labels):
1344
+ row_ints = _tensor_or_list_to_ints(row)
1345
+ if len(row_ints) != len(seqs[j]):
1346
+ _fail_schedule(f"{label} labels length mismatch at index {j}")
1347
+
1348
+ for key in ("masked_token_counts", "actual_token_counts"):
1349
+ if section.get(key) is not None:
1350
+ raw_counts = section.get(key)
1351
+ if not isinstance(raw_counts, list) or len(raw_counts) != len(seqs):
1352
+ _fail_schedule(f"{label} {key} length mismatch")
1353
+ return ids_int, seqs
1354
+
1355
+ prev_ids, prev_seqs = _arm_check("preview", prev)
1356
+ fin_ids, fin_seqs = _arm_check("final", fin)
1357
+
1358
+ if len(set(prev_ids)) != len(prev_ids):
1359
+ _fail_schedule("duplicate window_ids detected in preview arm")
1360
+ if len(set(fin_ids)) != len(fin_ids):
1361
+ _fail_schedule("duplicate window_ids detected in final arm")
1362
+ if set(prev_ids) & set(fin_ids):
1363
+ _fail_schedule("window_ids overlap between preview and final arms")
1364
+
1365
+ def _hash_tokens(tokens: list[int]) -> bytes:
1366
+ if not tokens:
1367
+ return b""
1368
+ token_array = array("I", (int(token) & 0xFFFFFFFF for token in tokens))
1369
+ return hashlib.blake2b(token_array.tobytes(), digest_size=16).digest()
1370
+
1371
+ prev_hashes = [_hash_tokens(seq) for seq in prev_seqs]
1372
+ fin_hashes = [_hash_tokens(seq) for seq in fin_seqs]
1373
+ if len(set(prev_hashes)) != len(prev_hashes):
1374
+ _fail_schedule("duplicate token sequences detected in preview arm")
1375
+ if len(set(fin_hashes)) != len(fin_hashes):
1376
+ _fail_schedule("duplicate token sequences detected in final arm")
1377
+ if set(prev_hashes) & set(fin_hashes):
1378
+ _fail_schedule("preview/final token sequence overlap detected")
1379
+
1380
+ # Optional: validate baseline hashes when present in baseline report data
1381
+ expected_preview_hash = _hash_sequences(prev_seqs)
1382
+ expected_final_hash = _hash_sequences(fin_seqs)
1383
+ expected_dataset_hash = hashlib.blake2s(
1384
+ (expected_preview_hash + expected_final_hash).encode("utf-8"),
1385
+ digest_size=16,
1386
+ ).hexdigest()
1387
+ baseline_preview_hash = baseline_meta.get("preview_hash")
1388
+ baseline_final_hash = baseline_meta.get("final_hash")
1389
+ baseline_dataset_hash = baseline_meta.get("dataset_hash")
1390
+ if (
1391
+ isinstance(baseline_preview_hash, str)
1392
+ and baseline_preview_hash
1393
+ and baseline_preview_hash != expected_preview_hash
1394
+ ):
1395
+ prof = (profile or "dev").strip().lower()
1396
+ if prof in {"ci", "release"}:
1397
+ _fail_schedule("preview_hash mismatch vs baseline report data")
1398
+ _print(
1399
+ "[yellow]⚠️ Baseline preview_hash mismatch; continuing in dev profile.[/yellow]"
1400
+ )
1401
+ if (
1402
+ isinstance(baseline_final_hash, str)
1403
+ and baseline_final_hash
1404
+ and baseline_final_hash != expected_final_hash
1405
+ ):
1406
+ prof = (profile or "dev").strip().lower()
1407
+ if prof in {"ci", "release"}:
1408
+ _fail_schedule("final_hash mismatch vs baseline report data")
1409
+ _print(
1410
+ "[yellow]⚠️ Baseline final_hash mismatch; continuing in dev profile.[/yellow]"
1411
+ )
1412
+ if (
1413
+ isinstance(baseline_dataset_hash, str)
1414
+ and baseline_dataset_hash
1415
+ and baseline_dataset_hash != expected_dataset_hash
1416
+ ):
1417
+ prof = (profile or "dev").strip().lower()
1418
+ if prof in {"ci", "release"}:
1419
+ _fail_schedule("dataset_hash mismatch vs baseline report data")
1420
+ _print(
1421
+ "[yellow]⚠️ Baseline dataset_hash mismatch; continuing in dev profile.[/yellow]"
1422
+ )
1423
+ except InvarlockError:
1424
+ raise
1425
+ except typer.Exit:
1426
+ raise
1427
+ except Exception as exc: # noqa: BLE001
1428
+ _fail_schedule(f"failed to validate baseline schedule integrity ({exc})")
1429
+
1143
1430
  # Adopt counts from the schedule, warning if they differ from cfg
1144
1431
  baseline_preview = len(pairing_schedule["preview"].get("input_ids") or [])
1145
1432
  baseline_final = len(pairing_schedule["final"].get("input_ids") or [])
@@ -1268,20 +1555,32 @@ def _enforce_provider_parity(
1268
1555
  return
1269
1556
  sd = subject_digest or {}
1270
1557
  bd = baseline_digest or {}
1558
+ subj_ids = sd.get("ids_sha256")
1559
+ base_ids = bd.get("ids_sha256")
1271
1560
  subj_tok = sd.get("tokenizer_sha256")
1272
1561
  base_tok = bd.get("tokenizer_sha256")
1273
1562
  subj_mask = sd.get("masking_sha256")
1274
1563
  base_mask = bd.get("masking_sha256")
1275
1564
  # Missing digest information in CI/Release → abort
1276
1565
  if not (
1277
- isinstance(subj_tok, str)
1566
+ isinstance(subj_ids, str)
1567
+ and isinstance(base_ids, str)
1568
+ and subj_ids
1569
+ and base_ids
1570
+ and isinstance(subj_tok, str)
1278
1571
  and isinstance(base_tok, str)
1279
1572
  and subj_tok
1280
1573
  and base_tok
1281
1574
  ):
1282
1575
  raise InvarlockError(
1283
1576
  code="E004",
1284
- message="PROVIDER-DIGEST-MISSING: subject or baseline missing tokenizer digest",
1577
+ message="PROVIDER-DIGEST-MISSING: subject or baseline missing ids/tokenizer digest",
1578
+ )
1579
+ # Window-ids mismatch → abort
1580
+ if subj_ids != base_ids:
1581
+ raise InvarlockError(
1582
+ code="E006",
1583
+ message="IDS-DIGEST-MISMATCH: subject and baseline window IDs differ",
1285
1584
  )
1286
1585
  # Tokenizer mismatch → abort with code
1287
1586
  if subj_tok != base_tok:
@@ -1771,38 +2070,83 @@ def run_command(
1771
2070
  pairing_schedule: dict[str, Any] | None = None
1772
2071
  if baseline:
1773
2072
  baseline_path = Path(baseline)
1774
- if baseline_path.exists():
2073
+ profile_normalized = (profile or "").strip().lower()
2074
+ strict_baseline = profile_normalized in {"ci", "release"}
2075
+ if not baseline_path.exists():
2076
+ msg = (
2077
+ "PAIRING-EVIDENCE-MISSING: baseline report path does not exist "
2078
+ f"({baseline})"
2079
+ )
2080
+ if strict_baseline:
2081
+ raise InvarlockError(code="E001", message=msg)
2082
+ console.print(
2083
+ f"[yellow]⚠️ {msg}. Falling back to dataset schedule.[/yellow]"
2084
+ )
2085
+ else:
1775
2086
  try:
1776
2087
  with baseline_path.open(encoding="utf-8") as f:
1777
2088
  baseline_report_data = json.load(f)
2089
+ except Exception as exc: # noqa: BLE001
2090
+ msg = f"PAIRING-EVIDENCE-MISSING: baseline report JSON parse failed ({exc})"
2091
+ if strict_baseline:
2092
+ raise InvarlockError(code="E001", message=msg) from exc
2093
+ console.print(
2094
+ f"[yellow]⚠️ {msg}. Falling back to dataset schedule.[/yellow]"
2095
+ )
2096
+ baseline_report_data = None
2097
+ if isinstance(baseline_report_data, dict):
1778
2098
  pairing_schedule = _extract_pairing_schedule(baseline_report_data)
1779
2099
  if pairing_schedule:
2100
+ # Normalize baseline report in-memory so downstream digest/parity
2101
+ # computations see a consistent window_id + mask shape even for
2102
+ # legacy baselines missing some fields.
2103
+ try:
2104
+ baseline_report_data["evaluation_windows"] = (
2105
+ pairing_schedule
2106
+ )
2107
+ except Exception:
2108
+ pass
2109
+ # Harvest tokenizer hash provenance from baseline when present.
2110
+ try:
2111
+ if not tokenizer_hash:
2112
+ tok = None
2113
+ meta = (
2114
+ baseline_report_data.get("meta")
2115
+ if isinstance(
2116
+ baseline_report_data.get("meta"), dict
2117
+ )
2118
+ else {}
2119
+ )
2120
+ data = (
2121
+ baseline_report_data.get("data")
2122
+ if isinstance(
2123
+ baseline_report_data.get("data"), dict
2124
+ )
2125
+ else {}
2126
+ )
2127
+ if isinstance(meta, dict):
2128
+ tok = meta.get("tokenizer_hash")
2129
+ if not tok and isinstance(data, dict):
2130
+ tok = data.get("tokenizer_hash")
2131
+ if isinstance(tok, str) and tok:
2132
+ tokenizer_hash = tok
2133
+ except Exception:
2134
+ pass
1780
2135
  console.print(
1781
2136
  "🧬 Loaded baseline evaluation schedule for pairing"
1782
2137
  )
1783
- elif (profile or "").lower() == "release":
1784
- console.print(
1785
- f"[red]❌ Baseline report '{baseline}' does not contain evaluation_windows required for pairing.[/red]"
1786
- )
1787
- raise typer.Exit(1)
1788
2138
  else:
2139
+ msg = (
2140
+ "PAIRING-EVIDENCE-MISSING: baseline report missing or invalid "
2141
+ f"evaluation_windows ({baseline})"
2142
+ )
2143
+ if strict_baseline:
2144
+ raise InvarlockError(code="E001", message=msg)
1789
2145
  console.print(
1790
- f"[yellow]⚠️ Baseline report '{baseline}' lacks evaluation_windows; falling back to dataset schedule.[/yellow]"
2146
+ f"[yellow]⚠️ {msg}. Falling back to dataset schedule.[/yellow]"
1791
2147
  )
1792
2148
  baseline_report_data = None
1793
2149
  pairing_schedule = None
1794
- except typer.Exit:
1795
- raise
1796
- except Exception as exc: # noqa: BLE001
1797
- console.print(
1798
- f"[yellow]⚠️ Failed to load baseline report '{baseline}': {exc}. Falling back to dataset schedule.[/yellow]"
1799
- )
1800
- baseline_report_data = None
1801
- pairing_schedule = None
1802
- else:
1803
- console.print(
1804
- f"[yellow]⚠️ Baseline report '{baseline}' not found. Falling back to dataset schedule.[/yellow]"
1805
- )
1806
2150
 
1807
2151
  requested_preview = int(getattr(cfg.dataset, "preview_n", 0))
1808
2152
  requested_final = int(getattr(cfg.dataset, "final_n", 0))
@@ -1994,6 +2338,7 @@ def run_command(
1994
2338
  baseline_report_data,
1995
2339
  tokenizer_hash=tokenizer_hash,
1996
2340
  resolved_loss_type=resolved_loss_type,
2341
+ profile=profile,
1997
2342
  baseline_path_str=str(baseline) if baseline else None,
1998
2343
  console=console,
1999
2344
  )
@@ -2787,12 +3132,16 @@ def run_command(
2787
3132
  model = None
2788
3133
  restore_fn = None
2789
3134
  snapshot_tmpdir: str | None = None
3135
+ snapshot_provenance: dict[str, bool] = {
3136
+ "restore_failed": False,
3137
+ "reload_path_used": False,
3138
+ }
2790
3139
 
2791
3140
  # Try single-load with snapshot/restore if adapter supports it; fallback to reload per attempt
2792
3141
  try:
2793
3142
  # Load once
2794
3143
  console.print(f"🔧 Loading model once: {cfg.model.id}")
2795
- model = adapter.load_model(cfg.model.id, device=resolved_device)
3144
+ model = _load_model_with_cfg(adapter, cfg, resolved_device)
2796
3145
 
2797
3146
  # No edit-specific bootstrap logic
2798
3147
 
@@ -2960,12 +3309,26 @@ def run_command(
2960
3309
 
2961
3310
  restore_fn = _restore
2962
3311
  elif mode == "bytes":
2963
- base_blob = adapter.snapshot(model) # type: ignore[attr-defined]
3312
+ supports_chunked = hasattr(adapter, "snapshot_chunked") and hasattr(
3313
+ adapter, "restore_chunked"
3314
+ )
3315
+ try:
3316
+ base_blob = adapter.snapshot(model) # type: ignore[attr-defined]
3317
+ except Exception:
3318
+ if not supports_chunked:
3319
+ raise
3320
+ snapshot_tmpdir = adapter.snapshot_chunked(model) # type: ignore[attr-defined]
3321
+
3322
+ def _restore_fallback_chunked():
3323
+ adapter.restore_chunked(model, snapshot_tmpdir) # type: ignore[attr-defined]
2964
3324
 
2965
- def _restore2():
2966
- adapter.restore(model, base_blob) # type: ignore[attr-defined]
3325
+ restore_fn = _restore_fallback_chunked
3326
+ else:
3327
+
3328
+ def _restore2():
3329
+ adapter.restore(model, base_blob) # type: ignore[attr-defined]
2967
3330
 
2968
- restore_fn = _restore2
3331
+ restore_fn = _restore2
2969
3332
  else:
2970
3333
  # reload path - properly free GPU memory before setting to None
2971
3334
  _free_model_memory(model)
@@ -3009,62 +3372,88 @@ def run_command(
3009
3372
  )
3010
3373
 
3011
3374
  guard_overhead_payload: dict[str, Any] | None = None
3012
- if skip_overhead and profile_normalized in {"ci", "release"}:
3013
- guard_overhead_payload = {
3014
- "overhead_threshold": GUARD_OVERHEAD_THRESHOLD,
3015
- "evaluated": False,
3016
- "passed": True,
3017
- "skipped": True,
3018
- "skip_reason": "INVARLOCK_SKIP_OVERHEAD_CHECK",
3019
- "mode": "skipped",
3020
- "source": "env:INVARLOCK_SKIP_OVERHEAD_CHECK",
3021
- "messages": [
3022
- "Overhead check skipped via INVARLOCK_SKIP_OVERHEAD_CHECK"
3023
- ],
3024
- "warnings": [],
3025
- "errors": [],
3026
- "checks": {},
3027
- }
3028
- elif measure_guard_overhead:
3029
- guard_overhead_payload = _run_bare_control(
3375
+ try:
3376
+ if skip_overhead and profile_normalized in {"ci", "release"}:
3377
+ guard_overhead_payload = {
3378
+ "overhead_threshold": GUARD_OVERHEAD_THRESHOLD,
3379
+ "evaluated": False,
3380
+ "passed": True,
3381
+ "skipped": True,
3382
+ "skip_reason": "INVARLOCK_SKIP_OVERHEAD_CHECK",
3383
+ "mode": "skipped",
3384
+ "source": "env:INVARLOCK_SKIP_OVERHEAD_CHECK",
3385
+ "messages": [
3386
+ "Overhead check skipped via INVARLOCK_SKIP_OVERHEAD_CHECK"
3387
+ ],
3388
+ "warnings": [],
3389
+ "errors": [],
3390
+ "checks": {},
3391
+ }
3392
+ elif measure_guard_overhead:
3393
+ guard_overhead_payload = _run_bare_control(
3394
+ adapter=adapter,
3395
+ edit_op=edit_op,
3396
+ cfg=cfg,
3397
+ model=model,
3398
+ run_config=run_config,
3399
+ calibration_data=calibration_data,
3400
+ auto_config=auto_config,
3401
+ edit_config=edit_config,
3402
+ preview_count=preview_count,
3403
+ final_count=final_count,
3404
+ seed_bundle=seed_bundle,
3405
+ resolved_device=resolved_device,
3406
+ restore_fn=restore_fn,
3407
+ console=console,
3408
+ resolved_loss_type=resolved_loss_type,
3409
+ profile_normalized=profile_normalized,
3410
+ snapshot_provenance=snapshot_provenance,
3411
+ skip_model_load=skip_model_load,
3412
+ )
3413
+
3414
+ # Ensure clean state for guarded run
3415
+ core_report, model = _execute_guarded_run(
3416
+ runner=runner,
3030
3417
  adapter=adapter,
3031
- edit_op=edit_op,
3032
- cfg=cfg,
3033
3418
  model=model,
3419
+ cfg=cfg,
3420
+ edit_op=edit_op,
3034
3421
  run_config=run_config,
3422
+ guards=guards,
3035
3423
  calibration_data=calibration_data,
3036
3424
  auto_config=auto_config,
3037
3425
  edit_config=edit_config,
3038
3426
  preview_count=preview_count,
3039
3427
  final_count=final_count,
3040
- seed_bundle=seed_bundle,
3041
- resolved_device=resolved_device,
3042
3428
  restore_fn=restore_fn,
3429
+ resolved_device=resolved_device,
3043
3430
  console=console,
3044
- resolved_loss_type=resolved_loss_type,
3045
- profile_normalized=profile_normalized,
3431
+ snapshot_provenance=snapshot_provenance,
3046
3432
  skip_model_load=skip_model_load,
3047
3433
  )
3048
-
3049
- # Ensure clean state for guarded run
3050
- core_report, model = _execute_guarded_run(
3051
- runner=runner,
3052
- adapter=adapter,
3053
- model=model,
3054
- cfg=cfg,
3055
- edit_op=edit_op,
3056
- run_config=run_config,
3057
- guards=guards,
3058
- calibration_data=calibration_data,
3059
- auto_config=auto_config,
3060
- edit_config=edit_config,
3061
- preview_count=preview_count,
3062
- final_count=final_count,
3063
- restore_fn=restore_fn,
3064
- resolved_device=resolved_device,
3065
- console=console,
3066
- skip_model_load=skip_model_load,
3067
- )
3434
+ except _SnapshotRestoreFailed as exc:
3435
+ snapshot_provenance["restore_failed"] = True
3436
+ _free_model_memory(model)
3437
+ model = None
3438
+ restore_fn = None
3439
+ console.print(
3440
+ "[yellow]⚠️ Snapshot restore failed; switching to reload-per-attempt.[/yellow]"
3441
+ )
3442
+ console.print(f"[yellow]↳ {exc}[/yellow]")
3443
+ if retry_controller:
3444
+ retry_controller.record_attempt(
3445
+ attempt,
3446
+ {
3447
+ "passed": False,
3448
+ "failures": ["restore_failed"],
3449
+ "validation": {},
3450
+ },
3451
+ edit_config,
3452
+ )
3453
+ if retry_controller.should_retry(False):
3454
+ attempt += 1
3455
+ continue
3456
+ raise typer.Exit(1) from exc
3068
3457
 
3069
3458
  if not hasattr(core_report, "context") or core_report.context is None:
3070
3459
  core_report.context = {}
@@ -3205,6 +3594,16 @@ def run_command(
3205
3594
  if tokenizer_hash:
3206
3595
  report["meta"]["tokenizer_hash"] = tokenizer_hash
3207
3596
 
3597
+ # Snapshot/restore provenance (survives retries).
3598
+ try:
3599
+ prov = report.setdefault("provenance", {})
3600
+ prov["restore_failed"] = bool(snapshot_provenance.get("restore_failed"))
3601
+ prov["reload_path_used"] = bool(
3602
+ snapshot_provenance.get("reload_path_used")
3603
+ )
3604
+ except Exception:
3605
+ pass
3606
+
3208
3607
  # Transfer edit information
3209
3608
  if hasattr(core_report, "edit") and core_report.edit:
3210
3609
  edit_deltas = core_report.edit.get("deltas", {})
@@ -3464,7 +3863,7 @@ def run_command(
3464
3863
  ],
3465
3864
  },
3466
3865
  }
3467
- elif had_baseline and (profile or "").lower() == "release":
3866
+ elif had_baseline and (profile or "").lower() in {"ci", "release"}:
3468
3867
  console.print(
3469
3868
  "[red]❌ [INVARLOCK:E001] PAIRING-SCHEDULE-MISMATCH: baseline pairing requested but evaluation windows were not produced. Check capacity/pairing config.[/red]"
3470
3869
  )
@@ -3482,11 +3881,12 @@ def run_command(
3482
3881
  except Exception:
3483
3882
  return 0
3484
3883
 
3884
+ preview_window_count = len(preview_records)
3885
+ final_window_count = len(final_records)
3886
+
3485
3887
  report["evaluation_windows"] = {
3486
3888
  "preview": {
3487
- "window_ids": [
3488
- f"preview::{i}" for i in range(len(preview_records))
3489
- ],
3889
+ "window_ids": list(range(preview_window_count)),
3490
3890
  "input_ids": [
3491
3891
  list(r["input_ids"]) for r in preview_records
3492
3892
  ],
@@ -3507,9 +3907,12 @@ def run_command(
3507
3907
  ),
3508
3908
  },
3509
3909
  "final": {
3510
- "window_ids": [
3511
- f"final::{i}" for i in range(len(final_records))
3512
- ],
3910
+ "window_ids": list(
3911
+ range(
3912
+ preview_window_count,
3913
+ preview_window_count + final_window_count,
3914
+ )
3915
+ ),
3513
3916
  "input_ids": [list(r["input_ids"]) for r in final_records],
3514
3917
  "attention_masks": [
3515
3918
  list(r["attention_mask"]) for r in final_records
@@ -3549,18 +3952,16 @@ def run_command(
3549
3952
  # Strict parity checks in CI/Release when baseline present
3550
3953
  try:
3551
3954
  if isinstance(baseline_report_data, dict):
3552
- base_prov = (
3553
- baseline_report_data.get("provenance", {})
3554
- if isinstance(
3555
- baseline_report_data.get("provenance"), dict
3955
+ base_digest = None
3956
+ base_prov = baseline_report_data.get("provenance")
3957
+ if isinstance(base_prov, dict):
3958
+ base_pd = base_prov.get("provider_digest")
3959
+ if isinstance(base_pd, dict):
3960
+ base_digest = base_pd
3961
+ if base_digest is None:
3962
+ base_digest = _compute_provider_digest(
3963
+ baseline_report_data
3556
3964
  )
3557
- else {}
3558
- )
3559
- base_digest = (
3560
- base_prov.get("provider_digest")
3561
- if isinstance(base_prov, dict)
3562
- else None
3563
- )
3564
3965
  _enforce_provider_parity(
3565
3966
  provider_digest,
3566
3967
  base_digest,
@@ -3576,6 +3977,8 @@ def run_command(
3576
3977
  _fail_run(str(_e))
3577
3978
  except Exception:
3578
3979
  pass
3980
+ except (typer.Exit, SystemExit, click.exceptions.Exit):
3981
+ raise
3579
3982
  except Exception:
3580
3983
  pass
3581
3984
 
@@ -3823,30 +4226,46 @@ def run_command(
3823
4226
  console.print(f"[red]{err}[/red]")
3824
4227
  raise typer.Exit(code)
3825
4228
 
3826
- # Additional guard: paired_windows collapse (0) in CI/Release
3827
- try:
4229
+ # Paired-run enforcement: baseline provided must be truly paired in CI/Release.
4230
+ if baseline and profile_normalized in {"ci", "release"}:
4231
+ pairing_reason = metrics_section.get("window_pairing_reason")
4232
+ if pairing_reason is not None:
4233
+ err = InvarlockError(
4234
+ code="E001",
4235
+ message=(
4236
+ "PAIRING-SCHEDULE-MISMATCH: baseline pairing requested but run was not paired "
4237
+ f"(window_pairing_reason={pairing_reason})"
4238
+ ),
4239
+ details={"window_pairing_reason": pairing_reason},
4240
+ )
4241
+ code = _resolve_exit_code(err, profile=profile_normalized)
4242
+ console.print(f"[red]{err}[/red]")
4243
+ raise typer.Exit(code)
4244
+
3828
4245
  paired_windows_val = metrics_section.get("paired_windows")
3829
- if (
3830
- profile_normalized in {"ci", "release"}
3831
- and isinstance(paired_windows_val, (int | float))
3832
- and int(paired_windows_val) == 0
3833
- ):
4246
+ paired_windows_int = None
4247
+ try:
4248
+ if paired_windows_val is not None and not isinstance(
4249
+ paired_windows_val, bool
4250
+ ):
4251
+ paired_windows_int = int(paired_windows_val)
4252
+ except Exception:
4253
+ paired_windows_int = None
4254
+ if paired_windows_int is None or paired_windows_int <= 0:
3834
4255
  err = InvarlockError(
3835
4256
  code="E001",
3836
4257
  message=(
3837
- "PAIRED-WINDOWS-COLLAPSED: paired_windows=0 under paired schedule. "
4258
+ "PAIRED-WINDOWS-COLLAPSED: paired_windows<=0 under paired baseline. "
3838
4259
  "Check device stability, dataset windows, or edit scope."
3839
4260
  ),
3840
4261
  details={
3841
- "paired_windows": int(paired_windows_val),
4262
+ "paired_windows": paired_windows_val,
3842
4263
  "profile": profile_normalized,
3843
4264
  },
3844
4265
  )
3845
4266
  code = _resolve_exit_code(err, profile=profile_normalized)
3846
4267
  console.print(f"[red]{err}[/red]")
3847
4268
  raise typer.Exit(code)
3848
- except Exception:
3849
- pass
3850
4269
 
3851
4270
  expected_preview = effective_preview or getattr(
3852
4271
  cfg.dataset, "preview_n", preview_count_report