eval-toolkit 0.27.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
eval_toolkit/config.py ADDED
@@ -0,0 +1,112 @@
1
+ """Frozen-dataclass config pattern + YAML loader.
2
+
3
+ The toolkit's recommended config pattern: ``@frozen_config`` wraps
4
+ ``@dataclass(frozen=True, slots=True)`` and ensures the subclass implements
5
+ ``__post_init__`` for validation.
6
+
7
+ YAML loading is optional — it requires the ``yaml`` extra
8
+ (``pip install eval-toolkit[yaml]``). If ``pyyaml`` is not installed,
9
+ :func:`from_yaml` raises :class:`ImportError` with a helpful message.
10
+ """
11
+
12
+ from __future__ import annotations
13
+
14
+ from dataclasses import dataclass, fields, is_dataclass
15
+ from pathlib import Path
16
+ from typing import Any
17
+
18
+ __all__ = ["frozen_config", "from_yaml"]
19
+
20
+
21
+ def frozen_config[T](cls: type[T]) -> type[T]:
22
+ """Decorator: apply ``@dataclass(frozen=True, slots=True)`` and validate.
23
+
24
+ Subclasses must implement ``__post_init__`` for field validation; the
25
+ decorator does not validate field values directly (that's
26
+ domain-specific) but it does ensure the class becomes a frozen, slotted
27
+ dataclass.
28
+
29
+ Parameters
30
+ ----------
31
+ cls : type
32
+ The class to decorate.
33
+
34
+ Returns
35
+ -------
36
+ type
37
+ A frozen, slotted dataclass.
38
+
39
+ Examples
40
+ --------
41
+ >>> @frozen_config
42
+ ... class TrainConfig:
43
+ ... lr: float
44
+ ... batch_size: int = 16
45
+ ... def __post_init__(self) -> None:
46
+ ... if self.lr <= 0:
47
+ ... raise ValueError(f"lr must be > 0, got {self.lr}")
48
+ >>> cfg = TrainConfig(lr=1e-3)
49
+ >>> cfg.lr
50
+ 0.001
51
+ >>> try:
52
+ ... cfg.lr = 1e-4
53
+ ... except (AttributeError, Exception) as e:
54
+ ... print(type(e).__name__)
55
+ FrozenInstanceError
56
+ """
57
+ return dataclass(frozen=True, slots=True)(cls)
58
+
59
+
60
+ def from_yaml[T](path: Path | str, cls: type[T]) -> T:
61
+ """Load a YAML file into an instance of ``cls`` (a frozen dataclass).
62
+
63
+ Requires the ``yaml`` extra: ``pip install eval-toolkit[yaml]``.
64
+
65
+ Parameters
66
+ ----------
67
+ path : pathlib.Path or str
68
+ cls : type
69
+ Frozen dataclass type. The YAML's top-level keys must be a subset of
70
+ ``cls``'s fields.
71
+
72
+ Returns
73
+ -------
74
+ T
75
+ Instance of ``cls`` constructed from the YAML.
76
+
77
+ Raises
78
+ ------
79
+ ImportError
80
+ If pyyaml is not installed.
81
+ FileNotFoundError
82
+ If ``path`` does not exist.
83
+ TypeError
84
+ If ``cls`` is not a dataclass.
85
+ KeyError
86
+ If the YAML contains an unknown key (not a field of ``cls``).
87
+ """
88
+ try:
89
+ import yaml # noqa: PLC0415
90
+ except ImportError as exc:
91
+ raise ImportError(
92
+ "from_yaml requires pyyaml; install with `pip install eval-toolkit[yaml]`"
93
+ ) from exc
94
+
95
+ if not is_dataclass(cls):
96
+ raise TypeError(f"cls must be a dataclass, got {cls.__name__}")
97
+
98
+ p = Path(path)
99
+ if not p.exists():
100
+ raise FileNotFoundError(f"config file not found: {p}")
101
+ raw: Any = yaml.safe_load(p.read_text())
102
+ if not isinstance(raw, dict):
103
+ raise TypeError(f"YAML root must be a mapping, got {type(raw).__name__}")
104
+
105
+ field_names = {f.name for f in fields(cls)}
106
+ unknown = set(raw.keys()) - field_names
107
+ if unknown:
108
+ raise KeyError(
109
+ f"unknown config keys: {sorted(unknown)}; expected subset of {sorted(field_names)}"
110
+ )
111
+
112
+ return cls(**raw)
eval_toolkit/docs.py ADDED
@@ -0,0 +1,305 @@
1
+ """Anchor-based markdown rendering with formatter registry.
2
+
3
+ Anchor format: ``<!-- begin:KEY -->old_value<!-- end:KEY -->``.
4
+
5
+ ``KEY`` is a dot-path into a metrics dict, e.g.
6
+ ``slices.test.scorers.deberta.pr_auc``. Compound formatters (``...lift``,
7
+ ``...lift_with_mde``) address the parent dict, not the leaf.
8
+
9
+ Pure functions: :func:`walk_path`, :func:`render_text`.
10
+ IO wrapper: :func:`render_files` (mode='check' returns drift report; mode='apply' writes).
11
+ """
12
+
13
+ from __future__ import annotations
14
+
15
+ import difflib
16
+ import re
17
+ from collections.abc import Callable
18
+ from pathlib import Path
19
+ from typing import Any
20
+
21
+ __all__ = [
22
+ "ANCHOR_RE",
23
+ "DEFAULT_FORMATTERS",
24
+ "render_files",
25
+ "render_text",
26
+ "walk_path",
27
+ ]
28
+
29
+ ANCHOR_RE = re.compile(
30
+ r"(<!--\s*begin:(?P<key>[^\s>]+)\s*-->)(?P<body>.*?)(<!--\s*end:(?P=key)\s*-->)",
31
+ re.DOTALL,
32
+ )
33
+
34
+
35
+ def walk_path(data: Any, dotted_path: str) -> Any:
36
+ """Walk a dot-path into nested dicts/lists.
37
+
38
+ Parameters
39
+ ----------
40
+ data : Any
41
+ Root mapping or list to walk.
42
+ dotted_path : str
43
+ Path like ``"a.b.c"`` or ``"a.0.x"``.
44
+
45
+ Returns
46
+ -------
47
+ Any
48
+ The value at the path.
49
+
50
+ Raises
51
+ ------
52
+ KeyError
53
+ If a key is missing or a non-dict/non-list is encountered mid-walk.
54
+
55
+ Examples
56
+ --------
57
+ >>> walk_path({"a": {"b": 42}}, "a.b")
58
+ 42
59
+ >>> walk_path({"a": [10, 20, 30]}, "a.1")
60
+ 20
61
+ >>> try:
62
+ ... walk_path({"a": {}}, "a.b")
63
+ ... except KeyError as e:
64
+ ... print("KeyError")
65
+ KeyError
66
+ """
67
+ cur: Any = data
68
+ for part in dotted_path.split("."):
69
+ if isinstance(cur, dict):
70
+ if part not in cur:
71
+ raise KeyError(f"missing key {part!r} at {dotted_path!r}")
72
+ cur = cur[part]
73
+ elif isinstance(cur, list):
74
+ try:
75
+ idx = int(part)
76
+ except ValueError as exc:
77
+ raise KeyError(f"non-integer index {part!r} at {dotted_path!r}") from exc
78
+ cur = cur[idx]
79
+ else:
80
+ raise KeyError(f"cannot descend into {type(cur).__name__} at {dotted_path!r}")
81
+ return cur
82
+
83
+
84
+ # Default leaf-name formatters. Callers extend via the ``formatters`` arg of
85
+ # :func:`render_text` / :func:`render_files`.
86
+
87
+
88
+ def _fmt_signed_3(v: Any) -> str:
89
+ if v is None:
90
+ return "N/A"
91
+ return f"{float(v):+.3f}"
92
+
93
+
94
+ def _fmt_signed_4(v: Any) -> str:
95
+ if v is None:
96
+ return "N/A"
97
+ return f"{float(v):+.4f}"
98
+
99
+
100
+ def _fmt_3(v: Any) -> str:
101
+ if v is None:
102
+ return "N/A"
103
+ return f"{float(v):.3f}"
104
+
105
+
106
+ def _fmt_4(v: Any) -> str:
107
+ if v is None:
108
+ return "N/A"
109
+ return f"{float(v):.4f}"
110
+
111
+
112
+ def _fmt_int(v: Any) -> str:
113
+ if v is None:
114
+ return "N/A"
115
+ return str(int(v))
116
+
117
+
118
+ def _fmt_lift(d: dict[str, Any]) -> str:
119
+ """Render a paired-bootstrap CI as ``+0.097 [+0.020, +0.199]``."""
120
+ return f"{d['delta']:+.3f} [{d['ci_low']:+.3f}, {d['ci_high']:+.3f}]"
121
+
122
+
123
+ DEFAULT_FORMATTERS: dict[str, Callable[[Any], str]] = {
124
+ "pr_auc": _fmt_3,
125
+ "roc_auc": _fmt_3,
126
+ "f1": _fmt_3,
127
+ "precision": _fmt_3,
128
+ "recall": _fmt_3,
129
+ "delta": _fmt_signed_3,
130
+ "ci_low": _fmt_signed_3,
131
+ "ci_high": _fmt_signed_3,
132
+ "ece": _fmt_3,
133
+ "ece_equal_width": _fmt_3,
134
+ "ece_equal_mass": _fmt_3,
135
+ "temperature": _fmt_4,
136
+ "nll_pre": _fmt_3,
137
+ "nll_post": _fmt_3,
138
+ "improvement": _fmt_signed_4,
139
+ "threshold": _fmt_4,
140
+ "n": _fmt_int,
141
+ "n_positive": _fmt_int,
142
+ "n_negative": _fmt_int,
143
+ "mean": _fmt_3,
144
+ "std": _fmt_3,
145
+ "min": _fmt_3,
146
+ "max": _fmt_3,
147
+ # compound formatters (called with the parent dict, not the leaf)
148
+ "lift": _fmt_lift,
149
+ }
150
+
151
+
152
+ def _render_value(
153
+ metrics: dict[str, Any],
154
+ key: str,
155
+ formatters: dict[str, Callable[[Any], str]],
156
+ compound_keys: frozenset[str],
157
+ ) -> str:
158
+ """Apply the matching formatter to ``key`` looked up in ``metrics``."""
159
+ leaf = key.rsplit(".", 1)[-1]
160
+ if leaf in compound_keys:
161
+ parent = walk_path(metrics, key.rsplit(".", 1)[0])
162
+ return formatters[leaf](parent)
163
+ value = walk_path(metrics, key)
164
+ fmt = formatters.get(leaf)
165
+ if fmt is None:
166
+ return str(value)
167
+ return fmt(value)
168
+
169
+
170
+ def render_text(
171
+ text: str,
172
+ metrics: dict[str, Any],
173
+ formatters: dict[str, Callable[[Any], str]] | None = None,
174
+ *,
175
+ compound_keys: frozenset[str] = frozenset({"lift"}),
176
+ ) -> tuple[str, list[str]]:
177
+ """Replace anchored bodies in ``text`` with values looked up in ``metrics``.
178
+
179
+ Parameters
180
+ ----------
181
+ text : str
182
+ Markdown source containing ``<!-- begin:KEY -->...<!-- end:KEY -->`` anchors.
183
+ metrics : dict[str, Any]
184
+ Nested dict (or list) to look up KEYs in via :func:`walk_path`.
185
+ formatters : dict[str, Callable] or None, optional
186
+ Leaf-name → formatter callable. If ``None``, uses
187
+ :data:`DEFAULT_FORMATTERS`. Caller extends via dict-merge.
188
+ compound_keys : frozenset[str], optional
189
+ Leaf names that should be passed the *parent dict* rather than the
190
+ leaf value (e.g., ``lift`` is a compound formatter that needs
191
+ ``delta``, ``ci_low``, ``ci_high`` together).
192
+
193
+ Returns
194
+ -------
195
+ tuple[str, list[str]]
196
+ ``(rendered_text, errors)``. Unknown keys leave the body unchanged
197
+ and append a diagnostic; errors list does not raise.
198
+
199
+ Raises
200
+ ------
201
+ TypeError
202
+ If ``text`` is not a ``str`` or ``metrics`` is not a ``dict``.
203
+
204
+ Examples
205
+ --------
206
+ >>> text = "<!-- begin:metric.pr_auc -->X<!-- end:metric.pr_auc -->"
207
+ >>> data = {"metric": {"pr_auc": 0.951}}
208
+ >>> rendered, errs = render_text(text, data)
209
+ >>> "0.951" in rendered
210
+ True
211
+ >>> errs
212
+ []
213
+ """
214
+ if not isinstance(text, str):
215
+ raise TypeError(f"text must be str, got {type(text).__name__}")
216
+ if not isinstance(metrics, dict):
217
+ raise TypeError(f"metrics must be a dict, got {type(metrics).__name__}")
218
+ fmts = formatters if formatters is not None else DEFAULT_FORMATTERS
219
+ errors: list[str] = []
220
+
221
+ def _sub(m: re.Match[str]) -> str:
222
+ key = m.group("key")
223
+ try:
224
+ new_body = _render_value(metrics, key, fmts, compound_keys)
225
+ except (KeyError, ValueError, TypeError) as exc:
226
+ errors.append(f"{key}: {exc}")
227
+ return m.group(0)
228
+ return f"{m.group(1)}{new_body}{m.group(4)}"
229
+
230
+ return ANCHOR_RE.sub(_sub, text), errors
231
+
232
+
233
+ def render_files(
234
+ targets: list[Path],
235
+ metrics: dict[str, Any],
236
+ *,
237
+ mode: str = "apply",
238
+ formatters: dict[str, Callable[[Any], str]] | None = None,
239
+ compound_keys: frozenset[str] = frozenset({"lift"}),
240
+ ) -> dict[str, Any]:
241
+ """IO wrapper: read each target file, render anchors, optionally write.
242
+
243
+ Parameters
244
+ ----------
245
+ targets : list of pathlib.Path
246
+ Markdown files to render.
247
+ metrics : dict[str, Any]
248
+ mode : {"apply", "check"}, optional
249
+ ``"apply"`` writes rendered output back to each file (default).
250
+ ``"check"`` does NOT write; instead returns the diff per file in the
251
+ result dict (useful for CI to detect drift).
252
+ formatters, compound_keys : see :func:`render_text`.
253
+
254
+ Returns
255
+ -------
256
+ dict[str, Any]
257
+ ``{"updated": [...], "unchanged": [...], "drift": {...}, "errors": {...}}``
258
+ where ``drift`` is populated only in check mode.
259
+
260
+ Raises
261
+ ------
262
+ ValueError
263
+ If ``mode`` is not one of ``{"apply", "check"}``.
264
+ """
265
+ if mode not in {"apply", "check"}:
266
+ raise ValueError(f"mode must be 'apply' or 'check', got {mode!r}")
267
+
268
+ updated: list[str] = []
269
+ unchanged: list[str] = []
270
+ drift: dict[str, str] = {}
271
+ errors: dict[str, list[str]] = {}
272
+
273
+ for path in targets:
274
+ if not path.exists():
275
+ errors[str(path)] = [f"file not found: {path}"]
276
+ continue
277
+ original = path.read_text()
278
+ rendered, errs = render_text(original, metrics, formatters, compound_keys=compound_keys)
279
+ if errs:
280
+ errors[str(path)] = errs
281
+ if rendered == original:
282
+ unchanged.append(str(path))
283
+ continue
284
+ if mode == "apply":
285
+ path.write_text(rendered)
286
+ updated.append(str(path))
287
+ else:
288
+ diff = "".join(
289
+ difflib.unified_diff(
290
+ original.splitlines(keepends=True),
291
+ rendered.splitlines(keepends=True),
292
+ fromfile=str(path),
293
+ tofile=f"{path} (rendered)",
294
+ n=1,
295
+ )
296
+ )
297
+ drift[str(path)] = diff
298
+ updated.append(str(path))
299
+
300
+ return {
301
+ "updated": updated,
302
+ "unchanged": unchanged,
303
+ "drift": drift,
304
+ "errors": errors,
305
+ }
@@ -0,0 +1,90 @@
1
+ """Generic evidence metadata contracts."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from dataclasses import dataclass, field
6
+ from typing import Literal
7
+
8
+ __all__ = [
9
+ "AggregateEvidence",
10
+ "EvidenceAxis",
11
+ "PairingMetadata",
12
+ "RECOMMENDED_SOURCE_ROLES",
13
+ ]
14
+
15
+ AggregateStatus = Literal["inferential", "descriptive", "diagnostic", "unsupported"]
16
+
17
+ RECOMMENDED_SOURCE_ROLES: tuple[str, ...] = (
18
+ "train",
19
+ "calibration",
20
+ "locked_eval",
21
+ "external_diagnostic",
22
+ "excluded",
23
+ )
24
+
25
+
26
+ @dataclass(frozen=True, slots=True)
27
+ class EvidenceAxis:
28
+ """Named evidence axis such as fold, seed, source_out, view, or slice."""
29
+
30
+ name: str
31
+ value: str
32
+
33
+ def __post_init__(self) -> None:
34
+ """Validate a non-empty axis."""
35
+ if not self.name:
36
+ raise ValueError("EvidenceAxis.name must be non-empty")
37
+ if not self.value:
38
+ raise ValueError("EvidenceAxis.value must be non-empty")
39
+
40
+ def to_dict(self) -> dict[str, object]:
41
+ """JSON-serializable representation."""
42
+ return {"name": self.name, "value": self.value}
43
+
44
+
45
+ @dataclass(frozen=True, slots=True)
46
+ class PairingMetadata:
47
+ """Machine-readable description of comparison pairedness."""
48
+
49
+ paired: bool
50
+ unit: str = "row"
51
+ valid_scope: str = ""
52
+ notes: str = ""
53
+
54
+ def to_dict(self) -> dict[str, object]:
55
+ """JSON-serializable representation."""
56
+ return {
57
+ "paired": self.paired,
58
+ "unit": self.unit,
59
+ "valid_scope": self.valid_scope,
60
+ "notes": self.notes,
61
+ }
62
+
63
+
64
+ @dataclass(frozen=True, slots=True)
65
+ class AggregateEvidence:
66
+ """Typed status for aggregate rows or summaries."""
67
+
68
+ status: AggregateStatus
69
+ method: str
70
+ axes: tuple[EvidenceAxis, ...] = ()
71
+ notes: str = ""
72
+ metadata: dict[str, object] = field(default_factory=dict)
73
+
74
+ def __post_init__(self) -> None:
75
+ """Normalize axes to a tuple."""
76
+ object.__setattr__(self, "axes", tuple(self.axes))
77
+ if not self.method:
78
+ raise ValueError("AggregateEvidence.method must be non-empty")
79
+
80
+ def to_dict(self) -> dict[str, object]:
81
+ """JSON-serializable representation."""
82
+ out: dict[str, object] = {
83
+ "status": self.status,
84
+ "method": self.method,
85
+ "axes": [axis.to_dict() for axis in self.axes],
86
+ "notes": self.notes,
87
+ }
88
+ if self.metadata:
89
+ out["metadata"] = dict(self.metadata)
90
+ return out