invarlock 0.3.1__py3-none-any.whl → 0.3.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (40) hide show
  1. invarlock/__init__.py +1 -1
  2. invarlock/_data/runtime/tiers.yaml +61 -0
  3. invarlock/adapters/hf_loading.py +97 -0
  4. invarlock/calibration/__init__.py +6 -0
  5. invarlock/calibration/spectral_null.py +301 -0
  6. invarlock/calibration/variance_ve.py +154 -0
  7. invarlock/cli/app.py +15 -0
  8. invarlock/cli/commands/calibrate.py +576 -0
  9. invarlock/cli/commands/doctor.py +9 -3
  10. invarlock/cli/commands/explain_gates.py +53 -9
  11. invarlock/cli/commands/plugins.py +12 -2
  12. invarlock/cli/commands/run.py +181 -79
  13. invarlock/cli/commands/verify.py +40 -0
  14. invarlock/cli/config.py +11 -1
  15. invarlock/cli/determinism.py +252 -0
  16. invarlock/core/auto_tuning.py +215 -17
  17. invarlock/core/bootstrap.py +137 -5
  18. invarlock/core/registry.py +9 -4
  19. invarlock/core/runner.py +305 -35
  20. invarlock/eval/bench.py +467 -141
  21. invarlock/eval/bench_regression.py +12 -0
  22. invarlock/eval/bootstrap.py +3 -1
  23. invarlock/eval/data.py +29 -7
  24. invarlock/eval/primary_metric.py +20 -5
  25. invarlock/guards/rmt.py +536 -46
  26. invarlock/guards/spectral.py +217 -10
  27. invarlock/guards/variance.py +124 -42
  28. invarlock/reporting/certificate.py +476 -45
  29. invarlock/reporting/certificate_schema.py +4 -1
  30. invarlock/reporting/guards_analysis.py +108 -10
  31. invarlock/reporting/normalizer.py +24 -1
  32. invarlock/reporting/policy_utils.py +97 -15
  33. invarlock/reporting/primary_metric_utils.py +17 -0
  34. invarlock/reporting/validate.py +10 -10
  35. {invarlock-0.3.1.dist-info → invarlock-0.3.3.dist-info}/METADATA +12 -10
  36. {invarlock-0.3.1.dist-info → invarlock-0.3.3.dist-info}/RECORD +40 -33
  37. {invarlock-0.3.1.dist-info → invarlock-0.3.3.dist-info}/WHEEL +0 -0
  38. {invarlock-0.3.1.dist-info → invarlock-0.3.3.dist-info}/entry_points.txt +0 -0
  39. {invarlock-0.3.1.dist-info → invarlock-0.3.3.dist-info}/licenses/LICENSE +0 -0
  40. {invarlock-0.3.1.dist-info → invarlock-0.3.3.dist-info}/top_level.txt +0 -0
@@ -26,6 +26,80 @@ from invarlock.core.api import Guard
26
26
  from ._contracts import guard_assert
27
27
 
28
28
 
