invarlock 0.3.1__py3-none-any.whl → 0.3.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- invarlock/__init__.py +1 -1
- invarlock/_data/runtime/tiers.yaml +61 -0
- invarlock/adapters/hf_loading.py +97 -0
- invarlock/calibration/__init__.py +6 -0
- invarlock/calibration/spectral_null.py +301 -0
- invarlock/calibration/variance_ve.py +154 -0
- invarlock/cli/app.py +15 -0
- invarlock/cli/commands/calibrate.py +576 -0
- invarlock/cli/commands/doctor.py +9 -3
- invarlock/cli/commands/explain_gates.py +53 -9
- invarlock/cli/commands/plugins.py +12 -2
- invarlock/cli/commands/run.py +181 -79
- invarlock/cli/commands/verify.py +40 -0
- invarlock/cli/config.py +11 -1
- invarlock/cli/determinism.py +252 -0
- invarlock/core/auto_tuning.py +215 -17
- invarlock/core/bootstrap.py +137 -5
- invarlock/core/registry.py +9 -4
- invarlock/core/runner.py +305 -35
- invarlock/eval/bench.py +467 -141
- invarlock/eval/bench_regression.py +12 -0
- invarlock/eval/bootstrap.py +3 -1
- invarlock/eval/data.py +29 -7
- invarlock/eval/primary_metric.py +20 -5
- invarlock/guards/rmt.py +536 -46
- invarlock/guards/spectral.py +217 -10
- invarlock/guards/variance.py +124 -42
- invarlock/reporting/certificate.py +476 -45
- invarlock/reporting/certificate_schema.py +4 -1
- invarlock/reporting/guards_analysis.py +108 -10
- invarlock/reporting/normalizer.py +24 -1
- invarlock/reporting/policy_utils.py +97 -15
- invarlock/reporting/primary_metric_utils.py +17 -0
- invarlock/reporting/validate.py +10 -10
- {invarlock-0.3.1.dist-info → invarlock-0.3.3.dist-info}/METADATA +12 -10
- {invarlock-0.3.1.dist-info → invarlock-0.3.3.dist-info}/RECORD +40 -33
- {invarlock-0.3.1.dist-info → invarlock-0.3.3.dist-info}/WHEEL +0 -0
- {invarlock-0.3.1.dist-info → invarlock-0.3.3.dist-info}/entry_points.txt +0 -0
- {invarlock-0.3.1.dist-info → invarlock-0.3.3.dist-info}/licenses/LICENSE +0 -0
- {invarlock-0.3.1.dist-info → invarlock-0.3.3.dist-info}/top_level.txt +0 -0
invarlock/guards/spectral.py
CHANGED
|
@@ -26,6 +26,80 @@ from invarlock.core.api import Guard
|
|
|
26
26
|
from ._contracts import guard_assert
|
|
27
27
|
|
|
28
28
|
|
|
29
|
+
def _z_to_two_sided_pvalue(z: Any) -> float:
|
|
30
|
+
try:
|
|
31
|
+
zf = float(z)
|
|
32
|
+
if not math.isfinite(zf):
|
|
33
|
+
return 1.0
|
|
34
|
+
return float(math.erfc(abs(zf) / math.sqrt(2.0)))
|
|
35
|
+
except Exception:
|
|
36
|
+
return 1.0
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def _finite01(value: Any) -> bool:
|
|
40
|
+
try:
|
|
41
|
+
f = float(value)
|
|
42
|
+
return math.isfinite(f) and 0.0 <= f <= 1.0
|
|
43
|
+
except Exception:
|
|
44
|
+
return False
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
def _bh_reject_families(
|
|
48
|
+
family_pvals: dict[str, float], *, alpha: float, m: int
|
|
49
|
+
) -> set[str]:
|
|
50
|
+
"""BH family selection with denominator `m` (conservative if m >= #families)."""
|
|
51
|
+
if not family_pvals:
|
|
52
|
+
return set()
|
|
53
|
+
try:
|
|
54
|
+
alpha_f = float(alpha)
|
|
55
|
+
except Exception:
|
|
56
|
+
alpha_f = 0.05
|
|
57
|
+
if not (0.0 < alpha_f <= 1.0):
|
|
58
|
+
return set()
|
|
59
|
+
|
|
60
|
+
names = list(family_pvals.keys())
|
|
61
|
+
pvals = [family_pvals[n] for n in names]
|
|
62
|
+
n = len(pvals)
|
|
63
|
+
m_eff = max(int(m) if isinstance(m, int) else 0, n, 1)
|
|
64
|
+
|
|
65
|
+
order = sorted(
|
|
66
|
+
range(n),
|
|
67
|
+
key=lambda idx: (float("inf") if not _finite01(pvals[idx]) else pvals[idx]),
|
|
68
|
+
)
|
|
69
|
+
max_k = 0
|
|
70
|
+
for rank, idx in enumerate(order, start=1):
|
|
71
|
+
p = pvals[idx]
|
|
72
|
+
if not _finite01(p):
|
|
73
|
+
continue
|
|
74
|
+
if p <= (alpha_f * rank) / m_eff:
|
|
75
|
+
max_k = rank
|
|
76
|
+
if max_k <= 0:
|
|
77
|
+
return set()
|
|
78
|
+
cutoff = (alpha_f * max_k) / m_eff
|
|
79
|
+
selected: set[str] = set()
|
|
80
|
+
for idx in order:
|
|
81
|
+
p = pvals[idx]
|
|
82
|
+
if _finite01(p) and p <= cutoff:
|
|
83
|
+
selected.add(names[idx])
|
|
84
|
+
return selected
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
def _bonferroni_reject_families(
|
|
88
|
+
family_pvals: dict[str, float], *, alpha: float, m: int
|
|
89
|
+
) -> set[str]:
|
|
90
|
+
if not family_pvals:
|
|
91
|
+
return set()
|
|
92
|
+
try:
|
|
93
|
+
alpha_f = float(alpha)
|
|
94
|
+
except Exception:
|
|
95
|
+
alpha_f = 0.05
|
|
96
|
+
if not (0.0 < alpha_f <= 1.0):
|
|
97
|
+
return set()
|
|
98
|
+
m_eff = max(int(m) if isinstance(m, int) else 0, len(family_pvals), 1)
|
|
99
|
+
cutoff = alpha_f / m_eff
|
|
100
|
+
return {fam for fam, p in family_pvals.items() if _finite01(p) and p <= cutoff}
|
|
101
|
+
|
|
102
|
+
|
|
29
103
|
class SpectralPolicy(TypedDict, total=False):
|
|
30
104
|
"""Type definition for spectral guard policy configuration."""
|
|
31
105
|
|
|
@@ -426,7 +500,7 @@ class SpectralGuard(Guard):
|
|
|
426
500
|
if self.ignore_preview_inflation and phase == "after_edit":
|
|
427
501
|
continue
|
|
428
502
|
|
|
429
|
-
if z_score > kappa_cap:
|
|
503
|
+
if abs(z_score) > kappa_cap:
|
|
430
504
|
violations.append(
|
|
431
505
|
{
|
|
432
506
|
"type": "family_z_cap",
|
|
@@ -567,6 +641,121 @@ class SpectralGuard(Guard):
|
|
|
567
641
|
|
|
568
642
|
return family_quantiles, top_z_scores
|
|
569
643
|
|
|
644
|
+
def _select_budgeted_violations(
|
|
645
|
+
self, budgeted_violations: list[dict[str, Any]]
|
|
646
|
+
) -> tuple[list[dict[str, Any]], dict[str, Any]]:
|
|
647
|
+
"""Apply BH/Bonferroni selection at the family level.
|
|
648
|
+
|
|
649
|
+
Returns:
|
|
650
|
+
(selected_violations, selection_metrics)
|
|
651
|
+
"""
|
|
652
|
+
mt = self.multiple_testing if isinstance(self.multiple_testing, dict) else {}
|
|
653
|
+
method = str(mt.get("method", "bh")).lower()
|
|
654
|
+
try:
|
|
655
|
+
alpha = float(mt.get("alpha", 0.05) or 0.05)
|
|
656
|
+
except Exception:
|
|
657
|
+
alpha = 0.05
|
|
658
|
+
m_raw = mt.get("m")
|
|
659
|
+
m = None
|
|
660
|
+
try:
|
|
661
|
+
if m_raw is not None:
|
|
662
|
+
m = int(m_raw)
|
|
663
|
+
except Exception:
|
|
664
|
+
m = None
|
|
665
|
+
|
|
666
|
+
# Fill in missing family assignments deterministically.
|
|
667
|
+
for violation in budgeted_violations:
|
|
668
|
+
if violation.get("family"):
|
|
669
|
+
continue
|
|
670
|
+
module = violation.get("module")
|
|
671
|
+
if isinstance(module, str):
|
|
672
|
+
family = self.module_family_map.get(module)
|
|
673
|
+
if isinstance(family, str) and family:
|
|
674
|
+
violation["family"] = family
|
|
675
|
+
continue
|
|
676
|
+
violation["family"] = "other"
|
|
677
|
+
|
|
678
|
+
# Family p-values derived from the most significant (min p) module in each family.
|
|
679
|
+
family_pvals: dict[str, float] = {}
|
|
680
|
+
family_max_abs_z: dict[str, float] = {}
|
|
681
|
+
family_counts: dict[str, int] = {}
|
|
682
|
+
for violation in budgeted_violations:
|
|
683
|
+
fam = violation.get("family")
|
|
684
|
+
if fam is None:
|
|
685
|
+
continue
|
|
686
|
+
family = str(fam)
|
|
687
|
+
z_val = violation.get("z_score")
|
|
688
|
+
try:
|
|
689
|
+
zf = float(z_val)
|
|
690
|
+
except Exception:
|
|
691
|
+
continue
|
|
692
|
+
if not math.isfinite(zf):
|
|
693
|
+
continue
|
|
694
|
+
p = _z_to_two_sided_pvalue(zf)
|
|
695
|
+
family_counts[family] = family_counts.get(family, 0) + 1
|
|
696
|
+
cur = family_pvals.get(family)
|
|
697
|
+
if cur is None or p < cur:
|
|
698
|
+
family_pvals[family] = p
|
|
699
|
+
family_max_abs_z[family] = abs(zf)
|
|
700
|
+
|
|
701
|
+
families_tested = sorted(family_pvals.keys())
|
|
702
|
+
m_eff = m if isinstance(m, int) and m > 0 else len(families_tested)
|
|
703
|
+
m_eff = max(m_eff, len(families_tested), 1)
|
|
704
|
+
if isinstance(self.multiple_testing, dict):
|
|
705
|
+
self.multiple_testing.setdefault("m", m_eff)
|
|
706
|
+
|
|
707
|
+
if method in {"bh", "benjamini-hochberg", "benjamini_hochberg"}:
|
|
708
|
+
selected_families = _bh_reject_families(family_pvals, alpha=alpha, m=m_eff)
|
|
709
|
+
applied_method = "bh"
|
|
710
|
+
elif method in {"bonferroni", "bonf"}:
|
|
711
|
+
selected_families = _bonferroni_reject_families(
|
|
712
|
+
family_pvals, alpha=alpha, m=m_eff
|
|
713
|
+
)
|
|
714
|
+
applied_method = "bonferroni"
|
|
715
|
+
else:
|
|
716
|
+
selected_families = _bonferroni_reject_families(
|
|
717
|
+
family_pvals, alpha=alpha, m=m_eff
|
|
718
|
+
)
|
|
719
|
+
applied_method = "bonferroni"
|
|
720
|
+
|
|
721
|
+
selected: list[dict[str, Any]] = []
|
|
722
|
+
default_selected_without_pvalue = 0
|
|
723
|
+
for violation in budgeted_violations:
|
|
724
|
+
fam = violation.get("family")
|
|
725
|
+
family = str(fam) if fam is not None else ""
|
|
726
|
+
z_val = violation.get("z_score")
|
|
727
|
+
p_val: float | None = None
|
|
728
|
+
try:
|
|
729
|
+
zf = float(z_val)
|
|
730
|
+
except Exception:
|
|
731
|
+
zf = None
|
|
732
|
+
if zf is not None and math.isfinite(zf):
|
|
733
|
+
p_val = _z_to_two_sided_pvalue(zf)
|
|
734
|
+
is_selected = family in selected_families
|
|
735
|
+
else:
|
|
736
|
+
# If we cannot compute a p-value, fail closed: keep the violation.
|
|
737
|
+
is_selected = True
|
|
738
|
+
default_selected_without_pvalue += 1
|
|
739
|
+
violation["p_value"] = p_val
|
|
740
|
+
violation["selected"] = is_selected
|
|
741
|
+
if is_selected:
|
|
742
|
+
selected.append(violation)
|
|
743
|
+
|
|
744
|
+
selection_metrics = {
|
|
745
|
+
"method": applied_method,
|
|
746
|
+
"alpha": alpha,
|
|
747
|
+
"m": int(m_eff),
|
|
748
|
+
"families_tested": families_tested,
|
|
749
|
+
"families_selected": sorted(selected_families),
|
|
750
|
+
"family_pvalues": {k: float(family_pvals[k]) for k in families_tested},
|
|
751
|
+
"family_max_abs_z": {
|
|
752
|
+
k: float(family_max_abs_z[k]) for k in families_tested
|
|
753
|
+
},
|
|
754
|
+
"family_violation_counts": dict(family_counts),
|
|
755
|
+
"default_selected_without_pvalue": int(default_selected_without_pvalue),
|
|
756
|
+
}
|
|
757
|
+
return selected, selection_metrics
|
|
758
|
+
|
|
570
759
|
def validate(
|
|
571
760
|
self, model: Any, adapter: Any, context: dict[str, Any]
|
|
572
761
|
) -> dict[str, Any]:
|
|
@@ -607,7 +796,13 @@ class SpectralGuard(Guard):
|
|
|
607
796
|
if violation.get("type") in fatal_violation_types
|
|
608
797
|
]
|
|
609
798
|
|
|
610
|
-
|
|
799
|
+
selected_budgeted, mt_selection = self._select_budgeted_violations(
|
|
800
|
+
budgeted_violations
|
|
801
|
+
)
|
|
802
|
+
selected_violations = [*fatal_violations, *selected_budgeted]
|
|
803
|
+
candidate_budgeted = len(budgeted_violations)
|
|
804
|
+
|
|
805
|
+
caps_applied = len(selected_budgeted)
|
|
611
806
|
caps_exceeded = caps_applied > int(self.max_caps)
|
|
612
807
|
passed = not fatal_violations and not caps_exceeded
|
|
613
808
|
if fatal_violations or caps_exceeded:
|
|
@@ -623,8 +818,9 @@ class SpectralGuard(Guard):
|
|
|
623
818
|
)
|
|
624
819
|
metrics = {
|
|
625
820
|
"modules_checked": len(current_metrics),
|
|
626
|
-
"violations_found": len(
|
|
821
|
+
"violations_found": len(selected_violations),
|
|
627
822
|
"budgeted_violations": caps_applied,
|
|
823
|
+
"candidate_budgeted_violations": candidate_budgeted,
|
|
628
824
|
"fatal_violations": len(fatal_violations),
|
|
629
825
|
"max_spectral_norm": max(current_metrics.values())
|
|
630
826
|
if current_metrics
|
|
@@ -642,6 +838,7 @@ class SpectralGuard(Guard):
|
|
|
642
838
|
"caps_applied": caps_applied,
|
|
643
839
|
"caps_exceeded": caps_exceeded,
|
|
644
840
|
"multiple_testing": self.multiple_testing,
|
|
841
|
+
"multiple_testing_selection": mt_selection,
|
|
645
842
|
}
|
|
646
843
|
|
|
647
844
|
family_quantiles, top_z_scores = self._compute_family_observability()
|
|
@@ -653,7 +850,7 @@ class SpectralGuard(Guard):
|
|
|
653
850
|
if passed:
|
|
654
851
|
message = (
|
|
655
852
|
"Spectral validation passed with "
|
|
656
|
-
f"{len(
|
|
853
|
+
f"{len(selected_violations)} violations "
|
|
657
854
|
f"(caps_applied={caps_applied}, max_caps={self.max_caps})"
|
|
658
855
|
)
|
|
659
856
|
else:
|
|
@@ -683,7 +880,7 @@ class SpectralGuard(Guard):
|
|
|
683
880
|
"passed": passed,
|
|
684
881
|
"action": action,
|
|
685
882
|
"metrics": metrics,
|
|
686
|
-
"violations":
|
|
883
|
+
"violations": selected_violations,
|
|
687
884
|
"message": message,
|
|
688
885
|
"policy": self._serialize_policy(),
|
|
689
886
|
"final_z_scores": self.latest_z_scores.copy(),
|
|
@@ -743,15 +940,23 @@ class SpectralGuard(Guard):
|
|
|
743
940
|
if violation.get("type") in fatal_violation_types
|
|
744
941
|
]
|
|
745
942
|
|
|
746
|
-
|
|
943
|
+
selected_budgeted, mt_selection = self._select_budgeted_violations(
|
|
944
|
+
budgeted_violations
|
|
945
|
+
)
|
|
946
|
+
selected_final_violations = [*fatal_violations, *selected_budgeted]
|
|
947
|
+
candidate_budgeted = len(budgeted_violations)
|
|
948
|
+
|
|
949
|
+
caps_applied = len(selected_budgeted)
|
|
747
950
|
caps_exceeded = caps_applied > int(self.max_caps)
|
|
748
951
|
passed = not fatal_violations and not caps_exceeded
|
|
749
952
|
|
|
750
953
|
# Compute comprehensive metrics
|
|
751
954
|
metrics = {
|
|
752
955
|
"modules_analyzed": len(final_metrics),
|
|
753
|
-
"violations_detected": len(
|
|
956
|
+
"violations_detected": len(selected_final_violations),
|
|
754
957
|
"budgeted_violations": caps_applied,
|
|
958
|
+
"candidate_violations_detected": len(final_violations),
|
|
959
|
+
"candidate_budgeted_violations": candidate_budgeted,
|
|
755
960
|
"fatal_violations": len(fatal_violations),
|
|
756
961
|
"baseline_modules": len(self.baseline_metrics),
|
|
757
962
|
"scope": self.scope,
|
|
@@ -764,7 +969,8 @@ class SpectralGuard(Guard):
|
|
|
764
969
|
"spectral_stability_score": 1.0
|
|
765
970
|
- min(len(final_violations) / max(len(final_metrics), 1), 1.0),
|
|
766
971
|
"target_sigma": self.target_sigma,
|
|
767
|
-
"correction_applied": len(
|
|
972
|
+
"correction_applied": len(selected_final_violations) > 0
|
|
973
|
+
and self.correction_enabled,
|
|
768
974
|
"family_caps": self.family_caps,
|
|
769
975
|
"family_z_summary": final_z_summary,
|
|
770
976
|
"family_stats": final_family_stats,
|
|
@@ -774,6 +980,7 @@ class SpectralGuard(Guard):
|
|
|
774
980
|
"caps_applied": caps_applied,
|
|
775
981
|
"caps_exceeded": caps_exceeded,
|
|
776
982
|
"multiple_testing": self.multiple_testing,
|
|
983
|
+
"multiple_testing_selection": mt_selection,
|
|
777
984
|
"family_z_quantiles": family_quantiles,
|
|
778
985
|
"top_z_scores": top_z_scores,
|
|
779
986
|
}
|
|
@@ -782,7 +989,7 @@ class SpectralGuard(Guard):
|
|
|
782
989
|
warnings = []
|
|
783
990
|
errors = []
|
|
784
991
|
|
|
785
|
-
for violation in
|
|
992
|
+
for violation in selected_final_violations:
|
|
786
993
|
if violation["type"] in ["max_spectral_norm", "ill_conditioned"]:
|
|
787
994
|
errors.append(violation["message"])
|
|
788
995
|
else:
|
|
@@ -793,7 +1000,7 @@ class SpectralGuard(Guard):
|
|
|
793
1000
|
"metrics": metrics,
|
|
794
1001
|
"warnings": warnings,
|
|
795
1002
|
"errors": errors,
|
|
796
|
-
"violations":
|
|
1003
|
+
"violations": selected_final_violations,
|
|
797
1004
|
"events": self.events,
|
|
798
1005
|
"baseline_metrics": self.baseline_metrics,
|
|
799
1006
|
"final_metrics": final_metrics,
|
invarlock/guards/variance.py
CHANGED
|
@@ -403,26 +403,36 @@ def _predictive_gate_outcome(
|
|
|
403
403
|
):
|
|
404
404
|
return False, "ci_unavailable"
|
|
405
405
|
|
|
406
|
-
lower
|
|
406
|
+
lower = float(delta_ci[0])
|
|
407
|
+
upper = float(delta_ci[1])
|
|
407
408
|
min_effect = float(min_effect or 0.0)
|
|
408
409
|
|
|
410
|
+
# CI must clear zero (and the min-effect band when provided).
|
|
409
411
|
if one_sided:
|
|
410
|
-
if
|
|
412
|
+
if upper >= 0.0:
|
|
411
413
|
return False, "ci_contains_zero"
|
|
412
414
|
if mean_delta >= 0.0:
|
|
413
415
|
return False, "mean_not_negative"
|
|
414
|
-
if
|
|
416
|
+
if upper > -min_effect:
|
|
417
|
+
return False, "gain_below_threshold"
|
|
418
|
+
if mean_delta > -min_effect:
|
|
415
419
|
return False, "gain_below_threshold"
|
|
416
420
|
return True, "ci_gain_met"
|
|
417
421
|
|
|
418
|
-
# Two-sided
|
|
419
|
-
|
|
422
|
+
# Two-sided: detect regressions outside the +min_effect band, but only
|
|
423
|
+
# enable VE for negative improvements.
|
|
424
|
+
if lower <= 0.0 <= upper:
|
|
420
425
|
return False, "ci_contains_zero"
|
|
421
|
-
|
|
422
|
-
|
|
423
|
-
|
|
426
|
+
if lower > 0.0:
|
|
427
|
+
if lower >= min_effect and mean_delta >= min_effect:
|
|
428
|
+
return False, "regression_detected"
|
|
429
|
+
return False, "mean_not_negative"
|
|
430
|
+
if upper > -min_effect:
|
|
431
|
+
return False, "gain_below_threshold"
|
|
432
|
+
if mean_delta >= 0.0:
|
|
433
|
+
return False, "mean_not_negative"
|
|
434
|
+
if mean_delta > -min_effect:
|
|
424
435
|
return False, "gain_below_threshold"
|
|
425
|
-
|
|
426
436
|
return True, "ci_gain_met"
|
|
427
437
|
|
|
428
438
|
|
|
@@ -1438,12 +1448,17 @@ class VarianceGuard(Guard):
|
|
|
1438
1448
|
|
|
1439
1449
|
device = next(model.parameters()).device
|
|
1440
1450
|
torch.manual_seed(calib_seed)
|
|
1441
|
-
|
|
1442
|
-
|
|
1451
|
+
(
|
|
1452
|
+
ppl_no_ve_samples,
|
|
1453
|
+
loss_no_ve_samples,
|
|
1454
|
+
token_counts,
|
|
1455
|
+
) = self._compute_ppl_for_batches(
|
|
1456
|
+
model, calibration_batches, device, return_counts=True
|
|
1443
1457
|
)
|
|
1444
1458
|
coverage = min(len(calibration_batches), len(ppl_no_ve_samples))
|
|
1445
1459
|
ppl_with_ve_samples: list[float] = []
|
|
1446
1460
|
loss_with_ve_samples: list[float] = []
|
|
1461
|
+
token_counts_with: list[int] = []
|
|
1447
1462
|
ratio_ci: tuple[float, float] | None = None
|
|
1448
1463
|
|
|
1449
1464
|
enable_success = False
|
|
@@ -1459,10 +1474,12 @@ class VarianceGuard(Guard):
|
|
|
1459
1474
|
try:
|
|
1460
1475
|
torch.manual_seed(calib_seed)
|
|
1461
1476
|
if enable_success:
|
|
1462
|
-
|
|
1463
|
-
|
|
1464
|
-
|
|
1465
|
-
|
|
1477
|
+
(
|
|
1478
|
+
ppl_with_ve_samples,
|
|
1479
|
+
loss_with_ve_samples,
|
|
1480
|
+
token_counts_with,
|
|
1481
|
+
) = self._compute_ppl_for_batches(
|
|
1482
|
+
model, calibration_batches, device, return_counts=True
|
|
1466
1483
|
)
|
|
1467
1484
|
finally:
|
|
1468
1485
|
if enable_success:
|
|
@@ -1475,6 +1492,8 @@ class VarianceGuard(Guard):
|
|
|
1475
1492
|
coverage,
|
|
1476
1493
|
len(ppl_with_ve_samples) if ppl_with_ve_samples else coverage,
|
|
1477
1494
|
len(loss_with_ve_samples) if loss_with_ve_samples else coverage,
|
|
1495
|
+
len(token_counts) if token_counts else coverage,
|
|
1496
|
+
len(token_counts_with) if token_counts_with else coverage,
|
|
1478
1497
|
)
|
|
1479
1498
|
self._calibration_stats.update(
|
|
1480
1499
|
{
|
|
@@ -1543,6 +1562,7 @@ class VarianceGuard(Guard):
|
|
|
1543
1562
|
loss_no_ve_samples = loss_no_ve_samples[:coverage]
|
|
1544
1563
|
ppl_with_ve_samples = ppl_with_ve_samples[:coverage]
|
|
1545
1564
|
loss_with_ve_samples = loss_with_ve_samples[:coverage]
|
|
1565
|
+
token_counts = token_counts[:coverage]
|
|
1546
1566
|
|
|
1547
1567
|
ratios = [
|
|
1548
1568
|
with_val / no_val
|
|
@@ -1599,6 +1619,7 @@ class VarianceGuard(Guard):
|
|
|
1599
1619
|
delta_ci = compute_paired_delta_log_ci(
|
|
1600
1620
|
loss_with_ve_samples,
|
|
1601
1621
|
loss_no_ve_samples,
|
|
1622
|
+
weights=token_counts,
|
|
1602
1623
|
method="bca",
|
|
1603
1624
|
replicates=500,
|
|
1604
1625
|
alpha=self._policy.get("alpha", 0.05),
|
|
@@ -1614,18 +1635,31 @@ class VarianceGuard(Guard):
|
|
|
1614
1635
|
)
|
|
1615
1636
|
|
|
1616
1637
|
predictive_state["evaluated"] = True
|
|
1617
|
-
|
|
1618
|
-
|
|
1619
|
-
|
|
1620
|
-
|
|
1621
|
-
|
|
1622
|
-
|
|
1623
|
-
|
|
1624
|
-
|
|
1625
|
-
|
|
1626
|
-
|
|
1638
|
+
if token_counts:
|
|
1639
|
+
sw = 0.0
|
|
1640
|
+
swx = 0.0
|
|
1641
|
+
for with_loss, no_loss, weight in zip(
|
|
1642
|
+
loss_with_ve_samples,
|
|
1643
|
+
loss_no_ve_samples,
|
|
1644
|
+
token_counts,
|
|
1645
|
+
strict=False,
|
|
1646
|
+
):
|
|
1647
|
+
sw += float(weight)
|
|
1648
|
+
swx += float(weight) * (with_loss - no_loss)
|
|
1649
|
+
mean_delta = float(swx / sw) if sw > 0 else float("nan")
|
|
1650
|
+
else:
|
|
1651
|
+
mean_delta = float(
|
|
1652
|
+
np.mean(
|
|
1653
|
+
[
|
|
1654
|
+
with_loss - no_loss
|
|
1655
|
+
for with_loss, no_loss in zip(
|
|
1656
|
+
loss_with_ve_samples,
|
|
1657
|
+
loss_no_ve_samples,
|
|
1658
|
+
strict=False,
|
|
1659
|
+
)
|
|
1660
|
+
]
|
|
1661
|
+
)
|
|
1627
1662
|
)
|
|
1628
|
-
)
|
|
1629
1663
|
predictive_state["mean_delta"] = mean_delta
|
|
1630
1664
|
|
|
1631
1665
|
if delta_ci is not None and all(
|
|
@@ -1872,12 +1906,19 @@ class VarianceGuard(Guard):
|
|
|
1872
1906
|
model: nn.Module,
|
|
1873
1907
|
batches: list[Any],
|
|
1874
1908
|
device: torch.device,
|
|
1875
|
-
|
|
1909
|
+
*,
|
|
1910
|
+
return_counts: bool = False,
|
|
1911
|
+
) -> tuple[list[float], list[float]] | tuple[list[float], list[float], list[int]]:
|
|
1876
1912
|
"""Compute per-batch perplexity and log-loss values for deterministic calibration."""
|
|
1877
1913
|
ppl_values: list[float] = []
|
|
1878
1914
|
loss_values: list[float] = []
|
|
1915
|
+
token_counts: list[int] = []
|
|
1879
1916
|
if not batches:
|
|
1880
|
-
return
|
|
1917
|
+
return (
|
|
1918
|
+
(ppl_values, loss_values, token_counts)
|
|
1919
|
+
if return_counts
|
|
1920
|
+
else (ppl_values, loss_values)
|
|
1921
|
+
)
|
|
1881
1922
|
|
|
1882
1923
|
model_was_training = model.training
|
|
1883
1924
|
model.eval()
|
|
@@ -1916,12 +1957,29 @@ class VarianceGuard(Guard):
|
|
|
1916
1957
|
if math.isfinite(ppl):
|
|
1917
1958
|
ppl_values.append(ppl)
|
|
1918
1959
|
loss_values.append(loss)
|
|
1960
|
+
if return_counts:
|
|
1961
|
+
count = None
|
|
1962
|
+
try:
|
|
1963
|
+
if labels is not None and isinstance(
|
|
1964
|
+
labels, torch.Tensor
|
|
1965
|
+
):
|
|
1966
|
+
count = int((labels != -100).sum().item())
|
|
1967
|
+
except Exception:
|
|
1968
|
+
count = None
|
|
1969
|
+
if count is None:
|
|
1970
|
+
try:
|
|
1971
|
+
count = int(inputs.numel())
|
|
1972
|
+
except Exception:
|
|
1973
|
+
count = 0
|
|
1974
|
+
token_counts.append(int(max(count, 0)))
|
|
1919
1975
|
except Exception:
|
|
1920
1976
|
continue
|
|
1921
1977
|
|
|
1922
1978
|
if model_was_training:
|
|
1923
1979
|
model.train()
|
|
1924
1980
|
|
|
1981
|
+
if return_counts:
|
|
1982
|
+
return ppl_values, loss_values, token_counts
|
|
1925
1983
|
return ppl_values, loss_values
|
|
1926
1984
|
|
|
1927
1985
|
def _bootstrap_mean_ci(
|
|
@@ -2108,12 +2166,17 @@ class VarianceGuard(Guard):
|
|
|
2108
2166
|
if calibration_batches:
|
|
2109
2167
|
device = next(model.parameters()).device
|
|
2110
2168
|
torch.manual_seed(calib_seed)
|
|
2111
|
-
|
|
2112
|
-
|
|
2169
|
+
(
|
|
2170
|
+
ppl_no_ve_samples,
|
|
2171
|
+
loss_no_ve_samples,
|
|
2172
|
+
token_counts,
|
|
2173
|
+
) = self._compute_ppl_for_batches(
|
|
2174
|
+
model, calibration_batches, device, return_counts=True
|
|
2113
2175
|
)
|
|
2114
2176
|
coverage = min(len(calibration_batches), len(ppl_no_ve_samples))
|
|
2115
2177
|
ppl_with_ve_samples: list[float] = []
|
|
2116
2178
|
loss_with_ve_samples: list[float] = []
|
|
2179
|
+
token_counts_with: list[int] = []
|
|
2117
2180
|
ratio_ci: tuple[float, float] | None = None
|
|
2118
2181
|
|
|
2119
2182
|
enable_success = False
|
|
@@ -2132,8 +2195,9 @@ class VarianceGuard(Guard):
|
|
|
2132
2195
|
(
|
|
2133
2196
|
ppl_with_ve_samples,
|
|
2134
2197
|
loss_with_ve_samples,
|
|
2198
|
+
token_counts_with,
|
|
2135
2199
|
) = self._compute_ppl_for_batches(
|
|
2136
|
-
model, calibration_batches, device
|
|
2200
|
+
model, calibration_batches, device, return_counts=True
|
|
2137
2201
|
)
|
|
2138
2202
|
finally:
|
|
2139
2203
|
if enable_success:
|
|
@@ -2146,6 +2210,8 @@ class VarianceGuard(Guard):
|
|
|
2146
2210
|
coverage,
|
|
2147
2211
|
len(ppl_with_ve_samples) if ppl_with_ve_samples else coverage,
|
|
2148
2212
|
len(loss_with_ve_samples) if loss_with_ve_samples else coverage,
|
|
2213
|
+
len(token_counts) if token_counts else coverage,
|
|
2214
|
+
len(token_counts_with) if token_counts_with else coverage,
|
|
2149
2215
|
)
|
|
2150
2216
|
self._calibration_stats.update(
|
|
2151
2217
|
{"coverage": coverage, "status": "insufficient"}
|
|
@@ -2178,6 +2244,8 @@ class VarianceGuard(Guard):
|
|
|
2178
2244
|
loss_no_ve_samples = loss_no_ve_samples[:coverage]
|
|
2179
2245
|
ppl_with_ve_samples = ppl_with_ve_samples[:coverage]
|
|
2180
2246
|
loss_with_ve_samples = loss_with_ve_samples[:coverage]
|
|
2247
|
+
token_counts = token_counts[:coverage]
|
|
2248
|
+
token_counts_with = token_counts_with[:coverage]
|
|
2181
2249
|
|
|
2182
2250
|
ratios = [
|
|
2183
2251
|
with_val / no_val
|
|
@@ -2216,6 +2284,7 @@ class VarianceGuard(Guard):
|
|
|
2216
2284
|
delta_ci = compute_paired_delta_log_ci(
|
|
2217
2285
|
loss_with_ve_samples,
|
|
2218
2286
|
loss_no_ve_samples,
|
|
2287
|
+
weights=token_counts,
|
|
2219
2288
|
method="bca",
|
|
2220
2289
|
replicates=500,
|
|
2221
2290
|
alpha=self._policy.get("alpha", 0.05),
|
|
@@ -2231,18 +2300,31 @@ class VarianceGuard(Guard):
|
|
|
2231
2300
|
)
|
|
2232
2301
|
|
|
2233
2302
|
predictive_state["evaluated"] = True
|
|
2234
|
-
|
|
2235
|
-
|
|
2236
|
-
|
|
2237
|
-
|
|
2238
|
-
|
|
2239
|
-
|
|
2240
|
-
|
|
2241
|
-
|
|
2242
|
-
|
|
2243
|
-
|
|
2303
|
+
if token_counts:
|
|
2304
|
+
sw = 0.0
|
|
2305
|
+
swx = 0.0
|
|
2306
|
+
for with_loss, no_loss, weight in zip(
|
|
2307
|
+
loss_with_ve_samples,
|
|
2308
|
+
loss_no_ve_samples,
|
|
2309
|
+
token_counts,
|
|
2310
|
+
strict=False,
|
|
2311
|
+
):
|
|
2312
|
+
sw += float(weight)
|
|
2313
|
+
swx += float(weight) * (with_loss - no_loss)
|
|
2314
|
+
mean_delta = float(swx / sw) if sw > 0 else float("nan")
|
|
2315
|
+
else:
|
|
2316
|
+
mean_delta = float(
|
|
2317
|
+
np.mean(
|
|
2318
|
+
[
|
|
2319
|
+
with_loss - no_loss
|
|
2320
|
+
for with_loss, no_loss in zip(
|
|
2321
|
+
loss_with_ve_samples,
|
|
2322
|
+
loss_no_ve_samples,
|
|
2323
|
+
strict=False,
|
|
2324
|
+
)
|
|
2325
|
+
]
|
|
2326
|
+
)
|
|
2244
2327
|
)
|
|
2245
|
-
)
|
|
2246
2328
|
predictive_state["mean_delta"] = mean_delta
|
|
2247
2329
|
|
|
2248
2330
|
if delta_ci is not None and all(
|