invarlock 0.3.2__py3-none-any.whl → 0.3.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
invarlock/core/runner.py CHANGED
@@ -41,6 +41,31 @@ BOOTSTRAP_COVERAGE_REQUIREMENTS = {
41
41
  __all__ = ["CoreRunner"]
42
42
 
43
43
 
44
+ _BOOL_TRUE = {"1", "true", "yes", "on"}
45
+ _BOOL_FALSE = {"0", "false", "no", "off"}
46
+
47
+
48
+ def _coerce_bool(value: Any) -> bool | None:
49
+ if isinstance(value, bool):
50
+ return value
51
+ if isinstance(value, int) and value in {0, 1}:
52
+ return bool(value)
53
+ if isinstance(value, str):
54
+ lowered = value.strip().lower()
55
+ if lowered in _BOOL_TRUE:
56
+ return True
57
+ if lowered in _BOOL_FALSE:
58
+ return False
59
+ return None
60
+
61
+
62
+ def _env_flag(name: str) -> bool | None:
63
+ raw = os.environ.get(name)
64
+ if raw is None:
65
+ return None
66
+ return _coerce_bool(raw)
67
+
68
+
44
69
  def _collect_cuda_flags() -> dict[str, Any]:
45
70
  """Capture deterministic CUDA configuration for provenance."""
46
71
  flags: dict[str, Any] = {}
@@ -80,6 +105,8 @@ class CoreRunner:
80
105
  def __init__(self):
81
106
  self.event_logger: EventLogger | None = None
82
107
  self.checkpoint_manager: CheckpointManager | None = None
108
+ self._active_model: Any | None = None
109
+ self._active_adapter: ModelAdapter | None = None
83
110
 
84
111
  def execute(
85
112
  self,
@@ -111,6 +138,8 @@ class CoreRunner:
111
138
  """
112
139
  # Initialize services
113
140
  self._initialize_services(config)
141
+ self._active_model = model
142
+ self._active_adapter = adapter
114
143
 
115
144
  # Create report
116
145
  report = RunReport()
@@ -135,6 +164,19 @@ class CoreRunner:
135
164
  # Store auto configuration for tier resolution
136
165
  if auto_config:
137
166
  report.meta["auto"] = auto_config
167
+ # Ensure tier/profile context is available to guards + evaluation code.
168
+ if isinstance(config.context, dict):
169
+ existing_auto = config.context.get("auto")
170
+ if isinstance(existing_auto, dict):
171
+ merged_auto = dict(existing_auto)
172
+ merged_auto.update(auto_config)
173
+ config.context["auto"] = merged_auto
174
+ else:
175
+ config.context["auto"] = dict(auto_config)
176
+ try:
177
+ report.context["auto"] = config.context["auto"]
178
+ except Exception:
179
+ pass
138
180
 
139
181
  report.status = RunStatus.RUNNING.value
140
182
 
@@ -156,7 +198,13 @@ class CoreRunner:
156
198
 
157
199
  # Phase 2: Prepare guards (must happen before edit)
158
200
  self._prepare_guards_phase(
159
- model, adapter, guards, calibration_data, report, auto_config
201
+ model,
202
+ adapter,
203
+ guards,
204
+ calibration_data,
205
+ report,
206
+ auto_config,
207
+ config,
160
208
  )
161
209
 
162
210
  # Phase 3: Apply edit
@@ -197,10 +245,12 @@ class CoreRunner:
197
245
  return report
198
246
 
199
247
  except Exception as e:
200
- self._handle_error(e, report)
248
+ self._handle_error(e, report, model=model, adapter=adapter)
201
249
  return report
202
250
 
203
251
  finally:
252
+ self._active_model = None
253
+ self._active_adapter = None
204
254
  self._cleanup_services()
205
255
 
206
256
  def _initialize_services(self, config: RunConfig) -> None:
@@ -307,12 +357,16 @@ class CoreRunner:
307
357
  calibration_data: Any,
308
358
  report: RunReport,
309
359
  auto_config: dict[str, Any] | None = None,
360
+ config: RunConfig | None = None,
310
361
  ) -> None:
311
362
  """Phase 2: Prepare safety guards with tier-resolved policies."""
312
363
  self._log_event(
313
364
  "guards_prepare", "start", LogLevel.INFO, {"count": len(guards)}
314
365
  )
315
366
 
367
+ policy_flags = self._resolve_policy_flags(config)
368
+ strict_guard_prepare = policy_flags["strict_guard_prepare"]
369
+
316
370
  # Resolve tier policies before guard preparation
317
371
  tier_policies = self._resolve_guard_policies(report, auto_config)
318
372
 
@@ -374,6 +428,13 @@ class CoreRunner:
374
428
  LogLevel.ERROR,
375
429
  {"guard": guard.name, "error": str(e)},
376
430
  )
431
+ report.meta.setdefault("guard_prepare_failures", []).append(
432
+ {"guard": guard.name, "error": str(e)}
433
+ )
434
+ if strict_guard_prepare:
435
+ raise RuntimeError(
436
+ f"Guard '{guard.name}' prepare failed: {e}"
437
+ ) from e
377
438
 
378
439
  # Store resolved policies in report for certificate
379
440
  report.meta["tier_policies"] = tier_policies
@@ -522,6 +583,20 @@ class CoreRunner:
522
583
  }
523
584
  eval_windows = {"preview": {}, "final": {}}
524
585
 
586
+ policy_flags = self._resolve_policy_flags(config)
587
+ eval_error = metrics.get("eval_error") if isinstance(metrics, dict) else None
588
+ if eval_error:
589
+ if policy_flags["strict_eval"]:
590
+ raise RuntimeError(
591
+ f"Evaluation failed: {eval_error.get('message', 'unknown error')}"
592
+ )
593
+ self._log_event(
594
+ "eval",
595
+ "soft_fail",
596
+ LogLevel.WARNING,
597
+ {"message": eval_error.get("message"), "type": eval_error.get("type")},
598
+ )
599
+
525
600
  # Store metrics in report
526
601
  if hasattr(report, "metrics"):
527
602
  report.metrics.update(metrics)
@@ -569,6 +644,18 @@ class CoreRunner:
569
644
  process = psutil.Process(os.getpid())
570
645
  initial_memory = process.memory_info().rss / 1024 / 1024 # MB
571
646
 
647
+ policy_flags = self._resolve_policy_flags(config)
648
+ allow_materialize = policy_flags["allow_calibration_materialize"]
649
+
650
+ if not hasattr(calibration_data, "__len__"):
651
+ if allow_materialize and hasattr(calibration_data, "__iter__"):
652
+ calibration_data = list(calibration_data)
653
+ else:
654
+ raise ValueError(
655
+ "Calibration data must define __len__ (or enable materialization "
656
+ "via INVARLOCK_ALLOW_CALIBRATION_MATERIALIZE or context.run.allow_calibration_materialize)."
657
+ )
658
+
572
659
  total_available = (
573
660
  len(calibration_data) if hasattr(calibration_data, "__len__") else 0
574
661
  )
@@ -635,8 +722,34 @@ class CoreRunner:
635
722
  )
636
723
  final_n = remaining
637
724
 
638
- preview_data = calibration_data[:preview_n]
639
- final_data = calibration_data[final_start : final_start + final_n]
725
+ def _slice_calibration(start: int, count: int) -> list[Any]:
726
+ nonlocal calibration_data
727
+ end = start + count
728
+ try:
729
+ sliced = calibration_data[start:end]
730
+ return sliced if isinstance(sliced, list) else list(sliced)
731
+ except Exception as err:
732
+ if hasattr(calibration_data, "__getitem__") and hasattr(
733
+ calibration_data, "__len__"
734
+ ):
735
+ try:
736
+ return [calibration_data[i] for i in range(start, end)]
737
+ except Exception:
738
+ pass
739
+ if allow_materialize and hasattr(calibration_data, "__iter__"):
740
+ calibration_data = (
741
+ calibration_data
742
+ if isinstance(calibration_data, list)
743
+ else list(calibration_data)
744
+ )
745
+ return calibration_data[start:end]
746
+ raise TypeError(
747
+ "Calibration data must support slicing or random access. "
748
+ "Provide a list/sequence or enable materialization."
749
+ ) from err
750
+
751
+ preview_data = _slice_calibration(0, preview_n)
752
+ final_data = _slice_calibration(final_start, final_n)
640
753
 
641
754
  eval_context: dict[str, Any] = {}
642
755
  if config and isinstance(config.context, dict):
@@ -685,6 +798,7 @@ class CoreRunner:
685
798
  bootstrap_seed = (
686
799
  bootstrap_seed_cfg if bootstrap_seed_cfg is not None else dataset_seed
687
800
  )
801
+ eval_error: dict[str, Any] | None = None
688
802
  try:
689
803
  bootstrap_seed = int(bootstrap_seed) if bootstrap_seed is not None else 0
690
804
  except (TypeError, ValueError):
@@ -1156,6 +1270,12 @@ class CoreRunner:
1156
1270
 
1157
1271
  # primary_metric consumers use log-space intervals; skip ppl-space tuple here
1158
1272
 
1273
+ paired_weights: list[float] | None = None
1274
+ if preview_token_counts:
1275
+ paired_weights = [float(max(w, 0)) for w in preview_token_counts]
1276
+ elif final_token_counts:
1277
+ paired_weights = [float(max(w, 0)) for w in final_token_counts]
1278
+
1159
1279
  if (
1160
1280
  bootstrap_enabled
1161
1281
  and final_log_losses
@@ -1166,6 +1286,7 @@ class CoreRunner:
1166
1286
  delta_log_ci = compute_paired_delta_log_ci(
1167
1287
  final_log_losses,
1168
1288
  preview_log_losses,
1289
+ weights=paired_weights,
1169
1290
  method=delta_method,
1170
1291
  replicates=bootstrap_replicates,
1171
1292
  alpha=bootstrap_alpha,
@@ -1195,6 +1316,10 @@ class CoreRunner:
1195
1316
  delta_weights = [
1196
1317
  float(max(preview_token_counts[i], 1)) for i in range(limit)
1197
1318
  ]
1319
+ elif final_token_counts and len(final_token_counts) >= limit:
1320
+ delta_weights = [
1321
+ float(max(final_token_counts[i], 1)) for i in range(limit)
1322
+ ]
1198
1323
 
1199
1324
  degenerate_delta = False
1200
1325
  degenerate_reason: str | None = None
@@ -1239,6 +1364,26 @@ class CoreRunner:
1239
1364
  return 0.0
1240
1365
  return max(0.0, (len(hashes) - unique) / len(hashes))
1241
1366
 
1367
+ def _overlap_fraction_from_config(cfg: RunConfig | None) -> float | None:
1368
+ if not cfg or not isinstance(cfg.context, dict):
1369
+ return None
1370
+ dataset_cfg = cfg.context.get("dataset", {})
1371
+ if not isinstance(dataset_cfg, dict):
1372
+ return None
1373
+ seq_len_val = dataset_cfg.get("seq_len")
1374
+ stride_val = dataset_cfg.get("stride", seq_len_val)
1375
+ try:
1376
+ seq_len_f = float(seq_len_val)
1377
+ stride_f = float(stride_val)
1378
+ except (TypeError, ValueError):
1379
+ return None
1380
+ if not math.isfinite(seq_len_f) or seq_len_f <= 0:
1381
+ return None
1382
+ if not math.isfinite(stride_f) or stride_f < 0:
1383
+ return None
1384
+ overlap = (seq_len_f - stride_f) / seq_len_f
1385
+ return max(0.0, min(1.0, float(overlap)))
1386
+
1242
1387
  def _compare_with_baseline(
1243
1388
  run_ids: list[int],
1244
1389
  run_tokens: list[list[int]],
@@ -1344,10 +1489,23 @@ class CoreRunner:
1344
1489
  preview_pair_stats["expected"] + final_pair_stats["expected"]
1345
1490
  )
1346
1491
  total_matched = preview_pair_stats["matched"] + final_pair_stats["matched"]
1492
+ total_unexpected = len(preview_pair_stats["unexpected_ids"]) + len(
1493
+ final_pair_stats["unexpected_ids"]
1494
+ )
1495
+ match_denominator = total_expected + total_unexpected
1347
1496
  window_match_fraction = (
1348
- float(total_matched / total_expected) if total_expected > 0 else 1.0
1497
+ float(total_matched / match_denominator)
1498
+ if match_denominator > 0
1499
+ else 1.0
1349
1500
  )
1350
- window_overlap_fraction = _duplicate_fraction(preview_tokens + final_tokens)
1501
+ duplicate_fraction = _duplicate_fraction(preview_tokens + final_tokens)
1502
+ overlap_fraction = _overlap_fraction_from_config(config)
1503
+ overlap_unknown = False
1504
+ if overlap_fraction is None:
1505
+ overlap_unknown = True
1506
+ overlap_fraction = 1.0
1507
+ window_overlap_fraction = float(overlap_fraction)
1508
+ count_mismatch = preview_batches_ct != final_batches_ct
1351
1509
 
1352
1510
  pairing_reason = None
1353
1511
  if total_expected > 0:
@@ -1362,8 +1520,14 @@ class CoreRunner:
1362
1520
  pairing_reason = stats_dict.get("reason") or f"{label}_mismatch"
1363
1521
  break
1364
1522
  if pairing_reason is None:
1365
- if window_overlap_fraction > 0.0:
1523
+ if overlap_unknown:
1524
+ pairing_reason = "overlap_unknown"
1525
+ elif window_overlap_fraction > 0.0:
1526
+ pairing_reason = "overlapping_windows"
1527
+ elif duplicate_fraction > 0.0:
1366
1528
  pairing_reason = "duplicate_windows"
1529
+ elif count_mismatch:
1530
+ pairing_reason = "count_mismatch"
1367
1531
  elif not pairing_context:
1368
1532
  pairing_reason = preview_pair_stats.get(
1369
1533
  "reason"
@@ -1389,7 +1553,8 @@ class CoreRunner:
1389
1553
  "window_overlap_warning",
1390
1554
  LogLevel.WARNING,
1391
1555
  {
1392
- "duplicate_fraction": window_overlap_fraction,
1556
+ "overlap_fraction": window_overlap_fraction,
1557
+ "duplicate_fraction": duplicate_fraction,
1393
1558
  "match_fraction": window_match_fraction,
1394
1559
  "preview": preview_pair_stats,
1395
1560
  "final": final_pair_stats,
@@ -1403,7 +1568,11 @@ class CoreRunner:
1403
1568
  )
1404
1569
  if window_overlap_fraction > 0.0:
1405
1570
  raise RuntimeError(
1406
- f"Window duplication detected (overlap_fraction={window_overlap_fraction:.3f})"
1571
+ f"Window overlap detected (overlap_fraction={window_overlap_fraction:.3f})"
1572
+ )
1573
+ if count_mismatch:
1574
+ raise RuntimeError(
1575
+ f"Window count mismatch detected (preview={preview_batches_ct}, final={final_batches_ct})"
1407
1576
  )
1408
1577
 
1409
1578
  tier = "balanced"
@@ -1419,8 +1588,7 @@ class CoreRunner:
1419
1588
  def _meets_requirement(actual: int, required: int) -> bool:
1420
1589
  if required <= 0:
1421
1590
  return True
1422
- slack = max(1, int(required * 0.95))
1423
- return actual >= slack
1591
+ return actual >= required
1424
1592
 
1425
1593
  preview_required = int(coverage_requirements.get("preview", 0))
1426
1594
  final_required = int(coverage_requirements.get("final", 0))
@@ -1468,7 +1636,7 @@ class CoreRunner:
1468
1636
  "replicates": int(bootstrap_replicates),
1469
1637
  "seed": int(bootstrap_seed),
1470
1638
  "ci_band": float(ci_band),
1471
- "window_duplicate_fraction": float(window_overlap_fraction),
1639
+ "window_duplicate_fraction": float(duplicate_fraction),
1472
1640
  "window_match_fraction": float(window_match_fraction),
1473
1641
  "coverage": {
1474
1642
  "tier": tier,
@@ -1498,6 +1666,7 @@ class CoreRunner:
1498
1666
  LogLevel.ERROR,
1499
1667
  {"message": f"Primary-metric computation failed: {exc}"},
1500
1668
  )
1669
+ eval_error = {"type": type(exc).__name__, "message": str(exc)}
1501
1670
 
1502
1671
  pm_ratio = pm_final / pm_preview if pm_preview > 0 else 1.0
1503
1672
 
@@ -1598,6 +1767,8 @@ class CoreRunner:
1598
1767
  "degenerate_reason": degenerate_reason,
1599
1768
  },
1600
1769
  }
1770
+ if eval_error:
1771
+ metrics["eval_error"] = eval_error
1601
1772
 
1602
1773
  eval_windows = {
1603
1774
  "preview": {
@@ -1669,6 +1840,18 @@ class CoreRunner:
1669
1840
  except Exception:
1670
1841
  pass
1671
1842
 
1843
+ def _maybe_sync() -> None:
1844
+ try:
1845
+ is_cuda = False
1846
+ if hasattr(device, "type"):
1847
+ is_cuda = device.type == "cuda"
1848
+ elif isinstance(device, str):
1849
+ is_cuda = device.startswith("cuda")
1850
+ if is_cuda and torch.cuda.is_available():
1851
+ torch.cuda.synchronize()
1852
+ except Exception:
1853
+ pass
1854
+
1672
1855
  # Simple timing measurement
1673
1856
  with torch.no_grad():
1674
1857
  try:
@@ -1707,21 +1890,22 @@ class CoreRunner:
1707
1890
  labels=labels_t,
1708
1891
  token_type_ids=token_type_t,
1709
1892
  )
1710
- else:
1711
- return model(
1712
- input_ids,
1713
- labels=labels_t,
1714
- token_type_ids=token_type_t,
1715
- )
1893
+ return model(
1894
+ input_ids,
1895
+ labels=labels_t,
1896
+ token_type_ids=token_type_t,
1897
+ )
1716
1898
 
1717
1899
  # Warmup
1718
1900
  for _ in range(3):
1719
1901
  _ = _call_model()
1720
1902
 
1721
1903
  # Measure
1904
+ _maybe_sync()
1722
1905
  start_time = time.time()
1723
1906
  for _ in range(10):
1724
1907
  _ = _call_model()
1908
+ _maybe_sync()
1725
1909
  end_time = time.time()
1726
1910
 
1727
1911
  total_time = (end_time - start_time) * 1000 # Convert to ms
@@ -1826,21 +2010,34 @@ class CoreRunner:
1826
2010
  pm = metrics.get("primary_metric", {}) if isinstance(metrics, dict) else {}
1827
2011
  pm_prev = pm.get("preview") if isinstance(pm, dict) else None
1828
2012
  pm_fin = pm.get("final") if isinstance(pm, dict) else None
1829
- try:
1830
- drift_ratio = (
1831
- float(pm_fin) / float(pm_prev)
1832
- if isinstance(pm_fin, (int | float))
1833
- and isinstance(pm_prev, (int | float))
1834
- and float(pm_prev) > 0.0
1835
- else float("inf")
1836
- )
1837
- except Exception:
1838
- drift_ratio = float("inf")
1839
- spike_threshold = getattr(config, "spike_threshold", 2.0)
1840
- is_catastrophic_spike = drift_ratio > spike_threshold
2013
+ pm_kind = str(pm.get("kind", "")).lower() if isinstance(pm, dict) else ""
2014
+ is_ppl_metric = pm_kind.startswith("ppl")
2015
+
2016
+ drift_ratio: float | None = None
2017
+ if is_ppl_metric:
2018
+ try:
2019
+ if isinstance(pm_fin, (int | float)) and isinstance(
2020
+ pm_prev, (int | float)
2021
+ ):
2022
+ pm_prev_val = float(pm_prev)
2023
+ pm_fin_val = float(pm_fin)
2024
+ if (
2025
+ pm_prev_val > 0.0
2026
+ and math.isfinite(pm_prev_val)
2027
+ and math.isfinite(pm_fin_val)
2028
+ ):
2029
+ drift_ratio = pm_fin_val / pm_prev_val
2030
+ except Exception:
2031
+ drift_ratio = None
1841
2032
 
1842
- # Check if standard metrics are acceptable against configured max ratio
1843
- metrics_acceptable = drift_ratio <= getattr(config, "max_pm_ratio", 2.0)
2033
+ if drift_ratio is None:
2034
+ is_catastrophic_spike = False
2035
+ metrics_acceptable = True
2036
+ else:
2037
+ spike_threshold = getattr(config, "spike_threshold", 2.0)
2038
+ is_catastrophic_spike = drift_ratio > spike_threshold
2039
+ # Check if standard metrics are acceptable against configured max ratio
2040
+ metrics_acceptable = drift_ratio <= getattr(config, "max_pm_ratio", 2.0)
1844
2041
 
1845
2042
  # Determine rollback reason and status
1846
2043
  rollback_reason = None
@@ -1907,7 +2104,13 @@ class CoreRunner:
1907
2104
 
1908
2105
  return status
1909
2106
 
1910
- def _handle_error(self, error: Exception, report: RunReport) -> None:
2107
+ def _handle_error(
2108
+ self,
2109
+ error: Exception,
2110
+ report: RunReport,
2111
+ model: Any | None = None,
2112
+ adapter: ModelAdapter | None = None,
2113
+ ) -> None:
1911
2114
  """Handle pipeline errors."""
1912
2115
  report.status = RunStatus.FAILED.value
1913
2116
  report.error = str(error)
@@ -1924,13 +2127,26 @@ class CoreRunner:
1924
2127
  if self.checkpoint_manager and "initial_checkpoint" in report.meta:
1925
2128
  try:
1926
2129
  checkpoint_id = report.meta["initial_checkpoint"]
1927
- # Would need model and adapter here for actual rollback
2130
+ effective_model = model or self._active_model
2131
+ effective_adapter = adapter or self._active_adapter
2132
+ restored = False
2133
+ if effective_model is not None and effective_adapter is not None:
2134
+ restored = self.checkpoint_manager.restore_checkpoint(
2135
+ effective_model, effective_adapter, checkpoint_id
2136
+ )
1928
2137
  self._log_event(
1929
2138
  "runner",
1930
2139
  "emergency_rollback",
1931
2140
  LogLevel.WARNING,
1932
- {"checkpoint": checkpoint_id},
2141
+ {"checkpoint": checkpoint_id, "restored": restored},
1933
2142
  )
2143
+ if not restored:
2144
+ self._log_event(
2145
+ "runner",
2146
+ "rollback_failed",
2147
+ LogLevel.CRITICAL,
2148
+ {"checkpoint": checkpoint_id, "error": "restore_failed"},
2149
+ )
1934
2150
  except Exception as rollback_error:
1935
2151
  self._log_event(
1936
2152
  "runner",
@@ -2039,3 +2255,57 @@ class CoreRunner:
2039
2255
  "verbose": config.verbose,
2040
2256
  "guards": config.context.get("guards", {}) if config.context else {},
2041
2257
  }
2258
+
2259
+ def _resolve_policy_flags(self, config: RunConfig | None) -> dict[str, bool]:
2260
+ run_ctx: dict[str, Any] = {}
2261
+ eval_ctx: dict[str, Any] = {}
2262
+ if config and isinstance(config.context, dict):
2263
+ run_ctx = (
2264
+ config.context.get("run", {})
2265
+ if isinstance(config.context.get("run"), dict)
2266
+ else {}
2267
+ )
2268
+ eval_ctx = (
2269
+ config.context.get("eval", {})
2270
+ if isinstance(config.context.get("eval"), dict)
2271
+ else {}
2272
+ )
2273
+
2274
+ def _resolve_flag(
2275
+ *,
2276
+ run_key: str,
2277
+ eval_keys: tuple[str, ...],
2278
+ env_key: str,
2279
+ default: bool,
2280
+ ) -> bool:
2281
+ val = _coerce_bool(run_ctx.get(run_key))
2282
+ if val is None:
2283
+ for key in eval_keys:
2284
+ val = _coerce_bool(eval_ctx.get(key))
2285
+ if val is not None:
2286
+ break
2287
+ env_val = _env_flag(env_key)
2288
+ if env_val is not None:
2289
+ val = env_val
2290
+ return default if val is None else bool(val)
2291
+
2292
+ return {
2293
+ "strict_eval": _resolve_flag(
2294
+ run_key="strict_eval",
2295
+ eval_keys=("strict_errors", "strict"),
2296
+ env_key="INVARLOCK_EVAL_STRICT",
2297
+ default=True,
2298
+ ),
2299
+ "strict_guard_prepare": _resolve_flag(
2300
+ run_key="strict_guard_prepare",
2301
+ eval_keys=(),
2302
+ env_key="INVARLOCK_GUARD_PREPARE_STRICT",
2303
+ default=True,
2304
+ ),
2305
+ "allow_calibration_materialize": _resolve_flag(
2306
+ run_key="allow_calibration_materialize",
2307
+ eval_keys=("materialize_calibration", "allow_iterable_calibration"),
2308
+ env_key="INVARLOCK_ALLOW_CALIBRATION_MATERIALIZE",
2309
+ default=False,
2310
+ ),
2311
+ }
@@ -16,6 +16,7 @@ from invarlock.core.exceptions import ValidationError
16
16
  def paired_delta_mean_ci(
17
17
  subject: Iterable[float],
18
18
  baseline: Iterable[float],
19
+ weights: Iterable[float] | None = None,
19
20
  *,
20
21
  reps: int = 2000,
21
22
  seed: int = 0,
@@ -27,7 +28,7 @@ def paired_delta_mean_ci(
27
28
 
28
29
  Notes:
29
30
  - When `method == 'bca'`, this dispatches to the core BCa implementation.
30
- - `weights` are currently not supported; pass pre-aggregated per-example values.
31
+ - Optional `weights` apply token-weighted resampling when provided.
31
32
  """
32
33
  alpha = 1.0 - float(ci_level)
33
34
  if method not in {"bca", "percentile"}:
@@ -43,6 +44,7 @@ def paired_delta_mean_ci(
43
44
  return _paired_delta_bca(
44
45
  list(subject),
45
46
  list(baseline),
47
+ weights=list(weights) if weights is not None else None,
46
48
  method="bca" if method == "bca" else "percentile",
47
49
  replicates=int(reps),
48
50
  alpha=alpha,
@@ -214,9 +214,15 @@ class _PPLCausal(PrimaryMetric):
214
214
  ) -> dict[str, Any]:
215
215
  subj = self._coerce_contrib_array(subject)
216
216
  base = self._coerce_contrib_array(baseline)
217
- # Compute simple (unweighted) per-example arrays in log space; weights ignored for bootstrap here
217
+ # Compute per-example arrays in log space; use weights for paired bootstrap
218
218
  subj_vals = [v for (v, _w) in subj]
219
219
  base_vals = [v for (v, _w) in base]
220
+ pair_weights = []
221
+ for (_sv, sw), (_bv, bw) in zip(subj, base, strict=False):
222
+ weight = bw if math.isfinite(bw) and bw > 0 else sw
223
+ if not math.isfinite(weight) or weight <= 0:
224
+ weight = 1.0
225
+ pair_weights.append(float(weight))
220
226
 
221
227
  # Points in display space
222
228
  def _point(
@@ -249,15 +255,24 @@ class _PPLCausal(PrimaryMetric):
249
255
  dlog_lo, dlog_hi = compute_paired_delta_log_ci(
250
256
  subj_vals,
251
257
  base_vals,
258
+ weights=pair_weights,
252
259
  method="bca",
253
260
  replicates=reps_eff,
254
261
  alpha=alpha,
255
262
  seed=seed_eff,
256
263
  )
257
- delta_log = float(
258
- sum((s - b) for s, b in zip(subj_vals, base_vals, strict=False))
259
- / max(1, min(len(subj_vals), len(base_vals)))
260
- )
264
+ if pair_weights and len(pair_weights) >= min(len(subj_vals), len(base_vals)):
265
+ sw = 0.0
266
+ swx = 0.0
267
+ for s, b, w in zip(subj_vals, base_vals, pair_weights, strict=False):
268
+ sw += w
269
+ swx += w * (s - b)
270
+ delta_log = float(swx / sw) if sw > 0 else float("nan")
271
+ else:
272
+ delta_log = float(
273
+ sum((s - b) for s, b in zip(subj_vals, base_vals, strict=False))
274
+ / max(1, min(len(subj_vals), len(base_vals)))
275
+ )
261
276
  ratio = self.display_transform(delta_log)
262
277
  return {
263
278
  "kind": self.kind,