29
+ def _z_to_two_sided_pvalue(z: Any) -> float:
30
+ try:
31
+ zf = float(z)
32
+ if not math.isfinite(zf):
33
+ return 1.0
34
+ return float(math.erfc(abs(zf) / math.sqrt(2.0)))
35
+ except Exception:
36
+ return 1.0
37
+
38
+
39
+ def _finite01(value: Any) -> bool:
40
+ try:
41
+ f = float(value)
42
+ return math.isfinite(f) and 0.0 <= f <= 1.0
43
+ except Exception:
44
+ return False
45
+
46
+
47
+ def _bh_reject_families(
48
+ family_pvals: dict[str, float], *, alpha: float, m: int
49
+ ) -> set[str]:
50
+ """BH family selection with denominator `m` (conservative if m >= #families)."""
51
+ if not family_pvals:
52
+ return set()
53
+ try:
54
+ alpha_f = float(alpha)
55
+ except Exception:
56
+ alpha_f = 0.05
57
+ if not (0.0 < alpha_f <= 1.0):
58
+ return set()
59
+
60
+ names = list(family_pvals.keys())
61
+ pvals = [family_pvals[n] for n in names]
62
+ n = len(pvals)
63
+ m_eff = max(int(m) if isinstance(m, int) else 0, n, 1)
64
+
65
+ order = sorted(
66
+ range(n),
67
+ key=lambda idx: (float("inf") if not _finite01(pvals[idx]) else pvals[idx]),
68
+ )
69
+ max_k = 0
70
+ for rank, idx in enumerate(order, start=1):
71
+ p = pvals[idx]
72
+ if not _finite01(p):
73
+ continue
74
+ if p <= (alpha_f * rank) / m_eff:
75
+ max_k = rank
76
+ if max_k <= 0:
77
+ return set()
78
+ cutoff = (alpha_f * max_k) / m_eff
79
+ selected: set[str] = set()
80
+ for idx in order:
81
+ p = pvals[idx]
82
+ if _finite01(p) and p <= cutoff:
83
+ selected.add(names[idx])
84
+ return selected
85
+
86
+
87
+ def _bonferroni_reject_families(
88
+ family_pvals: dict[str, float], *, alpha: float, m: int
89
+ ) -> set[str]:
90
+ if not family_pvals:
91
+ return set()
92
+ try:
93
+ alpha_f = float(alpha)
94
+ except Exception:
95
+ alpha_f = 0.05
96
+ if not (0.0 < alpha_f <= 1.0):
97
+ return set()
98
+ m_eff = max(int(m) if isinstance(m, int) else 0, len(family_pvals), 1)
99
+ cutoff = alpha_f / m_eff
100
+ return {fam for fam, p in family_pvals.items() if _finite01(p) and p <= cutoff}
101
+
102
+
29
103
  class SpectralPolicy(TypedDict, total=False):
30
104
  """Type definition for spectral guard policy configuration."""
31
105
 
@@ -426,7 +500,7 @@ class SpectralGuard(Guard):
426
500
  if self.ignore_preview_inflation and phase == "after_edit":
427
501
  continue
428
502
 
