themis-eval 0.1.0__py3-none-any.whl → 0.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (132) hide show
  1. themis/cli/__init__.py +5 -0
  2. themis/cli/__main__.py +6 -0
  3. themis/cli/commands/__init__.py +19 -0
  4. themis/cli/commands/benchmarks.py +221 -0
  5. themis/cli/commands/comparison.py +394 -0
  6. themis/cli/commands/config_commands.py +244 -0
  7. themis/cli/commands/cost.py +214 -0
  8. themis/cli/commands/demo.py +68 -0
  9. themis/cli/commands/info.py +90 -0
  10. themis/cli/commands/leaderboard.py +362 -0
  11. themis/cli/commands/math_benchmarks.py +318 -0
  12. themis/cli/commands/mcq_benchmarks.py +207 -0
  13. themis/cli/commands/sample_run.py +244 -0
  14. themis/cli/commands/visualize.py +299 -0
  15. themis/cli/main.py +93 -0
  16. themis/cli/new_project.py +33 -0
  17. themis/cli/utils.py +51 -0
  18. themis/config/__init__.py +19 -0
  19. themis/config/loader.py +27 -0
  20. themis/config/registry.py +34 -0
  21. themis/config/runtime.py +214 -0
  22. themis/config/schema.py +112 -0
  23. themis/core/__init__.py +5 -0
  24. themis/core/conversation.py +354 -0
  25. themis/core/entities.py +164 -0
  26. themis/core/serialization.py +231 -0
  27. themis/core/tools.py +393 -0
  28. themis/core/types.py +141 -0
  29. themis/datasets/__init__.py +273 -0
  30. themis/datasets/base.py +264 -0
  31. themis/datasets/commonsense_qa.py +174 -0
  32. themis/datasets/competition_math.py +265 -0
  33. themis/datasets/coqa.py +133 -0
  34. themis/datasets/gpqa.py +190 -0
  35. themis/datasets/gsm8k.py +123 -0
  36. themis/datasets/gsm_symbolic.py +124 -0
  37. themis/datasets/math500.py +122 -0
  38. themis/datasets/med_qa.py +179 -0
  39. themis/datasets/medmcqa.py +169 -0
  40. themis/datasets/mmlu_pro.py +262 -0
  41. themis/datasets/piqa.py +146 -0
  42. themis/datasets/registry.py +201 -0
  43. themis/datasets/schema.py +245 -0
  44. themis/datasets/sciq.py +150 -0
  45. themis/datasets/social_i_qa.py +151 -0
  46. themis/datasets/super_gpqa.py +263 -0
  47. themis/evaluation/__init__.py +1 -0
  48. themis/evaluation/conditional.py +410 -0
  49. themis/evaluation/extractors/__init__.py +19 -0
  50. themis/evaluation/extractors/error_taxonomy_extractor.py +80 -0
  51. themis/evaluation/extractors/exceptions.py +7 -0
  52. themis/evaluation/extractors/identity_extractor.py +29 -0
  53. themis/evaluation/extractors/json_field_extractor.py +45 -0
  54. themis/evaluation/extractors/math_verify_extractor.py +37 -0
  55. themis/evaluation/extractors/regex_extractor.py +43 -0
  56. themis/evaluation/math_verify_utils.py +87 -0
  57. themis/evaluation/metrics/__init__.py +21 -0
  58. themis/evaluation/metrics/composite_metric.py +47 -0
  59. themis/evaluation/metrics/consistency_metric.py +80 -0
  60. themis/evaluation/metrics/exact_match.py +51 -0
  61. themis/evaluation/metrics/length_difference_tolerance.py +33 -0
  62. themis/evaluation/metrics/math_verify_accuracy.py +40 -0
  63. themis/evaluation/metrics/pairwise_judge_metric.py +141 -0
  64. themis/evaluation/metrics/response_length.py +33 -0
  65. themis/evaluation/metrics/rubric_judge_metric.py +134 -0
  66. themis/evaluation/pipeline.py +49 -0
  67. themis/evaluation/pipelines/__init__.py +15 -0
  68. themis/evaluation/pipelines/composable_pipeline.py +357 -0
  69. themis/evaluation/pipelines/standard_pipeline.py +288 -0
  70. themis/evaluation/reports.py +293 -0
  71. themis/evaluation/statistics/__init__.py +53 -0
  72. themis/evaluation/statistics/bootstrap.py +79 -0
  73. themis/evaluation/statistics/confidence_intervals.py +121 -0
  74. themis/evaluation/statistics/distributions.py +207 -0
  75. themis/evaluation/statistics/effect_sizes.py +124 -0
  76. themis/evaluation/statistics/hypothesis_tests.py +305 -0
  77. themis/evaluation/statistics/types.py +139 -0
  78. themis/evaluation/strategies/__init__.py +13 -0
  79. themis/evaluation/strategies/attempt_aware_evaluation_strategy.py +51 -0
  80. themis/evaluation/strategies/default_evaluation_strategy.py +25 -0
  81. themis/evaluation/strategies/evaluation_strategy.py +24 -0
  82. themis/evaluation/strategies/judge_evaluation_strategy.py +64 -0
  83. themis/experiment/__init__.py +5 -0
  84. themis/experiment/builder.py +151 -0
  85. themis/experiment/cache_manager.py +129 -0
  86. themis/experiment/comparison.py +631 -0
  87. themis/experiment/cost.py +310 -0
  88. themis/experiment/definitions.py +62 -0
  89. themis/experiment/export.py +690 -0
  90. themis/experiment/export_csv.py +159 -0
  91. themis/experiment/integration_manager.py +104 -0
  92. themis/experiment/math.py +192 -0
  93. themis/experiment/mcq.py +169 -0
  94. themis/experiment/orchestrator.py +373 -0
  95. themis/experiment/pricing.py +317 -0
  96. themis/experiment/storage.py +255 -0
  97. themis/experiment/visualization.py +588 -0
  98. themis/generation/__init__.py +1 -0
  99. themis/generation/agentic_runner.py +420 -0
  100. themis/generation/batching.py +254 -0
  101. themis/generation/clients.py +143 -0
  102. themis/generation/conversation_runner.py +236 -0
  103. themis/generation/plan.py +456 -0
  104. themis/generation/providers/litellm_provider.py +221 -0
  105. themis/generation/providers/vllm_provider.py +135 -0
  106. themis/generation/router.py +34 -0
  107. themis/generation/runner.py +207 -0
  108. themis/generation/strategies.py +98 -0
  109. themis/generation/templates.py +71 -0
  110. themis/generation/turn_strategies.py +393 -0
  111. themis/generation/types.py +9 -0
  112. themis/integrations/__init__.py +0 -0
  113. themis/integrations/huggingface.py +61 -0
  114. themis/integrations/wandb.py +65 -0
  115. themis/interfaces/__init__.py +83 -0
  116. themis/project/__init__.py +20 -0
  117. themis/project/definitions.py +98 -0
  118. themis/project/patterns.py +230 -0
  119. themis/providers/__init__.py +5 -0
  120. themis/providers/registry.py +39 -0
  121. themis/utils/api_generator.py +379 -0
  122. themis/utils/cost_tracking.py +376 -0
  123. themis/utils/dashboard.py +452 -0
  124. themis/utils/logging_utils.py +41 -0
  125. themis/utils/progress.py +58 -0
  126. themis/utils/tracing.py +320 -0
  127. {themis_eval-0.1.0.dist-info → themis_eval-0.1.1.dist-info}/METADATA +1 -1
  128. themis_eval-0.1.1.dist-info/RECORD +134 -0
  129. themis_eval-0.1.0.dist-info/RECORD +0 -8
  130. {themis_eval-0.1.0.dist-info → themis_eval-0.1.1.dist-info}/WHEEL +0 -0
  131. {themis_eval-0.1.0.dist-info → themis_eval-0.1.1.dist-info}/licenses/LICENSE +0 -0
  132. {themis_eval-0.1.0.dist-info → themis_eval-0.1.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,273 @@
1
+ """Dataset helpers for Themis experiments."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from typing import Any
6
+
7
+ from . import (
8
+ competition_math,
9
+ commonsense_qa,
10
+ coqa,
11
+ gpqa,
12
+ gsm_symbolic,
13
+ gsm8k,
14
+ math500,
15
+ med_qa,
16
+ medmcqa,
17
+ mmlu_pro,
18
+ piqa,
19
+ sciq,
20
+ social_i_qa,
21
+ super_gpqa,
22
+ )
23
+ from .registry import (
24
+ create_dataset,
25
+ is_dataset_registered,
26
+ list_datasets,
27
+ register_dataset,
28
+ unregister_dataset,
29
+ )
30
+
31
+ # Factory functions for built-in datasets
32
+
33
+
34
+ def _create_math500(options: dict[str, Any]) -> list[dict[str, Any]]:
35
+ """Factory for MATH-500 dataset."""
36
+ samples = math500.load_math500(
37
+ source=options.get("source", "huggingface"),
38
+ data_dir=options.get("data_dir"),
39
+ split=options.get("split", "test"),
40
+ limit=options.get("limit"),
41
+ subjects=options.get("subjects"),
42
+ )
43
+ return [sample.to_generation_example() for sample in samples]
44
+
45
+
46
+ def _create_competition_math(options: dict[str, Any]) -> list[dict[str, Any]]:
47
+ """Factory for competition math datasets (AIME, AMC, etc.)."""
48
+ # Get dataset and subset from options
49
+ dataset = options.get("dataset")
50
+ if not dataset:
51
+ raise ValueError(
52
+ "Competition math requires 'dataset' option "
53
+ "(e.g., 'math-ai/aime24', 'math-ai/amc23')"
54
+ )
55
+
56
+ samples = competition_math.load_competition_math(
57
+ dataset=dataset,
58
+ subset=options.get("subset"),
59
+ source=options.get("source", "huggingface"),
60
+ data_dir=options.get("data_dir"),
61
+ split=options.get("split", "test"),
62
+ limit=options.get("limit"),
63
+ subjects=options.get("subjects"),
64
+ )
65
+ return [sample.to_generation_example() for sample in samples]
66
+
67
+
68
+ def _create_super_gpqa(options: dict[str, Any]) -> list[dict[str, Any]]:
69
+ """Factory for SuperGPQA dataset."""
70
+ samples = super_gpqa.load_super_gpqa(
71
+ source=options.get("source", "huggingface"),
72
+ data_dir=options.get("data_dir"),
73
+ split=options.get("split", "test"),
74
+ limit=options.get("limit"),
75
+ subjects=options.get("subjects"),
76
+ )
77
+ return [sample.to_generation_example() for sample in samples]
78
+
79
+
80
+ def _create_mmlu_pro(options: dict[str, Any]) -> list[dict[str, Any]]:
81
+ """Factory for MMLU-Pro dataset."""
82
+ samples = mmlu_pro.load_mmlu_pro(
83
+ source=options.get("source", "huggingface"),
84
+ data_dir=options.get("data_dir"),
85
+ split=options.get("split", "test"),
86
+ limit=options.get("limit"),
87
+ subjects=options.get("subjects"),
88
+ )
89
+ return [sample.to_generation_example() for sample in samples]
90
+
91
+
92
+ def _create_gsm8k(options: dict[str, Any]) -> list[dict[str, Any]]:
93
+ """Factory for GSM8K dataset."""
94
+ samples = gsm8k.load_gsm8k(
95
+ source=options.get("source", "huggingface"),
96
+ data_dir=options.get("data_dir"),
97
+ split=options.get("split", "test"),
98
+ limit=options.get("limit"),
99
+ subset=options.get("subset", "main"),
100
+ )
101
+ return [sample.to_generation_example() for sample in samples]
102
+
103
+
104
+ def _create_gpqa(options: dict[str, Any]) -> list[dict[str, Any]]:
105
+ """Factory for GPQA dataset."""
106
+ samples = gpqa.load_gpqa(
107
+ source=options.get("source", "huggingface"),
108
+ data_dir=options.get("data_dir"),
109
+ split=options.get("split", "test"),
110
+ limit=options.get("limit"),
111
+ subset=options.get("subset", "gpqa_diamond"),
112
+ )
113
+ return [sample.to_generation_example() for sample in samples]
114
+
115
+
116
+ def _create_gsm_symbolic(options: dict[str, Any]) -> list[dict[str, Any]]:
117
+ """Factory for GSM-Symbolic dataset."""
118
+ samples = gsm_symbolic.load_gsm_symbolic(
119
+ source=options.get("source", "huggingface"),
120
+ data_dir=options.get("data_dir"),
121
+ split=options.get("split", "test"),
122
+ limit=options.get("limit"),
123
+ subset=options.get("subset", "main"),
124
+ )
125
+ return [sample.to_generation_example() for sample in samples]
126
+
127
+
128
+ def _create_medmcqa(options: dict[str, Any]) -> list[dict[str, Any]]:
129
+ """Factory for MedMCQA dataset."""
130
+ samples = medmcqa.load_medmcqa(
131
+ source=options.get("source", "huggingface"),
132
+ data_dir=options.get("data_dir"),
133
+ split=options.get("split", "test"),
134
+ limit=options.get("limit"),
135
+ subset=options.get("subset"),
136
+ )
137
+ return [sample.to_generation_example() for sample in samples]
138
+
139
+
140
+ def _create_med_qa(options: dict[str, Any]) -> list[dict[str, Any]]:
141
+ """Factory for MedQA dataset."""
142
+ samples = med_qa.load_med_qa(
143
+ source=options.get("source", "huggingface"),
144
+ data_dir=options.get("data_dir"),
145
+ split=options.get("split", "test"),
146
+ limit=options.get("limit"),
147
+ subset=options.get("subset", "med_qa_en_bigbio_qa"),
148
+ )
149
+ return [sample.to_generation_example() for sample in samples]
150
+
151
+
152
+ def _create_sciq(options: dict[str, Any]) -> list[dict[str, Any]]:
153
+ """Factory for SciQ dataset."""
154
+ samples = sciq.load_sciq(
155
+ source=options.get("source", "huggingface"),
156
+ data_dir=options.get("data_dir"),
157
+ split=options.get("split", "test"),
158
+ limit=options.get("limit"),
159
+ )
160
+ return [sample.to_generation_example() for sample in samples]
161
+
162
+
163
+ def _create_commonsense_qa(options: dict[str, Any]) -> list[dict[str, Any]]:
164
+ """Factory for CommonsenseQA dataset."""
165
+ samples = commonsense_qa.load_commonsense_qa(
166
+ source=options.get("source", "huggingface"),
167
+ data_dir=options.get("data_dir"),
168
+ split=options.get("split", "validation"),
169
+ limit=options.get("limit"),
170
+ )
171
+ return [sample.to_generation_example() for sample in samples]
172
+
173
+
174
+ def _create_piqa(options: dict[str, Any]) -> list[dict[str, Any]]:
175
+ """Factory for PIQA dataset."""
176
+ samples = piqa.load_piqa(
177
+ source=options.get("source", "huggingface"),
178
+ data_dir=options.get("data_dir"),
179
+ split=options.get("split", "validation"),
180
+ limit=options.get("limit"),
181
+ )
182
+ return [sample.to_generation_example() for sample in samples]
183
+
184
+
185
+ def _create_social_i_qa(options: dict[str, Any]) -> list[dict[str, Any]]:
186
+ """Factory for Social IQA dataset."""
187
+ samples = social_i_qa.load_social_i_qa(
188
+ source=options.get("source", "huggingface"),
189
+ data_dir=options.get("data_dir"),
190
+ split=options.get("split", "validation"),
191
+ limit=options.get("limit"),
192
+ )
193
+ return [sample.to_generation_example() for sample in samples]
194
+
195
+
196
+ def _create_coqa(options: dict[str, Any]) -> list[dict[str, Any]]:
197
+ """Factory for CoQA dataset."""
198
+ samples = coqa.load_coqa(
199
+ source=options.get("source", "huggingface"),
200
+ data_dir=options.get("data_dir"),
201
+ split=options.get("split", "validation"),
202
+ limit=options.get("limit"),
203
+ )
204
+ return [sample.to_generation_example() for sample in samples]
205
+
206
+
207
+ # Auto-register built-in datasets
208
+ register_dataset("math500", _create_math500)
209
+ register_dataset("competition_math", _create_competition_math)
210
+ register_dataset("supergpqa", _create_super_gpqa)
211
+ register_dataset("mmlu-pro", _create_mmlu_pro)
212
+ register_dataset("gsm8k", _create_gsm8k)
213
+ register_dataset("gpqa", _create_gpqa)
214
+ register_dataset("gsm-symbolic", _create_gsm_symbolic)
215
+ register_dataset("medmcqa", _create_medmcqa)
216
+ register_dataset("med_qa", _create_med_qa)
217
+ register_dataset("sciq", _create_sciq)
218
+ register_dataset("commonsense_qa", _create_commonsense_qa)
219
+ register_dataset("piqa", _create_piqa)
220
+ register_dataset("social_i_qa", _create_social_i_qa)
221
+ register_dataset("coqa", _create_coqa)
222
+
223
+
224
+ # Also register specific competition datasets as aliases
225
+ def _create_aime24(options: dict[str, Any]) -> list[dict[str, Any]]:
226
+ return _create_competition_math({**options, "dataset": "math-ai/aime24"})
227
+
228
+
229
+ def _create_aime25(options: dict[str, Any]) -> list[dict[str, Any]]:
230
+ return _create_competition_math({**options, "dataset": "math-ai/aime25"})
231
+
232
+
233
+ def _create_amc23(options: dict[str, Any]) -> list[dict[str, Any]]:
234
+ return _create_competition_math({**options, "dataset": "math-ai/amc23"})
235
+
236
+
237
+ def _create_olympiadbench(options: dict[str, Any]) -> list[dict[str, Any]]:
238
+ return _create_competition_math({**options, "dataset": "math-ai/olympiadbench"})
239
+
240
+
241
+ def _create_beyondaime(options: dict[str, Any]) -> list[dict[str, Any]]:
242
+ return _create_competition_math({**options, "dataset": "ByteDance-Seed/BeyondAIME"})
243
+
244
+
245
+ register_dataset("aime24", _create_aime24)
246
+ register_dataset("aime25", _create_aime25)
247
+ register_dataset("amc23", _create_amc23)
248
+ register_dataset("olympiadbench", _create_olympiadbench)
249
+ register_dataset("beyondaime", _create_beyondaime)
250
+
251
+ __all__ = [
252
+ # Legacy module exports
253
+ "competition_math",
254
+ "commonsense_qa",
255
+ "coqa",
256
+ "gpqa",
257
+ "gsm_symbolic",
258
+ "gsm8k",
259
+ "math500",
260
+ "med_qa",
261
+ "medmcqa",
262
+ "mmlu_pro",
263
+ "piqa",
264
+ "sciq",
265
+ "social_i_qa",
266
+ "super_gpqa",
267
+ # Registry functions
268
+ "register_dataset",
269
+ "unregister_dataset",
270
+ "create_dataset",
271
+ "list_datasets",
272
+ "is_dataset_registered",
273
+ ]
@@ -0,0 +1,264 @@
1
+ """Base dataset implementation with schema support.
2
+
3
+ This module provides a base class that implements common dataset operations
4
+ like filtering, limiting, and stratification.
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ import logging
10
+ import random
11
+ from collections import defaultdict
12
+ from typing import Any, Callable, Iterable
13
+
14
+ from themis.datasets import schema as dataset_schema
15
+
16
+ logger = logging.getLogger(__name__)
17
+
18
+
19
+ class BaseDataset:
20
+ """Base implementation for dataset classes that implement DatasetAdapter protocol.
21
+
22
+ This class provides a reusable implementation of common dataset operations
23
+ including filtering, limiting, and stratification. It satisfies the
24
+ DatasetAdapter protocol by implementing iter_samples().
25
+
26
+ The class implements the structural DatasetAdapter protocol without
27
+ explicit inheritance, using duck typing. At runtime, instances will
28
+ satisfy isinstance(obj, DatasetAdapter) checks.
29
+
30
+ Subclasses should provide the initial samples, schema, and metadata.
31
+
32
+ Protocol Compliance:
33
+ Implements DatasetAdapter protocol via iter_samples() method
34
+
35
+ Examples:
36
+ class MyDataset(BaseDataset):
37
+ def __init__(self):
38
+ samples = [
39
+ {"id": "1", "problem": "What is 2+2?", "answer": "4"},
40
+ {"id": "2", "problem": "What is 3+3?", "answer": "6"},
41
+ ]
42
+ schema = DatasetSchema(
43
+ id_field="id",
44
+ reference_field="answer",
45
+ required_fields={"id", "problem", "answer"},
46
+ )
47
+ metadata = DatasetMetadata(
48
+ name="SimpleArithmetic",
49
+ version="1.0",
50
+ total_samples=2,
51
+ )
52
+ super().__init__(samples, schema, metadata)
53
+
54
+ # Verify protocol compliance
55
+ >>> from themis.interfaces import DatasetAdapter
56
+ >>> dataset = MyDataset()
57
+ >>> isinstance(dataset, DatasetAdapter) # True
58
+ """
59
+
60
+ def __init__(
61
+ self,
62
+ samples: Iterable[dict[str, Any]],
63
+ schema: dataset_schema.DatasetSchema,
64
+ metadata: dataset_schema.DatasetMetadata,
65
+ validate: bool = True,
66
+ ):
67
+ """Initialize dataset.
68
+
69
+ Args:
70
+ samples: Iterable of sample dictionaries
71
+ schema: Dataset schema
72
+ metadata: Dataset metadata
73
+ validate: Whether to validate samples against schema (default: True)
74
+
75
+ Raises:
76
+ ValueError: If validation is enabled and samples don't match schema
77
+ """
78
+ self._samples = list(samples)
79
+ self._schema = schema
80
+ self._metadata = metadata
81
+
82
+ if validate:
83
+ self._validate_all()
84
+
85
+ # Update metadata total if not set
86
+ if self._metadata.total_samples is None:
87
+ self._metadata = dataset_schema.DatasetMetadata(
88
+ **{**self._metadata.__dict__, "total_samples": len(self._samples)}
89
+ )
90
+
91
+ def _validate_all(self) -> None:
92
+ """Validate all samples against schema."""
93
+ logger.debug(
94
+ "Validating %d samples for dataset %s",
95
+ len(self._samples),
96
+ self._metadata.name,
97
+ )
98
+
99
+ for i, sample in enumerate(self._samples):
100
+ try:
101
+ self._schema.validate_sample(sample)
102
+ except ValueError as e:
103
+ logger.error("Validation failed for sample %d: %s", i, e)
104
+ raise ValueError(f"Sample {i} validation failed: {e}") from e
105
+
106
+ logger.debug("All samples validated successfully")
107
+
108
+ def iter_samples(self) -> Iterable[dict[str, Any]]:
109
+ """Iterate over dataset samples."""
110
+ return iter(self._samples)
111
+
112
+ def get_schema(self) -> dataset_schema.DatasetSchema:
113
+ """Get the dataset schema."""
114
+ return self._schema
115
+
116
+ def get_metadata(self) -> dataset_schema.DatasetMetadata:
117
+ """Get dataset metadata."""
118
+ return self._metadata
119
+
120
+ def filter(self, predicate: Callable[[dict[str, Any]], bool]) -> BaseDataset:
121
+ """Return filtered view of dataset.
122
+
123
+ Args:
124
+ predicate: Function that returns True for samples to keep
125
+
126
+ Returns:
127
+ New BaseDataset with filtered samples
128
+ """
129
+ filtered_samples = [s for s in self._samples if predicate(s)]
130
+ logger.debug(
131
+ "Filtered dataset from %d to %d samples",
132
+ len(self._samples),
133
+ len(filtered_samples),
134
+ )
135
+
136
+ return BaseDataset(
137
+ samples=filtered_samples,
138
+ schema=self._schema,
139
+ metadata=self._metadata,
140
+ validate=False, # Already validated
141
+ )
142
+
143
+ def limit(self, n: int) -> BaseDataset:
144
+ """Return dataset limited to first n samples.
145
+
146
+ Args:
147
+ n: Maximum number of samples
148
+
149
+ Returns:
150
+ New BaseDataset with limited samples
151
+ """
152
+ limited_samples = self._samples[:n]
153
+ logger.debug(
154
+ "Limited dataset from %d to %d samples",
155
+ len(self._samples),
156
+ len(limited_samples),
157
+ )
158
+
159
+ return BaseDataset(
160
+ samples=limited_samples,
161
+ schema=self._schema,
162
+ metadata=self._metadata,
163
+ validate=False,
164
+ )
165
+
166
+ def stratify(
167
+ self, field: str, distribution: dict[str, float], seed: int | None = None
168
+ ) -> BaseDataset:
169
+ """Return stratified sample of dataset.
170
+
171
+ Args:
172
+ field: Field to stratify by
173
+ distribution: Desired distribution (values should sum to ~1.0)
174
+ seed: Random seed for reproducibility
175
+
176
+ Returns:
177
+ New BaseDataset with stratified samples
178
+
179
+ Raises:
180
+ ValueError: If field doesn't exist or distribution is invalid
181
+ """
182
+ # Group samples by field value
183
+ groups: dict[Any, list[dict[str, Any]]] = defaultdict(list)
184
+ for sample in self._samples:
185
+ if field not in sample:
186
+ raise ValueError(f"Field '{field}' not found in sample")
187
+ groups[sample[field]].append(sample)
188
+
189
+ # Validate distribution
190
+ total_dist = sum(distribution.values())
191
+ if not (0.99 <= total_dist <= 1.01):
192
+ logger.warning("Distribution values sum to %f, expected ~1.0", total_dist)
193
+
194
+ # Calculate sample sizes for each group
195
+ total_samples = len(self._samples)
196
+ stratified_samples = []
197
+
198
+ if seed is not None:
199
+ rng = random.Random(seed)
200
+ else:
201
+ rng = random.Random()
202
+
203
+ for value, desired_ratio in distribution.items():
204
+ if value not in groups:
205
+ logger.warning(
206
+ "Value '%s' specified in distribution but not found in dataset",
207
+ value,
208
+ )
209
+ continue
210
+
211
+ group_samples = groups[value]
212
+ n_samples = int(total_samples * desired_ratio)
213
+ n_samples = min(n_samples, len(group_samples)) # Can't exceed available
214
+
215
+ # Sample from group
216
+ sampled = rng.sample(group_samples, n_samples)
217
+ stratified_samples.extend(sampled)
218
+
219
+ logger.debug(
220
+ "Stratified dataset by field '%s' from %d to %d samples",
221
+ field,
222
+ len(self._samples),
223
+ len(stratified_samples),
224
+ )
225
+
226
+ return BaseDataset(
227
+ samples=stratified_samples,
228
+ schema=self._schema,
229
+ metadata=self._metadata,
230
+ validate=False,
231
+ )
232
+
233
+ def shuffle(self, seed: int | None = None) -> BaseDataset:
234
+ """Return shuffled dataset.
235
+
236
+ Args:
237
+ seed: Random seed for reproducibility
238
+
239
+ Returns:
240
+ New BaseDataset with shuffled samples
241
+ """
242
+ shuffled = list(self._samples)
243
+ if seed is not None:
244
+ random.Random(seed).shuffle(shuffled)
245
+ else:
246
+ random.shuffle(shuffled)
247
+
248
+ return BaseDataset(
249
+ samples=shuffled,
250
+ schema=self._schema,
251
+ metadata=self._metadata,
252
+ validate=False,
253
+ )
254
+
255
+ def __len__(self) -> int:
256
+ """Return number of samples in dataset."""
257
+ return len(self._samples)
258
+
259
+ def __getitem__(self, idx: int) -> dict[str, Any]:
260
+ """Get sample by index."""
261
+ return self._samples[idx]
262
+
263
+
264
+ __all__ = ["BaseDataset"]