invarlock 0.3.5__py3-none-any.whl → 0.3.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- invarlock/__init__.py +2 -2
- invarlock/_data/runtime/tiers.yaml +57 -30
- invarlock/adapters/__init__.py +11 -15
- invarlock/adapters/auto.py +35 -40
- invarlock/adapters/capabilities.py +2 -2
- invarlock/adapters/hf_causal.py +418 -0
- invarlock/adapters/{hf_onnx.py → hf_causal_onnx.py} +3 -3
- invarlock/adapters/hf_mixin.py +25 -4
- invarlock/adapters/{hf_bert.py → hf_mlm.py} +4 -11
- invarlock/adapters/{hf_t5.py → hf_seq2seq.py} +9 -9
- invarlock/calibration/spectral_null.py +15 -10
- invarlock/calibration/variance_ve.py +0 -2
- invarlock/cli/adapter_auto.py +31 -21
- invarlock/cli/app.py +73 -2
- invarlock/cli/commands/calibrate.py +6 -2
- invarlock/cli/commands/certify.py +651 -91
- invarlock/cli/commands/doctor.py +11 -11
- invarlock/cli/commands/explain_gates.py +57 -8
- invarlock/cli/commands/plugins.py +13 -9
- invarlock/cli/commands/report.py +233 -69
- invarlock/cli/commands/run.py +1066 -244
- invarlock/cli/commands/verify.py +154 -15
- invarlock/cli/config.py +22 -6
- invarlock/cli/doctor_helpers.py +4 -5
- invarlock/cli/output.py +193 -0
- invarlock/cli/provenance.py +1 -1
- invarlock/core/api.py +45 -5
- invarlock/core/auto_tuning.py +65 -20
- invarlock/core/bootstrap.py +1 -1
- invarlock/core/contracts.py +7 -1
- invarlock/core/registry.py +11 -13
- invarlock/core/runner.py +425 -75
- invarlock/edits/quant_rtn.py +65 -37
- invarlock/eval/bench.py +3 -16
- invarlock/eval/data.py +82 -51
- invarlock/eval/metrics.py +63 -2
- invarlock/eval/primary_metric.py +23 -0
- invarlock/eval/tail_stats.py +230 -0
- invarlock/eval/tasks/__init__.py +12 -0
- invarlock/eval/tasks/classification.py +48 -0
- invarlock/eval/tasks/qa.py +36 -0
- invarlock/eval/tasks/text_generation.py +102 -0
- invarlock/guards/_estimators.py +154 -0
- invarlock/guards/invariants.py +19 -10
- invarlock/guards/policies.py +16 -6
- invarlock/guards/rmt.py +627 -546
- invarlock/guards/spectral.py +348 -110
- invarlock/guards/tier_config.py +32 -30
- invarlock/guards/variance.py +7 -31
- invarlock/guards_ref/rmt_ref.py +23 -23
- invarlock/model_profile.py +90 -42
- invarlock/observability/health.py +6 -6
- invarlock/observability/metrics.py +108 -0
- invarlock/reporting/certificate.py +384 -55
- invarlock/reporting/certificate_schema.py +3 -2
- invarlock/reporting/dataset_hashing.py +15 -2
- invarlock/reporting/guards_analysis.py +350 -277
- invarlock/reporting/html.py +55 -5
- invarlock/reporting/normalizer.py +13 -0
- invarlock/reporting/policy_utils.py +38 -36
- invarlock/reporting/primary_metric_utils.py +71 -17
- invarlock/reporting/render.py +852 -431
- invarlock/reporting/report.py +40 -4
- invarlock/reporting/report_types.py +11 -3
- invarlock/reporting/telemetry.py +86 -0
- invarlock/reporting/validate.py +1 -18
- {invarlock-0.3.5.dist-info → invarlock-0.3.7.dist-info}/METADATA +27 -13
- {invarlock-0.3.5.dist-info → invarlock-0.3.7.dist-info}/RECORD +72 -65
- {invarlock-0.3.5.dist-info → invarlock-0.3.7.dist-info}/WHEEL +1 -1
- {invarlock-0.3.5.dist-info → invarlock-0.3.7.dist-info}/entry_points.txt +5 -3
- invarlock/adapters/hf_gpt2.py +0 -404
- invarlock/adapters/hf_llama.py +0 -487
- {invarlock-0.3.5.dist-info → invarlock-0.3.7.dist-info}/licenses/LICENSE +0 -0
- {invarlock-0.3.5.dist-info → invarlock-0.3.7.dist-info}/top_level.txt +0 -0
invarlock/eval/primary_metric.py
CHANGED
|
@@ -623,6 +623,9 @@ def compute_primary_metric_from_report(
|
|
|
623
623
|
"preview": float("nan"),
|
|
624
624
|
"final": float("nan"),
|
|
625
625
|
"ratio_vs_baseline": float("nan"),
|
|
626
|
+
"invalid": True,
|
|
627
|
+
"degraded": True,
|
|
628
|
+
"degraded_reason": "non_finite_pm",
|
|
626
629
|
}
|
|
627
630
|
# For accuracy kinds, derive counts from input_ids if aggregates are missing
|
|
628
631
|
if kind in {"accuracy", "vqa_accuracy"}:
|
|
@@ -661,6 +664,11 @@ def compute_primary_metric_from_report(
|
|
|
661
664
|
final_point = metric.point_from_windows(windows=final_win)
|
|
662
665
|
|
|
663
666
|
ratio_vs_baseline = float("nan")
|
|
667
|
+
baseline_has_reference = False
|
|
668
|
+
|
|
669
|
+
def _is_finite(value: Any) -> bool:
|
|
670
|
+
return isinstance(value, (int, float)) and math.isfinite(float(value))
|
|
671
|
+
|
|
664
672
|
if isinstance(baseline, dict):
|
|
665
673
|
try:
|
|
666
674
|
base_metrics = (
|
|
@@ -686,14 +694,25 @@ def compute_primary_metric_from_report(
|
|
|
686
694
|
is_ppl_like = str(kind).lower().startswith("ppl")
|
|
687
695
|
if is_ppl_like and base_ref > 0:
|
|
688
696
|
ratio_vs_baseline = float(final_point) / float(base_ref)
|
|
697
|
+
baseline_has_reference = True
|
|
689
698
|
elif (
|
|
690
699
|
str(kind).lower() in {"accuracy", "vqa_accuracy"}
|
|
691
700
|
and 0 <= base_ref <= 1
|
|
692
701
|
):
|
|
693
702
|
ratio_vs_baseline = float(final_point) - float(base_ref)
|
|
703
|
+
baseline_has_reference = True
|
|
694
704
|
except Exception:
|
|
695
705
|
ratio_vs_baseline = float("nan")
|
|
696
706
|
|
|
707
|
+
invalid = not (_is_finite(preview_point) and _is_finite(final_point))
|
|
708
|
+
degraded_reason = None
|
|
709
|
+
if invalid:
|
|
710
|
+
degraded_reason = "non_finite_pm"
|
|
711
|
+
elif baseline_has_reference and not _is_finite(ratio_vs_baseline):
|
|
712
|
+
degraded_reason = "non_finite_delta"
|
|
713
|
+
|
|
714
|
+
degraded = bool(degraded_reason or invalid)
|
|
715
|
+
|
|
697
716
|
payload = {
|
|
698
717
|
"kind": metric.kind,
|
|
699
718
|
"unit": metric.unit,
|
|
@@ -705,7 +724,11 @@ def compute_primary_metric_from_report(
|
|
|
705
724
|
"preview": preview_point,
|
|
706
725
|
"final": final_point,
|
|
707
726
|
"ratio_vs_baseline": ratio_vs_baseline,
|
|
727
|
+
"invalid": invalid,
|
|
728
|
+
"degraded": degraded,
|
|
708
729
|
}
|
|
730
|
+
if degraded and degraded_reason:
|
|
731
|
+
payload["degraded_reason"] = degraded_reason
|
|
709
732
|
# Carry counts for accuracy to aid gating
|
|
710
733
|
if kind in {"accuracy", "vqa_accuracy"}:
|
|
711
734
|
if "n_prev" in locals() and n_prev is not None:
|
|
@@ -0,0 +1,230 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import math
|
|
4
|
+
from collections.abc import Mapping, Sequence
|
|
5
|
+
from typing import Any
|
|
6
|
+
|
|
7
|
+
__all__ = [
|
|
8
|
+
"compute_tail_summary",
|
|
9
|
+
"evaluate_metric_tail",
|
|
10
|
+
]
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def _as_finite_float(value: Any) -> float | None:
|
|
14
|
+
try:
|
|
15
|
+
out = float(value)
|
|
16
|
+
except Exception:
|
|
17
|
+
return None
|
|
18
|
+
return out if math.isfinite(out) else None
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def _linear_quantile(sorted_values: Sequence[float], q: float) -> float:
|
|
22
|
+
"""Deterministic linear-interpolated quantile on sorted values (q in [0, 1])."""
|
|
23
|
+
n = len(sorted_values)
|
|
24
|
+
if n == 0:
|
|
25
|
+
return float("nan")
|
|
26
|
+
if n == 1:
|
|
27
|
+
return float(sorted_values[0])
|
|
28
|
+
if q <= 0.0:
|
|
29
|
+
return float(sorted_values[0])
|
|
30
|
+
if q >= 1.0:
|
|
31
|
+
return float(sorted_values[-1])
|
|
32
|
+
pos = float(q) * float(n - 1)
|
|
33
|
+
lo = int(math.floor(pos))
|
|
34
|
+
hi = int(math.ceil(pos))
|
|
35
|
+
if lo == hi:
|
|
36
|
+
return float(sorted_values[lo])
|
|
37
|
+
frac = pos - float(lo)
|
|
38
|
+
a = float(sorted_values[lo])
|
|
39
|
+
b = float(sorted_values[hi])
|
|
40
|
+
return a + frac * (b - a)
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
def compute_tail_summary(
|
|
44
|
+
deltas: Sequence[float] | Sequence[Any],
|
|
45
|
+
*,
|
|
46
|
+
quantiles: Sequence[float] = (0.5, 0.9, 0.95, 0.99),
|
|
47
|
+
epsilon: float = 1e-4,
|
|
48
|
+
weights: Sequence[float] | Sequence[Any] | None = None,
|
|
49
|
+
) -> dict[str, Any]:
|
|
50
|
+
"""Compute deterministic tail summaries for Δlog-loss samples.
|
|
51
|
+
|
|
52
|
+
- Quantiles are computed unweighted using linear interpolation on sorted values.
|
|
53
|
+
- tail_mass is Pr[delta > epsilon] (unweighted).
|
|
54
|
+
- tail_mass_weighted is included when weights are provided and finite.
|
|
55
|
+
"""
|
|
56
|
+
eps = _as_finite_float(epsilon)
|
|
57
|
+
if eps is None or eps < 0.0:
|
|
58
|
+
eps = 0.0
|
|
59
|
+
|
|
60
|
+
values: list[float] = []
|
|
61
|
+
paired_weights: list[float] | None = [] if weights is not None else None
|
|
62
|
+
|
|
63
|
+
if weights is None:
|
|
64
|
+
for d in deltas:
|
|
65
|
+
dv = _as_finite_float(d)
|
|
66
|
+
if dv is None:
|
|
67
|
+
continue
|
|
68
|
+
values.append(float(dv))
|
|
69
|
+
else:
|
|
70
|
+
for d, w in zip(deltas, weights, strict=False):
|
|
71
|
+
dv = _as_finite_float(d)
|
|
72
|
+
if dv is None:
|
|
73
|
+
continue
|
|
74
|
+
wv = _as_finite_float(w)
|
|
75
|
+
if wv is None or wv < 0.0:
|
|
76
|
+
wv = 0.0
|
|
77
|
+
values.append(float(dv))
|
|
78
|
+
if paired_weights is not None:
|
|
79
|
+
paired_weights.append(float(wv))
|
|
80
|
+
|
|
81
|
+
n = int(len(values))
|
|
82
|
+
values_sorted = sorted(values)
|
|
83
|
+
|
|
84
|
+
summary: dict[str, Any] = {
|
|
85
|
+
"n": n,
|
|
86
|
+
"epsilon": float(eps),
|
|
87
|
+
}
|
|
88
|
+
if n == 0:
|
|
89
|
+
summary.update({"max": float("nan"), "tail_mass": 0.0})
|
|
90
|
+
for q in quantiles:
|
|
91
|
+
try:
|
|
92
|
+
qf = float(q)
|
|
93
|
+
except Exception:
|
|
94
|
+
continue
|
|
95
|
+
label = f"q{int(round(100.0 * max(0.0, min(1.0, qf))))}"
|
|
96
|
+
summary[label] = float("nan")
|
|
97
|
+
return summary
|
|
98
|
+
|
|
99
|
+
summary["max"] = float(values_sorted[-1])
|
|
100
|
+
tail_ct = sum(1 for v in values if v > eps)
|
|
101
|
+
summary["tail_mass"] = float(tail_ct / n)
|
|
102
|
+
|
|
103
|
+
if paired_weights is not None:
|
|
104
|
+
total_w = 0.0
|
|
105
|
+
tail_w = 0.0
|
|
106
|
+
for v, w in zip(values, paired_weights, strict=False):
|
|
107
|
+
total_w += float(w)
|
|
108
|
+
if v > eps:
|
|
109
|
+
tail_w += float(w)
|
|
110
|
+
if total_w > 0.0:
|
|
111
|
+
summary["tail_mass_weighted"] = float(tail_w / total_w)
|
|
112
|
+
summary["tail_mass_weighted_by"] = "weights"
|
|
113
|
+
|
|
114
|
+
for q in quantiles:
|
|
115
|
+
try:
|
|
116
|
+
qf = float(q)
|
|
117
|
+
except Exception:
|
|
118
|
+
continue
|
|
119
|
+
qf = max(0.0, min(1.0, qf))
|
|
120
|
+
label = f"q{int(round(100.0 * qf))}"
|
|
121
|
+
summary[label] = float(_linear_quantile(values_sorted, qf))
|
|
122
|
+
|
|
123
|
+
return summary
|
|
124
|
+
|
|
125
|
+
|
|
126
|
+
def evaluate_metric_tail(
|
|
127
|
+
*,
|
|
128
|
+
deltas: Sequence[float] | Sequence[Any],
|
|
129
|
+
policy: Mapping[str, Any] | None = None,
|
|
130
|
+
weights: Sequence[float] | Sequence[Any] | None = None,
|
|
131
|
+
) -> dict[str, Any]:
|
|
132
|
+
"""Evaluate a tail policy against Δlog-loss samples.
|
|
133
|
+
|
|
134
|
+
Policy keys:
|
|
135
|
+
- mode: "off" | "warn" | "fail" (default: "warn")
|
|
136
|
+
- min_windows: int (default: 1)
|
|
137
|
+
- quantile: float in [0, 1] (default: 0.95)
|
|
138
|
+
- quantile_max: float threshold in Δlog-loss (optional)
|
|
139
|
+
- epsilon: float deadband for tail_mass (default: 1e-4)
|
|
140
|
+
- mass_max: float in [0, 1] (optional)
|
|
141
|
+
"""
|
|
142
|
+
pol = dict(policy or {})
|
|
143
|
+
mode = str(pol.get("mode", "warn") or "warn").strip().lower()
|
|
144
|
+
if mode not in {"off", "warn", "fail"}:
|
|
145
|
+
mode = "warn"
|
|
146
|
+
|
|
147
|
+
min_windows = pol.get("min_windows", 1)
|
|
148
|
+
try:
|
|
149
|
+
min_windows_i = int(min_windows)
|
|
150
|
+
except Exception:
|
|
151
|
+
min_windows_i = 1
|
|
152
|
+
min_windows_i = max(1, min_windows_i)
|
|
153
|
+
|
|
154
|
+
q = _as_finite_float(pol.get("quantile", 0.95))
|
|
155
|
+
if q is None:
|
|
156
|
+
q = 0.95
|
|
157
|
+
q = max(0.0, min(1.0, float(q)))
|
|
158
|
+
|
|
159
|
+
eps = _as_finite_float(pol.get("epsilon", 1e-4))
|
|
160
|
+
if eps is None or eps < 0.0:
|
|
161
|
+
eps = 0.0
|
|
162
|
+
|
|
163
|
+
qmax = _as_finite_float(pol.get("quantile_max"))
|
|
164
|
+
mmax = _as_finite_float(pol.get("mass_max"))
|
|
165
|
+
if mmax is not None:
|
|
166
|
+
mmax = max(0.0, min(1.0, float(mmax)))
|
|
167
|
+
|
|
168
|
+
quantiles = sorted({0.5, 0.9, 0.95, 0.99, float(q)})
|
|
169
|
+
stats = compute_tail_summary(
|
|
170
|
+
deltas, quantiles=tuple(quantiles), epsilon=float(eps), weights=weights
|
|
171
|
+
)
|
|
172
|
+
n = int(stats.get("n", 0) or 0)
|
|
173
|
+
|
|
174
|
+
thresholds_present = (qmax is not None) or (mmax is not None)
|
|
175
|
+
evaluated = bool(mode != "off" and thresholds_present and n >= min_windows_i)
|
|
176
|
+
|
|
177
|
+
violations: list[dict[str, Any]] = []
|
|
178
|
+
passed = True
|
|
179
|
+
if evaluated:
|
|
180
|
+
passed = True
|
|
181
|
+
q_label = f"q{int(round(100.0 * q))}"
|
|
182
|
+
q_obs = stats.get(q_label)
|
|
183
|
+
if not (isinstance(q_obs, int | float) and math.isfinite(float(q_obs))):
|
|
184
|
+
q_obs = float("nan")
|
|
185
|
+
if qmax is not None and math.isfinite(q_obs) and q_obs > float(qmax):
|
|
186
|
+
passed = False
|
|
187
|
+
violations.append(
|
|
188
|
+
{
|
|
189
|
+
"type": "quantile_max_exceeded",
|
|
190
|
+
"quantile": float(q),
|
|
191
|
+
"observed": float(q_obs),
|
|
192
|
+
"threshold": float(qmax),
|
|
193
|
+
}
|
|
194
|
+
)
|
|
195
|
+
|
|
196
|
+
tail_mass = stats.get("tail_mass")
|
|
197
|
+
if (
|
|
198
|
+
mmax is not None
|
|
199
|
+
and isinstance(tail_mass, int | float)
|
|
200
|
+
and math.isfinite(float(tail_mass))
|
|
201
|
+
and float(tail_mass) > float(mmax)
|
|
202
|
+
):
|
|
203
|
+
passed = False
|
|
204
|
+
violations.append(
|
|
205
|
+
{
|
|
206
|
+
"type": "tail_mass_exceeded",
|
|
207
|
+
"epsilon": float(eps),
|
|
208
|
+
"observed": float(tail_mass),
|
|
209
|
+
"threshold": float(mmax),
|
|
210
|
+
}
|
|
211
|
+
)
|
|
212
|
+
|
|
213
|
+
warned = bool(evaluated and (not passed) and mode == "warn")
|
|
214
|
+
|
|
215
|
+
return {
|
|
216
|
+
"mode": mode,
|
|
217
|
+
"evaluated": evaluated,
|
|
218
|
+
"passed": bool(passed),
|
|
219
|
+
"warned": warned,
|
|
220
|
+
"violations": violations,
|
|
221
|
+
"policy": {
|
|
222
|
+
"mode": mode,
|
|
223
|
+
"min_windows": int(min_windows_i),
|
|
224
|
+
"quantile": float(q),
|
|
225
|
+
"quantile_max": float(qmax) if qmax is not None else None,
|
|
226
|
+
"epsilon": float(eps),
|
|
227
|
+
"mass_max": float(mmax) if mmax is not None else None,
|
|
228
|
+
},
|
|
229
|
+
"stats": stats,
|
|
230
|
+
}
|
|
@@ -0,0 +1,12 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from .classification import accuracy_from_records
|
|
4
|
+
from .qa import exact_match_from_records
|
|
5
|
+
from .text_generation import bleu1_from_records, rouge_l_from_records
|
|
6
|
+
|
|
7
|
+
__all__ = [
|
|
8
|
+
"accuracy_from_records",
|
|
9
|
+
"exact_match_from_records",
|
|
10
|
+
"bleu1_from_records",
|
|
11
|
+
"rouge_l_from_records",
|
|
12
|
+
]
|
|
@@ -0,0 +1,48 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from collections.abc import Iterable
|
|
4
|
+
from typing import Any
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
def _iter_pairs(record: dict[str, Any]) -> list[tuple[Any, Any]]:
|
|
8
|
+
if "correct" in record:
|
|
9
|
+
return [(bool(record.get("correct")), True)]
|
|
10
|
+
|
|
11
|
+
label = record.get("label")
|
|
12
|
+
pred = record.get("prediction")
|
|
13
|
+
if label is None:
|
|
14
|
+
label = record.get("labels")
|
|
15
|
+
if pred is None:
|
|
16
|
+
pred = record.get("pred")
|
|
17
|
+
if pred is None:
|
|
18
|
+
pred = record.get("predictions")
|
|
19
|
+
|
|
20
|
+
if isinstance(label, list) and isinstance(pred, list):
|
|
21
|
+
return list(zip(label, pred, strict=False))
|
|
22
|
+
if label is None or pred is None:
|
|
23
|
+
return []
|
|
24
|
+
return [(label, pred)]
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def accuracy_from_records(records: Iterable[dict[str, Any]]) -> float:
|
|
28
|
+
"""Compute accuracy from records with labels/predictions.
|
|
29
|
+
|
|
30
|
+
Accepted record shapes:
|
|
31
|
+
- {"label": <label>, "prediction": <label>}
|
|
32
|
+
- {"labels": [...], "predictions": [...]}
|
|
33
|
+
- {"correct": <bool>}
|
|
34
|
+
"""
|
|
35
|
+
total = 0
|
|
36
|
+
correct = 0
|
|
37
|
+
for record in records:
|
|
38
|
+
if not isinstance(record, dict):
|
|
39
|
+
continue
|
|
40
|
+
for label, pred in _iter_pairs(record):
|
|
41
|
+
total += 1
|
|
42
|
+
if isinstance(label, bool):
|
|
43
|
+
correct += int(label is pred)
|
|
44
|
+
else:
|
|
45
|
+
correct += int(label == pred)
|
|
46
|
+
if total == 0:
|
|
47
|
+
return float("nan")
|
|
48
|
+
return float(correct / total)
|
|
@@ -0,0 +1,36 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from collections.abc import Iterable
|
|
4
|
+
from typing import Any
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
def _normalize(text: str) -> str:
|
|
8
|
+
return " ".join(str(text).strip().lower().split())
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def exact_match_from_records(records: Iterable[dict[str, Any]]) -> float:
|
|
12
|
+
"""Compute exact-match accuracy for QA-style records.
|
|
13
|
+
|
|
14
|
+
Accepted record shapes:
|
|
15
|
+
- {"prediction": "...", "answer": "..."}
|
|
16
|
+
- {"prediction": "...", "answers": ["...", ...]}
|
|
17
|
+
"""
|
|
18
|
+
total = 0
|
|
19
|
+
correct = 0
|
|
20
|
+
for record in records:
|
|
21
|
+
if not isinstance(record, dict):
|
|
22
|
+
continue
|
|
23
|
+
pred = record.get("prediction")
|
|
24
|
+
answers = record.get("answers")
|
|
25
|
+
if answers is None and "answer" in record:
|
|
26
|
+
answers = [record.get("answer")]
|
|
27
|
+
if pred is None or answers is None:
|
|
28
|
+
continue
|
|
29
|
+
pred_norm = _normalize(pred)
|
|
30
|
+
answer_list = answers if isinstance(answers, list) else [answers]
|
|
31
|
+
total += 1
|
|
32
|
+
if any(_normalize(a) == pred_norm for a in answer_list if a is not None):
|
|
33
|
+
correct += 1
|
|
34
|
+
if total == 0:
|
|
35
|
+
return float("nan")
|
|
36
|
+
return float(correct / total)
|
|
@@ -0,0 +1,102 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from collections import Counter
|
|
4
|
+
from collections.abc import Iterable
|
|
5
|
+
from typing import Any
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def _tokenize(text: str) -> list[str]:
|
|
9
|
+
return [tok for tok in str(text).strip().lower().split() if tok]
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def _bleu1(pred: str, ref: str) -> float:
|
|
13
|
+
pred_tokens = _tokenize(pred)
|
|
14
|
+
ref_tokens = _tokenize(ref)
|
|
15
|
+
if not pred_tokens or not ref_tokens:
|
|
16
|
+
return 0.0
|
|
17
|
+
pred_counts = Counter(pred_tokens)
|
|
18
|
+
ref_counts = Counter(ref_tokens)
|
|
19
|
+
overlap = sum(min(pred_counts[tok], ref_counts.get(tok, 0)) for tok in pred_counts)
|
|
20
|
+
precision = overlap / float(len(pred_tokens))
|
|
21
|
+
bp = 1.0
|
|
22
|
+
if len(pred_tokens) < len(ref_tokens):
|
|
23
|
+
bp = pow(2.718281828, 1.0 - (len(ref_tokens) / float(len(pred_tokens))))
|
|
24
|
+
return float(precision * bp)
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def bleu1_from_records(records: Iterable[dict[str, Any]]) -> float:
|
|
28
|
+
"""Compute BLEU-1 from records with predictions and references."""
|
|
29
|
+
scores: list[float] = []
|
|
30
|
+
for record in records:
|
|
31
|
+
if not isinstance(record, dict):
|
|
32
|
+
continue
|
|
33
|
+
pred = record.get("prediction")
|
|
34
|
+
refs = record.get("references")
|
|
35
|
+
if pred is None:
|
|
36
|
+
continue
|
|
37
|
+
if refs is None and "reference" in record:
|
|
38
|
+
refs = [record.get("reference")]
|
|
39
|
+
if refs is None:
|
|
40
|
+
continue
|
|
41
|
+
ref_list = refs if isinstance(refs, list) else [refs]
|
|
42
|
+
best = 0.0
|
|
43
|
+
for ref in ref_list:
|
|
44
|
+
if ref is None:
|
|
45
|
+
continue
|
|
46
|
+
best = max(best, _bleu1(str(pred), str(ref)))
|
|
47
|
+
scores.append(best)
|
|
48
|
+
if not scores:
|
|
49
|
+
return float("nan")
|
|
50
|
+
return float(sum(scores) / float(len(scores)))
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
def _lcs_len(a: list[str], b: list[str]) -> int:
|
|
54
|
+
if not a or not b:
|
|
55
|
+
return 0
|
|
56
|
+
dp = [[0] * (len(b) + 1) for _ in range(len(a) + 1)]
|
|
57
|
+
for i, tok_a in enumerate(a, start=1):
|
|
58
|
+
for j, tok_b in enumerate(b, start=1):
|
|
59
|
+
if tok_a == tok_b:
|
|
60
|
+
dp[i][j] = dp[i - 1][j - 1] + 1
|
|
61
|
+
else:
|
|
62
|
+
dp[i][j] = max(dp[i - 1][j], dp[i][j - 1])
|
|
63
|
+
return dp[-1][-1]
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
def _rouge_l(pred: str, ref: str) -> float:
|
|
67
|
+
pred_tokens = _tokenize(pred)
|
|
68
|
+
ref_tokens = _tokenize(ref)
|
|
69
|
+
if not pred_tokens or not ref_tokens:
|
|
70
|
+
return 0.0
|
|
71
|
+
lcs = _lcs_len(pred_tokens, ref_tokens)
|
|
72
|
+
prec = lcs / float(len(pred_tokens))
|
|
73
|
+
rec = lcs / float(len(ref_tokens))
|
|
74
|
+
if prec + rec == 0:
|
|
75
|
+
return 0.0
|
|
76
|
+
return float(2 * prec * rec / (prec + rec))
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
def rouge_l_from_records(records: Iterable[dict[str, Any]]) -> float:
|
|
80
|
+
"""Compute ROUGE-L (F1) from records with predictions and references."""
|
|
81
|
+
scores: list[float] = []
|
|
82
|
+
for record in records:
|
|
83
|
+
if not isinstance(record, dict):
|
|
84
|
+
continue
|
|
85
|
+
pred = record.get("prediction")
|
|
86
|
+
refs = record.get("references")
|
|
87
|
+
if pred is None:
|
|
88
|
+
continue
|
|
89
|
+
if refs is None and "reference" in record:
|
|
90
|
+
refs = [record.get("reference")]
|
|
91
|
+
if refs is None:
|
|
92
|
+
continue
|
|
93
|
+
ref_list = refs if isinstance(refs, list) else [refs]
|
|
94
|
+
best = 0.0
|
|
95
|
+
for ref in ref_list:
|
|
96
|
+
if ref is None:
|
|
97
|
+
continue
|
|
98
|
+
best = max(best, _rouge_l(str(pred), str(ref)))
|
|
99
|
+
scores.append(best)
|
|
100
|
+
if not scores:
|
|
101
|
+
return float("nan")
|
|
102
|
+
return float(sum(scores) / float(len(scores)))
|
|
@@ -0,0 +1,154 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import math
|
|
4
|
+
from typing import Any
|
|
5
|
+
|
|
6
|
+
import torch
|
|
7
|
+
|
|
8
|
+
__all__ = [
|
|
9
|
+
"power_iter_sigma_max",
|
|
10
|
+
"frobenius_norm_sq",
|
|
11
|
+
"row_col_norm_extrema",
|
|
12
|
+
"stable_rank_estimate",
|
|
13
|
+
]
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def _as_matrix(tensor: torch.Tensor) -> torch.Tensor:
|
|
17
|
+
if tensor.ndim == 2:
|
|
18
|
+
return tensor
|
|
19
|
+
return tensor.view(tensor.shape[0], -1)
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def power_iter_sigma_max(
|
|
23
|
+
matrix: Any,
|
|
24
|
+
*,
|
|
25
|
+
iters: int,
|
|
26
|
+
init: str = "ones",
|
|
27
|
+
eps: float = 1e-12,
|
|
28
|
+
) -> float:
|
|
29
|
+
"""Estimate the largest singular value (spectral norm) via fixed-iter power iteration.
|
|
30
|
+
|
|
31
|
+
Contract properties (vNext):
|
|
32
|
+
- fixed iteration budget (no convergence stopping)
|
|
33
|
+
- deterministic initialization (`init`)
|
|
34
|
+
- device-resident matvecs (no `.cpu()` transfers)
|
|
35
|
+
"""
|
|
36
|
+
try:
|
|
37
|
+
iters_i = int(iters)
|
|
38
|
+
except Exception:
|
|
39
|
+
iters_i = 4
|
|
40
|
+
if iters_i < 1:
|
|
41
|
+
iters_i = 1
|
|
42
|
+
|
|
43
|
+
if not isinstance(matrix, torch.Tensor):
|
|
44
|
+
return 0.0
|
|
45
|
+
if matrix.numel() == 0:
|
|
46
|
+
return 0.0
|
|
47
|
+
if matrix.dtype in {torch.int8, torch.uint8}:
|
|
48
|
+
return 0.0
|
|
49
|
+
|
|
50
|
+
W = _as_matrix(matrix.detach())
|
|
51
|
+
if W.numel() == 0 or W.shape[0] == 0 or W.shape[1] == 0:
|
|
52
|
+
return 0.0
|
|
53
|
+
|
|
54
|
+
device = W.device
|
|
55
|
+
dtype = W.dtype
|
|
56
|
+
n = int(W.shape[1])
|
|
57
|
+
|
|
58
|
+
with torch.no_grad():
|
|
59
|
+
if init == "ones":
|
|
60
|
+
v = torch.ones((n,), device=device, dtype=dtype)
|
|
61
|
+
else:
|
|
62
|
+
# Deterministic fallback: unit vector e0.
|
|
63
|
+
v = torch.zeros((n,), device=device, dtype=dtype)
|
|
64
|
+
v[0] = 1
|
|
65
|
+
|
|
66
|
+
v_norm = torch.linalg.vector_norm(v.float()).clamp_min(eps)
|
|
67
|
+
v = v / v_norm.to(dtype)
|
|
68
|
+
|
|
69
|
+
sigma = 0.0
|
|
70
|
+
for _ in range(iters_i):
|
|
71
|
+
u = W @ v
|
|
72
|
+
u_norm = torch.linalg.vector_norm(u.float()).clamp_min(eps)
|
|
73
|
+
sigma_val = float(u_norm.item())
|
|
74
|
+
if not math.isfinite(sigma_val):
|
|
75
|
+
return 0.0
|
|
76
|
+
u = u / u_norm.to(dtype)
|
|
77
|
+
v = W.T @ u
|
|
78
|
+
v_norm = torch.linalg.vector_norm(v.float()).clamp_min(eps)
|
|
79
|
+
v = v / v_norm.to(dtype)
|
|
80
|
+
sigma = sigma_val
|
|
81
|
+
return float(sigma)
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
def frobenius_norm_sq(matrix: torch.Tensor) -> float:
|
|
85
|
+
"""Return ||matrix||_F^2 with float32 accumulation (device-resident)."""
|
|
86
|
+
W = _as_matrix(matrix.detach())
|
|
87
|
+
if W.numel() == 0:
|
|
88
|
+
return 0.0
|
|
89
|
+
with torch.no_grad():
|
|
90
|
+
# Use a fused reduction to avoid materializing a W*W intermediate.
|
|
91
|
+
norm = torch.linalg.vector_norm(W.reshape(-1), ord=2, dtype=torch.float32)
|
|
92
|
+
out = float((norm * norm).item())
|
|
93
|
+
return out if math.isfinite(out) else 0.0
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
def row_col_norm_extrema(
|
|
97
|
+
matrix: torch.Tensor, *, eps: float = 1e-12
|
|
98
|
+
) -> dict[str, float]:
|
|
99
|
+
"""Compute min/median/max of row/col L2 norms with float32 accumulation."""
|
|
100
|
+
W = _as_matrix(matrix.detach())
|
|
101
|
+
if W.numel() == 0 or W.shape[0] == 0 or W.shape[1] == 0:
|
|
102
|
+
return {
|
|
103
|
+
"row_min": 0.0,
|
|
104
|
+
"row_median": 0.0,
|
|
105
|
+
"row_max": 0.0,
|
|
106
|
+
"col_min": 0.0,
|
|
107
|
+
"col_median": 0.0,
|
|
108
|
+
"col_max": 0.0,
|
|
109
|
+
}
|
|
110
|
+
with torch.no_grad():
|
|
111
|
+
# Avoid materializing W*W: use fused reductions.
|
|
112
|
+
row = torch.linalg.vector_norm(W, ord=2, dim=1, dtype=torch.float32).clamp_min(
|
|
113
|
+
eps
|
|
114
|
+
)
|
|
115
|
+
col = torch.linalg.vector_norm(W, ord=2, dim=0, dtype=torch.float32).clamp_min(
|
|
116
|
+
eps
|
|
117
|
+
)
|
|
118
|
+
|
|
119
|
+
row_sorted, _ = torch.sort(row)
|
|
120
|
+
col_sorted, _ = torch.sort(col)
|
|
121
|
+
|
|
122
|
+
def _median(sorted_vec: torch.Tensor) -> float:
|
|
123
|
+
n = int(sorted_vec.numel())
|
|
124
|
+
if n <= 0:
|
|
125
|
+
return 0.0
|
|
126
|
+
mid = n // 2
|
|
127
|
+
if n % 2 == 1:
|
|
128
|
+
return float(sorted_vec[mid].item())
|
|
129
|
+
return float((sorted_vec[mid - 1] + sorted_vec[mid]).mul(0.5).item())
|
|
130
|
+
|
|
131
|
+
return {
|
|
132
|
+
"row_min": float(row_sorted[0].item()),
|
|
133
|
+
"row_median": _median(row_sorted),
|
|
134
|
+
"row_max": float(row_sorted[-1].item()),
|
|
135
|
+
"col_min": float(col_sorted[0].item()),
|
|
136
|
+
"col_median": _median(col_sorted),
|
|
137
|
+
"col_max": float(col_sorted[-1].item()),
|
|
138
|
+
}
|
|
139
|
+
|
|
140
|
+
|
|
141
|
+
def stable_rank_estimate(
|
|
142
|
+
matrix: torch.Tensor, *, sigma_max: float, eps: float = 1e-12
|
|
143
|
+
) -> float:
|
|
144
|
+
"""Estimate stable rank: ||W||_F^2 / ||W||_2^2, using a provided σ̂max."""
|
|
145
|
+
try:
|
|
146
|
+
denom = float(sigma_max) ** 2
|
|
147
|
+
except Exception:
|
|
148
|
+
return 0.0
|
|
149
|
+
if not math.isfinite(denom) or denom <= 0.0:
|
|
150
|
+
return 0.0
|
|
151
|
+
denom = max(denom, eps)
|
|
152
|
+
num = frobenius_norm_sq(matrix)
|
|
153
|
+
out = float(num) / denom if denom > 0 else 0.0
|
|
154
|
+
return out if math.isfinite(out) else 0.0
|