429
- if z_score > kappa_cap:
503
+ if abs(z_score) > kappa_cap:
430
504
  violations.append(
431
505
  {
432
506
  "type": "family_z_cap",
@@ -567,6 +641,121 @@ class SpectralGuard(Guard):
567
641
 
568
642
  return family_quantiles, top_z_scores
569
643
 
644
+ def _select_budgeted_violations(
645
+ self, budgeted_violations: list[dict[str, Any]]
646
+ ) -> tuple[list[dict[str, Any]], dict[str, Any]]:
647
+ """Apply BH/Bonferroni selection at the family level.
648
+
649
+ Returns:
650
+ (selected_violations, selection_metrics)
651
+ """
652
+ mt = self.multiple_testing if isinstance(self.multiple_testing, dict) else {}
653
+ method = str(mt.get("method", "bh")).lower()
654
+ try:
655
+ alpha = float(mt.get("alpha", 0.05) or 0.05)
656
+ except Exception:
657
+ alpha = 0.05
658
+ m_raw = mt.get("m")
659
+ m = None
660
+ try:
661
+ if m_raw is not None:
662
+ m = int(m_raw)
663
+ except Exception:
664
+ m = None
665
+
666
+ # Fill in missing family assignments deterministically.
667
+ for violation in budgeted_violations:
668
+ if violation.get("family"):
669
+ continue
670
+ module = violation.get("module")
671
+ if isinstance(module, str):
672
+ family = self.module_family_map.get(module)
673
+ if isinstance(family, str) and family:
674
+ violation["family"] = family
675
+ continue
676
+ violation["family"] = "other"
677
+
678
+ # Family p-values derived from the most significant (min p) module in each family.
679
+ family_pvals: dict[str, float] = {}
680
+ family_max_abs_z: dict[str, float] = {}
681
+ family_counts: dict[str, int] = {}
682
+ for violation in budgeted_violations:
683
+ fam = violation.get("family")
684
+ if fam is None:
685
+ continue
686
+ family = str(fam)
687
+ z_val = violation.get("z_score")
688
+ try:
689
+ zf = float(z_val)
690
+ except Exception:
691
+ continue
692
+ if not math.isfinite(zf):
693
+ continue
694
+ p = _z_to_two_sided_pvalue(zf)
695
+ family_counts[family] = family_counts.get(family, 0) + 1
696
+ cur = family_pvals.get(family)
697
+ if cur is None or p < cur:
698
+ family_pvals[family] = p
699
+ family_max_abs_z[family] = abs(zf)
700
+
701
+ families_tested = sorted(family_pvals.keys())
702
+ m_eff = m if isinstance(m, int) and m > 0 else len(families_tested)
703
+ m_eff = max(m_eff, len(families_tested), 1)
704
+ if isinstance(self.multiple_testing, dict):
705
+ self.multiple_testing.setdefault("m", m_eff)
706
+
707
+ if method in {"bh", "benjamini-hochberg", "benjamini_hochberg"}:
708
+ selected_families = _bh_reject_families(family_pvals, alpha=alpha, m=m_eff)
709
+ applied_method = "bh"
710
+ elif method in {"bonferroni", "bonf"}:
711
+ selected_families = _bonferroni_reject_families(
712
+ family_pvals, alpha=alpha, m=m_eff
713
+ )
714
+ applied_method = "bonferroni"
715
+ else:
716
+ selected_families = _bonferroni_reject_families(
717
+ family_pvals, alpha=alpha, m=m_eff
718
+ )
719
+ applied_method = "bonferroni"
720
+
721
+ selected: list[dict[str, Any]] = []
722
+ default_selected_without_pvalue = 0
723
+ for violation in budgeted_violations:
724
+ fam = violation.get("family")
725
+ family = str(fam) if fam is not None else ""
726
+ z_val = violation.get("z_score")
727
+ p_val: float | None = None
728
+ try:
729
+ zf = float(z_val)
730
+ except Exception:
731
+ zf = None
732
+ if zf is not None and math.isfinite(zf):
733
+ p_val = _z_to_two_sided_pvalue(zf)
734
+ is_selected = family in selected_families
735
+ else:
736
+ # If we cannot compute a p-value, fail closed: keep the violation.
737
+ is_selected = True
738
+ default_selected_without_pvalue += 1
739
+ violation["p_value"] = p_val
740
+ violation["selected"] = is_selected
741
+ if is_selected:
742
+ selected.append(violation)
743
+
744
+ selection_metrics = {
745
+ "method": applied_method,
746
+ "alpha": alpha,
747
+ "m": int(m_eff),
748
+ "families_tested": families_tested,
749
+ "families_selected": sorted(selected_families),
750
+ "family_pvalues": {k: float(family_pvals[k]) for k in families_tested},
751
+ "family_max_abs_z": {
752
+ k: float(family_max_abs_z[k]) for k in families_tested
753
+ },
754
+ "family_violation_counts": dict(family_counts),
755
+ "default_selected_without_pvalue": int(default_selected_without_pvalue),
756
+ }
757
+ return selected, selection_metrics
758
+
570
759
  def validate(
571
760
  self, model: Any, adapter: Any, context: dict[str, Any]
572
761
  ) -> dict[str, Any]:
@@ -607,7 +796,13 @@ class SpectralGuard(Guard):
607
796
  if violation.get("type") in fatal_violation_types
608
797
  ]
609
798
 
610
- caps_applied = len(budgeted_violations)
799
+ selected_budgeted, mt_selection = self._select_budgeted_violations(
800
+ budgeted_violations
801
+ )
802
+ selected_violations = [*fatal_violations, *selected_budgeted]
803
+ candidate_budgeted = len(budgeted_violations)
804
+
805
+ caps_applied = len(selected_budgeted)
611
806
  caps_exceeded = caps_applied > int(self.max_caps)
612
807
  passed = not fatal_violations and not caps_exceeded
613
808
  if fatal_violations or caps_exceeded:
@@ -623,8 +818,9 @@ class SpectralGuard(Guard):
623
818
  )
624
819
  metrics = {
625
820
  "modules_checked": len(current_metrics),
626
- "violations_found": len(violations),
821
+ "violations_found": len(selected_violations),
627
822
  "budgeted_violations": caps_applied,
823
+ "candidate_budgeted_violations": candidate_budgeted,
628
824
  "fatal_violations": len(fatal_violations),
629
825
  "max_spectral_norm": max(current_metrics.values())
630
826
  if current_metrics
@@ -642,6 +838,7 @@ class SpectralGuard(Guard):
642
838
  "caps_applied": caps_applied,
643
839
  "caps_exceeded": caps_exceeded,
644
840
  "multiple_testing": self.multiple_testing,
841
+ "multiple_testing_selection": mt_selection,
645
842
  }
646
843
 
