themis-eval 0.1.0__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (158) hide show
  1. themis/__init__.py +12 -1
  2. themis/_version.py +2 -2
  3. themis/api.py +343 -0
  4. themis/backends/__init__.py +17 -0
  5. themis/backends/execution.py +197 -0
  6. themis/backends/storage.py +260 -0
  7. themis/cli/__init__.py +5 -0
  8. themis/cli/__main__.py +6 -0
  9. themis/cli/commands/__init__.py +19 -0
  10. themis/cli/commands/benchmarks.py +221 -0
  11. themis/cli/commands/comparison.py +394 -0
  12. themis/cli/commands/config_commands.py +244 -0
  13. themis/cli/commands/cost.py +214 -0
  14. themis/cli/commands/demo.py +68 -0
  15. themis/cli/commands/info.py +90 -0
  16. themis/cli/commands/leaderboard.py +362 -0
  17. themis/cli/commands/math_benchmarks.py +318 -0
  18. themis/cli/commands/mcq_benchmarks.py +207 -0
  19. themis/cli/commands/results.py +252 -0
  20. themis/cli/commands/sample_run.py +244 -0
  21. themis/cli/commands/visualize.py +299 -0
  22. themis/cli/main.py +463 -0
  23. themis/cli/new_project.py +33 -0
  24. themis/cli/utils.py +51 -0
  25. themis/comparison/__init__.py +25 -0
  26. themis/comparison/engine.py +348 -0
  27. themis/comparison/reports.py +283 -0
  28. themis/comparison/statistics.py +402 -0
  29. themis/config/__init__.py +19 -0
  30. themis/config/loader.py +27 -0
  31. themis/config/registry.py +34 -0
  32. themis/config/runtime.py +214 -0
  33. themis/config/schema.py +112 -0
  34. themis/core/__init__.py +5 -0
  35. themis/core/conversation.py +354 -0
  36. themis/core/entities.py +184 -0
  37. themis/core/serialization.py +231 -0
  38. themis/core/tools.py +393 -0
  39. themis/core/types.py +141 -0
  40. themis/datasets/__init__.py +273 -0
  41. themis/datasets/base.py +264 -0
  42. themis/datasets/commonsense_qa.py +174 -0
  43. themis/datasets/competition_math.py +265 -0
  44. themis/datasets/coqa.py +133 -0
  45. themis/datasets/gpqa.py +190 -0
  46. themis/datasets/gsm8k.py +123 -0
  47. themis/datasets/gsm_symbolic.py +124 -0
  48. themis/datasets/math500.py +122 -0
  49. themis/datasets/med_qa.py +179 -0
  50. themis/datasets/medmcqa.py +169 -0
  51. themis/datasets/mmlu_pro.py +262 -0
  52. themis/datasets/piqa.py +146 -0
  53. themis/datasets/registry.py +201 -0
  54. themis/datasets/schema.py +245 -0
  55. themis/datasets/sciq.py +150 -0
  56. themis/datasets/social_i_qa.py +151 -0
  57. themis/datasets/super_gpqa.py +263 -0
  58. themis/evaluation/__init__.py +1 -0
  59. themis/evaluation/conditional.py +410 -0
  60. themis/evaluation/extractors/__init__.py +19 -0
  61. themis/evaluation/extractors/error_taxonomy_extractor.py +80 -0
  62. themis/evaluation/extractors/exceptions.py +7 -0
  63. themis/evaluation/extractors/identity_extractor.py +29 -0
  64. themis/evaluation/extractors/json_field_extractor.py +45 -0
  65. themis/evaluation/extractors/math_verify_extractor.py +37 -0
  66. themis/evaluation/extractors/regex_extractor.py +43 -0
  67. themis/evaluation/math_verify_utils.py +87 -0
  68. themis/evaluation/metrics/__init__.py +21 -0
  69. themis/evaluation/metrics/code/__init__.py +19 -0
  70. themis/evaluation/metrics/code/codebleu.py +144 -0
  71. themis/evaluation/metrics/code/execution.py +280 -0
  72. themis/evaluation/metrics/code/pass_at_k.py +181 -0
  73. themis/evaluation/metrics/composite_metric.py +47 -0
  74. themis/evaluation/metrics/consistency_metric.py +80 -0
  75. themis/evaluation/metrics/exact_match.py +51 -0
  76. themis/evaluation/metrics/length_difference_tolerance.py +33 -0
  77. themis/evaluation/metrics/math_verify_accuracy.py +40 -0
  78. themis/evaluation/metrics/nlp/__init__.py +21 -0
  79. themis/evaluation/metrics/nlp/bertscore.py +138 -0
  80. themis/evaluation/metrics/nlp/bleu.py +129 -0
  81. themis/evaluation/metrics/nlp/meteor.py +153 -0
  82. themis/evaluation/metrics/nlp/rouge.py +136 -0
  83. themis/evaluation/metrics/pairwise_judge_metric.py +141 -0
  84. themis/evaluation/metrics/response_length.py +33 -0
  85. themis/evaluation/metrics/rubric_judge_metric.py +134 -0
  86. themis/evaluation/pipeline.py +49 -0
  87. themis/evaluation/pipelines/__init__.py +15 -0
  88. themis/evaluation/pipelines/composable_pipeline.py +357 -0
  89. themis/evaluation/pipelines/standard_pipeline.py +348 -0
  90. themis/evaluation/reports.py +293 -0
  91. themis/evaluation/statistics/__init__.py +53 -0
  92. themis/evaluation/statistics/bootstrap.py +79 -0
  93. themis/evaluation/statistics/confidence_intervals.py +121 -0
  94. themis/evaluation/statistics/distributions.py +207 -0
  95. themis/evaluation/statistics/effect_sizes.py +124 -0
  96. themis/evaluation/statistics/hypothesis_tests.py +305 -0
  97. themis/evaluation/statistics/types.py +139 -0
  98. themis/evaluation/strategies/__init__.py +13 -0
  99. themis/evaluation/strategies/attempt_aware_evaluation_strategy.py +51 -0
  100. themis/evaluation/strategies/default_evaluation_strategy.py +25 -0
  101. themis/evaluation/strategies/evaluation_strategy.py +24 -0
  102. themis/evaluation/strategies/judge_evaluation_strategy.py +64 -0
  103. themis/experiment/__init__.py +5 -0
  104. themis/experiment/builder.py +151 -0
  105. themis/experiment/cache_manager.py +134 -0
  106. themis/experiment/comparison.py +631 -0
  107. themis/experiment/cost.py +310 -0
  108. themis/experiment/definitions.py +62 -0
  109. themis/experiment/export.py +798 -0
  110. themis/experiment/export_csv.py +159 -0
  111. themis/experiment/integration_manager.py +104 -0
  112. themis/experiment/math.py +192 -0
  113. themis/experiment/mcq.py +169 -0
  114. themis/experiment/orchestrator.py +415 -0
  115. themis/experiment/pricing.py +317 -0
  116. themis/experiment/storage.py +1458 -0
  117. themis/experiment/visualization.py +588 -0
  118. themis/generation/__init__.py +1 -0
  119. themis/generation/agentic_runner.py +420 -0
  120. themis/generation/batching.py +254 -0
  121. themis/generation/clients.py +143 -0
  122. themis/generation/conversation_runner.py +236 -0
  123. themis/generation/plan.py +456 -0
  124. themis/generation/providers/litellm_provider.py +221 -0
  125. themis/generation/providers/vllm_provider.py +135 -0
  126. themis/generation/router.py +34 -0
  127. themis/generation/runner.py +207 -0
  128. themis/generation/strategies.py +98 -0
  129. themis/generation/templates.py +71 -0
  130. themis/generation/turn_strategies.py +393 -0
  131. themis/generation/types.py +9 -0
  132. themis/integrations/__init__.py +0 -0
  133. themis/integrations/huggingface.py +72 -0
  134. themis/integrations/wandb.py +77 -0
  135. themis/interfaces/__init__.py +169 -0
  136. themis/presets/__init__.py +10 -0
  137. themis/presets/benchmarks.py +354 -0
  138. themis/presets/models.py +190 -0
  139. themis/project/__init__.py +20 -0
  140. themis/project/definitions.py +98 -0
  141. themis/project/patterns.py +230 -0
  142. themis/providers/__init__.py +5 -0
  143. themis/providers/registry.py +39 -0
  144. themis/server/__init__.py +28 -0
  145. themis/server/app.py +337 -0
  146. themis/utils/api_generator.py +379 -0
  147. themis/utils/cost_tracking.py +376 -0
  148. themis/utils/dashboard.py +452 -0
  149. themis/utils/logging_utils.py +41 -0
  150. themis/utils/progress.py +58 -0
  151. themis/utils/tracing.py +320 -0
  152. themis_eval-0.2.0.dist-info/METADATA +596 -0
  153. themis_eval-0.2.0.dist-info/RECORD +157 -0
  154. {themis_eval-0.1.0.dist-info → themis_eval-0.2.0.dist-info}/WHEEL +1 -1
  155. themis_eval-0.1.0.dist-info/METADATA +0 -758
  156. themis_eval-0.1.0.dist-info/RECORD +0 -8
  157. {themis_eval-0.1.0.dist-info → themis_eval-0.2.0.dist-info}/licenses/LICENSE +0 -0
  158. {themis_eval-0.1.0.dist-info → themis_eval-0.2.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,179 @@
1
+ """Helpers for working with the bigbio/med_qa dataset."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import json
6
+ import string
7
+ from pathlib import Path
8
+ from typing import Any, Iterable, Iterator, List, Sequence
9
+
10
+ from pydantic import BaseModel, Field, ValidationInfo, field_validator
11
+
12
+ _DATASET_NAME = "bigbio/med_qa"
13
+ _CHOICE_LABELS = tuple(string.ascii_uppercase)
14
+
15
+
16
+ class MedQaSample(BaseModel):
17
+ unique_id: str
18
+ question: str
19
+ choices: list[str]
20
+ answer: str
21
+ subject: str = Field(default="medicine")
22
+ metadata: dict[str, Any] = Field(default_factory=dict)
23
+ choice_labels: list[str] = Field(default_factory=list)
24
+
25
+ @field_validator("choices", mode="before")
26
+ @classmethod
27
+ def _ensure_choices(cls, value: Any) -> list[str]:
28
+ if value is None:
29
+ return []
30
+ if isinstance(value, dict):
31
+ return [str(v) for _, v in sorted(value.items())]
32
+ if isinstance(value, (list, tuple)):
33
+ return [str(item) for item in value]
34
+ raise TypeError("choices must be a sequence or mapping")
35
+
36
+ @field_validator("choice_labels", mode="before")
37
+ @classmethod
38
+ def _build_choice_labels(cls, value: Any, info: ValidationInfo) -> list[str]:
39
+ if value:
40
+ return [str(item) for item in value]
41
+ choices = info.data.get("choices") if hasattr(info, "data") else None
42
+ total = len(choices) if isinstance(choices, list) else 0
43
+ return [*_CHOICE_LABELS[:total]]
44
+
45
+ def to_generation_example(self) -> dict[str, Any]:
46
+ effective_labels = (
47
+ list(self.choice_labels)
48
+ if self.choice_labels
49
+ else list(_CHOICE_LABELS[: len(self.choices)])
50
+ )
51
+ return {
52
+ "unique_id": self.unique_id,
53
+ "question": self.question,
54
+ "choices": list(self.choices),
55
+ "choice_labels": effective_labels,
56
+ "answer": self.answer,
57
+ "subject": self.subject,
58
+ "metadata": dict(self.metadata),
59
+ }
60
+
61
+
62
+ def load_med_qa(
63
+ *,
64
+ split: str = "test",
65
+ limit: int | None = None,
66
+ source: str = "huggingface",
67
+ data_dir: str | Path | None = None,
68
+ subset: str = "med_qa_en_bigbio_qa",
69
+ ) -> List[MedQaSample]:
70
+ """Load MedQA samples from Hugging Face or a local directory."""
71
+
72
+ if source not in {"huggingface", "local"}:
73
+ raise ValueError(
74
+ f"Unsupported source '{source}'. Expected one of: 'huggingface', 'local'."
75
+ )
76
+
77
+ if source == "huggingface":
78
+ rows = _load_from_huggingface(split=split, subset=subset)
79
+ else:
80
+ if data_dir is None:
81
+ raise ValueError(
82
+ "data_dir must be provided when source='local'. "
83
+ "Pass dataset.data_dir in configs or --data-dir on the CLI."
84
+ )
85
+ rows = _load_from_local(Path(data_dir))
86
+
87
+ samples: list[MedQaSample] = []
88
+ for index, row in enumerate(rows, start=1):
89
+ sample = _row_to_sample(row, index=index)
90
+ samples.append(sample)
91
+ if limit is not None and len(samples) >= limit:
92
+ break
93
+ return samples
94
+
95
+
96
+ def _row_to_sample(row: dict[str, Any], *, index: int) -> MedQaSample:
97
+ unique_id = (
98
+ row.get("id")
99
+ or row.get("unique_id")
100
+ or f"med-qa-{index:05d}"
101
+ )
102
+ question = row.get("question") or ""
103
+
104
+ # BigBio MedQA format:
105
+ # choices: [{'key': 'A', 'text': '...'}, {'key': 'B', 'text': '...'}]
106
+ # answer: 'A' (or similar)
107
+
108
+ choices_data = row.get("choices") or []
109
+ choices = []
110
+ choice_labels = []
111
+
112
+ if isinstance(choices_data, list):
113
+ # Sort by key to ensure order
114
+ try:
115
+ sorted_choices = sorted(choices_data, key=lambda x: x.get("key", ""))
116
+ for c in sorted_choices:
117
+ choices.append(str(c.get("text", "")))
118
+ choice_labels.append(str(c.get("key", "")))
119
+ except (TypeError, AttributeError):
120
+ # Fallback if structure is different
121
+ choices = [str(c) for c in choices_data]
122
+
123
+ answer = ""
124
+ answer_data = row.get("answer")
125
+ if isinstance(answer_data, list) and answer_data:
126
+ answer = str(answer_data[0]) # Usually a list with one element
127
+ elif isinstance(answer_data, str):
128
+ answer = answer_data
129
+
130
+ metadata_keys = {
131
+ "question", "choices", "answer", "id"
132
+ }
133
+ metadata = {key: value for key, value in row.items() if key not in metadata_keys}
134
+
135
+ return MedQaSample(
136
+ unique_id=str(unique_id),
137
+ question=str(question),
138
+ choices=choices,
139
+ choice_labels=choice_labels,
140
+ answer=answer,
141
+ metadata=metadata,
142
+ )
143
+
144
+
145
+ def _load_from_huggingface(*, split: str, subset: str) -> Iterable[dict[str, Any]]:
146
+ try:
147
+ from datasets import load_dataset
148
+ except ImportError as exc: # pragma: no cover - optional dependency
149
+ raise RuntimeError(
150
+ "datasets is required to load MedQA from Hugging Face. Install it via `uv pip install '.[hf]'`."
151
+ ) from exc
152
+
153
+ dataset = load_dataset(_DATASET_NAME, subset, split=split)
154
+ for row in dataset:
155
+ yield dict(row)
156
+
157
+
158
+ def _load_from_local(root: Path) -> Iterator[dict[str, Any]]:
159
+ if not root.exists():
160
+ raise FileNotFoundError(f"Local dataset directory not found: {root}")
161
+
162
+ for path in root.rglob("*"):
163
+ if path.suffix.lower() == ".json":
164
+ with path.open("r", encoding="utf-8") as handle:
165
+ row = json.load(handle)
166
+ row.setdefault("id", path.stem)
167
+ yield row
168
+ elif path.suffix.lower() in {".jsonl", ".ndjson"}:
169
+ with path.open("r", encoding="utf-8") as handle:
170
+ for line_num, line in enumerate(handle, start=1):
171
+ line = line.strip()
172
+ if not line:
173
+ continue
174
+ row = json.loads(line)
175
+ row.setdefault("id", f"{path.stem}-{line_num}")
176
+ yield row
177
+
178
+
179
+ __all__ = ["MedQaSample", "load_med_qa"]
@@ -0,0 +1,169 @@
1
+ """Helpers for working with the openlifescienceai/medmcqa dataset."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import json
6
+ from pathlib import Path
7
+ from typing import Any, Iterable, Iterator, List, Sequence
8
+
9
+ from pydantic import BaseModel, Field, field_validator
10
+
11
+ _DATASET_NAME = "openlifescienceai/medmcqa"
12
+ _CHOICE_LABELS = ["A", "B", "C", "D"]
13
+
14
+
15
+ class MedMcqaSample(BaseModel):
16
+ unique_id: str
17
+ question: str
18
+ choices: list[str]
19
+ answer: str
20
+ subject: str = Field(default="unknown")
21
+ metadata: dict[str, Any] = Field(default_factory=dict)
22
+ choice_labels: list[str] = Field(default_factory=list)
23
+
24
+ @field_validator("choices", mode="before")
25
+ @classmethod
26
+ def _ensure_choices(cls, value: Any) -> list[str]:
27
+ if value is None:
28
+ return []
29
+ if isinstance(value, dict):
30
+ return [str(v) for _, v in sorted(value.items())]
31
+ if isinstance(value, (list, tuple)):
32
+ return [str(item) for item in value]
33
+ raise TypeError("choices must be a sequence or mapping")
34
+
35
+ def to_generation_example(self) -> dict[str, Any]:
36
+ effective_labels = (
37
+ list(self.choice_labels)
38
+ if self.choice_labels
39
+ else list(_CHOICE_LABELS[: len(self.choices)])
40
+ )
41
+ return {
42
+ "unique_id": self.unique_id,
43
+ "question": self.question,
44
+ "choices": list(self.choices),
45
+ "choice_labels": effective_labels,
46
+ "answer": self.answer,
47
+ "subject": self.subject,
48
+ "metadata": dict(self.metadata),
49
+ }
50
+
51
+
52
+ def load_medmcqa(
53
+ *,
54
+ split: str = "test",
55
+ limit: int | None = None,
56
+ source: str = "huggingface",
57
+ data_dir: str | Path | None = None,
58
+ subset: str | None = None,
59
+ ) -> List[MedMcqaSample]:
60
+ """Load MedMCQA samples from Hugging Face or a local directory."""
61
+
62
+ if source not in {"huggingface", "local"}:
63
+ raise ValueError(
64
+ f"Unsupported source '{source}'. Expected one of: 'huggingface', 'local'."
65
+ )
66
+
67
+ if source == "huggingface":
68
+ rows = _load_from_huggingface(split=split, subset=subset)
69
+ else:
70
+ if data_dir is None:
71
+ raise ValueError(
72
+ "data_dir must be provided when source='local'. "
73
+ "Pass dataset.data_dir in configs or --data-dir on the CLI."
74
+ )
75
+ rows = _load_from_local(Path(data_dir))
76
+
77
+ samples: list[MedMcqaSample] = []
78
+ for index, row in enumerate(rows, start=1):
79
+ sample = _row_to_sample(row, index=index)
80
+ samples.append(sample)
81
+ if limit is not None and len(samples) >= limit:
82
+ break
83
+ return samples
84
+
85
+
86
+ def _row_to_sample(row: dict[str, Any], *, index: int) -> MedMcqaSample:
87
+ unique_id = (
88
+ row.get("id")
89
+ or row.get("unique_id")
90
+ or f"medmcqa-{index:05d}"
91
+ )
92
+ question = row.get("question") or ""
93
+
94
+ # MedMCQA has 'opa', 'opb', 'opc', 'opd'
95
+ choices = [
96
+ str(row.get("opa") or ""),
97
+ str(row.get("opb") or ""),
98
+ str(row.get("opc") or ""),
99
+ str(row.get("opd") or ""),
100
+ ]
101
+
102
+ # Answer is an integer 0-3 (sometimes 1-4 depending on version, but usually 0-3 in HF)
103
+ # Let's check. HF dataset viewer says 'cop': 1 (meaning option B if 1-based, or B if 0-based?)
104
+ # Usually it's 0-3 or 1-4.
105
+ # Checking dataset info: "cop: Correct option (1-4)"
106
+ # So we need to map 1->A, 2->B, 3->C, 4->D
107
+
108
+ cop = row.get("cop")
109
+ answer = ""
110
+ if cop is not None:
111
+ try:
112
+ cop_int = int(cop)
113
+ if 1 <= cop_int <= 4:
114
+ answer = _CHOICE_LABELS[cop_int - 1]
115
+ except (ValueError, TypeError):
116
+ pass
117
+
118
+ subject = row.get("subject_name") or "medicine"
119
+
120
+ metadata_keys = {
121
+ "question", "opa", "opb", "opc", "opd", "cop", "subject_name", "id"
122
+ }
123
+ metadata = {key: value for key, value in row.items() if key not in metadata_keys}
124
+
125
+ return MedMcqaSample(
126
+ unique_id=str(unique_id),
127
+ question=str(question),
128
+ choices=choices,
129
+ answer=answer,
130
+ subject=str(subject),
131
+ metadata=metadata,
132
+ )
133
+
134
+
135
+ def _load_from_huggingface(*, split: str, subset: str | None) -> Iterable[dict[str, Any]]:
136
+ try:
137
+ from datasets import load_dataset
138
+ except ImportError as exc: # pragma: no cover - optional dependency
139
+ raise RuntimeError(
140
+ "datasets is required to load MedMCQA from Hugging Face. Install it via `uv pip install '.[hf]'`."
141
+ ) from exc
142
+
143
+ dataset = load_dataset(_DATASET_NAME, subset, split=split)
144
+ for row in dataset:
145
+ yield dict(row)
146
+
147
+
148
+ def _load_from_local(root: Path) -> Iterator[dict[str, Any]]:
149
+ if not root.exists():
150
+ raise FileNotFoundError(f"Local dataset directory not found: {root}")
151
+
152
+ for path in root.rglob("*"):
153
+ if path.suffix.lower() == ".json":
154
+ with path.open("r", encoding="utf-8") as handle:
155
+ row = json.load(handle)
156
+ row.setdefault("id", path.stem)
157
+ yield row
158
+ elif path.suffix.lower() in {".jsonl", ".ndjson"}:
159
+ with path.open("r", encoding="utf-8") as handle:
160
+ for line_num, line in enumerate(handle, start=1):
161
+ line = line.strip()
162
+ if not line:
163
+ continue
164
+ row = json.loads(line)
165
+ row.setdefault("id", f"{path.stem}-{line_num}")
166
+ yield row
167
+
168
+
169
+ __all__ = ["MedMcqaSample", "load_medmcqa"]
@@ -0,0 +1,262 @@
1
+ """Helpers for working with the TIGER-Lab/MMLU-Pro dataset."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import json
6
+ import string
7
+ from pathlib import Path
8
+ from typing import Any, Iterable, Iterator, List, Sequence
9
+
10
+ from pydantic import BaseModel, Field, ValidationInfo, field_validator
11
+
12
+ _DATASET_NAME = "TIGER-Lab/MMLU-Pro"
13
+ _CHOICE_LABELS = tuple(string.ascii_uppercase)
14
+
15
+
16
+ def _index_to_label(index: int) -> str:
17
+ if 0 <= index < len(_CHOICE_LABELS):
18
+ return _CHOICE_LABELS[index]
19
+ return str(index)
20
+
21
+
22
+ def _normalize_answer(value: Any, *, total_choices: int | None = None) -> str:
23
+ if isinstance(value, int):
24
+ return _index_to_label(value)
25
+ if isinstance(value, float):
26
+ as_int = int(value)
27
+ if as_int == value:
28
+ return _index_to_label(as_int)
29
+ text = str(value or "").strip()
30
+ if not text:
31
+ return ""
32
+ lowered = text.lower()
33
+ if lowered.startswith("option "):
34
+ text = text.split(" ", 1)[-1]
35
+ if lowered.startswith("choice "):
36
+ text = text.split(" ", 1)[-1]
37
+ if text.isdigit():
38
+ index = int(text)
39
+ if total_choices is None or 0 <= index < total_choices:
40
+ return _index_to_label(index)
41
+ text = text.strip().rstrip(".")
42
+ if len(text) == 1 and text.isalpha():
43
+ return text.upper()
44
+ if total_choices is not None:
45
+ mapping = {str(idx): _index_to_label(idx) for idx in range(total_choices)}
46
+ normalized = mapping.get(text)
47
+ if normalized:
48
+ return normalized
49
+ return text
50
+
51
+
52
+ class MmluProSample(BaseModel):
53
+ unique_id: str
54
+ question: str
55
+ choices: list[str]
56
+ answer: str
57
+ subject: str = Field(default="unknown")
58
+ metadata: dict[str, Any] = Field(default_factory=dict)
59
+ choice_labels: list[str] = Field(default_factory=list)
60
+
61
+ @field_validator("choices", mode="before")
62
+ @classmethod
63
+ def _ensure_choices(cls, value: Any) -> list[str]:
64
+ if value is None:
65
+ return []
66
+ if isinstance(value, dict):
67
+ return [str(v) for _, v in sorted(value.items())]
68
+ if isinstance(value, (list, tuple)):
69
+ return [str(item) for item in value]
70
+ raise TypeError("choices must be a sequence or mapping")
71
+
72
+ @field_validator("choice_labels", mode="before")
73
+ @classmethod
74
+ def _build_choice_labels(cls, value: Any, info: ValidationInfo) -> list[str]:
75
+ if value:
76
+ return [str(item) for item in value]
77
+ choices = info.data.get("choices") if hasattr(info, "data") else None
78
+ total = len(choices) if isinstance(choices, list) else 0
79
+ return [*_CHOICE_LABELS[:total]]
80
+
81
+ @field_validator("answer", mode="before")
82
+ @classmethod
83
+ def _normalize_answer_field(cls, value: Any, info: ValidationInfo) -> str:
84
+ choices = info.data.get("choices") if hasattr(info, "data") else None
85
+ total = len(choices) if isinstance(choices, list) else None
86
+ return _normalize_answer(value, total_choices=total)
87
+
88
+ def to_generation_example(self) -> dict[str, Any]:
89
+ effective_labels = (
90
+ list(self.choice_labels)
91
+ if self.choice_labels
92
+ else list(_CHOICE_LABELS[: len(self.choices)])
93
+ )
94
+ return {
95
+ "unique_id": self.unique_id,
96
+ "question": self.question,
97
+ "choices": list(self.choices),
98
+ "choice_labels": effective_labels,
99
+ "answer": self.answer,
100
+ "subject": self.subject,
101
+ "metadata": dict(self.metadata),
102
+ }
103
+
104
+
105
+ def load_mmlu_pro(
106
+ *,
107
+ split: str = "test",
108
+ limit: int | None = None,
109
+ source: str = "huggingface",
110
+ data_dir: str | Path | None = None,
111
+ subjects: Sequence[str] | None = None,
112
+ ) -> List[MmluProSample]:
113
+ """Load MMLU-Pro samples from Hugging Face or a local directory."""
114
+
115
+ if source not in {"huggingface", "local"}:
116
+ raise ValueError(
117
+ f"Unsupported source '{source}'. Expected one of: 'huggingface', 'local'."
118
+ )
119
+
120
+ if source == "huggingface":
121
+ rows = _load_from_huggingface(split=split)
122
+ else:
123
+ if data_dir is None:
124
+ raise ValueError(
125
+ "data_dir must be provided when source='local'. "
126
+ "Pass dataset.data_dir in configs or --data-dir on the CLI."
127
+ )
128
+ rows = _load_from_local(Path(data_dir))
129
+
130
+ samples: list[MmluProSample] = []
131
+ selected_subjects = {s.lower() for s in subjects} if subjects else None
132
+ for index, row in enumerate(rows, start=1):
133
+ subject = _extract_subject(row)
134
+ if selected_subjects and subject.lower() not in selected_subjects:
135
+ continue
136
+ sample = _row_to_sample(row, index=index, subject=subject)
137
+ samples.append(sample)
138
+ if limit is not None and len(samples) >= limit:
139
+ break
140
+ return samples
141
+
142
+
143
+ def _extract_subject(row: dict[str, Any]) -> str:
144
+ for key in ("subject", "category", "group", "topic", "field"):
145
+ value = row.get(key)
146
+ if value:
147
+ return str(value)
148
+ return "unknown"
149
+
150
+
151
+ def _row_to_sample(row: dict[str, Any], *, index: int, subject: str) -> MmluProSample:
152
+ unique_id = (
153
+ row.get("id")
154
+ or row.get("question_id")
155
+ or row.get("unique_id")
156
+ or f"mmlu-pro-{index:05d}"
157
+ )
158
+ question = row.get("question") or row.get("Question") or row.get("prompt") or ""
159
+ choices = _extract_choices(row)
160
+ answer = (
161
+ row.get("answer")
162
+ or row.get("Answer")
163
+ or row.get("correct_answer")
164
+ or row.get("correct")
165
+ or ""
166
+ )
167
+ metadata_keys = {
168
+ "question",
169
+ "Question",
170
+ "prompt",
171
+ "choices",
172
+ "options",
173
+ "answer",
174
+ "Answer",
175
+ "correct_answer",
176
+ "correct",
177
+ }
178
+ metadata = {key: value for key, value in row.items() if key not in metadata_keys}
179
+ sample = MmluProSample.model_validate(
180
+ {
181
+ "unique_id": str(unique_id),
182
+ "question": str(question),
183
+ "choices": choices,
184
+ "answer": answer,
185
+ "subject": str(subject),
186
+ "metadata": metadata,
187
+ }
188
+ )
189
+ return sample
190
+
191
+
192
+ def _extract_choices(row: dict[str, Any]) -> list[str]:
193
+ candidates = row.get("choices") or row.get("options") or row.get("Choices")
194
+ if isinstance(candidates, dict):
195
+ return [str(value) for _, value in sorted(candidates.items())]
196
+ if isinstance(candidates, (list, tuple)):
197
+ return [str(item) for item in candidates]
198
+
199
+ choice_keys = [
200
+ "A",
201
+ "B",
202
+ "C",
203
+ "D",
204
+ "E",
205
+ "F",
206
+ "choice_a",
207
+ "choice_b",
208
+ "choice_c",
209
+ "choice_d",
210
+ "choice_e",
211
+ "choice_f",
212
+ "option_a",
213
+ "option_b",
214
+ "option_c",
215
+ "option_d",
216
+ "option_e",
217
+ "option_f",
218
+ ]
219
+ collected: list[str] = []
220
+ for key in choice_keys:
221
+ if key in row:
222
+ collected.append(str(row[key]))
223
+ if collected:
224
+ return collected
225
+ return []
226
+
227
+
228
+ def _load_from_huggingface(*, split: str) -> Iterable[dict[str, Any]]:
229
+ try:
230
+ from datasets import load_dataset
231
+ except ImportError as exc: # pragma: no cover - optional dependency
232
+ raise RuntimeError(
233
+ "datasets is required to load MMLU-Pro from Hugging Face. Install it via `uv pip install '.[hf]'`."
234
+ ) from exc
235
+
236
+ dataset = load_dataset(_DATASET_NAME, split=split)
237
+ for row in dataset:
238
+ yield dict(row)
239
+
240
+
241
+ def _load_from_local(root: Path) -> Iterator[dict[str, Any]]:
242
+ if not root.exists():
243
+ raise FileNotFoundError(f"Local dataset directory not found: {root}")
244
+
245
+ for path in root.rglob("*"):
246
+ if path.suffix.lower() == ".json":
247
+ with path.open("r", encoding="utf-8") as handle:
248
+ row = json.load(handle)
249
+ row.setdefault("id", path.stem)
250
+ yield row
251
+ elif path.suffix.lower() in {".jsonl", ".ndjson"}:
252
+ with path.open("r", encoding="utf-8") as handle:
253
+ for line_num, line in enumerate(handle, start=1):
254
+ line = line.strip()
255
+ if not line:
256
+ continue
257
+ row = json.loads(line)
258
+ row.setdefault("id", f"{path.stem}-{line_num}")
259
+ yield row
260
+
261
+
262
+ __all__ = ["MmluProSample", "load_mmlu_pro"]