themis-eval 0.1.0__py3-none-any.whl → 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- themis/__init__.py +12 -1
- themis/_version.py +2 -2
- themis/api.py +343 -0
- themis/backends/__init__.py +17 -0
- themis/backends/execution.py +197 -0
- themis/backends/storage.py +260 -0
- themis/cli/__init__.py +5 -0
- themis/cli/__main__.py +6 -0
- themis/cli/commands/__init__.py +19 -0
- themis/cli/commands/benchmarks.py +221 -0
- themis/cli/commands/comparison.py +394 -0
- themis/cli/commands/config_commands.py +244 -0
- themis/cli/commands/cost.py +214 -0
- themis/cli/commands/demo.py +68 -0
- themis/cli/commands/info.py +90 -0
- themis/cli/commands/leaderboard.py +362 -0
- themis/cli/commands/math_benchmarks.py +318 -0
- themis/cli/commands/mcq_benchmarks.py +207 -0
- themis/cli/commands/results.py +252 -0
- themis/cli/commands/sample_run.py +244 -0
- themis/cli/commands/visualize.py +299 -0
- themis/cli/main.py +463 -0
- themis/cli/new_project.py +33 -0
- themis/cli/utils.py +51 -0
- themis/comparison/__init__.py +25 -0
- themis/comparison/engine.py +348 -0
- themis/comparison/reports.py +283 -0
- themis/comparison/statistics.py +402 -0
- themis/config/__init__.py +19 -0
- themis/config/loader.py +27 -0
- themis/config/registry.py +34 -0
- themis/config/runtime.py +214 -0
- themis/config/schema.py +112 -0
- themis/core/__init__.py +5 -0
- themis/core/conversation.py +354 -0
- themis/core/entities.py +184 -0
- themis/core/serialization.py +231 -0
- themis/core/tools.py +393 -0
- themis/core/types.py +141 -0
- themis/datasets/__init__.py +273 -0
- themis/datasets/base.py +264 -0
- themis/datasets/commonsense_qa.py +174 -0
- themis/datasets/competition_math.py +265 -0
- themis/datasets/coqa.py +133 -0
- themis/datasets/gpqa.py +190 -0
- themis/datasets/gsm8k.py +123 -0
- themis/datasets/gsm_symbolic.py +124 -0
- themis/datasets/math500.py +122 -0
- themis/datasets/med_qa.py +179 -0
- themis/datasets/medmcqa.py +169 -0
- themis/datasets/mmlu_pro.py +262 -0
- themis/datasets/piqa.py +146 -0
- themis/datasets/registry.py +201 -0
- themis/datasets/schema.py +245 -0
- themis/datasets/sciq.py +150 -0
- themis/datasets/social_i_qa.py +151 -0
- themis/datasets/super_gpqa.py +263 -0
- themis/evaluation/__init__.py +1 -0
- themis/evaluation/conditional.py +410 -0
- themis/evaluation/extractors/__init__.py +19 -0
- themis/evaluation/extractors/error_taxonomy_extractor.py +80 -0
- themis/evaluation/extractors/exceptions.py +7 -0
- themis/evaluation/extractors/identity_extractor.py +29 -0
- themis/evaluation/extractors/json_field_extractor.py +45 -0
- themis/evaluation/extractors/math_verify_extractor.py +37 -0
- themis/evaluation/extractors/regex_extractor.py +43 -0
- themis/evaluation/math_verify_utils.py +87 -0
- themis/evaluation/metrics/__init__.py +21 -0
- themis/evaluation/metrics/code/__init__.py +19 -0
- themis/evaluation/metrics/code/codebleu.py +144 -0
- themis/evaluation/metrics/code/execution.py +280 -0
- themis/evaluation/metrics/code/pass_at_k.py +181 -0
- themis/evaluation/metrics/composite_metric.py +47 -0
- themis/evaluation/metrics/consistency_metric.py +80 -0
- themis/evaluation/metrics/exact_match.py +51 -0
- themis/evaluation/metrics/length_difference_tolerance.py +33 -0
- themis/evaluation/metrics/math_verify_accuracy.py +40 -0
- themis/evaluation/metrics/nlp/__init__.py +21 -0
- themis/evaluation/metrics/nlp/bertscore.py +138 -0
- themis/evaluation/metrics/nlp/bleu.py +129 -0
- themis/evaluation/metrics/nlp/meteor.py +153 -0
- themis/evaluation/metrics/nlp/rouge.py +136 -0
- themis/evaluation/metrics/pairwise_judge_metric.py +141 -0
- themis/evaluation/metrics/response_length.py +33 -0
- themis/evaluation/metrics/rubric_judge_metric.py +134 -0
- themis/evaluation/pipeline.py +49 -0
- themis/evaluation/pipelines/__init__.py +15 -0
- themis/evaluation/pipelines/composable_pipeline.py +357 -0
- themis/evaluation/pipelines/standard_pipeline.py +348 -0
- themis/evaluation/reports.py +293 -0
- themis/evaluation/statistics/__init__.py +53 -0
- themis/evaluation/statistics/bootstrap.py +79 -0
- themis/evaluation/statistics/confidence_intervals.py +121 -0
- themis/evaluation/statistics/distributions.py +207 -0
- themis/evaluation/statistics/effect_sizes.py +124 -0
- themis/evaluation/statistics/hypothesis_tests.py +305 -0
- themis/evaluation/statistics/types.py +139 -0
- themis/evaluation/strategies/__init__.py +13 -0
- themis/evaluation/strategies/attempt_aware_evaluation_strategy.py +51 -0
- themis/evaluation/strategies/default_evaluation_strategy.py +25 -0
- themis/evaluation/strategies/evaluation_strategy.py +24 -0
- themis/evaluation/strategies/judge_evaluation_strategy.py +64 -0
- themis/experiment/__init__.py +5 -0
- themis/experiment/builder.py +151 -0
- themis/experiment/cache_manager.py +134 -0
- themis/experiment/comparison.py +631 -0
- themis/experiment/cost.py +310 -0
- themis/experiment/definitions.py +62 -0
- themis/experiment/export.py +798 -0
- themis/experiment/export_csv.py +159 -0
- themis/experiment/integration_manager.py +104 -0
- themis/experiment/math.py +192 -0
- themis/experiment/mcq.py +169 -0
- themis/experiment/orchestrator.py +415 -0
- themis/experiment/pricing.py +317 -0
- themis/experiment/storage.py +1458 -0
- themis/experiment/visualization.py +588 -0
- themis/generation/__init__.py +1 -0
- themis/generation/agentic_runner.py +420 -0
- themis/generation/batching.py +254 -0
- themis/generation/clients.py +143 -0
- themis/generation/conversation_runner.py +236 -0
- themis/generation/plan.py +456 -0
- themis/generation/providers/litellm_provider.py +221 -0
- themis/generation/providers/vllm_provider.py +135 -0
- themis/generation/router.py +34 -0
- themis/generation/runner.py +207 -0
- themis/generation/strategies.py +98 -0
- themis/generation/templates.py +71 -0
- themis/generation/turn_strategies.py +393 -0
- themis/generation/types.py +9 -0
- themis/integrations/__init__.py +0 -0
- themis/integrations/huggingface.py +72 -0
- themis/integrations/wandb.py +77 -0
- themis/interfaces/__init__.py +169 -0
- themis/presets/__init__.py +10 -0
- themis/presets/benchmarks.py +354 -0
- themis/presets/models.py +190 -0
- themis/project/__init__.py +20 -0
- themis/project/definitions.py +98 -0
- themis/project/patterns.py +230 -0
- themis/providers/__init__.py +5 -0
- themis/providers/registry.py +39 -0
- themis/server/__init__.py +28 -0
- themis/server/app.py +337 -0
- themis/utils/api_generator.py +379 -0
- themis/utils/cost_tracking.py +376 -0
- themis/utils/dashboard.py +452 -0
- themis/utils/logging_utils.py +41 -0
- themis/utils/progress.py +58 -0
- themis/utils/tracing.py +320 -0
- themis_eval-0.2.0.dist-info/METADATA +596 -0
- themis_eval-0.2.0.dist-info/RECORD +157 -0
- {themis_eval-0.1.0.dist-info → themis_eval-0.2.0.dist-info}/WHEEL +1 -1
- themis_eval-0.1.0.dist-info/METADATA +0 -758
- themis_eval-0.1.0.dist-info/RECORD +0 -8
- {themis_eval-0.1.0.dist-info → themis_eval-0.2.0.dist-info}/licenses/LICENSE +0 -0
- {themis_eval-0.1.0.dist-info → themis_eval-0.2.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,179 @@
|
|
|
1
|
+
"""Helpers for working with the bigbio/med_qa dataset."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import json
|
|
6
|
+
import string
|
|
7
|
+
from pathlib import Path
|
|
8
|
+
from typing import Any, Iterable, Iterator, List, Sequence
|
|
9
|
+
|
|
10
|
+
from pydantic import BaseModel, Field, ValidationInfo, field_validator
|
|
11
|
+
|
|
12
|
+
_DATASET_NAME = "bigbio/med_qa"
|
|
13
|
+
_CHOICE_LABELS = tuple(string.ascii_uppercase)
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class MedQaSample(BaseModel):
|
|
17
|
+
unique_id: str
|
|
18
|
+
question: str
|
|
19
|
+
choices: list[str]
|
|
20
|
+
answer: str
|
|
21
|
+
subject: str = Field(default="medicine")
|
|
22
|
+
metadata: dict[str, Any] = Field(default_factory=dict)
|
|
23
|
+
choice_labels: list[str] = Field(default_factory=list)
|
|
24
|
+
|
|
25
|
+
@field_validator("choices", mode="before")
|
|
26
|
+
@classmethod
|
|
27
|
+
def _ensure_choices(cls, value: Any) -> list[str]:
|
|
28
|
+
if value is None:
|
|
29
|
+
return []
|
|
30
|
+
if isinstance(value, dict):
|
|
31
|
+
return [str(v) for _, v in sorted(value.items())]
|
|
32
|
+
if isinstance(value, (list, tuple)):
|
|
33
|
+
return [str(item) for item in value]
|
|
34
|
+
raise TypeError("choices must be a sequence or mapping")
|
|
35
|
+
|
|
36
|
+
@field_validator("choice_labels", mode="before")
|
|
37
|
+
@classmethod
|
|
38
|
+
def _build_choice_labels(cls, value: Any, info: ValidationInfo) -> list[str]:
|
|
39
|
+
if value:
|
|
40
|
+
return [str(item) for item in value]
|
|
41
|
+
choices = info.data.get("choices") if hasattr(info, "data") else None
|
|
42
|
+
total = len(choices) if isinstance(choices, list) else 0
|
|
43
|
+
return [*_CHOICE_LABELS[:total]]
|
|
44
|
+
|
|
45
|
+
def to_generation_example(self) -> dict[str, Any]:
|
|
46
|
+
effective_labels = (
|
|
47
|
+
list(self.choice_labels)
|
|
48
|
+
if self.choice_labels
|
|
49
|
+
else list(_CHOICE_LABELS[: len(self.choices)])
|
|
50
|
+
)
|
|
51
|
+
return {
|
|
52
|
+
"unique_id": self.unique_id,
|
|
53
|
+
"question": self.question,
|
|
54
|
+
"choices": list(self.choices),
|
|
55
|
+
"choice_labels": effective_labels,
|
|
56
|
+
"answer": self.answer,
|
|
57
|
+
"subject": self.subject,
|
|
58
|
+
"metadata": dict(self.metadata),
|
|
59
|
+
}
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
def load_med_qa(
|
|
63
|
+
*,
|
|
64
|
+
split: str = "test",
|
|
65
|
+
limit: int | None = None,
|
|
66
|
+
source: str = "huggingface",
|
|
67
|
+
data_dir: str | Path | None = None,
|
|
68
|
+
subset: str = "med_qa_en_bigbio_qa",
|
|
69
|
+
) -> List[MedQaSample]:
|
|
70
|
+
"""Load MedQA samples from Hugging Face or a local directory."""
|
|
71
|
+
|
|
72
|
+
if source not in {"huggingface", "local"}:
|
|
73
|
+
raise ValueError(
|
|
74
|
+
f"Unsupported source '{source}'. Expected one of: 'huggingface', 'local'."
|
|
75
|
+
)
|
|
76
|
+
|
|
77
|
+
if source == "huggingface":
|
|
78
|
+
rows = _load_from_huggingface(split=split, subset=subset)
|
|
79
|
+
else:
|
|
80
|
+
if data_dir is None:
|
|
81
|
+
raise ValueError(
|
|
82
|
+
"data_dir must be provided when source='local'. "
|
|
83
|
+
"Pass dataset.data_dir in configs or --data-dir on the CLI."
|
|
84
|
+
)
|
|
85
|
+
rows = _load_from_local(Path(data_dir))
|
|
86
|
+
|
|
87
|
+
samples: list[MedQaSample] = []
|
|
88
|
+
for index, row in enumerate(rows, start=1):
|
|
89
|
+
sample = _row_to_sample(row, index=index)
|
|
90
|
+
samples.append(sample)
|
|
91
|
+
if limit is not None and len(samples) >= limit:
|
|
92
|
+
break
|
|
93
|
+
return samples
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
def _row_to_sample(row: dict[str, Any], *, index: int) -> MedQaSample:
|
|
97
|
+
unique_id = (
|
|
98
|
+
row.get("id")
|
|
99
|
+
or row.get("unique_id")
|
|
100
|
+
or f"med-qa-{index:05d}"
|
|
101
|
+
)
|
|
102
|
+
question = row.get("question") or ""
|
|
103
|
+
|
|
104
|
+
# BigBio MedQA format:
|
|
105
|
+
# choices: [{'key': 'A', 'text': '...'}, {'key': 'B', 'text': '...'}]
|
|
106
|
+
# answer: 'A' (or similar)
|
|
107
|
+
|
|
108
|
+
choices_data = row.get("choices") or []
|
|
109
|
+
choices = []
|
|
110
|
+
choice_labels = []
|
|
111
|
+
|
|
112
|
+
if isinstance(choices_data, list):
|
|
113
|
+
# Sort by key to ensure order
|
|
114
|
+
try:
|
|
115
|
+
sorted_choices = sorted(choices_data, key=lambda x: x.get("key", ""))
|
|
116
|
+
for c in sorted_choices:
|
|
117
|
+
choices.append(str(c.get("text", "")))
|
|
118
|
+
choice_labels.append(str(c.get("key", "")))
|
|
119
|
+
except (TypeError, AttributeError):
|
|
120
|
+
# Fallback if structure is different
|
|
121
|
+
choices = [str(c) for c in choices_data]
|
|
122
|
+
|
|
123
|
+
answer = ""
|
|
124
|
+
answer_data = row.get("answer")
|
|
125
|
+
if isinstance(answer_data, list) and answer_data:
|
|
126
|
+
answer = str(answer_data[0]) # Usually a list with one element
|
|
127
|
+
elif isinstance(answer_data, str):
|
|
128
|
+
answer = answer_data
|
|
129
|
+
|
|
130
|
+
metadata_keys = {
|
|
131
|
+
"question", "choices", "answer", "id"
|
|
132
|
+
}
|
|
133
|
+
metadata = {key: value for key, value in row.items() if key not in metadata_keys}
|
|
134
|
+
|
|
135
|
+
return MedQaSample(
|
|
136
|
+
unique_id=str(unique_id),
|
|
137
|
+
question=str(question),
|
|
138
|
+
choices=choices,
|
|
139
|
+
choice_labels=choice_labels,
|
|
140
|
+
answer=answer,
|
|
141
|
+
metadata=metadata,
|
|
142
|
+
)
|
|
143
|
+
|
|
144
|
+
|
|
145
|
+
def _load_from_huggingface(*, split: str, subset: str) -> Iterable[dict[str, Any]]:
|
|
146
|
+
try:
|
|
147
|
+
from datasets import load_dataset
|
|
148
|
+
except ImportError as exc: # pragma: no cover - optional dependency
|
|
149
|
+
raise RuntimeError(
|
|
150
|
+
"datasets is required to load MedQA from Hugging Face. Install it via `uv pip install '.[hf]'`."
|
|
151
|
+
) from exc
|
|
152
|
+
|
|
153
|
+
dataset = load_dataset(_DATASET_NAME, subset, split=split)
|
|
154
|
+
for row in dataset:
|
|
155
|
+
yield dict(row)
|
|
156
|
+
|
|
157
|
+
|
|
158
|
+
def _load_from_local(root: Path) -> Iterator[dict[str, Any]]:
|
|
159
|
+
if not root.exists():
|
|
160
|
+
raise FileNotFoundError(f"Local dataset directory not found: {root}")
|
|
161
|
+
|
|
162
|
+
for path in root.rglob("*"):
|
|
163
|
+
if path.suffix.lower() == ".json":
|
|
164
|
+
with path.open("r", encoding="utf-8") as handle:
|
|
165
|
+
row = json.load(handle)
|
|
166
|
+
row.setdefault("id", path.stem)
|
|
167
|
+
yield row
|
|
168
|
+
elif path.suffix.lower() in {".jsonl", ".ndjson"}:
|
|
169
|
+
with path.open("r", encoding="utf-8") as handle:
|
|
170
|
+
for line_num, line in enumerate(handle, start=1):
|
|
171
|
+
line = line.strip()
|
|
172
|
+
if not line:
|
|
173
|
+
continue
|
|
174
|
+
row = json.loads(line)
|
|
175
|
+
row.setdefault("id", f"{path.stem}-{line_num}")
|
|
176
|
+
yield row
|
|
177
|
+
|
|
178
|
+
|
|
179
|
+
__all__ = ["MedQaSample", "load_med_qa"]
|
|
@@ -0,0 +1,169 @@
|
|
|
1
|
+
"""Helpers for working with the openlifescienceai/medmcqa dataset."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import json
|
|
6
|
+
from pathlib import Path
|
|
7
|
+
from typing import Any, Iterable, Iterator, List, Sequence
|
|
8
|
+
|
|
9
|
+
from pydantic import BaseModel, Field, field_validator
|
|
10
|
+
|
|
11
|
+
_DATASET_NAME = "openlifescienceai/medmcqa"
|
|
12
|
+
_CHOICE_LABELS = ["A", "B", "C", "D"]
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class MedMcqaSample(BaseModel):
|
|
16
|
+
unique_id: str
|
|
17
|
+
question: str
|
|
18
|
+
choices: list[str]
|
|
19
|
+
answer: str
|
|
20
|
+
subject: str = Field(default="unknown")
|
|
21
|
+
metadata: dict[str, Any] = Field(default_factory=dict)
|
|
22
|
+
choice_labels: list[str] = Field(default_factory=list)
|
|
23
|
+
|
|
24
|
+
@field_validator("choices", mode="before")
|
|
25
|
+
@classmethod
|
|
26
|
+
def _ensure_choices(cls, value: Any) -> list[str]:
|
|
27
|
+
if value is None:
|
|
28
|
+
return []
|
|
29
|
+
if isinstance(value, dict):
|
|
30
|
+
return [str(v) for _, v in sorted(value.items())]
|
|
31
|
+
if isinstance(value, (list, tuple)):
|
|
32
|
+
return [str(item) for item in value]
|
|
33
|
+
raise TypeError("choices must be a sequence or mapping")
|
|
34
|
+
|
|
35
|
+
def to_generation_example(self) -> dict[str, Any]:
|
|
36
|
+
effective_labels = (
|
|
37
|
+
list(self.choice_labels)
|
|
38
|
+
if self.choice_labels
|
|
39
|
+
else list(_CHOICE_LABELS[: len(self.choices)])
|
|
40
|
+
)
|
|
41
|
+
return {
|
|
42
|
+
"unique_id": self.unique_id,
|
|
43
|
+
"question": self.question,
|
|
44
|
+
"choices": list(self.choices),
|
|
45
|
+
"choice_labels": effective_labels,
|
|
46
|
+
"answer": self.answer,
|
|
47
|
+
"subject": self.subject,
|
|
48
|
+
"metadata": dict(self.metadata),
|
|
49
|
+
}
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
def load_medmcqa(
|
|
53
|
+
*,
|
|
54
|
+
split: str = "test",
|
|
55
|
+
limit: int | None = None,
|
|
56
|
+
source: str = "huggingface",
|
|
57
|
+
data_dir: str | Path | None = None,
|
|
58
|
+
subset: str | None = None,
|
|
59
|
+
) -> List[MedMcqaSample]:
|
|
60
|
+
"""Load MedMCQA samples from Hugging Face or a local directory."""
|
|
61
|
+
|
|
62
|
+
if source not in {"huggingface", "local"}:
|
|
63
|
+
raise ValueError(
|
|
64
|
+
f"Unsupported source '{source}'. Expected one of: 'huggingface', 'local'."
|
|
65
|
+
)
|
|
66
|
+
|
|
67
|
+
if source == "huggingface":
|
|
68
|
+
rows = _load_from_huggingface(split=split, subset=subset)
|
|
69
|
+
else:
|
|
70
|
+
if data_dir is None:
|
|
71
|
+
raise ValueError(
|
|
72
|
+
"data_dir must be provided when source='local'. "
|
|
73
|
+
"Pass dataset.data_dir in configs or --data-dir on the CLI."
|
|
74
|
+
)
|
|
75
|
+
rows = _load_from_local(Path(data_dir))
|
|
76
|
+
|
|
77
|
+
samples: list[MedMcqaSample] = []
|
|
78
|
+
for index, row in enumerate(rows, start=1):
|
|
79
|
+
sample = _row_to_sample(row, index=index)
|
|
80
|
+
samples.append(sample)
|
|
81
|
+
if limit is not None and len(samples) >= limit:
|
|
82
|
+
break
|
|
83
|
+
return samples
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
def _row_to_sample(row: dict[str, Any], *, index: int) -> MedMcqaSample:
|
|
87
|
+
unique_id = (
|
|
88
|
+
row.get("id")
|
|
89
|
+
or row.get("unique_id")
|
|
90
|
+
or f"medmcqa-{index:05d}"
|
|
91
|
+
)
|
|
92
|
+
question = row.get("question") or ""
|
|
93
|
+
|
|
94
|
+
# MedMCQA has 'opa', 'opb', 'opc', 'opd'
|
|
95
|
+
choices = [
|
|
96
|
+
str(row.get("opa") or ""),
|
|
97
|
+
str(row.get("opb") or ""),
|
|
98
|
+
str(row.get("opc") or ""),
|
|
99
|
+
str(row.get("opd") or ""),
|
|
100
|
+
]
|
|
101
|
+
|
|
102
|
+
# Answer is an integer 0-3 (sometimes 1-4 depending on version, but usually 0-3 in HF)
|
|
103
|
+
# Let's check. HF dataset viewer says 'cop': 1 (meaning option B if 1-based, or B if 0-based?)
|
|
104
|
+
# Usually it's 0-3 or 1-4.
|
|
105
|
+
# Checking dataset info: "cop: Correct option (1-4)"
|
|
106
|
+
# So we need to map 1->A, 2->B, 3->C, 4->D
|
|
107
|
+
|
|
108
|
+
cop = row.get("cop")
|
|
109
|
+
answer = ""
|
|
110
|
+
if cop is not None:
|
|
111
|
+
try:
|
|
112
|
+
cop_int = int(cop)
|
|
113
|
+
if 1 <= cop_int <= 4:
|
|
114
|
+
answer = _CHOICE_LABELS[cop_int - 1]
|
|
115
|
+
except (ValueError, TypeError):
|
|
116
|
+
pass
|
|
117
|
+
|
|
118
|
+
subject = row.get("subject_name") or "medicine"
|
|
119
|
+
|
|
120
|
+
metadata_keys = {
|
|
121
|
+
"question", "opa", "opb", "opc", "opd", "cop", "subject_name", "id"
|
|
122
|
+
}
|
|
123
|
+
metadata = {key: value for key, value in row.items() if key not in metadata_keys}
|
|
124
|
+
|
|
125
|
+
return MedMcqaSample(
|
|
126
|
+
unique_id=str(unique_id),
|
|
127
|
+
question=str(question),
|
|
128
|
+
choices=choices,
|
|
129
|
+
answer=answer,
|
|
130
|
+
subject=str(subject),
|
|
131
|
+
metadata=metadata,
|
|
132
|
+
)
|
|
133
|
+
|
|
134
|
+
|
|
135
|
+
def _load_from_huggingface(*, split: str, subset: str | None) -> Iterable[dict[str, Any]]:
|
|
136
|
+
try:
|
|
137
|
+
from datasets import load_dataset
|
|
138
|
+
except ImportError as exc: # pragma: no cover - optional dependency
|
|
139
|
+
raise RuntimeError(
|
|
140
|
+
"datasets is required to load MedMCQA from Hugging Face. Install it via `uv pip install '.[hf]'`."
|
|
141
|
+
) from exc
|
|
142
|
+
|
|
143
|
+
dataset = load_dataset(_DATASET_NAME, subset, split=split)
|
|
144
|
+
for row in dataset:
|
|
145
|
+
yield dict(row)
|
|
146
|
+
|
|
147
|
+
|
|
148
|
+
def _load_from_local(root: Path) -> Iterator[dict[str, Any]]:
|
|
149
|
+
if not root.exists():
|
|
150
|
+
raise FileNotFoundError(f"Local dataset directory not found: {root}")
|
|
151
|
+
|
|
152
|
+
for path in root.rglob("*"):
|
|
153
|
+
if path.suffix.lower() == ".json":
|
|
154
|
+
with path.open("r", encoding="utf-8") as handle:
|
|
155
|
+
row = json.load(handle)
|
|
156
|
+
row.setdefault("id", path.stem)
|
|
157
|
+
yield row
|
|
158
|
+
elif path.suffix.lower() in {".jsonl", ".ndjson"}:
|
|
159
|
+
with path.open("r", encoding="utf-8") as handle:
|
|
160
|
+
for line_num, line in enumerate(handle, start=1):
|
|
161
|
+
line = line.strip()
|
|
162
|
+
if not line:
|
|
163
|
+
continue
|
|
164
|
+
row = json.loads(line)
|
|
165
|
+
row.setdefault("id", f"{path.stem}-{line_num}")
|
|
166
|
+
yield row
|
|
167
|
+
|
|
168
|
+
|
|
169
|
+
__all__ = ["MedMcqaSample", "load_medmcqa"]
|
|
@@ -0,0 +1,262 @@
|
|
|
1
|
+
"""Helpers for working with the TIGER-Lab/MMLU-Pro dataset."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import json
|
|
6
|
+
import string
|
|
7
|
+
from pathlib import Path
|
|
8
|
+
from typing import Any, Iterable, Iterator, List, Sequence
|
|
9
|
+
|
|
10
|
+
from pydantic import BaseModel, Field, ValidationInfo, field_validator
|
|
11
|
+
|
|
12
|
+
_DATASET_NAME = "TIGER-Lab/MMLU-Pro"
|
|
13
|
+
_CHOICE_LABELS = tuple(string.ascii_uppercase)
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def _index_to_label(index: int) -> str:
|
|
17
|
+
if 0 <= index < len(_CHOICE_LABELS):
|
|
18
|
+
return _CHOICE_LABELS[index]
|
|
19
|
+
return str(index)
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def _normalize_answer(value: Any, *, total_choices: int | None = None) -> str:
|
|
23
|
+
if isinstance(value, int):
|
|
24
|
+
return _index_to_label(value)
|
|
25
|
+
if isinstance(value, float):
|
|
26
|
+
as_int = int(value)
|
|
27
|
+
if as_int == value:
|
|
28
|
+
return _index_to_label(as_int)
|
|
29
|
+
text = str(value or "").strip()
|
|
30
|
+
if not text:
|
|
31
|
+
return ""
|
|
32
|
+
lowered = text.lower()
|
|
33
|
+
if lowered.startswith("option "):
|
|
34
|
+
text = text.split(" ", 1)[-1]
|
|
35
|
+
if lowered.startswith("choice "):
|
|
36
|
+
text = text.split(" ", 1)[-1]
|
|
37
|
+
if text.isdigit():
|
|
38
|
+
index = int(text)
|
|
39
|
+
if total_choices is None or 0 <= index < total_choices:
|
|
40
|
+
return _index_to_label(index)
|
|
41
|
+
text = text.strip().rstrip(".")
|
|
42
|
+
if len(text) == 1 and text.isalpha():
|
|
43
|
+
return text.upper()
|
|
44
|
+
if total_choices is not None:
|
|
45
|
+
mapping = {str(idx): _index_to_label(idx) for idx in range(total_choices)}
|
|
46
|
+
normalized = mapping.get(text)
|
|
47
|
+
if normalized:
|
|
48
|
+
return normalized
|
|
49
|
+
return text
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
class MmluProSample(BaseModel):
|
|
53
|
+
unique_id: str
|
|
54
|
+
question: str
|
|
55
|
+
choices: list[str]
|
|
56
|
+
answer: str
|
|
57
|
+
subject: str = Field(default="unknown")
|
|
58
|
+
metadata: dict[str, Any] = Field(default_factory=dict)
|
|
59
|
+
choice_labels: list[str] = Field(default_factory=list)
|
|
60
|
+
|
|
61
|
+
@field_validator("choices", mode="before")
|
|
62
|
+
@classmethod
|
|
63
|
+
def _ensure_choices(cls, value: Any) -> list[str]:
|
|
64
|
+
if value is None:
|
|
65
|
+
return []
|
|
66
|
+
if isinstance(value, dict):
|
|
67
|
+
return [str(v) for _, v in sorted(value.items())]
|
|
68
|
+
if isinstance(value, (list, tuple)):
|
|
69
|
+
return [str(item) for item in value]
|
|
70
|
+
raise TypeError("choices must be a sequence or mapping")
|
|
71
|
+
|
|
72
|
+
@field_validator("choice_labels", mode="before")
|
|
73
|
+
@classmethod
|
|
74
|
+
def _build_choice_labels(cls, value: Any, info: ValidationInfo) -> list[str]:
|
|
75
|
+
if value:
|
|
76
|
+
return [str(item) for item in value]
|
|
77
|
+
choices = info.data.get("choices") if hasattr(info, "data") else None
|
|
78
|
+
total = len(choices) if isinstance(choices, list) else 0
|
|
79
|
+
return [*_CHOICE_LABELS[:total]]
|
|
80
|
+
|
|
81
|
+
@field_validator("answer", mode="before")
|
|
82
|
+
@classmethod
|
|
83
|
+
def _normalize_answer_field(cls, value: Any, info: ValidationInfo) -> str:
|
|
84
|
+
choices = info.data.get("choices") if hasattr(info, "data") else None
|
|
85
|
+
total = len(choices) if isinstance(choices, list) else None
|
|
86
|
+
return _normalize_answer(value, total_choices=total)
|
|
87
|
+
|
|
88
|
+
def to_generation_example(self) -> dict[str, Any]:
|
|
89
|
+
effective_labels = (
|
|
90
|
+
list(self.choice_labels)
|
|
91
|
+
if self.choice_labels
|
|
92
|
+
else list(_CHOICE_LABELS[: len(self.choices)])
|
|
93
|
+
)
|
|
94
|
+
return {
|
|
95
|
+
"unique_id": self.unique_id,
|
|
96
|
+
"question": self.question,
|
|
97
|
+
"choices": list(self.choices),
|
|
98
|
+
"choice_labels": effective_labels,
|
|
99
|
+
"answer": self.answer,
|
|
100
|
+
"subject": self.subject,
|
|
101
|
+
"metadata": dict(self.metadata),
|
|
102
|
+
}
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
def load_mmlu_pro(
|
|
106
|
+
*,
|
|
107
|
+
split: str = "test",
|
|
108
|
+
limit: int | None = None,
|
|
109
|
+
source: str = "huggingface",
|
|
110
|
+
data_dir: str | Path | None = None,
|
|
111
|
+
subjects: Sequence[str] | None = None,
|
|
112
|
+
) -> List[MmluProSample]:
|
|
113
|
+
"""Load MMLU-Pro samples from Hugging Face or a local directory."""
|
|
114
|
+
|
|
115
|
+
if source not in {"huggingface", "local"}:
|
|
116
|
+
raise ValueError(
|
|
117
|
+
f"Unsupported source '{source}'. Expected one of: 'huggingface', 'local'."
|
|
118
|
+
)
|
|
119
|
+
|
|
120
|
+
if source == "huggingface":
|
|
121
|
+
rows = _load_from_huggingface(split=split)
|
|
122
|
+
else:
|
|
123
|
+
if data_dir is None:
|
|
124
|
+
raise ValueError(
|
|
125
|
+
"data_dir must be provided when source='local'. "
|
|
126
|
+
"Pass dataset.data_dir in configs or --data-dir on the CLI."
|
|
127
|
+
)
|
|
128
|
+
rows = _load_from_local(Path(data_dir))
|
|
129
|
+
|
|
130
|
+
samples: list[MmluProSample] = []
|
|
131
|
+
selected_subjects = {s.lower() for s in subjects} if subjects else None
|
|
132
|
+
for index, row in enumerate(rows, start=1):
|
|
133
|
+
subject = _extract_subject(row)
|
|
134
|
+
if selected_subjects and subject.lower() not in selected_subjects:
|
|
135
|
+
continue
|
|
136
|
+
sample = _row_to_sample(row, index=index, subject=subject)
|
|
137
|
+
samples.append(sample)
|
|
138
|
+
if limit is not None and len(samples) >= limit:
|
|
139
|
+
break
|
|
140
|
+
return samples
|
|
141
|
+
|
|
142
|
+
|
|
143
|
+
def _extract_subject(row: dict[str, Any]) -> str:
|
|
144
|
+
for key in ("subject", "category", "group", "topic", "field"):
|
|
145
|
+
value = row.get(key)
|
|
146
|
+
if value:
|
|
147
|
+
return str(value)
|
|
148
|
+
return "unknown"
|
|
149
|
+
|
|
150
|
+
|
|
151
|
+
def _row_to_sample(row: dict[str, Any], *, index: int, subject: str) -> MmluProSample:
|
|
152
|
+
unique_id = (
|
|
153
|
+
row.get("id")
|
|
154
|
+
or row.get("question_id")
|
|
155
|
+
or row.get("unique_id")
|
|
156
|
+
or f"mmlu-pro-{index:05d}"
|
|
157
|
+
)
|
|
158
|
+
question = row.get("question") or row.get("Question") or row.get("prompt") or ""
|
|
159
|
+
choices = _extract_choices(row)
|
|
160
|
+
answer = (
|
|
161
|
+
row.get("answer")
|
|
162
|
+
or row.get("Answer")
|
|
163
|
+
or row.get("correct_answer")
|
|
164
|
+
or row.get("correct")
|
|
165
|
+
or ""
|
|
166
|
+
)
|
|
167
|
+
metadata_keys = {
|
|
168
|
+
"question",
|
|
169
|
+
"Question",
|
|
170
|
+
"prompt",
|
|
171
|
+
"choices",
|
|
172
|
+
"options",
|
|
173
|
+
"answer",
|
|
174
|
+
"Answer",
|
|
175
|
+
"correct_answer",
|
|
176
|
+
"correct",
|
|
177
|
+
}
|
|
178
|
+
metadata = {key: value for key, value in row.items() if key not in metadata_keys}
|
|
179
|
+
sample = MmluProSample.model_validate(
|
|
180
|
+
{
|
|
181
|
+
"unique_id": str(unique_id),
|
|
182
|
+
"question": str(question),
|
|
183
|
+
"choices": choices,
|
|
184
|
+
"answer": answer,
|
|
185
|
+
"subject": str(subject),
|
|
186
|
+
"metadata": metadata,
|
|
187
|
+
}
|
|
188
|
+
)
|
|
189
|
+
return sample
|
|
190
|
+
|
|
191
|
+
|
|
192
|
+
def _extract_choices(row: dict[str, Any]) -> list[str]:
|
|
193
|
+
candidates = row.get("choices") or row.get("options") or row.get("Choices")
|
|
194
|
+
if isinstance(candidates, dict):
|
|
195
|
+
return [str(value) for _, value in sorted(candidates.items())]
|
|
196
|
+
if isinstance(candidates, (list, tuple)):
|
|
197
|
+
return [str(item) for item in candidates]
|
|
198
|
+
|
|
199
|
+
choice_keys = [
|
|
200
|
+
"A",
|
|
201
|
+
"B",
|
|
202
|
+
"C",
|
|
203
|
+
"D",
|
|
204
|
+
"E",
|
|
205
|
+
"F",
|
|
206
|
+
"choice_a",
|
|
207
|
+
"choice_b",
|
|
208
|
+
"choice_c",
|
|
209
|
+
"choice_d",
|
|
210
|
+
"choice_e",
|
|
211
|
+
"choice_f",
|
|
212
|
+
"option_a",
|
|
213
|
+
"option_b",
|
|
214
|
+
"option_c",
|
|
215
|
+
"option_d",
|
|
216
|
+
"option_e",
|
|
217
|
+
"option_f",
|
|
218
|
+
]
|
|
219
|
+
collected: list[str] = []
|
|
220
|
+
for key in choice_keys:
|
|
221
|
+
if key in row:
|
|
222
|
+
collected.append(str(row[key]))
|
|
223
|
+
if collected:
|
|
224
|
+
return collected
|
|
225
|
+
return []
|
|
226
|
+
|
|
227
|
+
|
|
228
|
+
def _load_from_huggingface(*, split: str) -> Iterable[dict[str, Any]]:
|
|
229
|
+
try:
|
|
230
|
+
from datasets import load_dataset
|
|
231
|
+
except ImportError as exc: # pragma: no cover - optional dependency
|
|
232
|
+
raise RuntimeError(
|
|
233
|
+
"datasets is required to load MMLU-Pro from Hugging Face. Install it via `uv pip install '.[hf]'`."
|
|
234
|
+
) from exc
|
|
235
|
+
|
|
236
|
+
dataset = load_dataset(_DATASET_NAME, split=split)
|
|
237
|
+
for row in dataset:
|
|
238
|
+
yield dict(row)
|
|
239
|
+
|
|
240
|
+
|
|
241
|
+
def _load_from_local(root: Path) -> Iterator[dict[str, Any]]:
|
|
242
|
+
if not root.exists():
|
|
243
|
+
raise FileNotFoundError(f"Local dataset directory not found: {root}")
|
|
244
|
+
|
|
245
|
+
for path in root.rglob("*"):
|
|
246
|
+
if path.suffix.lower() == ".json":
|
|
247
|
+
with path.open("r", encoding="utf-8") as handle:
|
|
248
|
+
row = json.load(handle)
|
|
249
|
+
row.setdefault("id", path.stem)
|
|
250
|
+
yield row
|
|
251
|
+
elif path.suffix.lower() in {".jsonl", ".ndjson"}:
|
|
252
|
+
with path.open("r", encoding="utf-8") as handle:
|
|
253
|
+
for line_num, line in enumerate(handle, start=1):
|
|
254
|
+
line = line.strip()
|
|
255
|
+
if not line:
|
|
256
|
+
continue
|
|
257
|
+
row = json.loads(line)
|
|
258
|
+
row.setdefault("id", f"{path.stem}-{line_num}")
|
|
259
|
+
yield row
|
|
260
|
+
|
|
261
|
+
|
|
262
|
+
__all__ = ["MmluProSample", "load_mmlu_pro"]
|