themis-eval 0.1.0__py3-none-any.whl → 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- themis/__init__.py +12 -1
- themis/_version.py +2 -2
- themis/api.py +343 -0
- themis/backends/__init__.py +17 -0
- themis/backends/execution.py +197 -0
- themis/backends/storage.py +260 -0
- themis/cli/__init__.py +5 -0
- themis/cli/__main__.py +6 -0
- themis/cli/commands/__init__.py +19 -0
- themis/cli/commands/benchmarks.py +221 -0
- themis/cli/commands/comparison.py +394 -0
- themis/cli/commands/config_commands.py +244 -0
- themis/cli/commands/cost.py +214 -0
- themis/cli/commands/demo.py +68 -0
- themis/cli/commands/info.py +90 -0
- themis/cli/commands/leaderboard.py +362 -0
- themis/cli/commands/math_benchmarks.py +318 -0
- themis/cli/commands/mcq_benchmarks.py +207 -0
- themis/cli/commands/results.py +252 -0
- themis/cli/commands/sample_run.py +244 -0
- themis/cli/commands/visualize.py +299 -0
- themis/cli/main.py +463 -0
- themis/cli/new_project.py +33 -0
- themis/cli/utils.py +51 -0
- themis/comparison/__init__.py +25 -0
- themis/comparison/engine.py +348 -0
- themis/comparison/reports.py +283 -0
- themis/comparison/statistics.py +402 -0
- themis/config/__init__.py +19 -0
- themis/config/loader.py +27 -0
- themis/config/registry.py +34 -0
- themis/config/runtime.py +214 -0
- themis/config/schema.py +112 -0
- themis/core/__init__.py +5 -0
- themis/core/conversation.py +354 -0
- themis/core/entities.py +184 -0
- themis/core/serialization.py +231 -0
- themis/core/tools.py +393 -0
- themis/core/types.py +141 -0
- themis/datasets/__init__.py +273 -0
- themis/datasets/base.py +264 -0
- themis/datasets/commonsense_qa.py +174 -0
- themis/datasets/competition_math.py +265 -0
- themis/datasets/coqa.py +133 -0
- themis/datasets/gpqa.py +190 -0
- themis/datasets/gsm8k.py +123 -0
- themis/datasets/gsm_symbolic.py +124 -0
- themis/datasets/math500.py +122 -0
- themis/datasets/med_qa.py +179 -0
- themis/datasets/medmcqa.py +169 -0
- themis/datasets/mmlu_pro.py +262 -0
- themis/datasets/piqa.py +146 -0
- themis/datasets/registry.py +201 -0
- themis/datasets/schema.py +245 -0
- themis/datasets/sciq.py +150 -0
- themis/datasets/social_i_qa.py +151 -0
- themis/datasets/super_gpqa.py +263 -0
- themis/evaluation/__init__.py +1 -0
- themis/evaluation/conditional.py +410 -0
- themis/evaluation/extractors/__init__.py +19 -0
- themis/evaluation/extractors/error_taxonomy_extractor.py +80 -0
- themis/evaluation/extractors/exceptions.py +7 -0
- themis/evaluation/extractors/identity_extractor.py +29 -0
- themis/evaluation/extractors/json_field_extractor.py +45 -0
- themis/evaluation/extractors/math_verify_extractor.py +37 -0
- themis/evaluation/extractors/regex_extractor.py +43 -0
- themis/evaluation/math_verify_utils.py +87 -0
- themis/evaluation/metrics/__init__.py +21 -0
- themis/evaluation/metrics/code/__init__.py +19 -0
- themis/evaluation/metrics/code/codebleu.py +144 -0
- themis/evaluation/metrics/code/execution.py +280 -0
- themis/evaluation/metrics/code/pass_at_k.py +181 -0
- themis/evaluation/metrics/composite_metric.py +47 -0
- themis/evaluation/metrics/consistency_metric.py +80 -0
- themis/evaluation/metrics/exact_match.py +51 -0
- themis/evaluation/metrics/length_difference_tolerance.py +33 -0
- themis/evaluation/metrics/math_verify_accuracy.py +40 -0
- themis/evaluation/metrics/nlp/__init__.py +21 -0
- themis/evaluation/metrics/nlp/bertscore.py +138 -0
- themis/evaluation/metrics/nlp/bleu.py +129 -0
- themis/evaluation/metrics/nlp/meteor.py +153 -0
- themis/evaluation/metrics/nlp/rouge.py +136 -0
- themis/evaluation/metrics/pairwise_judge_metric.py +141 -0
- themis/evaluation/metrics/response_length.py +33 -0
- themis/evaluation/metrics/rubric_judge_metric.py +134 -0
- themis/evaluation/pipeline.py +49 -0
- themis/evaluation/pipelines/__init__.py +15 -0
- themis/evaluation/pipelines/composable_pipeline.py +357 -0
- themis/evaluation/pipelines/standard_pipeline.py +348 -0
- themis/evaluation/reports.py +293 -0
- themis/evaluation/statistics/__init__.py +53 -0
- themis/evaluation/statistics/bootstrap.py +79 -0
- themis/evaluation/statistics/confidence_intervals.py +121 -0
- themis/evaluation/statistics/distributions.py +207 -0
- themis/evaluation/statistics/effect_sizes.py +124 -0
- themis/evaluation/statistics/hypothesis_tests.py +305 -0
- themis/evaluation/statistics/types.py +139 -0
- themis/evaluation/strategies/__init__.py +13 -0
- themis/evaluation/strategies/attempt_aware_evaluation_strategy.py +51 -0
- themis/evaluation/strategies/default_evaluation_strategy.py +25 -0
- themis/evaluation/strategies/evaluation_strategy.py +24 -0
- themis/evaluation/strategies/judge_evaluation_strategy.py +64 -0
- themis/experiment/__init__.py +5 -0
- themis/experiment/builder.py +151 -0
- themis/experiment/cache_manager.py +134 -0
- themis/experiment/comparison.py +631 -0
- themis/experiment/cost.py +310 -0
- themis/experiment/definitions.py +62 -0
- themis/experiment/export.py +798 -0
- themis/experiment/export_csv.py +159 -0
- themis/experiment/integration_manager.py +104 -0
- themis/experiment/math.py +192 -0
- themis/experiment/mcq.py +169 -0
- themis/experiment/orchestrator.py +415 -0
- themis/experiment/pricing.py +317 -0
- themis/experiment/storage.py +1458 -0
- themis/experiment/visualization.py +588 -0
- themis/generation/__init__.py +1 -0
- themis/generation/agentic_runner.py +420 -0
- themis/generation/batching.py +254 -0
- themis/generation/clients.py +143 -0
- themis/generation/conversation_runner.py +236 -0
- themis/generation/plan.py +456 -0
- themis/generation/providers/litellm_provider.py +221 -0
- themis/generation/providers/vllm_provider.py +135 -0
- themis/generation/router.py +34 -0
- themis/generation/runner.py +207 -0
- themis/generation/strategies.py +98 -0
- themis/generation/templates.py +71 -0
- themis/generation/turn_strategies.py +393 -0
- themis/generation/types.py +9 -0
- themis/integrations/__init__.py +0 -0
- themis/integrations/huggingface.py +72 -0
- themis/integrations/wandb.py +77 -0
- themis/interfaces/__init__.py +169 -0
- themis/presets/__init__.py +10 -0
- themis/presets/benchmarks.py +354 -0
- themis/presets/models.py +190 -0
- themis/project/__init__.py +20 -0
- themis/project/definitions.py +98 -0
- themis/project/patterns.py +230 -0
- themis/providers/__init__.py +5 -0
- themis/providers/registry.py +39 -0
- themis/server/__init__.py +28 -0
- themis/server/app.py +337 -0
- themis/utils/api_generator.py +379 -0
- themis/utils/cost_tracking.py +376 -0
- themis/utils/dashboard.py +452 -0
- themis/utils/logging_utils.py +41 -0
- themis/utils/progress.py +58 -0
- themis/utils/tracing.py +320 -0
- themis_eval-0.2.0.dist-info/METADATA +596 -0
- themis_eval-0.2.0.dist-info/RECORD +157 -0
- {themis_eval-0.1.0.dist-info → themis_eval-0.2.0.dist-info}/WHEEL +1 -1
- themis_eval-0.1.0.dist-info/METADATA +0 -758
- themis_eval-0.1.0.dist-info/RECORD +0 -8
- {themis_eval-0.1.0.dist-info → themis_eval-0.2.0.dist-info}/licenses/LICENSE +0 -0
- {themis_eval-0.1.0.dist-info → themis_eval-0.2.0.dist-info}/top_level.txt +0 -0
themis/datasets/piqa.py
ADDED
|
@@ -0,0 +1,146 @@
|
|
|
1
|
+
"""Helpers for working with the ybisk/piqa dataset."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import json
|
|
6
|
+
from pathlib import Path
|
|
7
|
+
from typing import Any, Iterable, Iterator, List, Sequence
|
|
8
|
+
|
|
9
|
+
from pydantic import BaseModel, Field, field_validator
|
|
10
|
+
|
|
11
|
+
_DATASET_NAME = "ybisk/piqa"
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class PiqaSample(BaseModel):
|
|
15
|
+
unique_id: str
|
|
16
|
+
goal: str
|
|
17
|
+
choices: list[str]
|
|
18
|
+
answer: str
|
|
19
|
+
metadata: dict[str, Any] = Field(default_factory=dict)
|
|
20
|
+
|
|
21
|
+
@field_validator("choices", mode="before")
|
|
22
|
+
@classmethod
|
|
23
|
+
def _ensure_choices(cls, value: Any) -> list[str]:
|
|
24
|
+
if value is None:
|
|
25
|
+
return []
|
|
26
|
+
if isinstance(value, (list, tuple)):
|
|
27
|
+
return [str(item) for item in value]
|
|
28
|
+
raise TypeError("choices must be a sequence")
|
|
29
|
+
|
|
30
|
+
def to_generation_example(self) -> dict[str, Any]:
|
|
31
|
+
return {
|
|
32
|
+
"unique_id": self.unique_id,
|
|
33
|
+
"goal": self.goal,
|
|
34
|
+
"choices": list(self.choices),
|
|
35
|
+
"answer": self.answer,
|
|
36
|
+
"metadata": dict(self.metadata),
|
|
37
|
+
}
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
def load_piqa(
|
|
41
|
+
*,
|
|
42
|
+
split: str = "validation", # Test set usually has no labels
|
|
43
|
+
limit: int | None = None,
|
|
44
|
+
source: str = "huggingface",
|
|
45
|
+
data_dir: str | Path | None = None,
|
|
46
|
+
) -> List[PiqaSample]:
|
|
47
|
+
"""Load PIQA samples from Hugging Face or a local directory."""
|
|
48
|
+
|
|
49
|
+
if source not in {"huggingface", "local"}:
|
|
50
|
+
raise ValueError(
|
|
51
|
+
f"Unsupported source '{source}'. Expected one of: 'huggingface', 'local'."
|
|
52
|
+
)
|
|
53
|
+
|
|
54
|
+
if source == "huggingface":
|
|
55
|
+
rows = _load_from_huggingface(split=split)
|
|
56
|
+
else:
|
|
57
|
+
if data_dir is None:
|
|
58
|
+
raise ValueError(
|
|
59
|
+
"data_dir must be provided when source='local'. "
|
|
60
|
+
"Pass dataset.data_dir in configs or --data-dir on the CLI."
|
|
61
|
+
)
|
|
62
|
+
rows = _load_from_local(Path(data_dir))
|
|
63
|
+
|
|
64
|
+
samples: list[PiqaSample] = []
|
|
65
|
+
for index, row in enumerate(rows, start=1):
|
|
66
|
+
sample = _row_to_sample(row, index=index)
|
|
67
|
+
samples.append(sample)
|
|
68
|
+
if limit is not None and len(samples) >= limit:
|
|
69
|
+
break
|
|
70
|
+
return samples
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
def _row_to_sample(row: dict[str, Any], *, index: int) -> PiqaSample:
|
|
74
|
+
unique_id = (
|
|
75
|
+
row.get("id")
|
|
76
|
+
or row.get("unique_id")
|
|
77
|
+
or f"piqa-{index:05d}"
|
|
78
|
+
)
|
|
79
|
+
goal = row.get("goal") or ""
|
|
80
|
+
|
|
81
|
+
# PIQA has 'sol1', 'sol2'
|
|
82
|
+
choices = [
|
|
83
|
+
str(row.get("sol1") or ""),
|
|
84
|
+
str(row.get("sol2") or ""),
|
|
85
|
+
]
|
|
86
|
+
|
|
87
|
+
# label is integer 0 or 1
|
|
88
|
+
label = row.get("label")
|
|
89
|
+
answer = ""
|
|
90
|
+
if label is not None:
|
|
91
|
+
try:
|
|
92
|
+
label_int = int(label)
|
|
93
|
+
if 0 <= label_int < len(choices):
|
|
94
|
+
answer = choices[label_int]
|
|
95
|
+
except (ValueError, TypeError):
|
|
96
|
+
pass
|
|
97
|
+
|
|
98
|
+
metadata_keys = {
|
|
99
|
+
"goal", "sol1", "sol2", "label", "id"
|
|
100
|
+
}
|
|
101
|
+
metadata = {key: value for key, value in row.items() if key not in metadata_keys}
|
|
102
|
+
|
|
103
|
+
return PiqaSample(
|
|
104
|
+
unique_id=str(unique_id),
|
|
105
|
+
goal=str(goal),
|
|
106
|
+
choices=choices,
|
|
107
|
+
answer=answer,
|
|
108
|
+
metadata=metadata,
|
|
109
|
+
)
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
def _load_from_huggingface(*, split: str) -> Iterable[dict[str, Any]]:
|
|
113
|
+
try:
|
|
114
|
+
from datasets import load_dataset
|
|
115
|
+
except ImportError as exc: # pragma: no cover - optional dependency
|
|
116
|
+
raise RuntimeError(
|
|
117
|
+
"datasets is required to load PIQA from Hugging Face. Install it via `uv pip install '.[hf]'`."
|
|
118
|
+
) from exc
|
|
119
|
+
|
|
120
|
+
dataset = load_dataset(_DATASET_NAME, split=split)
|
|
121
|
+
for row in dataset:
|
|
122
|
+
yield dict(row)
|
|
123
|
+
|
|
124
|
+
|
|
125
|
+
def _load_from_local(root: Path) -> Iterator[dict[str, Any]]:
|
|
126
|
+
if not root.exists():
|
|
127
|
+
raise FileNotFoundError(f"Local dataset directory not found: {root}")
|
|
128
|
+
|
|
129
|
+
for path in root.rglob("*"):
|
|
130
|
+
if path.suffix.lower() == ".json":
|
|
131
|
+
with path.open("r", encoding="utf-8") as handle:
|
|
132
|
+
row = json.load(handle)
|
|
133
|
+
row.setdefault("id", path.stem)
|
|
134
|
+
yield row
|
|
135
|
+
elif path.suffix.lower() in {".jsonl", ".ndjson"}:
|
|
136
|
+
with path.open("r", encoding="utf-8") as handle:
|
|
137
|
+
for line_num, line in enumerate(handle, start=1):
|
|
138
|
+
line = line.strip()
|
|
139
|
+
if not line:
|
|
140
|
+
continue
|
|
141
|
+
row = json.loads(line)
|
|
142
|
+
row.setdefault("id", f"{path.stem}-{line_num}")
|
|
143
|
+
yield row
|
|
144
|
+
|
|
145
|
+
|
|
146
|
+
__all__ = ["PiqaSample", "load_piqa"]
|
|
@@ -0,0 +1,201 @@
|
|
|
1
|
+
"""Registry for dataset loaders.
|
|
2
|
+
|
|
3
|
+
This module provides a plugin-based registry system for datasets, allowing
|
|
4
|
+
users to register custom datasets without modifying core Themis code.
|
|
5
|
+
|
|
6
|
+
Example:
|
|
7
|
+
```python
|
|
8
|
+
from themis.datasets import register_dataset
|
|
9
|
+
|
|
10
|
+
def create_my_dataset(**options):
|
|
11
|
+
from my_module import MyDataset
|
|
12
|
+
return MyDataset(path=options.get('path'))
|
|
13
|
+
|
|
14
|
+
register_dataset('my-dataset', create_my_dataset)
|
|
15
|
+
```
|
|
16
|
+
"""
|
|
17
|
+
|
|
18
|
+
from __future__ import annotations
|
|
19
|
+
|
|
20
|
+
from typing import Any, Callable
|
|
21
|
+
|
|
22
|
+
# Factory type: takes config options, returns list of samples
|
|
23
|
+
DatasetFactory = Callable[[dict[str, Any]], list[dict[str, Any]]]
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class DatasetRegistry:
|
|
27
|
+
"""Registry for dataset loaders.
|
|
28
|
+
|
|
29
|
+
Maintains a mapping from dataset names to factory functions that
|
|
30
|
+
load and return dataset samples.
|
|
31
|
+
"""
|
|
32
|
+
|
|
33
|
+
def __init__(self):
|
|
34
|
+
self._datasets: dict[str, DatasetFactory] = {}
|
|
35
|
+
|
|
36
|
+
def register(self, name: str, factory: DatasetFactory) -> None:
|
|
37
|
+
"""Register a dataset factory.
|
|
38
|
+
|
|
39
|
+
Args:
|
|
40
|
+
name: Unique identifier for the dataset (e.g., 'math500', 'my-dataset')
|
|
41
|
+
factory: Callable that takes config options and returns list of samples
|
|
42
|
+
|
|
43
|
+
Raises:
|
|
44
|
+
ValueError: If dataset name is already registered
|
|
45
|
+
"""
|
|
46
|
+
if name in self._datasets:
|
|
47
|
+
raise ValueError(
|
|
48
|
+
f"Dataset '{name}' is already registered. "
|
|
49
|
+
f"Use a different name or unregister the existing dataset first."
|
|
50
|
+
)
|
|
51
|
+
self._datasets[name] = factory
|
|
52
|
+
|
|
53
|
+
def unregister(self, name: str) -> None:
|
|
54
|
+
"""Unregister a dataset.
|
|
55
|
+
|
|
56
|
+
Args:
|
|
57
|
+
name: Dataset identifier to remove
|
|
58
|
+
|
|
59
|
+
Raises:
|
|
60
|
+
ValueError: If dataset name is not registered
|
|
61
|
+
"""
|
|
62
|
+
if name not in self._datasets:
|
|
63
|
+
raise ValueError(f"Dataset '{name}' is not registered")
|
|
64
|
+
del self._datasets[name]
|
|
65
|
+
|
|
66
|
+
def create(self, name: str, **options) -> list[dict[str, Any]]:
|
|
67
|
+
"""Create a dataset instance by loading samples.
|
|
68
|
+
|
|
69
|
+
Args:
|
|
70
|
+
name: Registered dataset identifier
|
|
71
|
+
**options: Configuration options passed to the factory function
|
|
72
|
+
Common options include:
|
|
73
|
+
- source: 'huggingface', 'local', or custom source
|
|
74
|
+
- data_dir: Path for local datasets
|
|
75
|
+
- split: Dataset split (e.g., 'train', 'test')
|
|
76
|
+
- limit: Maximum number of samples to load
|
|
77
|
+
- subjects: List of subjects to filter
|
|
78
|
+
|
|
79
|
+
Returns:
|
|
80
|
+
List of sample dictionaries ready for generation
|
|
81
|
+
|
|
82
|
+
Raises:
|
|
83
|
+
ValueError: If dataset name is not registered
|
|
84
|
+
"""
|
|
85
|
+
if name not in self._datasets:
|
|
86
|
+
available = list(self._datasets.keys())
|
|
87
|
+
raise ValueError(
|
|
88
|
+
f"Unknown dataset: '{name}'. "
|
|
89
|
+
f"Available datasets: {', '.join(sorted(available)) or 'none'}"
|
|
90
|
+
)
|
|
91
|
+
factory = self._datasets[name]
|
|
92
|
+
return factory(options)
|
|
93
|
+
|
|
94
|
+
def list_datasets(self) -> list[str]:
|
|
95
|
+
"""List all registered dataset names.
|
|
96
|
+
|
|
97
|
+
Returns:
|
|
98
|
+
Sorted list of registered dataset identifiers
|
|
99
|
+
"""
|
|
100
|
+
return sorted(self._datasets.keys())
|
|
101
|
+
|
|
102
|
+
def is_registered(self, name: str) -> bool:
|
|
103
|
+
"""Check if a dataset is registered.
|
|
104
|
+
|
|
105
|
+
Args:
|
|
106
|
+
name: Dataset identifier to check
|
|
107
|
+
|
|
108
|
+
Returns:
|
|
109
|
+
True if the dataset is registered, False otherwise
|
|
110
|
+
"""
|
|
111
|
+
return name in self._datasets
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
# Global registry instance
|
|
115
|
+
_REGISTRY = DatasetRegistry()
|
|
116
|
+
|
|
117
|
+
|
|
118
|
+
def register_dataset(name: str, factory: DatasetFactory) -> None:
|
|
119
|
+
"""Register a dataset factory in the global registry.
|
|
120
|
+
|
|
121
|
+
Args:
|
|
122
|
+
name: Unique identifier for the dataset
|
|
123
|
+
factory: Callable that takes config options and returns samples
|
|
124
|
+
|
|
125
|
+
Example:
|
|
126
|
+
```python
|
|
127
|
+
def create_my_dataset(options):
|
|
128
|
+
from my_module import load_data
|
|
129
|
+
return load_data(
|
|
130
|
+
path=options.get('path'),
|
|
131
|
+
limit=options.get('limit')
|
|
132
|
+
)
|
|
133
|
+
|
|
134
|
+
register_dataset('my-dataset', create_my_dataset)
|
|
135
|
+
```
|
|
136
|
+
"""
|
|
137
|
+
_REGISTRY.register(name, factory)
|
|
138
|
+
|
|
139
|
+
|
|
140
|
+
def unregister_dataset(name: str) -> None:
|
|
141
|
+
"""Unregister a dataset from the global registry.
|
|
142
|
+
|
|
143
|
+
Args:
|
|
144
|
+
name: Dataset identifier to remove
|
|
145
|
+
"""
|
|
146
|
+
_REGISTRY.unregister(name)
|
|
147
|
+
|
|
148
|
+
|
|
149
|
+
def create_dataset(name: str, **options) -> list[dict[str, Any]]:
|
|
150
|
+
"""Create a dataset by loading samples from a registered factory.
|
|
151
|
+
|
|
152
|
+
Args:
|
|
153
|
+
name: Registered dataset identifier
|
|
154
|
+
**options: Configuration options for the dataset
|
|
155
|
+
|
|
156
|
+
Returns:
|
|
157
|
+
List of sample dictionaries
|
|
158
|
+
|
|
159
|
+
Example:
|
|
160
|
+
```python
|
|
161
|
+
samples = create_dataset(
|
|
162
|
+
'math500',
|
|
163
|
+
source='huggingface',
|
|
164
|
+
split='test',
|
|
165
|
+
limit=10
|
|
166
|
+
)
|
|
167
|
+
```
|
|
168
|
+
"""
|
|
169
|
+
return _REGISTRY.create(name, **options)
|
|
170
|
+
|
|
171
|
+
|
|
172
|
+
def list_datasets() -> list[str]:
|
|
173
|
+
"""List all registered datasets.
|
|
174
|
+
|
|
175
|
+
Returns:
|
|
176
|
+
Sorted list of dataset names
|
|
177
|
+
"""
|
|
178
|
+
return _REGISTRY.list_datasets()
|
|
179
|
+
|
|
180
|
+
|
|
181
|
+
def is_dataset_registered(name: str) -> bool:
|
|
182
|
+
"""Check if a dataset is registered.
|
|
183
|
+
|
|
184
|
+
Args:
|
|
185
|
+
name: Dataset identifier
|
|
186
|
+
|
|
187
|
+
Returns:
|
|
188
|
+
True if registered, False otherwise
|
|
189
|
+
"""
|
|
190
|
+
return _REGISTRY.is_registered(name)
|
|
191
|
+
|
|
192
|
+
|
|
193
|
+
__all__ = [
|
|
194
|
+
"DatasetFactory",
|
|
195
|
+
"DatasetRegistry",
|
|
196
|
+
"register_dataset",
|
|
197
|
+
"unregister_dataset",
|
|
198
|
+
"create_dataset",
|
|
199
|
+
"list_datasets",
|
|
200
|
+
"is_dataset_registered",
|
|
201
|
+
]
|
|
@@ -0,0 +1,245 @@
|
|
|
1
|
+
"""Dataset schema and metadata definitions.
|
|
2
|
+
|
|
3
|
+
This module provides enhanced dataset abstractions with schema validation,
|
|
4
|
+
metadata, and filtering capabilities while maintaining backward compatibility.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from __future__ import annotations
|
|
8
|
+
|
|
9
|
+
from dataclasses import dataclass, field
|
|
10
|
+
from typing import Any, Callable, Iterable, Protocol, runtime_checkable
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
@dataclass
|
|
14
|
+
class DatasetSchema:
|
|
15
|
+
"""Describes the structure and validation rules for dataset samples.
|
|
16
|
+
|
|
17
|
+
Examples:
|
|
18
|
+
# Basic schema
|
|
19
|
+
schema = DatasetSchema(
|
|
20
|
+
id_field="unique_id",
|
|
21
|
+
reference_field="answer",
|
|
22
|
+
required_fields={"unique_id", "problem", "answer"},
|
|
23
|
+
)
|
|
24
|
+
|
|
25
|
+
# Schema with validation
|
|
26
|
+
def validate_problem(sample: dict) -> None:
|
|
27
|
+
if len(sample.get("problem", "")) < 10:
|
|
28
|
+
raise ValueError("Problem text too short")
|
|
29
|
+
|
|
30
|
+
schema = DatasetSchema(
|
|
31
|
+
id_field="id",
|
|
32
|
+
reference_field="expected",
|
|
33
|
+
required_fields={"id", "problem", "expected"},
|
|
34
|
+
validators=[validate_problem],
|
|
35
|
+
)
|
|
36
|
+
"""
|
|
37
|
+
|
|
38
|
+
id_field: str
|
|
39
|
+
reference_field: str | None
|
|
40
|
+
required_fields: set[str] = field(default_factory=set)
|
|
41
|
+
optional_fields: set[str] = field(default_factory=set)
|
|
42
|
+
metadata_fields: set[str] = field(default_factory=set)
|
|
43
|
+
validators: list[Callable[[dict], None]] = field(default_factory=list)
|
|
44
|
+
|
|
45
|
+
def validate_sample(self, sample: dict[str, Any]) -> None:
|
|
46
|
+
"""Validate a single sample against this schema.
|
|
47
|
+
|
|
48
|
+
Args:
|
|
49
|
+
sample: Sample to validate
|
|
50
|
+
|
|
51
|
+
Raises:
|
|
52
|
+
ValueError: If validation fails
|
|
53
|
+
"""
|
|
54
|
+
# Check required fields
|
|
55
|
+
for field_name in self.required_fields:
|
|
56
|
+
if field_name not in sample:
|
|
57
|
+
raise ValueError(
|
|
58
|
+
f"Missing required field '{field_name}' in sample {sample.get(self.id_field)}"
|
|
59
|
+
)
|
|
60
|
+
|
|
61
|
+
# Run custom validators
|
|
62
|
+
for validator in self.validators:
|
|
63
|
+
validator(sample)
|
|
64
|
+
|
|
65
|
+
def get_all_fields(self) -> set[str]:
|
|
66
|
+
"""Get all known fields (required + optional + metadata)."""
|
|
67
|
+
return self.required_fields | self.optional_fields | self.metadata_fields
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
@dataclass
|
|
71
|
+
class DatasetMetadata:
|
|
72
|
+
"""Metadata about the entire dataset.
|
|
73
|
+
|
|
74
|
+
This provides information useful for experiment planning, reporting,
|
|
75
|
+
and understanding dataset characteristics.
|
|
76
|
+
|
|
77
|
+
Examples:
|
|
78
|
+
metadata = DatasetMetadata(
|
|
79
|
+
name="MATH-500",
|
|
80
|
+
version="1.0",
|
|
81
|
+
total_samples=500,
|
|
82
|
+
categories={
|
|
83
|
+
"subject": ["algebra", "geometry", "number_theory"],
|
|
84
|
+
"difficulty": ["easy", "medium", "hard"],
|
|
85
|
+
},
|
|
86
|
+
difficulty_distribution={
|
|
87
|
+
"easy": 100,
|
|
88
|
+
"medium": 250,
|
|
89
|
+
"hard": 150,
|
|
90
|
+
},
|
|
91
|
+
description="Math problems from competition mathematics",
|
|
92
|
+
)
|
|
93
|
+
"""
|
|
94
|
+
|
|
95
|
+
name: str
|
|
96
|
+
version: str = "1.0"
|
|
97
|
+
total_samples: int | None = None
|
|
98
|
+
categories: dict[str, list[str]] = field(default_factory=dict)
|
|
99
|
+
difficulty_distribution: dict[str, int] | None = None
|
|
100
|
+
description: str = ""
|
|
101
|
+
source_url: str | None = None
|
|
102
|
+
license: str | None = None
|
|
103
|
+
citation: str | None = None
|
|
104
|
+
custom_metadata: dict[str, Any] = field(default_factory=dict)
|
|
105
|
+
|
|
106
|
+
def get_category_values(self, category: str) -> list[str]:
|
|
107
|
+
"""Get all possible values for a category."""
|
|
108
|
+
return self.categories.get(category, [])
|
|
109
|
+
|
|
110
|
+
def has_category(self, category: str) -> bool:
|
|
111
|
+
"""Check if dataset has a specific category."""
|
|
112
|
+
return category in self.categories
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
@runtime_checkable
|
|
116
|
+
class EnhancedDatasetAdapter(Protocol):
|
|
117
|
+
"""Extended dataset interface with schema and metadata support.
|
|
118
|
+
|
|
119
|
+
This protocol extends the basic DatasetAdapter with additional
|
|
120
|
+
capabilities for schema validation, filtering, and stratification.
|
|
121
|
+
"""
|
|
122
|
+
|
|
123
|
+
def iter_samples(self) -> Iterable[dict[str, Any]]:
|
|
124
|
+
"""Iterate over dataset samples."""
|
|
125
|
+
...
|
|
126
|
+
|
|
127
|
+
def get_schema(self) -> DatasetSchema:
|
|
128
|
+
"""Get the dataset schema."""
|
|
129
|
+
...
|
|
130
|
+
|
|
131
|
+
def get_metadata(self) -> DatasetMetadata:
|
|
132
|
+
"""Get dataset metadata."""
|
|
133
|
+
...
|
|
134
|
+
|
|
135
|
+
def filter(
|
|
136
|
+
self, predicate: Callable[[dict[str, Any]], bool]
|
|
137
|
+
) -> EnhancedDatasetAdapter:
|
|
138
|
+
"""Return filtered view of dataset.
|
|
139
|
+
|
|
140
|
+
Args:
|
|
141
|
+
predicate: Function that returns True for samples to keep
|
|
142
|
+
|
|
143
|
+
Returns:
|
|
144
|
+
New dataset adapter with filtered samples
|
|
145
|
+
"""
|
|
146
|
+
...
|
|
147
|
+
|
|
148
|
+
def limit(self, n: int) -> EnhancedDatasetAdapter:
|
|
149
|
+
"""Return dataset limited to first n samples.
|
|
150
|
+
|
|
151
|
+
Args:
|
|
152
|
+
n: Maximum number of samples
|
|
153
|
+
|
|
154
|
+
Returns:
|
|
155
|
+
New dataset adapter with limited samples
|
|
156
|
+
"""
|
|
157
|
+
...
|
|
158
|
+
|
|
159
|
+
def stratify(
|
|
160
|
+
self, field: str, distribution: dict[str, float]
|
|
161
|
+
) -> EnhancedDatasetAdapter:
|
|
162
|
+
"""Return stratified sample of dataset.
|
|
163
|
+
|
|
164
|
+
Args:
|
|
165
|
+
field: Field to stratify by
|
|
166
|
+
distribution: Desired distribution (values should sum to 1.0)
|
|
167
|
+
|
|
168
|
+
Returns:
|
|
169
|
+
New dataset adapter with stratified samples
|
|
170
|
+
"""
|
|
171
|
+
...
|
|
172
|
+
|
|
173
|
+
|
|
174
|
+
# Common validators
|
|
175
|
+
def validate_non_empty_field(field_name: str) -> Callable[[dict], None]:
|
|
176
|
+
"""Create validator that ensures field is non-empty.
|
|
177
|
+
|
|
178
|
+
Args:
|
|
179
|
+
field_name: Name of field to validate
|
|
180
|
+
|
|
181
|
+
Returns:
|
|
182
|
+
Validator function
|
|
183
|
+
"""
|
|
184
|
+
|
|
185
|
+
def validator(sample: dict) -> None:
|
|
186
|
+
value = sample.get(field_name)
|
|
187
|
+
if not value:
|
|
188
|
+
raise ValueError(f"Field '{field_name}' cannot be empty")
|
|
189
|
+
|
|
190
|
+
return validator
|
|
191
|
+
|
|
192
|
+
|
|
193
|
+
def validate_field_type(field_name: str, expected_type: type) -> Callable[[dict], None]:
|
|
194
|
+
"""Create validator that ensures field has correct type.
|
|
195
|
+
|
|
196
|
+
Args:
|
|
197
|
+
field_name: Name of field to validate
|
|
198
|
+
expected_type: Expected type
|
|
199
|
+
|
|
200
|
+
Returns:
|
|
201
|
+
Validator function
|
|
202
|
+
"""
|
|
203
|
+
|
|
204
|
+
def validator(sample: dict) -> None:
|
|
205
|
+
value = sample.get(field_name)
|
|
206
|
+
if value is not None and not isinstance(value, expected_type):
|
|
207
|
+
raise ValueError(
|
|
208
|
+
f"Field '{field_name}' expected type {expected_type.__name__}, "
|
|
209
|
+
f"got {type(value).__name__}"
|
|
210
|
+
)
|
|
211
|
+
|
|
212
|
+
return validator
|
|
213
|
+
|
|
214
|
+
|
|
215
|
+
def validate_field_in_choices(
|
|
216
|
+
field_name: str, choices: set[str]
|
|
217
|
+
) -> Callable[[dict], None]:
|
|
218
|
+
"""Create validator that ensures field value is in allowed choices.
|
|
219
|
+
|
|
220
|
+
Args:
|
|
221
|
+
field_name: Name of field to validate
|
|
222
|
+
choices: Set of allowed values
|
|
223
|
+
|
|
224
|
+
Returns:
|
|
225
|
+
Validator function
|
|
226
|
+
"""
|
|
227
|
+
|
|
228
|
+
def validator(sample: dict) -> None:
|
|
229
|
+
value = sample.get(field_name)
|
|
230
|
+
if value is not None and value not in choices:
|
|
231
|
+
raise ValueError(
|
|
232
|
+
f"Field '{field_name}' value '{value}' not in allowed choices: {choices}"
|
|
233
|
+
)
|
|
234
|
+
|
|
235
|
+
return validator
|
|
236
|
+
|
|
237
|
+
|
|
238
|
+
__all__ = [
|
|
239
|
+
"DatasetSchema",
|
|
240
|
+
"DatasetMetadata",
|
|
241
|
+
"EnhancedDatasetAdapter",
|
|
242
|
+
"validate_non_empty_field",
|
|
243
|
+
"validate_field_type",
|
|
244
|
+
"validate_field_in_choices",
|
|
245
|
+
]
|