themis-eval 0.1.0__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (158) hide show
  1. themis/__init__.py +12 -1
  2. themis/_version.py +2 -2
  3. themis/api.py +343 -0
  4. themis/backends/__init__.py +17 -0
  5. themis/backends/execution.py +197 -0
  6. themis/backends/storage.py +260 -0
  7. themis/cli/__init__.py +5 -0
  8. themis/cli/__main__.py +6 -0
  9. themis/cli/commands/__init__.py +19 -0
  10. themis/cli/commands/benchmarks.py +221 -0
  11. themis/cli/commands/comparison.py +394 -0
  12. themis/cli/commands/config_commands.py +244 -0
  13. themis/cli/commands/cost.py +214 -0
  14. themis/cli/commands/demo.py +68 -0
  15. themis/cli/commands/info.py +90 -0
  16. themis/cli/commands/leaderboard.py +362 -0
  17. themis/cli/commands/math_benchmarks.py +318 -0
  18. themis/cli/commands/mcq_benchmarks.py +207 -0
  19. themis/cli/commands/results.py +252 -0
  20. themis/cli/commands/sample_run.py +244 -0
  21. themis/cli/commands/visualize.py +299 -0
  22. themis/cli/main.py +463 -0
  23. themis/cli/new_project.py +33 -0
  24. themis/cli/utils.py +51 -0
  25. themis/comparison/__init__.py +25 -0
  26. themis/comparison/engine.py +348 -0
  27. themis/comparison/reports.py +283 -0
  28. themis/comparison/statistics.py +402 -0
  29. themis/config/__init__.py +19 -0
  30. themis/config/loader.py +27 -0
  31. themis/config/registry.py +34 -0
  32. themis/config/runtime.py +214 -0
  33. themis/config/schema.py +112 -0
  34. themis/core/__init__.py +5 -0
  35. themis/core/conversation.py +354 -0
  36. themis/core/entities.py +184 -0
  37. themis/core/serialization.py +231 -0
  38. themis/core/tools.py +393 -0
  39. themis/core/types.py +141 -0
  40. themis/datasets/__init__.py +273 -0
  41. themis/datasets/base.py +264 -0
  42. themis/datasets/commonsense_qa.py +174 -0
  43. themis/datasets/competition_math.py +265 -0
  44. themis/datasets/coqa.py +133 -0
  45. themis/datasets/gpqa.py +190 -0
  46. themis/datasets/gsm8k.py +123 -0
  47. themis/datasets/gsm_symbolic.py +124 -0
  48. themis/datasets/math500.py +122 -0
  49. themis/datasets/med_qa.py +179 -0
  50. themis/datasets/medmcqa.py +169 -0
  51. themis/datasets/mmlu_pro.py +262 -0
  52. themis/datasets/piqa.py +146 -0
  53. themis/datasets/registry.py +201 -0
  54. themis/datasets/schema.py +245 -0
  55. themis/datasets/sciq.py +150 -0
  56. themis/datasets/social_i_qa.py +151 -0
  57. themis/datasets/super_gpqa.py +263 -0
  58. themis/evaluation/__init__.py +1 -0
  59. themis/evaluation/conditional.py +410 -0
  60. themis/evaluation/extractors/__init__.py +19 -0
  61. themis/evaluation/extractors/error_taxonomy_extractor.py +80 -0
  62. themis/evaluation/extractors/exceptions.py +7 -0
  63. themis/evaluation/extractors/identity_extractor.py +29 -0
  64. themis/evaluation/extractors/json_field_extractor.py +45 -0
  65. themis/evaluation/extractors/math_verify_extractor.py +37 -0
  66. themis/evaluation/extractors/regex_extractor.py +43 -0
  67. themis/evaluation/math_verify_utils.py +87 -0
  68. themis/evaluation/metrics/__init__.py +21 -0
  69. themis/evaluation/metrics/code/__init__.py +19 -0
  70. themis/evaluation/metrics/code/codebleu.py +144 -0
  71. themis/evaluation/metrics/code/execution.py +280 -0
  72. themis/evaluation/metrics/code/pass_at_k.py +181 -0
  73. themis/evaluation/metrics/composite_metric.py +47 -0
  74. themis/evaluation/metrics/consistency_metric.py +80 -0
  75. themis/evaluation/metrics/exact_match.py +51 -0
  76. themis/evaluation/metrics/length_difference_tolerance.py +33 -0
  77. themis/evaluation/metrics/math_verify_accuracy.py +40 -0
  78. themis/evaluation/metrics/nlp/__init__.py +21 -0
  79. themis/evaluation/metrics/nlp/bertscore.py +138 -0
  80. themis/evaluation/metrics/nlp/bleu.py +129 -0
  81. themis/evaluation/metrics/nlp/meteor.py +153 -0
  82. themis/evaluation/metrics/nlp/rouge.py +136 -0
  83. themis/evaluation/metrics/pairwise_judge_metric.py +141 -0
  84. themis/evaluation/metrics/response_length.py +33 -0
  85. themis/evaluation/metrics/rubric_judge_metric.py +134 -0
  86. themis/evaluation/pipeline.py +49 -0
  87. themis/evaluation/pipelines/__init__.py +15 -0
  88. themis/evaluation/pipelines/composable_pipeline.py +357 -0
  89. themis/evaluation/pipelines/standard_pipeline.py +348 -0
  90. themis/evaluation/reports.py +293 -0
  91. themis/evaluation/statistics/__init__.py +53 -0
  92. themis/evaluation/statistics/bootstrap.py +79 -0
  93. themis/evaluation/statistics/confidence_intervals.py +121 -0
  94. themis/evaluation/statistics/distributions.py +207 -0
  95. themis/evaluation/statistics/effect_sizes.py +124 -0
  96. themis/evaluation/statistics/hypothesis_tests.py +305 -0
  97. themis/evaluation/statistics/types.py +139 -0
  98. themis/evaluation/strategies/__init__.py +13 -0
  99. themis/evaluation/strategies/attempt_aware_evaluation_strategy.py +51 -0
  100. themis/evaluation/strategies/default_evaluation_strategy.py +25 -0
  101. themis/evaluation/strategies/evaluation_strategy.py +24 -0
  102. themis/evaluation/strategies/judge_evaluation_strategy.py +64 -0
  103. themis/experiment/__init__.py +5 -0
  104. themis/experiment/builder.py +151 -0
  105. themis/experiment/cache_manager.py +134 -0
  106. themis/experiment/comparison.py +631 -0
  107. themis/experiment/cost.py +310 -0
  108. themis/experiment/definitions.py +62 -0
  109. themis/experiment/export.py +798 -0
  110. themis/experiment/export_csv.py +159 -0
  111. themis/experiment/integration_manager.py +104 -0
  112. themis/experiment/math.py +192 -0
  113. themis/experiment/mcq.py +169 -0
  114. themis/experiment/orchestrator.py +415 -0
  115. themis/experiment/pricing.py +317 -0
  116. themis/experiment/storage.py +1458 -0
  117. themis/experiment/visualization.py +588 -0
  118. themis/generation/__init__.py +1 -0
  119. themis/generation/agentic_runner.py +420 -0
  120. themis/generation/batching.py +254 -0
  121. themis/generation/clients.py +143 -0
  122. themis/generation/conversation_runner.py +236 -0
  123. themis/generation/plan.py +456 -0
  124. themis/generation/providers/litellm_provider.py +221 -0
  125. themis/generation/providers/vllm_provider.py +135 -0
  126. themis/generation/router.py +34 -0
  127. themis/generation/runner.py +207 -0
  128. themis/generation/strategies.py +98 -0
  129. themis/generation/templates.py +71 -0
  130. themis/generation/turn_strategies.py +393 -0
  131. themis/generation/types.py +9 -0
  132. themis/integrations/__init__.py +0 -0
  133. themis/integrations/huggingface.py +72 -0
  134. themis/integrations/wandb.py +77 -0
  135. themis/interfaces/__init__.py +169 -0
  136. themis/presets/__init__.py +10 -0
  137. themis/presets/benchmarks.py +354 -0
  138. themis/presets/models.py +190 -0
  139. themis/project/__init__.py +20 -0
  140. themis/project/definitions.py +98 -0
  141. themis/project/patterns.py +230 -0
  142. themis/providers/__init__.py +5 -0
  143. themis/providers/registry.py +39 -0
  144. themis/server/__init__.py +28 -0
  145. themis/server/app.py +337 -0
  146. themis/utils/api_generator.py +379 -0
  147. themis/utils/cost_tracking.py +376 -0
  148. themis/utils/dashboard.py +452 -0
  149. themis/utils/logging_utils.py +41 -0
  150. themis/utils/progress.py +58 -0
  151. themis/utils/tracing.py +320 -0
  152. themis_eval-0.2.0.dist-info/METADATA +596 -0
  153. themis_eval-0.2.0.dist-info/RECORD +157 -0
  154. {themis_eval-0.1.0.dist-info → themis_eval-0.2.0.dist-info}/WHEEL +1 -1
  155. themis_eval-0.1.0.dist-info/METADATA +0 -758
  156. themis_eval-0.1.0.dist-info/RECORD +0 -8
  157. {themis_eval-0.1.0.dist-info → themis_eval-0.2.0.dist-info}/licenses/LICENSE +0 -0
  158. {themis_eval-0.1.0.dist-info → themis_eval-0.2.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,174 @@
1
+ """Helpers for working with the tau/commonsense_qa dataset."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import json
6
+ import string
7
+ from pathlib import Path
8
+ from typing import Any, Iterable, Iterator, List, Sequence
9
+
10
+ from pydantic import BaseModel, Field, ValidationInfo, field_validator
11
+
12
+ _DATASET_NAME = "tau/commonsense_qa"
13
+ _CHOICE_LABELS = tuple(string.ascii_uppercase)
14
+
15
+
16
+ class CommonsenseQaSample(BaseModel):
17
+ unique_id: str
18
+ question: str
19
+ choices: list[str]
20
+ answer: str
21
+ concept: str = Field(default="")
22
+ metadata: dict[str, Any] = Field(default_factory=dict)
23
+ choice_labels: list[str] = Field(default_factory=list)
24
+
25
+ @field_validator("choices", mode="before")
26
+ @classmethod
27
+ def _ensure_choices(cls, value: Any) -> list[str]:
28
+ if value is None:
29
+ return []
30
+ if isinstance(value, dict):
31
+ return [str(v) for _, v in sorted(value.items())]
32
+ if isinstance(value, (list, tuple)):
33
+ return [str(item) for item in value]
34
+ raise TypeError("choices must be a sequence or mapping")
35
+
36
+ @field_validator("choice_labels", mode="before")
37
+ @classmethod
38
+ def _build_choice_labels(cls, value: Any, info: ValidationInfo) -> list[str]:
39
+ if value:
40
+ return [str(item) for item in value]
41
+ choices = info.data.get("choices") if hasattr(info, "data") else None
42
+ total = len(choices) if isinstance(choices, list) else 0
43
+ return [*_CHOICE_LABELS[:total]]
44
+
45
+ def to_generation_example(self) -> dict[str, Any]:
46
+ effective_labels = (
47
+ list(self.choice_labels)
48
+ if self.choice_labels
49
+ else list(_CHOICE_LABELS[: len(self.choices)])
50
+ )
51
+ return {
52
+ "unique_id": self.unique_id,
53
+ "question": self.question,
54
+ "choices": list(self.choices),
55
+ "choice_labels": effective_labels,
56
+ "answer": self.answer,
57
+ "concept": self.concept,
58
+ "metadata": dict(self.metadata),
59
+ }
60
+
61
+
62
+ def load_commonsense_qa(
63
+ *,
64
+ split: str = "validation", # Test set usually has no labels
65
+ limit: int | None = None,
66
+ source: str = "huggingface",
67
+ data_dir: str | Path | None = None,
68
+ ) -> List[CommonsenseQaSample]:
69
+ """Load CommonsenseQA samples from Hugging Face or a local directory."""
70
+
71
+ if source not in {"huggingface", "local"}:
72
+ raise ValueError(
73
+ f"Unsupported source '{source}'. Expected one of: 'huggingface', 'local'."
74
+ )
75
+
76
+ if source == "huggingface":
77
+ rows = _load_from_huggingface(split=split)
78
+ else:
79
+ if data_dir is None:
80
+ raise ValueError(
81
+ "data_dir must be provided when source='local'. "
82
+ "Pass dataset.data_dir in configs or --data-dir on the CLI."
83
+ )
84
+ rows = _load_from_local(Path(data_dir))
85
+
86
+ samples: list[CommonsenseQaSample] = []
87
+ for index, row in enumerate(rows, start=1):
88
+ sample = _row_to_sample(row, index=index)
89
+ samples.append(sample)
90
+ if limit is not None and len(samples) >= limit:
91
+ break
92
+ return samples
93
+
94
+
95
+ def _row_to_sample(row: dict[str, Any], *, index: int) -> CommonsenseQaSample:
96
+ unique_id = (
97
+ row.get("id")
98
+ or row.get("unique_id")
99
+ or f"csqa-{index:05d}"
100
+ )
101
+ question = row.get("question") or ""
102
+
103
+ # CommonsenseQA format:
104
+ # choices: {'label': ['A', 'B', ...], 'text': ['text1', 'text2', ...]}
105
+ # answerKey: 'A'
106
+
107
+ choices_data = row.get("choices") or {}
108
+ choices = []
109
+ choice_labels = []
110
+
111
+ if isinstance(choices_data, dict):
112
+ labels = choices_data.get("label") or []
113
+ texts = choices_data.get("text") or []
114
+
115
+ # Zip and sort by label
116
+ zipped = sorted(zip(labels, texts), key=lambda x: x[0])
117
+ for label, text in zipped:
118
+ choices.append(str(text))
119
+ choice_labels.append(str(label))
120
+
121
+ answer = str(row.get("answerKey") or "")
122
+ concept = str(row.get("question_concept") or "")
123
+
124
+ metadata_keys = {
125
+ "question", "choices", "answerKey", "question_concept", "id"
126
+ }
127
+ metadata = {key: value for key, value in row.items() if key not in metadata_keys}
128
+
129
+ return CommonsenseQaSample(
130
+ unique_id=str(unique_id),
131
+ question=str(question),
132
+ choices=choices,
133
+ choice_labels=choice_labels,
134
+ answer=answer,
135
+ concept=concept,
136
+ metadata=metadata,
137
+ )
138
+
139
+
140
+ def _load_from_huggingface(*, split: str) -> Iterable[dict[str, Any]]:
141
+ try:
142
+ from datasets import load_dataset
143
+ except ImportError as exc: # pragma: no cover - optional dependency
144
+ raise RuntimeError(
145
+ "datasets is required to load CommonsenseQA from Hugging Face. Install it via `uv pip install '.[hf]'`."
146
+ ) from exc
147
+
148
+ dataset = load_dataset(_DATASET_NAME, split=split)
149
+ for row in dataset:
150
+ yield dict(row)
151
+
152
+
153
+ def _load_from_local(root: Path) -> Iterator[dict[str, Any]]:
154
+ if not root.exists():
155
+ raise FileNotFoundError(f"Local dataset directory not found: {root}")
156
+
157
+ for path in root.rglob("*"):
158
+ if path.suffix.lower() == ".json":
159
+ with path.open("r", encoding="utf-8") as handle:
160
+ row = json.load(handle)
161
+ row.setdefault("id", path.stem)
162
+ yield row
163
+ elif path.suffix.lower() in {".jsonl", ".ndjson"}:
164
+ with path.open("r", encoding="utf-8") as handle:
165
+ for line_num, line in enumerate(handle, start=1):
166
+ line = line.strip()
167
+ if not line:
168
+ continue
169
+ row = json.loads(line)
170
+ row.setdefault("id", f"{path.stem}-{line_num}")
171
+ yield row
172
+
173
+
174
+ __all__ = ["CommonsenseQaSample", "load_commonsense_qa"]
@@ -0,0 +1,265 @@
1
+ """Helpers for competition-style math benchmarks from Hugging Face."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import json
6
+ from pathlib import Path
7
+ from typing import Any, Iterable, Iterator, List, Sequence
8
+
9
+ from pydantic import BaseModel, Field, field_validator
10
+
11
+
12
+ class CompetitionMathSample(BaseModel):
13
+ unique_id: str
14
+ problem: str
15
+ solution: str
16
+ answer: str
17
+ subject: str = Field(default="unknown")
18
+ level: str | int = Field(default="unknown")
19
+ metadata: dict[str, Any] = Field(default_factory=dict)
20
+
21
+ @field_validator("metadata", mode="before")
22
+ @classmethod
23
+ def _ensure_metadata(cls, value: Any) -> dict[str, Any]:
24
+ return dict(value or {})
25
+
26
+ @field_validator("level", mode="before")
27
+ @classmethod
28
+ def _normalize_level(cls, value: Any) -> str | int:
29
+ if value is None or value == "":
30
+ return "unknown"
31
+ try:
32
+ return int(value)
33
+ except (TypeError, ValueError):
34
+ return str(value)
35
+
36
+ def to_generation_example(self) -> dict[str, Any]:
37
+ payload = {
38
+ "unique_id": self.unique_id,
39
+ "problem": self.problem,
40
+ "solution": self.solution,
41
+ "answer": self.answer,
42
+ "subject": self.subject,
43
+ "level": self.level,
44
+ }
45
+ payload.update(self.metadata)
46
+ return payload
47
+
48
+
49
+ def load_competition_math(
50
+ *,
51
+ dataset: str,
52
+ split: str = "test",
53
+ limit: int | None = None,
54
+ source: str = "huggingface",
55
+ data_dir: str | Path | None = None,
56
+ subjects: Sequence[str] | None = None,
57
+ subset: str | None = None,
58
+ ) -> List[CompetitionMathSample]:
59
+ """Load math competition samples from Hugging Face or a local directory."""
60
+
61
+ if source not in {"huggingface", "local"}:
62
+ raise ValueError(
63
+ f"Unsupported source '{source}'. Expected one of: 'huggingface', 'local'."
64
+ )
65
+
66
+ if source == "huggingface":
67
+ rows = _load_from_huggingface(dataset=dataset, split=split, subset=subset)
68
+ else:
69
+ if data_dir is None:
70
+ raise ValueError(
71
+ "data_dir must be provided when source='local'. "
72
+ "Pass dataset.data_dir in configs or --data-dir on the CLI."
73
+ )
74
+ rows = _load_from_local(Path(data_dir))
75
+
76
+ samples: list[CompetitionMathSample] = []
77
+ selected_subjects = {s.lower() for s in subjects} if subjects else None
78
+ for index, row in enumerate(rows, start=1):
79
+ subject = _extract_subject(row) or "unknown"
80
+ if selected_subjects and subject.lower() not in selected_subjects:
81
+ continue
82
+ sample = _row_to_sample(
83
+ row=row,
84
+ index=index,
85
+ dataset=dataset,
86
+ subject=subject,
87
+ )
88
+ samples.append(sample)
89
+ if limit is not None and len(samples) >= limit:
90
+ break
91
+ return samples
92
+
93
+
94
+ def _load_from_huggingface(
95
+ *, dataset: str, split: str, subset: str | None
96
+ ) -> Iterable[dict[str, Any]]:
97
+ try:
98
+ from datasets import load_dataset
99
+ except ImportError as exc: # pragma: no cover - optional dependency
100
+ raise RuntimeError(
101
+ "datasets is required to load competition math benchmarks from Hugging Face. "
102
+ "Install it via `uv pip install '.[hf]'`."
103
+ ) from exc
104
+
105
+ if subset:
106
+ hf_dataset = load_dataset(dataset, subset, split=split)
107
+ else:
108
+ hf_dataset = load_dataset(dataset, split=split)
109
+ for row in hf_dataset:
110
+ yield dict(row)
111
+
112
+
113
+ def _load_from_local(root: Path) -> Iterator[dict[str, Any]]:
114
+ if not root.exists():
115
+ raise FileNotFoundError(f"Local dataset directory not found: {root}")
116
+
117
+ for path in root.rglob("*"):
118
+ if path.suffix.lower() == ".json":
119
+ with path.open("r", encoding="utf-8") as handle:
120
+ row = json.load(handle)
121
+ row.setdefault("id", path.stem)
122
+ yield row
123
+ elif path.suffix.lower() in {".jsonl", ".ndjson"}:
124
+ with path.open("r", encoding="utf-8") as handle:
125
+ for line_num, line in enumerate(handle, start=1):
126
+ line = line.strip()
127
+ if not line:
128
+ continue
129
+ row = json.loads(line)
130
+ row.setdefault("id", f"{path.stem}-{line_num}")
131
+ yield row
132
+
133
+
134
+ def _extract_subject(row: dict[str, Any]) -> str | None:
135
+ for key in (
136
+ "subject",
137
+ "category",
138
+ "topic",
139
+ "domain",
140
+ "contest",
141
+ "source",
142
+ "level",
143
+ ):
144
+ value = row.get(key)
145
+ if value:
146
+ return str(value)
147
+ return None
148
+
149
+
150
+ def _extract_problem(row: dict[str, Any]) -> str:
151
+ for key in (
152
+ "problem",
153
+ "problem_text",
154
+ "problem_statement",
155
+ "question",
156
+ "prompt",
157
+ "problem_markdown",
158
+ ):
159
+ value = row.get(key)
160
+ if value:
161
+ return str(value)
162
+ return ""
163
+
164
+
165
+ def _extract_solution(row: dict[str, Any]) -> str:
166
+ for key in (
167
+ "solution",
168
+ "solution_text",
169
+ "solution_markdown",
170
+ "answer_explanation",
171
+ "worked_solution",
172
+ "reasoning",
173
+ ):
174
+ value = row.get(key)
175
+ if value:
176
+ return str(value)
177
+ return ""
178
+
179
+
180
+ def _extract_answer(row: dict[str, Any]) -> str:
181
+ for key in (
182
+ "answer",
183
+ "final_answer",
184
+ "ground_truth",
185
+ "answer_text",
186
+ "answer_value",
187
+ ):
188
+ value = row.get(key)
189
+ if value is not None:
190
+ return str(value).strip()
191
+ return ""
192
+
193
+
194
+ def _extract_level(row: dict[str, Any]) -> str | int:
195
+ for key in ("difficulty", "level", "year"):
196
+ value = row.get(key)
197
+ if value:
198
+ return value
199
+ return "unknown"
200
+
201
+
202
+ def _row_to_sample(
203
+ *,
204
+ row: dict[str, Any],
205
+ index: int,
206
+ dataset: str,
207
+ subject: str,
208
+ ) -> CompetitionMathSample:
209
+ unique_id = (
210
+ row.get("id")
211
+ or row.get("problem_id")
212
+ or row.get("unique_id")
213
+ or f"{dataset.replace('/', '-')}-{index:05d}"
214
+ )
215
+ problem = _extract_problem(row)
216
+ solution = _extract_solution(row)
217
+ answer = _extract_answer(row)
218
+ level = _extract_level(row)
219
+ core_keys = {
220
+ "id",
221
+ "problem_id",
222
+ "unique_id",
223
+ "problem",
224
+ "problem_text",
225
+ "problem_statement",
226
+ "question",
227
+ "prompt",
228
+ "problem_markdown",
229
+ "solution",
230
+ "solution_text",
231
+ "solution_markdown",
232
+ "answer_explanation",
233
+ "worked_solution",
234
+ "reasoning",
235
+ "answer",
236
+ "final_answer",
237
+ "ground_truth",
238
+ "answer_text",
239
+ "answer_value",
240
+ "difficulty",
241
+ "level",
242
+ "year",
243
+ "subject",
244
+ "category",
245
+ "topic",
246
+ "domain",
247
+ "contest",
248
+ "source",
249
+ }
250
+ metadata = {key: value for key, value in row.items() if key not in core_keys}
251
+ sample = CompetitionMathSample.model_validate(
252
+ {
253
+ "unique_id": str(unique_id),
254
+ "problem": problem,
255
+ "solution": solution,
256
+ "answer": answer,
257
+ "subject": str(subject),
258
+ "level": level,
259
+ "metadata": metadata,
260
+ }
261
+ )
262
+ return sample
263
+
264
+
265
+ __all__ = ["CompetitionMathSample", "load_competition_math"]
@@ -0,0 +1,133 @@
1
+ """Helpers for working with the stanfordnlp/coqa dataset."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import json
6
+ from pathlib import Path
7
+ from typing import Any, Iterable, Iterator, List, Sequence
8
+
9
+ from pydantic import BaseModel, Field, field_validator
10
+
11
+ _DATASET_NAME = "stanfordnlp/coqa"
12
+
13
+
14
+ class CoQaSample(BaseModel):
15
+ unique_id: str
16
+ story: str
17
+ question: str
18
+ answer: str
19
+ metadata: dict[str, Any] = Field(default_factory=dict)
20
+
21
+ @field_validator("metadata", mode="before")
22
+ @classmethod
23
+ def _ensure_metadata(cls, value: Any) -> dict[str, Any]:
24
+ return dict(value or {})
25
+
26
+ def to_generation_example(self) -> dict[str, Any]:
27
+ return {
28
+ "unique_id": self.unique_id,
29
+ "story": self.story,
30
+ "question": self.question,
31
+ "answer": self.answer,
32
+ "metadata": dict(self.metadata),
33
+ }
34
+
35
+
36
+ def load_coqa(
37
+ *,
38
+ split: str = "validation", # Test set usually has no labels
39
+ limit: int | None = None,
40
+ source: str = "huggingface",
41
+ data_dir: str | Path | None = None,
42
+ ) -> List[CoQaSample]:
43
+ """Load CoQA samples from Hugging Face or a local directory."""
44
+
45
+ if source not in {"huggingface", "local"}:
46
+ raise ValueError(
47
+ f"Unsupported source '{source}'. Expected one of: 'huggingface', 'local'."
48
+ )
49
+
50
+ if source == "huggingface":
51
+ rows = _load_from_huggingface(split=split)
52
+ else:
53
+ if data_dir is None:
54
+ raise ValueError(
55
+ "data_dir must be provided when source='local'. "
56
+ "Pass dataset.data_dir in configs or --data-dir on the CLI."
57
+ )
58
+ rows = _load_from_local(Path(data_dir))
59
+
60
+ samples: list[CoQaSample] = []
61
+ for index, row in enumerate(rows, start=1):
62
+ # CoQA has multiple questions per story. We need to flatten them.
63
+ # But wait, usually we want to evaluate turn-by-turn or just single turn.
64
+ # For simplicity, let's flatten: each question is a sample.
65
+ # Or maybe just take the first one? No, that's wasteful.
66
+ # Let's see the structure:
67
+ # 'questions': ['q1', 'q2'], 'answers': {'input_text': ['a1', 'a2'], ...}
68
+
69
+ story = row.get("story") or ""
70
+ questions = row.get("questions") or []
71
+ answers_data = row.get("answers") or {}
72
+ answers = answers_data.get("input_text") or []
73
+
74
+ if len(questions) != len(answers):
75
+ # Mismatch, skip or warn?
76
+ # Let's just take the minimum length
77
+ min_len = min(len(questions), len(answers))
78
+ questions = questions[:min_len]
79
+ answers = answers[:min_len]
80
+
81
+ for i, (q, a) in enumerate(zip(questions, answers)):
82
+ sample = CoQaSample(
83
+ unique_id=f"coqa-{index:05d}-{i:02d}",
84
+ story=story,
85
+ question=str(q),
86
+ answer=str(a),
87
+ metadata={"turn": i, "source": row.get("source")},
88
+ )
89
+ samples.append(sample)
90
+ if limit is not None and len(samples) >= limit:
91
+ break
92
+
93
+ if limit is not None and len(samples) >= limit:
94
+ break
95
+
96
+ return samples
97
+
98
+
99
+ def _load_from_huggingface(*, split: str) -> Iterable[dict[str, Any]]:
100
+ try:
101
+ from datasets import load_dataset
102
+ except ImportError as exc: # pragma: no cover - optional dependency
103
+ raise RuntimeError(
104
+ "datasets is required to load CoQA from Hugging Face. Install it via `uv pip install '.[hf]'`."
105
+ ) from exc
106
+
107
+ dataset = load_dataset(_DATASET_NAME, split=split)
108
+ for row in dataset:
109
+ yield dict(row)
110
+
111
+
112
+ def _load_from_local(root: Path) -> Iterator[dict[str, Any]]:
113
+ if not root.exists():
114
+ raise FileNotFoundError(f"Local dataset directory not found: {root}")
115
+
116
+ for path in root.rglob("*"):
117
+ if path.suffix.lower() == ".json":
118
+ with path.open("r", encoding="utf-8") as handle:
119
+ row = json.load(handle)
120
+ row.setdefault("id", path.stem)
121
+ yield row
122
+ elif path.suffix.lower() in {".jsonl", ".ndjson"}:
123
+ with path.open("r", encoding="utf-8") as handle:
124
+ for line_num, line in enumerate(handle, start=1):
125
+ line = line.strip()
126
+ if not line:
127
+ continue
128
+ row = json.loads(line)
129
+ row.setdefault("id", f"{path.stem}-{line_num}")
130
+ yield row
131
+
132
+
133
+ __all__ = ["CoQaSample", "load_coqa"]