647
844
  family_quantiles, top_z_scores = self._compute_family_observability()
@@ -653,7 +850,7 @@ class SpectralGuard(Guard):
653
850
  if passed:
654
851
  message = (
655
852
  "Spectral validation passed with "
656
- f"{len(violations)} violations "
853
+ f"{len(selected_violations)} violations "
657
854
  f"(caps_applied={caps_applied}, max_caps={self.max_caps})"
658
855
  )
659
856
  else:
@@ -683,7 +880,7 @@ class SpectralGuard(Guard):
683
880
  "passed": passed,
684
881
  "action": action,
685
882
  "metrics": metrics,
686
- "violations": violations,
883
+ "violations": selected_violations,
687
884
  "message": message,
688
885
  "policy": self._serialize_policy(),
689
886
  "final_z_scores": self.latest_z_scores.copy(),
@@ -743,15 +940,23 @@ class SpectralGuard(Guard):
743
940
  if violation.get("type") in fatal_violation_types
744
941
  ]
745
942
 
746
- caps_applied = len(budgeted_violations)
943
+ selected_budgeted, mt_selection = self._select_budgeted_violations(
944
+ budgeted_violations
945
+ )
946
+ selected_final_violations = [*fatal_violations, *selected_budgeted]
947
+ candidate_budgeted = len(budgeted_violations)
948
+
949
+ caps_applied = len(selected_budgeted)
747
950
  caps_exceeded = caps_applied > int(self.max_caps)
748
951
  passed = not fatal_violations and not caps_exceeded
749
952
 
750
953
  # Compute comprehensive metrics
751
954
  metrics = {
752
955
  "modules_analyzed": len(final_metrics),
753
- "violations_detected": len(final_violations),
956
+ "violations_detected": len(selected_final_violations),
754
957
  "budgeted_violations": caps_applied,
958
+ "candidate_violations_detected": len(final_violations),
959
+ "candidate_budgeted_violations": candidate_budgeted,
755
960
  "fatal_violations": len(fatal_violations),
756
961
  "baseline_modules": len(self.baseline_metrics),
757
962
  "scope": self.scope,
@@ -764,7 +969,8 @@ class SpectralGuard(Guard):
764
969
  "spectral_stability_score": 1.0
765
970
  - min(len(final_violations) / max(len(final_metrics), 1), 1.0),
766
971
  "target_sigma": self.target_sigma,
767
- "correction_applied": len(final_violations) > 0 and self.correction_enabled,
972
+ "correction_applied": len(selected_final_violations) > 0
973
+ and self.correction_enabled,
768
974
  "family_caps": self.family_caps,
769
975
  "family_z_summary": final_z_summary,
770
976
  "family_stats": final_family_stats,
@@ -774,6 +980,7 @@ class SpectralGuard(Guard):
774
980
  "caps_applied": caps_applied,
775
981
  "caps_exceeded": caps_exceeded,
776
982
  "multiple_testing": self.multiple_testing,
983
+ "multiple_testing_selection": mt_selection,
777
984
  "family_z_quantiles": family_quantiles,
778
985
  "top_z_scores": top_z_scores,
779
986
  }
@@ -782,7 +989,7 @@ class SpectralGuard(Guard):
782
989
  warnings = []
783
990
  errors = []
784
991
 
785
- for violation in final_violations:
992
+ for violation in selected_final_violations:
786
993
  if violation["type"] in ["max_spectral_norm", "ill_conditioned"]:
787
994
  errors.append(violation["message"])
788
995
  else:
@@ -793,7 +1000,7 @@ class SpectralGuard(Guard):
793
1000
  "metrics": metrics,
794
1001
  "warnings": warnings,
795
1002
  "errors": errors,
796
- "violations": final_violations,
1003
+ "violations": selected_final_violations,
797
1004
  "events": self.events,
798
1005
  "baseline_metrics": self.baseline_metrics,
799
1006
  "final_metrics": final_metrics,
@@ -403,26 +403,36 @@ def _predictive_gate_outcome(
403
403
  ):
404
404
  return False, "ci_unavailable"
405
405
 
406
- lower, upper = float(delta_ci[0]), float(delta_ci[1])
406
+ lower = float(delta_ci[0])
407
+ upper = float(delta_ci[1])
407
408
  min_effect = float(min_effect or 0.0)
408
409
 
