themis-eval 0.1.0__py3-none-any.whl → 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- themis/cli/__init__.py +5 -0
- themis/cli/__main__.py +6 -0
- themis/cli/commands/__init__.py +19 -0
- themis/cli/commands/benchmarks.py +221 -0
- themis/cli/commands/comparison.py +394 -0
- themis/cli/commands/config_commands.py +244 -0
- themis/cli/commands/cost.py +214 -0
- themis/cli/commands/demo.py +68 -0
- themis/cli/commands/info.py +90 -0
- themis/cli/commands/leaderboard.py +362 -0
- themis/cli/commands/math_benchmarks.py +318 -0
- themis/cli/commands/mcq_benchmarks.py +207 -0
- themis/cli/commands/sample_run.py +244 -0
- themis/cli/commands/visualize.py +299 -0
- themis/cli/main.py +93 -0
- themis/cli/new_project.py +33 -0
- themis/cli/utils.py +51 -0
- themis/config/__init__.py +19 -0
- themis/config/loader.py +27 -0
- themis/config/registry.py +34 -0
- themis/config/runtime.py +214 -0
- themis/config/schema.py +112 -0
- themis/core/__init__.py +5 -0
- themis/core/conversation.py +354 -0
- themis/core/entities.py +164 -0
- themis/core/serialization.py +231 -0
- themis/core/tools.py +393 -0
- themis/core/types.py +141 -0
- themis/datasets/__init__.py +273 -0
- themis/datasets/base.py +264 -0
- themis/datasets/commonsense_qa.py +174 -0
- themis/datasets/competition_math.py +265 -0
- themis/datasets/coqa.py +133 -0
- themis/datasets/gpqa.py +190 -0
- themis/datasets/gsm8k.py +123 -0
- themis/datasets/gsm_symbolic.py +124 -0
- themis/datasets/math500.py +122 -0
- themis/datasets/med_qa.py +179 -0
- themis/datasets/medmcqa.py +169 -0
- themis/datasets/mmlu_pro.py +262 -0
- themis/datasets/piqa.py +146 -0
- themis/datasets/registry.py +201 -0
- themis/datasets/schema.py +245 -0
- themis/datasets/sciq.py +150 -0
- themis/datasets/social_i_qa.py +151 -0
- themis/datasets/super_gpqa.py +263 -0
- themis/evaluation/__init__.py +1 -0
- themis/evaluation/conditional.py +410 -0
- themis/evaluation/extractors/__init__.py +19 -0
- themis/evaluation/extractors/error_taxonomy_extractor.py +80 -0
- themis/evaluation/extractors/exceptions.py +7 -0
- themis/evaluation/extractors/identity_extractor.py +29 -0
- themis/evaluation/extractors/json_field_extractor.py +45 -0
- themis/evaluation/extractors/math_verify_extractor.py +37 -0
- themis/evaluation/extractors/regex_extractor.py +43 -0
- themis/evaluation/math_verify_utils.py +87 -0
- themis/evaluation/metrics/__init__.py +21 -0
- themis/evaluation/metrics/composite_metric.py +47 -0
- themis/evaluation/metrics/consistency_metric.py +80 -0
- themis/evaluation/metrics/exact_match.py +51 -0
- themis/evaluation/metrics/length_difference_tolerance.py +33 -0
- themis/evaluation/metrics/math_verify_accuracy.py +40 -0
- themis/evaluation/metrics/pairwise_judge_metric.py +141 -0
- themis/evaluation/metrics/response_length.py +33 -0
- themis/evaluation/metrics/rubric_judge_metric.py +134 -0
- themis/evaluation/pipeline.py +49 -0
- themis/evaluation/pipelines/__init__.py +15 -0
- themis/evaluation/pipelines/composable_pipeline.py +357 -0
- themis/evaluation/pipelines/standard_pipeline.py +288 -0
- themis/evaluation/reports.py +293 -0
- themis/evaluation/statistics/__init__.py +53 -0
- themis/evaluation/statistics/bootstrap.py +79 -0
- themis/evaluation/statistics/confidence_intervals.py +121 -0
- themis/evaluation/statistics/distributions.py +207 -0
- themis/evaluation/statistics/effect_sizes.py +124 -0
- themis/evaluation/statistics/hypothesis_tests.py +305 -0
- themis/evaluation/statistics/types.py +139 -0
- themis/evaluation/strategies/__init__.py +13 -0
- themis/evaluation/strategies/attempt_aware_evaluation_strategy.py +51 -0
- themis/evaluation/strategies/default_evaluation_strategy.py +25 -0
- themis/evaluation/strategies/evaluation_strategy.py +24 -0
- themis/evaluation/strategies/judge_evaluation_strategy.py +64 -0
- themis/experiment/__init__.py +5 -0
- themis/experiment/builder.py +151 -0
- themis/experiment/cache_manager.py +129 -0
- themis/experiment/comparison.py +631 -0
- themis/experiment/cost.py +310 -0
- themis/experiment/definitions.py +62 -0
- themis/experiment/export.py +690 -0
- themis/experiment/export_csv.py +159 -0
- themis/experiment/integration_manager.py +104 -0
- themis/experiment/math.py +192 -0
- themis/experiment/mcq.py +169 -0
- themis/experiment/orchestrator.py +373 -0
- themis/experiment/pricing.py +317 -0
- themis/experiment/storage.py +255 -0
- themis/experiment/visualization.py +588 -0
- themis/generation/__init__.py +1 -0
- themis/generation/agentic_runner.py +420 -0
- themis/generation/batching.py +254 -0
- themis/generation/clients.py +143 -0
- themis/generation/conversation_runner.py +236 -0
- themis/generation/plan.py +456 -0
- themis/generation/providers/litellm_provider.py +221 -0
- themis/generation/providers/vllm_provider.py +135 -0
- themis/generation/router.py +34 -0
- themis/generation/runner.py +207 -0
- themis/generation/strategies.py +98 -0
- themis/generation/templates.py +71 -0
- themis/generation/turn_strategies.py +393 -0
- themis/generation/types.py +9 -0
- themis/integrations/__init__.py +0 -0
- themis/integrations/huggingface.py +61 -0
- themis/integrations/wandb.py +65 -0
- themis/interfaces/__init__.py +83 -0
- themis/project/__init__.py +20 -0
- themis/project/definitions.py +98 -0
- themis/project/patterns.py +230 -0
- themis/providers/__init__.py +5 -0
- themis/providers/registry.py +39 -0
- themis/utils/api_generator.py +379 -0
- themis/utils/cost_tracking.py +376 -0
- themis/utils/dashboard.py +452 -0
- themis/utils/logging_utils.py +41 -0
- themis/utils/progress.py +58 -0
- themis/utils/tracing.py +320 -0
- {themis_eval-0.1.0.dist-info → themis_eval-0.1.1.dist-info}/METADATA +1 -1
- themis_eval-0.1.1.dist-info/RECORD +134 -0
- themis_eval-0.1.0.dist-info/RECORD +0 -8
- {themis_eval-0.1.0.dist-info → themis_eval-0.1.1.dist-info}/WHEEL +0 -0
- {themis_eval-0.1.0.dist-info → themis_eval-0.1.1.dist-info}/licenses/LICENSE +0 -0
- {themis_eval-0.1.0.dist-info → themis_eval-0.1.1.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,174 @@
|
|
|
1
|
+
"""Helpers for working with the tau/commonsense_qa dataset."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import json
|
|
6
|
+
import string
|
|
7
|
+
from pathlib import Path
|
|
8
|
+
from typing import Any, Iterable, Iterator, List, Sequence
|
|
9
|
+
|
|
10
|
+
from pydantic import BaseModel, Field, ValidationInfo, field_validator
|
|
11
|
+
|
|
12
|
+
_DATASET_NAME = "tau/commonsense_qa"
|
|
13
|
+
_CHOICE_LABELS = tuple(string.ascii_uppercase)
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class CommonsenseQaSample(BaseModel):
|
|
17
|
+
unique_id: str
|
|
18
|
+
question: str
|
|
19
|
+
choices: list[str]
|
|
20
|
+
answer: str
|
|
21
|
+
concept: str = Field(default="")
|
|
22
|
+
metadata: dict[str, Any] = Field(default_factory=dict)
|
|
23
|
+
choice_labels: list[str] = Field(default_factory=list)
|
|
24
|
+
|
|
25
|
+
@field_validator("choices", mode="before")
|
|
26
|
+
@classmethod
|
|
27
|
+
def _ensure_choices(cls, value: Any) -> list[str]:
|
|
28
|
+
if value is None:
|
|
29
|
+
return []
|
|
30
|
+
if isinstance(value, dict):
|
|
31
|
+
return [str(v) for _, v in sorted(value.items())]
|
|
32
|
+
if isinstance(value, (list, tuple)):
|
|
33
|
+
return [str(item) for item in value]
|
|
34
|
+
raise TypeError("choices must be a sequence or mapping")
|
|
35
|
+
|
|
36
|
+
@field_validator("choice_labels", mode="before")
|
|
37
|
+
@classmethod
|
|
38
|
+
def _build_choice_labels(cls, value: Any, info: ValidationInfo) -> list[str]:
|
|
39
|
+
if value:
|
|
40
|
+
return [str(item) for item in value]
|
|
41
|
+
choices = info.data.get("choices") if hasattr(info, "data") else None
|
|
42
|
+
total = len(choices) if isinstance(choices, list) else 0
|
|
43
|
+
return [*_CHOICE_LABELS[:total]]
|
|
44
|
+
|
|
45
|
+
def to_generation_example(self) -> dict[str, Any]:
|
|
46
|
+
effective_labels = (
|
|
47
|
+
list(self.choice_labels)
|
|
48
|
+
if self.choice_labels
|
|
49
|
+
else list(_CHOICE_LABELS[: len(self.choices)])
|
|
50
|
+
)
|
|
51
|
+
return {
|
|
52
|
+
"unique_id": self.unique_id,
|
|
53
|
+
"question": self.question,
|
|
54
|
+
"choices": list(self.choices),
|
|
55
|
+
"choice_labels": effective_labels,
|
|
56
|
+
"answer": self.answer,
|
|
57
|
+
"concept": self.concept,
|
|
58
|
+
"metadata": dict(self.metadata),
|
|
59
|
+
}
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
def load_commonsense_qa(
|
|
63
|
+
*,
|
|
64
|
+
split: str = "validation", # Test set usually has no labels
|
|
65
|
+
limit: int | None = None,
|
|
66
|
+
source: str = "huggingface",
|
|
67
|
+
data_dir: str | Path | None = None,
|
|
68
|
+
) -> List[CommonsenseQaSample]:
|
|
69
|
+
"""Load CommonsenseQA samples from Hugging Face or a local directory."""
|
|
70
|
+
|
|
71
|
+
if source not in {"huggingface", "local"}:
|
|
72
|
+
raise ValueError(
|
|
73
|
+
f"Unsupported source '{source}'. Expected one of: 'huggingface', 'local'."
|
|
74
|
+
)
|
|
75
|
+
|
|
76
|
+
if source == "huggingface":
|
|
77
|
+
rows = _load_from_huggingface(split=split)
|
|
78
|
+
else:
|
|
79
|
+
if data_dir is None:
|
|
80
|
+
raise ValueError(
|
|
81
|
+
"data_dir must be provided when source='local'. "
|
|
82
|
+
"Pass dataset.data_dir in configs or --data-dir on the CLI."
|
|
83
|
+
)
|
|
84
|
+
rows = _load_from_local(Path(data_dir))
|
|
85
|
+
|
|
86
|
+
samples: list[CommonsenseQaSample] = []
|
|
87
|
+
for index, row in enumerate(rows, start=1):
|
|
88
|
+
sample = _row_to_sample(row, index=index)
|
|
89
|
+
samples.append(sample)
|
|
90
|
+
if limit is not None and len(samples) >= limit:
|
|
91
|
+
break
|
|
92
|
+
return samples
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
def _row_to_sample(row: dict[str, Any], *, index: int) -> CommonsenseQaSample:
|
|
96
|
+
unique_id = (
|
|
97
|
+
row.get("id")
|
|
98
|
+
or row.get("unique_id")
|
|
99
|
+
or f"csqa-{index:05d}"
|
|
100
|
+
)
|
|
101
|
+
question = row.get("question") or ""
|
|
102
|
+
|
|
103
|
+
# CommonsenseQA format:
|
|
104
|
+
# choices: {'label': ['A', 'B', ...], 'text': ['text1', 'text2', ...]}
|
|
105
|
+
# answerKey: 'A'
|
|
106
|
+
|
|
107
|
+
choices_data = row.get("choices") or {}
|
|
108
|
+
choices = []
|
|
109
|
+
choice_labels = []
|
|
110
|
+
|
|
111
|
+
if isinstance(choices_data, dict):
|
|
112
|
+
labels = choices_data.get("label") or []
|
|
113
|
+
texts = choices_data.get("text") or []
|
|
114
|
+
|
|
115
|
+
# Zip and sort by label
|
|
116
|
+
zipped = sorted(zip(labels, texts), key=lambda x: x[0])
|
|
117
|
+
for label, text in zipped:
|
|
118
|
+
choices.append(str(text))
|
|
119
|
+
choice_labels.append(str(label))
|
|
120
|
+
|
|
121
|
+
answer = str(row.get("answerKey") or "")
|
|
122
|
+
concept = str(row.get("question_concept") or "")
|
|
123
|
+
|
|
124
|
+
metadata_keys = {
|
|
125
|
+
"question", "choices", "answerKey", "question_concept", "id"
|
|
126
|
+
}
|
|
127
|
+
metadata = {key: value for key, value in row.items() if key not in metadata_keys}
|
|
128
|
+
|
|
129
|
+
return CommonsenseQaSample(
|
|
130
|
+
unique_id=str(unique_id),
|
|
131
|
+
question=str(question),
|
|
132
|
+
choices=choices,
|
|
133
|
+
choice_labels=choice_labels,
|
|
134
|
+
answer=answer,
|
|
135
|
+
concept=concept,
|
|
136
|
+
metadata=metadata,
|
|
137
|
+
)
|
|
138
|
+
|
|
139
|
+
|
|
140
|
+
def _load_from_huggingface(*, split: str) -> Iterable[dict[str, Any]]:
|
|
141
|
+
try:
|
|
142
|
+
from datasets import load_dataset
|
|
143
|
+
except ImportError as exc: # pragma: no cover - optional dependency
|
|
144
|
+
raise RuntimeError(
|
|
145
|
+
"datasets is required to load CommonsenseQA from Hugging Face. Install it via `uv pip install '.[hf]'`."
|
|
146
|
+
) from exc
|
|
147
|
+
|
|
148
|
+
dataset = load_dataset(_DATASET_NAME, split=split)
|
|
149
|
+
for row in dataset:
|
|
150
|
+
yield dict(row)
|
|
151
|
+
|
|
152
|
+
|
|
153
|
+
def _load_from_local(root: Path) -> Iterator[dict[str, Any]]:
|
|
154
|
+
if not root.exists():
|
|
155
|
+
raise FileNotFoundError(f"Local dataset directory not found: {root}")
|
|
156
|
+
|
|
157
|
+
for path in root.rglob("*"):
|
|
158
|
+
if path.suffix.lower() == ".json":
|
|
159
|
+
with path.open("r", encoding="utf-8") as handle:
|
|
160
|
+
row = json.load(handle)
|
|
161
|
+
row.setdefault("id", path.stem)
|
|
162
|
+
yield row
|
|
163
|
+
elif path.suffix.lower() in {".jsonl", ".ndjson"}:
|
|
164
|
+
with path.open("r", encoding="utf-8") as handle:
|
|
165
|
+
for line_num, line in enumerate(handle, start=1):
|
|
166
|
+
line = line.strip()
|
|
167
|
+
if not line:
|
|
168
|
+
continue
|
|
169
|
+
row = json.loads(line)
|
|
170
|
+
row.setdefault("id", f"{path.stem}-{line_num}")
|
|
171
|
+
yield row
|
|
172
|
+
|
|
173
|
+
|
|
174
|
+
__all__ = ["CommonsenseQaSample", "load_commonsense_qa"]
|
|
@@ -0,0 +1,265 @@
|
|
|
1
|
+
"""Helpers for competition-style math benchmarks from Hugging Face."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import json
|
|
6
|
+
from pathlib import Path
|
|
7
|
+
from typing import Any, Iterable, Iterator, List, Sequence
|
|
8
|
+
|
|
9
|
+
from pydantic import BaseModel, Field, field_validator
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class CompetitionMathSample(BaseModel):
|
|
13
|
+
unique_id: str
|
|
14
|
+
problem: str
|
|
15
|
+
solution: str
|
|
16
|
+
answer: str
|
|
17
|
+
subject: str = Field(default="unknown")
|
|
18
|
+
level: str | int = Field(default="unknown")
|
|
19
|
+
metadata: dict[str, Any] = Field(default_factory=dict)
|
|
20
|
+
|
|
21
|
+
@field_validator("metadata", mode="before")
|
|
22
|
+
@classmethod
|
|
23
|
+
def _ensure_metadata(cls, value: Any) -> dict[str, Any]:
|
|
24
|
+
return dict(value or {})
|
|
25
|
+
|
|
26
|
+
@field_validator("level", mode="before")
|
|
27
|
+
@classmethod
|
|
28
|
+
def _normalize_level(cls, value: Any) -> str | int:
|
|
29
|
+
if value is None or value == "":
|
|
30
|
+
return "unknown"
|
|
31
|
+
try:
|
|
32
|
+
return int(value)
|
|
33
|
+
except (TypeError, ValueError):
|
|
34
|
+
return str(value)
|
|
35
|
+
|
|
36
|
+
def to_generation_example(self) -> dict[str, Any]:
|
|
37
|
+
payload = {
|
|
38
|
+
"unique_id": self.unique_id,
|
|
39
|
+
"problem": self.problem,
|
|
40
|
+
"solution": self.solution,
|
|
41
|
+
"answer": self.answer,
|
|
42
|
+
"subject": self.subject,
|
|
43
|
+
"level": self.level,
|
|
44
|
+
}
|
|
45
|
+
payload.update(self.metadata)
|
|
46
|
+
return payload
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
def load_competition_math(
|
|
50
|
+
*,
|
|
51
|
+
dataset: str,
|
|
52
|
+
split: str = "test",
|
|
53
|
+
limit: int | None = None,
|
|
54
|
+
source: str = "huggingface",
|
|
55
|
+
data_dir: str | Path | None = None,
|
|
56
|
+
subjects: Sequence[str] | None = None,
|
|
57
|
+
subset: str | None = None,
|
|
58
|
+
) -> List[CompetitionMathSample]:
|
|
59
|
+
"""Load math competition samples from Hugging Face or a local directory."""
|
|
60
|
+
|
|
61
|
+
if source not in {"huggingface", "local"}:
|
|
62
|
+
raise ValueError(
|
|
63
|
+
f"Unsupported source '{source}'. Expected one of: 'huggingface', 'local'."
|
|
64
|
+
)
|
|
65
|
+
|
|
66
|
+
if source == "huggingface":
|
|
67
|
+
rows = _load_from_huggingface(dataset=dataset, split=split, subset=subset)
|
|
68
|
+
else:
|
|
69
|
+
if data_dir is None:
|
|
70
|
+
raise ValueError(
|
|
71
|
+
"data_dir must be provided when source='local'. "
|
|
72
|
+
"Pass dataset.data_dir in configs or --data-dir on the CLI."
|
|
73
|
+
)
|
|
74
|
+
rows = _load_from_local(Path(data_dir))
|
|
75
|
+
|
|
76
|
+
samples: list[CompetitionMathSample] = []
|
|
77
|
+
selected_subjects = {s.lower() for s in subjects} if subjects else None
|
|
78
|
+
for index, row in enumerate(rows, start=1):
|
|
79
|
+
subject = _extract_subject(row) or "unknown"
|
|
80
|
+
if selected_subjects and subject.lower() not in selected_subjects:
|
|
81
|
+
continue
|
|
82
|
+
sample = _row_to_sample(
|
|
83
|
+
row=row,
|
|
84
|
+
index=index,
|
|
85
|
+
dataset=dataset,
|
|
86
|
+
subject=subject,
|
|
87
|
+
)
|
|
88
|
+
samples.append(sample)
|
|
89
|
+
if limit is not None and len(samples) >= limit:
|
|
90
|
+
break
|
|
91
|
+
return samples
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
def _load_from_huggingface(
|
|
95
|
+
*, dataset: str, split: str, subset: str | None
|
|
96
|
+
) -> Iterable[dict[str, Any]]:
|
|
97
|
+
try:
|
|
98
|
+
from datasets import load_dataset
|
|
99
|
+
except ImportError as exc: # pragma: no cover - optional dependency
|
|
100
|
+
raise RuntimeError(
|
|
101
|
+
"datasets is required to load competition math benchmarks from Hugging Face. "
|
|
102
|
+
"Install it via `uv pip install '.[hf]'`."
|
|
103
|
+
) from exc
|
|
104
|
+
|
|
105
|
+
if subset:
|
|
106
|
+
hf_dataset = load_dataset(dataset, subset, split=split)
|
|
107
|
+
else:
|
|
108
|
+
hf_dataset = load_dataset(dataset, split=split)
|
|
109
|
+
for row in hf_dataset:
|
|
110
|
+
yield dict(row)
|
|
111
|
+
|
|
112
|
+
|
|
113
|
+
def _load_from_local(root: Path) -> Iterator[dict[str, Any]]:
|
|
114
|
+
if not root.exists():
|
|
115
|
+
raise FileNotFoundError(f"Local dataset directory not found: {root}")
|
|
116
|
+
|
|
117
|
+
for path in root.rglob("*"):
|
|
118
|
+
if path.suffix.lower() == ".json":
|
|
119
|
+
with path.open("r", encoding="utf-8") as handle:
|
|
120
|
+
row = json.load(handle)
|
|
121
|
+
row.setdefault("id", path.stem)
|
|
122
|
+
yield row
|
|
123
|
+
elif path.suffix.lower() in {".jsonl", ".ndjson"}:
|
|
124
|
+
with path.open("r", encoding="utf-8") as handle:
|
|
125
|
+
for line_num, line in enumerate(handle, start=1):
|
|
126
|
+
line = line.strip()
|
|
127
|
+
if not line:
|
|
128
|
+
continue
|
|
129
|
+
row = json.loads(line)
|
|
130
|
+
row.setdefault("id", f"{path.stem}-{line_num}")
|
|
131
|
+
yield row
|
|
132
|
+
|
|
133
|
+
|
|
134
|
+
def _extract_subject(row: dict[str, Any]) -> str | None:
|
|
135
|
+
for key in (
|
|
136
|
+
"subject",
|
|
137
|
+
"category",
|
|
138
|
+
"topic",
|
|
139
|
+
"domain",
|
|
140
|
+
"contest",
|
|
141
|
+
"source",
|
|
142
|
+
"level",
|
|
143
|
+
):
|
|
144
|
+
value = row.get(key)
|
|
145
|
+
if value:
|
|
146
|
+
return str(value)
|
|
147
|
+
return None
|
|
148
|
+
|
|
149
|
+
|
|
150
|
+
def _extract_problem(row: dict[str, Any]) -> str:
|
|
151
|
+
for key in (
|
|
152
|
+
"problem",
|
|
153
|
+
"problem_text",
|
|
154
|
+
"problem_statement",
|
|
155
|
+
"question",
|
|
156
|
+
"prompt",
|
|
157
|
+
"problem_markdown",
|
|
158
|
+
):
|
|
159
|
+
value = row.get(key)
|
|
160
|
+
if value:
|
|
161
|
+
return str(value)
|
|
162
|
+
return ""
|
|
163
|
+
|
|
164
|
+
|
|
165
|
+
def _extract_solution(row: dict[str, Any]) -> str:
|
|
166
|
+
for key in (
|
|
167
|
+
"solution",
|
|
168
|
+
"solution_text",
|
|
169
|
+
"solution_markdown",
|
|
170
|
+
"answer_explanation",
|
|
171
|
+
"worked_solution",
|
|
172
|
+
"reasoning",
|
|
173
|
+
):
|
|
174
|
+
value = row.get(key)
|
|
175
|
+
if value:
|
|
176
|
+
return str(value)
|
|
177
|
+
return ""
|
|
178
|
+
|
|
179
|
+
|
|
180
|
+
def _extract_answer(row: dict[str, Any]) -> str:
|
|
181
|
+
for key in (
|
|
182
|
+
"answer",
|
|
183
|
+
"final_answer",
|
|
184
|
+
"ground_truth",
|
|
185
|
+
"answer_text",
|
|
186
|
+
"answer_value",
|
|
187
|
+
):
|
|
188
|
+
value = row.get(key)
|
|
189
|
+
if value is not None:
|
|
190
|
+
return str(value).strip()
|
|
191
|
+
return ""
|
|
192
|
+
|
|
193
|
+
|
|
194
|
+
def _extract_level(row: dict[str, Any]) -> str | int:
|
|
195
|
+
for key in ("difficulty", "level", "year"):
|
|
196
|
+
value = row.get(key)
|
|
197
|
+
if value:
|
|
198
|
+
return value
|
|
199
|
+
return "unknown"
|
|
200
|
+
|
|
201
|
+
|
|
202
|
+
def _row_to_sample(
|
|
203
|
+
*,
|
|
204
|
+
row: dict[str, Any],
|
|
205
|
+
index: int,
|
|
206
|
+
dataset: str,
|
|
207
|
+
subject: str,
|
|
208
|
+
) -> CompetitionMathSample:
|
|
209
|
+
unique_id = (
|
|
210
|
+
row.get("id")
|
|
211
|
+
or row.get("problem_id")
|
|
212
|
+
or row.get("unique_id")
|
|
213
|
+
or f"{dataset.replace('/', '-')}-{index:05d}"
|
|
214
|
+
)
|
|
215
|
+
problem = _extract_problem(row)
|
|
216
|
+
solution = _extract_solution(row)
|
|
217
|
+
answer = _extract_answer(row)
|
|
218
|
+
level = _extract_level(row)
|
|
219
|
+
core_keys = {
|
|
220
|
+
"id",
|
|
221
|
+
"problem_id",
|
|
222
|
+
"unique_id",
|
|
223
|
+
"problem",
|
|
224
|
+
"problem_text",
|
|
225
|
+
"problem_statement",
|
|
226
|
+
"question",
|
|
227
|
+
"prompt",
|
|
228
|
+
"problem_markdown",
|
|
229
|
+
"solution",
|
|
230
|
+
"solution_text",
|
|
231
|
+
"solution_markdown",
|
|
232
|
+
"answer_explanation",
|
|
233
|
+
"worked_solution",
|
|
234
|
+
"reasoning",
|
|
235
|
+
"answer",
|
|
236
|
+
"final_answer",
|
|
237
|
+
"ground_truth",
|
|
238
|
+
"answer_text",
|
|
239
|
+
"answer_value",
|
|
240
|
+
"difficulty",
|
|
241
|
+
"level",
|
|
242
|
+
"year",
|
|
243
|
+
"subject",
|
|
244
|
+
"category",
|
|
245
|
+
"topic",
|
|
246
|
+
"domain",
|
|
247
|
+
"contest",
|
|
248
|
+
"source",
|
|
249
|
+
}
|
|
250
|
+
metadata = {key: value for key, value in row.items() if key not in core_keys}
|
|
251
|
+
sample = CompetitionMathSample.model_validate(
|
|
252
|
+
{
|
|
253
|
+
"unique_id": str(unique_id),
|
|
254
|
+
"problem": problem,
|
|
255
|
+
"solution": solution,
|
|
256
|
+
"answer": answer,
|
|
257
|
+
"subject": str(subject),
|
|
258
|
+
"level": level,
|
|
259
|
+
"metadata": metadata,
|
|
260
|
+
}
|
|
261
|
+
)
|
|
262
|
+
return sample
|
|
263
|
+
|
|
264
|
+
|
|
265
|
+
__all__ = ["CompetitionMathSample", "load_competition_math"]
|
themis/datasets/coqa.py
ADDED
|
@@ -0,0 +1,133 @@
|
|
|
1
|
+
"""Helpers for working with the stanfordnlp/coqa dataset."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import json
|
|
6
|
+
from pathlib import Path
|
|
7
|
+
from typing import Any, Iterable, Iterator, List, Sequence
|
|
8
|
+
|
|
9
|
+
from pydantic import BaseModel, Field, field_validator
|
|
10
|
+
|
|
11
|
+
_DATASET_NAME = "stanfordnlp/coqa"
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class CoQaSample(BaseModel):
|
|
15
|
+
unique_id: str
|
|
16
|
+
story: str
|
|
17
|
+
question: str
|
|
18
|
+
answer: str
|
|
19
|
+
metadata: dict[str, Any] = Field(default_factory=dict)
|
|
20
|
+
|
|
21
|
+
@field_validator("metadata", mode="before")
|
|
22
|
+
@classmethod
|
|
23
|
+
def _ensure_metadata(cls, value: Any) -> dict[str, Any]:
|
|
24
|
+
return dict(value or {})
|
|
25
|
+
|
|
26
|
+
def to_generation_example(self) -> dict[str, Any]:
|
|
27
|
+
return {
|
|
28
|
+
"unique_id": self.unique_id,
|
|
29
|
+
"story": self.story,
|
|
30
|
+
"question": self.question,
|
|
31
|
+
"answer": self.answer,
|
|
32
|
+
"metadata": dict(self.metadata),
|
|
33
|
+
}
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
def load_coqa(
|
|
37
|
+
*,
|
|
38
|
+
split: str = "validation", # Test set usually has no labels
|
|
39
|
+
limit: int | None = None,
|
|
40
|
+
source: str = "huggingface",
|
|
41
|
+
data_dir: str | Path | None = None,
|
|
42
|
+
) -> List[CoQaSample]:
|
|
43
|
+
"""Load CoQA samples from Hugging Face or a local directory."""
|
|
44
|
+
|
|
45
|
+
if source not in {"huggingface", "local"}:
|
|
46
|
+
raise ValueError(
|
|
47
|
+
f"Unsupported source '{source}'. Expected one of: 'huggingface', 'local'."
|
|
48
|
+
)
|
|
49
|
+
|
|
50
|
+
if source == "huggingface":
|
|
51
|
+
rows = _load_from_huggingface(split=split)
|
|
52
|
+
else:
|
|
53
|
+
if data_dir is None:
|
|
54
|
+
raise ValueError(
|
|
55
|
+
"data_dir must be provided when source='local'. "
|
|
56
|
+
"Pass dataset.data_dir in configs or --data-dir on the CLI."
|
|
57
|
+
)
|
|
58
|
+
rows = _load_from_local(Path(data_dir))
|
|
59
|
+
|
|
60
|
+
samples: list[CoQaSample] = []
|
|
61
|
+
for index, row in enumerate(rows, start=1):
|
|
62
|
+
# CoQA has multiple questions per story. We need to flatten them.
|
|
63
|
+
# But wait, usually we want to evaluate turn-by-turn or just single turn.
|
|
64
|
+
# For simplicity, let's flatten: each question is a sample.
|
|
65
|
+
# Or maybe just take the first one? No, that's wasteful.
|
|
66
|
+
# Let's see the structure:
|
|
67
|
+
# 'questions': ['q1', 'q2'], 'answers': {'input_text': ['a1', 'a2'], ...}
|
|
68
|
+
|
|
69
|
+
story = row.get("story") or ""
|
|
70
|
+
questions = row.get("questions") or []
|
|
71
|
+
answers_data = row.get("answers") or {}
|
|
72
|
+
answers = answers_data.get("input_text") or []
|
|
73
|
+
|
|
74
|
+
if len(questions) != len(answers):
|
|
75
|
+
# Mismatch, skip or warn?
|
|
76
|
+
# Let's just take the minimum length
|
|
77
|
+
min_len = min(len(questions), len(answers))
|
|
78
|
+
questions = questions[:min_len]
|
|
79
|
+
answers = answers[:min_len]
|
|
80
|
+
|
|
81
|
+
for i, (q, a) in enumerate(zip(questions, answers)):
|
|
82
|
+
sample = CoQaSample(
|
|
83
|
+
unique_id=f"coqa-{index:05d}-{i:02d}",
|
|
84
|
+
story=story,
|
|
85
|
+
question=str(q),
|
|
86
|
+
answer=str(a),
|
|
87
|
+
metadata={"turn": i, "source": row.get("source")},
|
|
88
|
+
)
|
|
89
|
+
samples.append(sample)
|
|
90
|
+
if limit is not None and len(samples) >= limit:
|
|
91
|
+
break
|
|
92
|
+
|
|
93
|
+
if limit is not None and len(samples) >= limit:
|
|
94
|
+
break
|
|
95
|
+
|
|
96
|
+
return samples
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
def _load_from_huggingface(*, split: str) -> Iterable[dict[str, Any]]:
|
|
100
|
+
try:
|
|
101
|
+
from datasets import load_dataset
|
|
102
|
+
except ImportError as exc: # pragma: no cover - optional dependency
|
|
103
|
+
raise RuntimeError(
|
|
104
|
+
"datasets is required to load CoQA from Hugging Face. Install it via `uv pip install '.[hf]'`."
|
|
105
|
+
) from exc
|
|
106
|
+
|
|
107
|
+
dataset = load_dataset(_DATASET_NAME, split=split)
|
|
108
|
+
for row in dataset:
|
|
109
|
+
yield dict(row)
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
def _load_from_local(root: Path) -> Iterator[dict[str, Any]]:
|
|
113
|
+
if not root.exists():
|
|
114
|
+
raise FileNotFoundError(f"Local dataset directory not found: {root}")
|
|
115
|
+
|
|
116
|
+
for path in root.rglob("*"):
|
|
117
|
+
if path.suffix.lower() == ".json":
|
|
118
|
+
with path.open("r", encoding="utf-8") as handle:
|
|
119
|
+
row = json.load(handle)
|
|
120
|
+
row.setdefault("id", path.stem)
|
|
121
|
+
yield row
|
|
122
|
+
elif path.suffix.lower() in {".jsonl", ".ndjson"}:
|
|
123
|
+
with path.open("r", encoding="utf-8") as handle:
|
|
124
|
+
for line_num, line in enumerate(handle, start=1):
|
|
125
|
+
line = line.strip()
|
|
126
|
+
if not line:
|
|
127
|
+
continue
|
|
128
|
+
row = json.loads(line)
|
|
129
|
+
row.setdefault("id", f"{path.stem}-{line_num}")
|
|
130
|
+
yield row
|
|
131
|
+
|
|
132
|
+
|
|
133
|
+
__all__ = ["CoQaSample", "load_coqa"]
|