invarlock 0.3.4__py3-none-any.whl → 0.3.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- invarlock/__init__.py +1 -1
- invarlock/_data/runtime/tiers.yaml +57 -30
- invarlock/adapters/__init__.py +1 -1
- invarlock/calibration/spectral_null.py +15 -10
- invarlock/calibration/variance_ve.py +0 -2
- invarlock/cli/commands/calibrate.py +6 -2
- invarlock/cli/commands/certify.py +58 -39
- invarlock/cli/commands/doctor.py +3 -1
- invarlock/cli/commands/explain_gates.py +57 -8
- invarlock/cli/commands/report.py +1 -1
- invarlock/cli/commands/run.py +159 -61
- invarlock/cli/commands/verify.py +78 -4
- invarlock/cli/config.py +21 -5
- invarlock/core/api.py +45 -5
- invarlock/core/auto_tuning.py +65 -20
- invarlock/core/contracts.py +7 -1
- invarlock/core/registry.py +2 -2
- invarlock/core/runner.py +314 -50
- invarlock/eval/bench.py +0 -13
- invarlock/eval/data.py +73 -283
- invarlock/eval/metrics.py +134 -4
- invarlock/eval/primary_metric.py +23 -0
- invarlock/eval/tail_stats.py +230 -0
- invarlock/guards/_estimators.py +154 -0
- invarlock/guards/policies.py +16 -6
- invarlock/guards/rmt.py +625 -544
- invarlock/guards/spectral.py +348 -110
- invarlock/guards/tier_config.py +32 -30
- invarlock/guards/variance.py +5 -29
- invarlock/guards_ref/rmt_ref.py +23 -23
- invarlock/model_profile.py +42 -15
- invarlock/reporting/certificate.py +225 -46
- invarlock/reporting/certificate_schema.py +2 -1
- invarlock/reporting/dataset_hashing.py +15 -2
- invarlock/reporting/guards_analysis.py +197 -274
- invarlock/reporting/normalizer.py +6 -0
- invarlock/reporting/policy_utils.py +38 -36
- invarlock/reporting/primary_metric_utils.py +71 -17
- invarlock/reporting/render.py +61 -0
- invarlock/reporting/report.py +1 -1
- invarlock/reporting/report_types.py +5 -2
- invarlock/reporting/validate.py +1 -18
- {invarlock-0.3.4.dist-info → invarlock-0.3.6.dist-info}/METADATA +6 -6
- {invarlock-0.3.4.dist-info → invarlock-0.3.6.dist-info}/RECORD +48 -46
- {invarlock-0.3.4.dist-info → invarlock-0.3.6.dist-info}/WHEEL +0 -0
- {invarlock-0.3.4.dist-info → invarlock-0.3.6.dist-info}/entry_points.txt +0 -0
- {invarlock-0.3.4.dist-info → invarlock-0.3.6.dist-info}/licenses/LICENSE +0 -0
- {invarlock-0.3.4.dist-info → invarlock-0.3.6.dist-info}/top_level.txt +0 -0
invarlock/eval/primary_metric.py
CHANGED
|
@@ -623,6 +623,9 @@ def compute_primary_metric_from_report(
|
|
|
623
623
|
"preview": float("nan"),
|
|
624
624
|
"final": float("nan"),
|
|
625
625
|
"ratio_vs_baseline": float("nan"),
|
|
626
|
+
"invalid": True,
|
|
627
|
+
"degraded": True,
|
|
628
|
+
"degraded_reason": "non_finite_pm",
|
|
626
629
|
}
|
|
627
630
|
# For accuracy kinds, derive counts from input_ids if aggregates are missing
|
|
628
631
|
if kind in {"accuracy", "vqa_accuracy"}:
|
|
@@ -661,6 +664,11 @@ def compute_primary_metric_from_report(
|
|
|
661
664
|
final_point = metric.point_from_windows(windows=final_win)
|
|
662
665
|
|
|
663
666
|
ratio_vs_baseline = float("nan")
|
|
667
|
+
baseline_has_reference = False
|
|
668
|
+
|
|
669
|
+
def _is_finite(value: Any) -> bool:
|
|
670
|
+
return isinstance(value, (int, float)) and math.isfinite(float(value))
|
|
671
|
+
|
|
664
672
|
if isinstance(baseline, dict):
|
|
665
673
|
try:
|
|
666
674
|
base_metrics = (
|
|
@@ -686,14 +694,25 @@ def compute_primary_metric_from_report(
|
|
|
686
694
|
is_ppl_like = str(kind).lower().startswith("ppl")
|
|
687
695
|
if is_ppl_like and base_ref > 0:
|
|
688
696
|
ratio_vs_baseline = float(final_point) / float(base_ref)
|
|
697
|
+
baseline_has_reference = True
|
|
689
698
|
elif (
|
|
690
699
|
str(kind).lower() in {"accuracy", "vqa_accuracy"}
|
|
691
700
|
and 0 <= base_ref <= 1
|
|
692
701
|
):
|
|
693
702
|
ratio_vs_baseline = float(final_point) - float(base_ref)
|
|
703
|
+
baseline_has_reference = True
|
|
694
704
|
except Exception:
|
|
695
705
|
ratio_vs_baseline = float("nan")
|
|
696
706
|
|
|
707
|
+
invalid = not (_is_finite(preview_point) and _is_finite(final_point))
|
|
708
|
+
degraded_reason = None
|
|
709
|
+
if invalid:
|
|
710
|
+
degraded_reason = "non_finite_pm"
|
|
711
|
+
elif baseline_has_reference and not _is_finite(ratio_vs_baseline):
|
|
712
|
+
degraded_reason = "non_finite_delta"
|
|
713
|
+
|
|
714
|
+
degraded = bool(degraded_reason or invalid)
|
|
715
|
+
|
|
697
716
|
payload = {
|
|
698
717
|
"kind": metric.kind,
|
|
699
718
|
"unit": metric.unit,
|
|
@@ -705,7 +724,11 @@ def compute_primary_metric_from_report(
|
|
|
705
724
|
"preview": preview_point,
|
|
706
725
|
"final": final_point,
|
|
707
726
|
"ratio_vs_baseline": ratio_vs_baseline,
|
|
727
|
+
"invalid": invalid,
|
|
728
|
+
"degraded": degraded,
|
|
708
729
|
}
|
|
730
|
+
if degraded and degraded_reason:
|
|
731
|
+
payload["degraded_reason"] = degraded_reason
|
|
709
732
|
# Carry counts for accuracy to aid gating
|
|
710
733
|
if kind in {"accuracy", "vqa_accuracy"}:
|
|
711
734
|
if "n_prev" in locals() and n_prev is not None:
|
|
@@ -0,0 +1,230 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import math
|
|
4
|
+
from collections.abc import Mapping, Sequence
|
|
5
|
+
from typing import Any
|
|
6
|
+
|
|
7
|
+
__all__ = [
|
|
8
|
+
"compute_tail_summary",
|
|
9
|
+
"evaluate_metric_tail",
|
|
10
|
+
]
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def _as_finite_float(value: Any) -> float | None:
|
|
14
|
+
try:
|
|
15
|
+
out = float(value)
|
|
16
|
+
except Exception:
|
|
17
|
+
return None
|
|
18
|
+
return out if math.isfinite(out) else None
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def _linear_quantile(sorted_values: Sequence[float], q: float) -> float:
|
|
22
|
+
"""Deterministic linear-interpolated quantile on sorted values (q in [0, 1])."""
|
|
23
|
+
n = len(sorted_values)
|
|
24
|
+
if n == 0:
|
|
25
|
+
return float("nan")
|
|
26
|
+
if n == 1:
|
|
27
|
+
return float(sorted_values[0])
|
|
28
|
+
if q <= 0.0:
|
|
29
|
+
return float(sorted_values[0])
|
|
30
|
+
if q >= 1.0:
|
|
31
|
+
return float(sorted_values[-1])
|
|
32
|
+
pos = float(q) * float(n - 1)
|
|
33
|
+
lo = int(math.floor(pos))
|
|
34
|
+
hi = int(math.ceil(pos))
|
|
35
|
+
if lo == hi:
|
|
36
|
+
return float(sorted_values[lo])
|
|
37
|
+
frac = pos - float(lo)
|
|
38
|
+
a = float(sorted_values[lo])
|
|
39
|
+
b = float(sorted_values[hi])
|
|
40
|
+
return a + frac * (b - a)
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
def compute_tail_summary(
|
|
44
|
+
deltas: Sequence[float] | Sequence[Any],
|
|
45
|
+
*,
|
|
46
|
+
quantiles: Sequence[float] = (0.5, 0.9, 0.95, 0.99),
|
|
47
|
+
epsilon: float = 1e-4,
|
|
48
|
+
weights: Sequence[float] | Sequence[Any] | None = None,
|
|
49
|
+
) -> dict[str, Any]:
|
|
50
|
+
"""Compute deterministic tail summaries for Δlog-loss samples.
|
|
51
|
+
|
|
52
|
+
- Quantiles are computed unweighted using linear interpolation on sorted values.
|
|
53
|
+
- tail_mass is Pr[delta > epsilon] (unweighted).
|
|
54
|
+
- tail_mass_weighted is included when weights are provided and finite.
|
|
55
|
+
"""
|
|
56
|
+
eps = _as_finite_float(epsilon)
|
|
57
|
+
if eps is None or eps < 0.0:
|
|
58
|
+
eps = 0.0
|
|
59
|
+
|
|
60
|
+
values: list[float] = []
|
|
61
|
+
paired_weights: list[float] | None = [] if weights is not None else None
|
|
62
|
+
|
|
63
|
+
if weights is None:
|
|
64
|
+
for d in deltas:
|
|
65
|
+
dv = _as_finite_float(d)
|
|
66
|
+
if dv is None:
|
|
67
|
+
continue
|
|
68
|
+
values.append(float(dv))
|
|
69
|
+
else:
|
|
70
|
+
for d, w in zip(deltas, weights, strict=False):
|
|
71
|
+
dv = _as_finite_float(d)
|
|
72
|
+
if dv is None:
|
|
73
|
+
continue
|
|
74
|
+
wv = _as_finite_float(w)
|
|
75
|
+
if wv is None or wv < 0.0:
|
|
76
|
+
wv = 0.0
|
|
77
|
+
values.append(float(dv))
|
|
78
|
+
if paired_weights is not None:
|
|
79
|
+
paired_weights.append(float(wv))
|
|
80
|
+
|
|
81
|
+
n = int(len(values))
|
|
82
|
+
values_sorted = sorted(values)
|
|
83
|
+
|
|
84
|
+
summary: dict[str, Any] = {
|
|
85
|
+
"n": n,
|
|
86
|
+
"epsilon": float(eps),
|
|
87
|
+
}
|
|
88
|
+
if n == 0:
|
|
89
|
+
summary.update({"max": float("nan"), "tail_mass": 0.0})
|
|
90
|
+
for q in quantiles:
|
|
91
|
+
try:
|
|
92
|
+
qf = float(q)
|
|
93
|
+
except Exception:
|
|
94
|
+
continue
|
|
95
|
+
label = f"q{int(round(100.0 * max(0.0, min(1.0, qf))))}"
|
|
96
|
+
summary[label] = float("nan")
|
|
97
|
+
return summary
|
|
98
|
+
|
|
99
|
+
summary["max"] = float(values_sorted[-1])
|
|
100
|
+
tail_ct = sum(1 for v in values if v > eps)
|
|
101
|
+
summary["tail_mass"] = float(tail_ct / n)
|
|
102
|
+
|
|
103
|
+
if paired_weights is not None:
|
|
104
|
+
total_w = 0.0
|
|
105
|
+
tail_w = 0.0
|
|
106
|
+
for v, w in zip(values, paired_weights, strict=False):
|
|
107
|
+
total_w += float(w)
|
|
108
|
+
if v > eps:
|
|
109
|
+
tail_w += float(w)
|
|
110
|
+
if total_w > 0.0:
|
|
111
|
+
summary["tail_mass_weighted"] = float(tail_w / total_w)
|
|
112
|
+
summary["tail_mass_weighted_by"] = "weights"
|
|
113
|
+
|
|
114
|
+
for q in quantiles:
|
|
115
|
+
try:
|
|
116
|
+
qf = float(q)
|
|
117
|
+
except Exception:
|
|
118
|
+
continue
|
|
119
|
+
qf = max(0.0, min(1.0, qf))
|
|
120
|
+
label = f"q{int(round(100.0 * qf))}"
|
|
121
|
+
summary[label] = float(_linear_quantile(values_sorted, qf))
|
|
122
|
+
|
|
123
|
+
return summary
|
|
124
|
+
|
|
125
|
+
|
|
126
|
+
def evaluate_metric_tail(
|
|
127
|
+
*,
|
|
128
|
+
deltas: Sequence[float] | Sequence[Any],
|
|
129
|
+
policy: Mapping[str, Any] | None = None,
|
|
130
|
+
weights: Sequence[float] | Sequence[Any] | None = None,
|
|
131
|
+
) -> dict[str, Any]:
|
|
132
|
+
"""Evaluate a tail policy against Δlog-loss samples.
|
|
133
|
+
|
|
134
|
+
Policy keys:
|
|
135
|
+
- mode: "off" | "warn" | "fail" (default: "warn")
|
|
136
|
+
- min_windows: int (default: 1)
|
|
137
|
+
- quantile: float in [0, 1] (default: 0.95)
|
|
138
|
+
- quantile_max: float threshold in Δlog-loss (optional)
|
|
139
|
+
- epsilon: float deadband for tail_mass (default: 1e-4)
|
|
140
|
+
- mass_max: float in [0, 1] (optional)
|
|
141
|
+
"""
|
|
142
|
+
pol = dict(policy or {})
|
|
143
|
+
mode = str(pol.get("mode", "warn") or "warn").strip().lower()
|
|
144
|
+
if mode not in {"off", "warn", "fail"}:
|
|
145
|
+
mode = "warn"
|
|
146
|
+
|
|
147
|
+
min_windows = pol.get("min_windows", 1)
|
|
148
|
+
try:
|
|
149
|
+
min_windows_i = int(min_windows)
|
|
150
|
+
except Exception:
|
|
151
|
+
min_windows_i = 1
|
|
152
|
+
min_windows_i = max(1, min_windows_i)
|
|
153
|
+
|
|
154
|
+
q = _as_finite_float(pol.get("quantile", 0.95))
|
|
155
|
+
if q is None:
|
|
156
|
+
q = 0.95
|
|
157
|
+
q = max(0.0, min(1.0, float(q)))
|
|
158
|
+
|
|
159
|
+
eps = _as_finite_float(pol.get("epsilon", 1e-4))
|
|
160
|
+
if eps is None or eps < 0.0:
|
|
161
|
+
eps = 0.0
|
|
162
|
+
|
|
163
|
+
qmax = _as_finite_float(pol.get("quantile_max"))
|
|
164
|
+
mmax = _as_finite_float(pol.get("mass_max"))
|
|
165
|
+
if mmax is not None:
|
|
166
|
+
mmax = max(0.0, min(1.0, float(mmax)))
|
|
167
|
+
|
|
168
|
+
quantiles = sorted({0.5, 0.9, 0.95, 0.99, float(q)})
|
|
169
|
+
stats = compute_tail_summary(
|
|
170
|
+
deltas, quantiles=tuple(quantiles), epsilon=float(eps), weights=weights
|
|
171
|
+
)
|
|
172
|
+
n = int(stats.get("n", 0) or 0)
|
|
173
|
+
|
|
174
|
+
thresholds_present = (qmax is not None) or (mmax is not None)
|
|
175
|
+
evaluated = bool(mode != "off" and thresholds_present and n >= min_windows_i)
|
|
176
|
+
|
|
177
|
+
violations: list[dict[str, Any]] = []
|
|
178
|
+
passed = True
|
|
179
|
+
if evaluated:
|
|
180
|
+
passed = True
|
|
181
|
+
q_label = f"q{int(round(100.0 * q))}"
|
|
182
|
+
q_obs = stats.get(q_label)
|
|
183
|
+
if not (isinstance(q_obs, int | float) and math.isfinite(float(q_obs))):
|
|
184
|
+
q_obs = float("nan")
|
|
185
|
+
if qmax is not None and math.isfinite(q_obs) and q_obs > float(qmax):
|
|
186
|
+
passed = False
|
|
187
|
+
violations.append(
|
|
188
|
+
{
|
|
189
|
+
"type": "quantile_max_exceeded",
|
|
190
|
+
"quantile": float(q),
|
|
191
|
+
"observed": float(q_obs),
|
|
192
|
+
"threshold": float(qmax),
|
|
193
|
+
}
|
|
194
|
+
)
|
|
195
|
+
|
|
196
|
+
tail_mass = stats.get("tail_mass")
|
|
197
|
+
if (
|
|
198
|
+
mmax is not None
|
|
199
|
+
and isinstance(tail_mass, int | float)
|
|
200
|
+
and math.isfinite(float(tail_mass))
|
|
201
|
+
and float(tail_mass) > float(mmax)
|
|
202
|
+
):
|
|
203
|
+
passed = False
|
|
204
|
+
violations.append(
|
|
205
|
+
{
|
|
206
|
+
"type": "tail_mass_exceeded",
|
|
207
|
+
"epsilon": float(eps),
|
|
208
|
+
"observed": float(tail_mass),
|
|
209
|
+
"threshold": float(mmax),
|
|
210
|
+
}
|
|
211
|
+
)
|
|
212
|
+
|
|
213
|
+
warned = bool(evaluated and (not passed) and mode == "warn")
|
|
214
|
+
|
|
215
|
+
return {
|
|
216
|
+
"mode": mode,
|
|
217
|
+
"evaluated": evaluated,
|
|
218
|
+
"passed": bool(passed),
|
|
219
|
+
"warned": warned,
|
|
220
|
+
"violations": violations,
|
|
221
|
+
"policy": {
|
|
222
|
+
"mode": mode,
|
|
223
|
+
"min_windows": int(min_windows_i),
|
|
224
|
+
"quantile": float(q),
|
|
225
|
+
"quantile_max": float(qmax) if qmax is not None else None,
|
|
226
|
+
"epsilon": float(eps),
|
|
227
|
+
"mass_max": float(mmax) if mmax is not None else None,
|
|
228
|
+
},
|
|
229
|
+
"stats": stats,
|
|
230
|
+
}
|
|
@@ -0,0 +1,154 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import math
|
|
4
|
+
from typing import Any
|
|
5
|
+
|
|
6
|
+
import torch
|
|
7
|
+
|
|
8
|
+
__all__ = [
|
|
9
|
+
"power_iter_sigma_max",
|
|
10
|
+
"frobenius_norm_sq",
|
|
11
|
+
"row_col_norm_extrema",
|
|
12
|
+
"stable_rank_estimate",
|
|
13
|
+
]
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def _as_matrix(tensor: torch.Tensor) -> torch.Tensor:
|
|
17
|
+
if tensor.ndim == 2:
|
|
18
|
+
return tensor
|
|
19
|
+
return tensor.view(tensor.shape[0], -1)
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def power_iter_sigma_max(
|
|
23
|
+
matrix: Any,
|
|
24
|
+
*,
|
|
25
|
+
iters: int,
|
|
26
|
+
init: str = "ones",
|
|
27
|
+
eps: float = 1e-12,
|
|
28
|
+
) -> float:
|
|
29
|
+
"""Estimate the largest singular value (spectral norm) via fixed-iter power iteration.
|
|
30
|
+
|
|
31
|
+
Contract properties (vNext):
|
|
32
|
+
- fixed iteration budget (no convergence stopping)
|
|
33
|
+
- deterministic initialization (`init`)
|
|
34
|
+
- device-resident matvecs (no `.cpu()` transfers)
|
|
35
|
+
"""
|
|
36
|
+
try:
|
|
37
|
+
iters_i = int(iters)
|
|
38
|
+
except Exception:
|
|
39
|
+
iters_i = 4
|
|
40
|
+
if iters_i < 1:
|
|
41
|
+
iters_i = 1
|
|
42
|
+
|
|
43
|
+
if not isinstance(matrix, torch.Tensor):
|
|
44
|
+
return 0.0
|
|
45
|
+
if matrix.numel() == 0:
|
|
46
|
+
return 0.0
|
|
47
|
+
if matrix.dtype in {torch.int8, torch.uint8}:
|
|
48
|
+
return 0.0
|
|
49
|
+
|
|
50
|
+
W = _as_matrix(matrix.detach())
|
|
51
|
+
if W.numel() == 0 or W.shape[0] == 0 or W.shape[1] == 0:
|
|
52
|
+
return 0.0
|
|
53
|
+
|
|
54
|
+
device = W.device
|
|
55
|
+
dtype = W.dtype
|
|
56
|
+
n = int(W.shape[1])
|
|
57
|
+
|
|
58
|
+
with torch.no_grad():
|
|
59
|
+
if init == "ones":
|
|
60
|
+
v = torch.ones((n,), device=device, dtype=dtype)
|
|
61
|
+
else:
|
|
62
|
+
# Deterministic fallback: unit vector e0.
|
|
63
|
+
v = torch.zeros((n,), device=device, dtype=dtype)
|
|
64
|
+
v[0] = 1
|
|
65
|
+
|
|
66
|
+
v_norm = torch.linalg.vector_norm(v.float()).clamp_min(eps)
|
|
67
|
+
v = v / v_norm.to(dtype)
|
|
68
|
+
|
|
69
|
+
sigma = 0.0
|
|
70
|
+
for _ in range(iters_i):
|
|
71
|
+
u = W @ v
|
|
72
|
+
u_norm = torch.linalg.vector_norm(u.float()).clamp_min(eps)
|
|
73
|
+
sigma_val = float(u_norm.item())
|
|
74
|
+
if not math.isfinite(sigma_val):
|
|
75
|
+
return 0.0
|
|
76
|
+
u = u / u_norm.to(dtype)
|
|
77
|
+
v = W.T @ u
|
|
78
|
+
v_norm = torch.linalg.vector_norm(v.float()).clamp_min(eps)
|
|
79
|
+
v = v / v_norm.to(dtype)
|
|
80
|
+
sigma = sigma_val
|
|
81
|
+
return float(sigma)
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
def frobenius_norm_sq(matrix: torch.Tensor) -> float:
|
|
85
|
+
"""Return ||matrix||_F^2 with float32 accumulation (device-resident)."""
|
|
86
|
+
W = _as_matrix(matrix.detach())
|
|
87
|
+
if W.numel() == 0:
|
|
88
|
+
return 0.0
|
|
89
|
+
with torch.no_grad():
|
|
90
|
+
# Use a fused reduction to avoid materializing a W*W intermediate.
|
|
91
|
+
norm = torch.linalg.vector_norm(W.reshape(-1), ord=2, dtype=torch.float32)
|
|
92
|
+
out = float((norm * norm).item())
|
|
93
|
+
return out if math.isfinite(out) else 0.0
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
def row_col_norm_extrema(
|
|
97
|
+
matrix: torch.Tensor, *, eps: float = 1e-12
|
|
98
|
+
) -> dict[str, float]:
|
|
99
|
+
"""Compute min/median/max of row/col L2 norms with float32 accumulation."""
|
|
100
|
+
W = _as_matrix(matrix.detach())
|
|
101
|
+
if W.numel() == 0 or W.shape[0] == 0 or W.shape[1] == 0:
|
|
102
|
+
return {
|
|
103
|
+
"row_min": 0.0,
|
|
104
|
+
"row_median": 0.0,
|
|
105
|
+
"row_max": 0.0,
|
|
106
|
+
"col_min": 0.0,
|
|
107
|
+
"col_median": 0.0,
|
|
108
|
+
"col_max": 0.0,
|
|
109
|
+
}
|
|
110
|
+
with torch.no_grad():
|
|
111
|
+
# Avoid materializing W*W: use fused reductions.
|
|
112
|
+
row = torch.linalg.vector_norm(W, ord=2, dim=1, dtype=torch.float32).clamp_min(
|
|
113
|
+
eps
|
|
114
|
+
)
|
|
115
|
+
col = torch.linalg.vector_norm(W, ord=2, dim=0, dtype=torch.float32).clamp_min(
|
|
116
|
+
eps
|
|
117
|
+
)
|
|
118
|
+
|
|
119
|
+
row_sorted, _ = torch.sort(row)
|
|
120
|
+
col_sorted, _ = torch.sort(col)
|
|
121
|
+
|
|
122
|
+
def _median(sorted_vec: torch.Tensor) -> float:
|
|
123
|
+
n = int(sorted_vec.numel())
|
|
124
|
+
if n <= 0:
|
|
125
|
+
return 0.0
|
|
126
|
+
mid = n // 2
|
|
127
|
+
if n % 2 == 1:
|
|
128
|
+
return float(sorted_vec[mid].item())
|
|
129
|
+
return float((sorted_vec[mid - 1] + sorted_vec[mid]).mul(0.5).item())
|
|
130
|
+
|
|
131
|
+
return {
|
|
132
|
+
"row_min": float(row_sorted[0].item()),
|
|
133
|
+
"row_median": _median(row_sorted),
|
|
134
|
+
"row_max": float(row_sorted[-1].item()),
|
|
135
|
+
"col_min": float(col_sorted[0].item()),
|
|
136
|
+
"col_median": _median(col_sorted),
|
|
137
|
+
"col_max": float(col_sorted[-1].item()),
|
|
138
|
+
}
|
|
139
|
+
|
|
140
|
+
|
|
141
|
+
def stable_rank_estimate(
|
|
142
|
+
matrix: torch.Tensor, *, sigma_max: float, eps: float = 1e-12
|
|
143
|
+
) -> float:
|
|
144
|
+
"""Estimate stable rank: ||W||_F^2 / ||W||_2^2, using a provided σ̂max."""
|
|
145
|
+
try:
|
|
146
|
+
denom = float(sigma_max) ** 2
|
|
147
|
+
except Exception:
|
|
148
|
+
return 0.0
|
|
149
|
+
if not math.isfinite(denom) or denom <= 0.0:
|
|
150
|
+
return 0.0
|
|
151
|
+
denom = max(denom, eps)
|
|
152
|
+
num = frobenius_norm_sq(matrix)
|
|
153
|
+
out = float(num) / denom if denom > 0 else 0.0
|
|
154
|
+
return out if math.isfinite(out) else 0.0
|
invarlock/guards/policies.py
CHANGED
|
@@ -15,7 +15,7 @@ from typing import Any, Literal
|
|
|
15
15
|
|
|
16
16
|
try: # Python 3.12+
|
|
17
17
|
from typing import NotRequired, TypedDict
|
|
18
|
-
except ImportError: #
|
|
18
|
+
except ImportError: # Python <3.12 fallback
|
|
19
19
|
from typing import NotRequired
|
|
20
20
|
|
|
21
21
|
from typing_extensions import TypedDict
|
|
@@ -40,6 +40,7 @@ SPECTRAL_CONSERVATIVE: SpectralPolicy = {
|
|
|
40
40
|
"scope": "ffn", # FFN layers only (safest)
|
|
41
41
|
"correction_enabled": True,
|
|
42
42
|
"max_caps": 3,
|
|
43
|
+
"max_spectral_norm": None,
|
|
43
44
|
"multiple_testing": {"method": "bonferroni", "alpha": 0.02, "m": 4},
|
|
44
45
|
}
|
|
45
46
|
|
|
@@ -50,6 +51,7 @@ SPECTRAL_BALANCED: SpectralPolicy = {
|
|
|
50
51
|
"scope": "ffn", # FFN layers only
|
|
51
52
|
"correction_enabled": False,
|
|
52
53
|
"max_caps": 5,
|
|
54
|
+
"max_spectral_norm": None,
|
|
53
55
|
"multiple_testing": {"method": "bh", "alpha": 0.05, "m": 4},
|
|
54
56
|
}
|
|
55
57
|
|
|
@@ -60,6 +62,7 @@ SPECTRAL_AGGRESSIVE: SpectralPolicy = {
|
|
|
60
62
|
"scope": "all", # All layers including attention
|
|
61
63
|
"correction_enabled": True,
|
|
62
64
|
"max_caps": 8,
|
|
65
|
+
"max_spectral_norm": None,
|
|
63
66
|
"multiple_testing": {"method": "bh", "alpha": 0.1, "m": 4},
|
|
64
67
|
}
|
|
65
68
|
|
|
@@ -70,6 +73,7 @@ SPECTRAL_ATTN_AWARE: SpectralPolicy = {
|
|
|
70
73
|
"scope": "attn", # Attention layers only
|
|
71
74
|
"correction_enabled": False,
|
|
72
75
|
"max_caps": 5,
|
|
76
|
+
"max_spectral_norm": None,
|
|
73
77
|
"multiple_testing": {"method": "bh", "alpha": 0.05, "m": 4},
|
|
74
78
|
}
|
|
75
79
|
|
|
@@ -81,7 +85,8 @@ RMT_CONSERVATIVE: RMTPolicyDict = {
|
|
|
81
85
|
"deadband": 0.05, # 5% deadband - strict threshold
|
|
82
86
|
"margin": 1.3, # Lower margin for conservative detection
|
|
83
87
|
"correct": True, # Enable automatic correction
|
|
84
|
-
"
|
|
88
|
+
"epsilon_default": 0.06,
|
|
89
|
+
"epsilon_by_family": {"attn": 0.05, "ffn": 0.06, "embed": 0.07, "other": 0.07},
|
|
85
90
|
}
|
|
86
91
|
|
|
87
92
|
# Balanced RMT policy - good for most use cases
|
|
@@ -90,7 +95,8 @@ RMT_BALANCED: RMTPolicyDict = {
|
|
|
90
95
|
"deadband": 0.10, # 10% deadband - reasonable tolerance
|
|
91
96
|
"margin": 1.5, # Standard margin for outlier detection
|
|
92
97
|
"correct": False, # Monitor-only by default
|
|
93
|
-
"
|
|
98
|
+
"epsilon_default": 0.10,
|
|
99
|
+
"epsilon_by_family": {"attn": 0.08, "ffn": 0.10, "embed": 0.12, "other": 0.12},
|
|
94
100
|
}
|
|
95
101
|
|
|
96
102
|
# Aggressive RMT policy - for research/experimental use
|
|
@@ -99,7 +105,8 @@ RMT_AGGRESSIVE: RMTPolicyDict = {
|
|
|
99
105
|
"deadband": 0.15, # 15% deadband - more permissive
|
|
100
106
|
"margin": 1.8, # Higher margin allows more deviation
|
|
101
107
|
"correct": True, # Enable automatic correction
|
|
102
|
-
"
|
|
108
|
+
"epsilon_default": 0.15,
|
|
109
|
+
"epsilon_by_family": {"attn": 0.15, "ffn": 0.15, "embed": 0.15, "other": 0.15},
|
|
103
110
|
}
|
|
104
111
|
|
|
105
112
|
# === Variance Guard Policies ===
|
|
@@ -276,6 +283,8 @@ def get_spectral_policy(
|
|
|
276
283
|
policy["scope"] = tier_config["scope"]
|
|
277
284
|
if "max_caps" in tier_config:
|
|
278
285
|
policy["max_caps"] = tier_config["max_caps"]
|
|
286
|
+
if "max_spectral_norm" in tier_config:
|
|
287
|
+
policy["max_spectral_norm"] = tier_config["max_spectral_norm"]
|
|
279
288
|
if "family_caps" in tier_config:
|
|
280
289
|
policy["family_caps"] = tier_config["family_caps"]
|
|
281
290
|
if "multiple_testing" in tier_config:
|
|
@@ -390,9 +399,10 @@ def get_rmt_policy(name: str = "balanced", *, use_yaml: bool = True) -> RMTPolic
|
|
|
390
399
|
policy["deadband"] = tier_config["deadband"]
|
|
391
400
|
if "margin" in tier_config:
|
|
392
401
|
policy["margin"] = tier_config["margin"]
|
|
393
|
-
|
|
402
|
+
if "epsilon_default" in tier_config:
|
|
403
|
+
policy["epsilon_default"] = tier_config["epsilon_default"]
|
|
394
404
|
if "epsilon_by_family" in tier_config:
|
|
395
|
-
policy["
|
|
405
|
+
policy["epsilon_by_family"] = tier_config["epsilon_by_family"]
|
|
396
406
|
except Exception:
|
|
397
407
|
# Fallback to hardcoded values on any error
|
|
398
408
|
pass
|