410
+ # CI must clear zero (and the min-effect band when provided).
409
411
  if one_sided:
410
- if lower >= 0.0:
412
+ if upper >= 0.0:
411
413
  return False, "ci_contains_zero"
412
414
  if mean_delta >= 0.0:
413
415
  return False, "mean_not_negative"
414
- if min_effect > 0.0 and (-mean_delta) < min_effect:
416
+ if upper > -min_effect:
417
+ return False, "gain_below_threshold"
418
+ if mean_delta > -min_effect:
415
419
  return False, "gain_below_threshold"
416
420
  return True, "ci_gain_met"
417
421
 
418
- # Two-sided improvement: CI must be strictly below zero.
419
- if upper >= 0.0:
422
+ # Two-sided: detect regressions outside the +min_effect band, but only
423
+ # enable VE for negative improvements.
424
+ if lower <= 0.0 <= upper:
420
425
  return False, "ci_contains_zero"
421
-
422
- gain_lower_bound = -upper # Convert ΔlogNLL CI to gain CI lower bound.
423
- if gain_lower_bound < min_effect:
426
+ if lower > 0.0:
427
+ if lower >= min_effect and mean_delta >= min_effect:
428
+ return False, "regression_detected"
429
+ return False, "mean_not_negative"
430
+ if upper > -min_effect:
431
+ return False, "gain_below_threshold"
432
+ if mean_delta >= 0.0:
433
+ return False, "mean_not_negative"
434
+ if mean_delta > -min_effect:
424
435
  return False, "gain_below_threshold"
425
-
426
436
  return True, "ci_gain_met"
427
437
 
428
438
 
@@ -1438,12 +1448,17 @@ class VarianceGuard(Guard):
1438
1448
 
1439
1449
  device = next(model.parameters()).device
1440
1450
  torch.manual_seed(calib_seed)
