invarlock 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (132) hide show
  1. invarlock/__init__.py +33 -0
  2. invarlock/__main__.py +10 -0
  3. invarlock/_data/runtime/profiles/ci_cpu.yaml +15 -0
  4. invarlock/_data/runtime/profiles/release.yaml +23 -0
  5. invarlock/_data/runtime/tiers.yaml +76 -0
  6. invarlock/adapters/__init__.py +102 -0
  7. invarlock/adapters/_capabilities.py +45 -0
  8. invarlock/adapters/auto.py +99 -0
  9. invarlock/adapters/base.py +530 -0
  10. invarlock/adapters/base_types.py +85 -0
  11. invarlock/adapters/hf_bert.py +852 -0
  12. invarlock/adapters/hf_gpt2.py +403 -0
  13. invarlock/adapters/hf_llama.py +485 -0
  14. invarlock/adapters/hf_mixin.py +383 -0
  15. invarlock/adapters/hf_onnx.py +112 -0
  16. invarlock/adapters/hf_t5.py +137 -0
  17. invarlock/adapters/py.typed +1 -0
  18. invarlock/assurance/__init__.py +43 -0
  19. invarlock/cli/__init__.py +8 -0
  20. invarlock/cli/__main__.py +8 -0
  21. invarlock/cli/_evidence.py +25 -0
  22. invarlock/cli/_json.py +75 -0
  23. invarlock/cli/adapter_auto.py +162 -0
  24. invarlock/cli/app.py +287 -0
  25. invarlock/cli/commands/__init__.py +26 -0
  26. invarlock/cli/commands/certify.py +403 -0
  27. invarlock/cli/commands/doctor.py +1358 -0
  28. invarlock/cli/commands/explain_gates.py +151 -0
  29. invarlock/cli/commands/export_html.py +100 -0
  30. invarlock/cli/commands/plugins.py +1331 -0
  31. invarlock/cli/commands/report.py +354 -0
  32. invarlock/cli/commands/run.py +4146 -0
  33. invarlock/cli/commands/verify.py +1040 -0
  34. invarlock/cli/config.py +396 -0
  35. invarlock/cli/constants.py +68 -0
  36. invarlock/cli/device.py +92 -0
  37. invarlock/cli/doctor_helpers.py +74 -0
  38. invarlock/cli/errors.py +6 -0
  39. invarlock/cli/overhead_utils.py +60 -0
  40. invarlock/cli/provenance.py +66 -0
  41. invarlock/cli/utils.py +41 -0
  42. invarlock/config.py +56 -0
  43. invarlock/core/__init__.py +62 -0
  44. invarlock/core/abi.py +15 -0
  45. invarlock/core/api.py +274 -0
  46. invarlock/core/auto_tuning.py +317 -0
  47. invarlock/core/bootstrap.py +226 -0
  48. invarlock/core/checkpoint.py +221 -0
  49. invarlock/core/contracts.py +73 -0
  50. invarlock/core/error_utils.py +64 -0
  51. invarlock/core/events.py +298 -0
  52. invarlock/core/exceptions.py +95 -0
  53. invarlock/core/registry.py +481 -0
  54. invarlock/core/retry.py +146 -0
  55. invarlock/core/runner.py +2041 -0
  56. invarlock/core/types.py +154 -0
  57. invarlock/edits/__init__.py +12 -0
  58. invarlock/edits/_edit_utils.py +249 -0
  59. invarlock/edits/_external_utils.py +268 -0
  60. invarlock/edits/noop.py +47 -0
  61. invarlock/edits/py.typed +1 -0
  62. invarlock/edits/quant_rtn.py +801 -0
  63. invarlock/edits/registry.py +166 -0
  64. invarlock/eval/__init__.py +23 -0
  65. invarlock/eval/bench.py +1207 -0
  66. invarlock/eval/bootstrap.py +50 -0
  67. invarlock/eval/data.py +2052 -0
  68. invarlock/eval/metrics.py +2167 -0
  69. invarlock/eval/primary_metric.py +767 -0
  70. invarlock/eval/probes/__init__.py +24 -0
  71. invarlock/eval/probes/fft.py +139 -0
  72. invarlock/eval/probes/mi.py +213 -0
  73. invarlock/eval/probes/post_attention.py +323 -0
  74. invarlock/eval/providers/base.py +67 -0
  75. invarlock/eval/providers/seq2seq.py +111 -0
  76. invarlock/eval/providers/text_lm.py +113 -0
  77. invarlock/eval/providers/vision_text.py +93 -0
  78. invarlock/eval/py.typed +1 -0
  79. invarlock/guards/__init__.py +18 -0
  80. invarlock/guards/_contracts.py +9 -0
  81. invarlock/guards/invariants.py +640 -0
  82. invarlock/guards/policies.py +805 -0
  83. invarlock/guards/py.typed +1 -0
  84. invarlock/guards/rmt.py +2097 -0
  85. invarlock/guards/spectral.py +1419 -0
  86. invarlock/guards/tier_config.py +354 -0
  87. invarlock/guards/variance.py +3298 -0
  88. invarlock/guards_ref/__init__.py +15 -0
  89. invarlock/guards_ref/rmt_ref.py +40 -0
  90. invarlock/guards_ref/spectral_ref.py +135 -0
  91. invarlock/guards_ref/variance_ref.py +60 -0
  92. invarlock/model_profile.py +353 -0
  93. invarlock/model_utils.py +221 -0
  94. invarlock/observability/__init__.py +10 -0
  95. invarlock/observability/alerting.py +535 -0
  96. invarlock/observability/core.py +546 -0
  97. invarlock/observability/exporters.py +565 -0
  98. invarlock/observability/health.py +588 -0
  99. invarlock/observability/metrics.py +457 -0
  100. invarlock/observability/py.typed +1 -0
  101. invarlock/observability/utils.py +553 -0
  102. invarlock/plugins/__init__.py +12 -0
  103. invarlock/plugins/hello_guard.py +33 -0
  104. invarlock/plugins/hf_awq_adapter.py +82 -0
  105. invarlock/plugins/hf_bnb_adapter.py +79 -0
  106. invarlock/plugins/hf_gptq_adapter.py +78 -0
  107. invarlock/plugins/py.typed +1 -0
  108. invarlock/py.typed +1 -0
  109. invarlock/reporting/__init__.py +7 -0
  110. invarlock/reporting/certificate.py +3221 -0
  111. invarlock/reporting/certificate_schema.py +244 -0
  112. invarlock/reporting/dataset_hashing.py +215 -0
  113. invarlock/reporting/guards_analysis.py +948 -0
  114. invarlock/reporting/html.py +32 -0
  115. invarlock/reporting/normalizer.py +235 -0
  116. invarlock/reporting/policy_utils.py +517 -0
  117. invarlock/reporting/primary_metric_utils.py +265 -0
  118. invarlock/reporting/render.py +1442 -0
  119. invarlock/reporting/report.py +903 -0
  120. invarlock/reporting/report_types.py +278 -0
  121. invarlock/reporting/utils.py +175 -0
  122. invarlock/reporting/validate.py +631 -0
  123. invarlock/security.py +176 -0
  124. invarlock/sparsity_utils.py +323 -0
  125. invarlock/utils/__init__.py +150 -0
  126. invarlock/utils/digest.py +45 -0
  127. invarlock-0.2.0.dist-info/METADATA +586 -0
  128. invarlock-0.2.0.dist-info/RECORD +132 -0
  129. invarlock-0.2.0.dist-info/WHEEL +5 -0
  130. invarlock-0.2.0.dist-info/entry_points.txt +20 -0
  131. invarlock-0.2.0.dist-info/licenses/LICENSE +201 -0
  132. invarlock-0.2.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,767 @@
