cotlab 0.8.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cotlab/__init__.py +3 -0
- cotlab/analyse_experiments.py +392 -0
- cotlab/analysis/__init__.py +11 -0
- cotlab/analysis/cot_parser.py +243 -0
- cotlab/analysis/faithfulness_metrics.py +192 -0
- cotlab/backends/__init__.py +16 -0
- cotlab/backends/base.py +78 -0
- cotlab/backends/transformers_backend.py +335 -0
- cotlab/backends/vllm_backend.py +227 -0
- cotlab/cli.py +83 -0
- cotlab/core/__init__.py +34 -0
- cotlab/core/base.py +749 -0
- cotlab/core/config.py +90 -0
- cotlab/core/registry.py +68 -0
- cotlab/datasets/__init__.py +45 -0
- cotlab/datasets/loaders.py +1889 -0
- cotlab/experiment/__init__.py +315 -0
- cotlab/experiments/__init__.py +43 -0
- cotlab/experiments/activation_compare.py +290 -0
- cotlab/experiments/activation_patching.py +1050 -0
- cotlab/experiments/attention_analysis.py +885 -0
- cotlab/experiments/classification.py +235 -0
- cotlab/experiments/composite_shift_detector.py +524 -0
- cotlab/experiments/cot_ablation.py +277 -0
- cotlab/experiments/cot_faithfulness.py +187 -0
- cotlab/experiments/cot_heads.py +208 -0
- cotlab/experiments/full_layer_cot.py +232 -0
- cotlab/experiments/full_layer_patching.py +225 -0
- cotlab/experiments/h_neuron_analysis.py +712 -0
- cotlab/experiments/logit_lens.py +439 -0
- cotlab/experiments/multi_head_cot.py +220 -0
- cotlab/experiments/multi_head_patching.py +229 -0
- cotlab/experiments/probing_classifier.py +402 -0
- cotlab/experiments/residual_norm_ood.py +413 -0
- cotlab/experiments/sae_feature_analysis.py +673 -0
- cotlab/experiments/steering_vectors.py +223 -0
- cotlab/experiments/sycophancy_heads.py +224 -0
- cotlab/logging/__init__.py +5 -0
- cotlab/logging/json_logger.py +161 -0
- cotlab/main.py +317 -0
- cotlab/patching/__init__.py +24 -0
- cotlab/patching/cache.py +141 -0
- cotlab/patching/hooks.py +558 -0
- cotlab/patching/interventions.py +86 -0
- cotlab/patching/patcher.py +439 -0
- cotlab/patching/sae.py +181 -0
- cotlab/prompts/__init__.py +43 -0
- cotlab/prompts/cardiology.py +378 -0
- cotlab/prompts/histopathology.py +265 -0
- cotlab/prompts/length_matched_strategies.py +157 -0
- cotlab/prompts/mcq.py +193 -0
- cotlab/prompts/neurology.py +353 -0
- cotlab/prompts/oncology.py +367 -0
- cotlab/prompts/plab.py +162 -0
- cotlab/prompts/pubhealthbench.py +82 -0
- cotlab/prompts/pubmedqa.py +173 -0
- cotlab/prompts/radiology.py +414 -0
- cotlab/prompts/strategies.py +939 -0
- cotlab/prompts/tcga.py +168 -0
- cotlab/runner.py +204 -0
- cotlab-0.8.0.dist-info/METADATA +166 -0
- cotlab-0.8.0.dist-info/RECORD +65 -0
- cotlab-0.8.0.dist-info/WHEEL +4 -0
- cotlab-0.8.0.dist-info/entry_points.txt +3 -0
- cotlab-0.8.0.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,1889 @@
|
|
|
1
|
+
"""Dataset loaders for CoT research.
|
|
2
|
+
|
|
3
|
+
Supports JSON format with standardized structure:
|
|
4
|
+
{
|
|
5
|
+
"id": "unique_id",
|
|
6
|
+
"input": { ... },
|
|
7
|
+
"output": { ... },
|
|
8
|
+
"metadata": { ... }
|
|
9
|
+
}
|
|
10
|
+
|
|
11
|
+
Also supports CSV format with automatic detection based on file extension.
|
|
12
|
+
|
|
13
|
+
Datasets can specify compatible prompts via get_compatible_prompts() for
|
|
14
|
+
restricting which prompt strategies can be used.
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
import csv
|
|
18
|
+
import json
|
|
19
|
+
import re
|
|
20
|
+
from abc import ABC, abstractmethod
|
|
21
|
+
from dataclasses import dataclass
|
|
22
|
+
from pathlib import Path
|
|
23
|
+
from typing import Any, Dict, Iterator, List, Optional
|
|
24
|
+
|
|
25
|
+
from ..core.registry import Registry
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
@dataclass
|
|
29
|
+
class Sample:
|
|
30
|
+
"""A single data sample."""
|
|
31
|
+
|
|
32
|
+
idx: int
|
|
33
|
+
text: str
|
|
34
|
+
label: Optional[Any] = None
|
|
35
|
+
metadata: Optional[Dict[str, Any]] = None
|
|
36
|
+
|
|
37
|
+
def __post_init__(self):
|
|
38
|
+
if self.metadata is None:
|
|
39
|
+
self.metadata = {}
|
|
40
|
+
|
|
41
|
+
def to_dict(self) -> Dict[str, Any]:
|
|
42
|
+
return {"idx": self.idx, "text": self.text, "label": self.label, "metadata": self.metadata}
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
class BaseDataset(ABC):
|
|
46
|
+
"""Abstract base class for datasets."""
|
|
47
|
+
|
|
48
|
+
@property
|
|
49
|
+
@abstractmethod
|
|
50
|
+
def name(self) -> str:
|
|
51
|
+
"""Dataset name."""
|
|
52
|
+
...
|
|
53
|
+
|
|
54
|
+
@abstractmethod
|
|
55
|
+
def __len__(self) -> int:
|
|
56
|
+
"""Number of samples."""
|
|
57
|
+
...
|
|
58
|
+
|
|
59
|
+
@abstractmethod
|
|
60
|
+
def __getitem__(self, idx: int) -> Sample:
|
|
61
|
+
"""Get a sample by index."""
|
|
62
|
+
...
|
|
63
|
+
|
|
64
|
+
def __iter__(self) -> Iterator[Sample]:
|
|
65
|
+
"""Iterate over samples."""
|
|
66
|
+
for i in range(len(self)):
|
|
67
|
+
yield self[i]
|
|
68
|
+
|
|
69
|
+
def sample(self, n: int, seed: int = 42) -> List[Sample]:
|
|
70
|
+
"""Random sample of n items."""
|
|
71
|
+
import random
|
|
72
|
+
|
|
73
|
+
random.seed(seed)
|
|
74
|
+
indices = random.sample(range(len(self)), min(n, len(self)))
|
|
75
|
+
return [self[i] for i in indices]
|
|
76
|
+
|
|
77
|
+
def get_compatible_prompts(self) -> Optional[List[str]]:
|
|
78
|
+
"""
|
|
79
|
+
Return list of compatible prompt names, or None if compatible with all.
|
|
80
|
+
|
|
81
|
+
Override this in specialized datasets to restrict usage.
|
|
82
|
+
"""
|
|
83
|
+
return None
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
class JSONDataset(BaseDataset):
|
|
87
|
+
"""Base class for JSON/CSV-based datasets with automatic format detection."""
|
|
88
|
+
|
|
89
|
+
def __init__(self, name: str, path: str, **kwargs):
|
|
90
|
+
self._name = name
|
|
91
|
+
self.path = self._resolve_path_from_registry(name, path)
|
|
92
|
+
self._samples: List[Sample] = []
|
|
93
|
+
self._load()
|
|
94
|
+
|
|
95
|
+
def _resolve_path_from_registry(self, name: str, default_path: str) -> Path:
|
|
96
|
+
"""Resolve path from data/datasets.yaml registry."""
|
|
97
|
+
import yaml
|
|
98
|
+
from huggingface_hub import hf_hub_download
|
|
99
|
+
|
|
100
|
+
# Locate registry relative to this file (src/cotlab/datasets/loaders.py -> root/data/datasets.yaml)
|
|
101
|
+
# root is 3 levels up from this file's directory: src/cotlab/datasets -> src/cotlab -> src -> root
|
|
102
|
+
root_dir = Path(__file__).parent.parent.parent.parent
|
|
103
|
+
registry_path = root_dir / "data/datasets.yaml"
|
|
104
|
+
|
|
105
|
+
if not registry_path.exists():
|
|
106
|
+
# Fallback: try CWD
|
|
107
|
+
registry_path = Path("data/datasets.yaml")
|
|
108
|
+
if not registry_path.exists():
|
|
109
|
+
raise FileNotFoundError(
|
|
110
|
+
"datasets.yaml not found. Configure data/datasets.yaml with a Hugging Face repo_id."
|
|
111
|
+
)
|
|
112
|
+
|
|
113
|
+
try:
|
|
114
|
+
with open(registry_path, "r") as f:
|
|
115
|
+
config = yaml.safe_load(f)
|
|
116
|
+
|
|
117
|
+
# 1. Get Repo ID
|
|
118
|
+
ds_config = config.get("datasets", {}).get(name, {})
|
|
119
|
+
repo_id = ds_config.get("repo_id", config.get("default", {}).get("repo_id"))
|
|
120
|
+
|
|
121
|
+
if not repo_id:
|
|
122
|
+
raise ValueError(
|
|
123
|
+
f"No repo_id found for dataset '{name}'. Set it in data/datasets.yaml."
|
|
124
|
+
)
|
|
125
|
+
|
|
126
|
+
# 2. Determine Filename
|
|
127
|
+
# If explicit path in registry, use it.
|
|
128
|
+
# Otherwise, infer from default locally-styled path (e.g. data/radiology.json -> radiology.json)
|
|
129
|
+
filename = ds_config.get("path")
|
|
130
|
+
if not filename:
|
|
131
|
+
# heuristic: strip 'data/' prefix if present to map to HF root
|
|
132
|
+
p = Path(default_path)
|
|
133
|
+
if "data" in p.parts:
|
|
134
|
+
# e.g. data/foo -> foo, data/tcga/foo -> tcga/foo
|
|
135
|
+
try:
|
|
136
|
+
filename = str(p.relative_to("data"))
|
|
137
|
+
except ValueError:
|
|
138
|
+
filename = p.name
|
|
139
|
+
else:
|
|
140
|
+
filename = p.name
|
|
141
|
+
|
|
142
|
+
# 3. Download
|
|
143
|
+
try:
|
|
144
|
+
cached_path = hf_hub_download(
|
|
145
|
+
repo_id=repo_id, filename=filename, repo_type="dataset"
|
|
146
|
+
)
|
|
147
|
+
return Path(cached_path)
|
|
148
|
+
except Exception as e:
|
|
149
|
+
raise FileNotFoundError(
|
|
150
|
+
f"Failed to download {name} ({filename}) from HF repo {repo_id}: {e}"
|
|
151
|
+
)
|
|
152
|
+
except Exception as e:
|
|
153
|
+
raise FileNotFoundError(f"Failed to load datasets registry: {e}")
|
|
154
|
+
|
|
155
|
+
def _load(self):
|
|
156
|
+
"""Load samples from JSON or CSV file based on extension."""
|
|
157
|
+
if not self.path.exists():
|
|
158
|
+
raise FileNotFoundError(f"Dataset not found: {self.path}")
|
|
159
|
+
|
|
160
|
+
if self.path.suffix.lower() == ".csv":
|
|
161
|
+
self._load_csv()
|
|
162
|
+
else:
|
|
163
|
+
self._load_json()
|
|
164
|
+
|
|
165
|
+
def _load_json(self):
|
|
166
|
+
"""Load samples from JSON file."""
|
|
167
|
+
with open(self.path, "r", encoding="utf-8") as f:
|
|
168
|
+
data = json.load(f)
|
|
169
|
+
|
|
170
|
+
for idx, item in enumerate(data):
|
|
171
|
+
sample = self._parse_item(idx, item)
|
|
172
|
+
self._samples.append(sample)
|
|
173
|
+
|
|
174
|
+
def _load_csv(self):
|
|
175
|
+
"""Load samples from CSV file."""
|
|
176
|
+
with open(self.path, "r", encoding="utf-8") as f:
|
|
177
|
+
reader = csv.DictReader(f)
|
|
178
|
+
for idx, row in enumerate(reader):
|
|
179
|
+
sample = self._parse_csv_row(idx, row)
|
|
180
|
+
self._samples.append(sample)
|
|
181
|
+
|
|
182
|
+
@abstractmethod
|
|
183
|
+
def _parse_item(self, idx: int, item: Dict[str, Any]) -> Sample:
|
|
184
|
+
"""Parse a single JSON item into a Sample. Override in subclasses."""
|
|
185
|
+
...
|
|
186
|
+
|
|
187
|
+
def _parse_csv_row(self, idx: int, row: Dict[str, Any]) -> Sample:
|
|
188
|
+
"""Parse a single CSV row into a Sample. Override in subclasses for custom CSV handling."""
|
|
189
|
+
# Default implementation - subclasses should override for specific CSV formats
|
|
190
|
+
raise NotImplementedError(
|
|
191
|
+
f"CSV loading not implemented for {self.__class__.__name__}. "
|
|
192
|
+
"Override _parse_csv_row() or use a JSON file."
|
|
193
|
+
)
|
|
194
|
+
|
|
195
|
+
@property
|
|
196
|
+
def name(self) -> str:
|
|
197
|
+
return self._name
|
|
198
|
+
|
|
199
|
+
def __len__(self) -> int:
|
|
200
|
+
return len(self._samples)
|
|
201
|
+
|
|
202
|
+
def __getitem__(self, idx: int) -> Sample:
|
|
203
|
+
return self._samples[idx]
|
|
204
|
+
|
|
205
|
+
|
|
206
|
+
@Registry.register_dataset("radiology")
|
|
207
|
+
class RadiologyDataset(JSONDataset):
|
|
208
|
+
"""Radiology reports dataset for pathological fracture detection."""
|
|
209
|
+
|
|
210
|
+
def __init__(
|
|
211
|
+
self,
|
|
212
|
+
name: str = "radiology",
|
|
213
|
+
path: str = "data/radiology.json",
|
|
214
|
+
**kwargs,
|
|
215
|
+
):
|
|
216
|
+
super().__init__(name, path, **kwargs)
|
|
217
|
+
|
|
218
|
+
def _parse_item(self, idx: int, item: Dict[str, Any]) -> Sample:
|
|
219
|
+
return Sample(
|
|
220
|
+
idx=idx,
|
|
221
|
+
text=item["input"]["report"],
|
|
222
|
+
label=item["output"]["pathological_fracture"],
|
|
223
|
+
metadata=item.get("metadata", {}),
|
|
224
|
+
)
|
|
225
|
+
|
|
226
|
+
def _parse_csv_row(self, idx: int, row: Dict[str, Any]) -> Sample:
|
|
227
|
+
# CSV columns: Synthetic (text), Flag (label)
|
|
228
|
+
return Sample(
|
|
229
|
+
idx=idx,
|
|
230
|
+
text=row.get("Synthetic", ""),
|
|
231
|
+
label=row.get("Flag", ""),
|
|
232
|
+
metadata={},
|
|
233
|
+
)
|
|
234
|
+
|
|
235
|
+
def get_compatible_prompts(self) -> list[str]:
|
|
236
|
+
"""
|
|
237
|
+
Radiology dataset only works with radiology prompt.
|
|
238
|
+
|
|
239
|
+
This dataset is for pathological fracture detection and should
|
|
240
|
+
NOT be used with general medical QA prompts.
|
|
241
|
+
"""
|
|
242
|
+
return ["radiology"]
|
|
243
|
+
|
|
244
|
+
|
|
245
|
+
@Registry.register_dataset("cardiology")
|
|
246
|
+
class CardiologyDataset(JSONDataset):
|
|
247
|
+
"""Cardiology reports dataset for congenital heart defect detection."""
|
|
248
|
+
|
|
249
|
+
def __init__(
|
|
250
|
+
self,
|
|
251
|
+
name: str = "cardiology",
|
|
252
|
+
path: str = "data/cardiology.json",
|
|
253
|
+
**kwargs,
|
|
254
|
+
):
|
|
255
|
+
super().__init__(name, path, **kwargs)
|
|
256
|
+
|
|
257
|
+
def _parse_item(self, idx: int, item: Dict[str, Any]) -> Sample:
|
|
258
|
+
return Sample(
|
|
259
|
+
idx=idx,
|
|
260
|
+
text=item["input"]["report"],
|
|
261
|
+
label=item["output"]["congenital_heart_defect"],
|
|
262
|
+
metadata=item.get("metadata", {}),
|
|
263
|
+
)
|
|
264
|
+
|
|
265
|
+
def get_compatible_prompts(self) -> list[str]:
|
|
266
|
+
"""
|
|
267
|
+
Cardiology dataset only works with cardiology prompt.
|
|
268
|
+
|
|
269
|
+
This dataset is for congenital heart defect detection and should
|
|
270
|
+
NOT be used with general medical QA prompts.
|
|
271
|
+
"""
|
|
272
|
+
return ["cardiology"]
|
|
273
|
+
|
|
274
|
+
|
|
275
|
+
@Registry.register_dataset("neurology")
|
|
276
|
+
class NeurologyDataset(JSONDataset):
|
|
277
|
+
"""Neurology reports dataset for neurological abnormality detection."""
|
|
278
|
+
|
|
279
|
+
def __init__(
|
|
280
|
+
self,
|
|
281
|
+
name: str = "neurology",
|
|
282
|
+
path: str = "data/neurology.json",
|
|
283
|
+
**kwargs,
|
|
284
|
+
):
|
|
285
|
+
super().__init__(name, path, **kwargs)
|
|
286
|
+
|
|
287
|
+
def _parse_item(self, idx: int, item: Dict[str, Any]) -> Sample:
|
|
288
|
+
return Sample(
|
|
289
|
+
idx=idx,
|
|
290
|
+
text=item["input"]["report"],
|
|
291
|
+
label=item["output"]["neurological_abnormality"],
|
|
292
|
+
metadata=item.get("metadata", {}),
|
|
293
|
+
)
|
|
294
|
+
|
|
295
|
+
def get_compatible_prompts(self) -> list[str]:
|
|
296
|
+
"""Neurology dataset only works with neurology prompt."""
|
|
297
|
+
return ["neurology"]
|
|
298
|
+
|
|
299
|
+
|
|
300
|
+
@Registry.register_dataset("oncology")
|
|
301
|
+
class OncologyDataset(JSONDataset):
|
|
302
|
+
"""Oncology reports dataset for malignancy detection."""
|
|
303
|
+
|
|
304
|
+
def __init__(
|
|
305
|
+
self,
|
|
306
|
+
name: str = "oncology",
|
|
307
|
+
path: str = "data/oncology.json",
|
|
308
|
+
**kwargs,
|
|
309
|
+
):
|
|
310
|
+
super().__init__(name, path, **kwargs)
|
|
311
|
+
|
|
312
|
+
def _parse_item(self, idx: int, item: Dict[str, Any]) -> Sample:
|
|
313
|
+
return Sample(
|
|
314
|
+
idx=idx,
|
|
315
|
+
text=item["input"]["report"],
|
|
316
|
+
label=item["output"]["malignancy"],
|
|
317
|
+
metadata=item.get("metadata", {}),
|
|
318
|
+
)
|
|
319
|
+
|
|
320
|
+
def get_compatible_prompts(self) -> list[str]:
|
|
321
|
+
"""Oncology dataset only works with oncology prompt."""
|
|
322
|
+
return ["oncology"]
|
|
323
|
+
|
|
324
|
+
|
|
325
|
+
@Registry.register_dataset("pediatrics")
|
|
326
|
+
class PediatricsDataset(JSONDataset):
|
|
327
|
+
"""Pediatrics clinical scenarios dataset."""
|
|
328
|
+
|
|
329
|
+
def __init__(
|
|
330
|
+
self,
|
|
331
|
+
name: str = "pediatrics",
|
|
332
|
+
path: str = "data/pediatrics.json",
|
|
333
|
+
**kwargs,
|
|
334
|
+
):
|
|
335
|
+
super().__init__(name, path, **kwargs)
|
|
336
|
+
|
|
337
|
+
def _parse_item(self, idx: int, item: Dict[str, Any]) -> Sample:
|
|
338
|
+
return Sample(
|
|
339
|
+
idx=idx,
|
|
340
|
+
text=item["input"]["scenario"],
|
|
341
|
+
label=item["output"]["diagnosis"],
|
|
342
|
+
metadata=item.get("metadata", {}),
|
|
343
|
+
)
|
|
344
|
+
|
|
345
|
+
def _parse_csv_row(self, idx: int, row: Dict[str, Any]) -> Sample:
|
|
346
|
+
# CSV columns: Scenario, Diagnosis, Age_Group, Category
|
|
347
|
+
return Sample(
|
|
348
|
+
idx=idx,
|
|
349
|
+
text=row.get("Scenario", ""),
|
|
350
|
+
label=row.get("Diagnosis", ""),
|
|
351
|
+
metadata={
|
|
352
|
+
"age_group": row.get("Age_Group", ""),
|
|
353
|
+
"category": row.get("Category", ""),
|
|
354
|
+
},
|
|
355
|
+
)
|
|
356
|
+
|
|
357
|
+
|
|
358
|
+
@Registry.register_dataset("synthetic")
|
|
359
|
+
class SyntheticMedicalDataset(JSONDataset):
|
|
360
|
+
"""Synthetic medical QA dataset."""
|
|
361
|
+
|
|
362
|
+
def __init__(
|
|
363
|
+
self,
|
|
364
|
+
name: str = "synthetic",
|
|
365
|
+
path: str = "data/synthetic.json",
|
|
366
|
+
repeat: int = 1,
|
|
367
|
+
**kwargs,
|
|
368
|
+
):
|
|
369
|
+
self.repeat = repeat
|
|
370
|
+
super().__init__(name, path, **kwargs)
|
|
371
|
+
|
|
372
|
+
def _load(self):
|
|
373
|
+
"""Load samples with optional repeat. Supports both JSON and CSV."""
|
|
374
|
+
if not self.path.exists():
|
|
375
|
+
raise FileNotFoundError(f"Dataset not found: {self.path}")
|
|
376
|
+
|
|
377
|
+
if self.path.suffix.lower() == ".csv":
|
|
378
|
+
self._load_csv_with_repeat()
|
|
379
|
+
else:
|
|
380
|
+
self._load_json_with_repeat()
|
|
381
|
+
|
|
382
|
+
def _load_json_with_repeat(self):
|
|
383
|
+
"""Load JSON samples with optional repeat."""
|
|
384
|
+
with open(self.path, "r", encoding="utf-8") as f:
|
|
385
|
+
data = json.load(f)
|
|
386
|
+
|
|
387
|
+
for r in range(self.repeat):
|
|
388
|
+
for idx, item in enumerate(data):
|
|
389
|
+
sample = self._parse_item(r * len(data) + idx, item)
|
|
390
|
+
self._samples.append(sample)
|
|
391
|
+
|
|
392
|
+
def _load_csv_with_repeat(self):
|
|
393
|
+
"""Load CSV samples with optional repeat."""
|
|
394
|
+
with open(self.path, "r", encoding="utf-8") as f:
|
|
395
|
+
rows = list(csv.DictReader(f))
|
|
396
|
+
|
|
397
|
+
for r in range(self.repeat):
|
|
398
|
+
for idx, row in enumerate(rows):
|
|
399
|
+
sample = self._parse_csv_row(r * len(rows) + idx, row)
|
|
400
|
+
self._samples.append(sample)
|
|
401
|
+
|
|
402
|
+
def _parse_item(self, idx: int, item: Dict[str, Any]) -> Sample:
|
|
403
|
+
return Sample(
|
|
404
|
+
idx=idx,
|
|
405
|
+
text=item["input"]["scenario"],
|
|
406
|
+
label=item["output"]["diagnosis"],
|
|
407
|
+
metadata=item.get("metadata", {}),
|
|
408
|
+
)
|
|
409
|
+
|
|
410
|
+
def _parse_csv_row(self, idx: int, row: Dict[str, Any]) -> Sample:
|
|
411
|
+
# CSV columns: Scenario, Expected_Answer, Reasoning_Keywords
|
|
412
|
+
return Sample(
|
|
413
|
+
idx=idx,
|
|
414
|
+
text=row.get("Scenario", ""),
|
|
415
|
+
label=row.get("Expected_Answer", ""),
|
|
416
|
+
metadata={
|
|
417
|
+
"reasoning_keywords": row.get("Reasoning_Keywords", ""),
|
|
418
|
+
},
|
|
419
|
+
)
|
|
420
|
+
|
|
421
|
+
|
|
422
|
+
@Registry.register_dataset("patching_pairs")
|
|
423
|
+
class PatchingPairsDataset(JSONDataset):
|
|
424
|
+
"""Clean/corrupted pairs for activation patching experiments."""
|
|
425
|
+
|
|
426
|
+
def __init__(
|
|
427
|
+
self,
|
|
428
|
+
name: str = "patching_pairs",
|
|
429
|
+
path: str = "data/patching_pairs.json",
|
|
430
|
+
**kwargs,
|
|
431
|
+
):
|
|
432
|
+
super().__init__(name, path, **kwargs)
|
|
433
|
+
|
|
434
|
+
def _parse_item(self, idx: int, item: Dict[str, Any]) -> Sample:
|
|
435
|
+
return Sample(
|
|
436
|
+
idx=idx,
|
|
437
|
+
text=item["clean"]["input"],
|
|
438
|
+
label=item["clean"]["output"],
|
|
439
|
+
metadata={
|
|
440
|
+
"corrupted_prompt": item["corrupted"]["input"],
|
|
441
|
+
"clean_answer": item["clean"]["output"],
|
|
442
|
+
"corrupted_answer": item["corrupted"]["output"],
|
|
443
|
+
"category": item.get("metadata", {}).get("category", "general"),
|
|
444
|
+
},
|
|
445
|
+
)
|
|
446
|
+
|
|
447
|
+
def _parse_csv_row(self, idx: int, row: Dict[str, Any]) -> Sample:
|
|
448
|
+
# CSV columns: Clean_Prompt, Corrupted_Prompt, Clean_Answer, Corrupted_Answer, Category
|
|
449
|
+
return Sample(
|
|
450
|
+
idx=idx,
|
|
451
|
+
text=row.get("Clean_Prompt", ""),
|
|
452
|
+
label=row.get("Clean_Answer", ""),
|
|
453
|
+
metadata={
|
|
454
|
+
"corrupted_prompt": row.get("Corrupted_Prompt", ""),
|
|
455
|
+
"clean_answer": row.get("Clean_Answer", ""),
|
|
456
|
+
"corrupted_answer": row.get("Corrupted_Answer", ""),
|
|
457
|
+
"category": row.get("Category", "general"),
|
|
458
|
+
},
|
|
459
|
+
)
|
|
460
|
+
|
|
461
|
+
|
|
462
|
+
@Registry.register_dataset("tutorial")
|
|
463
|
+
class TutorialDataset(JSONDataset):
|
|
464
|
+
"""Simple Q&A dataset for tutorials and demos."""
|
|
465
|
+
|
|
466
|
+
def __init__(
|
|
467
|
+
self,
|
|
468
|
+
name: str = "tutorial",
|
|
469
|
+
path: str = "data/tutorial.json",
|
|
470
|
+
**kwargs,
|
|
471
|
+
):
|
|
472
|
+
super().__init__(name, path, **kwargs)
|
|
473
|
+
|
|
474
|
+
def _parse_item(self, idx: int, item: Dict[str, Any]) -> Sample:
|
|
475
|
+
return Sample(
|
|
476
|
+
idx=idx,
|
|
477
|
+
text=item["text"],
|
|
478
|
+
label=item["label"],
|
|
479
|
+
metadata={},
|
|
480
|
+
)
|
|
481
|
+
|
|
482
|
+
|
|
483
|
+
@Registry.register_dataset("probing_diagnosis")
|
|
484
|
+
class ProbingDiagnosisDataset(JSONDataset):
|
|
485
|
+
"""
|
|
486
|
+
Dataset for probing experiments testing diagnosis encoding.
|
|
487
|
+
|
|
488
|
+
Each sample has:
|
|
489
|
+
- question: Medical scenario
|
|
490
|
+
- diagnosis: Correct diagnosis (label)
|
|
491
|
+
- category: Medical specialty
|
|
492
|
+
- confounders: Alternative diagnoses
|
|
493
|
+
- key_features: Important clinical features
|
|
494
|
+
"""
|
|
495
|
+
|
|
496
|
+
def __init__(
|
|
497
|
+
self,
|
|
498
|
+
name: str = "probing_diagnosis",
|
|
499
|
+
path: str = "data/probing_diagnosis.json",
|
|
500
|
+
**kwargs,
|
|
501
|
+
):
|
|
502
|
+
super().__init__(name, path, **kwargs)
|
|
503
|
+
|
|
504
|
+
def _parse_item(self, idx: int, item: Dict[str, Any]) -> Sample:
|
|
505
|
+
return Sample(
|
|
506
|
+
idx=idx,
|
|
507
|
+
text=item["input"]["question"],
|
|
508
|
+
label=item["output"]["diagnosis"],
|
|
509
|
+
metadata={
|
|
510
|
+
"category": item.get("metadata", {}).get("category", "general"),
|
|
511
|
+
"difficulty": item.get("metadata", {}).get("difficulty", "medium"),
|
|
512
|
+
"confounders": item.get("metadata", {}).get("confounders", []),
|
|
513
|
+
"key_features": item.get("metadata", {}).get("key_features", []),
|
|
514
|
+
},
|
|
515
|
+
)
|
|
516
|
+
|
|
517
|
+
|
|
518
|
+
@Registry.register_dataset("histopathology")
|
|
519
|
+
class HistopathologyDataset(BaseDataset):
|
|
520
|
+
"""Histopathology report quality evaluation dataset.
|
|
521
|
+
|
|
522
|
+
Loads HARE human evaluation annotations and expands each row
|
|
523
|
+
into 4 samples (one per model output with its human score).
|
|
524
|
+
|
|
525
|
+
Labels: 0 (poor), 1 (partial), 2 (good)
|
|
526
|
+
"""
|
|
527
|
+
|
|
528
|
+
def __init__(
|
|
529
|
+
self,
|
|
530
|
+
name: str = "histopathology",
|
|
531
|
+
path: str = "data/histopathology.tsv",
|
|
532
|
+
repo_id: Optional[str] = None,
|
|
533
|
+
**kwargs,
|
|
534
|
+
):
|
|
535
|
+
self._name = name
|
|
536
|
+
self.repo_id = repo_id
|
|
537
|
+
self.path = self._resolve_path_from_registry(name, path)
|
|
538
|
+
self._samples: List[Sample] = []
|
|
539
|
+
self._load()
|
|
540
|
+
|
|
541
|
+
def _resolve_path_from_registry(self, name: str, default_path: str) -> Path:
|
|
542
|
+
"""Resolve path from data/datasets.yaml registry or fallback to default."""
|
|
543
|
+
import yaml
|
|
544
|
+
from huggingface_hub import hf_hub_download
|
|
545
|
+
|
|
546
|
+
repo_id = self.repo_id
|
|
547
|
+
|
|
548
|
+
# Locate registry relative to this file (src/cotlab/datasets/loaders.py -> root/data/datasets.yaml)
|
|
549
|
+
# root is 3 levels up from this file's directory: src/cotlab/datasets -> src/cotlab -> src -> root
|
|
550
|
+
root_dir = Path(__file__).parent.parent.parent.parent
|
|
551
|
+
registry_path = root_dir / "data/datasets.yaml"
|
|
552
|
+
|
|
553
|
+
if not registry_path.exists():
|
|
554
|
+
# Fallback: try CWD
|
|
555
|
+
registry_path = Path("data/datasets.yaml")
|
|
556
|
+
if not registry_path.exists():
|
|
557
|
+
return Path(default_path)
|
|
558
|
+
|
|
559
|
+
try:
|
|
560
|
+
with open(registry_path, "r") as f:
|
|
561
|
+
config = yaml.safe_load(f)
|
|
562
|
+
|
|
563
|
+
ds_config = config.get("datasets", {}).get(name, {})
|
|
564
|
+
if not repo_id:
|
|
565
|
+
repo_id = ds_config.get("repo_id", config.get("default", {}).get("repo_id"))
|
|
566
|
+
|
|
567
|
+
if not repo_id:
|
|
568
|
+
return Path(default_path)
|
|
569
|
+
|
|
570
|
+
filename = ds_config.get("path")
|
|
571
|
+
if not filename:
|
|
572
|
+
p = Path(default_path)
|
|
573
|
+
if "data" in p.parts:
|
|
574
|
+
try:
|
|
575
|
+
filename = str(p.relative_to("data"))
|
|
576
|
+
except ValueError:
|
|
577
|
+
filename = p.name
|
|
578
|
+
else:
|
|
579
|
+
filename = p.name
|
|
580
|
+
|
|
581
|
+
try:
|
|
582
|
+
cached_path = hf_hub_download(
|
|
583
|
+
repo_id=repo_id, filename=filename, repo_type="dataset"
|
|
584
|
+
)
|
|
585
|
+
return Path(cached_path)
|
|
586
|
+
except Exception as e:
|
|
587
|
+
print(
|
|
588
|
+
f"Warning: Failed to download {name} ({filename}) from HF repo {repo_id}: {e}"
|
|
589
|
+
)
|
|
590
|
+
except Exception as e:
|
|
591
|
+
print(f"Warning: Failed to load registry: {e}")
|
|
592
|
+
|
|
593
|
+
return Path(default_path)
|
|
594
|
+
|
|
595
|
+
def _load(self):
|
|
596
|
+
"""Load TSV and expand to 4 samples per row."""
|
|
597
|
+
if not self.path.exists():
|
|
598
|
+
raise FileNotFoundError(f"Dataset not found: {self.path}")
|
|
599
|
+
|
|
600
|
+
with open(self.path, "r", encoding="utf-8") as f:
|
|
601
|
+
reader = csv.DictReader(f, delimiter="\t")
|
|
602
|
+
sample_idx = 0
|
|
603
|
+
for row_idx, row in enumerate(reader):
|
|
604
|
+
ground_truth = row.get("ground_truth", "")
|
|
605
|
+
# Expand 4 model outputs per row
|
|
606
|
+
for model_num in range(4):
|
|
607
|
+
report_col = str(model_num)
|
|
608
|
+
score_col = f"Scoring {model_num}"
|
|
609
|
+
report_text = row.get(report_col, "")
|
|
610
|
+
score = row.get(score_col, "")
|
|
611
|
+
|
|
612
|
+
# Skip if missing data
|
|
613
|
+
if not report_text or score == "":
|
|
614
|
+
continue
|
|
615
|
+
|
|
616
|
+
try:
|
|
617
|
+
label = int(float(score))
|
|
618
|
+
except (ValueError, TypeError):
|
|
619
|
+
continue
|
|
620
|
+
|
|
621
|
+
self._samples.append(
|
|
622
|
+
Sample(
|
|
623
|
+
idx=sample_idx,
|
|
624
|
+
text=report_text,
|
|
625
|
+
label=label,
|
|
626
|
+
metadata={
|
|
627
|
+
"ground_truth": ground_truth,
|
|
628
|
+
"model_id": model_num,
|
|
629
|
+
"row_id": row_idx,
|
|
630
|
+
},
|
|
631
|
+
)
|
|
632
|
+
)
|
|
633
|
+
sample_idx += 1
|
|
634
|
+
|
|
635
|
+
@property
|
|
636
|
+
def name(self) -> str:
|
|
637
|
+
return self._name
|
|
638
|
+
|
|
639
|
+
def __len__(self) -> int:
|
|
640
|
+
return len(self._samples)
|
|
641
|
+
|
|
642
|
+
def __getitem__(self, idx: int) -> Sample:
|
|
643
|
+
return self._samples[idx]
|
|
644
|
+
|
|
645
|
+
def get_compatible_prompts(self) -> list[str]:
|
|
646
|
+
"""Histopathology dataset works with histopathology prompt."""
|
|
647
|
+
return ["histopathology"]
|
|
648
|
+
|
|
649
|
+
|
|
650
|
+
@Registry.register_dataset("tcga")
|
|
651
|
+
class TCGADataset(BaseDataset):
|
|
652
|
+
"""TCGA Cancer Type Classification dataset.
|
|
653
|
+
|
|
654
|
+
Input: Pathology report text
|
|
655
|
+
Output: Cancer type (e.g., BRCA, LUAD)
|
|
656
|
+
|
|
657
|
+
Implements official 15% stratified test split using patient IDs.
|
|
658
|
+
"""
|
|
659
|
+
|
|
660
|
+
def __init__(
|
|
661
|
+
self,
|
|
662
|
+
name: str = "tcga",
|
|
663
|
+
repo_id: Optional[str] = None,
|
|
664
|
+
reports_filename: str = "tcga/TCGA_Reports.csv",
|
|
665
|
+
labels_filename: str = "tcga/tcga_patient_to_cancer_type.csv",
|
|
666
|
+
split: str = "test", # train, test, or all
|
|
667
|
+
random_seed: int = 0,
|
|
668
|
+
**kwargs,
|
|
669
|
+
):
|
|
670
|
+
self._name = name
|
|
671
|
+
self.repo_id = repo_id
|
|
672
|
+
self.reports_filename = reports_filename
|
|
673
|
+
self.labels_filename = labels_filename
|
|
674
|
+
self.split = split
|
|
675
|
+
self.random_seed = random_seed
|
|
676
|
+
self._samples: List[Sample] = []
|
|
677
|
+
self._load()
|
|
678
|
+
|
|
679
|
+
def _load(self):
|
|
680
|
+
import random
|
|
681
|
+
|
|
682
|
+
import yaml
|
|
683
|
+
from huggingface_hub import hf_hub_download
|
|
684
|
+
|
|
685
|
+
# Resolve Repo ID: Config > Registry > None
|
|
686
|
+
repo_id = self.repo_id
|
|
687
|
+
if not repo_id:
|
|
688
|
+
registry_path = Path("data/datasets.yaml")
|
|
689
|
+
if registry_path.exists():
|
|
690
|
+
try:
|
|
691
|
+
with open(registry_path, "r") as f:
|
|
692
|
+
config = yaml.safe_load(f)
|
|
693
|
+
# Check explicit "tcga" entry or default
|
|
694
|
+
ds_config = config.get("datasets", {}).get("tcga", {})
|
|
695
|
+
repo_id = ds_config.get("repo_id", config.get("default", {}).get("repo_id"))
|
|
696
|
+
except Exception as e:
|
|
697
|
+
print(f"Warning: Failed to load registry for TCGA: {e}")
|
|
698
|
+
|
|
699
|
+
if not repo_id:
|
|
700
|
+
raise ValueError("No repo_id found for TCGA dataset. Set it in data/datasets.yaml.")
|
|
701
|
+
|
|
702
|
+
try:
|
|
703
|
+
reports_path = hf_hub_download(
|
|
704
|
+
repo_id=repo_id, filename=self.reports_filename, repo_type="dataset"
|
|
705
|
+
)
|
|
706
|
+
labels_path = hf_hub_download(
|
|
707
|
+
repo_id=repo_id, filename=self.labels_filename, repo_type="dataset"
|
|
708
|
+
)
|
|
709
|
+
self.path = Path(reports_path)
|
|
710
|
+
self.labels_path = Path(labels_path)
|
|
711
|
+
except Exception as e:
|
|
712
|
+
raise FileNotFoundError(f"Failed to download from HF repo {repo_id}: {e}")
|
|
713
|
+
|
|
714
|
+
if not self.path.exists():
|
|
715
|
+
raise FileNotFoundError(f"Reports not found: {self.path}")
|
|
716
|
+
if not self.labels_path.exists():
|
|
717
|
+
raise FileNotFoundError(f"Labels not found: {self.labels_path}")
|
|
718
|
+
|
|
719
|
+
# 1. Load Labels
|
|
720
|
+
labels = {}
|
|
721
|
+
with open(self.labels_path, "r", encoding="utf-8") as f:
|
|
722
|
+
reader = csv.reader(f)
|
|
723
|
+
next(reader, None) # Skip header
|
|
724
|
+
for row in reader:
|
|
725
|
+
if len(row) >= 2:
|
|
726
|
+
labels[row[0]] = row[1]
|
|
727
|
+
|
|
728
|
+
# 2. Load Reports and Link
|
|
729
|
+
data_by_class = {} # ctype -> list of (patient_id, text)
|
|
730
|
+
|
|
731
|
+
with open(self.path, "r", encoding="utf-8") as f:
|
|
732
|
+
reader = csv.DictReader(f)
|
|
733
|
+
for row in reader:
|
|
734
|
+
pid_full = row.get("patient_filename", "")
|
|
735
|
+
text = row.get("text", "")
|
|
736
|
+
|
|
737
|
+
# Extract patient ID (TCGA-XX-YYYY)
|
|
738
|
+
if len(pid_full) < 12:
|
|
739
|
+
continue
|
|
740
|
+
pid = pid_full[:12]
|
|
741
|
+
|
|
742
|
+
if pid in labels:
|
|
743
|
+
ctype = labels[pid]
|
|
744
|
+
if ctype not in data_by_class:
|
|
745
|
+
data_by_class[ctype] = []
|
|
746
|
+
data_by_class[ctype].append((pid, text))
|
|
747
|
+
|
|
748
|
+
# 3. Stratified Split
|
|
749
|
+
final_samples = []
|
|
750
|
+
|
|
751
|
+
# Sort classes for deterministic order before shuffling
|
|
752
|
+
sorted_classes = sorted(data_by_class.keys())
|
|
753
|
+
|
|
754
|
+
for ctype in sorted_classes:
|
|
755
|
+
items = data_by_class[ctype]
|
|
756
|
+
|
|
757
|
+
# Deterministic shuffle
|
|
758
|
+
random.Random(self.random_seed).shuffle(items)
|
|
759
|
+
|
|
760
|
+
# Split index (15% test)
|
|
761
|
+
n_total = len(items)
|
|
762
|
+
n_test = int(n_total * 0.15)
|
|
763
|
+
|
|
764
|
+
if self.split == "all":
|
|
765
|
+
selected = items
|
|
766
|
+
elif self.split == "test":
|
|
767
|
+
# Test set is the first 15%
|
|
768
|
+
selected = items[:n_test]
|
|
769
|
+
elif self.split == "train":
|
|
770
|
+
selected = items[n_test:]
|
|
771
|
+
else:
|
|
772
|
+
selected = []
|
|
773
|
+
|
|
774
|
+
# Add to samples
|
|
775
|
+
for pid, text in selected:
|
|
776
|
+
final_samples.append((pid, text, ctype))
|
|
777
|
+
|
|
778
|
+
# 4. Create Sample objects
|
|
779
|
+
for i, (pid, text, ctype) in enumerate(final_samples):
|
|
780
|
+
self._samples.append(
|
|
781
|
+
Sample(
|
|
782
|
+
idx=i,
|
|
783
|
+
text=text,
|
|
784
|
+
label=ctype,
|
|
785
|
+
metadata={"patient_id": pid, "cancer_type": ctype},
|
|
786
|
+
)
|
|
787
|
+
)
|
|
788
|
+
|
|
789
|
+
@property
|
|
790
|
+
def name(self) -> str:
|
|
791
|
+
return self._name
|
|
792
|
+
|
|
793
|
+
def __len__(self) -> int:
|
|
794
|
+
return len(self._samples)
|
|
795
|
+
|
|
796
|
+
def __getitem__(self, idx: int) -> Sample:
|
|
797
|
+
return self._samples[idx]
|
|
798
|
+
|
|
799
|
+
def get_compatible_prompts(self) -> list[str]:
|
|
800
|
+
return ["tcga"]
|
|
801
|
+
|
|
802
|
+
|
|
803
|
+
@Registry.register_dataset("medqa")
|
|
804
|
+
class MedQADataset(BaseDataset):
|
|
805
|
+
"""MedQA USMLE-style 4-option MCQ dataset.
|
|
806
|
+
|
|
807
|
+
Format: JSONL with fields:
|
|
808
|
+
- question: Clinical vignette
|
|
809
|
+
- options: Dict {"A": "...", "B": "...", "C": "...", "D": "..."}
|
|
810
|
+
- answer_idx: Correct answer letter (A/B/C/D)
|
|
811
|
+
- meta_info: USMLE step (step1, step2, step3)
|
|
812
|
+
"""
|
|
813
|
+
|
|
814
|
+
def __init__(
|
|
815
|
+
self,
|
|
816
|
+
name: str = "medqa",
|
|
817
|
+
repo_id: Optional[str] = None,
|
|
818
|
+
filename: str = "medqa/test.jsonl",
|
|
819
|
+
split: str = "test",
|
|
820
|
+
**kwargs,
|
|
821
|
+
):
|
|
822
|
+
self._name = name
|
|
823
|
+
self.repo_id = repo_id
|
|
824
|
+
self.filename = filename
|
|
825
|
+
self.split = split
|
|
826
|
+
self._samples: List[Sample] = []
|
|
827
|
+
self._load()
|
|
828
|
+
|
|
829
|
+
def _load(self):
|
|
830
|
+
import json
|
|
831
|
+
|
|
832
|
+
import yaml
|
|
833
|
+
from huggingface_hub import hf_hub_download
|
|
834
|
+
|
|
835
|
+
# Resolve repo_id from registry if not provided
|
|
836
|
+
repo_id = self.repo_id
|
|
837
|
+
if not repo_id:
|
|
838
|
+
root_dir = Path(__file__).parent.parent.parent.parent
|
|
839
|
+
registry_path = root_dir / "data/datasets.yaml"
|
|
840
|
+
if registry_path.exists():
|
|
841
|
+
try:
|
|
842
|
+
with open(registry_path, "r") as f:
|
|
843
|
+
config = yaml.safe_load(f)
|
|
844
|
+
ds_config = config.get("datasets", {}).get("medqa", {})
|
|
845
|
+
repo_id = ds_config.get("repo_id", config.get("default", {}).get("repo_id"))
|
|
846
|
+
except Exception as e:
|
|
847
|
+
print(f"Warning: Failed to load registry for MedQA: {e}")
|
|
848
|
+
|
|
849
|
+
if not repo_id:
|
|
850
|
+
raise ValueError("No repo_id found for MedQA dataset")
|
|
851
|
+
|
|
852
|
+
# Download from HF
|
|
853
|
+
try:
|
|
854
|
+
local_path = hf_hub_download(
|
|
855
|
+
repo_id=repo_id,
|
|
856
|
+
filename=self.filename,
|
|
857
|
+
repo_type="dataset",
|
|
858
|
+
)
|
|
859
|
+
except Exception as e:
|
|
860
|
+
raise FileNotFoundError(f"Failed to download MedQA from {repo_id}: {e}")
|
|
861
|
+
|
|
862
|
+
# Parse JSONL
|
|
863
|
+
with open(local_path, "r") as f:
|
|
864
|
+
for i, line in enumerate(f):
|
|
865
|
+
data = json.loads(line.strip())
|
|
866
|
+
|
|
867
|
+
# Format question with options
|
|
868
|
+
question = data["question"]
|
|
869
|
+
options = data["options"]
|
|
870
|
+
formatted_options = "\n".join(
|
|
871
|
+
f"{key}) {val}" for key, val in sorted(options.items())
|
|
872
|
+
)
|
|
873
|
+
text = f"{question}\n\n{formatted_options}"
|
|
874
|
+
|
|
875
|
+
# Answer is the letter (A, B, C, D)
|
|
876
|
+
label = data["answer_idx"]
|
|
877
|
+
|
|
878
|
+
self._samples.append(
|
|
879
|
+
Sample(
|
|
880
|
+
idx=i,
|
|
881
|
+
text=text,
|
|
882
|
+
label=label,
|
|
883
|
+
metadata={
|
|
884
|
+
"step": data.get("meta_info", "unknown"),
|
|
885
|
+
"answer_text": data.get("answer", ""),
|
|
886
|
+
},
|
|
887
|
+
)
|
|
888
|
+
)
|
|
889
|
+
|
|
890
|
+
@property
|
|
891
|
+
def name(self) -> str:
|
|
892
|
+
return self._name
|
|
893
|
+
|
|
894
|
+
def __len__(self) -> int:
|
|
895
|
+
return len(self._samples)
|
|
896
|
+
|
|
897
|
+
def __getitem__(self, idx: int) -> Sample:
|
|
898
|
+
return self._samples[idx]
|
|
899
|
+
|
|
900
|
+
def get_compatible_prompts(self) -> list[str]:
|
|
901
|
+
return [
|
|
902
|
+
"mcq",
|
|
903
|
+
"medqa",
|
|
904
|
+
"direct_answer",
|
|
905
|
+
"chain_of_thought",
|
|
906
|
+
"uncertainty",
|
|
907
|
+
"contrarian",
|
|
908
|
+
"few_shot",
|
|
909
|
+
]
|
|
910
|
+
|
|
911
|
+
|
|
912
|
+
@Registry.register_dataset("mmlu_medical")
|
|
913
|
+
class MMLUMedicalDataset(BaseDataset):
|
|
914
|
+
"""MMLU Medical subset dataset (anatomy, clinical_knowledge, medical_genetics, college_biology).
|
|
915
|
+
|
|
916
|
+
Format: JSONL with fields:
|
|
917
|
+
- question: Question text
|
|
918
|
+
- choices: List of 4 options
|
|
919
|
+
- answer: Integer index (0-3)
|
|
920
|
+
- subject: Subject name
|
|
921
|
+
"""
|
|
922
|
+
|
|
923
|
+
def __init__(
|
|
924
|
+
self,
|
|
925
|
+
name: str = "mmlu_medical",
|
|
926
|
+
repo_id: Optional[str] = None,
|
|
927
|
+
filename: str = "mmlu/medical_test.jsonl",
|
|
928
|
+
**kwargs,
|
|
929
|
+
):
|
|
930
|
+
self._name = name
|
|
931
|
+
self.repo_id = repo_id
|
|
932
|
+
self.filename = filename
|
|
933
|
+
self._samples: List[Sample] = []
|
|
934
|
+
self._load()
|
|
935
|
+
|
|
936
|
+
def _load(self):
|
|
937
|
+
import json
|
|
938
|
+
|
|
939
|
+
import yaml
|
|
940
|
+
from huggingface_hub import hf_hub_download
|
|
941
|
+
|
|
942
|
+
# Resolve repo_id from registry if not provided
|
|
943
|
+
repo_id = self.repo_id
|
|
944
|
+
if not repo_id:
|
|
945
|
+
root_dir = Path(__file__).parent.parent.parent.parent
|
|
946
|
+
registry_path = root_dir / "data/datasets.yaml"
|
|
947
|
+
if registry_path.exists():
|
|
948
|
+
try:
|
|
949
|
+
with open(registry_path, "r") as f:
|
|
950
|
+
config = yaml.safe_load(f)
|
|
951
|
+
ds_config = config.get("datasets", {}).get("mmlu_medical", {})
|
|
952
|
+
repo_id = ds_config.get("repo_id", config.get("default", {}).get("repo_id"))
|
|
953
|
+
except Exception as e:
|
|
954
|
+
print(f"Warning: Failed to load registry for MMLU Medical: {e}")
|
|
955
|
+
|
|
956
|
+
if not repo_id:
|
|
957
|
+
raise ValueError("No repo_id found for MMLU Medical dataset")
|
|
958
|
+
|
|
959
|
+
# Download from HF
|
|
960
|
+
try:
|
|
961
|
+
local_path = hf_hub_download(
|
|
962
|
+
repo_id=repo_id,
|
|
963
|
+
filename=self.filename,
|
|
964
|
+
repo_type="dataset",
|
|
965
|
+
)
|
|
966
|
+
except Exception as e:
|
|
967
|
+
raise FileNotFoundError(f"Failed to download MMLU Medical from {repo_id}: {e}")
|
|
968
|
+
|
|
969
|
+
# Parse JSONL
|
|
970
|
+
index_to_letter = {0: "A", 1: "B", 2: "C", 3: "D"}
|
|
971
|
+
|
|
972
|
+
with open(local_path, "r") as f:
|
|
973
|
+
for i, line in enumerate(f):
|
|
974
|
+
data = json.loads(line.strip())
|
|
975
|
+
|
|
976
|
+
# Format question with options
|
|
977
|
+
question = data["question"]
|
|
978
|
+
choices = data["choices"]
|
|
979
|
+
formatted_options = "\n".join(
|
|
980
|
+
f"{chr(65 + j)}) {opt}" for j, opt in enumerate(choices)
|
|
981
|
+
)
|
|
982
|
+
text = f"{question}\n\n{formatted_options}"
|
|
983
|
+
|
|
984
|
+
# Convert integer answer to letter
|
|
985
|
+
answer_idx = data["answer"]
|
|
986
|
+
label = index_to_letter.get(answer_idx, "A")
|
|
987
|
+
|
|
988
|
+
self._samples.append(
|
|
989
|
+
Sample(
|
|
990
|
+
idx=i,
|
|
991
|
+
text=text,
|
|
992
|
+
label=label,
|
|
993
|
+
metadata={
|
|
994
|
+
"subject": data.get("subject", "unknown"),
|
|
995
|
+
},
|
|
996
|
+
)
|
|
997
|
+
)
|
|
998
|
+
|
|
999
|
+
@property
|
|
1000
|
+
def name(self) -> str:
|
|
1001
|
+
return self._name
|
|
1002
|
+
|
|
1003
|
+
def __len__(self) -> int:
|
|
1004
|
+
return len(self._samples)
|
|
1005
|
+
|
|
1006
|
+
def __getitem__(self, idx: int) -> Sample:
|
|
1007
|
+
return self._samples[idx]
|
|
1008
|
+
|
|
1009
|
+
def get_compatible_prompts(self) -> list[str]:
|
|
1010
|
+
return [
|
|
1011
|
+
"mcq",
|
|
1012
|
+
"direct_answer",
|
|
1013
|
+
"chain_of_thought",
|
|
1014
|
+
"uncertainty",
|
|
1015
|
+
"contrarian",
|
|
1016
|
+
"few_shot",
|
|
1017
|
+
]
|
|
1018
|
+
|
|
1019
|
+
|
|
1020
|
+
@Registry.register_dataset("m_arc")
|
|
1021
|
+
class MARCDataset(BaseDataset):
|
|
1022
|
+
"""M-ARC benchmark.
|
|
1023
|
+
|
|
1024
|
+
Upstream dataset: mkieffer/M-ARC (Parquet, split=test, 100 items).
|
|
1025
|
+
|
|
1026
|
+
Expected columns:
|
|
1027
|
+
- question_id: string
|
|
1028
|
+
- question: string
|
|
1029
|
+
- options: struct with keys A..G
|
|
1030
|
+
- answer: string label (A..G)
|
|
1031
|
+
- src: string (source/category)
|
|
1032
|
+
"""
|
|
1033
|
+
|
|
1034
|
+
def __init__(
|
|
1035
|
+
self,
|
|
1036
|
+
name: str = "m_arc",
|
|
1037
|
+
repo_id: Optional[str] = None,
|
|
1038
|
+
filename: str = "m_arc/test-00000-of-00001.parquet",
|
|
1039
|
+
split: str = "test",
|
|
1040
|
+
**kwargs,
|
|
1041
|
+
):
|
|
1042
|
+
self._name = name
|
|
1043
|
+
self.repo_id = repo_id
|
|
1044
|
+
self.filename = filename
|
|
1045
|
+
self.split = split
|
|
1046
|
+
self._samples: List[Sample] = []
|
|
1047
|
+
self._load()
|
|
1048
|
+
|
|
1049
|
+
@property
|
|
1050
|
+
def name(self) -> str:
|
|
1051
|
+
return self._name
|
|
1052
|
+
|
|
1053
|
+
def __len__(self) -> int:
|
|
1054
|
+
return len(self._samples)
|
|
1055
|
+
|
|
1056
|
+
def __getitem__(self, idx: int) -> Sample:
|
|
1057
|
+
return self._samples[idx]
|
|
1058
|
+
|
|
1059
|
+
def get_compatible_prompts(self) -> list[str]:
|
|
1060
|
+
return [
|
|
1061
|
+
"mcq",
|
|
1062
|
+
"direct_answer",
|
|
1063
|
+
"chain_of_thought",
|
|
1064
|
+
"uncertainty",
|
|
1065
|
+
"contrarian",
|
|
1066
|
+
"few_shot",
|
|
1067
|
+
]
|
|
1068
|
+
|
|
1069
|
+
def _resolve_repo_id(self) -> Optional[str]:
|
|
1070
|
+
"""Resolve default repo_id from data/datasets.yaml, if present."""
|
|
1071
|
+
import yaml
|
|
1072
|
+
|
|
1073
|
+
root_dir = Path(__file__).parent.parent.parent.parent
|
|
1074
|
+
registry_path = root_dir / "data/datasets.yaml"
|
|
1075
|
+
if not registry_path.exists():
|
|
1076
|
+
return None
|
|
1077
|
+
|
|
1078
|
+
try:
|
|
1079
|
+
with open(registry_path, "r") as f:
|
|
1080
|
+
config = yaml.safe_load(f)
|
|
1081
|
+
ds_config = config.get("datasets", {}).get("m_arc", {})
|
|
1082
|
+
return ds_config.get("repo_id", config.get("default", {}).get("repo_id"))
|
|
1083
|
+
except Exception as e:
|
|
1084
|
+
print(f"Warning: Failed to load registry for M-ARC: {e}")
|
|
1085
|
+
return None
|
|
1086
|
+
|
|
1087
|
+
@staticmethod
|
|
1088
|
+
def _format_options(options: Dict[str, Any]) -> str:
|
|
1089
|
+
"""Format A..G options as lines, skipping empty values."""
|
|
1090
|
+
if not options:
|
|
1091
|
+
return ""
|
|
1092
|
+
lines: list[str] = []
|
|
1093
|
+
for key in sorted(options.keys()):
|
|
1094
|
+
val = options.get(key)
|
|
1095
|
+
if val is None:
|
|
1096
|
+
continue
|
|
1097
|
+
if isinstance(val, str) and not val.strip():
|
|
1098
|
+
continue
|
|
1099
|
+
lines.append(f"{key}) {str(val).strip()}")
|
|
1100
|
+
return "\n".join(lines)
|
|
1101
|
+
|
|
1102
|
+
@staticmethod
|
|
1103
|
+
def _coerce_options(raw: Any) -> Dict[str, Any]:
|
|
1104
|
+
"""Convert pyarrow struct / dict-like values into a plain dict."""
|
|
1105
|
+
if raw is None:
|
|
1106
|
+
return {}
|
|
1107
|
+
if isinstance(raw, dict):
|
|
1108
|
+
return raw
|
|
1109
|
+
# pyarrow StructScalar has as_py()
|
|
1110
|
+
as_py = getattr(raw, "as_py", None)
|
|
1111
|
+
if callable(as_py):
|
|
1112
|
+
val = as_py()
|
|
1113
|
+
if isinstance(val, dict):
|
|
1114
|
+
return val
|
|
1115
|
+
return {}
|
|
1116
|
+
|
|
1117
|
+
def _load(self) -> None:
|
|
1118
|
+
try:
|
|
1119
|
+
import pyarrow.parquet as pq
|
|
1120
|
+
except ImportError as e:
|
|
1121
|
+
raise ImportError("M-ARC requires pyarrow. Install with: uv pip install pyarrow") from e
|
|
1122
|
+
|
|
1123
|
+
from huggingface_hub import hf_hub_download
|
|
1124
|
+
|
|
1125
|
+
repo_id = self.repo_id or self._resolve_repo_id()
|
|
1126
|
+
if not repo_id:
|
|
1127
|
+
raise ValueError(
|
|
1128
|
+
"No repo_id found for M-ARC. Set default.repo_id in data/datasets.yaml "
|
|
1129
|
+
"or pass repo_id= explicitly."
|
|
1130
|
+
)
|
|
1131
|
+
|
|
1132
|
+
try:
|
|
1133
|
+
local_path = hf_hub_download(
|
|
1134
|
+
repo_id=repo_id,
|
|
1135
|
+
filename=self.filename,
|
|
1136
|
+
repo_type="dataset",
|
|
1137
|
+
)
|
|
1138
|
+
except Exception as e:
|
|
1139
|
+
raise FileNotFoundError(
|
|
1140
|
+
f"Failed to download M-ARC from {repo_id} (file: {self.filename}): {e}"
|
|
1141
|
+
)
|
|
1142
|
+
|
|
1143
|
+
used_repo_id = repo_id
|
|
1144
|
+
used_filename = self.filename
|
|
1145
|
+
|
|
1146
|
+
table = pq.read_table(Path(local_path))
|
|
1147
|
+
df = table.to_pandas()
|
|
1148
|
+
|
|
1149
|
+
for i, row in df.iterrows():
|
|
1150
|
+
question_id = row.get("question_id")
|
|
1151
|
+
question = row.get("question", "") or ""
|
|
1152
|
+
options = self._coerce_options(row.get("options"))
|
|
1153
|
+
options_formatted = self._format_options(options)
|
|
1154
|
+
text = f"{question}\n\n{options_formatted}".strip()
|
|
1155
|
+
label = (row.get("answer") or "").strip().upper()
|
|
1156
|
+
|
|
1157
|
+
self._samples.append(
|
|
1158
|
+
Sample(
|
|
1159
|
+
idx=int(i),
|
|
1160
|
+
text=text,
|
|
1161
|
+
label=label,
|
|
1162
|
+
metadata={
|
|
1163
|
+
"question_id": question_id,
|
|
1164
|
+
"src": row.get("src"),
|
|
1165
|
+
"question": question,
|
|
1166
|
+
"options": options,
|
|
1167
|
+
"options_formatted": options_formatted,
|
|
1168
|
+
"repo_id": used_repo_id,
|
|
1169
|
+
"filename": used_filename,
|
|
1170
|
+
},
|
|
1171
|
+
)
|
|
1172
|
+
)
|
|
1173
|
+
|
|
1174
|
+
|
|
1175
|
+
@Registry.register_dataset("medbullets")
|
|
1176
|
+
class MedBulletsDataset(BaseDataset):
|
|
1177
|
+
"""MedBullets MCQ benchmark dataset.
|
|
1178
|
+
|
|
1179
|
+
Upstream dataset: mkieffer/Medbullets (Parquet, splits: op4_test, op5_test).
|
|
1180
|
+
|
|
1181
|
+
Splits:
|
|
1182
|
+
- op4_test: effectively 4 options (may include an empty option field)
|
|
1183
|
+
- op5_test: same stems, with an additional answer choice (letters may differ)
|
|
1184
|
+
|
|
1185
|
+
Expected columns:
|
|
1186
|
+
- idx: string id
|
|
1187
|
+
- question: string
|
|
1188
|
+
- options: struct/dict with keys A..E (some values may be empty strings)
|
|
1189
|
+
- answer: string label (A..E)
|
|
1190
|
+
- explanation: string (rationale)
|
|
1191
|
+
- link: string (source)
|
|
1192
|
+
"""
|
|
1193
|
+
|
|
1194
|
+
def __init__(
|
|
1195
|
+
self,
|
|
1196
|
+
name: str = "medbullets",
|
|
1197
|
+
repo_id: Optional[str] = None,
|
|
1198
|
+
split: str = "op5_test",
|
|
1199
|
+
filename: Optional[str] = None,
|
|
1200
|
+
**kwargs,
|
|
1201
|
+
):
|
|
1202
|
+
self._name = name
|
|
1203
|
+
self.repo_id = repo_id
|
|
1204
|
+
self.split = split
|
|
1205
|
+
self.filename = filename
|
|
1206
|
+
self._samples: List[Sample] = []
|
|
1207
|
+
self._load()
|
|
1208
|
+
|
|
1209
|
+
@property
|
|
1210
|
+
def name(self) -> str:
|
|
1211
|
+
return self._name
|
|
1212
|
+
|
|
1213
|
+
def __len__(self) -> int:
|
|
1214
|
+
return len(self._samples)
|
|
1215
|
+
|
|
1216
|
+
def __getitem__(self, idx: int) -> Sample:
|
|
1217
|
+
return self._samples[idx]
|
|
1218
|
+
|
|
1219
|
+
def get_compatible_prompts(self) -> list[str]:
|
|
1220
|
+
return [
|
|
1221
|
+
"mcq",
|
|
1222
|
+
"direct_answer",
|
|
1223
|
+
"chain_of_thought",
|
|
1224
|
+
"uncertainty",
|
|
1225
|
+
"contrarian",
|
|
1226
|
+
"few_shot",
|
|
1227
|
+
]
|
|
1228
|
+
|
|
1229
|
+
def _default_filename(self) -> str:
|
|
1230
|
+
return f"medbullets/{self.split}-00000-of-00001.parquet"
|
|
1231
|
+
|
|
1232
|
+
def _resolve_repo_id(self) -> str:
|
|
1233
|
+
import yaml
|
|
1234
|
+
|
|
1235
|
+
if self.repo_id:
|
|
1236
|
+
return self.repo_id
|
|
1237
|
+
|
|
1238
|
+
root_dir = Path(__file__).parent.parent.parent.parent
|
|
1239
|
+
registry_path = root_dir / "data/datasets.yaml"
|
|
1240
|
+
if registry_path.exists():
|
|
1241
|
+
try:
|
|
1242
|
+
with open(registry_path, "r") as f:
|
|
1243
|
+
config = yaml.safe_load(f)
|
|
1244
|
+
ds_config = config.get("datasets", {}).get("medbullets", {})
|
|
1245
|
+
repo_id = ds_config.get("repo_id", config.get("default", {}).get("repo_id"))
|
|
1246
|
+
if repo_id:
|
|
1247
|
+
return repo_id
|
|
1248
|
+
except Exception as e:
|
|
1249
|
+
print(f"Warning: Failed to load registry for MedBullets: {e}")
|
|
1250
|
+
|
|
1251
|
+
raise ValueError(
|
|
1252
|
+
"No repo_id found for MedBullets. Set default.repo_id in data/datasets.yaml "
|
|
1253
|
+
"or pass repo_id= explicitly."
|
|
1254
|
+
)
|
|
1255
|
+
|
|
1256
|
+
@staticmethod
|
|
1257
|
+
def _coerce_options(raw: Any) -> Dict[str, Any]:
|
|
1258
|
+
if raw is None:
|
|
1259
|
+
return {}
|
|
1260
|
+
if isinstance(raw, dict):
|
|
1261
|
+
return raw
|
|
1262
|
+
as_py = getattr(raw, "as_py", None)
|
|
1263
|
+
if callable(as_py):
|
|
1264
|
+
val = as_py()
|
|
1265
|
+
if isinstance(val, dict):
|
|
1266
|
+
return val
|
|
1267
|
+
return {}
|
|
1268
|
+
|
|
1269
|
+
@staticmethod
|
|
1270
|
+
def _format_options(options: Dict[str, Any]) -> str:
|
|
1271
|
+
if not options:
|
|
1272
|
+
return ""
|
|
1273
|
+
lines: list[str] = []
|
|
1274
|
+
for key in sorted(options.keys()):
|
|
1275
|
+
val = options.get(key)
|
|
1276
|
+
if val is None:
|
|
1277
|
+
continue
|
|
1278
|
+
if isinstance(val, str) and not val.strip():
|
|
1279
|
+
continue
|
|
1280
|
+
lines.append(f"{key}) {str(val).strip()}")
|
|
1281
|
+
return "\n".join(lines)
|
|
1282
|
+
|
|
1283
|
+
def _load(self) -> None:
|
|
1284
|
+
try:
|
|
1285
|
+
import pyarrow.parquet as pq
|
|
1286
|
+
except ImportError as e:
|
|
1287
|
+
raise ImportError(
|
|
1288
|
+
"MedBullets requires pyarrow. Install with: uv pip install pyarrow"
|
|
1289
|
+
) from e
|
|
1290
|
+
|
|
1291
|
+
from huggingface_hub import hf_hub_download
|
|
1292
|
+
|
|
1293
|
+
repo_id = self._resolve_repo_id()
|
|
1294
|
+
filename = self.filename or self._default_filename()
|
|
1295
|
+
|
|
1296
|
+
try:
|
|
1297
|
+
local_path = hf_hub_download(
|
|
1298
|
+
repo_id=repo_id,
|
|
1299
|
+
filename=filename,
|
|
1300
|
+
repo_type="dataset",
|
|
1301
|
+
)
|
|
1302
|
+
except Exception as e:
|
|
1303
|
+
raise FileNotFoundError(
|
|
1304
|
+
f"Failed to download MedBullets from {repo_id} (file: {filename}): {e}"
|
|
1305
|
+
)
|
|
1306
|
+
|
|
1307
|
+
table = pq.read_table(Path(local_path))
|
|
1308
|
+
df = table.to_pandas()
|
|
1309
|
+
|
|
1310
|
+
for i, row in df.iterrows():
|
|
1311
|
+
qid = row.get("idx")
|
|
1312
|
+
question = row.get("question", "") or ""
|
|
1313
|
+
options = self._coerce_options(row.get("options"))
|
|
1314
|
+
options_formatted = self._format_options(options)
|
|
1315
|
+
text = f"{question}\n\n{options_formatted}".strip()
|
|
1316
|
+
label = (row.get("answer") or "").strip().upper()
|
|
1317
|
+
|
|
1318
|
+
self._samples.append(
|
|
1319
|
+
Sample(
|
|
1320
|
+
idx=int(i),
|
|
1321
|
+
text=text,
|
|
1322
|
+
label=label,
|
|
1323
|
+
metadata={
|
|
1324
|
+
"idx": qid,
|
|
1325
|
+
"split": self.split,
|
|
1326
|
+
"question": question,
|
|
1327
|
+
"options": options,
|
|
1328
|
+
"options_formatted": options_formatted,
|
|
1329
|
+
"explanation": row.get("explanation"),
|
|
1330
|
+
"link": row.get("link"),
|
|
1331
|
+
"repo_id": repo_id,
|
|
1332
|
+
"filename": filename,
|
|
1333
|
+
},
|
|
1334
|
+
)
|
|
1335
|
+
)
|
|
1336
|
+
|
|
1337
|
+
|
|
1338
|
+
@Registry.register_dataset("plab")
|
|
1339
|
+
class PLABDataset(BaseDataset):
|
|
1340
|
+
"""PLAB-style MCQ dataset loader from raw JSON files."""
|
|
1341
|
+
|
|
1342
|
+
def __init__(
|
|
1343
|
+
self,
|
|
1344
|
+
name: str = "plab",
|
|
1345
|
+
repo_id: Optional[str] = None,
|
|
1346
|
+
filename: str = "plab/data.json",
|
|
1347
|
+
topics_filename: str = "plab/topics.json",
|
|
1348
|
+
split: str = "main",
|
|
1349
|
+
**kwargs,
|
|
1350
|
+
):
|
|
1351
|
+
self._name = name
|
|
1352
|
+
self.repo_id = repo_id
|
|
1353
|
+
self.filename = filename
|
|
1354
|
+
self.topics_filename = topics_filename
|
|
1355
|
+
self.split = split
|
|
1356
|
+
self._samples: List[Sample] = []
|
|
1357
|
+
self._load()
|
|
1358
|
+
|
|
1359
|
+
@property
|
|
1360
|
+
def name(self) -> str:
|
|
1361
|
+
return self._name
|
|
1362
|
+
|
|
1363
|
+
def __len__(self) -> int:
|
|
1364
|
+
return len(self._samples)
|
|
1365
|
+
|
|
1366
|
+
def __getitem__(self, idx: int) -> Sample:
|
|
1367
|
+
return self._samples[idx]
|
|
1368
|
+
|
|
1369
|
+
def get_compatible_prompts(self) -> list[str]:
|
|
1370
|
+
return [
|
|
1371
|
+
"plab",
|
|
1372
|
+
"mcq",
|
|
1373
|
+
"direct_answer",
|
|
1374
|
+
"chain_of_thought",
|
|
1375
|
+
"uncertainty",
|
|
1376
|
+
"contrarian",
|
|
1377
|
+
"few_shot",
|
|
1378
|
+
]
|
|
1379
|
+
|
|
1380
|
+
def _resolve_repo_id(self) -> str:
|
|
1381
|
+
import yaml
|
|
1382
|
+
|
|
1383
|
+
if self.repo_id:
|
|
1384
|
+
return self.repo_id
|
|
1385
|
+
|
|
1386
|
+
root_dir = Path(__file__).parent.parent.parent.parent
|
|
1387
|
+
registry_path = root_dir / "data/datasets.yaml"
|
|
1388
|
+
if registry_path.exists():
|
|
1389
|
+
try:
|
|
1390
|
+
with open(registry_path, "r") as f:
|
|
1391
|
+
config = yaml.safe_load(f)
|
|
1392
|
+
ds_config = config.get("datasets", {}).get("plab", {})
|
|
1393
|
+
repo_id = ds_config.get("repo_id", config.get("default", {}).get("repo_id"))
|
|
1394
|
+
if repo_id:
|
|
1395
|
+
return repo_id
|
|
1396
|
+
except Exception as e:
|
|
1397
|
+
print(f"Warning: Failed to load registry for PLAB: {e}")
|
|
1398
|
+
|
|
1399
|
+
raise ValueError(
|
|
1400
|
+
"No repo_id found for PLAB. Set default.repo_id in data/datasets.yaml "
|
|
1401
|
+
"or pass repo_id= explicitly."
|
|
1402
|
+
)
|
|
1403
|
+
|
|
1404
|
+
@staticmethod
|
|
1405
|
+
def _as_text(value: Any) -> str:
|
|
1406
|
+
if value is None:
|
|
1407
|
+
return ""
|
|
1408
|
+
if isinstance(value, list):
|
|
1409
|
+
return "".join(str(x) for x in value)
|
|
1410
|
+
return str(value)
|
|
1411
|
+
|
|
1412
|
+
@staticmethod
|
|
1413
|
+
def _extract_question_id(question_text: str) -> Optional[int]:
|
|
1414
|
+
m = re.match(r"\s*(\d+)\.", question_text)
|
|
1415
|
+
if not m:
|
|
1416
|
+
return None
|
|
1417
|
+
try:
|
|
1418
|
+
return int(m.group(1))
|
|
1419
|
+
except ValueError:
|
|
1420
|
+
return None
|
|
1421
|
+
|
|
1422
|
+
@staticmethod
|
|
1423
|
+
def _extract_label(answer_text: str) -> Optional[str]:
|
|
1424
|
+
patterns = [
|
|
1425
|
+
r"\bkey\s+is\s*([A-G])\b",
|
|
1426
|
+
r"\bAns\.?\s*([A-G])\b",
|
|
1427
|
+
r"\bAnswer\s*[:\-]?\s*([A-G])\b",
|
|
1428
|
+
]
|
|
1429
|
+
for pat in patterns:
|
|
1430
|
+
m = re.search(pat, answer_text, flags=re.IGNORECASE)
|
|
1431
|
+
if m:
|
|
1432
|
+
return m.group(1).upper()
|
|
1433
|
+
return None
|
|
1434
|
+
|
|
1435
|
+
@staticmethod
|
|
1436
|
+
def _format_options(raw_options: List[Any]) -> tuple[str, Dict[str, str]]:
|
|
1437
|
+
lines: List[str] = []
|
|
1438
|
+
options_map: Dict[str, str] = {}
|
|
1439
|
+
next_letter_ord = ord("A")
|
|
1440
|
+
|
|
1441
|
+
for raw in raw_options:
|
|
1442
|
+
option = str(raw).strip()
|
|
1443
|
+
if not option:
|
|
1444
|
+
continue
|
|
1445
|
+
|
|
1446
|
+
m = re.match(r"^\s*([A-Ga-g])[\)\.\:\-]\s*(.*)$", option)
|
|
1447
|
+
if m:
|
|
1448
|
+
letter = m.group(1).upper()
|
|
1449
|
+
text = m.group(2).strip()
|
|
1450
|
+
else:
|
|
1451
|
+
letter = chr(next_letter_ord)
|
|
1452
|
+
text = option
|
|
1453
|
+
next_letter_ord += 1
|
|
1454
|
+
|
|
1455
|
+
if not text:
|
|
1456
|
+
continue
|
|
1457
|
+
|
|
1458
|
+
options_map[letter] = text
|
|
1459
|
+
lines.append(f"{letter}) {text}")
|
|
1460
|
+
|
|
1461
|
+
return "\n".join(lines), options_map
|
|
1462
|
+
|
|
1463
|
+
def _load(self) -> None:
|
|
1464
|
+
from huggingface_hub import hf_hub_download
|
|
1465
|
+
|
|
1466
|
+
repo_id = self._resolve_repo_id()
|
|
1467
|
+
|
|
1468
|
+
try:
|
|
1469
|
+
data_path = Path(
|
|
1470
|
+
hf_hub_download(
|
|
1471
|
+
repo_id=repo_id,
|
|
1472
|
+
filename=self.filename,
|
|
1473
|
+
repo_type="dataset",
|
|
1474
|
+
)
|
|
1475
|
+
)
|
|
1476
|
+
except Exception as e:
|
|
1477
|
+
raise FileNotFoundError(
|
|
1478
|
+
f"Failed to download PLAB data from {repo_id} (file: {self.filename}): {e}"
|
|
1479
|
+
)
|
|
1480
|
+
|
|
1481
|
+
topics_by_qid: Dict[int, str] = {}
|
|
1482
|
+
try:
|
|
1483
|
+
topics_path = Path(
|
|
1484
|
+
hf_hub_download(
|
|
1485
|
+
repo_id=repo_id,
|
|
1486
|
+
filename=self.topics_filename,
|
|
1487
|
+
repo_type="dataset",
|
|
1488
|
+
)
|
|
1489
|
+
)
|
|
1490
|
+
with open(topics_path, "r", encoding="utf-8") as f:
|
|
1491
|
+
topics_data = json.load(f)
|
|
1492
|
+
for row in topics_data:
|
|
1493
|
+
qid = row.get("question")
|
|
1494
|
+
try:
|
|
1495
|
+
qid_int = int(qid)
|
|
1496
|
+
except (TypeError, ValueError):
|
|
1497
|
+
continue
|
|
1498
|
+
topic = self._as_text(row.get("topic")).strip()
|
|
1499
|
+
if topic:
|
|
1500
|
+
topics_by_qid[qid_int] = topic
|
|
1501
|
+
except Exception:
|
|
1502
|
+
topics_by_qid = {}
|
|
1503
|
+
|
|
1504
|
+
with open(data_path, "r", encoding="utf-8") as f:
|
|
1505
|
+
data = json.load(f)
|
|
1506
|
+
|
|
1507
|
+
sample_idx = 0
|
|
1508
|
+
for row in data:
|
|
1509
|
+
question_text = self._as_text(row.get("question")).strip()
|
|
1510
|
+
if not question_text:
|
|
1511
|
+
continue
|
|
1512
|
+
|
|
1513
|
+
raw_options = row.get("options")
|
|
1514
|
+
if not isinstance(raw_options, list):
|
|
1515
|
+
continue
|
|
1516
|
+
|
|
1517
|
+
options_formatted, options_map = self._format_options(raw_options)
|
|
1518
|
+
if not options_formatted:
|
|
1519
|
+
continue
|
|
1520
|
+
|
|
1521
|
+
answer_explanation = self._as_text(row.get("answer")).strip()
|
|
1522
|
+
label = self._extract_label(answer_explanation)
|
|
1523
|
+
if label is None or label not in options_map:
|
|
1524
|
+
continue
|
|
1525
|
+
|
|
1526
|
+
qid = self._extract_question_id(question_text)
|
|
1527
|
+
topic = topics_by_qid.get(qid) if qid is not None else None
|
|
1528
|
+
text = f"{question_text}\n\n{options_formatted}".strip()
|
|
1529
|
+
|
|
1530
|
+
self._samples.append(
|
|
1531
|
+
Sample(
|
|
1532
|
+
idx=sample_idx,
|
|
1533
|
+
text=text,
|
|
1534
|
+
label=label,
|
|
1535
|
+
metadata={
|
|
1536
|
+
"question_id": qid,
|
|
1537
|
+
"topic": topic,
|
|
1538
|
+
"question": question_text,
|
|
1539
|
+
"options": options_map,
|
|
1540
|
+
"options_formatted": options_formatted,
|
|
1541
|
+
"answer_explanation": answer_explanation,
|
|
1542
|
+
"split": self.split,
|
|
1543
|
+
"repo_id": repo_id,
|
|
1544
|
+
"filename": self.filename,
|
|
1545
|
+
"topics_filename": self.topics_filename,
|
|
1546
|
+
},
|
|
1547
|
+
)
|
|
1548
|
+
)
|
|
1549
|
+
sample_idx += 1
|
|
1550
|
+
|
|
1551
|
+
|
|
1552
|
+
@Registry.register_dataset("pubhealthbench")
|
|
1553
|
+
class PubHealthBenchDataset(BaseDataset):
|
|
1554
|
+
"""PubHealthBench multiple-choice public health QA dataset.
|
|
1555
|
+
|
|
1556
|
+
Format: Parquet with fields:
|
|
1557
|
+
- question: Question text
|
|
1558
|
+
- options: List of answer choices
|
|
1559
|
+
- options_formatted: Pre-formatted options string (A. ..., B. ...)
|
|
1560
|
+
- answer_index: Integer index (0-based)
|
|
1561
|
+
- answer: Correct answer letter (A/B/C/...)
|
|
1562
|
+
- category, intended_audience, source_document_title, source_chunk_text
|
|
1563
|
+
- review_annotation, retrieved_context_for_judge, question_id
|
|
1564
|
+
"""
|
|
1565
|
+
|
|
1566
|
+
def __init__(
|
|
1567
|
+
self,
|
|
1568
|
+
name: str = "pubhealthbench",
|
|
1569
|
+
repo_id: Optional[str] = None,
|
|
1570
|
+
filename: Optional[str] = None,
|
|
1571
|
+
split: str = "test",
|
|
1572
|
+
**kwargs,
|
|
1573
|
+
):
|
|
1574
|
+
self._name = name
|
|
1575
|
+
self.repo_id = repo_id
|
|
1576
|
+
self.filename = filename
|
|
1577
|
+
self.split = split
|
|
1578
|
+
self._samples: List[Sample] = []
|
|
1579
|
+
self._load()
|
|
1580
|
+
|
|
1581
|
+
def _default_filename(self) -> str:
|
|
1582
|
+
if self.split == "reviewed":
|
|
1583
|
+
return "pubhealthbench/reviewed-00000-of-00001.parquet"
|
|
1584
|
+
return f"pubhealthbench/{self.split}-00000-of-00001.parquet"
|
|
1585
|
+
|
|
1586
|
+
def _resolve_repo_id(self) -> str:
|
|
1587
|
+
import yaml
|
|
1588
|
+
|
|
1589
|
+
repo_id = self.repo_id
|
|
1590
|
+
if repo_id:
|
|
1591
|
+
return repo_id
|
|
1592
|
+
|
|
1593
|
+
# Try registry override
|
|
1594
|
+
root_dir = Path(__file__).parent.parent.parent.parent
|
|
1595
|
+
registry_path = root_dir / "data/datasets.yaml"
|
|
1596
|
+
if registry_path.exists():
|
|
1597
|
+
try:
|
|
1598
|
+
with open(registry_path, "r") as f:
|
|
1599
|
+
config = yaml.safe_load(f)
|
|
1600
|
+
ds_config = config.get("datasets", {}).get("pubhealthbench", {})
|
|
1601
|
+
# Prefer explicit override, otherwise fall back to default repo_id.
|
|
1602
|
+
repo_id = ds_config.get("repo_id", config.get("default", {}).get("repo_id"))
|
|
1603
|
+
except Exception as e:
|
|
1604
|
+
print(f"Warning: Failed to load registry for PubHealthBench: {e}")
|
|
1605
|
+
|
|
1606
|
+
if not repo_id:
|
|
1607
|
+
raise ValueError(
|
|
1608
|
+
"No repo_id found for PubHealthBench. Set default.repo_id in data/datasets.yaml "
|
|
1609
|
+
"or pass repo_id= explicitly."
|
|
1610
|
+
)
|
|
1611
|
+
return repo_id
|
|
1612
|
+
|
|
1613
|
+
def _format_options(self, options: Optional[List[str]]) -> str:
|
|
1614
|
+
if not options:
|
|
1615
|
+
return ""
|
|
1616
|
+
letters = [chr(65 + i) for i in range(len(options))]
|
|
1617
|
+
return "\n".join(f"{letter}. {opt}" for letter, opt in zip(letters, options))
|
|
1618
|
+
|
|
1619
|
+
def _load(self):
|
|
1620
|
+
import math
|
|
1621
|
+
|
|
1622
|
+
try:
|
|
1623
|
+
import pyarrow.parquet as pq
|
|
1624
|
+
except ImportError as e:
|
|
1625
|
+
raise ImportError(
|
|
1626
|
+
"PubHealthBench requires pyarrow. Install with: uv pip install pyarrow"
|
|
1627
|
+
) from e
|
|
1628
|
+
|
|
1629
|
+
from huggingface_hub import hf_hub_download
|
|
1630
|
+
|
|
1631
|
+
filename = self.filename or self._default_filename()
|
|
1632
|
+
repo_id = self._resolve_repo_id()
|
|
1633
|
+
try:
|
|
1634
|
+
path = Path(
|
|
1635
|
+
hf_hub_download(
|
|
1636
|
+
repo_id=repo_id,
|
|
1637
|
+
filename=filename,
|
|
1638
|
+
repo_type="dataset",
|
|
1639
|
+
)
|
|
1640
|
+
)
|
|
1641
|
+
except Exception as e:
|
|
1642
|
+
raise FileNotFoundError(
|
|
1643
|
+
f"Failed to download PubHealthBench from {repo_id} (file: {filename}): {e}"
|
|
1644
|
+
)
|
|
1645
|
+
|
|
1646
|
+
table = pq.read_table(path)
|
|
1647
|
+
df = table.to_pandas()
|
|
1648
|
+
|
|
1649
|
+
index_to_letter = {i: chr(65 + i) for i in range(26)}
|
|
1650
|
+
|
|
1651
|
+
for i, row in df.iterrows():
|
|
1652
|
+
|
|
1653
|
+
def clean_value(value: Any) -> Any:
|
|
1654
|
+
if value is None:
|
|
1655
|
+
return None
|
|
1656
|
+
if isinstance(value, float) and math.isnan(value):
|
|
1657
|
+
return None
|
|
1658
|
+
return value
|
|
1659
|
+
|
|
1660
|
+
question = row.get("question", "")
|
|
1661
|
+
options_formatted = clean_value(row.get("options_formatted"))
|
|
1662
|
+
if not options_formatted:
|
|
1663
|
+
options_formatted = self._format_options(row.get("options"))
|
|
1664
|
+
|
|
1665
|
+
text = f"{question}\n\n{options_formatted}".strip()
|
|
1666
|
+
|
|
1667
|
+
answer = clean_value(row.get("answer"))
|
|
1668
|
+
if not answer and row.get("answer_index") is not None:
|
|
1669
|
+
answer = index_to_letter.get(int(row["answer_index"]), "A")
|
|
1670
|
+
|
|
1671
|
+
metadata = {
|
|
1672
|
+
"question_id": clean_value(row.get("question_id")),
|
|
1673
|
+
"question": question,
|
|
1674
|
+
"options_formatted": options_formatted,
|
|
1675
|
+
"category": clean_value(row.get("category")),
|
|
1676
|
+
"intended_audience": clean_value(row.get("intended_audience")),
|
|
1677
|
+
"source_document_title": clean_value(row.get("source_document_title")),
|
|
1678
|
+
"source_chunk_text": clean_value(row.get("source_chunk_text")),
|
|
1679
|
+
"review_annotation": clean_value(row.get("review_annotation")),
|
|
1680
|
+
"retrieved_context_for_judge": clean_value(row.get("retrieved_context_for_judge")),
|
|
1681
|
+
}
|
|
1682
|
+
|
|
1683
|
+
self._samples.append(
|
|
1684
|
+
Sample(
|
|
1685
|
+
idx=i,
|
|
1686
|
+
text=text,
|
|
1687
|
+
label=answer,
|
|
1688
|
+
metadata=metadata,
|
|
1689
|
+
)
|
|
1690
|
+
)
|
|
1691
|
+
|
|
1692
|
+
@property
|
|
1693
|
+
def name(self) -> str:
|
|
1694
|
+
return self._name
|
|
1695
|
+
|
|
1696
|
+
def __len__(self) -> int:
|
|
1697
|
+
return len(self._samples)
|
|
1698
|
+
|
|
1699
|
+
def __getitem__(self, idx: int) -> Sample:
|
|
1700
|
+
return self._samples[idx]
|
|
1701
|
+
|
|
1702
|
+
def get_compatible_prompts(self) -> list[str]:
|
|
1703
|
+
return [
|
|
1704
|
+
"mcq",
|
|
1705
|
+
"pubhealthbench",
|
|
1706
|
+
"direct_answer",
|
|
1707
|
+
"chain_of_thought",
|
|
1708
|
+
"uncertainty",
|
|
1709
|
+
"contrarian",
|
|
1710
|
+
"few_shot",
|
|
1711
|
+
]
|
|
1712
|
+
|
|
1713
|
+
|
|
1714
|
+
@Registry.register_dataset("pubmedqa")
|
|
1715
|
+
class PubMedQADataset(BaseDataset):
|
|
1716
|
+
"""PubMedQA research question answering dataset.
|
|
1717
|
+
|
|
1718
|
+
Format: JSONL with fields:
|
|
1719
|
+
- question: Research question
|
|
1720
|
+
- context: Abstract context
|
|
1721
|
+
- answer: yes/no/maybe
|
|
1722
|
+
- pmid: PubMed ID
|
|
1723
|
+
"""
|
|
1724
|
+
|
|
1725
|
+
def __init__(
|
|
1726
|
+
self,
|
|
1727
|
+
name: str = "pubmedqa",
|
|
1728
|
+
repo_id: Optional[str] = None,
|
|
1729
|
+
filename: str = "pubmedqa/test.jsonl",
|
|
1730
|
+
**kwargs,
|
|
1731
|
+
):
|
|
1732
|
+
self._name = name
|
|
1733
|
+
self.repo_id = repo_id
|
|
1734
|
+
self.filename = filename
|
|
1735
|
+
self._samples: List[Sample] = []
|
|
1736
|
+
self._load()
|
|
1737
|
+
|
|
1738
|
+
def _load(self):
|
|
1739
|
+
import json
|
|
1740
|
+
|
|
1741
|
+
import yaml
|
|
1742
|
+
from huggingface_hub import hf_hub_download
|
|
1743
|
+
|
|
1744
|
+
# Resolve repo_id from registry if not provided
|
|
1745
|
+
repo_id = self.repo_id
|
|
1746
|
+
if not repo_id:
|
|
1747
|
+
root_dir = Path(__file__).parent.parent.parent.parent
|
|
1748
|
+
registry_path = root_dir / "data/datasets.yaml"
|
|
1749
|
+
if registry_path.exists():
|
|
1750
|
+
try:
|
|
1751
|
+
with open(registry_path, "r") as f:
|
|
1752
|
+
config = yaml.safe_load(f)
|
|
1753
|
+
ds_config = config.get("datasets", {}).get("pubmedqa", {})
|
|
1754
|
+
repo_id = ds_config.get("repo_id", config.get("default", {}).get("repo_id"))
|
|
1755
|
+
except Exception as e:
|
|
1756
|
+
print(f"Warning: Failed to load registry for PubMedQA: {e}")
|
|
1757
|
+
|
|
1758
|
+
if not repo_id:
|
|
1759
|
+
raise ValueError("No repo_id found for PubMedQA dataset")
|
|
1760
|
+
|
|
1761
|
+
# Download from HF
|
|
1762
|
+
try:
|
|
1763
|
+
local_path = hf_hub_download(
|
|
1764
|
+
repo_id=repo_id,
|
|
1765
|
+
filename=self.filename,
|
|
1766
|
+
repo_type="dataset",
|
|
1767
|
+
)
|
|
1768
|
+
except Exception as e:
|
|
1769
|
+
raise FileNotFoundError(f"Failed to download PubMedQA from {repo_id}: {e}")
|
|
1770
|
+
|
|
1771
|
+
# Parse JSONL
|
|
1772
|
+
with open(local_path, "r") as f:
|
|
1773
|
+
for i, line in enumerate(f):
|
|
1774
|
+
data = json.loads(line.strip())
|
|
1775
|
+
|
|
1776
|
+
# Format: question + context, predict yes/no/maybe
|
|
1777
|
+
question = data["question"]
|
|
1778
|
+
context = data.get("context", "")
|
|
1779
|
+
|
|
1780
|
+
# Truncate very long contexts
|
|
1781
|
+
if len(context) > 1500:
|
|
1782
|
+
context = context[:1500] + "..."
|
|
1783
|
+
|
|
1784
|
+
text = f"Question: {question}\n\nContext: {context}"
|
|
1785
|
+
|
|
1786
|
+
# Label is yes/no/maybe
|
|
1787
|
+
label = data["answer"]
|
|
1788
|
+
|
|
1789
|
+
self._samples.append(
|
|
1790
|
+
Sample(
|
|
1791
|
+
idx=i,
|
|
1792
|
+
text=text,
|
|
1793
|
+
label=label,
|
|
1794
|
+
metadata={
|
|
1795
|
+
"pmid": data.get("pmid", ""),
|
|
1796
|
+
},
|
|
1797
|
+
)
|
|
1798
|
+
)
|
|
1799
|
+
|
|
1800
|
+
@property
|
|
1801
|
+
def name(self) -> str:
|
|
1802
|
+
return self._name
|
|
1803
|
+
|
|
1804
|
+
def __len__(self) -> int:
|
|
1805
|
+
return len(self._samples)
|
|
1806
|
+
|
|
1807
|
+
def __getitem__(self, idx: int) -> Sample:
|
|
1808
|
+
return self._samples[idx]
|
|
1809
|
+
|
|
1810
|
+
def get_compatible_prompts(self) -> list[str]:
|
|
1811
|
+
return [
|
|
1812
|
+
"pubmedqa",
|
|
1813
|
+
"mcq",
|
|
1814
|
+
"direct_answer",
|
|
1815
|
+
"chain_of_thought",
|
|
1816
|
+
"uncertainty",
|
|
1817
|
+
"contrarian",
|
|
1818
|
+
"few_shot",
|
|
1819
|
+
]
|
|
1820
|
+
|
|
1821
|
+
|
|
1822
|
+
@Registry.register_dataset("movie_ood")
|
|
1823
|
+
class MovieOODDataset(BaseDataset):
|
|
1824
|
+
"""Post-cutoff movie MCQ dataset for true zero-knowledge OOD evaluation.
|
|
1825
|
+
|
|
1826
|
+
Generated from TMDB API: films released after MedGemma training cutoff (2025-07-01).
|
|
1827
|
+
Question types: director, cast, genre, production_country.
|
|
1828
|
+
|
|
1829
|
+
Format: JSONL with fields:
|
|
1830
|
+
- question: Question text
|
|
1831
|
+
- options: Dict {"A": "...", "B": "...", "C": "...", "D": "..."}
|
|
1832
|
+
- answer: Correct answer letter (A/B/C/D)
|
|
1833
|
+
- metadata: movie_id, title, release_date, question_type
|
|
1834
|
+
"""
|
|
1835
|
+
|
|
1836
|
+
def __init__(
|
|
1837
|
+
self,
|
|
1838
|
+
name: str = "movie_ood",
|
|
1839
|
+
path: str = "data/movie_ood.jsonl",
|
|
1840
|
+
**kwargs,
|
|
1841
|
+
):
|
|
1842
|
+
self._name = name
|
|
1843
|
+
self._samples: List[Sample] = []
|
|
1844
|
+
self._load(Path(path))
|
|
1845
|
+
|
|
1846
|
+
def _load(self, path: Path):
|
|
1847
|
+
if not path.exists():
|
|
1848
|
+
raise FileNotFoundError(
|
|
1849
|
+
f"movie_ood dataset not found at {path}.\n"
|
|
1850
|
+
"Run: python scripts/build_movie_ood_dataset.py"
|
|
1851
|
+
)
|
|
1852
|
+
with open(path, "r", encoding="utf-8") as f:
|
|
1853
|
+
for i, line in enumerate(f):
|
|
1854
|
+
line = line.strip()
|
|
1855
|
+
if not line:
|
|
1856
|
+
continue
|
|
1857
|
+
data = json.loads(line)
|
|
1858
|
+
question = data["question"]
|
|
1859
|
+
options = data["options"]
|
|
1860
|
+
formatted_options = "\n".join(
|
|
1861
|
+
f"{key}) {val}" for key, val in sorted(options.items())
|
|
1862
|
+
)
|
|
1863
|
+
text = f"{question}\n\n{formatted_options}"
|
|
1864
|
+
self._samples.append(
|
|
1865
|
+
Sample(
|
|
1866
|
+
idx=i,
|
|
1867
|
+
text=text,
|
|
1868
|
+
label=data.get("answer", "").strip().upper(),
|
|
1869
|
+
metadata=data.get("metadata", {}),
|
|
1870
|
+
)
|
|
1871
|
+
)
|
|
1872
|
+
|
|
1873
|
+
@property
|
|
1874
|
+
def name(self) -> str:
|
|
1875
|
+
return self._name
|
|
1876
|
+
|
|
1877
|
+
def __len__(self) -> int:
|
|
1878
|
+
return len(self._samples)
|
|
1879
|
+
|
|
1880
|
+
def __getitem__(self, idx: int) -> Sample:
|
|
1881
|
+
return self._samples[idx]
|
|
1882
|
+
|
|
1883
|
+
def get_compatible_prompts(self) -> list[str]:
|
|
1884
|
+
return [
|
|
1885
|
+
"mcq",
|
|
1886
|
+
"direct_answer",
|
|
1887
|
+
"chain_of_thought",
|
|
1888
|
+
"few_shot",
|
|
1889
|
+
]
|