invarlock 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- invarlock/__init__.py +33 -0
- invarlock/__main__.py +10 -0
- invarlock/_data/runtime/profiles/ci_cpu.yaml +15 -0
- invarlock/_data/runtime/profiles/release.yaml +23 -0
- invarlock/_data/runtime/tiers.yaml +76 -0
- invarlock/adapters/__init__.py +102 -0
- invarlock/adapters/_capabilities.py +45 -0
- invarlock/adapters/auto.py +99 -0
- invarlock/adapters/base.py +530 -0
- invarlock/adapters/base_types.py +85 -0
- invarlock/adapters/hf_bert.py +852 -0
- invarlock/adapters/hf_gpt2.py +403 -0
- invarlock/adapters/hf_llama.py +485 -0
- invarlock/adapters/hf_mixin.py +383 -0
- invarlock/adapters/hf_onnx.py +112 -0
- invarlock/adapters/hf_t5.py +137 -0
- invarlock/adapters/py.typed +1 -0
- invarlock/assurance/__init__.py +43 -0
- invarlock/cli/__init__.py +8 -0
- invarlock/cli/__main__.py +8 -0
- invarlock/cli/_evidence.py +25 -0
- invarlock/cli/_json.py +75 -0
- invarlock/cli/adapter_auto.py +162 -0
- invarlock/cli/app.py +287 -0
- invarlock/cli/commands/__init__.py +26 -0
- invarlock/cli/commands/certify.py +403 -0
- invarlock/cli/commands/doctor.py +1358 -0
- invarlock/cli/commands/explain_gates.py +151 -0
- invarlock/cli/commands/export_html.py +100 -0
- invarlock/cli/commands/plugins.py +1331 -0
- invarlock/cli/commands/report.py +354 -0
- invarlock/cli/commands/run.py +4146 -0
- invarlock/cli/commands/verify.py +1040 -0
- invarlock/cli/config.py +396 -0
- invarlock/cli/constants.py +68 -0
- invarlock/cli/device.py +92 -0
- invarlock/cli/doctor_helpers.py +74 -0
- invarlock/cli/errors.py +6 -0
- invarlock/cli/overhead_utils.py +60 -0
- invarlock/cli/provenance.py +66 -0
- invarlock/cli/utils.py +41 -0
- invarlock/config.py +56 -0
- invarlock/core/__init__.py +62 -0
- invarlock/core/abi.py +15 -0
- invarlock/core/api.py +274 -0
- invarlock/core/auto_tuning.py +317 -0
- invarlock/core/bootstrap.py +226 -0
- invarlock/core/checkpoint.py +221 -0
- invarlock/core/contracts.py +73 -0
- invarlock/core/error_utils.py +64 -0
- invarlock/core/events.py +298 -0
- invarlock/core/exceptions.py +95 -0
- invarlock/core/registry.py +481 -0
- invarlock/core/retry.py +146 -0
- invarlock/core/runner.py +2041 -0
- invarlock/core/types.py +154 -0
- invarlock/edits/__init__.py +12 -0
- invarlock/edits/_edit_utils.py +249 -0
- invarlock/edits/_external_utils.py +268 -0
- invarlock/edits/noop.py +47 -0
- invarlock/edits/py.typed +1 -0
- invarlock/edits/quant_rtn.py +801 -0
- invarlock/edits/registry.py +166 -0
- invarlock/eval/__init__.py +23 -0
- invarlock/eval/bench.py +1207 -0
- invarlock/eval/bootstrap.py +50 -0
- invarlock/eval/data.py +2052 -0
- invarlock/eval/metrics.py +2167 -0
- invarlock/eval/primary_metric.py +767 -0
- invarlock/eval/probes/__init__.py +24 -0
- invarlock/eval/probes/fft.py +139 -0
- invarlock/eval/probes/mi.py +213 -0
- invarlock/eval/probes/post_attention.py +323 -0
- invarlock/eval/providers/base.py +67 -0
- invarlock/eval/providers/seq2seq.py +111 -0
- invarlock/eval/providers/text_lm.py +113 -0
- invarlock/eval/providers/vision_text.py +93 -0
- invarlock/eval/py.typed +1 -0
- invarlock/guards/__init__.py +18 -0
- invarlock/guards/_contracts.py +9 -0
- invarlock/guards/invariants.py +640 -0
- invarlock/guards/policies.py +805 -0
- invarlock/guards/py.typed +1 -0
- invarlock/guards/rmt.py +2097 -0
- invarlock/guards/spectral.py +1419 -0
- invarlock/guards/tier_config.py +354 -0
- invarlock/guards/variance.py +3298 -0
- invarlock/guards_ref/__init__.py +15 -0
- invarlock/guards_ref/rmt_ref.py +40 -0
- invarlock/guards_ref/spectral_ref.py +135 -0
- invarlock/guards_ref/variance_ref.py +60 -0
- invarlock/model_profile.py +353 -0
- invarlock/model_utils.py +221 -0
- invarlock/observability/__init__.py +10 -0
- invarlock/observability/alerting.py +535 -0
- invarlock/observability/core.py +546 -0
- invarlock/observability/exporters.py +565 -0
- invarlock/observability/health.py +588 -0
- invarlock/observability/metrics.py +457 -0
- invarlock/observability/py.typed +1 -0
- invarlock/observability/utils.py +553 -0
- invarlock/plugins/__init__.py +12 -0
- invarlock/plugins/hello_guard.py +33 -0
- invarlock/plugins/hf_awq_adapter.py +82 -0
- invarlock/plugins/hf_bnb_adapter.py +79 -0
- invarlock/plugins/hf_gptq_adapter.py +78 -0
- invarlock/plugins/py.typed +1 -0
- invarlock/py.typed +1 -0
- invarlock/reporting/__init__.py +7 -0
- invarlock/reporting/certificate.py +3221 -0
- invarlock/reporting/certificate_schema.py +244 -0
- invarlock/reporting/dataset_hashing.py +215 -0
- invarlock/reporting/guards_analysis.py +948 -0
- invarlock/reporting/html.py +32 -0
- invarlock/reporting/normalizer.py +235 -0
- invarlock/reporting/policy_utils.py +517 -0
- invarlock/reporting/primary_metric_utils.py +265 -0
- invarlock/reporting/render.py +1442 -0
- invarlock/reporting/report.py +903 -0
- invarlock/reporting/report_types.py +278 -0
- invarlock/reporting/utils.py +175 -0
- invarlock/reporting/validate.py +631 -0
- invarlock/security.py +176 -0
- invarlock/sparsity_utils.py +323 -0
- invarlock/utils/__init__.py +150 -0
- invarlock/utils/digest.py +45 -0
- invarlock-0.2.0.dist-info/METADATA +586 -0
- invarlock-0.2.0.dist-info/RECORD +132 -0
- invarlock-0.2.0.dist-info/WHEEL +5 -0
- invarlock-0.2.0.dist-info/entry_points.txt +20 -0
- invarlock-0.2.0.dist-info/licenses/LICENSE +201 -0
- invarlock-0.2.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,767 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Primary metric abstraction and minimal ppl_causal implementation (Phase 1).
|
|
3
|
+
|
|
4
|
+
This module introduces a light-weight, task-agnostic metric interface and a
|
|
5
|
+
registry so the runner/certificate can evolve beyond causal-LM perplexity.
|
|
6
|
+
|
|
7
|
+
Phase 1 goal: provide a ppl_causal metric and a helper that can compute point
|
|
8
|
+
estimates directly from evaluation window aggregates already present in run
|
|
9
|
+
reports. This is the canonical path; no env flag toggles.
|
|
10
|
+
"""
|
|
11
|
+
|
|
12
|
+
from __future__ import annotations
|
|
13
|
+
|
|
14
|
+
import math
|
|
15
|
+
from collections.abc import Iterable, Sequence
|
|
16
|
+
from dataclasses import dataclass
|
|
17
|
+
from typing import Any, Protocol
|
|
18
|
+
|
|
19
|
+
import numpy as np
|
|
20
|
+
|
|
21
|
+
from invarlock.core.bootstrap import compute_paired_delta_log_ci
|
|
22
|
+
from invarlock.core.exceptions import ValidationError
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
@dataclass
|
|
26
|
+
class MetricDefaults:
|
|
27
|
+
reps: int = 2000
|
|
28
|
+
ci_level: float = 0.95
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
@dataclass
|
|
32
|
+
class MetricContribution:
|
|
33
|
+
"""Per-example contribution to a metric.
|
|
34
|
+
|
|
35
|
+
For ppl_* metrics, `value` is per-token mean log-loss for the example and
|
|
36
|
+
`weight` is the number of target tokens. For accuracy, `value` is 0/1 and
|
|
37
|
+
`weight` is ignored.
|
|
38
|
+
"""
|
|
39
|
+
|
|
40
|
+
value: float
|
|
41
|
+
weight: float = 1.0
|
|
42
|
+
id: str | None = None
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
class PrimaryMetric(Protocol):
|
|
46
|
+
"""Protocol for task-agnostic primary metrics.
|
|
47
|
+
|
|
48
|
+
Implementations should describe their semantics and provide helpers to
|
|
49
|
+
compute point estimates (and optionally CIs) from per-example aggregates.
|
|
50
|
+
"""
|
|
51
|
+
|
|
52
|
+
kind: str
|
|
53
|
+
unit: str
|
|
54
|
+
direction: str # "lower" | "higher"
|
|
55
|
+
aggregation_scope: str # "token" | "sequence" | "example"
|
|
56
|
+
paired: bool
|
|
57
|
+
gating_basis: str # "point" | "upper" | "lower"
|
|
58
|
+
supports_bootstrap: bool
|
|
59
|
+
defaults: MetricDefaults
|
|
60
|
+
|
|
61
|
+
def display_transform(self, x: float) -> float:
|
|
62
|
+
"""Map native comparison space to display space.
|
|
63
|
+
|
|
64
|
+
ppl_*: exp(x) maps Δlog-loss → ratio
|
|
65
|
+
accuracy: x*100 maps proportion Δ → percentage points
|
|
66
|
+
"""
|
|
67
|
+
|
|
68
|
+
def point_from_windows(self, *, windows: dict[str, Any]) -> float:
|
|
69
|
+
"""Compute a single point estimate from evaluation windows.
|
|
70
|
+
|
|
71
|
+
For token-aggregated loss metrics (like ppl), this expects:
|
|
72
|
+
windows = {"logloss": [...], "token_counts": [...]} with matching lengths.
|
|
73
|
+
"""
|
|
74
|
+
|
|
75
|
+
def accumulate(self, contrib: MetricContribution) -> None:
|
|
76
|
+
"""Accumulate a per-example contribution for finalize()."""
|
|
77
|
+
|
|
78
|
+
def finalize(self) -> float:
|
|
79
|
+
"""Return a point estimate from accumulated contributions (display space)."""
|
|
80
|
+
|
|
81
|
+
def paired_compare(
|
|
82
|
+
self,
|
|
83
|
+
subject: Iterable[dict[str, Any] | MetricContribution],
|
|
84
|
+
baseline: Iterable[dict[str, Any] | MetricContribution],
|
|
85
|
+
*,
|
|
86
|
+
reps: int | None = None,
|
|
87
|
+
seed: int | None = None,
|
|
88
|
+
ci_level: float | None = None,
|
|
89
|
+
) -> dict[str, Any]:
|
|
90
|
+
"""Paired compare subject vs baseline with bootstrap CI.
|
|
91
|
+
|
|
92
|
+
Returns a dict with native-space delta and display-space values:
|
|
93
|
+
{
|
|
94
|
+
'delta': float, # native compare space (Δlog-loss for ppl, Δ proportion for accuracy)
|
|
95
|
+
'ci': (lo, hi), # native-space CI
|
|
96
|
+
'display': float, # display space (ratio for ppl, pp for accuracy)
|
|
97
|
+
'display_ci': (lo, hi),
|
|
98
|
+
'subject_point': float, # point estimate in display space
|
|
99
|
+
'baseline_point': float, # point estimate in display space
|
|
100
|
+
'reps': int, 'ci_level': float, 'paired': True, meta...
|
|
101
|
+
}
|
|
102
|
+
"""
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
@dataclass
|
|
106
|
+
class MetricInfo:
|
|
107
|
+
kind: str
|
|
108
|
+
unit: str
|
|
109
|
+
direction: str
|
|
110
|
+
aggregation_scope: str
|
|
111
|
+
paired: bool
|
|
112
|
+
gating_basis: str
|
|
113
|
+
supports_bootstrap: bool
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
class _PPLCausal(PrimaryMetric):
|
|
117
|
+
"""Token-aggregated perplexity for causal LMs.
|
|
118
|
+
|
|
119
|
+
point_from_windows computes ppl = exp(weighted_mean(logloss)).
|
|
120
|
+
"""
|
|
121
|
+
|
|
122
|
+
kind = "ppl_causal"
|
|
123
|
+
unit = "ppl"
|
|
124
|
+
direction = "lower"
|
|
125
|
+
aggregation_scope = "token"
|
|
126
|
+
paired = True
|
|
127
|
+
gating_basis = "upper" # typical gate on ratio upper-bound
|
|
128
|
+
supports_bootstrap = True
|
|
129
|
+
defaults = MetricDefaults()
|
|
130
|
+
|
|
131
|
+
def __init__(self) -> None:
|
|
132
|
+
self._values: list[float] = []
|
|
133
|
+
self._weights: list[float] = []
|
|
134
|
+
|
|
135
|
+
def display_transform(self, x: float) -> float:
|
|
136
|
+
try:
|
|
137
|
+
return float(math.exp(x))
|
|
138
|
+
except OverflowError:
|
|
139
|
+
return float("inf")
|
|
140
|
+
|
|
141
|
+
def point_from_windows(self, *, windows: dict[str, Any]) -> float:
|
|
142
|
+
logloss = list(windows.get("logloss", []) or [])
|
|
143
|
+
token_counts = list(windows.get("token_counts", []) or [])
|
|
144
|
+
if not logloss or not token_counts or len(logloss) != len(token_counts):
|
|
145
|
+
return float("nan")
|
|
146
|
+
sum_w = 0.0
|
|
147
|
+
sum_wx = 0.0
|
|
148
|
+
for ll, w in zip(logloss, token_counts, strict=False):
|
|
149
|
+
try:
|
|
150
|
+
llv = float(ll)
|
|
151
|
+
wv = float(w)
|
|
152
|
+
except Exception:
|
|
153
|
+
continue
|
|
154
|
+
if not math.isfinite(llv) or not math.isfinite(wv) or wv <= 0:
|
|
155
|
+
continue
|
|
156
|
+
sum_w += wv
|
|
157
|
+
sum_wx += wv * llv
|
|
158
|
+
if sum_w <= 0.0:
|
|
159
|
+
return float("nan")
|
|
160
|
+
mean_ll = sum_wx / sum_w
|
|
161
|
+
try:
|
|
162
|
+
return float(math.exp(mean_ll))
|
|
163
|
+
except OverflowError:
|
|
164
|
+
return float("inf")
|
|
165
|
+
|
|
166
|
+
def accumulate(self, contrib: MetricContribution) -> None:
|
|
167
|
+
try:
|
|
168
|
+
v = float(contrib.value)
|
|
169
|
+
w = float(contrib.weight)
|
|
170
|
+
except Exception:
|
|
171
|
+
return
|
|
172
|
+
if not math.isfinite(v) or not math.isfinite(w) or w <= 0:
|
|
173
|
+
return
|
|
174
|
+
self._values.append(v)
|
|
175
|
+
self._weights.append(w)
|
|
176
|
+
|
|
177
|
+
def finalize(self) -> float:
|
|
178
|
+
if (
|
|
179
|
+
not self._values
|
|
180
|
+
or not self._weights
|
|
181
|
+
or len(self._values) != len(self._weights)
|
|
182
|
+
):
|
|
183
|
+
return float("nan")
|
|
184
|
+
sw = 0.0
|
|
185
|
+
swx = 0.0
|
|
186
|
+
for v, w in zip(self._values, self._weights, strict=False):
|
|
187
|
+
sw += w
|
|
188
|
+
swx += w * v
|
|
189
|
+
if sw <= 0.0:
|
|
190
|
+
return float("nan")
|
|
191
|
+
return self.display_transform(swx / sw)
|
|
192
|
+
|
|
193
|
+
def _coerce_contrib_array(
|
|
194
|
+
self, items: Iterable[dict[str, Any] | MetricContribution]
|
|
195
|
+
) -> list[tuple[float, float]]:
|
|
196
|
+
out: list[tuple[float, float]] = []
|
|
197
|
+
for it in items:
|
|
198
|
+
if isinstance(it, MetricContribution):
|
|
199
|
+
out.append((float(it.value), float(it.weight)))
|
|
200
|
+
elif isinstance(it, dict) and "value" in it:
|
|
201
|
+
v = float(it.get("value"))
|
|
202
|
+
w = float(it.get("weight", 1.0))
|
|
203
|
+
out.append((v, w))
|
|
204
|
+
return out
|
|
205
|
+
|
|
206
|
+
def paired_compare(
|
|
207
|
+
self,
|
|
208
|
+
subject: Iterable[dict[str, Any] | MetricContribution],
|
|
209
|
+
baseline: Iterable[dict[str, Any] | MetricContribution],
|
|
210
|
+
*,
|
|
211
|
+
reps: int | None = None,
|
|
212
|
+
seed: int | None = None,
|
|
213
|
+
ci_level: float | None = None,
|
|
214
|
+
) -> dict[str, Any]:
|
|
215
|
+
subj = self._coerce_contrib_array(subject)
|
|
216
|
+
base = self._coerce_contrib_array(baseline)
|
|
217
|
+
# Compute simple (unweighted) per-example arrays in log space; weights ignored for bootstrap here
|
|
218
|
+
subj_vals = [v for (v, _w) in subj]
|
|
219
|
+
base_vals = [v for (v, _w) in base]
|
|
220
|
+
|
|
221
|
+
# Points in display space
|
|
222
|
+
def _point(
|
|
223
|
+
vals: Sequence[float], weights: Sequence[float] | None = None
|
|
224
|
+
) -> float:
|
|
225
|
+
if not vals:
|
|
226
|
+
return float("nan")
|
|
227
|
+
if weights and len(weights) == len(vals):
|
|
228
|
+
sw = 0.0
|
|
229
|
+
swx = 0.0
|
|
230
|
+
for v, w in zip(vals, weights, strict=False):
|
|
231
|
+
sw += w
|
|
232
|
+
swx += w * v
|
|
233
|
+
if sw <= 0:
|
|
234
|
+
return float("nan")
|
|
235
|
+
return self.display_transform(swx / sw)
|
|
236
|
+
else:
|
|
237
|
+
return self.display_transform(sum(vals) / float(len(vals)))
|
|
238
|
+
|
|
239
|
+
subj_point = _point([v for v, _ in subj], [w for _, w in subj])
|
|
240
|
+
base_point = _point([v for v, _ in base], [w for _, w in base])
|
|
241
|
+
|
|
242
|
+
# Bootstrap Δlog-loss → CI, then display-transform → ratio CI
|
|
243
|
+
reps_eff = int(reps) if (reps is not None and reps > 0) else self.defaults.reps
|
|
244
|
+
seed_eff = int(seed) if (seed is not None) else 0
|
|
245
|
+
ci_level_eff = (
|
|
246
|
+
float(ci_level) if (ci_level is not None) else self.defaults.ci_level
|
|
247
|
+
)
|
|
248
|
+
alpha = 1.0 - ci_level_eff
|
|
249
|
+
dlog_lo, dlog_hi = compute_paired_delta_log_ci(
|
|
250
|
+
subj_vals,
|
|
251
|
+
base_vals,
|
|
252
|
+
method="bca",
|
|
253
|
+
replicates=reps_eff,
|
|
254
|
+
alpha=alpha,
|
|
255
|
+
seed=seed_eff,
|
|
256
|
+
)
|
|
257
|
+
delta_log = float(
|
|
258
|
+
sum((s - b) for s, b in zip(subj_vals, base_vals, strict=False))
|
|
259
|
+
/ max(1, min(len(subj_vals), len(base_vals)))
|
|
260
|
+
)
|
|
261
|
+
ratio = self.display_transform(delta_log)
|
|
262
|
+
return {
|
|
263
|
+
"kind": self.kind,
|
|
264
|
+
"unit": self.unit,
|
|
265
|
+
"direction": self.direction,
|
|
266
|
+
"aggregation_scope": self.aggregation_scope,
|
|
267
|
+
"paired": True,
|
|
268
|
+
"gating_basis": self.gating_basis,
|
|
269
|
+
"supports_bootstrap": self.supports_bootstrap,
|
|
270
|
+
"reps": reps_eff,
|
|
271
|
+
"ci_level": ci_level_eff,
|
|
272
|
+
"subject_point": subj_point,
|
|
273
|
+
"baseline_point": base_point,
|
|
274
|
+
"delta": delta_log,
|
|
275
|
+
"ci": (dlog_lo, dlog_hi),
|
|
276
|
+
"display": ratio,
|
|
277
|
+
"display_ci": (
|
|
278
|
+
self.display_transform(dlog_lo),
|
|
279
|
+
self.display_transform(dlog_hi),
|
|
280
|
+
),
|
|
281
|
+
}
|
|
282
|
+
|
|
283
|
+
|
|
284
|
+
# ── Simple registry ───────────────────────────────────────────────────────
|
|
285
|
+
class _PPLMLM(_PPLCausal):
|
|
286
|
+
"""Masked LM perplexity.
|
|
287
|
+
|
|
288
|
+
Uses masked_token_counts when available; falls back to token_counts.
|
|
289
|
+
"""
|
|
290
|
+
|
|
291
|
+
kind = "ppl_mlm"
|
|
292
|
+
|
|
293
|
+
def point_from_windows(self, *, windows: dict[str, Any]) -> float:
|
|
294
|
+
# Prefer masked_token_counts for MLM
|
|
295
|
+
masked = list(windows.get("masked_token_counts", []) or [])
|
|
296
|
+
if masked:
|
|
297
|
+
win = {"logloss": windows.get("logloss", []), "token_counts": masked}
|
|
298
|
+
return super().point_from_windows(windows=win)
|
|
299
|
+
return super().point_from_windows(windows=windows)
|
|
300
|
+
|
|
301
|
+
|
|
302
|
+
class _PPLSeq2Seq(_PPLCausal):
|
|
303
|
+
"""Seq2Seq perplexity (token-aggregated over decoder labels)."""
|
|
304
|
+
|
|
305
|
+
kind = "ppl_seq2seq"
|
|
306
|
+
|
|
307
|
+
|
|
308
|
+
class _Accuracy:
|
|
309
|
+
"""Example-aggregated accuracy (0..1).
|
|
310
|
+
|
|
311
|
+
Accepts either per-example flags (example_correct) or aggregate counts
|
|
312
|
+
(correct_total/total or correct_counts/total_counts).
|
|
313
|
+
"""
|
|
314
|
+
|
|
315
|
+
kind = "accuracy"
|
|
316
|
+
unit = "accuracy"
|
|
317
|
+
direction = "higher"
|
|
318
|
+
aggregation_scope = "example"
|
|
319
|
+
paired = True
|
|
320
|
+
gating_basis = "lower"
|
|
321
|
+
supports_bootstrap = True
|
|
322
|
+
defaults = MetricDefaults()
|
|
323
|
+
|
|
324
|
+
def __init__(self) -> None:
|
|
325
|
+
self._values: list[float] = []
|
|
326
|
+
|
|
327
|
+
def display_transform(self, x: float) -> float: # proportion → percentage points
|
|
328
|
+
return float(x * 100.0)
|
|
329
|
+
|
|
330
|
+
def point_from_windows(self, *, windows: dict[str, Any]) -> float:
|
|
331
|
+
# Per-example path
|
|
332
|
+
ex = list(windows.get("example_correct", []) or [])
|
|
333
|
+
if ex:
|
|
334
|
+
s = 0.0
|
|
335
|
+
n = 0.0
|
|
336
|
+
for v in ex:
|
|
337
|
+
try:
|
|
338
|
+
s += 1.0 if float(v) > 0.5 else 0.0
|
|
339
|
+
n += 1.0
|
|
340
|
+
except Exception:
|
|
341
|
+
continue
|
|
342
|
+
if n > 0:
|
|
343
|
+
return s / n
|
|
344
|
+
return float("nan")
|
|
345
|
+
# Aggregate counts path
|
|
346
|
+
for c_key, t_key in (
|
|
347
|
+
("correct_total", "total"),
|
|
348
|
+
("correct_counts", "total_counts"),
|
|
349
|
+
):
|
|
350
|
+
c = windows.get(c_key)
|
|
351
|
+
t = windows.get(t_key)
|
|
352
|
+
if isinstance(c, int | float) and isinstance(t, int | float) and t > 0:
|
|
353
|
+
total = float(t)
|
|
354
|
+
# Optional abstain/tie handling with documented policy
|
|
355
|
+
try:
|
|
356
|
+
policy = (
|
|
357
|
+
windows.get("policy", {})
|
|
358
|
+
if isinstance(windows.get("policy"), dict)
|
|
359
|
+
else {}
|
|
360
|
+
)
|
|
361
|
+
abstain = windows.get("abstain_total")
|
|
362
|
+
ties = windows.get("ties_total")
|
|
363
|
+
exclude_abstain = bool(policy.get("exclude_abstain", True))
|
|
364
|
+
count_ties_as_correct = bool(
|
|
365
|
+
policy.get("ties_count_as_correct", False)
|
|
366
|
+
)
|
|
367
|
+
count_ties_as_incorrect = bool(
|
|
368
|
+
policy.get("ties_count_as_incorrect", False)
|
|
369
|
+
)
|
|
370
|
+
# Apply abstain exclusion from denominator if requested
|
|
371
|
+
if (
|
|
372
|
+
exclude_abstain
|
|
373
|
+
and isinstance(abstain, int | float)
|
|
374
|
+
and abstain > 0
|
|
375
|
+
):
|
|
376
|
+
total = max(1.0, total - float(abstain))
|
|
377
|
+
# Apply tie policy
|
|
378
|
+
if isinstance(ties, int | float) and ties > 0:
|
|
379
|
+
if count_ties_as_correct:
|
|
380
|
+
c = float(c) + float(ties)
|
|
381
|
+
elif count_ties_as_incorrect:
|
|
382
|
+
# leave c unchanged; implicit in denominator
|
|
383
|
+
pass
|
|
384
|
+
else:
|
|
385
|
+
# default: treat ties as abstain (exclude if exclude_abstain True)
|
|
386
|
+
if exclude_abstain:
|
|
387
|
+
total = max(1.0, total - float(ties))
|
|
388
|
+
except Exception:
|
|
389
|
+
pass
|
|
390
|
+
return float(c) / float(total)
|
|
391
|
+
return float("nan")
|
|
392
|
+
|
|
393
|
+
def accumulate(self, contrib: MetricContribution) -> None:
|
|
394
|
+
try:
|
|
395
|
+
v = float(contrib.value)
|
|
396
|
+
except Exception:
|
|
397
|
+
return
|
|
398
|
+
if not math.isfinite(v):
|
|
399
|
+
return
|
|
400
|
+
# Clamp to [0,1]
|
|
401
|
+
v = 1.0 if v >= 0.5 else 0.0
|
|
402
|
+
self._values.append(v)
|
|
403
|
+
|
|
404
|
+
def finalize(self) -> float:
|
|
405
|
+
if not self._values:
|
|
406
|
+
return float("nan")
|
|
407
|
+
return float(sum(self._values) / float(len(self._values)))
|
|
408
|
+
|
|
409
|
+
def _coerce_vals(
|
|
410
|
+
self, items: Iterable[dict[str, Any] | MetricContribution]
|
|
411
|
+
) -> list[float]:
|
|
412
|
+
out: list[float] = []
|
|
413
|
+
for it in items:
|
|
414
|
+
if isinstance(it, MetricContribution):
|
|
415
|
+
out.append(1.0 if float(it.value) >= 0.5 else 0.0)
|
|
416
|
+
elif isinstance(it, dict) and "value" in it:
|
|
417
|
+
v = float(it.get("value"))
|
|
418
|
+
out.append(1.0 if v >= 0.5 else 0.0)
|
|
419
|
+
return out
|
|
420
|
+
|
|
421
|
+
def paired_compare(
|
|
422
|
+
self,
|
|
423
|
+
subject: Iterable[dict[str, Any] | MetricContribution],
|
|
424
|
+
baseline: Iterable[dict[str, Any] | MetricContribution],
|
|
425
|
+
*,
|
|
426
|
+
reps: int | None = None,
|
|
427
|
+
seed: int | None = None,
|
|
428
|
+
ci_level: float | None = None,
|
|
429
|
+
) -> dict[str, Any]:
|
|
430
|
+
subj = self._coerce_vals(subject)
|
|
431
|
+
base = self._coerce_vals(baseline)
|
|
432
|
+
m = min(len(subj), len(base))
|
|
433
|
+
subj = subj[:m]
|
|
434
|
+
base = base[:m]
|
|
435
|
+
if m == 0:
|
|
436
|
+
return {
|
|
437
|
+
"kind": self.kind,
|
|
438
|
+
"unit": self.unit,
|
|
439
|
+
"direction": self.direction,
|
|
440
|
+
"aggregation_scope": self.aggregation_scope,
|
|
441
|
+
"paired": True,
|
|
442
|
+
"gating_basis": self.gating_basis,
|
|
443
|
+
"supports_bootstrap": self.supports_bootstrap,
|
|
444
|
+
"reps": 0,
|
|
445
|
+
"ci_level": ci_level or self.defaults.ci_level,
|
|
446
|
+
"subject_point": float("nan"),
|
|
447
|
+
"baseline_point": float("nan"),
|
|
448
|
+
"delta": float("nan"),
|
|
449
|
+
"ci": (float("nan"), float("nan")),
|
|
450
|
+
"display": float("nan"),
|
|
451
|
+
"display_ci": (float("nan"), float("nan")),
|
|
452
|
+
}
|
|
453
|
+
# Points in display space for subject/baseline (proportions, no transform)
|
|
454
|
+
subj_point = float(sum(subj) / float(m))
|
|
455
|
+
base_point = float(sum(base) / float(m))
|
|
456
|
+
# Δ in native (proportion) space
|
|
457
|
+
diffs = [float(s - b) for s, b in zip(subj, base, strict=False)]
|
|
458
|
+
delta = float(sum(diffs) / float(m))
|
|
459
|
+
reps_eff = int(reps) if (reps is not None and reps > 0) else self.defaults.reps
|
|
460
|
+
seed_eff = int(seed) if (seed is not None) else 0
|
|
461
|
+
ci_level_eff = (
|
|
462
|
+
float(ci_level) if (ci_level is not None) else self.defaults.ci_level
|
|
463
|
+
)
|
|
464
|
+
alpha = 1.0 - ci_level_eff
|
|
465
|
+
# Percentile bootstrap on paired diffs
|
|
466
|
+
rng = np.random.default_rng(seed_eff) # type: ignore[name-defined]
|
|
467
|
+
stats = []
|
|
468
|
+
for _ in range(reps_eff):
|
|
469
|
+
idx = rng.integers(0, m, size=m)
|
|
470
|
+
s = 0.0
|
|
471
|
+
for i in idx:
|
|
472
|
+
s += diffs[i]
|
|
473
|
+
stats.append(s / float(m))
|
|
474
|
+
stats.sort()
|
|
475
|
+
lo = float(np.percentile(stats, 100.0 * (alpha / 2.0))) # type: ignore[name-defined]
|
|
476
|
+
hi = float(np.percentile(stats, 100.0 * (1.0 - alpha / 2.0))) # type: ignore[name-defined]
|
|
477
|
+
return {
|
|
478
|
+
"kind": self.kind,
|
|
479
|
+
"unit": self.unit,
|
|
480
|
+
"direction": self.direction,
|
|
481
|
+
"aggregation_scope": self.aggregation_scope,
|
|
482
|
+
"paired": True,
|
|
483
|
+
"gating_basis": self.gating_basis,
|
|
484
|
+
"supports_bootstrap": self.supports_bootstrap,
|
|
485
|
+
"reps": reps_eff,
|
|
486
|
+
"ci_level": ci_level_eff,
|
|
487
|
+
"subject_point": subj_point,
|
|
488
|
+
"baseline_point": base_point,
|
|
489
|
+
"delta": delta,
|
|
490
|
+
"ci": (lo, hi),
|
|
491
|
+
"display": self.display_transform(delta),
|
|
492
|
+
"display_ci": (self.display_transform(lo), self.display_transform(hi)),
|
|
493
|
+
}
|
|
494
|
+
|
|
495
|
+
|
|
496
|
+
class _AliasMetric:
|
|
497
|
+
"""Light alias wrapper that re-labels `kind` while delegating behavior."""
|
|
498
|
+
|
|
499
|
+
def __init__(self, alias: str, base: PrimaryMetric) -> None:
|
|
500
|
+
self._alias = str(alias)
|
|
501
|
+
self._base = base
|
|
502
|
+
# Copy metadata
|
|
503
|
+
self.kind = self._alias
|
|
504
|
+
self.unit = base.unit
|
|
505
|
+
self.direction = base.direction
|
|
506
|
+
self.aggregation_scope = base.aggregation_scope
|
|
507
|
+
self.paired = base.paired
|
|
508
|
+
self.gating_basis = base.gating_basis
|
|
509
|
+
self.supports_bootstrap = base.supports_bootstrap
|
|
510
|
+
|
|
511
|
+
def point_from_windows(self, *, windows: dict[str, Any]) -> float:
|
|
512
|
+
return self._base.point_from_windows(windows=windows)
|
|
513
|
+
|
|
514
|
+
|
|
515
|
+
_REGISTRY: dict[str, PrimaryMetric] = {
|
|
516
|
+
_PPLCausal.kind: _PPLCausal(),
|
|
517
|
+
_PPLMLM.kind: _PPLMLM(),
|
|
518
|
+
_PPLSeq2Seq.kind: _PPLSeq2Seq(),
|
|
519
|
+
_Accuracy.kind: _Accuracy(),
|
|
520
|
+
# Multimodal aliases
|
|
521
|
+
"vqa_accuracy": _AliasMetric("vqa_accuracy", _Accuracy()),
|
|
522
|
+
}
|
|
523
|
+
|
|
524
|
+
|
|
525
|
+
def get_metric(kind: str) -> PrimaryMetric:
|
|
526
|
+
key = str(kind).lower()
|
|
527
|
+
if key in _REGISTRY:
|
|
528
|
+
return _REGISTRY[key]
|
|
529
|
+
raise KeyError(f"Unknown metric kind: {kind}")
|
|
530
|
+
|
|
531
|
+
|
|
532
|
+
def compute_primary_metric_from_report(
|
|
533
|
+
report: dict[str, Any],
|
|
534
|
+
*,
|
|
535
|
+
kind: str = "ppl_causal",
|
|
536
|
+
baseline: dict[str, Any] | None = None,
|
|
537
|
+
) -> dict[str, Any]:
|
|
538
|
+
"""Compute a primary metric snapshot from a run report (Phase 1 helper).
|
|
539
|
+
|
|
540
|
+
Returns a dict that can be attached to report["metrics"]["primary_metric"].
|
|
541
|
+
Includes preview/final points and (when baseline is present) a simple ratio
|
|
542
|
+
vs baseline based on baseline ppl_final.
|
|
543
|
+
"""
|
|
544
|
+
metric = get_metric(kind)
|
|
545
|
+
windows = report.get("evaluation_windows") if isinstance(report, dict) else None
|
|
546
|
+
|
|
547
|
+
# Choose window sections
|
|
548
|
+
preview_win: dict[str, Any] = {}
|
|
549
|
+
final_win: dict[str, Any] = {}
|
|
550
|
+
|
|
551
|
+
counts_source_tag: str | None = None
|
|
552
|
+
if kind in {"accuracy", "vqa_accuracy"}:
|
|
553
|
+
# Prefer classification aggregates if provided (may not have evaluation_windows)
|
|
554
|
+
metrics = (
|
|
555
|
+
report.get("metrics", {}) if isinstance(report.get("metrics"), dict) else {}
|
|
556
|
+
)
|
|
557
|
+
clf = metrics.get("classification") if isinstance(metrics, dict) else None
|
|
558
|
+
if isinstance(clf, dict) and clf:
|
|
559
|
+
prev = (
|
|
560
|
+
clf.get("preview", {}) if isinstance(clf.get("preview"), dict) else {}
|
|
561
|
+
)
|
|
562
|
+
fin = clf.get("final", {}) if isinstance(clf.get("final"), dict) else {}
|
|
563
|
+
preview_win = prev
|
|
564
|
+
final_win = fin
|
|
565
|
+
# Attach counts into a small context to help gating
|
|
566
|
+
try:
|
|
567
|
+
n_prev = None
|
|
568
|
+
n_fin = None
|
|
569
|
+
if isinstance(prev.get("total"), int | float):
|
|
570
|
+
n_prev = int(prev.get("total"))
|
|
571
|
+
elif isinstance(prev.get("example_correct"), list):
|
|
572
|
+
n_prev = len(prev.get("example_correct"))
|
|
573
|
+
if isinstance(fin.get("total"), int | float):
|
|
574
|
+
n_fin = int(fin.get("total"))
|
|
575
|
+
elif isinstance(fin.get("example_correct"), list):
|
|
576
|
+
n_fin = len(fin.get("example_correct"))
|
|
577
|
+
except Exception:
|
|
578
|
+
n_prev = None
|
|
579
|
+
n_fin = None
|
|
580
|
+
# Propagate counts source tagging when present
|
|
581
|
+
try:
|
|
582
|
+
counts_source = clf.get("counts_source")
|
|
583
|
+
if isinstance(counts_source, str) and counts_source:
|
|
584
|
+
counts_source_tag = counts_source
|
|
585
|
+
except Exception:
|
|
586
|
+
pass
|
|
587
|
+
|
|
588
|
+
if not preview_win and not final_win and isinstance(windows, dict):
|
|
589
|
+
preview_win = (
|
|
590
|
+
windows.get("preview", {})
|
|
591
|
+
if isinstance(windows.get("preview"), dict)
|
|
592
|
+
else {}
|
|
593
|
+
)
|
|
594
|
+
final_win = (
|
|
595
|
+
windows.get("final", {}) if isinstance(windows.get("final"), dict) else {}
|
|
596
|
+
)
|
|
597
|
+
|
|
598
|
+
if not preview_win and not final_win:
|
|
599
|
+
# Nothing to compute from
|
|
600
|
+
return {
|
|
601
|
+
"kind": metric.kind,
|
|
602
|
+
"unit": metric.unit,
|
|
603
|
+
"direction": metric.direction,
|
|
604
|
+
"aggregation_scope": metric.aggregation_scope,
|
|
605
|
+
"paired": metric.paired,
|
|
606
|
+
"gating_basis": metric.gating_basis,
|
|
607
|
+
"supports_bootstrap": metric.supports_bootstrap,
|
|
608
|
+
"preview": float("nan"),
|
|
609
|
+
"final": float("nan"),
|
|
610
|
+
"ratio_vs_baseline": float("nan"),
|
|
611
|
+
}
|
|
612
|
+
# For accuracy kinds, derive counts from input_ids if aggregates are missing
|
|
613
|
+
if kind in {"accuracy", "vqa_accuracy"}:
|
|
614
|
+
|
|
615
|
+
def _ensure_counts(win: dict[str, Any]) -> dict[str, Any]:
|
|
616
|
+
if not isinstance(win, dict):
|
|
617
|
+
return {}
|
|
618
|
+
has_counts = (
|
|
619
|
+
isinstance(win.get("correct_total"), int | float)
|
|
620
|
+
and isinstance(win.get("total"), int | float)
|
|
621
|
+
and win.get("total") > 0
|
|
622
|
+
)
|
|
623
|
+
if has_counts:
|
|
624
|
+
return win
|
|
625
|
+
# Try to derive from input_ids deterministically
|
|
626
|
+
recs = []
|
|
627
|
+
seqs = (
|
|
628
|
+
win.get("input_ids") if isinstance(win.get("input_ids"), list) else []
|
|
629
|
+
)
|
|
630
|
+
if isinstance(seqs, list) and seqs:
|
|
631
|
+
for seq in seqs:
|
|
632
|
+
if isinstance(seq, list):
|
|
633
|
+
recs.append({"input_ids": seq})
|
|
634
|
+
if recs:
|
|
635
|
+
try:
|
|
636
|
+
c, n = compute_accuracy_counts(recs)
|
|
637
|
+
return {"correct_total": int(c), "total": int(n)}
|
|
638
|
+
except Exception:
|
|
639
|
+
return win
|
|
640
|
+
return win
|
|
641
|
+
|
|
642
|
+
preview_win = _ensure_counts(preview_win)
|
|
643
|
+
final_win = _ensure_counts(final_win)
|
|
644
|
+
|
|
645
|
+
preview_point = metric.point_from_windows(windows=preview_win)
|
|
646
|
+
final_point = metric.point_from_windows(windows=final_win)
|
|
647
|
+
|
|
648
|
+
ratio_vs_baseline = float("nan")
|
|
649
|
+
if isinstance(baseline, dict):
|
|
650
|
+
try:
|
|
651
|
+
base_metrics = (
|
|
652
|
+
baseline.get("metrics", {})
|
|
653
|
+
if isinstance(baseline.get("metrics"), dict)
|
|
654
|
+
else {}
|
|
655
|
+
)
|
|
656
|
+
pm_base = base_metrics.get("primary_metric")
|
|
657
|
+
base_kind = (
|
|
658
|
+
str(pm_base.get("kind", "")).lower()
|
|
659
|
+
if isinstance(pm_base, dict)
|
|
660
|
+
else ""
|
|
661
|
+
)
|
|
662
|
+
kind_l = str(kind).lower()
|
|
663
|
+
ppl_kinds = {"ppl_causal", "ppl_mlm", "ppl_seq2seq"}
|
|
664
|
+
acc_kinds = {"accuracy", "vqa_accuracy"}
|
|
665
|
+
same_family = (kind_l in ppl_kinds and base_kind in ppl_kinds) or (
|
|
666
|
+
kind_l in acc_kinds and base_kind in acc_kinds
|
|
667
|
+
)
|
|
668
|
+
if isinstance(pm_base, dict) and (base_kind == kind_l or same_family):
|
|
669
|
+
base_ref = pm_base.get("final")
|
|
670
|
+
if isinstance(base_ref, (int | float)):
|
|
671
|
+
is_ppl_like = str(kind).lower().startswith("ppl")
|
|
672
|
+
if is_ppl_like and base_ref > 0:
|
|
673
|
+
ratio_vs_baseline = float(final_point) / float(base_ref)
|
|
674
|
+
elif (
|
|
675
|
+
str(kind).lower() in {"accuracy", "vqa_accuracy"}
|
|
676
|
+
and 0 <= base_ref <= 1
|
|
677
|
+
):
|
|
678
|
+
ratio_vs_baseline = float(final_point) - float(base_ref)
|
|
679
|
+
except Exception:
|
|
680
|
+
ratio_vs_baseline = float("nan")
|
|
681
|
+
|
|
682
|
+
payload = {
|
|
683
|
+
"kind": metric.kind,
|
|
684
|
+
"unit": metric.unit,
|
|
685
|
+
"direction": metric.direction,
|
|
686
|
+
"aggregation_scope": metric.aggregation_scope,
|
|
687
|
+
"paired": metric.paired,
|
|
688
|
+
"gating_basis": metric.gating_basis,
|
|
689
|
+
"supports_bootstrap": metric.supports_bootstrap,
|
|
690
|
+
"preview": preview_point,
|
|
691
|
+
"final": final_point,
|
|
692
|
+
"ratio_vs_baseline": ratio_vs_baseline,
|
|
693
|
+
}
|
|
694
|
+
# Carry counts for accuracy to aid gating
|
|
695
|
+
if kind in {"accuracy", "vqa_accuracy"}:
|
|
696
|
+
if "n_prev" in locals() and n_prev is not None:
|
|
697
|
+
payload["n_preview"] = int(n_prev)
|
|
698
|
+
if "n_fin" in locals() and n_fin is not None:
|
|
699
|
+
payload["n_final"] = int(n_fin)
|
|
700
|
+
# Carry counts_source/estimated tag when available
|
|
701
|
+
if isinstance(counts_source_tag, str) and counts_source_tag:
|
|
702
|
+
payload["counts_source"] = counts_source_tag
|
|
703
|
+
payload["estimated"] = counts_source_tag != "measured"
|
|
704
|
+
return payload
|
|
705
|
+
|
|
706
|
+
|
|
707
|
+
def validate_primary_metric_block(block: dict[str, Any]) -> dict[str, Any]:
|
|
708
|
+
"""Validate that a primary_metric block has finite preview/final values.
|
|
709
|
+
|
|
710
|
+
Raises ValidationError(E402) when preview or final are non-finite.
|
|
711
|
+
|
|
712
|
+
Returns the input block on success to enable fluent usage.
|
|
713
|
+
"""
|
|
714
|
+
try:
|
|
715
|
+
prev = float(block.get("preview"))
|
|
716
|
+
fin = float(block.get("final"))
|
|
717
|
+
except Exception as err:
|
|
718
|
+
raise ValidationError(
|
|
719
|
+
code="E402",
|
|
720
|
+
message="METRICS-VALIDATION-FAILED",
|
|
721
|
+
details={"reason": "missing preview/final"},
|
|
722
|
+
) from err
|
|
723
|
+
if not math.isfinite(prev) or not math.isfinite(fin):
|
|
724
|
+
details = {
|
|
725
|
+
"reason": "non-finite primary_metric values",
|
|
726
|
+
"preview": prev,
|
|
727
|
+
"final": fin,
|
|
728
|
+
}
|
|
729
|
+
raise ValidationError(
|
|
730
|
+
code="E402", message="METRICS-VALIDATION-FAILED", details=details
|
|
731
|
+
)
|
|
732
|
+
return block
|
|
733
|
+
|
|
734
|
+
|
|
735
|
+
# --- Classification helpers (deterministic smoke path) ----------------------
|
|
736
|
+
|
|
737
|
+
|
|
738
|
+
def infer_binary_label_from_ids(input_ids: list[int]) -> int:
|
|
739
|
+
"""Deterministic binary label from token ids (parity), for smoke usage.
|
|
740
|
+
|
|
741
|
+
This is a placeholder for provider-driven labels; it enables a stable,
|
|
742
|
+
model-agnostic accuracy path for tests and demos without dataset labels.
|
|
743
|
+
"""
|
|
744
|
+
try:
|
|
745
|
+
return int(sum(int(t) for t in input_ids) % 2)
|
|
746
|
+
except Exception:
|
|
747
|
+
return 0
|
|
748
|
+
|
|
749
|
+
|
|
750
|
+
def compute_accuracy_counts(records: list[dict[str, Any]]) -> tuple[int, int]:
|
|
751
|
+
"""Compute accuracy counts from records with input_ids.
|
|
752
|
+
|
|
753
|
+
Predicts the same as the inferred label for a perfect-accuracy smoke path.
|
|
754
|
+
Returns (correct_total, total).
|
|
755
|
+
"""
|
|
756
|
+
correct = 0
|
|
757
|
+
total = 0
|
|
758
|
+
for rec in records:
|
|
759
|
+
seq = rec.get("input_ids") if isinstance(rec, dict) else None
|
|
760
|
+
if not isinstance(seq, list) or not seq:
|
|
761
|
+
continue
|
|
762
|
+
y = infer_binary_label_from_ids(seq)
|
|
763
|
+
yhat = y # perfect prediction in smoke path
|
|
764
|
+
if int(yhat) == int(y):
|
|
765
|
+
correct += 1
|
|
766
|
+
total += 1
|
|
767
|
+
return correct, total
|