1
+ """
2
+ Primary metric abstraction and minimal ppl_causal implementation (Phase 1).
3
+
4
+ This module introduces a light-weight, task-agnostic metric interface and a
5
+ registry so the runner/certificate can evolve beyond causal-LM perplexity.
6
+
7
+ Phase 1 goal: provide a ppl_causal metric and a helper that can compute point
8
+ estimates directly from evaluation window aggregates already present in run
9
+ reports. This is the canonical path; no env flag toggles.
10
+ """
11
+
12
+ from __future__ import annotations
13
+
14
+ import math
15
+ from collections.abc import Iterable, Sequence
16
+ from dataclasses import dataclass
17
+ from typing import Any, Protocol
18
+
19
+ import numpy as np
20
+
21
+ from invarlock.core.bootstrap import compute_paired_delta_log_ci
22
+ from invarlock.core.exceptions import ValidationError
23
+
24
+
25
+ @dataclass
26
+ class MetricDefaults:
27
+ reps: int = 2000
28
+ ci_level: float = 0.95
29
+
30
+
31
+ @dataclass
32
+ class MetricContribution:
33
+ """Per-example contribution to a metric.
34
+
35
+ For ppl_* metrics, `value` is per-token mean log-loss for the example and
36
+ `weight` is the number of target tokens. For accuracy, `value` is 0/1 and
37
+ `weight` is ignored.
38
+ """
39
+
40
+ value: float
41
+ weight: float = 1.0
42
+ id: str | None = None
43
+
44
+
45
+ class PrimaryMetric(Protocol):
46
+ """Protocol for task-agnostic primary metrics.
47
+
48
+ Implementations should describe their semantics and provide helpers to
49
+ compute point estimates (and optionally CIs) from per-example aggregates.
50
+ """
51
+
52
+ kind: str
53
+ unit: str
54
+ direction: str # "lower" | "higher"
55
+ aggregation_scope: str # "token" | "sequence" | "example"
56
+ paired: bool
57
+ gating_basis: str # "point" | "upper" | "lower"
58
+ supports_bootstrap: bool
59
+ defaults: MetricDefaults
60
+
61
+ def display_transform(self, x: float) -> float:
62
+ """Map native comparison space to display space.
63
+
64
+ ppl_*: exp(x) maps Δlog-loss → ratio
65
+ accuracy: x*100 maps proportion Δ → percentage points
66
+ """
67
+
68
+ def point_from_windows(self, *, windows: dict[str, Any]) -> float:
69
+ """Compute a single point estimate from evaluation windows.
70
+
71
+ For token-aggregated loss metrics (like ppl), this expects:
72
+ windows = {"logloss": [...], "token_counts": [...]} with matching lengths.
73
+ """
74
+
75
+ def accumulate(self, contrib: MetricContribution) -> None:
76
+ """Accumulate a per-example contribution for finalize()."""
77
+
78
+ def finalize(self) -> float:
79
+ """Return a point estimate from accumulated contributions (display space)."""
80
+
81
+ def paired_compare(
82
+ self,
83
+ subject: Iterable[dict[str, Any] | MetricContribution],
84
+ baseline: Iterable[dict[str, Any] | MetricContribution],
85
+ *,
86
+ reps: int | None = None,
87
+ seed: int | None = None,
88
+ ci_level: float | None = None,
89
+ ) -> dict[str, Any]:
90
+ """Paired compare subject vs baseline with bootstrap CI.
91
+
92
+ Returns a dict with native-space delta and display-space values:
93
+ {
94
+ 'delta': float, # native compare space (Δlog-loss for ppl, Δ proportion for accuracy)
95
+ 'ci': (lo, hi), # native-space CI
96
+ 'display': float, # display space (ratio for ppl, pp for accuracy)
97
+ 'display_ci': (lo, hi),
98
+ 'subject_point': float, # point estimate in display space
99
+ 'baseline_point': float, # point estimate in display space
100
+ 'reps': int, 'ci_level': float, 'paired': True, meta...
101
+ }
102
+ """
103
+
104
+
105
+ @dataclass
106
+ class MetricInfo:
107
+ kind: str
108
+ unit: str
109
+ direction: str
110
+ aggregation_scope: str
111
+ paired: bool
112
+ gating_basis: str
113
+ supports_bootstrap: bool
114
+
115
+
116
+ class _PPLCausal(PrimaryMetric):
117
+ """Token-aggregated perplexity for causal LMs.
118
+
119
+ point_from_windows computes ppl = exp(weighted_mean(logloss)).
120
+ """
121
+
122
+ kind = "ppl_causal"
123
+ unit = "ppl"
124
+ direction = "lower"
125
+ aggregation_scope = "token"
126
+ paired = True
127
+ gating_basis = "upper" # typical gate on ratio upper-bound
128
+ supports_bootstrap = True
129
+ defaults = MetricDefaults()
130
+
131
+ def __init__(self) -> None:
132
+ self._values: list[float] = []
133
+ self._weights: list[float] = []
134
+
135
+ def display_transform(self, x: float) -> float:
136
+ try:
137
+ return float(math.exp(x))
138
+ except OverflowError:
139
+ return float("inf")
140
+
141
+ def point_from_windows(self, *, windows: dict[str, Any]) -> float:
142
+ logloss = list(windows.get("logloss", []) or [])
143
+ token_counts = list(windows.get("token_counts", []) or [])
144
+ if not logloss or not token_counts or len(logloss) != len(token_counts):
145
+ return float("nan")
146
+ sum_w = 0.0
147
+ sum_wx = 0.0
148
+ for ll, w in zip(logloss, token_counts, strict=False):
149
+ try:
150
+ llv = float(ll)
151
+ wv = float(w)
152
+ except Exception:
153
+ continue
154
+ if not math.isfinite(llv) or not math.isfinite(wv) or wv <= 0:
155
+ continue
156
+ sum_w += wv
157
+ sum_wx += wv * llv
158
+ if sum_w <= 0.0:
159
+ return float("nan")
160
+ mean_ll = sum_wx / sum_w
161
+ try:
162
+ return float(math.exp(mean_ll))
163
+ except OverflowError:
164
+ return float("inf")
165
+
166
+ def accumulate(self, contrib: MetricContribution) -> None:
167
+ try:
168
+ v = float(contrib.value)
169
+ w = float(contrib.weight)
170
+ except Exception:
171
+ return
172
+ if not math.isfinite(v) or not math.isfinite(w) or w <= 0:
173
+ return
174
+ self._values.append(v)
175
+ self._weights.append(w)
176
+
177
+ def finalize(self) -> float:
178
+ if (
179
+ not self._values
180
+ or not self._weights
181
+ or len(self._values) != len(self._weights)
182
+ ):
183
+ return float("nan")
184
+ sw = 0.0
185
+ swx = 0.0
186
+ for v, w in zip(self._values, self._weights, strict=False):
187
+ sw += w
188
+ swx += w * v
189
+ if sw <= 0.0:
190
+ return float("nan")
191
+ return self.display_transform(swx / sw)
192
+
193
+ def _coerce_contrib_array(
194
+ self, items: Iterable[dict[str, Any] | MetricContribution]
195
+ ) -> list[tuple[float, float]]:
196
+ out: list[tuple[float, float]] = []
197
+ for it in items:
198
+ if isinstance(it, MetricContribution):
199
+ out.append((float(it.value), float(it.weight)))
200
+ elif isinstance(it, dict) and "value" in it:
201
+ v = float(it.get("value"))
202
+ w = float(it.get("weight", 1.0))
203
+ out.append((v, w))
204
+ return out
205
+
206
+ def paired_compare(
207
+ self,
208
+ subject: Iterable[dict[str, Any] | MetricContribution],
209
+ baseline: Iterable[dict[str, Any] | MetricContribution],
210
+ *,
211
+ reps: int | None = None,
212
+ seed: int | None = None,
213
+ ci_level: float | None = None,
214
+ ) -> dict[str, Any]:
215
+ subj = self._coerce_contrib_array(subject)
216
+ base = self._coerce_contrib_array(baseline)
217
+ # Compute simple (unweighted) per-example arrays in log space; weights ignored for bootstrap here
218
+ subj_vals = [v for (v, _w) in subj]
219
+ base_vals = [v for (v, _w) in base]
220
+
221
+ # Points in display space
222
+ def _point(
223
+ vals: Sequence[float], weights: Sequence[float] | None = None
224
+ ) -> float:
225
+ if not vals:
226
+ return float("nan")
227
+ if weights and len(weights) == len(vals):
228
+ sw = 0.0
229
+ swx = 0.0
230
+ for v, w in zip(vals, weights, strict=False):
231
+ sw += w
232
+ swx += w * v
233
+ if sw <= 0:
234
+ return float("nan")
235
+ return self.display_transform(swx / sw)
236
+ else:
237
+ return self.display_transform(sum(vals) / float(len(vals)))
238
+
239
+ subj_point = _point([v for v, _ in subj], [w for _, w in subj])
240
+ base_point = _point([v for v, _ in base], [w for _, w in base])
241
+
242
+ # Bootstrap Δlog-loss → CI, then display-transform → ratio CI
243
+ reps_eff = int(reps) if (reps is not None and reps > 0) else self.defaults.reps
244
+ seed_eff = int(seed) if (seed is not None) else 0
245
+ ci_level_eff = (
246
+ float(ci_level) if (ci_level is not None) else self.defaults.ci_level
247
+ )
248
+ alpha = 1.0 - ci_level_eff
249
+ dlog_lo, dlog_hi = compute_paired_delta_log_ci(
250
+ subj_vals,
251
+ base_vals,
252
+ method="bca",
253
+ replicates=reps_eff,
254
+ alpha=alpha,
255
+ seed=seed_eff,
256
+ )
257
+ delta_log = float(
258
+ sum((s - b) for s, b in zip(subj_vals, base_vals, strict=False))
259
+ / max(1, min(len(subj_vals), len(base_vals)))
260
+ )
261
+ ratio = self.display_transform(delta_log)
262
+ return {
263
+ "kind": self.kind,
264
+ "unit": self.unit,
265
+ "direction": self.direction,
266
+ "aggregation_scope": self.aggregation_scope,
267
+ "paired": True,
268
+ "gating_basis": self.gating_basis,
269
+ "supports_bootstrap": self.supports_bootstrap,
270
+ "reps": reps_eff,
271
+ "ci_level": ci_level_eff,
272
+ "subject_point": subj_point,
273
+ "baseline_point": base_point,
274
+ "delta": delta_log,
275
+ "ci": (dlog_lo, dlog_hi),
276
+ "display": ratio,
277
+ "display_ci": (
278
+ self.display_transform(dlog_lo),
279
+ self.display_transform(dlog_hi),
280
+ ),
281
+ }
282
+
283
+
284
+ # ── Simple registry ───────────────────────────────────────────────────────
285
+ class _PPLMLM(_PPLCausal):
286
+ """Masked LM perplexity.
287
+
288
+ Uses masked_token_counts when available; falls back to token_counts.
289
+ """
290
+
291
+ kind = "ppl_mlm"
292
+
293
+ def point_from_windows(self, *, windows: dict[str, Any]) -> float:
294
+ # Prefer masked_token_counts for MLM
295
+ masked = list(windows.get("masked_token_counts", []) or [])
296
+ if masked:
297
+ win = {"logloss": windows.get("logloss", []), "token_counts": masked}
298
+ return super().point_from_windows(windows=win)
299
+ return super().point_from_windows(windows=windows)
300
+
301
+
302
+ class _PPLSeq2Seq(_PPLCausal):
303
+ """Seq2Seq perplexity (token-aggregated over decoder labels)."""
304
+
305
+ kind = "ppl_seq2seq"
306
+
307
+
308
+ class _Accuracy:
309
+ """Example-aggregated accuracy (0..1).
310
+
311
+ Accepts either per-example flags (example_correct) or aggregate counts
312
+ (correct_total/total or correct_counts/total_counts).
313
+ """
314
+
315
+ kind = "accuracy"
316
+ unit = "accuracy"
317
+ direction = "higher"
318
+ aggregation_scope = "example"
319
+ paired = True
320
+ gating_basis = "lower"
321
+ supports_bootstrap = True
322
+ defaults = MetricDefaults()
323
+
324
+ def __init__(self) -> None:
325
+ self._values: list[float] = []
326
+
327
+ def display_transform(self, x: float) -> float: # proportion → percentage points
328
+ return float(x * 100.0)
329
+
330
+ def point_from_windows(self, *, windows: dict[str, Any]) -> float:
331
+ # Per-example path
332
+ ex = list(windows.get("example_correct", []) or [])
333
+ if ex:
334
+ s = 0.0
335
+ n = 0.0
336
+ for v in ex:
337
+ try:
338
+ s += 1.0 if float(v) > 0.5 else 0.0
339
+ n += 1.0
340
+ except Exception:
341
+ continue
342
+ if n > 0:
343
+ return s / n
344
+ return float("nan")
345
+ # Aggregate counts path
346
+ for c_key, t_key in (
347
+ ("correct_total", "total"),
348
+ ("correct_counts", "total_counts"),
349
+ ):
350
+ c = windows.get(c_key)
351
+ t = windows.get(t_key)
352
+ if isinstance(c, int | float) and isinstance(t, int | float) and t > 0:
353
+ total = float(t)
354
+ # Optional abstain/tie handling with documented policy
355
+ try:
356
+ policy = (
357
+ windows.get("policy", {})
358
+ if isinstance(windows.get("policy"), dict)
359
+ else {}
360
+ )
361
+ abstain = windows.get("abstain_total")
362
+ ties = windows.get("ties_total")
363
+ exclude_abstain = bool(policy.get("exclude_abstain", True))
364
+ count_ties_as_correct = bool(
365
+ policy.get("ties_count_as_correct", False)
366
+ )
367
+ count_ties_as_incorrect = bool(
368
+ policy.get("ties_count_as_incorrect", False)
369
+ )
370
+ # Apply abstain exclusion from denominator if requested
371
+ if (
372
+ exclude_abstain
373
+ and isinstance(abstain, int | float)
374
+ and abstain > 0
375
+ ):
376
+ total = max(1.0, total - float(abstain))
377
+ # Apply tie policy
378
+ if isinstance(ties, int | float) and ties > 0:
379
+ if count_ties_as_correct:
380
+ c = float(c) + float(ties)
381
+ elif count_ties_as_incorrect:
382
+ # leave c unchanged; implicit in denominator
383
+ pass
384
+ else:
385
+ # default: treat ties as abstain (exclude if exclude_abstain True)
386
+ if exclude_abstain:
387
+ total = max(1.0, total - float(ties))
388
+ except Exception:
389
+ pass
390
+ return float(c) / float(total)
391
+ return float("nan")
392
+
393
+ def accumulate(self, contrib: MetricContribution) -> None:
394
+ try:
395
+ v = float(contrib.value)
396
+ except Exception:
397
+ return
398
+ if not math.isfinite(v):
399
+ return
400
+ # Clamp to [0,1]
401
+ v = 1.0 if v >= 0.5 else 0.0
402
+ self._values.append(v)
403
+
404
+ def finalize(self) -> float:
405
+ if not self._values:
406
+ return float("nan")
407
+ return float(sum(self._values) / float(len(self._values)))
408
+
409
+ def _coerce_vals(
410
+ self, items: Iterable[dict[str, Any] | MetricContribution]
411
+ ) -> list[float]:
412
+ out: list[float] = []
413
+ for it in items:
414
+ if isinstance(it, MetricContribution):
415
+ out.append(1.0 if float(it.value) >= 0.5 else 0.0)
416
+ elif isinstance(it, dict) and "value" in it:
417
+ v = float(it.get("value"))
418
+ out.append(1.0 if v >= 0.5 else 0.0)
419
+ return out
420
+
421
+ def paired_compare(
422
+ self,
423
+ subject: Iterable[dict[str, Any] | MetricContribution],
424
+ baseline: Iterable[dict[str, Any] | MetricContribution],
425
+ *,
426
+ reps: int | None = None,
427
+ seed: int | None = None,
428
+ ci_level: float | None = None,
429
+ ) -> dict[str, Any]:
430
+ subj = self._coerce_vals(subject)
431
+ base = self._coerce_vals(baseline)
432
+ m = min(len(subj), len(base))
433
+ subj = subj[:m]
434
+ base = base[:m]
435
+ if m == 0:
436
+ return {
437
+ "kind": self.kind,
438
+ "unit": self.unit,
439
+ "direction": self.direction,
440
+ "aggregation_scope": self.aggregation_scope,
441
+ "paired": True,
442
+ "gating_basis": self.gating_basis,
443
+ "supports_bootstrap": self.supports_bootstrap,
444
+ "reps": 0,
445
+ "ci_level": ci_level or self.defaults.ci_level,
446
+ "subject_point": float("nan"),
447
+ "baseline_point": float("nan"),
448
+ "delta": float("nan"),
449
+ "ci": (float("nan"), float("nan")),
450
+ "display": float("nan"),
451
+ "display_ci": (float("nan"), float("nan")),
452
+ }
453
+ # Points in display space for subject/baseline (proportions, no transform)
454
+ subj_point = float(sum(subj) / float(m))
455
+ base_point = float(sum(base) / float(m))
456
+ # Δ in native (proportion) space
457
+ diffs = [float(s - b) for s, b in zip(subj, base, strict=False)]
458
+ delta = float(sum(diffs) / float(m))
459
+ reps_eff = int(reps) if (reps is not None and reps > 0) else self.defaults.reps
460
+ seed_eff = int(seed) if (seed is not None) else 0
461
+ ci_level_eff = (
462
+ float(ci_level) if (ci_level is not None) else self.defaults.ci_level
463
+ )
464
+ alpha = 1.0 - ci_level_eff
465
+ # Percentile bootstrap on paired diffs
466
+ rng = np.random.default_rng(seed_eff) # type: ignore[name-defined]
467
+ stats = []
468
+ for _ in range(reps_eff):
469
+ idx = rng.integers(0, m, size=m)
470
+ s = 0.0
471
+ for i in idx:
472
+ s += diffs[i]
473
+ stats.append(s / float(m))
474
+ stats.sort()
475
+ lo = float(np.percentile(stats, 100.0 * (alpha / 2.0))) # type: ignore[name-defined]
476
+ hi = float(np.percentile(stats, 100.0 * (1.0 - alpha / 2.0))) # type: ignore[name-defined]
477
+ return {
478
+ "kind": self.kind,
479
+ "unit": self.unit,
480
+ "direction": self.direction,
481
+ "aggregation_scope": self.aggregation_scope,
482
+ "paired": True,
483
+ "gating_basis": self.gating_basis,
484
+ "supports_bootstrap": self.supports_bootstrap,
485
+ "reps": reps_eff,
486
+ "ci_level": ci_level_eff,
487
+ "subject_point": subj_point,
488
+ "baseline_point": base_point,
489
+ "delta": delta,
490
+ "ci": (lo, hi),
491
+ "display": self.display_transform(delta),
492
+ "display_ci": (self.display_transform(lo), self.display_transform(hi)),
493
+ }
494
+
495
+
496
+ class _AliasMetric:
497
+ """Light alias wrapper that re-labels `kind` while delegating behavior."""
498
+
499
+ def __init__(self, alias: str, base: PrimaryMetric) -> None:
500
+ self._alias = str(alias)
501
+ self._base = base
502
+ # Copy metadata
503
+ self.kind = self._alias
504
+ self.unit = base.unit
505
+ self.direction = base.direction
506
+ self.aggregation_scope = base.aggregation_scope
507
+ self.paired = base.paired
508
+ self.gating_basis = base.gating_basis
509
+ self.supports_bootstrap = base.supports_bootstrap
510
+
511
+ def point_from_windows(self, *, windows: dict[str, Any]) -> float:
512
+ return self._base.point_from_windows(windows=windows)
513
+
514
+
515
+ _REGISTRY: dict[str, PrimaryMetric] = {
516
+ _PPLCausal.kind: _PPLCausal(),
517
+ _PPLMLM.kind: _PPLMLM(),
518
+ _PPLSeq2Seq.kind: _PPLSeq2Seq(),
519
+ _Accuracy.kind: _Accuracy(),
520
+ # Multimodal aliases
521
+ "vqa_accuracy": _AliasMetric("vqa_accuracy", _Accuracy()),
522
+ }
523
+
524
+
525
+ def get_metric(kind: str) -> PrimaryMetric:
526
+ key = str(kind).lower()
527
+ if key in _REGISTRY:
528
+ return _REGISTRY[key]
529
+ raise KeyError(f"Unknown metric kind: {kind}")
530
+
531
+
532
+ def compute_primary_metric_from_report(
533
+ report: dict[str, Any],
534
+ *,
535
+ kind: str = "ppl_causal",
536
+ baseline: dict[str, Any] | None = None,
537
+ ) -> dict[str, Any]:
538
+ """Compute a primary metric snapshot from a run report (Phase 1 helper).
539
+
540
+ Returns a dict that can be attached to report["metrics"]["primary_metric"].
541
+ Includes preview/final points and (when baseline is present) a simple ratio
542
+ vs baseline based on baseline ppl_final.
543
+ """
544
+ metric = get_metric(kind)
545
+ windows = report.get("evaluation_windows") if isinstance(report, dict) else None
546
+
547
+ # Choose window sections
548
+ preview_win: dict[str, Any] = {}
549
+ final_win: dict[str, Any] = {}
550
+
551
+ counts_source_tag: str | None = None
552
+ if kind in {"accuracy", "vqa_accuracy"}:
553
+ # Prefer classification aggregates if provided (may not have evaluation_windows)
554
+ metrics = (
555
+ report.get("metrics", {}) if isinstance(report.get("metrics"), dict) else {}
556
+ )
557
+ clf = metrics.get("classification") if isinstance(metrics, dict) else None
558
+ if isinstance(clf, dict) and clf:
559
+ prev = (
560
+ clf.get("preview", {}) if isinstance(clf.get("preview"), dict) else {}
561
+ )
562
+ fin = clf.get("final", {}) if isinstance(clf.get("final"), dict) else {}
563
+ preview_win = prev
564
+ final_win = fin
565
+ # Attach counts into a small context to help gating
566
+ try:
567
+ n_prev = None
568
+ n_fin = None
569
+ if isinstance(prev.get("total"), int | float):
570
+ n_prev = int(prev.get("total"))
571
+ elif isinstance(prev.get("example_correct"), list):
572
+ n_prev = len(prev.get("example_correct"))
573
+ if isinstance(fin.get("total"), int | float):
574
+ n_fin = int(fin.get("total"))
575
+ elif isinstance(fin.get("example_correct"), list):
576
+ n_fin = len(fin.get("example_correct"))
577
+ except Exception:
578
+ n_prev = None
579
+ n_fin = None
580
+ # Propagate counts source tagging when present
581
+ try:
582
+ counts_source = clf.get("counts_source")
583
+ if isinstance(counts_source, str) and counts_source:
584
+ counts_source_tag = counts_source
585
+ except Exception:
586
+ pass
587
+
588
+ if not preview_win and not final_win and isinstance(windows, dict):
589
+ preview_win = (
590
+ windows.get("preview", {})
591
+ if isinstance(windows.get("preview"), dict)
592
+ else {}
593
+ )
594
+ final_win = (
595
+ windows.get("final", {}) if isinstance(windows.get("final"), dict) else {}
596
+ )
597
+
598
+ if not preview_win and not final_win:
599
+ # Nothing to compute from
600
+ return {
601
+ "kind": metric.kind,
602
+ "unit": metric.unit,
603
+ "direction": metric.direction,
604
+ "aggregation_scope": metric.aggregation_scope,
605
+ "paired": metric.paired,
606
+ "gating_basis": metric.gating_basis,
607
+ "supports_bootstrap": metric.supports_bootstrap,
608
+ "preview": float("nan"),
609
+ "final": float("nan"),
610
+ "ratio_vs_baseline": float("nan"),
611
+ }
612
+ # For accuracy kinds, derive counts from input_ids if aggregates are missing
613
+ if kind in {"accuracy", "vqa_accuracy"}:
614
+
615
+ def _ensure_counts(win: dict[str, Any]) -> dict[str, Any]:
616
+ if not isinstance(win, dict):
617
+ return {}
618
+ has_counts = (
619
+ isinstance(win.get("correct_total"), int | float)
620
+ and isinstance(win.get("total"), int | float)
621
+ and win.get("total") > 0
622
+ )
623
+ if has_counts:
624
+ return win
625
+ # Try to derive from input_ids deterministically
626
+ recs = []
627
+ seqs = (
628
+ win.get("input_ids") if isinstance(win.get("input_ids"), list) else []
629
+ )
630
+ if isinstance(seqs, list) and seqs:
631
+ for seq in seqs:
632
+ if isinstance(seq, list):
633
+ recs.append({"input_ids": seq})
634
+ if recs:
635
+ try:
636
+ c, n = compute_accuracy_counts(recs)
637
+ return {"correct_total": int(c), "total": int(n)}
638
+ except Exception:
639
+ return win
640
+ return win
641
+
642
+ preview_win = _ensure_counts(preview_win)
643
+ final_win = _ensure_counts(final_win)
644
+
645
+ preview_point = metric.point_from_windows(windows=preview_win)
646
+ final_point = metric.point_from_windows(windows=final_win)
647
+
648
+ ratio_vs_baseline = float("nan")
649
+ if isinstance(baseline, dict):
650
+ try:
651
+ base_metrics = (
652
+ baseline.get("metrics", {})
653
+ if isinstance(baseline.get("metrics"), dict)
654
+ else {}
655
+ )
656
+ pm_base = base_metrics.get("primary_metric")
657
+ base_kind = (
658
+ str(pm_base.get("kind", "")).lower()
659
+ if isinstance(pm_base, dict)
660
+ else ""
661
+ )
662
+ kind_l = str(kind).lower()
663
+ ppl_kinds = {"ppl_causal", "ppl_mlm", "ppl_seq2seq"}
664
+ acc_kinds = {"accuracy", "vqa_accuracy"}
665
+ same_family = (kind_l in ppl_kinds and base_kind in ppl_kinds) or (
666
+ kind_l in acc_kinds and base_kind in acc_kinds
667
+ )
668
+ if isinstance(pm_base, dict) and (base_kind == kind_l or same_family):
669
+ base_ref = pm_base.get("final")
670
+ if isinstance(base_ref, (int | float)):
671
+ is_ppl_like = str(kind).lower().startswith("ppl")
672
+ if is_ppl_like and base_ref > 0:
673
+ ratio_vs_baseline = float(final_point) / float(base_ref)
674
+ elif (
675
+ str(kind).lower() in {"accuracy", "vqa_accuracy"}
676
+ and 0 <= base_ref <= 1
677
+ ):
678
+ ratio_vs_baseline = float(final_point) - float(base_ref)
679
+ except Exception:
680
+ ratio_vs_baseline = float("nan")
681
+
682
+ payload = {
683
+ "kind": metric.kind,
684
+ "unit": metric.unit,
685
+ "direction": metric.direction,
686
+ "aggregation_scope": metric.aggregation_scope,
687
+ "paired": metric.paired,
688
+ "gating_basis": metric.gating_basis,
689
+ "supports_bootstrap": metric.supports_bootstrap,
690
+ "preview": preview_point,
691
+ "final": final_point,
692
+ "ratio_vs_baseline": ratio_vs_baseline,
693
+ }
694
+ # Carry counts for accuracy to aid gating
695
+ if kind in {"accuracy", "vqa_accuracy"}:
696
+ if "n_prev" in locals() and n_prev is not None:
697
+ payload["n_preview"] = int(n_prev)
698
+ if "n_fin" in locals() and n_fin is not None:
699
+ payload["n_final"] = int(n_fin)
700
+ # Carry counts_source/estimated tag when available
701
+ if isinstance(counts_source_tag, str) and counts_source_tag:
702
+ payload["counts_source"] = counts_source_tag
703
+ payload["estimated"] = counts_source_tag != "measured"
704
+ return payload
705
+
706
+
707
+ def validate_primary_metric_block(block: dict[str, Any]) -> dict[str, Any]:
708
+ """Validate that a primary_metric block has finite preview/final values.
709
+
710
+ Raises ValidationError(E402) when preview or final are non-finite.
711
+
712
+ Returns the input block on success to enable fluent usage.
713
+ """
714
+ try:
715
+ prev = float(block.get("preview"))
716
+ fin = float(block.get("final"))
717
+ except Exception as err:
718
+ raise ValidationError(
719
+ code="E402",
720
+ message="METRICS-VALIDATION-FAILED",
721
+ details={"reason": "missing preview/final"},
722
+ ) from err
723
+ if not math.isfinite(prev) or not math.isfinite(fin):
724
+ details = {
725
+ "reason": "non-finite primary_metric values",
726
+ "preview": prev,
727
+ "final": fin,
728
+ }
729
+ raise ValidationError(
730
+ code="E402", message="METRICS-VALIDATION-FAILED", details=details
731
+ )
732
+ return block
733
+
734
+
735
+ # --- Classification helpers (deterministic smoke path) ----------------------
736
+
737
+
738
+ def infer_binary_label_from_ids(input_ids: list[int]) -> int:
739
+ """Deterministic binary label from token ids (parity), for smoke usage.
740
+
741
+ This is a placeholder for provider-driven labels; it enables a stable,
742
+ model-agnostic accuracy path for tests and demos without dataset labels.
743
+ """
744
+ try:
745
+ return int(sum(int(t) for t in input_ids) % 2)
746
+ except Exception:
747
+ return 0
748
+
749
+
750
+ def compute_accuracy_counts(records: list[dict[str, Any]]) -> tuple[int, int]:
751
+ """Compute accuracy counts from records with input_ids.
752
+
753
+ Predicts the same as the inferred label for a perfect-accuracy smoke path.
754
+ Returns (correct_total, total).
755
+ """
756
+ correct = 0
757
+ total = 0
758
+ for rec in records:
759
+ seq = rec.get("input_ids") if isinstance(rec, dict) else None
760
+ if not isinstance(seq, list) or not seq:
761
+ continue
762
+ y = infer_binary_label_from_ids(seq)
763
+ yhat = y # perfect prediction in smoke path
764
+ if int(yhat) == int(y):
765
+ correct += 1
766
+ total += 1
767
+ return correct, total