themis-eval 0.1.0__py3-none-any.whl → 0.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (132) hide show
  1. themis/cli/__init__.py +5 -0
  2. themis/cli/__main__.py +6 -0
  3. themis/cli/commands/__init__.py +19 -0
  4. themis/cli/commands/benchmarks.py +221 -0
  5. themis/cli/commands/comparison.py +394 -0
  6. themis/cli/commands/config_commands.py +244 -0
  7. themis/cli/commands/cost.py +214 -0
  8. themis/cli/commands/demo.py +68 -0
  9. themis/cli/commands/info.py +90 -0
  10. themis/cli/commands/leaderboard.py +362 -0
  11. themis/cli/commands/math_benchmarks.py +318 -0
  12. themis/cli/commands/mcq_benchmarks.py +207 -0
  13. themis/cli/commands/sample_run.py +244 -0
  14. themis/cli/commands/visualize.py +299 -0
  15. themis/cli/main.py +93 -0
  16. themis/cli/new_project.py +33 -0
  17. themis/cli/utils.py +51 -0
  18. themis/config/__init__.py +19 -0
  19. themis/config/loader.py +27 -0
  20. themis/config/registry.py +34 -0
  21. themis/config/runtime.py +214 -0
  22. themis/config/schema.py +112 -0
  23. themis/core/__init__.py +5 -0
  24. themis/core/conversation.py +354 -0
  25. themis/core/entities.py +164 -0
  26. themis/core/serialization.py +231 -0
  27. themis/core/tools.py +393 -0
  28. themis/core/types.py +141 -0
  29. themis/datasets/__init__.py +273 -0
  30. themis/datasets/base.py +264 -0
  31. themis/datasets/commonsense_qa.py +174 -0
  32. themis/datasets/competition_math.py +265 -0
  33. themis/datasets/coqa.py +133 -0
  34. themis/datasets/gpqa.py +190 -0
  35. themis/datasets/gsm8k.py +123 -0
  36. themis/datasets/gsm_symbolic.py +124 -0
  37. themis/datasets/math500.py +122 -0
  38. themis/datasets/med_qa.py +179 -0
  39. themis/datasets/medmcqa.py +169 -0
  40. themis/datasets/mmlu_pro.py +262 -0
  41. themis/datasets/piqa.py +146 -0
  42. themis/datasets/registry.py +201 -0
  43. themis/datasets/schema.py +245 -0
  44. themis/datasets/sciq.py +150 -0
  45. themis/datasets/social_i_qa.py +151 -0
  46. themis/datasets/super_gpqa.py +263 -0
  47. themis/evaluation/__init__.py +1 -0
  48. themis/evaluation/conditional.py +410 -0
  49. themis/evaluation/extractors/__init__.py +19 -0
  50. themis/evaluation/extractors/error_taxonomy_extractor.py +80 -0
  51. themis/evaluation/extractors/exceptions.py +7 -0
  52. themis/evaluation/extractors/identity_extractor.py +29 -0
  53. themis/evaluation/extractors/json_field_extractor.py +45 -0
  54. themis/evaluation/extractors/math_verify_extractor.py +37 -0
  55. themis/evaluation/extractors/regex_extractor.py +43 -0
  56. themis/evaluation/math_verify_utils.py +87 -0
  57. themis/evaluation/metrics/__init__.py +21 -0
  58. themis/evaluation/metrics/composite_metric.py +47 -0
  59. themis/evaluation/metrics/consistency_metric.py +80 -0
  60. themis/evaluation/metrics/exact_match.py +51 -0
  61. themis/evaluation/metrics/length_difference_tolerance.py +33 -0
  62. themis/evaluation/metrics/math_verify_accuracy.py +40 -0
  63. themis/evaluation/metrics/pairwise_judge_metric.py +141 -0
  64. themis/evaluation/metrics/response_length.py +33 -0
  65. themis/evaluation/metrics/rubric_judge_metric.py +134 -0
  66. themis/evaluation/pipeline.py +49 -0
  67. themis/evaluation/pipelines/__init__.py +15 -0
  68. themis/evaluation/pipelines/composable_pipeline.py +357 -0
  69. themis/evaluation/pipelines/standard_pipeline.py +288 -0
  70. themis/evaluation/reports.py +293 -0
  71. themis/evaluation/statistics/__init__.py +53 -0
  72. themis/evaluation/statistics/bootstrap.py +79 -0
  73. themis/evaluation/statistics/confidence_intervals.py +121 -0
  74. themis/evaluation/statistics/distributions.py +207 -0
  75. themis/evaluation/statistics/effect_sizes.py +124 -0
  76. themis/evaluation/statistics/hypothesis_tests.py +305 -0
  77. themis/evaluation/statistics/types.py +139 -0
  78. themis/evaluation/strategies/__init__.py +13 -0
  79. themis/evaluation/strategies/attempt_aware_evaluation_strategy.py +51 -0
  80. themis/evaluation/strategies/default_evaluation_strategy.py +25 -0
  81. themis/evaluation/strategies/evaluation_strategy.py +24 -0
  82. themis/evaluation/strategies/judge_evaluation_strategy.py +64 -0
  83. themis/experiment/__init__.py +5 -0
  84. themis/experiment/builder.py +151 -0
  85. themis/experiment/cache_manager.py +129 -0
  86. themis/experiment/comparison.py +631 -0
  87. themis/experiment/cost.py +310 -0
  88. themis/experiment/definitions.py +62 -0
  89. themis/experiment/export.py +690 -0
  90. themis/experiment/export_csv.py +159 -0
  91. themis/experiment/integration_manager.py +104 -0
  92. themis/experiment/math.py +192 -0
  93. themis/experiment/mcq.py +169 -0
  94. themis/experiment/orchestrator.py +373 -0
  95. themis/experiment/pricing.py +317 -0
  96. themis/experiment/storage.py +255 -0
  97. themis/experiment/visualization.py +588 -0
  98. themis/generation/__init__.py +1 -0
  99. themis/generation/agentic_runner.py +420 -0
  100. themis/generation/batching.py +254 -0
  101. themis/generation/clients.py +143 -0
  102. themis/generation/conversation_runner.py +236 -0
  103. themis/generation/plan.py +456 -0
  104. themis/generation/providers/litellm_provider.py +221 -0
  105. themis/generation/providers/vllm_provider.py +135 -0
  106. themis/generation/router.py +34 -0
  107. themis/generation/runner.py +207 -0
  108. themis/generation/strategies.py +98 -0
  109. themis/generation/templates.py +71 -0
  110. themis/generation/turn_strategies.py +393 -0
  111. themis/generation/types.py +9 -0
  112. themis/integrations/__init__.py +0 -0
  113. themis/integrations/huggingface.py +61 -0
  114. themis/integrations/wandb.py +65 -0
  115. themis/interfaces/__init__.py +83 -0
  116. themis/project/__init__.py +20 -0
  117. themis/project/definitions.py +98 -0
  118. themis/project/patterns.py +230 -0
  119. themis/providers/__init__.py +5 -0
  120. themis/providers/registry.py +39 -0
  121. themis/utils/api_generator.py +379 -0
  122. themis/utils/cost_tracking.py +376 -0
  123. themis/utils/dashboard.py +452 -0
  124. themis/utils/logging_utils.py +41 -0
  125. themis/utils/progress.py +58 -0
  126. themis/utils/tracing.py +320 -0
  127. {themis_eval-0.1.0.dist-info → themis_eval-0.1.1.dist-info}/METADATA +1 -1
  128. themis_eval-0.1.1.dist-info/RECORD +134 -0
  129. themis_eval-0.1.0.dist-info/RECORD +0 -8
  130. {themis_eval-0.1.0.dist-info → themis_eval-0.1.1.dist-info}/WHEEL +0 -0
  131. {themis_eval-0.1.0.dist-info → themis_eval-0.1.1.dist-info}/licenses/LICENSE +0 -0
  132. {themis_eval-0.1.0.dist-info → themis_eval-0.1.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,190 @@
1
+ """Helpers for working with the math-ai/gpqa dataset."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import json
6
+ import string
7
+ from pathlib import Path
8
+ from typing import Any, Iterable, Iterator, List, Sequence
9
+
10
+ from pydantic import BaseModel, Field, ValidationInfo, field_validator
11
+
12
+ _DATASET_NAME = "math-ai/gpqa"
13
+ _CHOICE_LABELS = tuple(string.ascii_uppercase)
14
+
15
+
16
+ class GpqaSample(BaseModel):
17
+ unique_id: str
18
+ question: str
19
+ choices: list[str]
20
+ answer: str
21
+ subject: str = Field(default="unknown")
22
+ metadata: dict[str, Any] = Field(default_factory=dict)
23
+ choice_labels: list[str] = Field(default_factory=list)
24
+
25
+ @field_validator("choices", mode="before")
26
+ @classmethod
27
+ def _ensure_choices(cls, value: Any) -> list[str]:
28
+ if value is None:
29
+ return []
30
+ if isinstance(value, dict):
31
+ return [str(v) for _, v in sorted(value.items())]
32
+ if isinstance(value, (list, tuple)):
33
+ return [str(item) for item in value]
34
+ raise TypeError("choices must be a sequence or mapping")
35
+
36
+ @field_validator("choice_labels", mode="before")
37
+ @classmethod
38
+ def _build_choice_labels(cls, value: Any, info: ValidationInfo) -> list[str]:
39
+ if value:
40
+ return [str(item) for item in value]
41
+ choices = info.data.get("choices") if hasattr(info, "data") else None
42
+ total = len(choices) if isinstance(choices, list) else 0
43
+ return [*_CHOICE_LABELS[:total]]
44
+
45
+ def to_generation_example(self) -> dict[str, Any]:
46
+ effective_labels = (
47
+ list(self.choice_labels)
48
+ if self.choice_labels
49
+ else list(_CHOICE_LABELS[: len(self.choices)])
50
+ )
51
+ return {
52
+ "unique_id": self.unique_id,
53
+ "question": self.question,
54
+ "choices": list(self.choices),
55
+ "choice_labels": effective_labels,
56
+ "answer": self.answer,
57
+ "subject": self.subject,
58
+ "metadata": dict(self.metadata),
59
+ }
60
+
61
+
62
+ def load_gpqa(
63
+ *,
64
+ split: str = "test",
65
+ limit: int | None = None,
66
+ source: str = "huggingface",
67
+ data_dir: str | Path | None = None,
68
+ subset: str = "gpqa_diamond",
69
+ ) -> List[GpqaSample]:
70
+ """Load GPQA samples from Hugging Face or a local directory."""
71
+
72
+ if source not in {"huggingface", "local"}:
73
+ raise ValueError(
74
+ f"Unsupported source '{source}'. Expected one of: 'huggingface', 'local'."
75
+ )
76
+
77
+ if source == "huggingface":
78
+ rows = _load_from_huggingface(split=split, subset=subset)
79
+ else:
80
+ if data_dir is None:
81
+ raise ValueError(
82
+ "data_dir must be provided when source='local'. "
83
+ "Pass dataset.data_dir in configs or --data-dir on the CLI."
84
+ )
85
+ rows = _load_from_local(Path(data_dir))
86
+
87
+ samples: list[GpqaSample] = []
88
+ for index, row in enumerate(rows, start=1):
89
+ sample = _row_to_sample(row, index=index)
90
+ samples.append(sample)
91
+ if limit is not None and len(samples) >= limit:
92
+ break
93
+ return samples
94
+
95
+
96
+ def _row_to_sample(row: dict[str, Any], *, index: int) -> GpqaSample:
97
+ unique_id = (
98
+ row.get("id")
99
+ or row.get("unique_id")
100
+ or f"gpqa-{index:05d}"
101
+ )
102
+ question = row.get("Question") or row.get("question") or ""
103
+
104
+ # GPQA usually has 'Correct Answer', 'Incorrect Answer 1', 'Incorrect Answer 2', 'Incorrect Answer 3'
105
+ # We need to shuffle them or just present them. For simplicity, we'll just collect them.
106
+ # However, standard GPQA format in HF might be different.
107
+ # Let's assume the HF format: 'Question', 'Correct Answer', 'Incorrect Answer 1', ...
108
+
109
+ correct_answer = row.get("Correct Answer") or row.get("correct_answer") or ""
110
+ incorrect_answers = []
111
+ for i in range(1, 4):
112
+ inc = row.get(f"Incorrect Answer {i}") or row.get(f"incorrect_answer_{i}")
113
+ if inc:
114
+ incorrect_answers.append(str(inc))
115
+
116
+ # If choices are already present (e.g. processed version), use them
117
+ choices = row.get("choices") or row.get("options")
118
+ if not choices:
119
+ # We need to form choices. For now, let's put correct answer first (should be shuffled in real eval)
120
+ # But wait, if we put correct answer first always, the model might learn.
121
+ # Ideally, we should shuffle. But to keep it deterministic for now without extra deps,
122
+ # let's just list them. The evaluator should handle permutation if needed,
123
+ # or we should shuffle here if we want to present them as A, B, C, D.
124
+ # For this implementation, I will just append them.
125
+ choices = [correct_answer] + incorrect_answers
126
+ # Note: In a real evaluation pipeline, you'd want to shuffle these and track the correct index.
127
+ # But since we are just loading, we'll leave it as is.
128
+ # Actually, let's sort them to be deterministic if we can't shuffle safely.
129
+ choices.sort()
130
+
131
+ # Determine the answer label
132
+ try:
133
+ answer_idx = choices.index(correct_answer)
134
+ answer = _CHOICE_LABELS[answer_idx]
135
+ except ValueError:
136
+ answer = "" # Should not happen if correct_answer is in choices
137
+
138
+ metadata_keys = {
139
+ "Question", "question", "Correct Answer", "correct_answer",
140
+ "Incorrect Answer 1", "incorrect_answer_1",
141
+ "Incorrect Answer 2", "incorrect_answer_2",
142
+ "Incorrect Answer 3", "incorrect_answer_3",
143
+ "choices", "options"
144
+ }
145
+ metadata = {key: value for key, value in row.items() if key not in metadata_keys}
146
+
147
+ return GpqaSample(
148
+ unique_id=str(unique_id),
149
+ question=str(question),
150
+ choices=choices,
151
+ answer=answer,
152
+ metadata=metadata,
153
+ )
154
+
155
+
156
+ def _load_from_huggingface(*, split: str, subset: str) -> Iterable[dict[str, Any]]:
157
+ try:
158
+ from datasets import load_dataset
159
+ except ImportError as exc: # pragma: no cover - optional dependency
160
+ raise RuntimeError(
161
+ "datasets is required to load GPQA from Hugging Face. Install it via `uv pip install '.[hf]'`."
162
+ ) from exc
163
+
164
+ dataset = load_dataset(_DATASET_NAME, subset, split=split)
165
+ for row in dataset:
166
+ yield dict(row)
167
+
168
+
169
+ def _load_from_local(root: Path) -> Iterator[dict[str, Any]]:
170
+ if not root.exists():
171
+ raise FileNotFoundError(f"Local dataset directory not found: {root}")
172
+
173
+ for path in root.rglob("*"):
174
+ if path.suffix.lower() == ".json":
175
+ with path.open("r", encoding="utf-8") as handle:
176
+ row = json.load(handle)
177
+ row.setdefault("id", path.stem)
178
+ yield row
179
+ elif path.suffix.lower() in {".jsonl", ".ndjson"}:
180
+ with path.open("r", encoding="utf-8") as handle:
181
+ for line_num, line in enumerate(handle, start=1):
182
+ line = line.strip()
183
+ if not line:
184
+ continue
185
+ row = json.loads(line)
186
+ row.setdefault("id", f"{path.stem}-{line_num}")
187
+ yield row
188
+
189
+
190
+ __all__ = ["GpqaSample", "load_gpqa"]
@@ -0,0 +1,123 @@
1
+ """Helpers for working with the openai/gsm8k dataset."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import json
6
+ from pathlib import Path
7
+ from typing import Any, Iterable, Iterator, List, Sequence
8
+
9
+ from pydantic import BaseModel, Field, field_validator
10
+
11
+ _DATASET_NAME = "openai/gsm8k"
12
+
13
+
14
+ class Gsm8kSample(BaseModel):
15
+ unique_id: str
16
+ question: str
17
+ answer: str
18
+ metadata: dict[str, Any] = Field(default_factory=dict)
19
+
20
+ @field_validator("metadata", mode="before")
21
+ @classmethod
22
+ def _ensure_metadata(cls, value: Any) -> dict[str, Any]:
23
+ return dict(value or {})
24
+
25
+ def to_generation_example(self) -> dict[str, Any]:
26
+ payload = {
27
+ "unique_id": self.unique_id,
28
+ "question": self.question,
29
+ "answer": self.answer,
30
+ }
31
+ payload.update(self.metadata)
32
+ return payload
33
+
34
+
35
+ def load_gsm8k(
36
+ *,
37
+ split: str = "test",
38
+ limit: int | None = None,
39
+ source: str = "huggingface",
40
+ data_dir: str | Path | None = None,
41
+ subset: str = "main",
42
+ ) -> List[Gsm8kSample]:
43
+ """Load GSM8K samples from Hugging Face or a local directory."""
44
+
45
+ if source not in {"huggingface", "local"}:
46
+ raise ValueError(
47
+ f"Unsupported source '{source}'. Expected one of: 'huggingface', 'local'."
48
+ )
49
+
50
+ if source == "huggingface":
51
+ rows = _load_from_huggingface(split=split, subset=subset)
52
+ else:
53
+ if data_dir is None:
54
+ raise ValueError(
55
+ "data_dir must be provided when source='local'. "
56
+ "Pass dataset.data_dir in configs or --data-dir on the CLI."
57
+ )
58
+ rows = _load_from_local(Path(data_dir))
59
+
60
+ samples: list[Gsm8kSample] = []
61
+ for index, row in enumerate(rows, start=1):
62
+ sample = _row_to_sample(row, index=index)
63
+ samples.append(sample)
64
+ if limit is not None and len(samples) >= limit:
65
+ break
66
+ return samples
67
+
68
+
69
+ def _row_to_sample(row: dict[str, Any], *, index: int) -> Gsm8kSample:
70
+ unique_id = (
71
+ row.get("id")
72
+ or row.get("unique_id")
73
+ or f"gsm8k-{index:05d}"
74
+ )
75
+ question = row.get("question") or row.get("problem") or ""
76
+ answer = row.get("answer") or ""
77
+
78
+ core_keys = {"id", "unique_id", "question", "problem", "answer"}
79
+ metadata = {key: value for key, value in row.items() if key not in core_keys}
80
+
81
+ return Gsm8kSample(
82
+ unique_id=str(unique_id),
83
+ question=str(question),
84
+ answer=str(answer),
85
+ metadata=metadata,
86
+ )
87
+
88
+
89
+ def _load_from_huggingface(*, split: str, subset: str) -> Iterable[dict[str, Any]]:
90
+ try:
91
+ from datasets import load_dataset
92
+ except ImportError as exc: # pragma: no cover - optional dependency
93
+ raise RuntimeError(
94
+ "datasets is required to load GSM8K from Hugging Face. Install it via `uv pip install '.[hf]'`."
95
+ ) from exc
96
+
97
+ dataset = load_dataset(_DATASET_NAME, subset, split=split)
98
+ for row in dataset:
99
+ yield dict(row)
100
+
101
+
102
+ def _load_from_local(root: Path) -> Iterator[dict[str, Any]]:
103
+ if not root.exists():
104
+ raise FileNotFoundError(f"Local dataset directory not found: {root}")
105
+
106
+ for path in root.rglob("*"):
107
+ if path.suffix.lower() == ".json":
108
+ with path.open("r", encoding="utf-8") as handle:
109
+ row = json.load(handle)
110
+ row.setdefault("id", path.stem)
111
+ yield row
112
+ elif path.suffix.lower() in {".jsonl", ".ndjson"}:
113
+ with path.open("r", encoding="utf-8") as handle:
114
+ for line_num, line in enumerate(handle, start=1):
115
+ line = line.strip()
116
+ if not line:
117
+ continue
118
+ row = json.loads(line)
119
+ row.setdefault("id", f"{path.stem}-{line_num}")
120
+ yield row
121
+
122
+
123
+ __all__ = ["Gsm8kSample", "load_gsm8k"]
@@ -0,0 +1,124 @@
1
+ """Helpers for working with the apple/GSM-Symbolic dataset."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import json
6
+ from pathlib import Path
7
+ from typing import Any, Iterable, Iterator, List, Sequence
8
+
9
+ from pydantic import BaseModel, Field, field_validator
10
+
11
+ _DATASET_NAME = "apple/GSM-Symbolic"
12
+
13
+
14
+ class GsmSymbolicSample(BaseModel):
15
+ unique_id: str
16
+ question: str
17
+ answer: str
18
+ metadata: dict[str, Any] = Field(default_factory=dict)
19
+
20
+ @field_validator("metadata", mode="before")
21
+ @classmethod
22
+ def _ensure_metadata(cls, value: Any) -> dict[str, Any]:
23
+ return dict(value or {})
24
+
25
+ def to_generation_example(self) -> dict[str, Any]:
26
+ payload = {
27
+ "unique_id": self.unique_id,
28
+ "question": self.question,
29
+ "answer": self.answer,
30
+ }
31
+ payload.update(self.metadata)
32
+ return payload
33
+
34
+
35
+ def load_gsm_symbolic(
36
+ *,
37
+ split: str = "test",
38
+ limit: int | None = None,
39
+ source: str = "huggingface",
40
+ data_dir: str | Path | None = None,
41
+ subset: str = "main",
42
+ ) -> List[GsmSymbolicSample]:
43
+ """Load GSM-Symbolic samples from Hugging Face or a local directory."""
44
+
45
+ if source not in {"huggingface", "local"}:
46
+ raise ValueError(
47
+ f"Unsupported source '{source}'. Expected one of: 'huggingface', 'local'."
48
+ )
49
+
50
+ if source == "huggingface":
51
+ rows = _load_from_huggingface(split=split, subset=subset)
52
+ else:
53
+ if data_dir is None:
54
+ raise ValueError(
55
+ "data_dir must be provided when source='local'. "
56
+ "Pass dataset.data_dir in configs or --data-dir on the CLI."
57
+ )
58
+ rows = _load_from_local(Path(data_dir))
59
+
60
+ samples: list[GsmSymbolicSample] = []
61
+ for index, row in enumerate(rows, start=1):
62
+ sample = _row_to_sample(row, index=index)
63
+ samples.append(sample)
64
+ if limit is not None and len(samples) >= limit:
65
+ break
66
+ return samples
67
+
68
+
69
+ def _row_to_sample(row: dict[str, Any], *, index: int) -> GsmSymbolicSample:
70
+ unique_id = (
71
+ row.get("id")
72
+ or row.get("unique_id")
73
+ or f"gsm-symbolic-{index:05d}"
74
+ )
75
+ question = row.get("question") or row.get("problem") or ""
76
+ answer = row.get("answer") or ""
77
+
78
+ core_keys = {"id", "unique_id", "question", "problem", "answer"}
79
+ metadata = {key: value for key, value in row.items() if key not in core_keys}
80
+
81
+ return GsmSymbolicSample(
82
+ unique_id=str(unique_id),
83
+ question=str(question),
84
+ answer=str(answer),
85
+ metadata=metadata,
86
+ )
87
+
88
+
89
+ def _load_from_huggingface(*, split: str, subset: str) -> Iterable[dict[str, Any]]:
90
+ try:
91
+ from datasets import load_dataset
92
+ except ImportError as exc: # pragma: no cover - optional dependency
93
+ raise RuntimeError(
94
+ "datasets is required to load GSM-Symbolic from Hugging Face. Install it via `uv pip install '.[hf]'`."
95
+ ) from exc
96
+
97
+ # GSM-Symbolic might have different configs, defaulting to main if not specified
98
+ dataset = load_dataset(_DATASET_NAME, subset, split=split)
99
+ for row in dataset:
100
+ yield dict(row)
101
+
102
+
103
+ def _load_from_local(root: Path) -> Iterator[dict[str, Any]]:
104
+ if not root.exists():
105
+ raise FileNotFoundError(f"Local dataset directory not found: {root}")
106
+
107
+ for path in root.rglob("*"):
108
+ if path.suffix.lower() == ".json":
109
+ with path.open("r", encoding="utf-8") as handle:
110
+ row = json.load(handle)
111
+ row.setdefault("id", path.stem)
112
+ yield row
113
+ elif path.suffix.lower() in {".jsonl", ".ndjson"}:
114
+ with path.open("r", encoding="utf-8") as handle:
115
+ for line_num, line in enumerate(handle, start=1):
116
+ line = line.strip()
117
+ if not line:
118
+ continue
119
+ row = json.loads(line)
120
+ row.setdefault("id", f"{path.stem}-{line_num}")
121
+ yield row
122
+
123
+
124
+ __all__ = ["GsmSymbolicSample", "load_gsm_symbolic"]
@@ -0,0 +1,122 @@
1
+ """Helpers for working with the HuggingFaceH4/MATH-500 dataset."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import json
6
+ from pathlib import Path
7
+ from typing import Any, Iterable, Iterator, List, Sequence
8
+
9
+ from pydantic import BaseModel, Field, field_validator
10
+
11
+ _DATASET_NAME = "HuggingFaceH4/MATH-500"
12
+
13
+
14
+ class MathSample(BaseModel):
15
+ unique_id: str
16
+ problem: str
17
+ solution: str
18
+ answer: str
19
+ subject: str = Field(default="unknown")
20
+ level: int | str = Field(default=0)
21
+ extra: dict[str, Any] = Field(default_factory=dict)
22
+
23
+ @field_validator("level", mode="before")
24
+ @classmethod
25
+ def _normalize_level(cls, value: Any) -> int | str:
26
+ if value in (None, ""):
27
+ return 0
28
+ try:
29
+ return int(value)
30
+ except (TypeError, ValueError):
31
+ return value
32
+
33
+ @field_validator("extra", mode="before")
34
+ @classmethod
35
+ def _ensure_extra(cls, value: Any) -> dict[str, Any]:
36
+ return dict(value or {})
37
+
38
+ def to_generation_example(self) -> dict[str, Any]:
39
+ payload = self.model_dump()
40
+ payload.pop("extra", None)
41
+ payload.update(self.extra)
42
+ return payload
43
+
44
+
45
+ def load_math500(
46
+ *,
47
+ split: str = "test",
48
+ limit: int | None = None,
49
+ subjects: Sequence[str] | None = None,
50
+ source: str = "huggingface",
51
+ data_dir: str | Path | None = None,
52
+ ) -> List[MathSample]:
53
+ """Load MATH-500 samples from Hugging Face or a local directory."""
54
+
55
+ if source not in {"huggingface", "local"}:
56
+ raise ValueError(
57
+ f"Unsupported source '{source}'. Expected one of: 'huggingface', 'local'."
58
+ )
59
+
60
+ if source == "huggingface":
61
+ rows = _load_from_huggingface(split=split)
62
+ else:
63
+ if data_dir is None:
64
+ raise ValueError(
65
+ "data_dir must be provided when source='local'. "
66
+ "Pass dataset.data_dir in configs or --data-dir on the CLI."
67
+ )
68
+ rows = _load_from_local(Path(data_dir))
69
+
70
+ samples: list[MathSample] = []
71
+ selected_subjects = {s.lower() for s in subjects} if subjects else None
72
+ for row in rows:
73
+ if (
74
+ selected_subjects
75
+ and row.get("subject", "").lower() not in selected_subjects
76
+ ):
77
+ continue
78
+ samples.append(_row_to_sample(row))
79
+ if limit is not None and len(samples) >= limit:
80
+ break
81
+ return samples
82
+
83
+
84
+ def _row_to_sample(row: dict[str, Any]) -> MathSample:
85
+ core_keys = {"unique_id", "problem", "solution", "answer", "subject", "level"}
86
+ extra = {key: value for key, value in row.items() if key not in core_keys}
87
+ data = {
88
+ "unique_id": str(row["unique_id"]),
89
+ "problem": row.get("problem", ""),
90
+ "solution": row.get("solution", ""),
91
+ "answer": row.get("answer", ""),
92
+ "subject": row.get("subject", "unknown"),
93
+ "level": row.get("level", 0),
94
+ "extra": extra,
95
+ }
96
+ return MathSample.model_validate(data)
97
+
98
+
99
+ def _load_from_huggingface(*, split: str) -> Iterable[dict[str, Any]]:
100
+ try:
101
+ from datasets import load_dataset
102
+ except ImportError as exc: # pragma: no cover - optional dependency
103
+ raise RuntimeError(
104
+ "datasets is required to load MATH-500 from Hugging Face. Install it via `uv pip install '.[math]'`."
105
+ ) from exc
106
+
107
+ dataset = load_dataset(_DATASET_NAME, split=split)
108
+ for row in dataset:
109
+ yield dict(row)
110
+
111
+
112
+ def _load_from_local(root: Path) -> Iterator[dict[str, Any]]:
113
+ if not root.exists():
114
+ raise FileNotFoundError(f"Local dataset directory not found: {root}")
115
+ for path in root.rglob("*.json"):
116
+ with path.open("r", encoding="utf-8") as handle:
117
+ row = json.load(handle)
118
+ row.setdefault("unique_id", path.stem)
119
+ yield row
120
+
121
+
122
+ __all__ = ["MathSample", "load_math500"]