eval-toolkit 0.27.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,376 @@
1
+ """Strict artifact helpers for JSON outputs and prediction references."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import dataclasses
6
+ import json
7
+ import math
8
+ from collections.abc import Mapping, Sequence
9
+ from dataclasses import asdict, dataclass, field
10
+ from importlib import resources
11
+ from pathlib import Path
12
+ from typing import Literal
13
+
14
+ import numpy as np
15
+
16
+ __all__ = [
17
+ "MetricState",
18
+ "PredictionArtifactRef",
19
+ "PredictionColumns",
20
+ "error_metric",
21
+ "sanitize_for_json",
22
+ "skipped_metric",
23
+ "validate_manifest",
24
+ "validate_payload",
25
+ "validate_prediction_artifact_ref",
26
+ "validate_results",
27
+ "write_json_strict",
28
+ ]
29
+
30
+ MetricStatus = Literal["ok", "skipped", "error"]
31
+
32
+
33
+ @dataclass(frozen=True, slots=True)
34
+ class MetricState:
35
+ """Structured state for metrics that may be unavailable or invalid."""
36
+
37
+ value: object | None
38
+ status: MetricStatus
39
+ reason: str = ""
40
+ details: dict[str, object] = field(default_factory=dict)
41
+
42
+ def to_dict(self) -> dict[str, object]:
43
+ """JSON-serializable representation."""
44
+ out: dict[str, object] = {
45
+ "value": sanitize_for_json(self.value),
46
+ "status": self.status,
47
+ "reason": self.reason,
48
+ }
49
+ if self.details:
50
+ out["details"] = sanitize_for_json(self.details)
51
+ return out
52
+
53
+
54
+ def skipped_metric(reason: str, **details: object) -> dict[str, object]:
55
+ """Return a structured skipped-metric payload."""
56
+ return MetricState(value=None, status="skipped", reason=reason, details=details).to_dict()
57
+
58
+
59
+ def error_metric(reason: str, **details: object) -> dict[str, object]:
60
+ """Return a structured errored-metric payload."""
61
+ return MetricState(value=None, status="error", reason=reason, details=details).to_dict()
62
+
63
+
64
+ @dataclass(frozen=True, slots=True)
65
+ class PredictionColumns:
66
+ """Column mapping for a retained prediction artifact."""
67
+
68
+ label: str
69
+ score: str
70
+ row_id: str | None = None
71
+ content_hash: str | None = None
72
+ scorer: str | None = None
73
+ slice: str | None = None
74
+ text: str | None = None
75
+ provenance: dict[str, str] = field(default_factory=dict)
76
+
77
+ def __post_init__(self) -> None:
78
+ """Validate required prediction columns."""
79
+ if not self.label:
80
+ raise ValueError("PredictionColumns.label must be non-empty")
81
+ if not self.score:
82
+ raise ValueError("PredictionColumns.score must be non-empty")
83
+
84
+ def to_dict(self) -> dict[str, object]:
85
+ """JSON-serializable representation with absent optional columns omitted."""
86
+ out: dict[str, object] = {"label": self.label, "score": self.score}
87
+ for key in ("row_id", "content_hash", "scorer", "slice", "text"):
88
+ value = getattr(self, key)
89
+ if value is not None:
90
+ out[key] = value
91
+ if self.provenance:
92
+ out["provenance"] = dict(self.provenance)
93
+ return out
94
+
95
+
96
+ @dataclass(frozen=True, slots=True)
97
+ class PredictionArtifactRef:
98
+ """Manifest reference to a retained prediction artifact.
99
+
100
+ ``role`` accepts ``str`` or ``list[str]`` since v0.15.0 (F5.2): a
101
+ single artifact that covers multiple slices / fold-roles can name them
102
+ explicitly instead of carrying a synthetic single-string role plus a
103
+ ``metadata["slices"]`` list. The schema accepts both shapes.
104
+ """
105
+
106
+ uri: str
107
+ media_type: str
108
+ columns: PredictionColumns | Mapping[str, object]
109
+ sha256: str = ""
110
+ n_rows: int | None = None
111
+ role: str | list[str] = "predictions"
112
+ metadata: dict[str, object] = field(default_factory=dict)
113
+
114
+ def __post_init__(self) -> None:
115
+ """Validate the stable prediction-reference shape."""
116
+ if not self.uri:
117
+ raise ValueError("PredictionArtifactRef.uri must be non-empty")
118
+ if not self.media_type:
119
+ raise ValueError("PredictionArtifactRef.media_type must be non-empty")
120
+ if self.n_rows is not None and (isinstance(self.n_rows, bool) or self.n_rows < 0):
121
+ raise ValueError("PredictionArtifactRef.n_rows must be non-negative when present")
122
+ if isinstance(self.role, list):
123
+ if not self.role or not all(isinstance(r, str) and r.strip() for r in self.role):
124
+ raise ValueError(
125
+ "PredictionArtifactRef.role list must be non-empty and contain "
126
+ "only non-empty strings"
127
+ )
128
+ elif isinstance(self.role, str):
129
+ if not self.role.strip():
130
+ raise ValueError("PredictionArtifactRef.role must be non-empty")
131
+ else:
132
+ raise TypeError(
133
+ f"PredictionArtifactRef.role must be str or list[str], "
134
+ f"got {type(self.role).__name__}"
135
+ )
136
+ columns = self.columns
137
+ if isinstance(columns, Mapping):
138
+ if not isinstance(columns.get("label"), str) or not columns.get("label"):
139
+ raise ValueError("PredictionArtifactRef.columns must include a label column")
140
+ if not isinstance(columns.get("score"), str) or not columns.get("score"):
141
+ raise ValueError("PredictionArtifactRef.columns must include a score column")
142
+
143
+ def to_dict(self) -> dict[str, object]:
144
+ """JSON-serializable representation."""
145
+ if isinstance(self.columns, PredictionColumns):
146
+ columns: object = self.columns.to_dict()
147
+ else:
148
+ columns = dict(self.columns)
149
+ # Preserve role as a list when caller passed a list; otherwise a string.
150
+ role: object = list(self.role) if isinstance(self.role, list) else self.role
151
+ out: dict[str, object] = {
152
+ "uri": self.uri,
153
+ "media_type": self.media_type,
154
+ "role": role,
155
+ "columns": sanitize_for_json(columns),
156
+ }
157
+ if self.sha256:
158
+ out["sha256"] = self.sha256
159
+ if self.n_rows is not None:
160
+ out["n_rows"] = self.n_rows
161
+ if self.metadata:
162
+ out["metadata"] = sanitize_for_json(self.metadata)
163
+ return out
164
+
165
+
166
+ NanStrategy = Literal["skipped", "null", "raise"]
167
+
168
+
169
+ def sanitize_for_json(
170
+ payload: object,
171
+ *,
172
+ nan_strategy: NanStrategy = "skipped",
173
+ ) -> object:
174
+ """Return a strict-JSON-safe copy of ``payload``.
175
+
176
+ Parameters
177
+ ----------
178
+ payload : object
179
+ Arbitrary nested structure of primitives, mappings, sequences,
180
+ dataclasses, numpy scalars / arrays, or objects with ``to_dict``.
181
+ nan_strategy : {"skipped", "null", "raise"}, optional
182
+ How to handle non-finite floats (NaN / ±Inf). Default ``"skipped"``
183
+ replaces them with a structured
184
+ :func:`skipped_metric` dict (keeps the reason auditable; preserves
185
+ pre-v0.13 behavior). ``"null"`` replaces with JSON ``null`` (use when
186
+ downstream consumers expect numeric-or-null without the structured
187
+ sentinel). ``"raise"`` raises ``ValueError`` on first non-finite
188
+ value, surfacing scoring bugs that would otherwise pass silently.
189
+ Closes F4.1 from the V4 consumer feedback log.
190
+
191
+ Returns
192
+ -------
193
+ object
194
+ A nested structure that ``json.dumps(..., allow_nan=False)`` will
195
+ accept. RFC 8259 compliant.
196
+
197
+ Raises
198
+ ------
199
+ ValueError
200
+ If ``nan_strategy="raise"`` and ``payload`` contains a non-finite
201
+ float anywhere in the structure.
202
+ """
203
+ if payload is None or isinstance(payload, (str, bool, int)):
204
+ return payload
205
+ if isinstance(payload, float):
206
+ if math.isfinite(payload):
207
+ return payload
208
+ if nan_strategy == "skipped":
209
+ return skipped_metric(f"non-finite numeric value: {payload!r}")
210
+ if nan_strategy == "null":
211
+ return None
212
+ if nan_strategy == "raise":
213
+ raise ValueError(f"non-finite numeric value: {payload!r}")
214
+ raise ValueError(
215
+ f"nan_strategy must be one of 'skipped', 'null', 'raise'; " f"got {nan_strategy!r}"
216
+ )
217
+ if isinstance(payload, np.generic):
218
+ return sanitize_for_json(payload.item(), nan_strategy=nan_strategy)
219
+ if isinstance(payload, np.ndarray):
220
+ return sanitize_for_json(payload.tolist(), nan_strategy=nan_strategy)
221
+ if hasattr(payload, "to_dict") and callable(payload.to_dict):
222
+ return sanitize_for_json(payload.to_dict(), nan_strategy=nan_strategy)
223
+ if dataclasses.is_dataclass(payload) and not isinstance(payload, type):
224
+ return sanitize_for_json(asdict(payload), nan_strategy=nan_strategy)
225
+ if isinstance(payload, Mapping):
226
+ return {
227
+ str(key): sanitize_for_json(value, nan_strategy=nan_strategy)
228
+ for key, value in payload.items()
229
+ }
230
+ if isinstance(payload, Sequence) and not isinstance(payload, (str, bytes, bytearray)):
231
+ return [sanitize_for_json(value, nan_strategy=nan_strategy) for value in payload]
232
+ return str(payload)
233
+
234
+
235
+ def write_json_strict(
236
+ payload: object,
237
+ path: Path | str,
238
+ *,
239
+ indent: int = 2,
240
+ sort_keys: bool = False,
241
+ ) -> Path:
242
+ """Write strict RFC 8259 JSON after sanitizing non-finite values."""
243
+ out_path = Path(path)
244
+ out_path.parent.mkdir(parents=True, exist_ok=True)
245
+ sanitized = sanitize_for_json(payload)
246
+ out_path.write_text(json.dumps(sanitized, indent=indent, sort_keys=sort_keys, allow_nan=False))
247
+ return out_path
248
+
249
+
250
+ def validate_payload(payload: object, schema_name: str) -> None:
251
+ """Validate a payload against a bundled schema.
252
+
253
+ ``jsonschema`` is a hard dependency since v0.16.0 (closing F9.1):
254
+ schema validation is the NeurIPS-aligned manifest contract, not an
255
+ optional polish, so consumers no longer need to install
256
+ ``eval-toolkit[validation]`` to use this helper.
257
+ """
258
+ from jsonschema import Draft202012Validator # type: ignore[import-untyped]
259
+
260
+ schema_path = resources.files("eval_toolkit") / "schemas" / schema_name
261
+ schema = json.loads(schema_path.read_text())
262
+ Draft202012Validator(schema).validate(sanitize_for_json(payload))
263
+
264
+
265
+ _KNOWN_MANIFEST_VERSIONS: frozenset[str] = frozenset({"v1", "v2", "v3"})
266
+
267
+
268
+ def validate_manifest(payload: Mapping[str, object]) -> None:
269
+ """Validate a serialized ``RunManifest`` payload.
270
+
271
+ Dispatches on ``payload["schema_version"]`` (``"v1"``, ``"v2"``, or
272
+ ``"v3"``); falls back to ``"v2"`` when the field is absent (preserves
273
+ V4.2/V4.3-era manifest readability). v3 (v0.23.0+) adds the required
274
+ ``contamination_flags`` field for per-scorer contamination posture.
275
+
276
+ Parameters
277
+ ----------
278
+ payload : Mapping[str, object]
279
+ Serialized manifest dict (typically ``RunManifest.to_dict()``).
280
+
281
+ Raises
282
+ ------
283
+ ImportError
284
+ If the optional ``validation`` extra is not installed.
285
+ jsonschema.ValidationError
286
+ If the payload does not conform to the schema for its declared
287
+ ``schema_version``.
288
+ ValueError
289
+ If ``schema_version`` is set but unrecognized.
290
+ """
291
+ raw_version = payload.get("schema_version", "v2")
292
+ version = raw_version if isinstance(raw_version, str) else "v2"
293
+ if version not in _KNOWN_MANIFEST_VERSIONS:
294
+ raise ValueError(
295
+ f"unknown manifest schema_version {version!r}; "
296
+ f"expected one of {sorted(_KNOWN_MANIFEST_VERSIONS)}"
297
+ )
298
+ validate_payload(payload, f"manifest.{version}.json")
299
+
300
+
301
+ def validate_results(payload: Mapping[str, object]) -> None:
302
+ """Validate a serialized ``RunResult`` payload against ``results.v1.json``.
303
+
304
+ Thin wrapper over :func:`validate_payload` so callers do not pass magic
305
+ schema-name strings (F9.2).
306
+ """
307
+ validate_payload(payload, "results.v1.json")
308
+
309
+
310
+ _PREDICTION_ARTIFACT_REF_SCHEMA: dict[str, object] = {
311
+ "$schema": "https://json-schema.org/draft/2020-12/schema",
312
+ "type": "object",
313
+ "required": ["uri", "media_type", "columns"],
314
+ "properties": {
315
+ "uri": {"type": "string", "minLength": 1},
316
+ "media_type": {"type": "string", "minLength": 1},
317
+ "sha256": {"type": "string"},
318
+ "n_rows": {"type": "integer", "minimum": 0},
319
+ # v0.15.0 (F5.2): role accepts a string or an array of strings.
320
+ "role": {
321
+ "oneOf": [
322
+ {"type": "string", "minLength": 1},
323
+ {
324
+ "type": "array",
325
+ "minItems": 1,
326
+ "items": {"type": "string", "minLength": 1},
327
+ },
328
+ ],
329
+ },
330
+ "metadata": {"type": "object"},
331
+ "columns": {
332
+ "type": "object",
333
+ "required": ["label", "score"],
334
+ "properties": {
335
+ "row_id": {"type": "string", "minLength": 1},
336
+ "content_hash": {"type": "string", "minLength": 1},
337
+ "label": {"type": "string", "minLength": 1},
338
+ "score": {"type": "string", "minLength": 1},
339
+ "scorer": {"type": "string", "minLength": 1},
340
+ "slice": {"type": "string", "minLength": 1},
341
+ "text": {"type": "string", "minLength": 1},
342
+ "provenance": {
343
+ "type": "object",
344
+ "additionalProperties": {"type": "string"},
345
+ },
346
+ },
347
+ "additionalProperties": True,
348
+ },
349
+ },
350
+ "additionalProperties": True,
351
+ }
352
+
353
+
354
+ def validate_prediction_artifact_ref(payload: Mapping[str, object]) -> None:
355
+ """Validate a single serialized :class:`PredictionArtifactRef` payload.
356
+
357
+ Mirrors the inline schema embedded in ``manifest.v2.json`` for
358
+ ``prediction_artifacts`` items, so callers can validate refs
359
+ independently of the surrounding manifest. Closes F9.2.
360
+
361
+ Parameters
362
+ ----------
363
+ payload : Mapping[str, object]
364
+ Serialized prediction artifact reference (typically
365
+ ``PredictionArtifactRef.to_dict()``).
366
+
367
+ Raises
368
+ ------
369
+ ImportError
370
+ If the optional ``validation`` extra is not installed.
371
+ jsonschema.ValidationError
372
+ If the payload does not conform.
373
+ """
374
+ from jsonschema import Draft202012Validator
375
+
376
+ Draft202012Validator(_PREDICTION_ARTIFACT_REF_SCHEMA).validate(sanitize_for_json(payload))