themis-eval 0.1.0__py3-none-any.whl → 0.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (132) hide show
  1. themis/cli/__init__.py +5 -0
  2. themis/cli/__main__.py +6 -0
  3. themis/cli/commands/__init__.py +19 -0
  4. themis/cli/commands/benchmarks.py +221 -0
  5. themis/cli/commands/comparison.py +394 -0
  6. themis/cli/commands/config_commands.py +244 -0
  7. themis/cli/commands/cost.py +214 -0
  8. themis/cli/commands/demo.py +68 -0
  9. themis/cli/commands/info.py +90 -0
  10. themis/cli/commands/leaderboard.py +362 -0
  11. themis/cli/commands/math_benchmarks.py +318 -0
  12. themis/cli/commands/mcq_benchmarks.py +207 -0
  13. themis/cli/commands/sample_run.py +244 -0
  14. themis/cli/commands/visualize.py +299 -0
  15. themis/cli/main.py +93 -0
  16. themis/cli/new_project.py +33 -0
  17. themis/cli/utils.py +51 -0
  18. themis/config/__init__.py +19 -0
  19. themis/config/loader.py +27 -0
  20. themis/config/registry.py +34 -0
  21. themis/config/runtime.py +214 -0
  22. themis/config/schema.py +112 -0
  23. themis/core/__init__.py +5 -0
  24. themis/core/conversation.py +354 -0
  25. themis/core/entities.py +164 -0
  26. themis/core/serialization.py +231 -0
  27. themis/core/tools.py +393 -0
  28. themis/core/types.py +141 -0
  29. themis/datasets/__init__.py +273 -0
  30. themis/datasets/base.py +264 -0
  31. themis/datasets/commonsense_qa.py +174 -0
  32. themis/datasets/competition_math.py +265 -0
  33. themis/datasets/coqa.py +133 -0
  34. themis/datasets/gpqa.py +190 -0
  35. themis/datasets/gsm8k.py +123 -0
  36. themis/datasets/gsm_symbolic.py +124 -0
  37. themis/datasets/math500.py +122 -0
  38. themis/datasets/med_qa.py +179 -0
  39. themis/datasets/medmcqa.py +169 -0
  40. themis/datasets/mmlu_pro.py +262 -0
  41. themis/datasets/piqa.py +146 -0
  42. themis/datasets/registry.py +201 -0
  43. themis/datasets/schema.py +245 -0
  44. themis/datasets/sciq.py +150 -0
  45. themis/datasets/social_i_qa.py +151 -0
  46. themis/datasets/super_gpqa.py +263 -0
  47. themis/evaluation/__init__.py +1 -0
  48. themis/evaluation/conditional.py +410 -0
  49. themis/evaluation/extractors/__init__.py +19 -0
  50. themis/evaluation/extractors/error_taxonomy_extractor.py +80 -0
  51. themis/evaluation/extractors/exceptions.py +7 -0
  52. themis/evaluation/extractors/identity_extractor.py +29 -0
  53. themis/evaluation/extractors/json_field_extractor.py +45 -0
  54. themis/evaluation/extractors/math_verify_extractor.py +37 -0
  55. themis/evaluation/extractors/regex_extractor.py +43 -0
  56. themis/evaluation/math_verify_utils.py +87 -0
  57. themis/evaluation/metrics/__init__.py +21 -0
  58. themis/evaluation/metrics/composite_metric.py +47 -0
  59. themis/evaluation/metrics/consistency_metric.py +80 -0
  60. themis/evaluation/metrics/exact_match.py +51 -0
  61. themis/evaluation/metrics/length_difference_tolerance.py +33 -0
  62. themis/evaluation/metrics/math_verify_accuracy.py +40 -0
  63. themis/evaluation/metrics/pairwise_judge_metric.py +141 -0
  64. themis/evaluation/metrics/response_length.py +33 -0
  65. themis/evaluation/metrics/rubric_judge_metric.py +134 -0
  66. themis/evaluation/pipeline.py +49 -0
  67. themis/evaluation/pipelines/__init__.py +15 -0
  68. themis/evaluation/pipelines/composable_pipeline.py +357 -0
  69. themis/evaluation/pipelines/standard_pipeline.py +288 -0
  70. themis/evaluation/reports.py +293 -0
  71. themis/evaluation/statistics/__init__.py +53 -0
  72. themis/evaluation/statistics/bootstrap.py +79 -0
  73. themis/evaluation/statistics/confidence_intervals.py +121 -0
  74. themis/evaluation/statistics/distributions.py +207 -0
  75. themis/evaluation/statistics/effect_sizes.py +124 -0
  76. themis/evaluation/statistics/hypothesis_tests.py +305 -0
  77. themis/evaluation/statistics/types.py +139 -0
  78. themis/evaluation/strategies/__init__.py +13 -0
  79. themis/evaluation/strategies/attempt_aware_evaluation_strategy.py +51 -0
  80. themis/evaluation/strategies/default_evaluation_strategy.py +25 -0
  81. themis/evaluation/strategies/evaluation_strategy.py +24 -0
  82. themis/evaluation/strategies/judge_evaluation_strategy.py +64 -0
  83. themis/experiment/__init__.py +5 -0
  84. themis/experiment/builder.py +151 -0
  85. themis/experiment/cache_manager.py +129 -0
  86. themis/experiment/comparison.py +631 -0
  87. themis/experiment/cost.py +310 -0
  88. themis/experiment/definitions.py +62 -0
  89. themis/experiment/export.py +690 -0
  90. themis/experiment/export_csv.py +159 -0
  91. themis/experiment/integration_manager.py +104 -0
  92. themis/experiment/math.py +192 -0
  93. themis/experiment/mcq.py +169 -0
  94. themis/experiment/orchestrator.py +373 -0
  95. themis/experiment/pricing.py +317 -0
  96. themis/experiment/storage.py +255 -0
  97. themis/experiment/visualization.py +588 -0
  98. themis/generation/__init__.py +1 -0
  99. themis/generation/agentic_runner.py +420 -0
  100. themis/generation/batching.py +254 -0
  101. themis/generation/clients.py +143 -0
  102. themis/generation/conversation_runner.py +236 -0
  103. themis/generation/plan.py +456 -0
  104. themis/generation/providers/litellm_provider.py +221 -0
  105. themis/generation/providers/vllm_provider.py +135 -0
  106. themis/generation/router.py +34 -0
  107. themis/generation/runner.py +207 -0
  108. themis/generation/strategies.py +98 -0
  109. themis/generation/templates.py +71 -0
  110. themis/generation/turn_strategies.py +393 -0
  111. themis/generation/types.py +9 -0
  112. themis/integrations/__init__.py +0 -0
  113. themis/integrations/huggingface.py +61 -0
  114. themis/integrations/wandb.py +65 -0
  115. themis/interfaces/__init__.py +83 -0
  116. themis/project/__init__.py +20 -0
  117. themis/project/definitions.py +98 -0
  118. themis/project/patterns.py +230 -0
  119. themis/providers/__init__.py +5 -0
  120. themis/providers/registry.py +39 -0
  121. themis/utils/api_generator.py +379 -0
  122. themis/utils/cost_tracking.py +376 -0
  123. themis/utils/dashboard.py +452 -0
  124. themis/utils/logging_utils.py +41 -0
  125. themis/utils/progress.py +58 -0
  126. themis/utils/tracing.py +320 -0
  127. {themis_eval-0.1.0.dist-info → themis_eval-0.1.1.dist-info}/METADATA +1 -1
  128. themis_eval-0.1.1.dist-info/RECORD +134 -0
  129. themis_eval-0.1.0.dist-info/RECORD +0 -8
  130. {themis_eval-0.1.0.dist-info → themis_eval-0.1.1.dist-info}/WHEEL +0 -0
  131. {themis_eval-0.1.0.dist-info → themis_eval-0.1.1.dist-info}/licenses/LICENSE +0 -0
  132. {themis_eval-0.1.0.dist-info → themis_eval-0.1.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,150 @@
1
+ """Helpers for working with the allenai/sciq dataset."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import json
6
+ import random
7
+ from pathlib import Path
8
+ from typing import Any, Iterable, Iterator, List, Sequence
9
+
10
+ from pydantic import BaseModel, Field, field_validator
11
+
12
+ _DATASET_NAME = "allenai/sciq"
13
+
14
+
15
+ class SciQSample(BaseModel):
16
+ unique_id: str
17
+ question: str
18
+ choices: list[str]
19
+ answer: str
20
+ support: str = Field(default="")
21
+ metadata: dict[str, Any] = Field(default_factory=dict)
22
+
23
+ @field_validator("choices", mode="before")
24
+ @classmethod
25
+ def _ensure_choices(cls, value: Any) -> list[str]:
26
+ if value is None:
27
+ return []
28
+ if isinstance(value, (list, tuple)):
29
+ return [str(item) for item in value]
30
+ raise TypeError("choices must be a sequence")
31
+
32
+ def to_generation_example(self) -> dict[str, Any]:
33
+ return {
34
+ "unique_id": self.unique_id,
35
+ "question": self.question,
36
+ "choices": list(self.choices),
37
+ "answer": self.answer,
38
+ "support": self.support,
39
+ "metadata": dict(self.metadata),
40
+ }
41
+
42
+
43
+ def load_sciq(
44
+ *,
45
+ split: str = "test",
46
+ limit: int | None = None,
47
+ source: str = "huggingface",
48
+ data_dir: str | Path | None = None,
49
+ ) -> List[SciQSample]:
50
+ """Load SciQ samples from Hugging Face or a local directory."""
51
+
52
+ if source not in {"huggingface", "local"}:
53
+ raise ValueError(
54
+ f"Unsupported source '{source}'. Expected one of: 'huggingface', 'local'."
55
+ )
56
+
57
+ if source == "huggingface":
58
+ rows = _load_from_huggingface(split=split)
59
+ else:
60
+ if data_dir is None:
61
+ raise ValueError(
62
+ "data_dir must be provided when source='local'. "
63
+ "Pass dataset.data_dir in configs or --data-dir on the CLI."
64
+ )
65
+ rows = _load_from_local(Path(data_dir))
66
+
67
+ samples: list[SciQSample] = []
68
+ for index, row in enumerate(rows, start=1):
69
+ sample = _row_to_sample(row, index=index)
70
+ samples.append(sample)
71
+ if limit is not None and len(samples) >= limit:
72
+ break
73
+ return samples
74
+
75
+
76
+ def _row_to_sample(row: dict[str, Any], *, index: int) -> SciQSample:
77
+ unique_id = (
78
+ row.get("id")
79
+ or row.get("unique_id")
80
+ or f"sciq-{index:05d}"
81
+ )
82
+ question = row.get("question") or ""
83
+
84
+ # SciQ has 'correct_answer', 'distractor1', 'distractor2', 'distractor3'
85
+ correct = str(row.get("correct_answer") or "")
86
+ distractors = [
87
+ str(row.get("distractor1") or ""),
88
+ str(row.get("distractor2") or ""),
89
+ str(row.get("distractor3") or ""),
90
+ ]
91
+
92
+ # Filter empty distractors just in case
93
+ distractors = [d for d in distractors if d]
94
+
95
+ choices = [correct] + distractors
96
+ # Sort to be deterministic
97
+ choices.sort()
98
+
99
+ support = str(row.get("support") or "")
100
+
101
+ metadata_keys = {
102
+ "question", "correct_answer", "distractor1", "distractor2", "distractor3", "support", "id"
103
+ }
104
+ metadata = {key: value for key, value in row.items() if key not in metadata_keys}
105
+
106
+ return SciQSample(
107
+ unique_id=str(unique_id),
108
+ question=str(question),
109
+ choices=choices,
110
+ answer=correct,
111
+ support=support,
112
+ metadata=metadata,
113
+ )
114
+
115
+
116
+ def _load_from_huggingface(*, split: str) -> Iterable[dict[str, Any]]:
117
+ try:
118
+ from datasets import load_dataset
119
+ except ImportError as exc: # pragma: no cover - optional dependency
120
+ raise RuntimeError(
121
+ "datasets is required to load SciQ from Hugging Face. Install it via `uv pip install '.[hf]'`."
122
+ ) from exc
123
+
124
+ dataset = load_dataset(_DATASET_NAME, split=split)
125
+ for row in dataset:
126
+ yield dict(row)
127
+
128
+
129
+ def _load_from_local(root: Path) -> Iterator[dict[str, Any]]:
130
+ if not root.exists():
131
+ raise FileNotFoundError(f"Local dataset directory not found: {root}")
132
+
133
+ for path in root.rglob("*"):
134
+ if path.suffix.lower() == ".json":
135
+ with path.open("r", encoding="utf-8") as handle:
136
+ row = json.load(handle)
137
+ row.setdefault("id", path.stem)
138
+ yield row
139
+ elif path.suffix.lower() in {".jsonl", ".ndjson"}:
140
+ with path.open("r", encoding="utf-8") as handle:
141
+ for line_num, line in enumerate(handle, start=1):
142
+ line = line.strip()
143
+ if not line:
144
+ continue
145
+ row = json.loads(line)
146
+ row.setdefault("id", f"{path.stem}-{line_num}")
147
+ yield row
148
+
149
+
150
+ __all__ = ["SciQSample", "load_sciq"]
@@ -0,0 +1,151 @@
1
+ """Helpers for working with the allenai/social_i_qa dataset."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import json
6
+ from pathlib import Path
7
+ from typing import Any, Iterable, Iterator, List, Sequence
8
+
9
+ from pydantic import BaseModel, Field, field_validator
10
+
11
+ _DATASET_NAME = "allenai/social_i_qa"
12
+
13
+
14
+ class SocialIQaSample(BaseModel):
15
+ unique_id: str
16
+ context: str
17
+ question: str
18
+ choices: list[str]
19
+ answer: str
20
+ metadata: dict[str, Any] = Field(default_factory=dict)
21
+
22
+ @field_validator("choices", mode="before")
23
+ @classmethod
24
+ def _ensure_choices(cls, value: Any) -> list[str]:
25
+ if value is None:
26
+ return []
27
+ if isinstance(value, (list, tuple)):
28
+ return [str(item) for item in value]
29
+ raise TypeError("choices must be a sequence")
30
+
31
+ def to_generation_example(self) -> dict[str, Any]:
32
+ return {
33
+ "unique_id": self.unique_id,
34
+ "context": self.context,
35
+ "question": self.question,
36
+ "choices": list(self.choices),
37
+ "answer": self.answer,
38
+ "metadata": dict(self.metadata),
39
+ }
40
+
41
+
42
+ def load_social_i_qa(
43
+ *,
44
+ split: str = "validation", # Test set usually has no labels
45
+ limit: int | None = None,
46
+ source: str = "huggingface",
47
+ data_dir: str | Path | None = None,
48
+ ) -> List[SocialIQaSample]:
49
+ """Load Social IQA samples from Hugging Face or a local directory."""
50
+
51
+ if source not in {"huggingface", "local"}:
52
+ raise ValueError(
53
+ f"Unsupported source '{source}'. Expected one of: 'huggingface', 'local'."
54
+ )
55
+
56
+ if source == "huggingface":
57
+ rows = _load_from_huggingface(split=split)
58
+ else:
59
+ if data_dir is None:
60
+ raise ValueError(
61
+ "data_dir must be provided when source='local'. "
62
+ "Pass dataset.data_dir in configs or --data-dir on the CLI."
63
+ )
64
+ rows = _load_from_local(Path(data_dir))
65
+
66
+ samples: list[SocialIQaSample] = []
67
+ for index, row in enumerate(rows, start=1):
68
+ sample = _row_to_sample(row, index=index)
69
+ samples.append(sample)
70
+ if limit is not None and len(samples) >= limit:
71
+ break
72
+ return samples
73
+
74
+
75
+ def _row_to_sample(row: dict[str, Any], *, index: int) -> SocialIQaSample:
76
+ unique_id = (
77
+ row.get("id")
78
+ or row.get("unique_id")
79
+ or f"social-iqa-{index:05d}"
80
+ )
81
+ context = row.get("context") or ""
82
+ question = row.get("question") or ""
83
+
84
+ # Social IQA has 'answerA', 'answerB', 'answerC'
85
+ choices = [
86
+ str(row.get("answerA") or ""),
87
+ str(row.get("answerB") or ""),
88
+ str(row.get("answerC") or ""),
89
+ ]
90
+
91
+ # label is '1', '2', '3' (strings)
92
+ label = row.get("label")
93
+ answer = ""
94
+ if label is not None:
95
+ try:
96
+ label_int = int(label)
97
+ if 1 <= label_int <= len(choices):
98
+ answer = choices[label_int - 1]
99
+ except (ValueError, TypeError):
100
+ pass
101
+
102
+ metadata_keys = {
103
+ "context", "question", "answerA", "answerB", "answerC", "label", "id"
104
+ }
105
+ metadata = {key: value for key, value in row.items() if key not in metadata_keys}
106
+
107
+ return SocialIQaSample(
108
+ unique_id=str(unique_id),
109
+ context=str(context),
110
+ question=str(question),
111
+ choices=choices,
112
+ answer=answer,
113
+ metadata=metadata,
114
+ )
115
+
116
+
117
+ def _load_from_huggingface(*, split: str) -> Iterable[dict[str, Any]]:
118
+ try:
119
+ from datasets import load_dataset
120
+ except ImportError as exc: # pragma: no cover - optional dependency
121
+ raise RuntimeError(
122
+ "datasets is required to load Social IQA from Hugging Face. Install it via `uv pip install '.[hf]'`."
123
+ ) from exc
124
+
125
+ dataset = load_dataset(_DATASET_NAME, split=split)
126
+ for row in dataset:
127
+ yield dict(row)
128
+
129
+
130
+ def _load_from_local(root: Path) -> Iterator[dict[str, Any]]:
131
+ if not root.exists():
132
+ raise FileNotFoundError(f"Local dataset directory not found: {root}")
133
+
134
+ for path in root.rglob("*"):
135
+ if path.suffix.lower() == ".json":
136
+ with path.open("r", encoding="utf-8") as handle:
137
+ row = json.load(handle)
138
+ row.setdefault("id", path.stem)
139
+ yield row
140
+ elif path.suffix.lower() in {".jsonl", ".ndjson"}:
141
+ with path.open("r", encoding="utf-8") as handle:
142
+ for line_num, line in enumerate(handle, start=1):
143
+ line = line.strip()
144
+ if not line:
145
+ continue
146
+ row = json.loads(line)
147
+ row.setdefault("id", f"{path.stem}-{line_num}")
148
+ yield row
149
+
150
+
151
+ __all__ = ["SocialIQaSample", "load_social_i_qa"]
@@ -0,0 +1,263 @@
1
+ """Helpers for working with the m-a-p/SuperGPQA dataset."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import json
6
+ import string
7
+ from pathlib import Path
8
+ from typing import Any, Iterable, Iterator, List, Sequence
9
+
10
+ from pydantic import BaseModel, Field, ValidationInfo, field_validator
11
+
12
+ _DATASET_NAME = "m-a-p/SuperGPQA"
13
+ _CHOICE_LABELS = tuple(string.ascii_uppercase)
14
+
15
+
16
+ def _index_to_label(index: int) -> str:
17
+ if 0 <= index < len(_CHOICE_LABELS):
18
+ return _CHOICE_LABELS[index]
19
+ return str(index)
20
+
21
+
22
+ def _normalize_answer(value: Any, *, total_choices: int | None = None) -> str:
23
+ if isinstance(value, int):
24
+ return _index_to_label(value)
25
+ if isinstance(value, float):
26
+ as_int = int(value)
27
+ if as_int == value:
28
+ return _index_to_label(as_int)
29
+ text = str(value or "").strip()
30
+ if not text:
31
+ return ""
32
+ lowered = text.lower()
33
+ if lowered.startswith("option "):
34
+ text = text.split(" ", 1)[-1]
35
+ if lowered.startswith("choice "):
36
+ text = text.split(" ", 1)[-1]
37
+ if text.isdigit():
38
+ index = int(text)
39
+ if total_choices is None or 0 <= index < total_choices:
40
+ return _index_to_label(index)
41
+ text = text.strip().rstrip(".")
42
+ if len(text) == 1 and text.isalpha():
43
+ return text.upper()
44
+ if total_choices is not None:
45
+ mapping = {str(idx): _index_to_label(idx) for idx in range(total_choices)}
46
+ normalized = mapping.get(text)
47
+ if normalized:
48
+ return normalized
49
+ return text
50
+
51
+
52
+ class SuperGpqaSample(BaseModel):
53
+ unique_id: str
54
+ question: str
55
+ choices: list[str]
56
+ answer: str
57
+ subject: str = Field(default="unknown")
58
+ metadata: dict[str, Any] = Field(default_factory=dict)
59
+ choice_labels: list[str] = Field(default_factory=list)
60
+
61
+ @field_validator("choices", mode="before")
62
+ @classmethod
63
+ def _ensure_choices(cls, value: Any) -> list[str]:
64
+ if value is None:
65
+ return []
66
+ if isinstance(value, dict):
67
+ # Sort by key to keep deterministic order
68
+ return [str(v) for _, v in sorted(value.items())]
69
+ if isinstance(value, (list, tuple)):
70
+ return [str(item) for item in value]
71
+ raise TypeError("choices must be a sequence or mapping")
72
+
73
+ @field_validator("choice_labels", mode="before")
74
+ @classmethod
75
+ def _build_choice_labels(cls, value: Any, info: ValidationInfo) -> list[str]:
76
+ if value:
77
+ return [str(item) for item in value]
78
+ choices = info.data.get("choices") if hasattr(info, "data") else None
79
+ total = len(choices) if isinstance(choices, list) else 0
80
+ return [*_CHOICE_LABELS[:total]]
81
+
82
+ @field_validator("answer", mode="before")
83
+ @classmethod
84
+ def _normalize_answer_field(cls, value: Any, info: ValidationInfo) -> str:
85
+ choices = info.data.get("choices") if hasattr(info, "data") else None
86
+ total = len(choices) if isinstance(choices, list) else None
87
+ return _normalize_answer(value, total_choices=total)
88
+
89
+ def to_generation_example(self) -> dict[str, Any]:
90
+ effective_labels = (
91
+ list(self.choice_labels)
92
+ if self.choice_labels
93
+ else list(_CHOICE_LABELS[: len(self.choices)])
94
+ )
95
+ return {
96
+ "unique_id": self.unique_id,
97
+ "question": self.question,
98
+ "choices": list(self.choices),
99
+ "choice_labels": effective_labels,
100
+ "answer": self.answer,
101
+ "subject": self.subject,
102
+ "metadata": dict(self.metadata),
103
+ }
104
+
105
+
106
+ def load_super_gpqa(
107
+ *,
108
+ split: str = "test",
109
+ limit: int | None = None,
110
+ source: str = "huggingface",
111
+ data_dir: str | Path | None = None,
112
+ subjects: Sequence[str] | None = None,
113
+ ) -> List[SuperGpqaSample]:
114
+ """Load SuperGPQA samples from Hugging Face or a local directory."""
115
+
116
+ if source not in {"huggingface", "local"}:
117
+ raise ValueError(
118
+ f"Unsupported source '{source}'. Expected one of: 'huggingface', 'local'."
119
+ )
120
+
121
+ if source == "huggingface":
122
+ rows = _load_from_huggingface(split=split)
123
+ else:
124
+ if data_dir is None:
125
+ raise ValueError(
126
+ "data_dir must be provided when source='local'. "
127
+ "Pass dataset.data_dir in configs or --data-dir on the CLI."
128
+ )
129
+ rows = _load_from_local(Path(data_dir))
130
+
131
+ samples: list[SuperGpqaSample] = []
132
+ selected_subjects = {s.lower() for s in subjects} if subjects else None
133
+ for index, row in enumerate(rows, start=1):
134
+ subject = _extract_subject(row)
135
+ if selected_subjects and subject.lower() not in selected_subjects:
136
+ continue
137
+ sample = _row_to_sample(row, index=index, subject=subject)
138
+ samples.append(sample)
139
+ if limit is not None and len(samples) >= limit:
140
+ break
141
+ return samples
142
+
143
+
144
+ def _extract_subject(row: dict[str, Any]) -> str:
145
+ for key in ("subject", "category", "field", "domain", "track"):
146
+ value = row.get(key)
147
+ if value:
148
+ return str(value)
149
+ return "unknown"
150
+
151
+
152
+ def _row_to_sample(row: dict[str, Any], *, index: int, subject: str) -> SuperGpqaSample:
153
+ unique_id = (
154
+ row.get("id")
155
+ or row.get("question_id")
156
+ or row.get("unique_id")
157
+ or f"supergpqa-{index:05d}"
158
+ )
159
+ question = row.get("question") or row.get("Question") or row.get("prompt") or ""
160
+ choices = _extract_choices(row)
161
+ answer = (
162
+ row.get("answer")
163
+ or row.get("Answer")
164
+ or row.get("correct_answer")
165
+ or row.get("correct")
166
+ or ""
167
+ )
168
+ metadata_keys = {
169
+ "question",
170
+ "Question",
171
+ "prompt",
172
+ "choices",
173
+ "options",
174
+ "answer",
175
+ "Answer",
176
+ "correct_answer",
177
+ "correct",
178
+ }
179
+ metadata = {key: value for key, value in row.items() if key not in metadata_keys}
180
+ sample = SuperGpqaSample.model_validate(
181
+ {
182
+ "unique_id": str(unique_id),
183
+ "question": str(question),
184
+ "choices": choices,
185
+ "answer": answer,
186
+ "subject": str(subject),
187
+ "metadata": metadata,
188
+ }
189
+ )
190
+ return sample
191
+
192
+
193
+ def _extract_choices(row: dict[str, Any]) -> list[str]:
194
+ candidates = row.get("choices") or row.get("options") or row.get("Choices")
195
+ if isinstance(candidates, dict):
196
+ return [str(value) for _, value in sorted(candidates.items())]
197
+ if isinstance(candidates, (list, tuple)):
198
+ return [str(item) for item in candidates]
199
+
200
+ choice_keys = [
201
+ "A",
202
+ "B",
203
+ "C",
204
+ "D",
205
+ "E",
206
+ "F",
207
+ "choice_a",
208
+ "choice_b",
209
+ "choice_c",
210
+ "choice_d",
211
+ "choice_e",
212
+ "choice_f",
213
+ "option_a",
214
+ "option_b",
215
+ "option_c",
216
+ "option_d",
217
+ "option_e",
218
+ "option_f",
219
+ ]
220
+ collected: list[str] = []
221
+ for key in choice_keys:
222
+ if key in row:
223
+ collected.append(str(row[key]))
224
+ if collected:
225
+ return collected
226
+ return []
227
+
228
+
229
+ def _load_from_huggingface(*, split: str) -> Iterable[dict[str, Any]]:
230
+ try:
231
+ from datasets import load_dataset
232
+ except ImportError as exc: # pragma: no cover - optional dependency
233
+ raise RuntimeError(
234
+ "datasets is required to load SuperGPQA from Hugging Face. Install it via `uv pip install '.[hf]'`."
235
+ ) from exc
236
+
237
+ dataset = load_dataset(_DATASET_NAME, split=split)
238
+ for row in dataset:
239
+ yield dict(row)
240
+
241
+
242
+ def _load_from_local(root: Path) -> Iterator[dict[str, Any]]:
243
+ if not root.exists():
244
+ raise FileNotFoundError(f"Local dataset directory not found: {root}")
245
+
246
+ for path in root.rglob("*"):
247
+ if path.suffix.lower() == ".json":
248
+ with path.open("r", encoding="utf-8") as handle:
249
+ row = json.load(handle)
250
+ row.setdefault("id", path.stem)
251
+ yield row
252
+ elif path.suffix.lower() in {".jsonl", ".ndjson"}:
253
+ with path.open("r", encoding="utf-8") as handle:
254
+ for line_num, line in enumerate(handle, start=1):
255
+ line = line.strip()
256
+ if not line:
257
+ continue
258
+ row = json.loads(line)
259
+ row.setdefault("id", f"{path.stem}-{line_num}")
260
+ yield row
261
+
262
+
263
+ __all__ = ["SuperGpqaSample", "load_super_gpqa"]
@@ -0,0 +1 @@
1
+ """Evaluation domain primitives."""