themis-eval 0.1.0__py3-none-any.whl → 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- themis/cli/__init__.py +5 -0
- themis/cli/__main__.py +6 -0
- themis/cli/commands/__init__.py +19 -0
- themis/cli/commands/benchmarks.py +221 -0
- themis/cli/commands/comparison.py +394 -0
- themis/cli/commands/config_commands.py +244 -0
- themis/cli/commands/cost.py +214 -0
- themis/cli/commands/demo.py +68 -0
- themis/cli/commands/info.py +90 -0
- themis/cli/commands/leaderboard.py +362 -0
- themis/cli/commands/math_benchmarks.py +318 -0
- themis/cli/commands/mcq_benchmarks.py +207 -0
- themis/cli/commands/sample_run.py +244 -0
- themis/cli/commands/visualize.py +299 -0
- themis/cli/main.py +93 -0
- themis/cli/new_project.py +33 -0
- themis/cli/utils.py +51 -0
- themis/config/__init__.py +19 -0
- themis/config/loader.py +27 -0
- themis/config/registry.py +34 -0
- themis/config/runtime.py +214 -0
- themis/config/schema.py +112 -0
- themis/core/__init__.py +5 -0
- themis/core/conversation.py +354 -0
- themis/core/entities.py +164 -0
- themis/core/serialization.py +231 -0
- themis/core/tools.py +393 -0
- themis/core/types.py +141 -0
- themis/datasets/__init__.py +273 -0
- themis/datasets/base.py +264 -0
- themis/datasets/commonsense_qa.py +174 -0
- themis/datasets/competition_math.py +265 -0
- themis/datasets/coqa.py +133 -0
- themis/datasets/gpqa.py +190 -0
- themis/datasets/gsm8k.py +123 -0
- themis/datasets/gsm_symbolic.py +124 -0
- themis/datasets/math500.py +122 -0
- themis/datasets/med_qa.py +179 -0
- themis/datasets/medmcqa.py +169 -0
- themis/datasets/mmlu_pro.py +262 -0
- themis/datasets/piqa.py +146 -0
- themis/datasets/registry.py +201 -0
- themis/datasets/schema.py +245 -0
- themis/datasets/sciq.py +150 -0
- themis/datasets/social_i_qa.py +151 -0
- themis/datasets/super_gpqa.py +263 -0
- themis/evaluation/__init__.py +1 -0
- themis/evaluation/conditional.py +410 -0
- themis/evaluation/extractors/__init__.py +19 -0
- themis/evaluation/extractors/error_taxonomy_extractor.py +80 -0
- themis/evaluation/extractors/exceptions.py +7 -0
- themis/evaluation/extractors/identity_extractor.py +29 -0
- themis/evaluation/extractors/json_field_extractor.py +45 -0
- themis/evaluation/extractors/math_verify_extractor.py +37 -0
- themis/evaluation/extractors/regex_extractor.py +43 -0
- themis/evaluation/math_verify_utils.py +87 -0
- themis/evaluation/metrics/__init__.py +21 -0
- themis/evaluation/metrics/composite_metric.py +47 -0
- themis/evaluation/metrics/consistency_metric.py +80 -0
- themis/evaluation/metrics/exact_match.py +51 -0
- themis/evaluation/metrics/length_difference_tolerance.py +33 -0
- themis/evaluation/metrics/math_verify_accuracy.py +40 -0
- themis/evaluation/metrics/pairwise_judge_metric.py +141 -0
- themis/evaluation/metrics/response_length.py +33 -0
- themis/evaluation/metrics/rubric_judge_metric.py +134 -0
- themis/evaluation/pipeline.py +49 -0
- themis/evaluation/pipelines/__init__.py +15 -0
- themis/evaluation/pipelines/composable_pipeline.py +357 -0
- themis/evaluation/pipelines/standard_pipeline.py +288 -0
- themis/evaluation/reports.py +293 -0
- themis/evaluation/statistics/__init__.py +53 -0
- themis/evaluation/statistics/bootstrap.py +79 -0
- themis/evaluation/statistics/confidence_intervals.py +121 -0
- themis/evaluation/statistics/distributions.py +207 -0
- themis/evaluation/statistics/effect_sizes.py +124 -0
- themis/evaluation/statistics/hypothesis_tests.py +305 -0
- themis/evaluation/statistics/types.py +139 -0
- themis/evaluation/strategies/__init__.py +13 -0
- themis/evaluation/strategies/attempt_aware_evaluation_strategy.py +51 -0
- themis/evaluation/strategies/default_evaluation_strategy.py +25 -0
- themis/evaluation/strategies/evaluation_strategy.py +24 -0
- themis/evaluation/strategies/judge_evaluation_strategy.py +64 -0
- themis/experiment/__init__.py +5 -0
- themis/experiment/builder.py +151 -0
- themis/experiment/cache_manager.py +129 -0
- themis/experiment/comparison.py +631 -0
- themis/experiment/cost.py +310 -0
- themis/experiment/definitions.py +62 -0
- themis/experiment/export.py +690 -0
- themis/experiment/export_csv.py +159 -0
- themis/experiment/integration_manager.py +104 -0
- themis/experiment/math.py +192 -0
- themis/experiment/mcq.py +169 -0
- themis/experiment/orchestrator.py +373 -0
- themis/experiment/pricing.py +317 -0
- themis/experiment/storage.py +255 -0
- themis/experiment/visualization.py +588 -0
- themis/generation/__init__.py +1 -0
- themis/generation/agentic_runner.py +420 -0
- themis/generation/batching.py +254 -0
- themis/generation/clients.py +143 -0
- themis/generation/conversation_runner.py +236 -0
- themis/generation/plan.py +456 -0
- themis/generation/providers/litellm_provider.py +221 -0
- themis/generation/providers/vllm_provider.py +135 -0
- themis/generation/router.py +34 -0
- themis/generation/runner.py +207 -0
- themis/generation/strategies.py +98 -0
- themis/generation/templates.py +71 -0
- themis/generation/turn_strategies.py +393 -0
- themis/generation/types.py +9 -0
- themis/integrations/__init__.py +0 -0
- themis/integrations/huggingface.py +61 -0
- themis/integrations/wandb.py +65 -0
- themis/interfaces/__init__.py +83 -0
- themis/project/__init__.py +20 -0
- themis/project/definitions.py +98 -0
- themis/project/patterns.py +230 -0
- themis/providers/__init__.py +5 -0
- themis/providers/registry.py +39 -0
- themis/utils/api_generator.py +379 -0
- themis/utils/cost_tracking.py +376 -0
- themis/utils/dashboard.py +452 -0
- themis/utils/logging_utils.py +41 -0
- themis/utils/progress.py +58 -0
- themis/utils/tracing.py +320 -0
- {themis_eval-0.1.0.dist-info → themis_eval-0.1.1.dist-info}/METADATA +1 -1
- themis_eval-0.1.1.dist-info/RECORD +134 -0
- themis_eval-0.1.0.dist-info/RECORD +0 -8
- {themis_eval-0.1.0.dist-info → themis_eval-0.1.1.dist-info}/WHEEL +0 -0
- {themis_eval-0.1.0.dist-info → themis_eval-0.1.1.dist-info}/licenses/LICENSE +0 -0
- {themis_eval-0.1.0.dist-info → themis_eval-0.1.1.dist-info}/top_level.txt +0 -0
themis/datasets/sciq.py
ADDED
|
@@ -0,0 +1,150 @@
|
|
|
1
|
+
"""Helpers for working with the allenai/sciq dataset."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import json
|
|
6
|
+
import random
|
|
7
|
+
from pathlib import Path
|
|
8
|
+
from typing import Any, Iterable, Iterator, List, Sequence
|
|
9
|
+
|
|
10
|
+
from pydantic import BaseModel, Field, field_validator
|
|
11
|
+
|
|
12
|
+
_DATASET_NAME = "allenai/sciq"
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class SciQSample(BaseModel):
|
|
16
|
+
unique_id: str
|
|
17
|
+
question: str
|
|
18
|
+
choices: list[str]
|
|
19
|
+
answer: str
|
|
20
|
+
support: str = Field(default="")
|
|
21
|
+
metadata: dict[str, Any] = Field(default_factory=dict)
|
|
22
|
+
|
|
23
|
+
@field_validator("choices", mode="before")
|
|
24
|
+
@classmethod
|
|
25
|
+
def _ensure_choices(cls, value: Any) -> list[str]:
|
|
26
|
+
if value is None:
|
|
27
|
+
return []
|
|
28
|
+
if isinstance(value, (list, tuple)):
|
|
29
|
+
return [str(item) for item in value]
|
|
30
|
+
raise TypeError("choices must be a sequence")
|
|
31
|
+
|
|
32
|
+
def to_generation_example(self) -> dict[str, Any]:
|
|
33
|
+
return {
|
|
34
|
+
"unique_id": self.unique_id,
|
|
35
|
+
"question": self.question,
|
|
36
|
+
"choices": list(self.choices),
|
|
37
|
+
"answer": self.answer,
|
|
38
|
+
"support": self.support,
|
|
39
|
+
"metadata": dict(self.metadata),
|
|
40
|
+
}
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
def load_sciq(
|
|
44
|
+
*,
|
|
45
|
+
split: str = "test",
|
|
46
|
+
limit: int | None = None,
|
|
47
|
+
source: str = "huggingface",
|
|
48
|
+
data_dir: str | Path | None = None,
|
|
49
|
+
) -> List[SciQSample]:
|
|
50
|
+
"""Load SciQ samples from Hugging Face or a local directory."""
|
|
51
|
+
|
|
52
|
+
if source not in {"huggingface", "local"}:
|
|
53
|
+
raise ValueError(
|
|
54
|
+
f"Unsupported source '{source}'. Expected one of: 'huggingface', 'local'."
|
|
55
|
+
)
|
|
56
|
+
|
|
57
|
+
if source == "huggingface":
|
|
58
|
+
rows = _load_from_huggingface(split=split)
|
|
59
|
+
else:
|
|
60
|
+
if data_dir is None:
|
|
61
|
+
raise ValueError(
|
|
62
|
+
"data_dir must be provided when source='local'. "
|
|
63
|
+
"Pass dataset.data_dir in configs or --data-dir on the CLI."
|
|
64
|
+
)
|
|
65
|
+
rows = _load_from_local(Path(data_dir))
|
|
66
|
+
|
|
67
|
+
samples: list[SciQSample] = []
|
|
68
|
+
for index, row in enumerate(rows, start=1):
|
|
69
|
+
sample = _row_to_sample(row, index=index)
|
|
70
|
+
samples.append(sample)
|
|
71
|
+
if limit is not None and len(samples) >= limit:
|
|
72
|
+
break
|
|
73
|
+
return samples
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
def _row_to_sample(row: dict[str, Any], *, index: int) -> SciQSample:
|
|
77
|
+
unique_id = (
|
|
78
|
+
row.get("id")
|
|
79
|
+
or row.get("unique_id")
|
|
80
|
+
or f"sciq-{index:05d}"
|
|
81
|
+
)
|
|
82
|
+
question = row.get("question") or ""
|
|
83
|
+
|
|
84
|
+
# SciQ has 'correct_answer', 'distractor1', 'distractor2', 'distractor3'
|
|
85
|
+
correct = str(row.get("correct_answer") or "")
|
|
86
|
+
distractors = [
|
|
87
|
+
str(row.get("distractor1") or ""),
|
|
88
|
+
str(row.get("distractor2") or ""),
|
|
89
|
+
str(row.get("distractor3") or ""),
|
|
90
|
+
]
|
|
91
|
+
|
|
92
|
+
# Filter empty distractors just in case
|
|
93
|
+
distractors = [d for d in distractors if d]
|
|
94
|
+
|
|
95
|
+
choices = [correct] + distractors
|
|
96
|
+
# Sort to be deterministic
|
|
97
|
+
choices.sort()
|
|
98
|
+
|
|
99
|
+
support = str(row.get("support") or "")
|
|
100
|
+
|
|
101
|
+
metadata_keys = {
|
|
102
|
+
"question", "correct_answer", "distractor1", "distractor2", "distractor3", "support", "id"
|
|
103
|
+
}
|
|
104
|
+
metadata = {key: value for key, value in row.items() if key not in metadata_keys}
|
|
105
|
+
|
|
106
|
+
return SciQSample(
|
|
107
|
+
unique_id=str(unique_id),
|
|
108
|
+
question=str(question),
|
|
109
|
+
choices=choices,
|
|
110
|
+
answer=correct,
|
|
111
|
+
support=support,
|
|
112
|
+
metadata=metadata,
|
|
113
|
+
)
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
def _load_from_huggingface(*, split: str) -> Iterable[dict[str, Any]]:
|
|
117
|
+
try:
|
|
118
|
+
from datasets import load_dataset
|
|
119
|
+
except ImportError as exc: # pragma: no cover - optional dependency
|
|
120
|
+
raise RuntimeError(
|
|
121
|
+
"datasets is required to load SciQ from Hugging Face. Install it via `uv pip install '.[hf]'`."
|
|
122
|
+
) from exc
|
|
123
|
+
|
|
124
|
+
dataset = load_dataset(_DATASET_NAME, split=split)
|
|
125
|
+
for row in dataset:
|
|
126
|
+
yield dict(row)
|
|
127
|
+
|
|
128
|
+
|
|
129
|
+
def _load_from_local(root: Path) -> Iterator[dict[str, Any]]:
|
|
130
|
+
if not root.exists():
|
|
131
|
+
raise FileNotFoundError(f"Local dataset directory not found: {root}")
|
|
132
|
+
|
|
133
|
+
for path in root.rglob("*"):
|
|
134
|
+
if path.suffix.lower() == ".json":
|
|
135
|
+
with path.open("r", encoding="utf-8") as handle:
|
|
136
|
+
row = json.load(handle)
|
|
137
|
+
row.setdefault("id", path.stem)
|
|
138
|
+
yield row
|
|
139
|
+
elif path.suffix.lower() in {".jsonl", ".ndjson"}:
|
|
140
|
+
with path.open("r", encoding="utf-8") as handle:
|
|
141
|
+
for line_num, line in enumerate(handle, start=1):
|
|
142
|
+
line = line.strip()
|
|
143
|
+
if not line:
|
|
144
|
+
continue
|
|
145
|
+
row = json.loads(line)
|
|
146
|
+
row.setdefault("id", f"{path.stem}-{line_num}")
|
|
147
|
+
yield row
|
|
148
|
+
|
|
149
|
+
|
|
150
|
+
__all__ = ["SciQSample", "load_sciq"]
|
|
@@ -0,0 +1,151 @@
|
|
|
1
|
+
"""Helpers for working with the allenai/social_i_qa dataset."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import json
|
|
6
|
+
from pathlib import Path
|
|
7
|
+
from typing import Any, Iterable, Iterator, List, Sequence
|
|
8
|
+
|
|
9
|
+
from pydantic import BaseModel, Field, field_validator
|
|
10
|
+
|
|
11
|
+
_DATASET_NAME = "allenai/social_i_qa"
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class SocialIQaSample(BaseModel):
|
|
15
|
+
unique_id: str
|
|
16
|
+
context: str
|
|
17
|
+
question: str
|
|
18
|
+
choices: list[str]
|
|
19
|
+
answer: str
|
|
20
|
+
metadata: dict[str, Any] = Field(default_factory=dict)
|
|
21
|
+
|
|
22
|
+
@field_validator("choices", mode="before")
|
|
23
|
+
@classmethod
|
|
24
|
+
def _ensure_choices(cls, value: Any) -> list[str]:
|
|
25
|
+
if value is None:
|
|
26
|
+
return []
|
|
27
|
+
if isinstance(value, (list, tuple)):
|
|
28
|
+
return [str(item) for item in value]
|
|
29
|
+
raise TypeError("choices must be a sequence")
|
|
30
|
+
|
|
31
|
+
def to_generation_example(self) -> dict[str, Any]:
|
|
32
|
+
return {
|
|
33
|
+
"unique_id": self.unique_id,
|
|
34
|
+
"context": self.context,
|
|
35
|
+
"question": self.question,
|
|
36
|
+
"choices": list(self.choices),
|
|
37
|
+
"answer": self.answer,
|
|
38
|
+
"metadata": dict(self.metadata),
|
|
39
|
+
}
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
def load_social_i_qa(
|
|
43
|
+
*,
|
|
44
|
+
split: str = "validation", # Test set usually has no labels
|
|
45
|
+
limit: int | None = None,
|
|
46
|
+
source: str = "huggingface",
|
|
47
|
+
data_dir: str | Path | None = None,
|
|
48
|
+
) -> List[SocialIQaSample]:
|
|
49
|
+
"""Load Social IQA samples from Hugging Face or a local directory."""
|
|
50
|
+
|
|
51
|
+
if source not in {"huggingface", "local"}:
|
|
52
|
+
raise ValueError(
|
|
53
|
+
f"Unsupported source '{source}'. Expected one of: 'huggingface', 'local'."
|
|
54
|
+
)
|
|
55
|
+
|
|
56
|
+
if source == "huggingface":
|
|
57
|
+
rows = _load_from_huggingface(split=split)
|
|
58
|
+
else:
|
|
59
|
+
if data_dir is None:
|
|
60
|
+
raise ValueError(
|
|
61
|
+
"data_dir must be provided when source='local'. "
|
|
62
|
+
"Pass dataset.data_dir in configs or --data-dir on the CLI."
|
|
63
|
+
)
|
|
64
|
+
rows = _load_from_local(Path(data_dir))
|
|
65
|
+
|
|
66
|
+
samples: list[SocialIQaSample] = []
|
|
67
|
+
for index, row in enumerate(rows, start=1):
|
|
68
|
+
sample = _row_to_sample(row, index=index)
|
|
69
|
+
samples.append(sample)
|
|
70
|
+
if limit is not None and len(samples) >= limit:
|
|
71
|
+
break
|
|
72
|
+
return samples
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
def _row_to_sample(row: dict[str, Any], *, index: int) -> SocialIQaSample:
|
|
76
|
+
unique_id = (
|
|
77
|
+
row.get("id")
|
|
78
|
+
or row.get("unique_id")
|
|
79
|
+
or f"social-iqa-{index:05d}"
|
|
80
|
+
)
|
|
81
|
+
context = row.get("context") or ""
|
|
82
|
+
question = row.get("question") or ""
|
|
83
|
+
|
|
84
|
+
# Social IQA has 'answerA', 'answerB', 'answerC'
|
|
85
|
+
choices = [
|
|
86
|
+
str(row.get("answerA") or ""),
|
|
87
|
+
str(row.get("answerB") or ""),
|
|
88
|
+
str(row.get("answerC") or ""),
|
|
89
|
+
]
|
|
90
|
+
|
|
91
|
+
# label is '1', '2', '3' (strings)
|
|
92
|
+
label = row.get("label")
|
|
93
|
+
answer = ""
|
|
94
|
+
if label is not None:
|
|
95
|
+
try:
|
|
96
|
+
label_int = int(label)
|
|
97
|
+
if 1 <= label_int <= len(choices):
|
|
98
|
+
answer = choices[label_int - 1]
|
|
99
|
+
except (ValueError, TypeError):
|
|
100
|
+
pass
|
|
101
|
+
|
|
102
|
+
metadata_keys = {
|
|
103
|
+
"context", "question", "answerA", "answerB", "answerC", "label", "id"
|
|
104
|
+
}
|
|
105
|
+
metadata = {key: value for key, value in row.items() if key not in metadata_keys}
|
|
106
|
+
|
|
107
|
+
return SocialIQaSample(
|
|
108
|
+
unique_id=str(unique_id),
|
|
109
|
+
context=str(context),
|
|
110
|
+
question=str(question),
|
|
111
|
+
choices=choices,
|
|
112
|
+
answer=answer,
|
|
113
|
+
metadata=metadata,
|
|
114
|
+
)
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
def _load_from_huggingface(*, split: str) -> Iterable[dict[str, Any]]:
|
|
118
|
+
try:
|
|
119
|
+
from datasets import load_dataset
|
|
120
|
+
except ImportError as exc: # pragma: no cover - optional dependency
|
|
121
|
+
raise RuntimeError(
|
|
122
|
+
"datasets is required to load Social IQA from Hugging Face. Install it via `uv pip install '.[hf]'`."
|
|
123
|
+
) from exc
|
|
124
|
+
|
|
125
|
+
dataset = load_dataset(_DATASET_NAME, split=split)
|
|
126
|
+
for row in dataset:
|
|
127
|
+
yield dict(row)
|
|
128
|
+
|
|
129
|
+
|
|
130
|
+
def _load_from_local(root: Path) -> Iterator[dict[str, Any]]:
|
|
131
|
+
if not root.exists():
|
|
132
|
+
raise FileNotFoundError(f"Local dataset directory not found: {root}")
|
|
133
|
+
|
|
134
|
+
for path in root.rglob("*"):
|
|
135
|
+
if path.suffix.lower() == ".json":
|
|
136
|
+
with path.open("r", encoding="utf-8") as handle:
|
|
137
|
+
row = json.load(handle)
|
|
138
|
+
row.setdefault("id", path.stem)
|
|
139
|
+
yield row
|
|
140
|
+
elif path.suffix.lower() in {".jsonl", ".ndjson"}:
|
|
141
|
+
with path.open("r", encoding="utf-8") as handle:
|
|
142
|
+
for line_num, line in enumerate(handle, start=1):
|
|
143
|
+
line = line.strip()
|
|
144
|
+
if not line:
|
|
145
|
+
continue
|
|
146
|
+
row = json.loads(line)
|
|
147
|
+
row.setdefault("id", f"{path.stem}-{line_num}")
|
|
148
|
+
yield row
|
|
149
|
+
|
|
150
|
+
|
|
151
|
+
__all__ = ["SocialIQaSample", "load_social_i_qa"]
|
|
@@ -0,0 +1,263 @@
|
|
|
1
|
+
"""Helpers for working with the m-a-p/SuperGPQA dataset."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import json
|
|
6
|
+
import string
|
|
7
|
+
from pathlib import Path
|
|
8
|
+
from typing import Any, Iterable, Iterator, List, Sequence
|
|
9
|
+
|
|
10
|
+
from pydantic import BaseModel, Field, ValidationInfo, field_validator
|
|
11
|
+
|
|
12
|
+
_DATASET_NAME = "m-a-p/SuperGPQA"
|
|
13
|
+
_CHOICE_LABELS = tuple(string.ascii_uppercase)
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def _index_to_label(index: int) -> str:
|
|
17
|
+
if 0 <= index < len(_CHOICE_LABELS):
|
|
18
|
+
return _CHOICE_LABELS[index]
|
|
19
|
+
return str(index)
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def _normalize_answer(value: Any, *, total_choices: int | None = None) -> str:
|
|
23
|
+
if isinstance(value, int):
|
|
24
|
+
return _index_to_label(value)
|
|
25
|
+
if isinstance(value, float):
|
|
26
|
+
as_int = int(value)
|
|
27
|
+
if as_int == value:
|
|
28
|
+
return _index_to_label(as_int)
|
|
29
|
+
text = str(value or "").strip()
|
|
30
|
+
if not text:
|
|
31
|
+
return ""
|
|
32
|
+
lowered = text.lower()
|
|
33
|
+
if lowered.startswith("option "):
|
|
34
|
+
text = text.split(" ", 1)[-1]
|
|
35
|
+
if lowered.startswith("choice "):
|
|
36
|
+
text = text.split(" ", 1)[-1]
|
|
37
|
+
if text.isdigit():
|
|
38
|
+
index = int(text)
|
|
39
|
+
if total_choices is None or 0 <= index < total_choices:
|
|
40
|
+
return _index_to_label(index)
|
|
41
|
+
text = text.strip().rstrip(".")
|
|
42
|
+
if len(text) == 1 and text.isalpha():
|
|
43
|
+
return text.upper()
|
|
44
|
+
if total_choices is not None:
|
|
45
|
+
mapping = {str(idx): _index_to_label(idx) for idx in range(total_choices)}
|
|
46
|
+
normalized = mapping.get(text)
|
|
47
|
+
if normalized:
|
|
48
|
+
return normalized
|
|
49
|
+
return text
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
class SuperGpqaSample(BaseModel):
|
|
53
|
+
unique_id: str
|
|
54
|
+
question: str
|
|
55
|
+
choices: list[str]
|
|
56
|
+
answer: str
|
|
57
|
+
subject: str = Field(default="unknown")
|
|
58
|
+
metadata: dict[str, Any] = Field(default_factory=dict)
|
|
59
|
+
choice_labels: list[str] = Field(default_factory=list)
|
|
60
|
+
|
|
61
|
+
@field_validator("choices", mode="before")
|
|
62
|
+
@classmethod
|
|
63
|
+
def _ensure_choices(cls, value: Any) -> list[str]:
|
|
64
|
+
if value is None:
|
|
65
|
+
return []
|
|
66
|
+
if isinstance(value, dict):
|
|
67
|
+
# Sort by key to keep deterministic order
|
|
68
|
+
return [str(v) for _, v in sorted(value.items())]
|
|
69
|
+
if isinstance(value, (list, tuple)):
|
|
70
|
+
return [str(item) for item in value]
|
|
71
|
+
raise TypeError("choices must be a sequence or mapping")
|
|
72
|
+
|
|
73
|
+
@field_validator("choice_labels", mode="before")
|
|
74
|
+
@classmethod
|
|
75
|
+
def _build_choice_labels(cls, value: Any, info: ValidationInfo) -> list[str]:
|
|
76
|
+
if value:
|
|
77
|
+
return [str(item) for item in value]
|
|
78
|
+
choices = info.data.get("choices") if hasattr(info, "data") else None
|
|
79
|
+
total = len(choices) if isinstance(choices, list) else 0
|
|
80
|
+
return [*_CHOICE_LABELS[:total]]
|
|
81
|
+
|
|
82
|
+
@field_validator("answer", mode="before")
|
|
83
|
+
@classmethod
|
|
84
|
+
def _normalize_answer_field(cls, value: Any, info: ValidationInfo) -> str:
|
|
85
|
+
choices = info.data.get("choices") if hasattr(info, "data") else None
|
|
86
|
+
total = len(choices) if isinstance(choices, list) else None
|
|
87
|
+
return _normalize_answer(value, total_choices=total)
|
|
88
|
+
|
|
89
|
+
def to_generation_example(self) -> dict[str, Any]:
|
|
90
|
+
effective_labels = (
|
|
91
|
+
list(self.choice_labels)
|
|
92
|
+
if self.choice_labels
|
|
93
|
+
else list(_CHOICE_LABELS[: len(self.choices)])
|
|
94
|
+
)
|
|
95
|
+
return {
|
|
96
|
+
"unique_id": self.unique_id,
|
|
97
|
+
"question": self.question,
|
|
98
|
+
"choices": list(self.choices),
|
|
99
|
+
"choice_labels": effective_labels,
|
|
100
|
+
"answer": self.answer,
|
|
101
|
+
"subject": self.subject,
|
|
102
|
+
"metadata": dict(self.metadata),
|
|
103
|
+
}
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
def load_super_gpqa(
|
|
107
|
+
*,
|
|
108
|
+
split: str = "test",
|
|
109
|
+
limit: int | None = None,
|
|
110
|
+
source: str = "huggingface",
|
|
111
|
+
data_dir: str | Path | None = None,
|
|
112
|
+
subjects: Sequence[str] | None = None,
|
|
113
|
+
) -> List[SuperGpqaSample]:
|
|
114
|
+
"""Load SuperGPQA samples from Hugging Face or a local directory."""
|
|
115
|
+
|
|
116
|
+
if source not in {"huggingface", "local"}:
|
|
117
|
+
raise ValueError(
|
|
118
|
+
f"Unsupported source '{source}'. Expected one of: 'huggingface', 'local'."
|
|
119
|
+
)
|
|
120
|
+
|
|
121
|
+
if source == "huggingface":
|
|
122
|
+
rows = _load_from_huggingface(split=split)
|
|
123
|
+
else:
|
|
124
|
+
if data_dir is None:
|
|
125
|
+
raise ValueError(
|
|
126
|
+
"data_dir must be provided when source='local'. "
|
|
127
|
+
"Pass dataset.data_dir in configs or --data-dir on the CLI."
|
|
128
|
+
)
|
|
129
|
+
rows = _load_from_local(Path(data_dir))
|
|
130
|
+
|
|
131
|
+
samples: list[SuperGpqaSample] = []
|
|
132
|
+
selected_subjects = {s.lower() for s in subjects} if subjects else None
|
|
133
|
+
for index, row in enumerate(rows, start=1):
|
|
134
|
+
subject = _extract_subject(row)
|
|
135
|
+
if selected_subjects and subject.lower() not in selected_subjects:
|
|
136
|
+
continue
|
|
137
|
+
sample = _row_to_sample(row, index=index, subject=subject)
|
|
138
|
+
samples.append(sample)
|
|
139
|
+
if limit is not None and len(samples) >= limit:
|
|
140
|
+
break
|
|
141
|
+
return samples
|
|
142
|
+
|
|
143
|
+
|
|
144
|
+
def _extract_subject(row: dict[str, Any]) -> str:
|
|
145
|
+
for key in ("subject", "category", "field", "domain", "track"):
|
|
146
|
+
value = row.get(key)
|
|
147
|
+
if value:
|
|
148
|
+
return str(value)
|
|
149
|
+
return "unknown"
|
|
150
|
+
|
|
151
|
+
|
|
152
|
+
def _row_to_sample(row: dict[str, Any], *, index: int, subject: str) -> SuperGpqaSample:
|
|
153
|
+
unique_id = (
|
|
154
|
+
row.get("id")
|
|
155
|
+
or row.get("question_id")
|
|
156
|
+
or row.get("unique_id")
|
|
157
|
+
or f"supergpqa-{index:05d}"
|
|
158
|
+
)
|
|
159
|
+
question = row.get("question") or row.get("Question") or row.get("prompt") or ""
|
|
160
|
+
choices = _extract_choices(row)
|
|
161
|
+
answer = (
|
|
162
|
+
row.get("answer")
|
|
163
|
+
or row.get("Answer")
|
|
164
|
+
or row.get("correct_answer")
|
|
165
|
+
or row.get("correct")
|
|
166
|
+
or ""
|
|
167
|
+
)
|
|
168
|
+
metadata_keys = {
|
|
169
|
+
"question",
|
|
170
|
+
"Question",
|
|
171
|
+
"prompt",
|
|
172
|
+
"choices",
|
|
173
|
+
"options",
|
|
174
|
+
"answer",
|
|
175
|
+
"Answer",
|
|
176
|
+
"correct_answer",
|
|
177
|
+
"correct",
|
|
178
|
+
}
|
|
179
|
+
metadata = {key: value for key, value in row.items() if key not in metadata_keys}
|
|
180
|
+
sample = SuperGpqaSample.model_validate(
|
|
181
|
+
{
|
|
182
|
+
"unique_id": str(unique_id),
|
|
183
|
+
"question": str(question),
|
|
184
|
+
"choices": choices,
|
|
185
|
+
"answer": answer,
|
|
186
|
+
"subject": str(subject),
|
|
187
|
+
"metadata": metadata,
|
|
188
|
+
}
|
|
189
|
+
)
|
|
190
|
+
return sample
|
|
191
|
+
|
|
192
|
+
|
|
193
|
+
def _extract_choices(row: dict[str, Any]) -> list[str]:
|
|
194
|
+
candidates = row.get("choices") or row.get("options") or row.get("Choices")
|
|
195
|
+
if isinstance(candidates, dict):
|
|
196
|
+
return [str(value) for _, value in sorted(candidates.items())]
|
|
197
|
+
if isinstance(candidates, (list, tuple)):
|
|
198
|
+
return [str(item) for item in candidates]
|
|
199
|
+
|
|
200
|
+
choice_keys = [
|
|
201
|
+
"A",
|
|
202
|
+
"B",
|
|
203
|
+
"C",
|
|
204
|
+
"D",
|
|
205
|
+
"E",
|
|
206
|
+
"F",
|
|
207
|
+
"choice_a",
|
|
208
|
+
"choice_b",
|
|
209
|
+
"choice_c",
|
|
210
|
+
"choice_d",
|
|
211
|
+
"choice_e",
|
|
212
|
+
"choice_f",
|
|
213
|
+
"option_a",
|
|
214
|
+
"option_b",
|
|
215
|
+
"option_c",
|
|
216
|
+
"option_d",
|
|
217
|
+
"option_e",
|
|
218
|
+
"option_f",
|
|
219
|
+
]
|
|
220
|
+
collected: list[str] = []
|
|
221
|
+
for key in choice_keys:
|
|
222
|
+
if key in row:
|
|
223
|
+
collected.append(str(row[key]))
|
|
224
|
+
if collected:
|
|
225
|
+
return collected
|
|
226
|
+
return []
|
|
227
|
+
|
|
228
|
+
|
|
229
|
+
def _load_from_huggingface(*, split: str) -> Iterable[dict[str, Any]]:
|
|
230
|
+
try:
|
|
231
|
+
from datasets import load_dataset
|
|
232
|
+
except ImportError as exc: # pragma: no cover - optional dependency
|
|
233
|
+
raise RuntimeError(
|
|
234
|
+
"datasets is required to load SuperGPQA from Hugging Face. Install it via `uv pip install '.[hf]'`."
|
|
235
|
+
) from exc
|
|
236
|
+
|
|
237
|
+
dataset = load_dataset(_DATASET_NAME, split=split)
|
|
238
|
+
for row in dataset:
|
|
239
|
+
yield dict(row)
|
|
240
|
+
|
|
241
|
+
|
|
242
|
+
def _load_from_local(root: Path) -> Iterator[dict[str, Any]]:
|
|
243
|
+
if not root.exists():
|
|
244
|
+
raise FileNotFoundError(f"Local dataset directory not found: {root}")
|
|
245
|
+
|
|
246
|
+
for path in root.rglob("*"):
|
|
247
|
+
if path.suffix.lower() == ".json":
|
|
248
|
+
with path.open("r", encoding="utf-8") as handle:
|
|
249
|
+
row = json.load(handle)
|
|
250
|
+
row.setdefault("id", path.stem)
|
|
251
|
+
yield row
|
|
252
|
+
elif path.suffix.lower() in {".jsonl", ".ndjson"}:
|
|
253
|
+
with path.open("r", encoding="utf-8") as handle:
|
|
254
|
+
for line_num, line in enumerate(handle, start=1):
|
|
255
|
+
line = line.strip()
|
|
256
|
+
if not line:
|
|
257
|
+
continue
|
|
258
|
+
row = json.loads(line)
|
|
259
|
+
row.setdefault("id", f"{path.stem}-{line_num}")
|
|
260
|
+
yield row
|
|
261
|
+
|
|
262
|
+
|
|
263
|
+
__all__ = ["SuperGpqaSample", "load_super_gpqa"]
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
"""Evaluation domain primitives."""
|