1441
- ppl_no_ve_samples, loss_no_ve_samples = self._compute_ppl_for_batches(
1442
- model, calibration_batches, device
1451
+ (
1452
+ ppl_no_ve_samples,
1453
+ loss_no_ve_samples,
1454
+ token_counts,
1455
+ ) = self._compute_ppl_for_batches(
1456
+ model, calibration_batches, device, return_counts=True
1443
1457
  )
1444
1458
  coverage = min(len(calibration_batches), len(ppl_no_ve_samples))
1445
1459
  ppl_with_ve_samples: list[float] = []
1446
1460
  loss_with_ve_samples: list[float] = []
1461
+ token_counts_with: list[int] = []
1447
1462
  ratio_ci: tuple[float, float] | None = None
1448
1463
 
1449
1464
  enable_success = False
@@ -1459,10 +1474,12 @@ class VarianceGuard(Guard):
1459
1474
  try:
1460
1475
  torch.manual_seed(calib_seed)
1461
1476
  if enable_success:
1462
- ppl_with_ve_samples, loss_with_ve_samples = (
1463
- self._compute_ppl_for_batches(
1464
- model, calibration_batches, device
1465
- )
1477
+ (
1478
+ ppl_with_ve_samples,
1479
+ loss_with_ve_samples,
1480
+ token_counts_with,
1481
+ ) = self._compute_ppl_for_batches(
1482
+ model, calibration_batches, device, return_counts=True
1466
1483
  )
1467
1484
  finally:
1468
1485
  if enable_success:
@@ -1475,6 +1492,8 @@ class VarianceGuard(Guard):
1475
1492
  coverage,
1476
1493
  len(ppl_with_ve_samples) if ppl_with_ve_samples else coverage,
1477
1494
  len(loss_with_ve_samples) if loss_with_ve_samples else coverage,
1495
+ len(token_counts) if token_counts else coverage,
1496
+ len(token_counts_with) if token_counts_with else coverage,
1478
1497
  )
1479
1498
  self._calibration_stats.update(
1480
1499
  {
@@ -1543,6 +1562,7 @@ class VarianceGuard(Guard):
1543
1562
  loss_no_ve_samples = loss_no_ve_samples[:coverage]
1544
1563
  ppl_with_ve_samples = ppl_with_ve_samples[:coverage]
1545
1564
  loss_with_ve_samples = loss_with_ve_samples[:coverage]
1565
+ token_counts = token_counts[:coverage]
1546
1566
 
1547
1567
  ratios = [
1548
1568
  with_val / no_val
@@ -1599,6 +1619,7 @@ class VarianceGuard(Guard):
1599
1619
  delta_ci = compute_paired_delta_log_ci(
1600
1620
  loss_with_ve_samples,
1601
1621
  loss_no_ve_samples,
1622
+ weights=token_counts,
1602
1623
  method="bca",
1603
1624
  replicates=500,
1604
1625
  alpha=self._policy.get("alpha", 0.05),
@@ -1614,18 +1635,31 @@ class VarianceGuard(Guard):
1614
1635
  )
1615
1636
 
1616
1637
  predictive_state["evaluated"] = True
1617
- mean_delta = float(
1618
- np.mean(
1619
- [
1620
- with_loss - no_loss
1621
- for with_loss, no_loss in zip(
1622
- loss_with_ve_samples,
1623
- loss_no_ve_samples,
1624
- strict=False,
1625
- )
1626
- ]
1638
+ if token_counts:
1639
+ sw = 0.0
1640
+ swx = 0.0
1641
+ for with_loss, no_loss, weight in zip(
1642
+ loss_with_ve_samples,
1643
+ loss_no_ve_samples,
1644
+ token_counts,
1645
+ strict=False,
1646
+ ):
1647
+ sw += float(weight)
1648
+ swx += float(weight) * (with_loss - no_loss)
1649
+ mean_delta = float(swx / sw) if sw > 0 else float("nan")
1650
+ else:
1651
+ mean_delta = float(
1652
+ np.mean(
1653
+ [
1654
+ with_loss - no_loss
1655
+ for with_loss, no_loss in zip(
1656
+ loss_with_ve_samples,
1657
+ loss_no_ve_samples,
1658
+ strict=False,
1659
+ )
1660
+ ]
1661
+ )
1627
1662
  )
1628
- )
1629
1663
  predictive_state["mean_delta"] = mean_delta
1630
1664
 
1631
1665
  if delta_ci is not None and all(
@@ -1872,12 +1906,19 @@ class VarianceGuard(Guard):
1872
1906
  model: nn.Module,
1873
1907
  batches: list[Any],
1874
1908
  device: torch.device,
1875
- ) -> tuple[list[float], list[float]]:
1909
+ *,
1910
+ return_counts: bool = False,
1911
+ ) -> tuple[list[float], list[float]] | tuple[list[float], list[float], list[int]]:
1876
1912
  """Compute per-batch perplexity and log-loss values for deterministic calibration."""
1877
1913
  ppl_values: list[float] = []
1878
1914
  loss_values: list[float] = []
1915
+ token_counts: list[int] = []
1879
1916
  if not batches:
1880
- return ppl_values, loss_values
1917
+ return (
1918
+ (ppl_values, loss_values, token_counts)
1919
+ if return_counts
1920
+ else (ppl_values, loss_values)
1921
+ )
1881
1922
 
1882
1923
  model_was_training = model.training
1883
1924
  model.eval()
@@ -1916,12 +1957,29 @@ class VarianceGuard(Guard):
1916
1957
  if math.isfinite(ppl):
1917
1958
  ppl_values.append(ppl)
1918
1959
  loss_values.append(loss)
1960
+ if return_counts:
1961
+ count = None
1962
+ try:
1963
+ if labels is not None and isinstance(
1964
+ labels, torch.Tensor
1965
+ ):
1966
+ count = int((labels != -100).sum().item())
1967
+ except Exception:
1968
+ count = None
1969
+ if count is None:
1970
+ try:
1971
+ count = int(inputs.numel())
1972
+ except Exception:
1973
+ count = 0
1974
+ token_counts.append(int(max(count, 0)))
1919
1975
  except Exception:
1920
1976
  continue
1921
1977
 
1922
1978
  if model_was_training:
1923
1979
  model.train()
1924
1980
 
1981
+ if return_counts:
1982
+ return ppl_values, loss_values, token_counts
1925
1983
  return ppl_values, loss_values
1926
1984
 
1927
1985
  def _bootstrap_mean_ci(
@@ -2108,12 +2166,17 @@ class VarianceGuard(Guard):
2108
2166
  if calibration_batches:
2109
2167
  device = next(model.parameters()).device
2110
2168
  torch.manual_seed(calib_seed)
2111
- ppl_no_ve_samples, loss_no_ve_samples = self._compute_ppl_for_batches(
2112
- model, calibration_batches, device
2169
+ (
2170
+ ppl_no_ve_samples,
2171
+ loss_no_ve_samples,
2172
+ token_counts,
2173
+ ) = self._compute_ppl_for_batches(
2174
+ model, calibration_batches, device, return_counts=True
2113
2175
  )
2114
2176
  coverage = min(len(calibration_batches), len(ppl_no_ve_samples))
2115
2177
  ppl_with_ve_samples: list[float] = []
2116
2178
  loss_with_ve_samples: list[float] = []
2179
+ token_counts_with: list[int] = []
2117
2180
  ratio_ci: tuple[float, float] | None = None
2118
2181
 
2119
2182
  enable_success = False
@@ -2132,8 +2195,9 @@ class VarianceGuard(Guard):
2132
2195
  (
2133
2196
  ppl_with_ve_samples,
2134
2197
  loss_with_ve_samples,
2198
+ token_counts_with,
2135
2199
  ) = self._compute_ppl_for_batches(
2136
- model, calibration_batches, device
2200
+ model, calibration_batches, device, return_counts=True
2137
2201
  )
2138
2202
  finally:
2139
2203
  if enable_success:
@@ -2146,6 +2210,8 @@ class VarianceGuard(Guard):
2146
2210
  coverage,
2147
2211
  len(ppl_with_ve_samples) if ppl_with_ve_samples else coverage,
2148
2212
  len(loss_with_ve_samples) if loss_with_ve_samples else coverage,
2213
+ len(token_counts) if token_counts else coverage,
2214
+ len(token_counts_with) if token_counts_with else coverage,
2149
2215
  )
2150
2216
  self._calibration_stats.update(
2151
2217
  {"coverage": coverage, "status": "insufficient"}
@@ -2178,6 +2244,8 @@ class VarianceGuard(Guard):
2178
2244
  loss_no_ve_samples = loss_no_ve_samples[:coverage]
2179
2245
  ppl_with_ve_samples = ppl_with_ve_samples[:coverage]
2180
2246
  loss_with_ve_samples = loss_with_ve_samples[:coverage]
2247
+ token_counts = token_counts[:coverage]
2248
+ token_counts_with = token_counts_with[:coverage]
2181
2249
 
2182
2250
  ratios = [
2183
2251
  with_val / no_val
@@ -2216,6 +2284,7 @@ class VarianceGuard(Guard):
2216
2284
  delta_ci = compute_paired_delta_log_ci(
2217
2285
  loss_with_ve_samples,
2218
2286
  loss_no_ve_samples,
2287
+ weights=token_counts,
2219
2288
  method="bca",
2220
2289
  replicates=500,
2221
2290
  alpha=self._policy.get("alpha", 0.05),
@@ -2231,18 +2300,31 @@ class VarianceGuard(Guard):
2231
2300
  )
2232
2301
 
2233
2302
  predictive_state["evaluated"] = True
2234
- mean_delta = float(
2235
- np.mean(
2236
- [
2237
- with_loss - no_loss
2238
- for with_loss, no_loss in zip(
2239
- loss_with_ve_samples,
2240
- loss_no_ve_samples,
2241
- strict=False,
2242
- )
2243
- ]
2303
+ if token_counts:
2304
+ sw = 0.0
2305
+ swx = 0.0
2306
+ for with_loss, no_loss, weight in zip(
2307
+ loss_with_ve_samples,
2308
+ loss_no_ve_samples,
2309
+ token_counts,
2310
+ strict=False,
2311
+ ):
2312
+ sw += float(weight)
2313
+ swx += float(weight) * (with_loss - no_loss)
2314
+ mean_delta = float(swx / sw) if sw > 0 else float("nan")
2315
+ else:
2316
+ mean_delta = float(
2317
+ np.mean(
2318
+ [
2319
+ with_loss - no_loss
2320
+ for with_loss, no_loss in zip(
2321
+ loss_with_ve_samples,
2322
+ loss_no_ve_samples,
2323
+ strict=False,
2324
+ )
2325
+ ]
2326
+ )
2244
2327
  )
2245
- )
2246
2328
  predictive_state["mean_delta"] = mean_delta
2247
2329
 
2248
2330
  if delta_ci is not None and all(