themis-eval 0.1.0__py3-none-any.whl → 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- themis/__init__.py +12 -1
- themis/_version.py +2 -2
- themis/api.py +343 -0
- themis/backends/__init__.py +17 -0
- themis/backends/execution.py +197 -0
- themis/backends/storage.py +260 -0
- themis/cli/__init__.py +5 -0
- themis/cli/__main__.py +6 -0
- themis/cli/commands/__init__.py +19 -0
- themis/cli/commands/benchmarks.py +221 -0
- themis/cli/commands/comparison.py +394 -0
- themis/cli/commands/config_commands.py +244 -0
- themis/cli/commands/cost.py +214 -0
- themis/cli/commands/demo.py +68 -0
- themis/cli/commands/info.py +90 -0
- themis/cli/commands/leaderboard.py +362 -0
- themis/cli/commands/math_benchmarks.py +318 -0
- themis/cli/commands/mcq_benchmarks.py +207 -0
- themis/cli/commands/results.py +252 -0
- themis/cli/commands/sample_run.py +244 -0
- themis/cli/commands/visualize.py +299 -0
- themis/cli/main.py +463 -0
- themis/cli/new_project.py +33 -0
- themis/cli/utils.py +51 -0
- themis/comparison/__init__.py +25 -0
- themis/comparison/engine.py +348 -0
- themis/comparison/reports.py +283 -0
- themis/comparison/statistics.py +402 -0
- themis/config/__init__.py +19 -0
- themis/config/loader.py +27 -0
- themis/config/registry.py +34 -0
- themis/config/runtime.py +214 -0
- themis/config/schema.py +112 -0
- themis/core/__init__.py +5 -0
- themis/core/conversation.py +354 -0
- themis/core/entities.py +184 -0
- themis/core/serialization.py +231 -0
- themis/core/tools.py +393 -0
- themis/core/types.py +141 -0
- themis/datasets/__init__.py +273 -0
- themis/datasets/base.py +264 -0
- themis/datasets/commonsense_qa.py +174 -0
- themis/datasets/competition_math.py +265 -0
- themis/datasets/coqa.py +133 -0
- themis/datasets/gpqa.py +190 -0
- themis/datasets/gsm8k.py +123 -0
- themis/datasets/gsm_symbolic.py +124 -0
- themis/datasets/math500.py +122 -0
- themis/datasets/med_qa.py +179 -0
- themis/datasets/medmcqa.py +169 -0
- themis/datasets/mmlu_pro.py +262 -0
- themis/datasets/piqa.py +146 -0
- themis/datasets/registry.py +201 -0
- themis/datasets/schema.py +245 -0
- themis/datasets/sciq.py +150 -0
- themis/datasets/social_i_qa.py +151 -0
- themis/datasets/super_gpqa.py +263 -0
- themis/evaluation/__init__.py +1 -0
- themis/evaluation/conditional.py +410 -0
- themis/evaluation/extractors/__init__.py +19 -0
- themis/evaluation/extractors/error_taxonomy_extractor.py +80 -0
- themis/evaluation/extractors/exceptions.py +7 -0
- themis/evaluation/extractors/identity_extractor.py +29 -0
- themis/evaluation/extractors/json_field_extractor.py +45 -0
- themis/evaluation/extractors/math_verify_extractor.py +37 -0
- themis/evaluation/extractors/regex_extractor.py +43 -0
- themis/evaluation/math_verify_utils.py +87 -0
- themis/evaluation/metrics/__init__.py +21 -0
- themis/evaluation/metrics/code/__init__.py +19 -0
- themis/evaluation/metrics/code/codebleu.py +144 -0
- themis/evaluation/metrics/code/execution.py +280 -0
- themis/evaluation/metrics/code/pass_at_k.py +181 -0
- themis/evaluation/metrics/composite_metric.py +47 -0
- themis/evaluation/metrics/consistency_metric.py +80 -0
- themis/evaluation/metrics/exact_match.py +51 -0
- themis/evaluation/metrics/length_difference_tolerance.py +33 -0
- themis/evaluation/metrics/math_verify_accuracy.py +40 -0
- themis/evaluation/metrics/nlp/__init__.py +21 -0
- themis/evaluation/metrics/nlp/bertscore.py +138 -0
- themis/evaluation/metrics/nlp/bleu.py +129 -0
- themis/evaluation/metrics/nlp/meteor.py +153 -0
- themis/evaluation/metrics/nlp/rouge.py +136 -0
- themis/evaluation/metrics/pairwise_judge_metric.py +141 -0
- themis/evaluation/metrics/response_length.py +33 -0
- themis/evaluation/metrics/rubric_judge_metric.py +134 -0
- themis/evaluation/pipeline.py +49 -0
- themis/evaluation/pipelines/__init__.py +15 -0
- themis/evaluation/pipelines/composable_pipeline.py +357 -0
- themis/evaluation/pipelines/standard_pipeline.py +348 -0
- themis/evaluation/reports.py +293 -0
- themis/evaluation/statistics/__init__.py +53 -0
- themis/evaluation/statistics/bootstrap.py +79 -0
- themis/evaluation/statistics/confidence_intervals.py +121 -0
- themis/evaluation/statistics/distributions.py +207 -0
- themis/evaluation/statistics/effect_sizes.py +124 -0
- themis/evaluation/statistics/hypothesis_tests.py +305 -0
- themis/evaluation/statistics/types.py +139 -0
- themis/evaluation/strategies/__init__.py +13 -0
- themis/evaluation/strategies/attempt_aware_evaluation_strategy.py +51 -0
- themis/evaluation/strategies/default_evaluation_strategy.py +25 -0
- themis/evaluation/strategies/evaluation_strategy.py +24 -0
- themis/evaluation/strategies/judge_evaluation_strategy.py +64 -0
- themis/experiment/__init__.py +5 -0
- themis/experiment/builder.py +151 -0
- themis/experiment/cache_manager.py +134 -0
- themis/experiment/comparison.py +631 -0
- themis/experiment/cost.py +310 -0
- themis/experiment/definitions.py +62 -0
- themis/experiment/export.py +798 -0
- themis/experiment/export_csv.py +159 -0
- themis/experiment/integration_manager.py +104 -0
- themis/experiment/math.py +192 -0
- themis/experiment/mcq.py +169 -0
- themis/experiment/orchestrator.py +415 -0
- themis/experiment/pricing.py +317 -0
- themis/experiment/storage.py +1458 -0
- themis/experiment/visualization.py +588 -0
- themis/generation/__init__.py +1 -0
- themis/generation/agentic_runner.py +420 -0
- themis/generation/batching.py +254 -0
- themis/generation/clients.py +143 -0
- themis/generation/conversation_runner.py +236 -0
- themis/generation/plan.py +456 -0
- themis/generation/providers/litellm_provider.py +221 -0
- themis/generation/providers/vllm_provider.py +135 -0
- themis/generation/router.py +34 -0
- themis/generation/runner.py +207 -0
- themis/generation/strategies.py +98 -0
- themis/generation/templates.py +71 -0
- themis/generation/turn_strategies.py +393 -0
- themis/generation/types.py +9 -0
- themis/integrations/__init__.py +0 -0
- themis/integrations/huggingface.py +72 -0
- themis/integrations/wandb.py +77 -0
- themis/interfaces/__init__.py +169 -0
- themis/presets/__init__.py +10 -0
- themis/presets/benchmarks.py +354 -0
- themis/presets/models.py +190 -0
- themis/project/__init__.py +20 -0
- themis/project/definitions.py +98 -0
- themis/project/patterns.py +230 -0
- themis/providers/__init__.py +5 -0
- themis/providers/registry.py +39 -0
- themis/server/__init__.py +28 -0
- themis/server/app.py +337 -0
- themis/utils/api_generator.py +379 -0
- themis/utils/cost_tracking.py +376 -0
- themis/utils/dashboard.py +452 -0
- themis/utils/logging_utils.py +41 -0
- themis/utils/progress.py +58 -0
- themis/utils/tracing.py +320 -0
- themis_eval-0.2.0.dist-info/METADATA +596 -0
- themis_eval-0.2.0.dist-info/RECORD +157 -0
- {themis_eval-0.1.0.dist-info → themis_eval-0.2.0.dist-info}/WHEEL +1 -1
- themis_eval-0.1.0.dist-info/METADATA +0 -758
- themis_eval-0.1.0.dist-info/RECORD +0 -8
- {themis_eval-0.1.0.dist-info → themis_eval-0.2.0.dist-info}/licenses/LICENSE +0 -0
- {themis_eval-0.1.0.dist-info → themis_eval-0.2.0.dist-info}/top_level.txt +0 -0
themis/datasets/gpqa.py
ADDED
|
@@ -0,0 +1,190 @@
|
|
|
1
|
+
"""Helpers for working with the math-ai/gpqa dataset."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import json
|
|
6
|
+
import string
|
|
7
|
+
from pathlib import Path
|
|
8
|
+
from typing import Any, Iterable, Iterator, List, Sequence
|
|
9
|
+
|
|
10
|
+
from pydantic import BaseModel, Field, ValidationInfo, field_validator
|
|
11
|
+
|
|
12
|
+
_DATASET_NAME = "math-ai/gpqa"
|
|
13
|
+
_CHOICE_LABELS = tuple(string.ascii_uppercase)
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class GpqaSample(BaseModel):
|
|
17
|
+
unique_id: str
|
|
18
|
+
question: str
|
|
19
|
+
choices: list[str]
|
|
20
|
+
answer: str
|
|
21
|
+
subject: str = Field(default="unknown")
|
|
22
|
+
metadata: dict[str, Any] = Field(default_factory=dict)
|
|
23
|
+
choice_labels: list[str] = Field(default_factory=list)
|
|
24
|
+
|
|
25
|
+
@field_validator("choices", mode="before")
|
|
26
|
+
@classmethod
|
|
27
|
+
def _ensure_choices(cls, value: Any) -> list[str]:
|
|
28
|
+
if value is None:
|
|
29
|
+
return []
|
|
30
|
+
if isinstance(value, dict):
|
|
31
|
+
return [str(v) for _, v in sorted(value.items())]
|
|
32
|
+
if isinstance(value, (list, tuple)):
|
|
33
|
+
return [str(item) for item in value]
|
|
34
|
+
raise TypeError("choices must be a sequence or mapping")
|
|
35
|
+
|
|
36
|
+
@field_validator("choice_labels", mode="before")
|
|
37
|
+
@classmethod
|
|
38
|
+
def _build_choice_labels(cls, value: Any, info: ValidationInfo) -> list[str]:
|
|
39
|
+
if value:
|
|
40
|
+
return [str(item) for item in value]
|
|
41
|
+
choices = info.data.get("choices") if hasattr(info, "data") else None
|
|
42
|
+
total = len(choices) if isinstance(choices, list) else 0
|
|
43
|
+
return [*_CHOICE_LABELS[:total]]
|
|
44
|
+
|
|
45
|
+
def to_generation_example(self) -> dict[str, Any]:
|
|
46
|
+
effective_labels = (
|
|
47
|
+
list(self.choice_labels)
|
|
48
|
+
if self.choice_labels
|
|
49
|
+
else list(_CHOICE_LABELS[: len(self.choices)])
|
|
50
|
+
)
|
|
51
|
+
return {
|
|
52
|
+
"unique_id": self.unique_id,
|
|
53
|
+
"question": self.question,
|
|
54
|
+
"choices": list(self.choices),
|
|
55
|
+
"choice_labels": effective_labels,
|
|
56
|
+
"answer": self.answer,
|
|
57
|
+
"subject": self.subject,
|
|
58
|
+
"metadata": dict(self.metadata),
|
|
59
|
+
}
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
def load_gpqa(
|
|
63
|
+
*,
|
|
64
|
+
split: str = "test",
|
|
65
|
+
limit: int | None = None,
|
|
66
|
+
source: str = "huggingface",
|
|
67
|
+
data_dir: str | Path | None = None,
|
|
68
|
+
subset: str = "gpqa_diamond",
|
|
69
|
+
) -> List[GpqaSample]:
|
|
70
|
+
"""Load GPQA samples from Hugging Face or a local directory."""
|
|
71
|
+
|
|
72
|
+
if source not in {"huggingface", "local"}:
|
|
73
|
+
raise ValueError(
|
|
74
|
+
f"Unsupported source '{source}'. Expected one of: 'huggingface', 'local'."
|
|
75
|
+
)
|
|
76
|
+
|
|
77
|
+
if source == "huggingface":
|
|
78
|
+
rows = _load_from_huggingface(split=split, subset=subset)
|
|
79
|
+
else:
|
|
80
|
+
if data_dir is None:
|
|
81
|
+
raise ValueError(
|
|
82
|
+
"data_dir must be provided when source='local'. "
|
|
83
|
+
"Pass dataset.data_dir in configs or --data-dir on the CLI."
|
|
84
|
+
)
|
|
85
|
+
rows = _load_from_local(Path(data_dir))
|
|
86
|
+
|
|
87
|
+
samples: list[GpqaSample] = []
|
|
88
|
+
for index, row in enumerate(rows, start=1):
|
|
89
|
+
sample = _row_to_sample(row, index=index)
|
|
90
|
+
samples.append(sample)
|
|
91
|
+
if limit is not None and len(samples) >= limit:
|
|
92
|
+
break
|
|
93
|
+
return samples
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
def _row_to_sample(row: dict[str, Any], *, index: int) -> GpqaSample:
|
|
97
|
+
unique_id = (
|
|
98
|
+
row.get("id")
|
|
99
|
+
or row.get("unique_id")
|
|
100
|
+
or f"gpqa-{index:05d}"
|
|
101
|
+
)
|
|
102
|
+
question = row.get("Question") or row.get("question") or ""
|
|
103
|
+
|
|
104
|
+
# GPQA usually has 'Correct Answer', 'Incorrect Answer 1', 'Incorrect Answer 2', 'Incorrect Answer 3'
|
|
105
|
+
# We need to shuffle them or just present them. For simplicity, we'll just collect them.
|
|
106
|
+
# However, standard GPQA format in HF might be different.
|
|
107
|
+
# Let's assume the HF format: 'Question', 'Correct Answer', 'Incorrect Answer 1', ...
|
|
108
|
+
|
|
109
|
+
correct_answer = row.get("Correct Answer") or row.get("correct_answer") or ""
|
|
110
|
+
incorrect_answers = []
|
|
111
|
+
for i in range(1, 4):
|
|
112
|
+
inc = row.get(f"Incorrect Answer {i}") or row.get(f"incorrect_answer_{i}")
|
|
113
|
+
if inc:
|
|
114
|
+
incorrect_answers.append(str(inc))
|
|
115
|
+
|
|
116
|
+
# If choices are already present (e.g. processed version), use them
|
|
117
|
+
choices = row.get("choices") or row.get("options")
|
|
118
|
+
if not choices:
|
|
119
|
+
# We need to form choices. For now, let's put correct answer first (should be shuffled in real eval)
|
|
120
|
+
# But wait, if we put correct answer first always, the model might learn.
|
|
121
|
+
# Ideally, we should shuffle. But to keep it deterministic for now without extra deps,
|
|
122
|
+
# let's just list them. The evaluator should handle permutation if needed,
|
|
123
|
+
# or we should shuffle here if we want to present them as A, B, C, D.
|
|
124
|
+
# For this implementation, I will just append them.
|
|
125
|
+
choices = [correct_answer] + incorrect_answers
|
|
126
|
+
# Note: In a real evaluation pipeline, you'd want to shuffle these and track the correct index.
|
|
127
|
+
# But since we are just loading, we'll leave it as is.
|
|
128
|
+
# Actually, let's sort them to be deterministic if we can't shuffle safely.
|
|
129
|
+
choices.sort()
|
|
130
|
+
|
|
131
|
+
# Determine the answer label
|
|
132
|
+
try:
|
|
133
|
+
answer_idx = choices.index(correct_answer)
|
|
134
|
+
answer = _CHOICE_LABELS[answer_idx]
|
|
135
|
+
except ValueError:
|
|
136
|
+
answer = "" # Should not happen if correct_answer is in choices
|
|
137
|
+
|
|
138
|
+
metadata_keys = {
|
|
139
|
+
"Question", "question", "Correct Answer", "correct_answer",
|
|
140
|
+
"Incorrect Answer 1", "incorrect_answer_1",
|
|
141
|
+
"Incorrect Answer 2", "incorrect_answer_2",
|
|
142
|
+
"Incorrect Answer 3", "incorrect_answer_3",
|
|
143
|
+
"choices", "options"
|
|
144
|
+
}
|
|
145
|
+
metadata = {key: value for key, value in row.items() if key not in metadata_keys}
|
|
146
|
+
|
|
147
|
+
return GpqaSample(
|
|
148
|
+
unique_id=str(unique_id),
|
|
149
|
+
question=str(question),
|
|
150
|
+
choices=choices,
|
|
151
|
+
answer=answer,
|
|
152
|
+
metadata=metadata,
|
|
153
|
+
)
|
|
154
|
+
|
|
155
|
+
|
|
156
|
+
def _load_from_huggingface(*, split: str, subset: str) -> Iterable[dict[str, Any]]:
|
|
157
|
+
try:
|
|
158
|
+
from datasets import load_dataset
|
|
159
|
+
except ImportError as exc: # pragma: no cover - optional dependency
|
|
160
|
+
raise RuntimeError(
|
|
161
|
+
"datasets is required to load GPQA from Hugging Face. Install it via `uv pip install '.[hf]'`."
|
|
162
|
+
) from exc
|
|
163
|
+
|
|
164
|
+
dataset = load_dataset(_DATASET_NAME, subset, split=split)
|
|
165
|
+
for row in dataset:
|
|
166
|
+
yield dict(row)
|
|
167
|
+
|
|
168
|
+
|
|
169
|
+
def _load_from_local(root: Path) -> Iterator[dict[str, Any]]:
|
|
170
|
+
if not root.exists():
|
|
171
|
+
raise FileNotFoundError(f"Local dataset directory not found: {root}")
|
|
172
|
+
|
|
173
|
+
for path in root.rglob("*"):
|
|
174
|
+
if path.suffix.lower() == ".json":
|
|
175
|
+
with path.open("r", encoding="utf-8") as handle:
|
|
176
|
+
row = json.load(handle)
|
|
177
|
+
row.setdefault("id", path.stem)
|
|
178
|
+
yield row
|
|
179
|
+
elif path.suffix.lower() in {".jsonl", ".ndjson"}:
|
|
180
|
+
with path.open("r", encoding="utf-8") as handle:
|
|
181
|
+
for line_num, line in enumerate(handle, start=1):
|
|
182
|
+
line = line.strip()
|
|
183
|
+
if not line:
|
|
184
|
+
continue
|
|
185
|
+
row = json.loads(line)
|
|
186
|
+
row.setdefault("id", f"{path.stem}-{line_num}")
|
|
187
|
+
yield row
|
|
188
|
+
|
|
189
|
+
|
|
190
|
+
__all__ = ["GpqaSample", "load_gpqa"]
|
themis/datasets/gsm8k.py
ADDED
|
@@ -0,0 +1,123 @@
|
|
|
1
|
+
"""Helpers for working with the openai/gsm8k dataset."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import json
|
|
6
|
+
from pathlib import Path
|
|
7
|
+
from typing import Any, Iterable, Iterator, List, Sequence
|
|
8
|
+
|
|
9
|
+
from pydantic import BaseModel, Field, field_validator
|
|
10
|
+
|
|
11
|
+
_DATASET_NAME = "openai/gsm8k"
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class Gsm8kSample(BaseModel):
|
|
15
|
+
unique_id: str
|
|
16
|
+
question: str
|
|
17
|
+
answer: str
|
|
18
|
+
metadata: dict[str, Any] = Field(default_factory=dict)
|
|
19
|
+
|
|
20
|
+
@field_validator("metadata", mode="before")
|
|
21
|
+
@classmethod
|
|
22
|
+
def _ensure_metadata(cls, value: Any) -> dict[str, Any]:
|
|
23
|
+
return dict(value or {})
|
|
24
|
+
|
|
25
|
+
def to_generation_example(self) -> dict[str, Any]:
|
|
26
|
+
payload = {
|
|
27
|
+
"unique_id": self.unique_id,
|
|
28
|
+
"question": self.question,
|
|
29
|
+
"answer": self.answer,
|
|
30
|
+
}
|
|
31
|
+
payload.update(self.metadata)
|
|
32
|
+
return payload
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def load_gsm8k(
|
|
36
|
+
*,
|
|
37
|
+
split: str = "test",
|
|
38
|
+
limit: int | None = None,
|
|
39
|
+
source: str = "huggingface",
|
|
40
|
+
data_dir: str | Path | None = None,
|
|
41
|
+
subset: str = "main",
|
|
42
|
+
) -> List[Gsm8kSample]:
|
|
43
|
+
"""Load GSM8K samples from Hugging Face or a local directory."""
|
|
44
|
+
|
|
45
|
+
if source not in {"huggingface", "local"}:
|
|
46
|
+
raise ValueError(
|
|
47
|
+
f"Unsupported source '{source}'. Expected one of: 'huggingface', 'local'."
|
|
48
|
+
)
|
|
49
|
+
|
|
50
|
+
if source == "huggingface":
|
|
51
|
+
rows = _load_from_huggingface(split=split, subset=subset)
|
|
52
|
+
else:
|
|
53
|
+
if data_dir is None:
|
|
54
|
+
raise ValueError(
|
|
55
|
+
"data_dir must be provided when source='local'. "
|
|
56
|
+
"Pass dataset.data_dir in configs or --data-dir on the CLI."
|
|
57
|
+
)
|
|
58
|
+
rows = _load_from_local(Path(data_dir))
|
|
59
|
+
|
|
60
|
+
samples: list[Gsm8kSample] = []
|
|
61
|
+
for index, row in enumerate(rows, start=1):
|
|
62
|
+
sample = _row_to_sample(row, index=index)
|
|
63
|
+
samples.append(sample)
|
|
64
|
+
if limit is not None and len(samples) >= limit:
|
|
65
|
+
break
|
|
66
|
+
return samples
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
def _row_to_sample(row: dict[str, Any], *, index: int) -> Gsm8kSample:
|
|
70
|
+
unique_id = (
|
|
71
|
+
row.get("id")
|
|
72
|
+
or row.get("unique_id")
|
|
73
|
+
or f"gsm8k-{index:05d}"
|
|
74
|
+
)
|
|
75
|
+
question = row.get("question") or row.get("problem") or ""
|
|
76
|
+
answer = row.get("answer") or ""
|
|
77
|
+
|
|
78
|
+
core_keys = {"id", "unique_id", "question", "problem", "answer"}
|
|
79
|
+
metadata = {key: value for key, value in row.items() if key not in core_keys}
|
|
80
|
+
|
|
81
|
+
return Gsm8kSample(
|
|
82
|
+
unique_id=str(unique_id),
|
|
83
|
+
question=str(question),
|
|
84
|
+
answer=str(answer),
|
|
85
|
+
metadata=metadata,
|
|
86
|
+
)
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
def _load_from_huggingface(*, split: str, subset: str) -> Iterable[dict[str, Any]]:
|
|
90
|
+
try:
|
|
91
|
+
from datasets import load_dataset
|
|
92
|
+
except ImportError as exc: # pragma: no cover - optional dependency
|
|
93
|
+
raise RuntimeError(
|
|
94
|
+
"datasets is required to load GSM8K from Hugging Face. Install it via `uv pip install '.[hf]'`."
|
|
95
|
+
) from exc
|
|
96
|
+
|
|
97
|
+
dataset = load_dataset(_DATASET_NAME, subset, split=split)
|
|
98
|
+
for row in dataset:
|
|
99
|
+
yield dict(row)
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
def _load_from_local(root: Path) -> Iterator[dict[str, Any]]:
|
|
103
|
+
if not root.exists():
|
|
104
|
+
raise FileNotFoundError(f"Local dataset directory not found: {root}")
|
|
105
|
+
|
|
106
|
+
for path in root.rglob("*"):
|
|
107
|
+
if path.suffix.lower() == ".json":
|
|
108
|
+
with path.open("r", encoding="utf-8") as handle:
|
|
109
|
+
row = json.load(handle)
|
|
110
|
+
row.setdefault("id", path.stem)
|
|
111
|
+
yield row
|
|
112
|
+
elif path.suffix.lower() in {".jsonl", ".ndjson"}:
|
|
113
|
+
with path.open("r", encoding="utf-8") as handle:
|
|
114
|
+
for line_num, line in enumerate(handle, start=1):
|
|
115
|
+
line = line.strip()
|
|
116
|
+
if not line:
|
|
117
|
+
continue
|
|
118
|
+
row = json.loads(line)
|
|
119
|
+
row.setdefault("id", f"{path.stem}-{line_num}")
|
|
120
|
+
yield row
|
|
121
|
+
|
|
122
|
+
|
|
123
|
+
__all__ = ["Gsm8kSample", "load_gsm8k"]
|
|
@@ -0,0 +1,124 @@
|
|
|
1
|
+
"""Helpers for working with the apple/GSM-Symbolic dataset."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import json
|
|
6
|
+
from pathlib import Path
|
|
7
|
+
from typing import Any, Iterable, Iterator, List, Sequence
|
|
8
|
+
|
|
9
|
+
from pydantic import BaseModel, Field, field_validator
|
|
10
|
+
|
|
11
|
+
_DATASET_NAME = "apple/GSM-Symbolic"
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class GsmSymbolicSample(BaseModel):
|
|
15
|
+
unique_id: str
|
|
16
|
+
question: str
|
|
17
|
+
answer: str
|
|
18
|
+
metadata: dict[str, Any] = Field(default_factory=dict)
|
|
19
|
+
|
|
20
|
+
@field_validator("metadata", mode="before")
|
|
21
|
+
@classmethod
|
|
22
|
+
def _ensure_metadata(cls, value: Any) -> dict[str, Any]:
|
|
23
|
+
return dict(value or {})
|
|
24
|
+
|
|
25
|
+
def to_generation_example(self) -> dict[str, Any]:
|
|
26
|
+
payload = {
|
|
27
|
+
"unique_id": self.unique_id,
|
|
28
|
+
"question": self.question,
|
|
29
|
+
"answer": self.answer,
|
|
30
|
+
}
|
|
31
|
+
payload.update(self.metadata)
|
|
32
|
+
return payload
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def load_gsm_symbolic(
|
|
36
|
+
*,
|
|
37
|
+
split: str = "test",
|
|
38
|
+
limit: int | None = None,
|
|
39
|
+
source: str = "huggingface",
|
|
40
|
+
data_dir: str | Path | None = None,
|
|
41
|
+
subset: str = "main",
|
|
42
|
+
) -> List[GsmSymbolicSample]:
|
|
43
|
+
"""Load GSM-Symbolic samples from Hugging Face or a local directory."""
|
|
44
|
+
|
|
45
|
+
if source not in {"huggingface", "local"}:
|
|
46
|
+
raise ValueError(
|
|
47
|
+
f"Unsupported source '{source}'. Expected one of: 'huggingface', 'local'."
|
|
48
|
+
)
|
|
49
|
+
|
|
50
|
+
if source == "huggingface":
|
|
51
|
+
rows = _load_from_huggingface(split=split, subset=subset)
|
|
52
|
+
else:
|
|
53
|
+
if data_dir is None:
|
|
54
|
+
raise ValueError(
|
|
55
|
+
"data_dir must be provided when source='local'. "
|
|
56
|
+
"Pass dataset.data_dir in configs or --data-dir on the CLI."
|
|
57
|
+
)
|
|
58
|
+
rows = _load_from_local(Path(data_dir))
|
|
59
|
+
|
|
60
|
+
samples: list[GsmSymbolicSample] = []
|
|
61
|
+
for index, row in enumerate(rows, start=1):
|
|
62
|
+
sample = _row_to_sample(row, index=index)
|
|
63
|
+
samples.append(sample)
|
|
64
|
+
if limit is not None and len(samples) >= limit:
|
|
65
|
+
break
|
|
66
|
+
return samples
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
def _row_to_sample(row: dict[str, Any], *, index: int) -> GsmSymbolicSample:
|
|
70
|
+
unique_id = (
|
|
71
|
+
row.get("id")
|
|
72
|
+
or row.get("unique_id")
|
|
73
|
+
or f"gsm-symbolic-{index:05d}"
|
|
74
|
+
)
|
|
75
|
+
question = row.get("question") or row.get("problem") or ""
|
|
76
|
+
answer = row.get("answer") or ""
|
|
77
|
+
|
|
78
|
+
core_keys = {"id", "unique_id", "question", "problem", "answer"}
|
|
79
|
+
metadata = {key: value for key, value in row.items() if key not in core_keys}
|
|
80
|
+
|
|
81
|
+
return GsmSymbolicSample(
|
|
82
|
+
unique_id=str(unique_id),
|
|
83
|
+
question=str(question),
|
|
84
|
+
answer=str(answer),
|
|
85
|
+
metadata=metadata,
|
|
86
|
+
)
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
def _load_from_huggingface(*, split: str, subset: str) -> Iterable[dict[str, Any]]:
|
|
90
|
+
try:
|
|
91
|
+
from datasets import load_dataset
|
|
92
|
+
except ImportError as exc: # pragma: no cover - optional dependency
|
|
93
|
+
raise RuntimeError(
|
|
94
|
+
"datasets is required to load GSM-Symbolic from Hugging Face. Install it via `uv pip install '.[hf]'`."
|
|
95
|
+
) from exc
|
|
96
|
+
|
|
97
|
+
# GSM-Symbolic might have different configs, defaulting to main if not specified
|
|
98
|
+
dataset = load_dataset(_DATASET_NAME, subset, split=split)
|
|
99
|
+
for row in dataset:
|
|
100
|
+
yield dict(row)
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
def _load_from_local(root: Path) -> Iterator[dict[str, Any]]:
|
|
104
|
+
if not root.exists():
|
|
105
|
+
raise FileNotFoundError(f"Local dataset directory not found: {root}")
|
|
106
|
+
|
|
107
|
+
for path in root.rglob("*"):
|
|
108
|
+
if path.suffix.lower() == ".json":
|
|
109
|
+
with path.open("r", encoding="utf-8") as handle:
|
|
110
|
+
row = json.load(handle)
|
|
111
|
+
row.setdefault("id", path.stem)
|
|
112
|
+
yield row
|
|
113
|
+
elif path.suffix.lower() in {".jsonl", ".ndjson"}:
|
|
114
|
+
with path.open("r", encoding="utf-8") as handle:
|
|
115
|
+
for line_num, line in enumerate(handle, start=1):
|
|
116
|
+
line = line.strip()
|
|
117
|
+
if not line:
|
|
118
|
+
continue
|
|
119
|
+
row = json.loads(line)
|
|
120
|
+
row.setdefault("id", f"{path.stem}-{line_num}")
|
|
121
|
+
yield row
|
|
122
|
+
|
|
123
|
+
|
|
124
|
+
__all__ = ["GsmSymbolicSample", "load_gsm_symbolic"]
|
|
@@ -0,0 +1,122 @@
|
|
|
1
|
+
"""Helpers for working with the HuggingFaceH4/MATH-500 dataset."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import json
|
|
6
|
+
from pathlib import Path
|
|
7
|
+
from typing import Any, Iterable, Iterator, List, Sequence
|
|
8
|
+
|
|
9
|
+
from pydantic import BaseModel, Field, field_validator
|
|
10
|
+
|
|
11
|
+
_DATASET_NAME = "HuggingFaceH4/MATH-500"
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class MathSample(BaseModel):
|
|
15
|
+
unique_id: str
|
|
16
|
+
problem: str
|
|
17
|
+
solution: str
|
|
18
|
+
answer: str
|
|
19
|
+
subject: str = Field(default="unknown")
|
|
20
|
+
level: int | str = Field(default=0)
|
|
21
|
+
extra: dict[str, Any] = Field(default_factory=dict)
|
|
22
|
+
|
|
23
|
+
@field_validator("level", mode="before")
|
|
24
|
+
@classmethod
|
|
25
|
+
def _normalize_level(cls, value: Any) -> int | str:
|
|
26
|
+
if value in (None, ""):
|
|
27
|
+
return 0
|
|
28
|
+
try:
|
|
29
|
+
return int(value)
|
|
30
|
+
except (TypeError, ValueError):
|
|
31
|
+
return value
|
|
32
|
+
|
|
33
|
+
@field_validator("extra", mode="before")
|
|
34
|
+
@classmethod
|
|
35
|
+
def _ensure_extra(cls, value: Any) -> dict[str, Any]:
|
|
36
|
+
return dict(value or {})
|
|
37
|
+
|
|
38
|
+
def to_generation_example(self) -> dict[str, Any]:
|
|
39
|
+
payload = self.model_dump()
|
|
40
|
+
payload.pop("extra", None)
|
|
41
|
+
payload.update(self.extra)
|
|
42
|
+
return payload
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
def load_math500(
|
|
46
|
+
*,
|
|
47
|
+
split: str = "test",
|
|
48
|
+
limit: int | None = None,
|
|
49
|
+
subjects: Sequence[str] | None = None,
|
|
50
|
+
source: str = "huggingface",
|
|
51
|
+
data_dir: str | Path | None = None,
|
|
52
|
+
) -> List[MathSample]:
|
|
53
|
+
"""Load MATH-500 samples from Hugging Face or a local directory."""
|
|
54
|
+
|
|
55
|
+
if source not in {"huggingface", "local"}:
|
|
56
|
+
raise ValueError(
|
|
57
|
+
f"Unsupported source '{source}'. Expected one of: 'huggingface', 'local'."
|
|
58
|
+
)
|
|
59
|
+
|
|
60
|
+
if source == "huggingface":
|
|
61
|
+
rows = _load_from_huggingface(split=split)
|
|
62
|
+
else:
|
|
63
|
+
if data_dir is None:
|
|
64
|
+
raise ValueError(
|
|
65
|
+
"data_dir must be provided when source='local'. "
|
|
66
|
+
"Pass dataset.data_dir in configs or --data-dir on the CLI."
|
|
67
|
+
)
|
|
68
|
+
rows = _load_from_local(Path(data_dir))
|
|
69
|
+
|
|
70
|
+
samples: list[MathSample] = []
|
|
71
|
+
selected_subjects = {s.lower() for s in subjects} if subjects else None
|
|
72
|
+
for row in rows:
|
|
73
|
+
if (
|
|
74
|
+
selected_subjects
|
|
75
|
+
and row.get("subject", "").lower() not in selected_subjects
|
|
76
|
+
):
|
|
77
|
+
continue
|
|
78
|
+
samples.append(_row_to_sample(row))
|
|
79
|
+
if limit is not None and len(samples) >= limit:
|
|
80
|
+
break
|
|
81
|
+
return samples
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
def _row_to_sample(row: dict[str, Any]) -> MathSample:
|
|
85
|
+
core_keys = {"unique_id", "problem", "solution", "answer", "subject", "level"}
|
|
86
|
+
extra = {key: value for key, value in row.items() if key not in core_keys}
|
|
87
|
+
data = {
|
|
88
|
+
"unique_id": str(row["unique_id"]),
|
|
89
|
+
"problem": row.get("problem", ""),
|
|
90
|
+
"solution": row.get("solution", ""),
|
|
91
|
+
"answer": row.get("answer", ""),
|
|
92
|
+
"subject": row.get("subject", "unknown"),
|
|
93
|
+
"level": row.get("level", 0),
|
|
94
|
+
"extra": extra,
|
|
95
|
+
}
|
|
96
|
+
return MathSample.model_validate(data)
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
def _load_from_huggingface(*, split: str) -> Iterable[dict[str, Any]]:
|
|
100
|
+
try:
|
|
101
|
+
from datasets import load_dataset
|
|
102
|
+
except ImportError as exc: # pragma: no cover - optional dependency
|
|
103
|
+
raise RuntimeError(
|
|
104
|
+
"datasets is required to load MATH-500 from Hugging Face. Install it via `uv pip install '.[math]'`."
|
|
105
|
+
) from exc
|
|
106
|
+
|
|
107
|
+
dataset = load_dataset(_DATASET_NAME, split=split)
|
|
108
|
+
for row in dataset:
|
|
109
|
+
yield dict(row)
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
def _load_from_local(root: Path) -> Iterator[dict[str, Any]]:
|
|
113
|
+
if not root.exists():
|
|
114
|
+
raise FileNotFoundError(f"Local dataset directory not found: {root}")
|
|
115
|
+
for path in root.rglob("*.json"):
|
|
116
|
+
with path.open("r", encoding="utf-8") as handle:
|
|
117
|
+
row = json.load(handle)
|
|
118
|
+
row.setdefault("unique_id", path.stem)
|
|
119
|
+
yield row
|
|
120
|
+
|
|
121
|
+
|
|
122
|
+
__all__ = ["MathSample", "load_math500"]
|