invarlock 0.3.5__py3-none-any.whl → 0.3.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- invarlock/__init__.py +1 -1
- invarlock/_data/runtime/tiers.yaml +57 -30
- invarlock/adapters/__init__.py +1 -1
- invarlock/calibration/spectral_null.py +15 -10
- invarlock/calibration/variance_ve.py +0 -2
- invarlock/cli/commands/calibrate.py +6 -2
- invarlock/cli/commands/certify.py +58 -39
- invarlock/cli/commands/doctor.py +3 -1
- invarlock/cli/commands/explain_gates.py +57 -8
- invarlock/cli/commands/report.py +1 -1
- invarlock/cli/commands/run.py +159 -61
- invarlock/cli/commands/verify.py +78 -4
- invarlock/cli/config.py +21 -5
- invarlock/core/api.py +45 -5
- invarlock/core/auto_tuning.py +65 -20
- invarlock/core/contracts.py +7 -1
- invarlock/core/registry.py +2 -2
- invarlock/core/runner.py +314 -50
- invarlock/eval/bench.py +0 -13
- invarlock/eval/data.py +14 -28
- invarlock/eval/metrics.py +4 -1
- invarlock/eval/primary_metric.py +23 -0
- invarlock/eval/tail_stats.py +230 -0
- invarlock/guards/_estimators.py +154 -0
- invarlock/guards/policies.py +16 -6
- invarlock/guards/rmt.py +625 -544
- invarlock/guards/spectral.py +348 -110
- invarlock/guards/tier_config.py +32 -30
- invarlock/guards/variance.py +5 -29
- invarlock/guards_ref/rmt_ref.py +23 -23
- invarlock/model_profile.py +42 -15
- invarlock/reporting/certificate.py +225 -46
- invarlock/reporting/certificate_schema.py +2 -1
- invarlock/reporting/dataset_hashing.py +15 -2
- invarlock/reporting/guards_analysis.py +197 -274
- invarlock/reporting/normalizer.py +6 -0
- invarlock/reporting/policy_utils.py +38 -36
- invarlock/reporting/primary_metric_utils.py +71 -17
- invarlock/reporting/render.py +61 -0
- invarlock/reporting/report.py +1 -1
- invarlock/reporting/report_types.py +5 -2
- invarlock/reporting/validate.py +1 -18
- {invarlock-0.3.5.dist-info → invarlock-0.3.6.dist-info}/METADATA +6 -6
- {invarlock-0.3.5.dist-info → invarlock-0.3.6.dist-info}/RECORD +48 -46
- {invarlock-0.3.5.dist-info → invarlock-0.3.6.dist-info}/WHEEL +0 -0
- {invarlock-0.3.5.dist-info → invarlock-0.3.6.dist-info}/entry_points.txt +0 -0
- {invarlock-0.3.5.dist-info → invarlock-0.3.6.dist-info}/licenses/LICENSE +0 -0
- {invarlock-0.3.5.dist-info → invarlock-0.3.6.dist-info}/top_level.txt +0 -0
invarlock/eval/data.py
CHANGED
|
@@ -189,8 +189,6 @@ class WikiText2Provider:
|
|
|
189
189
|
def _validate_dependencies(self) -> None:
|
|
190
190
|
"""Check that required dependencies are available."""
|
|
191
191
|
if not HAS_DATASETS:
|
|
192
|
-
if _LIGHT_IMPORT:
|
|
193
|
-
return
|
|
194
192
|
raise _DepErr(
|
|
195
193
|
code="E301",
|
|
196
194
|
message=(
|
|
@@ -328,13 +326,6 @@ class WikiText2Provider:
|
|
|
328
326
|
if cached is not None and len(cached) >= max_samples:
|
|
329
327
|
return cached[:max_samples]
|
|
330
328
|
|
|
331
|
-
if not HAS_DATASETS and _LIGHT_IMPORT:
|
|
332
|
-
texts = ["hello world", "invarlock synthetic text"] * max(
|
|
333
|
-
1, max_samples // 2
|
|
334
|
-
)
|
|
335
|
-
self._texts_cache[split] = texts
|
|
336
|
-
return texts[:max_samples]
|
|
337
|
-
|
|
338
329
|
# Load dataset with size limit for efficiency
|
|
339
330
|
dataset_slice = f"{split}[:{max_samples}]" if max_samples > 0 else split
|
|
340
331
|
dataset = load_dataset(
|
|
@@ -1062,14 +1053,13 @@ class HFTextProvider:
|
|
|
1062
1053
|
max_samples: int = 2000,
|
|
1063
1054
|
):
|
|
1064
1055
|
if not HAS_DATASETS:
|
|
1065
|
-
|
|
1066
|
-
|
|
1067
|
-
|
|
1068
|
-
|
|
1069
|
-
|
|
1070
|
-
|
|
1071
|
-
|
|
1072
|
-
)
|
|
1056
|
+
raise _DepErr(
|
|
1057
|
+
code="E301",
|
|
1058
|
+
message=(
|
|
1059
|
+
"DEPENDENCY-MISSING: datasets library required for hf_text provider"
|
|
1060
|
+
),
|
|
1061
|
+
details={"dependency": "datasets"},
|
|
1062
|
+
)
|
|
1073
1063
|
self.dataset_name = dataset_name or "wikitext"
|
|
1074
1064
|
self.config_name = config_name or None
|
|
1075
1065
|
self.text_field = text_field
|
|
@@ -1077,9 +1067,6 @@ class HFTextProvider:
|
|
|
1077
1067
|
self.max_samples = int(max_samples)
|
|
1078
1068
|
|
|
1079
1069
|
def load(self, split: str = "validation", **kwargs) -> list[str]:
|
|
1080
|
-
if not HAS_DATASETS and _LIGHT_IMPORT:
|
|
1081
|
-
return ["synthetic dataset text"] * int(self.max_samples or 1)
|
|
1082
|
-
|
|
1083
1070
|
ds = load_dataset(
|
|
1084
1071
|
path=self.dataset_name,
|
|
1085
1072
|
name=self.config_name,
|
|
@@ -1204,14 +1191,13 @@ class HFSeq2SeqProvider:
|
|
|
1204
1191
|
max_samples: int = 2000,
|
|
1205
1192
|
) -> None:
|
|
1206
1193
|
if not HAS_DATASETS:
|
|
1207
|
-
|
|
1208
|
-
|
|
1209
|
-
|
|
1210
|
-
|
|
1211
|
-
|
|
1212
|
-
|
|
1213
|
-
|
|
1214
|
-
)
|
|
1194
|
+
raise _DepErr(
|
|
1195
|
+
code="E301",
|
|
1196
|
+
message=(
|
|
1197
|
+
"DEPENDENCY-MISSING: datasets library required for hf_seq2seq provider"
|
|
1198
|
+
),
|
|
1199
|
+
details={"dependency": "datasets"},
|
|
1200
|
+
)
|
|
1215
1201
|
self.dataset_name = dataset_name
|
|
1216
1202
|
self.config_name = config_name
|
|
1217
1203
|
self.src_field = src_field
|
invarlock/eval/metrics.py
CHANGED
|
@@ -723,7 +723,10 @@ def calculate_lens_metrics_for_model(
|
|
|
723
723
|
except Exception as e:
|
|
724
724
|
logger.error(f"Metrics calculation failed: {e}")
|
|
725
725
|
if config.strict_validation:
|
|
726
|
-
raise MetricsError(
|
|
726
|
+
raise MetricsError(
|
|
727
|
+
code="E401",
|
|
728
|
+
message=f"METRICS-COMPUTE-FAILED: {e}",
|
|
729
|
+
) from e
|
|
727
730
|
|
|
728
731
|
finally:
|
|
729
732
|
resource_manager.cleanup()
|
invarlock/eval/primary_metric.py
CHANGED
|
@@ -623,6 +623,9 @@ def compute_primary_metric_from_report(
|
|
|
623
623
|
"preview": float("nan"),
|
|
624
624
|
"final": float("nan"),
|
|
625
625
|
"ratio_vs_baseline": float("nan"),
|
|
626
|
+
"invalid": True,
|
|
627
|
+
"degraded": True,
|
|
628
|
+
"degraded_reason": "non_finite_pm",
|
|
626
629
|
}
|
|
627
630
|
# For accuracy kinds, derive counts from input_ids if aggregates are missing
|
|
628
631
|
if kind in {"accuracy", "vqa_accuracy"}:
|
|
@@ -661,6 +664,11 @@ def compute_primary_metric_from_report(
|
|
|
661
664
|
final_point = metric.point_from_windows(windows=final_win)
|
|
662
665
|
|
|
663
666
|
ratio_vs_baseline = float("nan")
|
|
667
|
+
baseline_has_reference = False
|
|
668
|
+
|
|
669
|
+
def _is_finite(value: Any) -> bool:
|
|
670
|
+
return isinstance(value, (int, float)) and math.isfinite(float(value))
|
|
671
|
+
|
|
664
672
|
if isinstance(baseline, dict):
|
|
665
673
|
try:
|
|
666
674
|
base_metrics = (
|
|
@@ -686,14 +694,25 @@ def compute_primary_metric_from_report(
|
|
|
686
694
|
is_ppl_like = str(kind).lower().startswith("ppl")
|
|
687
695
|
if is_ppl_like and base_ref > 0:
|
|
688
696
|
ratio_vs_baseline = float(final_point) / float(base_ref)
|
|
697
|
+
baseline_has_reference = True
|
|
689
698
|
elif (
|
|
690
699
|
str(kind).lower() in {"accuracy", "vqa_accuracy"}
|
|
691
700
|
and 0 <= base_ref <= 1
|
|
692
701
|
):
|
|
693
702
|
ratio_vs_baseline = float(final_point) - float(base_ref)
|
|
703
|
+
baseline_has_reference = True
|
|
694
704
|
except Exception:
|
|
695
705
|
ratio_vs_baseline = float("nan")
|
|
696
706
|
|
|
707
|
+
invalid = not (_is_finite(preview_point) and _is_finite(final_point))
|
|
708
|
+
degraded_reason = None
|
|
709
|
+
if invalid:
|
|
710
|
+
degraded_reason = "non_finite_pm"
|
|
711
|
+
elif baseline_has_reference and not _is_finite(ratio_vs_baseline):
|
|
712
|
+
degraded_reason = "non_finite_delta"
|
|
713
|
+
|
|
714
|
+
degraded = bool(degraded_reason or invalid)
|
|
715
|
+
|
|
697
716
|
payload = {
|
|
698
717
|
"kind": metric.kind,
|
|
699
718
|
"unit": metric.unit,
|
|
@@ -705,7 +724,11 @@ def compute_primary_metric_from_report(
|
|
|
705
724
|
"preview": preview_point,
|
|
706
725
|
"final": final_point,
|
|
707
726
|
"ratio_vs_baseline": ratio_vs_baseline,
|
|
727
|
+
"invalid": invalid,
|
|
728
|
+
"degraded": degraded,
|
|
708
729
|
}
|
|
730
|
+
if degraded and degraded_reason:
|
|
731
|
+
payload["degraded_reason"] = degraded_reason
|
|
709
732
|
# Carry counts for accuracy to aid gating
|
|
710
733
|
if kind in {"accuracy", "vqa_accuracy"}:
|
|
711
734
|
if "n_prev" in locals() and n_prev is not None:
|
|
@@ -0,0 +1,230 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import math
|
|
4
|
+
from collections.abc import Mapping, Sequence
|
|
5
|
+
from typing import Any
|
|
6
|
+
|
|
7
|
+
__all__ = [
|
|
8
|
+
"compute_tail_summary",
|
|
9
|
+
"evaluate_metric_tail",
|
|
10
|
+
]
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def _as_finite_float(value: Any) -> float | None:
|
|
14
|
+
try:
|
|
15
|
+
out = float(value)
|
|
16
|
+
except Exception:
|
|
17
|
+
return None
|
|
18
|
+
return out if math.isfinite(out) else None
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def _linear_quantile(sorted_values: Sequence[float], q: float) -> float:
|
|
22
|
+
"""Deterministic linear-interpolated quantile on sorted values (q in [0, 1])."""
|
|
23
|
+
n = len(sorted_values)
|
|
24
|
+
if n == 0:
|
|
25
|
+
return float("nan")
|
|
26
|
+
if n == 1:
|
|
27
|
+
return float(sorted_values[0])
|
|
28
|
+
if q <= 0.0:
|
|
29
|
+
return float(sorted_values[0])
|
|
30
|
+
if q >= 1.0:
|
|
31
|
+
return float(sorted_values[-1])
|
|
32
|
+
pos = float(q) * float(n - 1)
|
|
33
|
+
lo = int(math.floor(pos))
|
|
34
|
+
hi = int(math.ceil(pos))
|
|
35
|
+
if lo == hi:
|
|
36
|
+
return float(sorted_values[lo])
|
|
37
|
+
frac = pos - float(lo)
|
|
38
|
+
a = float(sorted_values[lo])
|
|
39
|
+
b = float(sorted_values[hi])
|
|
40
|
+
return a + frac * (b - a)
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
def compute_tail_summary(
|
|
44
|
+
deltas: Sequence[float] | Sequence[Any],
|
|
45
|
+
*,
|
|
46
|
+
quantiles: Sequence[float] = (0.5, 0.9, 0.95, 0.99),
|
|
47
|
+
epsilon: float = 1e-4,
|
|
48
|
+
weights: Sequence[float] | Sequence[Any] | None = None,
|
|
49
|
+
) -> dict[str, Any]:
|
|
50
|
+
"""Compute deterministic tail summaries for Δlog-loss samples.
|
|
51
|
+
|
|
52
|
+
- Quantiles are computed unweighted using linear interpolation on sorted values.
|
|
53
|
+
- tail_mass is Pr[delta > epsilon] (unweighted).
|
|
54
|
+
- tail_mass_weighted is included when weights are provided and finite.
|
|
55
|
+
"""
|
|
56
|
+
eps = _as_finite_float(epsilon)
|
|
57
|
+
if eps is None or eps < 0.0:
|
|
58
|
+
eps = 0.0
|
|
59
|
+
|
|
60
|
+
values: list[float] = []
|
|
61
|
+
paired_weights: list[float] | None = [] if weights is not None else None
|
|
62
|
+
|
|
63
|
+
if weights is None:
|
|
64
|
+
for d in deltas:
|
|
65
|
+
dv = _as_finite_float(d)
|
|
66
|
+
if dv is None:
|
|
67
|
+
continue
|
|
68
|
+
values.append(float(dv))
|
|
69
|
+
else:
|
|
70
|
+
for d, w in zip(deltas, weights, strict=False):
|
|
71
|
+
dv = _as_finite_float(d)
|
|
72
|
+
if dv is None:
|
|
73
|
+
continue
|
|
74
|
+
wv = _as_finite_float(w)
|
|
75
|
+
if wv is None or wv < 0.0:
|
|
76
|
+
wv = 0.0
|
|
77
|
+
values.append(float(dv))
|
|
78
|
+
if paired_weights is not None:
|
|
79
|
+
paired_weights.append(float(wv))
|
|
80
|
+
|
|
81
|
+
n = int(len(values))
|
|
82
|
+
values_sorted = sorted(values)
|
|
83
|
+
|
|
84
|
+
summary: dict[str, Any] = {
|
|
85
|
+
"n": n,
|
|
86
|
+
"epsilon": float(eps),
|
|
87
|
+
}
|
|
88
|
+
if n == 0:
|
|
89
|
+
summary.update({"max": float("nan"), "tail_mass": 0.0})
|
|
90
|
+
for q in quantiles:
|
|
91
|
+
try:
|
|
92
|
+
qf = float(q)
|
|
93
|
+
except Exception:
|
|
94
|
+
continue
|
|
95
|
+
label = f"q{int(round(100.0 * max(0.0, min(1.0, qf))))}"
|
|
96
|
+
summary[label] = float("nan")
|
|
97
|
+
return summary
|
|
98
|
+
|
|
99
|
+
summary["max"] = float(values_sorted[-1])
|
|
100
|
+
tail_ct = sum(1 for v in values if v > eps)
|
|
101
|
+
summary["tail_mass"] = float(tail_ct / n)
|
|
102
|
+
|
|
103
|
+
if paired_weights is not None:
|
|
104
|
+
total_w = 0.0
|
|
105
|
+
tail_w = 0.0
|
|
106
|
+
for v, w in zip(values, paired_weights, strict=False):
|
|
107
|
+
total_w += float(w)
|
|
108
|
+
if v > eps:
|
|
109
|
+
tail_w += float(w)
|
|
110
|
+
if total_w > 0.0:
|
|
111
|
+
summary["tail_mass_weighted"] = float(tail_w / total_w)
|
|
112
|
+
summary["tail_mass_weighted_by"] = "weights"
|
|
113
|
+
|
|
114
|
+
for q in quantiles:
|
|
115
|
+
try:
|
|
116
|
+
qf = float(q)
|
|
117
|
+
except Exception:
|
|
118
|
+
continue
|
|
119
|
+
qf = max(0.0, min(1.0, qf))
|
|
120
|
+
label = f"q{int(round(100.0 * qf))}"
|
|
121
|
+
summary[label] = float(_linear_quantile(values_sorted, qf))
|
|
122
|
+
|
|
123
|
+
return summary
|
|
124
|
+
|
|
125
|
+
|
|
126
|
+
def evaluate_metric_tail(
|
|
127
|
+
*,
|
|
128
|
+
deltas: Sequence[float] | Sequence[Any],
|
|
129
|
+
policy: Mapping[str, Any] | None = None,
|
|
130
|
+
weights: Sequence[float] | Sequence[Any] | None = None,
|
|
131
|
+
) -> dict[str, Any]:
|
|
132
|
+
"""Evaluate a tail policy against Δlog-loss samples.
|
|
133
|
+
|
|
134
|
+
Policy keys:
|
|
135
|
+
- mode: "off" | "warn" | "fail" (default: "warn")
|
|
136
|
+
- min_windows: int (default: 1)
|
|
137
|
+
- quantile: float in [0, 1] (default: 0.95)
|
|
138
|
+
- quantile_max: float threshold in Δlog-loss (optional)
|
|
139
|
+
- epsilon: float deadband for tail_mass (default: 1e-4)
|
|
140
|
+
- mass_max: float in [0, 1] (optional)
|
|
141
|
+
"""
|
|
142
|
+
pol = dict(policy or {})
|
|
143
|
+
mode = str(pol.get("mode", "warn") or "warn").strip().lower()
|
|
144
|
+
if mode not in {"off", "warn", "fail"}:
|
|
145
|
+
mode = "warn"
|
|
146
|
+
|
|
147
|
+
min_windows = pol.get("min_windows", 1)
|
|
148
|
+
try:
|
|
149
|
+
min_windows_i = int(min_windows)
|
|
150
|
+
except Exception:
|
|
151
|
+
min_windows_i = 1
|
|
152
|
+
min_windows_i = max(1, min_windows_i)
|
|
153
|
+
|
|
154
|
+
q = _as_finite_float(pol.get("quantile", 0.95))
|
|
155
|
+
if q is None:
|
|
156
|
+
q = 0.95
|
|
157
|
+
q = max(0.0, min(1.0, float(q)))
|
|
158
|
+
|
|
159
|
+
eps = _as_finite_float(pol.get("epsilon", 1e-4))
|
|
160
|
+
if eps is None or eps < 0.0:
|
|
161
|
+
eps = 0.0
|
|
162
|
+
|
|
163
|
+
qmax = _as_finite_float(pol.get("quantile_max"))
|
|
164
|
+
mmax = _as_finite_float(pol.get("mass_max"))
|
|
165
|
+
if mmax is not None:
|
|
166
|
+
mmax = max(0.0, min(1.0, float(mmax)))
|
|
167
|
+
|
|
168
|
+
quantiles = sorted({0.5, 0.9, 0.95, 0.99, float(q)})
|
|
169
|
+
stats = compute_tail_summary(
|
|
170
|
+
deltas, quantiles=tuple(quantiles), epsilon=float(eps), weights=weights
|
|
171
|
+
)
|
|
172
|
+
n = int(stats.get("n", 0) or 0)
|
|
173
|
+
|
|
174
|
+
thresholds_present = (qmax is not None) or (mmax is not None)
|
|
175
|
+
evaluated = bool(mode != "off" and thresholds_present and n >= min_windows_i)
|
|
176
|
+
|
|
177
|
+
violations: list[dict[str, Any]] = []
|
|
178
|
+
passed = True
|
|
179
|
+
if evaluated:
|
|
180
|
+
passed = True
|
|
181
|
+
q_label = f"q{int(round(100.0 * q))}"
|
|
182
|
+
q_obs = stats.get(q_label)
|
|
183
|
+
if not (isinstance(q_obs, int | float) and math.isfinite(float(q_obs))):
|
|
184
|
+
q_obs = float("nan")
|
|
185
|
+
if qmax is not None and math.isfinite(q_obs) and q_obs > float(qmax):
|
|
186
|
+
passed = False
|
|
187
|
+
violations.append(
|
|
188
|
+
{
|
|
189
|
+
"type": "quantile_max_exceeded",
|
|
190
|
+
"quantile": float(q),
|
|
191
|
+
"observed": float(q_obs),
|
|
192
|
+
"threshold": float(qmax),
|
|
193
|
+
}
|
|
194
|
+
)
|
|
195
|
+
|
|
196
|
+
tail_mass = stats.get("tail_mass")
|
|
197
|
+
if (
|
|
198
|
+
mmax is not None
|
|
199
|
+
and isinstance(tail_mass, int | float)
|
|
200
|
+
and math.isfinite(float(tail_mass))
|
|
201
|
+
and float(tail_mass) > float(mmax)
|
|
202
|
+
):
|
|
203
|
+
passed = False
|
|
204
|
+
violations.append(
|
|
205
|
+
{
|
|
206
|
+
"type": "tail_mass_exceeded",
|
|
207
|
+
"epsilon": float(eps),
|
|
208
|
+
"observed": float(tail_mass),
|
|
209
|
+
"threshold": float(mmax),
|
|
210
|
+
}
|
|
211
|
+
)
|
|
212
|
+
|
|
213
|
+
warned = bool(evaluated and (not passed) and mode == "warn")
|
|
214
|
+
|
|
215
|
+
return {
|
|
216
|
+
"mode": mode,
|
|
217
|
+
"evaluated": evaluated,
|
|
218
|
+
"passed": bool(passed),
|
|
219
|
+
"warned": warned,
|
|
220
|
+
"violations": violations,
|
|
221
|
+
"policy": {
|
|
222
|
+
"mode": mode,
|
|
223
|
+
"min_windows": int(min_windows_i),
|
|
224
|
+
"quantile": float(q),
|
|
225
|
+
"quantile_max": float(qmax) if qmax is not None else None,
|
|
226
|
+
"epsilon": float(eps),
|
|
227
|
+
"mass_max": float(mmax) if mmax is not None else None,
|
|
228
|
+
},
|
|
229
|
+
"stats": stats,
|
|
230
|
+
}
|
|
@@ -0,0 +1,154 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import math
|
|
4
|
+
from typing import Any
|
|
5
|
+
|
|
6
|
+
import torch
|
|
7
|
+
|
|
8
|
+
__all__ = [
|
|
9
|
+
"power_iter_sigma_max",
|
|
10
|
+
"frobenius_norm_sq",
|
|
11
|
+
"row_col_norm_extrema",
|
|
12
|
+
"stable_rank_estimate",
|
|
13
|
+
]
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def _as_matrix(tensor: torch.Tensor) -> torch.Tensor:
|
|
17
|
+
if tensor.ndim == 2:
|
|
18
|
+
return tensor
|
|
19
|
+
return tensor.view(tensor.shape[0], -1)
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def power_iter_sigma_max(
|
|
23
|
+
matrix: Any,
|
|
24
|
+
*,
|
|
25
|
+
iters: int,
|
|
26
|
+
init: str = "ones",
|
|
27
|
+
eps: float = 1e-12,
|
|
28
|
+
) -> float:
|
|
29
|
+
"""Estimate the largest singular value (spectral norm) via fixed-iter power iteration.
|
|
30
|
+
|
|
31
|
+
Contract properties (vNext):
|
|
32
|
+
- fixed iteration budget (no convergence stopping)
|
|
33
|
+
- deterministic initialization (`init`)
|
|
34
|
+
- device-resident matvecs (no `.cpu()` transfers)
|
|
35
|
+
"""
|
|
36
|
+
try:
|
|
37
|
+
iters_i = int(iters)
|
|
38
|
+
except Exception:
|
|
39
|
+
iters_i = 4
|
|
40
|
+
if iters_i < 1:
|
|
41
|
+
iters_i = 1
|
|
42
|
+
|
|
43
|
+
if not isinstance(matrix, torch.Tensor):
|
|
44
|
+
return 0.0
|
|
45
|
+
if matrix.numel() == 0:
|
|
46
|
+
return 0.0
|
|
47
|
+
if matrix.dtype in {torch.int8, torch.uint8}:
|
|
48
|
+
return 0.0
|
|
49
|
+
|
|
50
|
+
W = _as_matrix(matrix.detach())
|
|
51
|
+
if W.numel() == 0 or W.shape[0] == 0 or W.shape[1] == 0:
|
|
52
|
+
return 0.0
|
|
53
|
+
|
|
54
|
+
device = W.device
|
|
55
|
+
dtype = W.dtype
|
|
56
|
+
n = int(W.shape[1])
|
|
57
|
+
|
|
58
|
+
with torch.no_grad():
|
|
59
|
+
if init == "ones":
|
|
60
|
+
v = torch.ones((n,), device=device, dtype=dtype)
|
|
61
|
+
else:
|
|
62
|
+
# Deterministic fallback: unit vector e0.
|
|
63
|
+
v = torch.zeros((n,), device=device, dtype=dtype)
|
|
64
|
+
v[0] = 1
|
|
65
|
+
|
|
66
|
+
v_norm = torch.linalg.vector_norm(v.float()).clamp_min(eps)
|
|
67
|
+
v = v / v_norm.to(dtype)
|
|
68
|
+
|
|
69
|
+
sigma = 0.0
|
|
70
|
+
for _ in range(iters_i):
|
|
71
|
+
u = W @ v
|
|
72
|
+
u_norm = torch.linalg.vector_norm(u.float()).clamp_min(eps)
|
|
73
|
+
sigma_val = float(u_norm.item())
|
|
74
|
+
if not math.isfinite(sigma_val):
|
|
75
|
+
return 0.0
|
|
76
|
+
u = u / u_norm.to(dtype)
|
|
77
|
+
v = W.T @ u
|
|
78
|
+
v_norm = torch.linalg.vector_norm(v.float()).clamp_min(eps)
|
|
79
|
+
v = v / v_norm.to(dtype)
|
|
80
|
+
sigma = sigma_val
|
|
81
|
+
return float(sigma)
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
def frobenius_norm_sq(matrix: torch.Tensor) -> float:
|
|
85
|
+
"""Return ||matrix||_F^2 with float32 accumulation (device-resident)."""
|
|
86
|
+
W = _as_matrix(matrix.detach())
|
|
87
|
+
if W.numel() == 0:
|
|
88
|
+
return 0.0
|
|
89
|
+
with torch.no_grad():
|
|
90
|
+
# Use a fused reduction to avoid materializing a W*W intermediate.
|
|
91
|
+
norm = torch.linalg.vector_norm(W.reshape(-1), ord=2, dtype=torch.float32)
|
|
92
|
+
out = float((norm * norm).item())
|
|
93
|
+
return out if math.isfinite(out) else 0.0
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
def row_col_norm_extrema(
|
|
97
|
+
matrix: torch.Tensor, *, eps: float = 1e-12
|
|
98
|
+
) -> dict[str, float]:
|
|
99
|
+
"""Compute min/median/max of row/col L2 norms with float32 accumulation."""
|
|
100
|
+
W = _as_matrix(matrix.detach())
|
|
101
|
+
if W.numel() == 0 or W.shape[0] == 0 or W.shape[1] == 0:
|
|
102
|
+
return {
|
|
103
|
+
"row_min": 0.0,
|
|
104
|
+
"row_median": 0.0,
|
|
105
|
+
"row_max": 0.0,
|
|
106
|
+
"col_min": 0.0,
|
|
107
|
+
"col_median": 0.0,
|
|
108
|
+
"col_max": 0.0,
|
|
109
|
+
}
|
|
110
|
+
with torch.no_grad():
|
|
111
|
+
# Avoid materializing W*W: use fused reductions.
|
|
112
|
+
row = torch.linalg.vector_norm(W, ord=2, dim=1, dtype=torch.float32).clamp_min(
|
|
113
|
+
eps
|
|
114
|
+
)
|
|
115
|
+
col = torch.linalg.vector_norm(W, ord=2, dim=0, dtype=torch.float32).clamp_min(
|
|
116
|
+
eps
|
|
117
|
+
)
|
|
118
|
+
|
|
119
|
+
row_sorted, _ = torch.sort(row)
|
|
120
|
+
col_sorted, _ = torch.sort(col)
|
|
121
|
+
|
|
122
|
+
def _median(sorted_vec: torch.Tensor) -> float:
|
|
123
|
+
n = int(sorted_vec.numel())
|
|
124
|
+
if n <= 0:
|
|
125
|
+
return 0.0
|
|
126
|
+
mid = n // 2
|
|
127
|
+
if n % 2 == 1:
|
|
128
|
+
return float(sorted_vec[mid].item())
|
|
129
|
+
return float((sorted_vec[mid - 1] + sorted_vec[mid]).mul(0.5).item())
|
|
130
|
+
|
|
131
|
+
return {
|
|
132
|
+
"row_min": float(row_sorted[0].item()),
|
|
133
|
+
"row_median": _median(row_sorted),
|
|
134
|
+
"row_max": float(row_sorted[-1].item()),
|
|
135
|
+
"col_min": float(col_sorted[0].item()),
|
|
136
|
+
"col_median": _median(col_sorted),
|
|
137
|
+
"col_max": float(col_sorted[-1].item()),
|
|
138
|
+
}
|
|
139
|
+
|
|
140
|
+
|
|
141
|
+
def stable_rank_estimate(
|
|
142
|
+
matrix: torch.Tensor, *, sigma_max: float, eps: float = 1e-12
|
|
143
|
+
) -> float:
|
|
144
|
+
"""Estimate stable rank: ||W||_F^2 / ||W||_2^2, using a provided σ̂max."""
|
|
145
|
+
try:
|
|
146
|
+
denom = float(sigma_max) ** 2
|
|
147
|
+
except Exception:
|
|
148
|
+
return 0.0
|
|
149
|
+
if not math.isfinite(denom) or denom <= 0.0:
|
|
150
|
+
return 0.0
|
|
151
|
+
denom = max(denom, eps)
|
|
152
|
+
num = frobenius_norm_sq(matrix)
|
|
153
|
+
out = float(num) / denom if denom > 0 else 0.0
|
|
154
|
+
return out if math.isfinite(out) else 0.0
|
invarlock/guards/policies.py
CHANGED
|
@@ -15,7 +15,7 @@ from typing import Any, Literal
|
|
|
15
15
|
|
|
16
16
|
try: # Python 3.12+
|
|
17
17
|
from typing import NotRequired, TypedDict
|
|
18
|
-
except ImportError: #
|
|
18
|
+
except ImportError: # Python <3.12 fallback
|
|
19
19
|
from typing import NotRequired
|
|
20
20
|
|
|
21
21
|
from typing_extensions import TypedDict
|
|
@@ -40,6 +40,7 @@ SPECTRAL_CONSERVATIVE: SpectralPolicy = {
|
|
|
40
40
|
"scope": "ffn", # FFN layers only (safest)
|
|
41
41
|
"correction_enabled": True,
|
|
42
42
|
"max_caps": 3,
|
|
43
|
+
"max_spectral_norm": None,
|
|
43
44
|
"multiple_testing": {"method": "bonferroni", "alpha": 0.02, "m": 4},
|
|
44
45
|
}
|
|
45
46
|
|
|
@@ -50,6 +51,7 @@ SPECTRAL_BALANCED: SpectralPolicy = {
|
|
|
50
51
|
"scope": "ffn", # FFN layers only
|
|
51
52
|
"correction_enabled": False,
|
|
52
53
|
"max_caps": 5,
|
|
54
|
+
"max_spectral_norm": None,
|
|
53
55
|
"multiple_testing": {"method": "bh", "alpha": 0.05, "m": 4},
|
|
54
56
|
}
|
|
55
57
|
|
|
@@ -60,6 +62,7 @@ SPECTRAL_AGGRESSIVE: SpectralPolicy = {
|
|
|
60
62
|
"scope": "all", # All layers including attention
|
|
61
63
|
"correction_enabled": True,
|
|
62
64
|
"max_caps": 8,
|
|
65
|
+
"max_spectral_norm": None,
|
|
63
66
|
"multiple_testing": {"method": "bh", "alpha": 0.1, "m": 4},
|
|
64
67
|
}
|
|
65
68
|
|
|
@@ -70,6 +73,7 @@ SPECTRAL_ATTN_AWARE: SpectralPolicy = {
|
|
|
70
73
|
"scope": "attn", # Attention layers only
|
|
71
74
|
"correction_enabled": False,
|
|
72
75
|
"max_caps": 5,
|
|
76
|
+
"max_spectral_norm": None,
|
|
73
77
|
"multiple_testing": {"method": "bh", "alpha": 0.05, "m": 4},
|
|
74
78
|
}
|
|
75
79
|
|
|
@@ -81,7 +85,8 @@ RMT_CONSERVATIVE: RMTPolicyDict = {
|
|
|
81
85
|
"deadband": 0.05, # 5% deadband - strict threshold
|
|
82
86
|
"margin": 1.3, # Lower margin for conservative detection
|
|
83
87
|
"correct": True, # Enable automatic correction
|
|
84
|
-
"
|
|
88
|
+
"epsilon_default": 0.06,
|
|
89
|
+
"epsilon_by_family": {"attn": 0.05, "ffn": 0.06, "embed": 0.07, "other": 0.07},
|
|
85
90
|
}
|
|
86
91
|
|
|
87
92
|
# Balanced RMT policy - good for most use cases
|
|
@@ -90,7 +95,8 @@ RMT_BALANCED: RMTPolicyDict = {
|
|
|
90
95
|
"deadband": 0.10, # 10% deadband - reasonable tolerance
|
|
91
96
|
"margin": 1.5, # Standard margin for outlier detection
|
|
92
97
|
"correct": False, # Monitor-only by default
|
|
93
|
-
"
|
|
98
|
+
"epsilon_default": 0.10,
|
|
99
|
+
"epsilon_by_family": {"attn": 0.08, "ffn": 0.10, "embed": 0.12, "other": 0.12},
|
|
94
100
|
}
|
|
95
101
|
|
|
96
102
|
# Aggressive RMT policy - for research/experimental use
|
|
@@ -99,7 +105,8 @@ RMT_AGGRESSIVE: RMTPolicyDict = {
|
|
|
99
105
|
"deadband": 0.15, # 15% deadband - more permissive
|
|
100
106
|
"margin": 1.8, # Higher margin allows more deviation
|
|
101
107
|
"correct": True, # Enable automatic correction
|
|
102
|
-
"
|
|
108
|
+
"epsilon_default": 0.15,
|
|
109
|
+
"epsilon_by_family": {"attn": 0.15, "ffn": 0.15, "embed": 0.15, "other": 0.15},
|
|
103
110
|
}
|
|
104
111
|
|
|
105
112
|
# === Variance Guard Policies ===
|
|
@@ -276,6 +283,8 @@ def get_spectral_policy(
|
|
|
276
283
|
policy["scope"] = tier_config["scope"]
|
|
277
284
|
if "max_caps" in tier_config:
|
|
278
285
|
policy["max_caps"] = tier_config["max_caps"]
|
|
286
|
+
if "max_spectral_norm" in tier_config:
|
|
287
|
+
policy["max_spectral_norm"] = tier_config["max_spectral_norm"]
|
|
279
288
|
if "family_caps" in tier_config:
|
|
280
289
|
policy["family_caps"] = tier_config["family_caps"]
|
|
281
290
|
if "multiple_testing" in tier_config:
|
|
@@ -390,9 +399,10 @@ def get_rmt_policy(name: str = "balanced", *, use_yaml: bool = True) -> RMTPolic
|
|
|
390
399
|
policy["deadband"] = tier_config["deadband"]
|
|
391
400
|
if "margin" in tier_config:
|
|
392
401
|
policy["margin"] = tier_config["margin"]
|
|
393
|
-
|
|
402
|
+
if "epsilon_default" in tier_config:
|
|
403
|
+
policy["epsilon_default"] = tier_config["epsilon_default"]
|
|
394
404
|
if "epsilon_by_family" in tier_config:
|
|
395
|
-
policy["
|
|
405
|
+
policy["epsilon_by_family"] = tier_config["epsilon_by_family"]
|
|
396
406
|
except Exception:
|
|
397
407
|
# Fallback to hardcoded values on any error
|
|
398
408
|
pass
|