themis-eval 0.1.0__py3-none-any.whl → 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- themis/cli/__init__.py +5 -0
- themis/cli/__main__.py +6 -0
- themis/cli/commands/__init__.py +19 -0
- themis/cli/commands/benchmarks.py +221 -0
- themis/cli/commands/comparison.py +394 -0
- themis/cli/commands/config_commands.py +244 -0
- themis/cli/commands/cost.py +214 -0
- themis/cli/commands/demo.py +68 -0
- themis/cli/commands/info.py +90 -0
- themis/cli/commands/leaderboard.py +362 -0
- themis/cli/commands/math_benchmarks.py +318 -0
- themis/cli/commands/mcq_benchmarks.py +207 -0
- themis/cli/commands/sample_run.py +244 -0
- themis/cli/commands/visualize.py +299 -0
- themis/cli/main.py +93 -0
- themis/cli/new_project.py +33 -0
- themis/cli/utils.py +51 -0
- themis/config/__init__.py +19 -0
- themis/config/loader.py +27 -0
- themis/config/registry.py +34 -0
- themis/config/runtime.py +214 -0
- themis/config/schema.py +112 -0
- themis/core/__init__.py +5 -0
- themis/core/conversation.py +354 -0
- themis/core/entities.py +164 -0
- themis/core/serialization.py +231 -0
- themis/core/tools.py +393 -0
- themis/core/types.py +141 -0
- themis/datasets/__init__.py +273 -0
- themis/datasets/base.py +264 -0
- themis/datasets/commonsense_qa.py +174 -0
- themis/datasets/competition_math.py +265 -0
- themis/datasets/coqa.py +133 -0
- themis/datasets/gpqa.py +190 -0
- themis/datasets/gsm8k.py +123 -0
- themis/datasets/gsm_symbolic.py +124 -0
- themis/datasets/math500.py +122 -0
- themis/datasets/med_qa.py +179 -0
- themis/datasets/medmcqa.py +169 -0
- themis/datasets/mmlu_pro.py +262 -0
- themis/datasets/piqa.py +146 -0
- themis/datasets/registry.py +201 -0
- themis/datasets/schema.py +245 -0
- themis/datasets/sciq.py +150 -0
- themis/datasets/social_i_qa.py +151 -0
- themis/datasets/super_gpqa.py +263 -0
- themis/evaluation/__init__.py +1 -0
- themis/evaluation/conditional.py +410 -0
- themis/evaluation/extractors/__init__.py +19 -0
- themis/evaluation/extractors/error_taxonomy_extractor.py +80 -0
- themis/evaluation/extractors/exceptions.py +7 -0
- themis/evaluation/extractors/identity_extractor.py +29 -0
- themis/evaluation/extractors/json_field_extractor.py +45 -0
- themis/evaluation/extractors/math_verify_extractor.py +37 -0
- themis/evaluation/extractors/regex_extractor.py +43 -0
- themis/evaluation/math_verify_utils.py +87 -0
- themis/evaluation/metrics/__init__.py +21 -0
- themis/evaluation/metrics/composite_metric.py +47 -0
- themis/evaluation/metrics/consistency_metric.py +80 -0
- themis/evaluation/metrics/exact_match.py +51 -0
- themis/evaluation/metrics/length_difference_tolerance.py +33 -0
- themis/evaluation/metrics/math_verify_accuracy.py +40 -0
- themis/evaluation/metrics/pairwise_judge_metric.py +141 -0
- themis/evaluation/metrics/response_length.py +33 -0
- themis/evaluation/metrics/rubric_judge_metric.py +134 -0
- themis/evaluation/pipeline.py +49 -0
- themis/evaluation/pipelines/__init__.py +15 -0
- themis/evaluation/pipelines/composable_pipeline.py +357 -0
- themis/evaluation/pipelines/standard_pipeline.py +288 -0
- themis/evaluation/reports.py +293 -0
- themis/evaluation/statistics/__init__.py +53 -0
- themis/evaluation/statistics/bootstrap.py +79 -0
- themis/evaluation/statistics/confidence_intervals.py +121 -0
- themis/evaluation/statistics/distributions.py +207 -0
- themis/evaluation/statistics/effect_sizes.py +124 -0
- themis/evaluation/statistics/hypothesis_tests.py +305 -0
- themis/evaluation/statistics/types.py +139 -0
- themis/evaluation/strategies/__init__.py +13 -0
- themis/evaluation/strategies/attempt_aware_evaluation_strategy.py +51 -0
- themis/evaluation/strategies/default_evaluation_strategy.py +25 -0
- themis/evaluation/strategies/evaluation_strategy.py +24 -0
- themis/evaluation/strategies/judge_evaluation_strategy.py +64 -0
- themis/experiment/__init__.py +5 -0
- themis/experiment/builder.py +151 -0
- themis/experiment/cache_manager.py +129 -0
- themis/experiment/comparison.py +631 -0
- themis/experiment/cost.py +310 -0
- themis/experiment/definitions.py +62 -0
- themis/experiment/export.py +690 -0
- themis/experiment/export_csv.py +159 -0
- themis/experiment/integration_manager.py +104 -0
- themis/experiment/math.py +192 -0
- themis/experiment/mcq.py +169 -0
- themis/experiment/orchestrator.py +373 -0
- themis/experiment/pricing.py +317 -0
- themis/experiment/storage.py +255 -0
- themis/experiment/visualization.py +588 -0
- themis/generation/__init__.py +1 -0
- themis/generation/agentic_runner.py +420 -0
- themis/generation/batching.py +254 -0
- themis/generation/clients.py +143 -0
- themis/generation/conversation_runner.py +236 -0
- themis/generation/plan.py +456 -0
- themis/generation/providers/litellm_provider.py +221 -0
- themis/generation/providers/vllm_provider.py +135 -0
- themis/generation/router.py +34 -0
- themis/generation/runner.py +207 -0
- themis/generation/strategies.py +98 -0
- themis/generation/templates.py +71 -0
- themis/generation/turn_strategies.py +393 -0
- themis/generation/types.py +9 -0
- themis/integrations/__init__.py +0 -0
- themis/integrations/huggingface.py +61 -0
- themis/integrations/wandb.py +65 -0
- themis/interfaces/__init__.py +83 -0
- themis/project/__init__.py +20 -0
- themis/project/definitions.py +98 -0
- themis/project/patterns.py +230 -0
- themis/providers/__init__.py +5 -0
- themis/providers/registry.py +39 -0
- themis/utils/api_generator.py +379 -0
- themis/utils/cost_tracking.py +376 -0
- themis/utils/dashboard.py +452 -0
- themis/utils/logging_utils.py +41 -0
- themis/utils/progress.py +58 -0
- themis/utils/tracing.py +320 -0
- {themis_eval-0.1.0.dist-info → themis_eval-0.1.1.dist-info}/METADATA +1 -1
- themis_eval-0.1.1.dist-info/RECORD +134 -0
- themis_eval-0.1.0.dist-info/RECORD +0 -8
- {themis_eval-0.1.0.dist-info → themis_eval-0.1.1.dist-info}/WHEEL +0 -0
- {themis_eval-0.1.0.dist-info → themis_eval-0.1.1.dist-info}/licenses/LICENSE +0 -0
- {themis_eval-0.1.0.dist-info → themis_eval-0.1.1.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,273 @@
|
|
|
1
|
+
"""Dataset helpers for Themis experiments."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from typing import Any
|
|
6
|
+
|
|
7
|
+
from . import (
|
|
8
|
+
competition_math,
|
|
9
|
+
commonsense_qa,
|
|
10
|
+
coqa,
|
|
11
|
+
gpqa,
|
|
12
|
+
gsm_symbolic,
|
|
13
|
+
gsm8k,
|
|
14
|
+
math500,
|
|
15
|
+
med_qa,
|
|
16
|
+
medmcqa,
|
|
17
|
+
mmlu_pro,
|
|
18
|
+
piqa,
|
|
19
|
+
sciq,
|
|
20
|
+
social_i_qa,
|
|
21
|
+
super_gpqa,
|
|
22
|
+
)
|
|
23
|
+
from .registry import (
|
|
24
|
+
create_dataset,
|
|
25
|
+
is_dataset_registered,
|
|
26
|
+
list_datasets,
|
|
27
|
+
register_dataset,
|
|
28
|
+
unregister_dataset,
|
|
29
|
+
)
|
|
30
|
+
|
|
31
|
+
# Factory functions for built-in datasets
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def _create_math500(options: dict[str, Any]) -> list[dict[str, Any]]:
|
|
35
|
+
"""Factory for MATH-500 dataset."""
|
|
36
|
+
samples = math500.load_math500(
|
|
37
|
+
source=options.get("source", "huggingface"),
|
|
38
|
+
data_dir=options.get("data_dir"),
|
|
39
|
+
split=options.get("split", "test"),
|
|
40
|
+
limit=options.get("limit"),
|
|
41
|
+
subjects=options.get("subjects"),
|
|
42
|
+
)
|
|
43
|
+
return [sample.to_generation_example() for sample in samples]
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
def _create_competition_math(options: dict[str, Any]) -> list[dict[str, Any]]:
|
|
47
|
+
"""Factory for competition math datasets (AIME, AMC, etc.)."""
|
|
48
|
+
# Get dataset and subset from options
|
|
49
|
+
dataset = options.get("dataset")
|
|
50
|
+
if not dataset:
|
|
51
|
+
raise ValueError(
|
|
52
|
+
"Competition math requires 'dataset' option "
|
|
53
|
+
"(e.g., 'math-ai/aime24', 'math-ai/amc23')"
|
|
54
|
+
)
|
|
55
|
+
|
|
56
|
+
samples = competition_math.load_competition_math(
|
|
57
|
+
dataset=dataset,
|
|
58
|
+
subset=options.get("subset"),
|
|
59
|
+
source=options.get("source", "huggingface"),
|
|
60
|
+
data_dir=options.get("data_dir"),
|
|
61
|
+
split=options.get("split", "test"),
|
|
62
|
+
limit=options.get("limit"),
|
|
63
|
+
subjects=options.get("subjects"),
|
|
64
|
+
)
|
|
65
|
+
return [sample.to_generation_example() for sample in samples]
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
def _create_super_gpqa(options: dict[str, Any]) -> list[dict[str, Any]]:
|
|
69
|
+
"""Factory for SuperGPQA dataset."""
|
|
70
|
+
samples = super_gpqa.load_super_gpqa(
|
|
71
|
+
source=options.get("source", "huggingface"),
|
|
72
|
+
data_dir=options.get("data_dir"),
|
|
73
|
+
split=options.get("split", "test"),
|
|
74
|
+
limit=options.get("limit"),
|
|
75
|
+
subjects=options.get("subjects"),
|
|
76
|
+
)
|
|
77
|
+
return [sample.to_generation_example() for sample in samples]
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
def _create_mmlu_pro(options: dict[str, Any]) -> list[dict[str, Any]]:
|
|
81
|
+
"""Factory for MMLU-Pro dataset."""
|
|
82
|
+
samples = mmlu_pro.load_mmlu_pro(
|
|
83
|
+
source=options.get("source", "huggingface"),
|
|
84
|
+
data_dir=options.get("data_dir"),
|
|
85
|
+
split=options.get("split", "test"),
|
|
86
|
+
limit=options.get("limit"),
|
|
87
|
+
subjects=options.get("subjects"),
|
|
88
|
+
)
|
|
89
|
+
return [sample.to_generation_example() for sample in samples]
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
def _create_gsm8k(options: dict[str, Any]) -> list[dict[str, Any]]:
|
|
93
|
+
"""Factory for GSM8K dataset."""
|
|
94
|
+
samples = gsm8k.load_gsm8k(
|
|
95
|
+
source=options.get("source", "huggingface"),
|
|
96
|
+
data_dir=options.get("data_dir"),
|
|
97
|
+
split=options.get("split", "test"),
|
|
98
|
+
limit=options.get("limit"),
|
|
99
|
+
subset=options.get("subset", "main"),
|
|
100
|
+
)
|
|
101
|
+
return [sample.to_generation_example() for sample in samples]
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
def _create_gpqa(options: dict[str, Any]) -> list[dict[str, Any]]:
|
|
105
|
+
"""Factory for GPQA dataset."""
|
|
106
|
+
samples = gpqa.load_gpqa(
|
|
107
|
+
source=options.get("source", "huggingface"),
|
|
108
|
+
data_dir=options.get("data_dir"),
|
|
109
|
+
split=options.get("split", "test"),
|
|
110
|
+
limit=options.get("limit"),
|
|
111
|
+
subset=options.get("subset", "gpqa_diamond"),
|
|
112
|
+
)
|
|
113
|
+
return [sample.to_generation_example() for sample in samples]
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
def _create_gsm_symbolic(options: dict[str, Any]) -> list[dict[str, Any]]:
|
|
117
|
+
"""Factory for GSM-Symbolic dataset."""
|
|
118
|
+
samples = gsm_symbolic.load_gsm_symbolic(
|
|
119
|
+
source=options.get("source", "huggingface"),
|
|
120
|
+
data_dir=options.get("data_dir"),
|
|
121
|
+
split=options.get("split", "test"),
|
|
122
|
+
limit=options.get("limit"),
|
|
123
|
+
subset=options.get("subset", "main"),
|
|
124
|
+
)
|
|
125
|
+
return [sample.to_generation_example() for sample in samples]
|
|
126
|
+
|
|
127
|
+
|
|
128
|
+
def _create_medmcqa(options: dict[str, Any]) -> list[dict[str, Any]]:
|
|
129
|
+
"""Factory for MedMCQA dataset."""
|
|
130
|
+
samples = medmcqa.load_medmcqa(
|
|
131
|
+
source=options.get("source", "huggingface"),
|
|
132
|
+
data_dir=options.get("data_dir"),
|
|
133
|
+
split=options.get("split", "test"),
|
|
134
|
+
limit=options.get("limit"),
|
|
135
|
+
subset=options.get("subset"),
|
|
136
|
+
)
|
|
137
|
+
return [sample.to_generation_example() for sample in samples]
|
|
138
|
+
|
|
139
|
+
|
|
140
|
+
def _create_med_qa(options: dict[str, Any]) -> list[dict[str, Any]]:
|
|
141
|
+
"""Factory for MedQA dataset."""
|
|
142
|
+
samples = med_qa.load_med_qa(
|
|
143
|
+
source=options.get("source", "huggingface"),
|
|
144
|
+
data_dir=options.get("data_dir"),
|
|
145
|
+
split=options.get("split", "test"),
|
|
146
|
+
limit=options.get("limit"),
|
|
147
|
+
subset=options.get("subset", "med_qa_en_bigbio_qa"),
|
|
148
|
+
)
|
|
149
|
+
return [sample.to_generation_example() for sample in samples]
|
|
150
|
+
|
|
151
|
+
|
|
152
|
+
def _create_sciq(options: dict[str, Any]) -> list[dict[str, Any]]:
|
|
153
|
+
"""Factory for SciQ dataset."""
|
|
154
|
+
samples = sciq.load_sciq(
|
|
155
|
+
source=options.get("source", "huggingface"),
|
|
156
|
+
data_dir=options.get("data_dir"),
|
|
157
|
+
split=options.get("split", "test"),
|
|
158
|
+
limit=options.get("limit"),
|
|
159
|
+
)
|
|
160
|
+
return [sample.to_generation_example() for sample in samples]
|
|
161
|
+
|
|
162
|
+
|
|
163
|
+
def _create_commonsense_qa(options: dict[str, Any]) -> list[dict[str, Any]]:
|
|
164
|
+
"""Factory for CommonsenseQA dataset."""
|
|
165
|
+
samples = commonsense_qa.load_commonsense_qa(
|
|
166
|
+
source=options.get("source", "huggingface"),
|
|
167
|
+
data_dir=options.get("data_dir"),
|
|
168
|
+
split=options.get("split", "validation"),
|
|
169
|
+
limit=options.get("limit"),
|
|
170
|
+
)
|
|
171
|
+
return [sample.to_generation_example() for sample in samples]
|
|
172
|
+
|
|
173
|
+
|
|
174
|
+
def _create_piqa(options: dict[str, Any]) -> list[dict[str, Any]]:
|
|
175
|
+
"""Factory for PIQA dataset."""
|
|
176
|
+
samples = piqa.load_piqa(
|
|
177
|
+
source=options.get("source", "huggingface"),
|
|
178
|
+
data_dir=options.get("data_dir"),
|
|
179
|
+
split=options.get("split", "validation"),
|
|
180
|
+
limit=options.get("limit"),
|
|
181
|
+
)
|
|
182
|
+
return [sample.to_generation_example() for sample in samples]
|
|
183
|
+
|
|
184
|
+
|
|
185
|
+
def _create_social_i_qa(options: dict[str, Any]) -> list[dict[str, Any]]:
|
|
186
|
+
"""Factory for Social IQA dataset."""
|
|
187
|
+
samples = social_i_qa.load_social_i_qa(
|
|
188
|
+
source=options.get("source", "huggingface"),
|
|
189
|
+
data_dir=options.get("data_dir"),
|
|
190
|
+
split=options.get("split", "validation"),
|
|
191
|
+
limit=options.get("limit"),
|
|
192
|
+
)
|
|
193
|
+
return [sample.to_generation_example() for sample in samples]
|
|
194
|
+
|
|
195
|
+
|
|
196
|
+
def _create_coqa(options: dict[str, Any]) -> list[dict[str, Any]]:
|
|
197
|
+
"""Factory for CoQA dataset."""
|
|
198
|
+
samples = coqa.load_coqa(
|
|
199
|
+
source=options.get("source", "huggingface"),
|
|
200
|
+
data_dir=options.get("data_dir"),
|
|
201
|
+
split=options.get("split", "validation"),
|
|
202
|
+
limit=options.get("limit"),
|
|
203
|
+
)
|
|
204
|
+
return [sample.to_generation_example() for sample in samples]
|
|
205
|
+
|
|
206
|
+
|
|
207
|
+
# Auto-register built-in datasets
|
|
208
|
+
register_dataset("math500", _create_math500)
|
|
209
|
+
register_dataset("competition_math", _create_competition_math)
|
|
210
|
+
register_dataset("supergpqa", _create_super_gpqa)
|
|
211
|
+
register_dataset("mmlu-pro", _create_mmlu_pro)
|
|
212
|
+
register_dataset("gsm8k", _create_gsm8k)
|
|
213
|
+
register_dataset("gpqa", _create_gpqa)
|
|
214
|
+
register_dataset("gsm-symbolic", _create_gsm_symbolic)
|
|
215
|
+
register_dataset("medmcqa", _create_medmcqa)
|
|
216
|
+
register_dataset("med_qa", _create_med_qa)
|
|
217
|
+
register_dataset("sciq", _create_sciq)
|
|
218
|
+
register_dataset("commonsense_qa", _create_commonsense_qa)
|
|
219
|
+
register_dataset("piqa", _create_piqa)
|
|
220
|
+
register_dataset("social_i_qa", _create_social_i_qa)
|
|
221
|
+
register_dataset("coqa", _create_coqa)
|
|
222
|
+
|
|
223
|
+
|
|
224
|
+
# Also register specific competition datasets as aliases
|
|
225
|
+
def _create_aime24(options: dict[str, Any]) -> list[dict[str, Any]]:
|
|
226
|
+
return _create_competition_math({**options, "dataset": "math-ai/aime24"})
|
|
227
|
+
|
|
228
|
+
|
|
229
|
+
def _create_aime25(options: dict[str, Any]) -> list[dict[str, Any]]:
|
|
230
|
+
return _create_competition_math({**options, "dataset": "math-ai/aime25"})
|
|
231
|
+
|
|
232
|
+
|
|
233
|
+
def _create_amc23(options: dict[str, Any]) -> list[dict[str, Any]]:
|
|
234
|
+
return _create_competition_math({**options, "dataset": "math-ai/amc23"})
|
|
235
|
+
|
|
236
|
+
|
|
237
|
+
def _create_olympiadbench(options: dict[str, Any]) -> list[dict[str, Any]]:
|
|
238
|
+
return _create_competition_math({**options, "dataset": "math-ai/olympiadbench"})
|
|
239
|
+
|
|
240
|
+
|
|
241
|
+
def _create_beyondaime(options: dict[str, Any]) -> list[dict[str, Any]]:
|
|
242
|
+
return _create_competition_math({**options, "dataset": "ByteDance-Seed/BeyondAIME"})
|
|
243
|
+
|
|
244
|
+
|
|
245
|
+
register_dataset("aime24", _create_aime24)
|
|
246
|
+
register_dataset("aime25", _create_aime25)
|
|
247
|
+
register_dataset("amc23", _create_amc23)
|
|
248
|
+
register_dataset("olympiadbench", _create_olympiadbench)
|
|
249
|
+
register_dataset("beyondaime", _create_beyondaime)
|
|
250
|
+
|
|
251
|
+
__all__ = [
|
|
252
|
+
# Legacy module exports
|
|
253
|
+
"competition_math",
|
|
254
|
+
"commonsense_qa",
|
|
255
|
+
"coqa",
|
|
256
|
+
"gpqa",
|
|
257
|
+
"gsm_symbolic",
|
|
258
|
+
"gsm8k",
|
|
259
|
+
"math500",
|
|
260
|
+
"med_qa",
|
|
261
|
+
"medmcqa",
|
|
262
|
+
"mmlu_pro",
|
|
263
|
+
"piqa",
|
|
264
|
+
"sciq",
|
|
265
|
+
"social_i_qa",
|
|
266
|
+
"super_gpqa",
|
|
267
|
+
# Registry functions
|
|
268
|
+
"register_dataset",
|
|
269
|
+
"unregister_dataset",
|
|
270
|
+
"create_dataset",
|
|
271
|
+
"list_datasets",
|
|
272
|
+
"is_dataset_registered",
|
|
273
|
+
]
|
themis/datasets/base.py
ADDED
|
@@ -0,0 +1,264 @@
|
|
|
1
|
+
"""Base dataset implementation with schema support.
|
|
2
|
+
|
|
3
|
+
This module provides a base class that implements common dataset operations
|
|
4
|
+
like filtering, limiting, and stratification.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from __future__ import annotations
|
|
8
|
+
|
|
9
|
+
import logging
|
|
10
|
+
import random
|
|
11
|
+
from collections import defaultdict
|
|
12
|
+
from typing import Any, Callable, Iterable
|
|
13
|
+
|
|
14
|
+
from themis.datasets import schema as dataset_schema
|
|
15
|
+
|
|
16
|
+
logger = logging.getLogger(__name__)
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class BaseDataset:
|
|
20
|
+
"""Base implementation for dataset classes that implement DatasetAdapter protocol.
|
|
21
|
+
|
|
22
|
+
This class provides a reusable implementation of common dataset operations
|
|
23
|
+
including filtering, limiting, and stratification. It satisfies the
|
|
24
|
+
DatasetAdapter protocol by implementing iter_samples().
|
|
25
|
+
|
|
26
|
+
The class implements the structural DatasetAdapter protocol without
|
|
27
|
+
explicit inheritance, using duck typing. At runtime, instances will
|
|
28
|
+
satisfy isinstance(obj, DatasetAdapter) checks.
|
|
29
|
+
|
|
30
|
+
Subclasses should provide the initial samples, schema, and metadata.
|
|
31
|
+
|
|
32
|
+
Protocol Compliance:
|
|
33
|
+
Implements DatasetAdapter protocol via iter_samples() method
|
|
34
|
+
|
|
35
|
+
Examples:
|
|
36
|
+
class MyDataset(BaseDataset):
|
|
37
|
+
def __init__(self):
|
|
38
|
+
samples = [
|
|
39
|
+
{"id": "1", "problem": "What is 2+2?", "answer": "4"},
|
|
40
|
+
{"id": "2", "problem": "What is 3+3?", "answer": "6"},
|
|
41
|
+
]
|
|
42
|
+
schema = DatasetSchema(
|
|
43
|
+
id_field="id",
|
|
44
|
+
reference_field="answer",
|
|
45
|
+
required_fields={"id", "problem", "answer"},
|
|
46
|
+
)
|
|
47
|
+
metadata = DatasetMetadata(
|
|
48
|
+
name="SimpleArithmetic",
|
|
49
|
+
version="1.0",
|
|
50
|
+
total_samples=2,
|
|
51
|
+
)
|
|
52
|
+
super().__init__(samples, schema, metadata)
|
|
53
|
+
|
|
54
|
+
# Verify protocol compliance
|
|
55
|
+
>>> from themis.interfaces import DatasetAdapter
|
|
56
|
+
>>> dataset = MyDataset()
|
|
57
|
+
>>> isinstance(dataset, DatasetAdapter) # True
|
|
58
|
+
"""
|
|
59
|
+
|
|
60
|
+
def __init__(
|
|
61
|
+
self,
|
|
62
|
+
samples: Iterable[dict[str, Any]],
|
|
63
|
+
schema: dataset_schema.DatasetSchema,
|
|
64
|
+
metadata: dataset_schema.DatasetMetadata,
|
|
65
|
+
validate: bool = True,
|
|
66
|
+
):
|
|
67
|
+
"""Initialize dataset.
|
|
68
|
+
|
|
69
|
+
Args:
|
|
70
|
+
samples: Iterable of sample dictionaries
|
|
71
|
+
schema: Dataset schema
|
|
72
|
+
metadata: Dataset metadata
|
|
73
|
+
validate: Whether to validate samples against schema (default: True)
|
|
74
|
+
|
|
75
|
+
Raises:
|
|
76
|
+
ValueError: If validation is enabled and samples don't match schema
|
|
77
|
+
"""
|
|
78
|
+
self._samples = list(samples)
|
|
79
|
+
self._schema = schema
|
|
80
|
+
self._metadata = metadata
|
|
81
|
+
|
|
82
|
+
if validate:
|
|
83
|
+
self._validate_all()
|
|
84
|
+
|
|
85
|
+
# Update metadata total if not set
|
|
86
|
+
if self._metadata.total_samples is None:
|
|
87
|
+
self._metadata = dataset_schema.DatasetMetadata(
|
|
88
|
+
**{**self._metadata.__dict__, "total_samples": len(self._samples)}
|
|
89
|
+
)
|
|
90
|
+
|
|
91
|
+
def _validate_all(self) -> None:
|
|
92
|
+
"""Validate all samples against schema."""
|
|
93
|
+
logger.debug(
|
|
94
|
+
"Validating %d samples for dataset %s",
|
|
95
|
+
len(self._samples),
|
|
96
|
+
self._metadata.name,
|
|
97
|
+
)
|
|
98
|
+
|
|
99
|
+
for i, sample in enumerate(self._samples):
|
|
100
|
+
try:
|
|
101
|
+
self._schema.validate_sample(sample)
|
|
102
|
+
except ValueError as e:
|
|
103
|
+
logger.error("Validation failed for sample %d: %s", i, e)
|
|
104
|
+
raise ValueError(f"Sample {i} validation failed: {e}") from e
|
|
105
|
+
|
|
106
|
+
logger.debug("All samples validated successfully")
|
|
107
|
+
|
|
108
|
+
def iter_samples(self) -> Iterable[dict[str, Any]]:
|
|
109
|
+
"""Iterate over dataset samples."""
|
|
110
|
+
return iter(self._samples)
|
|
111
|
+
|
|
112
|
+
def get_schema(self) -> dataset_schema.DatasetSchema:
|
|
113
|
+
"""Get the dataset schema."""
|
|
114
|
+
return self._schema
|
|
115
|
+
|
|
116
|
+
def get_metadata(self) -> dataset_schema.DatasetMetadata:
|
|
117
|
+
"""Get dataset metadata."""
|
|
118
|
+
return self._metadata
|
|
119
|
+
|
|
120
|
+
def filter(self, predicate: Callable[[dict[str, Any]], bool]) -> BaseDataset:
|
|
121
|
+
"""Return filtered view of dataset.
|
|
122
|
+
|
|
123
|
+
Args:
|
|
124
|
+
predicate: Function that returns True for samples to keep
|
|
125
|
+
|
|
126
|
+
Returns:
|
|
127
|
+
New BaseDataset with filtered samples
|
|
128
|
+
"""
|
|
129
|
+
filtered_samples = [s for s in self._samples if predicate(s)]
|
|
130
|
+
logger.debug(
|
|
131
|
+
"Filtered dataset from %d to %d samples",
|
|
132
|
+
len(self._samples),
|
|
133
|
+
len(filtered_samples),
|
|
134
|
+
)
|
|
135
|
+
|
|
136
|
+
return BaseDataset(
|
|
137
|
+
samples=filtered_samples,
|
|
138
|
+
schema=self._schema,
|
|
139
|
+
metadata=self._metadata,
|
|
140
|
+
validate=False, # Already validated
|
|
141
|
+
)
|
|
142
|
+
|
|
143
|
+
def limit(self, n: int) -> BaseDataset:
|
|
144
|
+
"""Return dataset limited to first n samples.
|
|
145
|
+
|
|
146
|
+
Args:
|
|
147
|
+
n: Maximum number of samples
|
|
148
|
+
|
|
149
|
+
Returns:
|
|
150
|
+
New BaseDataset with limited samples
|
|
151
|
+
"""
|
|
152
|
+
limited_samples = self._samples[:n]
|
|
153
|
+
logger.debug(
|
|
154
|
+
"Limited dataset from %d to %d samples",
|
|
155
|
+
len(self._samples),
|
|
156
|
+
len(limited_samples),
|
|
157
|
+
)
|
|
158
|
+
|
|
159
|
+
return BaseDataset(
|
|
160
|
+
samples=limited_samples,
|
|
161
|
+
schema=self._schema,
|
|
162
|
+
metadata=self._metadata,
|
|
163
|
+
validate=False,
|
|
164
|
+
)
|
|
165
|
+
|
|
166
|
+
def stratify(
|
|
167
|
+
self, field: str, distribution: dict[str, float], seed: int | None = None
|
|
168
|
+
) -> BaseDataset:
|
|
169
|
+
"""Return stratified sample of dataset.
|
|
170
|
+
|
|
171
|
+
Args:
|
|
172
|
+
field: Field to stratify by
|
|
173
|
+
distribution: Desired distribution (values should sum to ~1.0)
|
|
174
|
+
seed: Random seed for reproducibility
|
|
175
|
+
|
|
176
|
+
Returns:
|
|
177
|
+
New BaseDataset with stratified samples
|
|
178
|
+
|
|
179
|
+
Raises:
|
|
180
|
+
ValueError: If field doesn't exist or distribution is invalid
|
|
181
|
+
"""
|
|
182
|
+
# Group samples by field value
|
|
183
|
+
groups: dict[Any, list[dict[str, Any]]] = defaultdict(list)
|
|
184
|
+
for sample in self._samples:
|
|
185
|
+
if field not in sample:
|
|
186
|
+
raise ValueError(f"Field '{field}' not found in sample")
|
|
187
|
+
groups[sample[field]].append(sample)
|
|
188
|
+
|
|
189
|
+
# Validate distribution
|
|
190
|
+
total_dist = sum(distribution.values())
|
|
191
|
+
if not (0.99 <= total_dist <= 1.01):
|
|
192
|
+
logger.warning("Distribution values sum to %f, expected ~1.0", total_dist)
|
|
193
|
+
|
|
194
|
+
# Calculate sample sizes for each group
|
|
195
|
+
total_samples = len(self._samples)
|
|
196
|
+
stratified_samples = []
|
|
197
|
+
|
|
198
|
+
if seed is not None:
|
|
199
|
+
rng = random.Random(seed)
|
|
200
|
+
else:
|
|
201
|
+
rng = random.Random()
|
|
202
|
+
|
|
203
|
+
for value, desired_ratio in distribution.items():
|
|
204
|
+
if value not in groups:
|
|
205
|
+
logger.warning(
|
|
206
|
+
"Value '%s' specified in distribution but not found in dataset",
|
|
207
|
+
value,
|
|
208
|
+
)
|
|
209
|
+
continue
|
|
210
|
+
|
|
211
|
+
group_samples = groups[value]
|
|
212
|
+
n_samples = int(total_samples * desired_ratio)
|
|
213
|
+
n_samples = min(n_samples, len(group_samples)) # Can't exceed available
|
|
214
|
+
|
|
215
|
+
# Sample from group
|
|
216
|
+
sampled = rng.sample(group_samples, n_samples)
|
|
217
|
+
stratified_samples.extend(sampled)
|
|
218
|
+
|
|
219
|
+
logger.debug(
|
|
220
|
+
"Stratified dataset by field '%s' from %d to %d samples",
|
|
221
|
+
field,
|
|
222
|
+
len(self._samples),
|
|
223
|
+
len(stratified_samples),
|
|
224
|
+
)
|
|
225
|
+
|
|
226
|
+
return BaseDataset(
|
|
227
|
+
samples=stratified_samples,
|
|
228
|
+
schema=self._schema,
|
|
229
|
+
metadata=self._metadata,
|
|
230
|
+
validate=False,
|
|
231
|
+
)
|
|
232
|
+
|
|
233
|
+
def shuffle(self, seed: int | None = None) -> BaseDataset:
|
|
234
|
+
"""Return shuffled dataset.
|
|
235
|
+
|
|
236
|
+
Args:
|
|
237
|
+
seed: Random seed for reproducibility
|
|
238
|
+
|
|
239
|
+
Returns:
|
|
240
|
+
New BaseDataset with shuffled samples
|
|
241
|
+
"""
|
|
242
|
+
shuffled = list(self._samples)
|
|
243
|
+
if seed is not None:
|
|
244
|
+
random.Random(seed).shuffle(shuffled)
|
|
245
|
+
else:
|
|
246
|
+
random.shuffle(shuffled)
|
|
247
|
+
|
|
248
|
+
return BaseDataset(
|
|
249
|
+
samples=shuffled,
|
|
250
|
+
schema=self._schema,
|
|
251
|
+
metadata=self._metadata,
|
|
252
|
+
validate=False,
|
|
253
|
+
)
|
|
254
|
+
|
|
255
|
+
def __len__(self) -> int:
|
|
256
|
+
"""Return number of samples in dataset."""
|
|
257
|
+
return len(self._samples)
|
|
258
|
+
|
|
259
|
+
def __getitem__(self, idx: int) -> dict[str, Any]:
|
|
260
|
+
"""Get sample by index."""
|
|
261
|
+
return self._samples[idx]
|
|
262
|
+
|
|
263
|
+
|
|
264
|
+
__all__ = ["BaseDataset"]
|