cotlab 0.8.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (65) hide show
  1. cotlab/__init__.py +3 -0
  2. cotlab/analyse_experiments.py +392 -0
  3. cotlab/analysis/__init__.py +11 -0
  4. cotlab/analysis/cot_parser.py +243 -0
  5. cotlab/analysis/faithfulness_metrics.py +192 -0
  6. cotlab/backends/__init__.py +16 -0
  7. cotlab/backends/base.py +78 -0
  8. cotlab/backends/transformers_backend.py +335 -0
  9. cotlab/backends/vllm_backend.py +227 -0
  10. cotlab/cli.py +83 -0
  11. cotlab/core/__init__.py +34 -0
  12. cotlab/core/base.py +749 -0
  13. cotlab/core/config.py +90 -0
  14. cotlab/core/registry.py +68 -0
  15. cotlab/datasets/__init__.py +45 -0
  16. cotlab/datasets/loaders.py +1889 -0
  17. cotlab/experiment/__init__.py +315 -0
  18. cotlab/experiments/__init__.py +43 -0
  19. cotlab/experiments/activation_compare.py +290 -0
  20. cotlab/experiments/activation_patching.py +1050 -0
  21. cotlab/experiments/attention_analysis.py +885 -0
  22. cotlab/experiments/classification.py +235 -0
  23. cotlab/experiments/composite_shift_detector.py +524 -0
  24. cotlab/experiments/cot_ablation.py +277 -0
  25. cotlab/experiments/cot_faithfulness.py +187 -0
  26. cotlab/experiments/cot_heads.py +208 -0
  27. cotlab/experiments/full_layer_cot.py +232 -0
  28. cotlab/experiments/full_layer_patching.py +225 -0
  29. cotlab/experiments/h_neuron_analysis.py +712 -0
  30. cotlab/experiments/logit_lens.py +439 -0
  31. cotlab/experiments/multi_head_cot.py +220 -0
  32. cotlab/experiments/multi_head_patching.py +229 -0
  33. cotlab/experiments/probing_classifier.py +402 -0
  34. cotlab/experiments/residual_norm_ood.py +413 -0
  35. cotlab/experiments/sae_feature_analysis.py +673 -0
  36. cotlab/experiments/steering_vectors.py +223 -0
  37. cotlab/experiments/sycophancy_heads.py +224 -0
  38. cotlab/logging/__init__.py +5 -0
  39. cotlab/logging/json_logger.py +161 -0
  40. cotlab/main.py +317 -0
  41. cotlab/patching/__init__.py +24 -0
  42. cotlab/patching/cache.py +141 -0
  43. cotlab/patching/hooks.py +558 -0
  44. cotlab/patching/interventions.py +86 -0
  45. cotlab/patching/patcher.py +439 -0
  46. cotlab/patching/sae.py +181 -0
  47. cotlab/prompts/__init__.py +43 -0
  48. cotlab/prompts/cardiology.py +378 -0
  49. cotlab/prompts/histopathology.py +265 -0
  50. cotlab/prompts/length_matched_strategies.py +157 -0
  51. cotlab/prompts/mcq.py +193 -0
  52. cotlab/prompts/neurology.py +353 -0
  53. cotlab/prompts/oncology.py +367 -0
  54. cotlab/prompts/plab.py +162 -0
  55. cotlab/prompts/pubhealthbench.py +82 -0
  56. cotlab/prompts/pubmedqa.py +173 -0
  57. cotlab/prompts/radiology.py +414 -0
  58. cotlab/prompts/strategies.py +939 -0
  59. cotlab/prompts/tcga.py +168 -0
  60. cotlab/runner.py +204 -0
  61. cotlab-0.8.0.dist-info/METADATA +166 -0
  62. cotlab-0.8.0.dist-info/RECORD +65 -0
  63. cotlab-0.8.0.dist-info/WHEEL +4 -0
  64. cotlab-0.8.0.dist-info/entry_points.txt +3 -0
  65. cotlab-0.8.0.dist-info/licenses/LICENSE +21 -0
@@ -0,0 +1,1889 @@
1
+ """Dataset loaders for CoT research.
2
+
3
+ Supports JSON format with standardized structure:
4
+ {
5
+ "id": "unique_id",
6
+ "input": { ... },
7
+ "output": { ... },
8
+ "metadata": { ... }
9
+ }
10
+
11
+ Also supports CSV format with automatic detection based on file extension.
12
+
13
+ Datasets can specify compatible prompts via get_compatible_prompts() for
14
+ restricting which prompt strategies can be used.
15
+ """
16
+
17
+ import csv
18
+ import json
19
+ import re
20
+ from abc import ABC, abstractmethod
21
+ from dataclasses import dataclass
22
+ from pathlib import Path
23
+ from typing import Any, Dict, Iterator, List, Optional
24
+
25
+ from ..core.registry import Registry
26
+
27
+
28
+ @dataclass
29
+ class Sample:
30
+ """A single data sample."""
31
+
32
+ idx: int
33
+ text: str
34
+ label: Optional[Any] = None
35
+ metadata: Optional[Dict[str, Any]] = None
36
+
37
+ def __post_init__(self):
38
+ if self.metadata is None:
39
+ self.metadata = {}
40
+
41
+ def to_dict(self) -> Dict[str, Any]:
42
+ return {"idx": self.idx, "text": self.text, "label": self.label, "metadata": self.metadata}
43
+
44
+
45
+ class BaseDataset(ABC):
46
+ """Abstract base class for datasets."""
47
+
48
+ @property
49
+ @abstractmethod
50
+ def name(self) -> str:
51
+ """Dataset name."""
52
+ ...
53
+
54
+ @abstractmethod
55
+ def __len__(self) -> int:
56
+ """Number of samples."""
57
+ ...
58
+
59
+ @abstractmethod
60
+ def __getitem__(self, idx: int) -> Sample:
61
+ """Get a sample by index."""
62
+ ...
63
+
64
+ def __iter__(self) -> Iterator[Sample]:
65
+ """Iterate over samples."""
66
+ for i in range(len(self)):
67
+ yield self[i]
68
+
69
+ def sample(self, n: int, seed: int = 42) -> List[Sample]:
70
+ """Random sample of n items."""
71
+ import random
72
+
73
+ random.seed(seed)
74
+ indices = random.sample(range(len(self)), min(n, len(self)))
75
+ return [self[i] for i in indices]
76
+
77
+ def get_compatible_prompts(self) -> Optional[List[str]]:
78
+ """
79
+ Return list of compatible prompt names, or None if compatible with all.
80
+
81
+ Override this in specialized datasets to restrict usage.
82
+ """
83
+ return None
84
+
85
+
86
+ class JSONDataset(BaseDataset):
87
+ """Base class for JSON/CSV-based datasets with automatic format detection."""
88
+
89
+ def __init__(self, name: str, path: str, **kwargs):
90
+ self._name = name
91
+ self.path = self._resolve_path_from_registry(name, path)
92
+ self._samples: List[Sample] = []
93
+ self._load()
94
+
95
+ def _resolve_path_from_registry(self, name: str, default_path: str) -> Path:
96
+ """Resolve path from data/datasets.yaml registry."""
97
+ import yaml
98
+ from huggingface_hub import hf_hub_download
99
+
100
+ # Locate registry relative to this file (src/cotlab/datasets/loaders.py -> root/data/datasets.yaml)
101
+ # root is 3 levels up from this file's directory: src/cotlab/datasets -> src/cotlab -> src -> root
102
+ root_dir = Path(__file__).parent.parent.parent.parent
103
+ registry_path = root_dir / "data/datasets.yaml"
104
+
105
+ if not registry_path.exists():
106
+ # Fallback: try CWD
107
+ registry_path = Path("data/datasets.yaml")
108
+ if not registry_path.exists():
109
+ raise FileNotFoundError(
110
+ "datasets.yaml not found. Configure data/datasets.yaml with a Hugging Face repo_id."
111
+ )
112
+
113
+ try:
114
+ with open(registry_path, "r") as f:
115
+ config = yaml.safe_load(f)
116
+
117
+ # 1. Get Repo ID
118
+ ds_config = config.get("datasets", {}).get(name, {})
119
+ repo_id = ds_config.get("repo_id", config.get("default", {}).get("repo_id"))
120
+
121
+ if not repo_id:
122
+ raise ValueError(
123
+ f"No repo_id found for dataset '{name}'. Set it in data/datasets.yaml."
124
+ )
125
+
126
+ # 2. Determine Filename
127
+ # If explicit path in registry, use it.
128
+ # Otherwise, infer from default locally-styled path (e.g. data/radiology.json -> radiology.json)
129
+ filename = ds_config.get("path")
130
+ if not filename:
131
+ # heuristic: strip 'data/' prefix if present to map to HF root
132
+ p = Path(default_path)
133
+ if "data" in p.parts:
134
+ # e.g. data/foo -> foo, data/tcga/foo -> tcga/foo
135
+ try:
136
+ filename = str(p.relative_to("data"))
137
+ except ValueError:
138
+ filename = p.name
139
+ else:
140
+ filename = p.name
141
+
142
+ # 3. Download
143
+ try:
144
+ cached_path = hf_hub_download(
145
+ repo_id=repo_id, filename=filename, repo_type="dataset"
146
+ )
147
+ return Path(cached_path)
148
+ except Exception as e:
149
+ raise FileNotFoundError(
150
+ f"Failed to download {name} ({filename}) from HF repo {repo_id}: {e}"
151
+ )
152
+ except Exception as e:
153
+ raise FileNotFoundError(f"Failed to load datasets registry: {e}")
154
+
155
+ def _load(self):
156
+ """Load samples from JSON or CSV file based on extension."""
157
+ if not self.path.exists():
158
+ raise FileNotFoundError(f"Dataset not found: {self.path}")
159
+
160
+ if self.path.suffix.lower() == ".csv":
161
+ self._load_csv()
162
+ else:
163
+ self._load_json()
164
+
165
+ def _load_json(self):
166
+ """Load samples from JSON file."""
167
+ with open(self.path, "r", encoding="utf-8") as f:
168
+ data = json.load(f)
169
+
170
+ for idx, item in enumerate(data):
171
+ sample = self._parse_item(idx, item)
172
+ self._samples.append(sample)
173
+
174
+ def _load_csv(self):
175
+ """Load samples from CSV file."""
176
+ with open(self.path, "r", encoding="utf-8") as f:
177
+ reader = csv.DictReader(f)
178
+ for idx, row in enumerate(reader):
179
+ sample = self._parse_csv_row(idx, row)
180
+ self._samples.append(sample)
181
+
182
+ @abstractmethod
183
+ def _parse_item(self, idx: int, item: Dict[str, Any]) -> Sample:
184
+ """Parse a single JSON item into a Sample. Override in subclasses."""
185
+ ...
186
+
187
+ def _parse_csv_row(self, idx: int, row: Dict[str, Any]) -> Sample:
188
+ """Parse a single CSV row into a Sample. Override in subclasses for custom CSV handling."""
189
+ # Default implementation - subclasses should override for specific CSV formats
190
+ raise NotImplementedError(
191
+ f"CSV loading not implemented for {self.__class__.__name__}. "
192
+ "Override _parse_csv_row() or use a JSON file."
193
+ )
194
+
195
+ @property
196
+ def name(self) -> str:
197
+ return self._name
198
+
199
+ def __len__(self) -> int:
200
+ return len(self._samples)
201
+
202
+ def __getitem__(self, idx: int) -> Sample:
203
+ return self._samples[idx]
204
+
205
+
206
+ @Registry.register_dataset("radiology")
207
+ class RadiologyDataset(JSONDataset):
208
+ """Radiology reports dataset for pathological fracture detection."""
209
+
210
+ def __init__(
211
+ self,
212
+ name: str = "radiology",
213
+ path: str = "data/radiology.json",
214
+ **kwargs,
215
+ ):
216
+ super().__init__(name, path, **kwargs)
217
+
218
+ def _parse_item(self, idx: int, item: Dict[str, Any]) -> Sample:
219
+ return Sample(
220
+ idx=idx,
221
+ text=item["input"]["report"],
222
+ label=item["output"]["pathological_fracture"],
223
+ metadata=item.get("metadata", {}),
224
+ )
225
+
226
+ def _parse_csv_row(self, idx: int, row: Dict[str, Any]) -> Sample:
227
+ # CSV columns: Synthetic (text), Flag (label)
228
+ return Sample(
229
+ idx=idx,
230
+ text=row.get("Synthetic", ""),
231
+ label=row.get("Flag", ""),
232
+ metadata={},
233
+ )
234
+
235
+ def get_compatible_prompts(self) -> list[str]:
236
+ """
237
+ Radiology dataset only works with radiology prompt.
238
+
239
+ This dataset is for pathological fracture detection and should
240
+ NOT be used with general medical QA prompts.
241
+ """
242
+ return ["radiology"]
243
+
244
+
245
+ @Registry.register_dataset("cardiology")
246
+ class CardiologyDataset(JSONDataset):
247
+ """Cardiology reports dataset for congenital heart defect detection."""
248
+
249
+ def __init__(
250
+ self,
251
+ name: str = "cardiology",
252
+ path: str = "data/cardiology.json",
253
+ **kwargs,
254
+ ):
255
+ super().__init__(name, path, **kwargs)
256
+
257
+ def _parse_item(self, idx: int, item: Dict[str, Any]) -> Sample:
258
+ return Sample(
259
+ idx=idx,
260
+ text=item["input"]["report"],
261
+ label=item["output"]["congenital_heart_defect"],
262
+ metadata=item.get("metadata", {}),
263
+ )
264
+
265
+ def get_compatible_prompts(self) -> list[str]:
266
+ """
267
+ Cardiology dataset only works with cardiology prompt.
268
+
269
+ This dataset is for congenital heart defect detection and should
270
+ NOT be used with general medical QA prompts.
271
+ """
272
+ return ["cardiology"]
273
+
274
+
275
+ @Registry.register_dataset("neurology")
276
+ class NeurologyDataset(JSONDataset):
277
+ """Neurology reports dataset for neurological abnormality detection."""
278
+
279
+ def __init__(
280
+ self,
281
+ name: str = "neurology",
282
+ path: str = "data/neurology.json",
283
+ **kwargs,
284
+ ):
285
+ super().__init__(name, path, **kwargs)
286
+
287
+ def _parse_item(self, idx: int, item: Dict[str, Any]) -> Sample:
288
+ return Sample(
289
+ idx=idx,
290
+ text=item["input"]["report"],
291
+ label=item["output"]["neurological_abnormality"],
292
+ metadata=item.get("metadata", {}),
293
+ )
294
+
295
+ def get_compatible_prompts(self) -> list[str]:
296
+ """Neurology dataset only works with neurology prompt."""
297
+ return ["neurology"]
298
+
299
+
300
+ @Registry.register_dataset("oncology")
301
+ class OncologyDataset(JSONDataset):
302
+ """Oncology reports dataset for malignancy detection."""
303
+
304
+ def __init__(
305
+ self,
306
+ name: str = "oncology",
307
+ path: str = "data/oncology.json",
308
+ **kwargs,
309
+ ):
310
+ super().__init__(name, path, **kwargs)
311
+
312
+ def _parse_item(self, idx: int, item: Dict[str, Any]) -> Sample:
313
+ return Sample(
314
+ idx=idx,
315
+ text=item["input"]["report"],
316
+ label=item["output"]["malignancy"],
317
+ metadata=item.get("metadata", {}),
318
+ )
319
+
320
+ def get_compatible_prompts(self) -> list[str]:
321
+ """Oncology dataset only works with oncology prompt."""
322
+ return ["oncology"]
323
+
324
+
325
+ @Registry.register_dataset("pediatrics")
326
+ class PediatricsDataset(JSONDataset):
327
+ """Pediatrics clinical scenarios dataset."""
328
+
329
+ def __init__(
330
+ self,
331
+ name: str = "pediatrics",
332
+ path: str = "data/pediatrics.json",
333
+ **kwargs,
334
+ ):
335
+ super().__init__(name, path, **kwargs)
336
+
337
+ def _parse_item(self, idx: int, item: Dict[str, Any]) -> Sample:
338
+ return Sample(
339
+ idx=idx,
340
+ text=item["input"]["scenario"],
341
+ label=item["output"]["diagnosis"],
342
+ metadata=item.get("metadata", {}),
343
+ )
344
+
345
+ def _parse_csv_row(self, idx: int, row: Dict[str, Any]) -> Sample:
346
+ # CSV columns: Scenario, Diagnosis, Age_Group, Category
347
+ return Sample(
348
+ idx=idx,
349
+ text=row.get("Scenario", ""),
350
+ label=row.get("Diagnosis", ""),
351
+ metadata={
352
+ "age_group": row.get("Age_Group", ""),
353
+ "category": row.get("Category", ""),
354
+ },
355
+ )
356
+
357
+
358
+ @Registry.register_dataset("synthetic")
359
+ class SyntheticMedicalDataset(JSONDataset):
360
+ """Synthetic medical QA dataset."""
361
+
362
+ def __init__(
363
+ self,
364
+ name: str = "synthetic",
365
+ path: str = "data/synthetic.json",
366
+ repeat: int = 1,
367
+ **kwargs,
368
+ ):
369
+ self.repeat = repeat
370
+ super().__init__(name, path, **kwargs)
371
+
372
+ def _load(self):
373
+ """Load samples with optional repeat. Supports both JSON and CSV."""
374
+ if not self.path.exists():
375
+ raise FileNotFoundError(f"Dataset not found: {self.path}")
376
+
377
+ if self.path.suffix.lower() == ".csv":
378
+ self._load_csv_with_repeat()
379
+ else:
380
+ self._load_json_with_repeat()
381
+
382
+ def _load_json_with_repeat(self):
383
+ """Load JSON samples with optional repeat."""
384
+ with open(self.path, "r", encoding="utf-8") as f:
385
+ data = json.load(f)
386
+
387
+ for r in range(self.repeat):
388
+ for idx, item in enumerate(data):
389
+ sample = self._parse_item(r * len(data) + idx, item)
390
+ self._samples.append(sample)
391
+
392
+ def _load_csv_with_repeat(self):
393
+ """Load CSV samples with optional repeat."""
394
+ with open(self.path, "r", encoding="utf-8") as f:
395
+ rows = list(csv.DictReader(f))
396
+
397
+ for r in range(self.repeat):
398
+ for idx, row in enumerate(rows):
399
+ sample = self._parse_csv_row(r * len(rows) + idx, row)
400
+ self._samples.append(sample)
401
+
402
+ def _parse_item(self, idx: int, item: Dict[str, Any]) -> Sample:
403
+ return Sample(
404
+ idx=idx,
405
+ text=item["input"]["scenario"],
406
+ label=item["output"]["diagnosis"],
407
+ metadata=item.get("metadata", {}),
408
+ )
409
+
410
+ def _parse_csv_row(self, idx: int, row: Dict[str, Any]) -> Sample:
411
+ # CSV columns: Scenario, Expected_Answer, Reasoning_Keywords
412
+ return Sample(
413
+ idx=idx,
414
+ text=row.get("Scenario", ""),
415
+ label=row.get("Expected_Answer", ""),
416
+ metadata={
417
+ "reasoning_keywords": row.get("Reasoning_Keywords", ""),
418
+ },
419
+ )
420
+
421
+
422
+ @Registry.register_dataset("patching_pairs")
423
+ class PatchingPairsDataset(JSONDataset):
424
+ """Clean/corrupted pairs for activation patching experiments."""
425
+
426
+ def __init__(
427
+ self,
428
+ name: str = "patching_pairs",
429
+ path: str = "data/patching_pairs.json",
430
+ **kwargs,
431
+ ):
432
+ super().__init__(name, path, **kwargs)
433
+
434
+ def _parse_item(self, idx: int, item: Dict[str, Any]) -> Sample:
435
+ return Sample(
436
+ idx=idx,
437
+ text=item["clean"]["input"],
438
+ label=item["clean"]["output"],
439
+ metadata={
440
+ "corrupted_prompt": item["corrupted"]["input"],
441
+ "clean_answer": item["clean"]["output"],
442
+ "corrupted_answer": item["corrupted"]["output"],
443
+ "category": item.get("metadata", {}).get("category", "general"),
444
+ },
445
+ )
446
+
447
+ def _parse_csv_row(self, idx: int, row: Dict[str, Any]) -> Sample:
448
+ # CSV columns: Clean_Prompt, Corrupted_Prompt, Clean_Answer, Corrupted_Answer, Category
449
+ return Sample(
450
+ idx=idx,
451
+ text=row.get("Clean_Prompt", ""),
452
+ label=row.get("Clean_Answer", ""),
453
+ metadata={
454
+ "corrupted_prompt": row.get("Corrupted_Prompt", ""),
455
+ "clean_answer": row.get("Clean_Answer", ""),
456
+ "corrupted_answer": row.get("Corrupted_Answer", ""),
457
+ "category": row.get("Category", "general"),
458
+ },
459
+ )
460
+
461
+
462
+ @Registry.register_dataset("tutorial")
463
+ class TutorialDataset(JSONDataset):
464
+ """Simple Q&A dataset for tutorials and demos."""
465
+
466
+ def __init__(
467
+ self,
468
+ name: str = "tutorial",
469
+ path: str = "data/tutorial.json",
470
+ **kwargs,
471
+ ):
472
+ super().__init__(name, path, **kwargs)
473
+
474
+ def _parse_item(self, idx: int, item: Dict[str, Any]) -> Sample:
475
+ return Sample(
476
+ idx=idx,
477
+ text=item["text"],
478
+ label=item["label"],
479
+ metadata={},
480
+ )
481
+
482
+
483
+ @Registry.register_dataset("probing_diagnosis")
484
+ class ProbingDiagnosisDataset(JSONDataset):
485
+ """
486
+ Dataset for probing experiments testing diagnosis encoding.
487
+
488
+ Each sample has:
489
+ - question: Medical scenario
490
+ - diagnosis: Correct diagnosis (label)
491
+ - category: Medical specialty
492
+ - confounders: Alternative diagnoses
493
+ - key_features: Important clinical features
494
+ """
495
+
496
+ def __init__(
497
+ self,
498
+ name: str = "probing_diagnosis",
499
+ path: str = "data/probing_diagnosis.json",
500
+ **kwargs,
501
+ ):
502
+ super().__init__(name, path, **kwargs)
503
+
504
+ def _parse_item(self, idx: int, item: Dict[str, Any]) -> Sample:
505
+ return Sample(
506
+ idx=idx,
507
+ text=item["input"]["question"],
508
+ label=item["output"]["diagnosis"],
509
+ metadata={
510
+ "category": item.get("metadata", {}).get("category", "general"),
511
+ "difficulty": item.get("metadata", {}).get("difficulty", "medium"),
512
+ "confounders": item.get("metadata", {}).get("confounders", []),
513
+ "key_features": item.get("metadata", {}).get("key_features", []),
514
+ },
515
+ )
516
+
517
+
518
+ @Registry.register_dataset("histopathology")
519
+ class HistopathologyDataset(BaseDataset):
520
+ """Histopathology report quality evaluation dataset.
521
+
522
+ Loads HARE human evaluation annotations and expands each row
523
+ into 4 samples (one per model output with its human score).
524
+
525
+ Labels: 0 (poor), 1 (partial), 2 (good)
526
+ """
527
+
528
+ def __init__(
529
+ self,
530
+ name: str = "histopathology",
531
+ path: str = "data/histopathology.tsv",
532
+ repo_id: Optional[str] = None,
533
+ **kwargs,
534
+ ):
535
+ self._name = name
536
+ self.repo_id = repo_id
537
+ self.path = self._resolve_path_from_registry(name, path)
538
+ self._samples: List[Sample] = []
539
+ self._load()
540
+
541
+ def _resolve_path_from_registry(self, name: str, default_path: str) -> Path:
542
+ """Resolve path from data/datasets.yaml registry or fallback to default."""
543
+ import yaml
544
+ from huggingface_hub import hf_hub_download
545
+
546
+ repo_id = self.repo_id
547
+
548
+ # Locate registry relative to this file (src/cotlab/datasets/loaders.py -> root/data/datasets.yaml)
549
+ # root is 3 levels up from this file's directory: src/cotlab/datasets -> src/cotlab -> src -> root
550
+ root_dir = Path(__file__).parent.parent.parent.parent
551
+ registry_path = root_dir / "data/datasets.yaml"
552
+
553
+ if not registry_path.exists():
554
+ # Fallback: try CWD
555
+ registry_path = Path("data/datasets.yaml")
556
+ if not registry_path.exists():
557
+ return Path(default_path)
558
+
559
+ try:
560
+ with open(registry_path, "r") as f:
561
+ config = yaml.safe_load(f)
562
+
563
+ ds_config = config.get("datasets", {}).get(name, {})
564
+ if not repo_id:
565
+ repo_id = ds_config.get("repo_id", config.get("default", {}).get("repo_id"))
566
+
567
+ if not repo_id:
568
+ return Path(default_path)
569
+
570
+ filename = ds_config.get("path")
571
+ if not filename:
572
+ p = Path(default_path)
573
+ if "data" in p.parts:
574
+ try:
575
+ filename = str(p.relative_to("data"))
576
+ except ValueError:
577
+ filename = p.name
578
+ else:
579
+ filename = p.name
580
+
581
+ try:
582
+ cached_path = hf_hub_download(
583
+ repo_id=repo_id, filename=filename, repo_type="dataset"
584
+ )
585
+ return Path(cached_path)
586
+ except Exception as e:
587
+ print(
588
+ f"Warning: Failed to download {name} ({filename}) from HF repo {repo_id}: {e}"
589
+ )
590
+ except Exception as e:
591
+ print(f"Warning: Failed to load registry: {e}")
592
+
593
+ return Path(default_path)
594
+
595
+ def _load(self):
596
+ """Load TSV and expand to 4 samples per row."""
597
+ if not self.path.exists():
598
+ raise FileNotFoundError(f"Dataset not found: {self.path}")
599
+
600
+ with open(self.path, "r", encoding="utf-8") as f:
601
+ reader = csv.DictReader(f, delimiter="\t")
602
+ sample_idx = 0
603
+ for row_idx, row in enumerate(reader):
604
+ ground_truth = row.get("ground_truth", "")
605
+ # Expand 4 model outputs per row
606
+ for model_num in range(4):
607
+ report_col = str(model_num)
608
+ score_col = f"Scoring {model_num}"
609
+ report_text = row.get(report_col, "")
610
+ score = row.get(score_col, "")
611
+
612
+ # Skip if missing data
613
+ if not report_text or score == "":
614
+ continue
615
+
616
+ try:
617
+ label = int(float(score))
618
+ except (ValueError, TypeError):
619
+ continue
620
+
621
+ self._samples.append(
622
+ Sample(
623
+ idx=sample_idx,
624
+ text=report_text,
625
+ label=label,
626
+ metadata={
627
+ "ground_truth": ground_truth,
628
+ "model_id": model_num,
629
+ "row_id": row_idx,
630
+ },
631
+ )
632
+ )
633
+ sample_idx += 1
634
+
635
+ @property
636
+ def name(self) -> str:
637
+ return self._name
638
+
639
+ def __len__(self) -> int:
640
+ return len(self._samples)
641
+
642
+ def __getitem__(self, idx: int) -> Sample:
643
+ return self._samples[idx]
644
+
645
+ def get_compatible_prompts(self) -> list[str]:
646
+ """Histopathology dataset works with histopathology prompt."""
647
+ return ["histopathology"]
648
+
649
+
650
+ @Registry.register_dataset("tcga")
651
+ class TCGADataset(BaseDataset):
652
+ """TCGA Cancer Type Classification dataset.
653
+
654
+ Input: Pathology report text
655
+ Output: Cancer type (e.g., BRCA, LUAD)
656
+
657
+ Implements official 15% stratified test split using patient IDs.
658
+ """
659
+
660
+ def __init__(
661
+ self,
662
+ name: str = "tcga",
663
+ repo_id: Optional[str] = None,
664
+ reports_filename: str = "tcga/TCGA_Reports.csv",
665
+ labels_filename: str = "tcga/tcga_patient_to_cancer_type.csv",
666
+ split: str = "test", # train, test, or all
667
+ random_seed: int = 0,
668
+ **kwargs,
669
+ ):
670
+ self._name = name
671
+ self.repo_id = repo_id
672
+ self.reports_filename = reports_filename
673
+ self.labels_filename = labels_filename
674
+ self.split = split
675
+ self.random_seed = random_seed
676
+ self._samples: List[Sample] = []
677
+ self._load()
678
+
679
+ def _load(self):
680
+ import random
681
+
682
+ import yaml
683
+ from huggingface_hub import hf_hub_download
684
+
685
+ # Resolve Repo ID: Config > Registry > None
686
+ repo_id = self.repo_id
687
+ if not repo_id:
688
+ registry_path = Path("data/datasets.yaml")
689
+ if registry_path.exists():
690
+ try:
691
+ with open(registry_path, "r") as f:
692
+ config = yaml.safe_load(f)
693
+ # Check explicit "tcga" entry or default
694
+ ds_config = config.get("datasets", {}).get("tcga", {})
695
+ repo_id = ds_config.get("repo_id", config.get("default", {}).get("repo_id"))
696
+ except Exception as e:
697
+ print(f"Warning: Failed to load registry for TCGA: {e}")
698
+
699
+ if not repo_id:
700
+ raise ValueError("No repo_id found for TCGA dataset. Set it in data/datasets.yaml.")
701
+
702
+ try:
703
+ reports_path = hf_hub_download(
704
+ repo_id=repo_id, filename=self.reports_filename, repo_type="dataset"
705
+ )
706
+ labels_path = hf_hub_download(
707
+ repo_id=repo_id, filename=self.labels_filename, repo_type="dataset"
708
+ )
709
+ self.path = Path(reports_path)
710
+ self.labels_path = Path(labels_path)
711
+ except Exception as e:
712
+ raise FileNotFoundError(f"Failed to download from HF repo {repo_id}: {e}")
713
+
714
+ if not self.path.exists():
715
+ raise FileNotFoundError(f"Reports not found: {self.path}")
716
+ if not self.labels_path.exists():
717
+ raise FileNotFoundError(f"Labels not found: {self.labels_path}")
718
+
719
+ # 1. Load Labels
720
+ labels = {}
721
+ with open(self.labels_path, "r", encoding="utf-8") as f:
722
+ reader = csv.reader(f)
723
+ next(reader, None) # Skip header
724
+ for row in reader:
725
+ if len(row) >= 2:
726
+ labels[row[0]] = row[1]
727
+
728
+ # 2. Load Reports and Link
729
+ data_by_class = {} # ctype -> list of (patient_id, text)
730
+
731
+ with open(self.path, "r", encoding="utf-8") as f:
732
+ reader = csv.DictReader(f)
733
+ for row in reader:
734
+ pid_full = row.get("patient_filename", "")
735
+ text = row.get("text", "")
736
+
737
+ # Extract patient ID (TCGA-XX-YYYY)
738
+ if len(pid_full) < 12:
739
+ continue
740
+ pid = pid_full[:12]
741
+
742
+ if pid in labels:
743
+ ctype = labels[pid]
744
+ if ctype not in data_by_class:
745
+ data_by_class[ctype] = []
746
+ data_by_class[ctype].append((pid, text))
747
+
748
+ # 3. Stratified Split
749
+ final_samples = []
750
+
751
+ # Sort classes for deterministic order before shuffling
752
+ sorted_classes = sorted(data_by_class.keys())
753
+
754
+ for ctype in sorted_classes:
755
+ items = data_by_class[ctype]
756
+
757
+ # Deterministic shuffle
758
+ random.Random(self.random_seed).shuffle(items)
759
+
760
+ # Split index (15% test)
761
+ n_total = len(items)
762
+ n_test = int(n_total * 0.15)
763
+
764
+ if self.split == "all":
765
+ selected = items
766
+ elif self.split == "test":
767
+ # Test set is the first 15%
768
+ selected = items[:n_test]
769
+ elif self.split == "train":
770
+ selected = items[n_test:]
771
+ else:
772
+ selected = []
773
+
774
+ # Add to samples
775
+ for pid, text in selected:
776
+ final_samples.append((pid, text, ctype))
777
+
778
+ # 4. Create Sample objects
779
+ for i, (pid, text, ctype) in enumerate(final_samples):
780
+ self._samples.append(
781
+ Sample(
782
+ idx=i,
783
+ text=text,
784
+ label=ctype,
785
+ metadata={"patient_id": pid, "cancer_type": ctype},
786
+ )
787
+ )
788
+
789
+ @property
790
+ def name(self) -> str:
791
+ return self._name
792
+
793
+ def __len__(self) -> int:
794
+ return len(self._samples)
795
+
796
+ def __getitem__(self, idx: int) -> Sample:
797
+ return self._samples[idx]
798
+
799
+ def get_compatible_prompts(self) -> list[str]:
800
+ return ["tcga"]
801
+
802
+
803
+ @Registry.register_dataset("medqa")
804
+ class MedQADataset(BaseDataset):
805
+ """MedQA USMLE-style 4-option MCQ dataset.
806
+
807
+ Format: JSONL with fields:
808
+ - question: Clinical vignette
809
+ - options: Dict {"A": "...", "B": "...", "C": "...", "D": "..."}
810
+ - answer_idx: Correct answer letter (A/B/C/D)
811
+ - meta_info: USMLE step (step1, step2, step3)
812
+ """
813
+
814
+ def __init__(
815
+ self,
816
+ name: str = "medqa",
817
+ repo_id: Optional[str] = None,
818
+ filename: str = "medqa/test.jsonl",
819
+ split: str = "test",
820
+ **kwargs,
821
+ ):
822
+ self._name = name
823
+ self.repo_id = repo_id
824
+ self.filename = filename
825
+ self.split = split
826
+ self._samples: List[Sample] = []
827
+ self._load()
828
+
829
+ def _load(self):
830
+ import json
831
+
832
+ import yaml
833
+ from huggingface_hub import hf_hub_download
834
+
835
+ # Resolve repo_id from registry if not provided
836
+ repo_id = self.repo_id
837
+ if not repo_id:
838
+ root_dir = Path(__file__).parent.parent.parent.parent
839
+ registry_path = root_dir / "data/datasets.yaml"
840
+ if registry_path.exists():
841
+ try:
842
+ with open(registry_path, "r") as f:
843
+ config = yaml.safe_load(f)
844
+ ds_config = config.get("datasets", {}).get("medqa", {})
845
+ repo_id = ds_config.get("repo_id", config.get("default", {}).get("repo_id"))
846
+ except Exception as e:
847
+ print(f"Warning: Failed to load registry for MedQA: {e}")
848
+
849
+ if not repo_id:
850
+ raise ValueError("No repo_id found for MedQA dataset")
851
+
852
+ # Download from HF
853
+ try:
854
+ local_path = hf_hub_download(
855
+ repo_id=repo_id,
856
+ filename=self.filename,
857
+ repo_type="dataset",
858
+ )
859
+ except Exception as e:
860
+ raise FileNotFoundError(f"Failed to download MedQA from {repo_id}: {e}")
861
+
862
+ # Parse JSONL
863
+ with open(local_path, "r") as f:
864
+ for i, line in enumerate(f):
865
+ data = json.loads(line.strip())
866
+
867
+ # Format question with options
868
+ question = data["question"]
869
+ options = data["options"]
870
+ formatted_options = "\n".join(
871
+ f"{key}) {val}" for key, val in sorted(options.items())
872
+ )
873
+ text = f"{question}\n\n{formatted_options}"
874
+
875
+ # Answer is the letter (A, B, C, D)
876
+ label = data["answer_idx"]
877
+
878
+ self._samples.append(
879
+ Sample(
880
+ idx=i,
881
+ text=text,
882
+ label=label,
883
+ metadata={
884
+ "step": data.get("meta_info", "unknown"),
885
+ "answer_text": data.get("answer", ""),
886
+ },
887
+ )
888
+ )
889
+
890
+ @property
891
+ def name(self) -> str:
892
+ return self._name
893
+
894
+ def __len__(self) -> int:
895
+ return len(self._samples)
896
+
897
+ def __getitem__(self, idx: int) -> Sample:
898
+ return self._samples[idx]
899
+
900
+ def get_compatible_prompts(self) -> list[str]:
901
+ return [
902
+ "mcq",
903
+ "medqa",
904
+ "direct_answer",
905
+ "chain_of_thought",
906
+ "uncertainty",
907
+ "contrarian",
908
+ "few_shot",
909
+ ]
910
+
911
+
912
+ @Registry.register_dataset("mmlu_medical")
913
+ class MMLUMedicalDataset(BaseDataset):
914
+ """MMLU Medical subset dataset (anatomy, clinical_knowledge, medical_genetics, college_biology).
915
+
916
+ Format: JSONL with fields:
917
+ - question: Question text
918
+ - choices: List of 4 options
919
+ - answer: Integer index (0-3)
920
+ - subject: Subject name
921
+ """
922
+
923
+ def __init__(
924
+ self,
925
+ name: str = "mmlu_medical",
926
+ repo_id: Optional[str] = None,
927
+ filename: str = "mmlu/medical_test.jsonl",
928
+ **kwargs,
929
+ ):
930
+ self._name = name
931
+ self.repo_id = repo_id
932
+ self.filename = filename
933
+ self._samples: List[Sample] = []
934
+ self._load()
935
+
936
+ def _load(self):
937
+ import json
938
+
939
+ import yaml
940
+ from huggingface_hub import hf_hub_download
941
+
942
+ # Resolve repo_id from registry if not provided
943
+ repo_id = self.repo_id
944
+ if not repo_id:
945
+ root_dir = Path(__file__).parent.parent.parent.parent
946
+ registry_path = root_dir / "data/datasets.yaml"
947
+ if registry_path.exists():
948
+ try:
949
+ with open(registry_path, "r") as f:
950
+ config = yaml.safe_load(f)
951
+ ds_config = config.get("datasets", {}).get("mmlu_medical", {})
952
+ repo_id = ds_config.get("repo_id", config.get("default", {}).get("repo_id"))
953
+ except Exception as e:
954
+ print(f"Warning: Failed to load registry for MMLU Medical: {e}")
955
+
956
+ if not repo_id:
957
+ raise ValueError("No repo_id found for MMLU Medical dataset")
958
+
959
+ # Download from HF
960
+ try:
961
+ local_path = hf_hub_download(
962
+ repo_id=repo_id,
963
+ filename=self.filename,
964
+ repo_type="dataset",
965
+ )
966
+ except Exception as e:
967
+ raise FileNotFoundError(f"Failed to download MMLU Medical from {repo_id}: {e}")
968
+
969
+ # Parse JSONL
970
+ index_to_letter = {0: "A", 1: "B", 2: "C", 3: "D"}
971
+
972
+ with open(local_path, "r") as f:
973
+ for i, line in enumerate(f):
974
+ data = json.loads(line.strip())
975
+
976
+ # Format question with options
977
+ question = data["question"]
978
+ choices = data["choices"]
979
+ formatted_options = "\n".join(
980
+ f"{chr(65 + j)}) {opt}" for j, opt in enumerate(choices)
981
+ )
982
+ text = f"{question}\n\n{formatted_options}"
983
+
984
+ # Convert integer answer to letter
985
+ answer_idx = data["answer"]
986
+ label = index_to_letter.get(answer_idx, "A")
987
+
988
+ self._samples.append(
989
+ Sample(
990
+ idx=i,
991
+ text=text,
992
+ label=label,
993
+ metadata={
994
+ "subject": data.get("subject", "unknown"),
995
+ },
996
+ )
997
+ )
998
+
999
+ @property
1000
+ def name(self) -> str:
1001
+ return self._name
1002
+
1003
+ def __len__(self) -> int:
1004
+ return len(self._samples)
1005
+
1006
+ def __getitem__(self, idx: int) -> Sample:
1007
+ return self._samples[idx]
1008
+
1009
+ def get_compatible_prompts(self) -> list[str]:
1010
+ return [
1011
+ "mcq",
1012
+ "direct_answer",
1013
+ "chain_of_thought",
1014
+ "uncertainty",
1015
+ "contrarian",
1016
+ "few_shot",
1017
+ ]
1018
+
1019
+
1020
+ @Registry.register_dataset("m_arc")
1021
+ class MARCDataset(BaseDataset):
1022
+ """M-ARC benchmark.
1023
+
1024
+ Upstream dataset: mkieffer/M-ARC (Parquet, split=test, 100 items).
1025
+
1026
+ Expected columns:
1027
+ - question_id: string
1028
+ - question: string
1029
+ - options: struct with keys A..G
1030
+ - answer: string label (A..G)
1031
+ - src: string (source/category)
1032
+ """
1033
+
1034
+ def __init__(
1035
+ self,
1036
+ name: str = "m_arc",
1037
+ repo_id: Optional[str] = None,
1038
+ filename: str = "m_arc/test-00000-of-00001.parquet",
1039
+ split: str = "test",
1040
+ **kwargs,
1041
+ ):
1042
+ self._name = name
1043
+ self.repo_id = repo_id
1044
+ self.filename = filename
1045
+ self.split = split
1046
+ self._samples: List[Sample] = []
1047
+ self._load()
1048
+
1049
+ @property
1050
+ def name(self) -> str:
1051
+ return self._name
1052
+
1053
+ def __len__(self) -> int:
1054
+ return len(self._samples)
1055
+
1056
+ def __getitem__(self, idx: int) -> Sample:
1057
+ return self._samples[idx]
1058
+
1059
+ def get_compatible_prompts(self) -> list[str]:
1060
+ return [
1061
+ "mcq",
1062
+ "direct_answer",
1063
+ "chain_of_thought",
1064
+ "uncertainty",
1065
+ "contrarian",
1066
+ "few_shot",
1067
+ ]
1068
+
1069
+ def _resolve_repo_id(self) -> Optional[str]:
1070
+ """Resolve default repo_id from data/datasets.yaml, if present."""
1071
+ import yaml
1072
+
1073
+ root_dir = Path(__file__).parent.parent.parent.parent
1074
+ registry_path = root_dir / "data/datasets.yaml"
1075
+ if not registry_path.exists():
1076
+ return None
1077
+
1078
+ try:
1079
+ with open(registry_path, "r") as f:
1080
+ config = yaml.safe_load(f)
1081
+ ds_config = config.get("datasets", {}).get("m_arc", {})
1082
+ return ds_config.get("repo_id", config.get("default", {}).get("repo_id"))
1083
+ except Exception as e:
1084
+ print(f"Warning: Failed to load registry for M-ARC: {e}")
1085
+ return None
1086
+
1087
+ @staticmethod
1088
+ def _format_options(options: Dict[str, Any]) -> str:
1089
+ """Format A..G options as lines, skipping empty values."""
1090
+ if not options:
1091
+ return ""
1092
+ lines: list[str] = []
1093
+ for key in sorted(options.keys()):
1094
+ val = options.get(key)
1095
+ if val is None:
1096
+ continue
1097
+ if isinstance(val, str) and not val.strip():
1098
+ continue
1099
+ lines.append(f"{key}) {str(val).strip()}")
1100
+ return "\n".join(lines)
1101
+
1102
+ @staticmethod
1103
+ def _coerce_options(raw: Any) -> Dict[str, Any]:
1104
+ """Convert pyarrow struct / dict-like values into a plain dict."""
1105
+ if raw is None:
1106
+ return {}
1107
+ if isinstance(raw, dict):
1108
+ return raw
1109
+ # pyarrow StructScalar has as_py()
1110
+ as_py = getattr(raw, "as_py", None)
1111
+ if callable(as_py):
1112
+ val = as_py()
1113
+ if isinstance(val, dict):
1114
+ return val
1115
+ return {}
1116
+
1117
+ def _load(self) -> None:
1118
+ try:
1119
+ import pyarrow.parquet as pq
1120
+ except ImportError as e:
1121
+ raise ImportError("M-ARC requires pyarrow. Install with: uv pip install pyarrow") from e
1122
+
1123
+ from huggingface_hub import hf_hub_download
1124
+
1125
+ repo_id = self.repo_id or self._resolve_repo_id()
1126
+ if not repo_id:
1127
+ raise ValueError(
1128
+ "No repo_id found for M-ARC. Set default.repo_id in data/datasets.yaml "
1129
+ "or pass repo_id= explicitly."
1130
+ )
1131
+
1132
+ try:
1133
+ local_path = hf_hub_download(
1134
+ repo_id=repo_id,
1135
+ filename=self.filename,
1136
+ repo_type="dataset",
1137
+ )
1138
+ except Exception as e:
1139
+ raise FileNotFoundError(
1140
+ f"Failed to download M-ARC from {repo_id} (file: {self.filename}): {e}"
1141
+ )
1142
+
1143
+ used_repo_id = repo_id
1144
+ used_filename = self.filename
1145
+
1146
+ table = pq.read_table(Path(local_path))
1147
+ df = table.to_pandas()
1148
+
1149
+ for i, row in df.iterrows():
1150
+ question_id = row.get("question_id")
1151
+ question = row.get("question", "") or ""
1152
+ options = self._coerce_options(row.get("options"))
1153
+ options_formatted = self._format_options(options)
1154
+ text = f"{question}\n\n{options_formatted}".strip()
1155
+ label = (row.get("answer") or "").strip().upper()
1156
+
1157
+ self._samples.append(
1158
+ Sample(
1159
+ idx=int(i),
1160
+ text=text,
1161
+ label=label,
1162
+ metadata={
1163
+ "question_id": question_id,
1164
+ "src": row.get("src"),
1165
+ "question": question,
1166
+ "options": options,
1167
+ "options_formatted": options_formatted,
1168
+ "repo_id": used_repo_id,
1169
+ "filename": used_filename,
1170
+ },
1171
+ )
1172
+ )
1173
+
1174
+
1175
+ @Registry.register_dataset("medbullets")
1176
+ class MedBulletsDataset(BaseDataset):
1177
+ """MedBullets MCQ benchmark dataset.
1178
+
1179
+ Upstream dataset: mkieffer/Medbullets (Parquet, splits: op4_test, op5_test).
1180
+
1181
+ Splits:
1182
+ - op4_test: effectively 4 options (may include an empty option field)
1183
+ - op5_test: same stems, with an additional answer choice (letters may differ)
1184
+
1185
+ Expected columns:
1186
+ - idx: string id
1187
+ - question: string
1188
+ - options: struct/dict with keys A..E (some values may be empty strings)
1189
+ - answer: string label (A..E)
1190
+ - explanation: string (rationale)
1191
+ - link: string (source)
1192
+ """
1193
+
1194
+ def __init__(
1195
+ self,
1196
+ name: str = "medbullets",
1197
+ repo_id: Optional[str] = None,
1198
+ split: str = "op5_test",
1199
+ filename: Optional[str] = None,
1200
+ **kwargs,
1201
+ ):
1202
+ self._name = name
1203
+ self.repo_id = repo_id
1204
+ self.split = split
1205
+ self.filename = filename
1206
+ self._samples: List[Sample] = []
1207
+ self._load()
1208
+
1209
+ @property
1210
+ def name(self) -> str:
1211
+ return self._name
1212
+
1213
+ def __len__(self) -> int:
1214
+ return len(self._samples)
1215
+
1216
+ def __getitem__(self, idx: int) -> Sample:
1217
+ return self._samples[idx]
1218
+
1219
+ def get_compatible_prompts(self) -> list[str]:
1220
+ return [
1221
+ "mcq",
1222
+ "direct_answer",
1223
+ "chain_of_thought",
1224
+ "uncertainty",
1225
+ "contrarian",
1226
+ "few_shot",
1227
+ ]
1228
+
1229
+ def _default_filename(self) -> str:
1230
+ return f"medbullets/{self.split}-00000-of-00001.parquet"
1231
+
1232
+ def _resolve_repo_id(self) -> str:
1233
+ import yaml
1234
+
1235
+ if self.repo_id:
1236
+ return self.repo_id
1237
+
1238
+ root_dir = Path(__file__).parent.parent.parent.parent
1239
+ registry_path = root_dir / "data/datasets.yaml"
1240
+ if registry_path.exists():
1241
+ try:
1242
+ with open(registry_path, "r") as f:
1243
+ config = yaml.safe_load(f)
1244
+ ds_config = config.get("datasets", {}).get("medbullets", {})
1245
+ repo_id = ds_config.get("repo_id", config.get("default", {}).get("repo_id"))
1246
+ if repo_id:
1247
+ return repo_id
1248
+ except Exception as e:
1249
+ print(f"Warning: Failed to load registry for MedBullets: {e}")
1250
+
1251
+ raise ValueError(
1252
+ "No repo_id found for MedBullets. Set default.repo_id in data/datasets.yaml "
1253
+ "or pass repo_id= explicitly."
1254
+ )
1255
+
1256
+ @staticmethod
1257
+ def _coerce_options(raw: Any) -> Dict[str, Any]:
1258
+ if raw is None:
1259
+ return {}
1260
+ if isinstance(raw, dict):
1261
+ return raw
1262
+ as_py = getattr(raw, "as_py", None)
1263
+ if callable(as_py):
1264
+ val = as_py()
1265
+ if isinstance(val, dict):
1266
+ return val
1267
+ return {}
1268
+
1269
+ @staticmethod
1270
+ def _format_options(options: Dict[str, Any]) -> str:
1271
+ if not options:
1272
+ return ""
1273
+ lines: list[str] = []
1274
+ for key in sorted(options.keys()):
1275
+ val = options.get(key)
1276
+ if val is None:
1277
+ continue
1278
+ if isinstance(val, str) and not val.strip():
1279
+ continue
1280
+ lines.append(f"{key}) {str(val).strip()}")
1281
+ return "\n".join(lines)
1282
+
1283
+ def _load(self) -> None:
1284
+ try:
1285
+ import pyarrow.parquet as pq
1286
+ except ImportError as e:
1287
+ raise ImportError(
1288
+ "MedBullets requires pyarrow. Install with: uv pip install pyarrow"
1289
+ ) from e
1290
+
1291
+ from huggingface_hub import hf_hub_download
1292
+
1293
+ repo_id = self._resolve_repo_id()
1294
+ filename = self.filename or self._default_filename()
1295
+
1296
+ try:
1297
+ local_path = hf_hub_download(
1298
+ repo_id=repo_id,
1299
+ filename=filename,
1300
+ repo_type="dataset",
1301
+ )
1302
+ except Exception as e:
1303
+ raise FileNotFoundError(
1304
+ f"Failed to download MedBullets from {repo_id} (file: {filename}): {e}"
1305
+ )
1306
+
1307
+ table = pq.read_table(Path(local_path))
1308
+ df = table.to_pandas()
1309
+
1310
+ for i, row in df.iterrows():
1311
+ qid = row.get("idx")
1312
+ question = row.get("question", "") or ""
1313
+ options = self._coerce_options(row.get("options"))
1314
+ options_formatted = self._format_options(options)
1315
+ text = f"{question}\n\n{options_formatted}".strip()
1316
+ label = (row.get("answer") or "").strip().upper()
1317
+
1318
+ self._samples.append(
1319
+ Sample(
1320
+ idx=int(i),
1321
+ text=text,
1322
+ label=label,
1323
+ metadata={
1324
+ "idx": qid,
1325
+ "split": self.split,
1326
+ "question": question,
1327
+ "options": options,
1328
+ "options_formatted": options_formatted,
1329
+ "explanation": row.get("explanation"),
1330
+ "link": row.get("link"),
1331
+ "repo_id": repo_id,
1332
+ "filename": filename,
1333
+ },
1334
+ )
1335
+ )
1336
+
1337
+
1338
+ @Registry.register_dataset("plab")
1339
+ class PLABDataset(BaseDataset):
1340
+ """PLAB-style MCQ dataset loader from raw JSON files."""
1341
+
1342
+ def __init__(
1343
+ self,
1344
+ name: str = "plab",
1345
+ repo_id: Optional[str] = None,
1346
+ filename: str = "plab/data.json",
1347
+ topics_filename: str = "plab/topics.json",
1348
+ split: str = "main",
1349
+ **kwargs,
1350
+ ):
1351
+ self._name = name
1352
+ self.repo_id = repo_id
1353
+ self.filename = filename
1354
+ self.topics_filename = topics_filename
1355
+ self.split = split
1356
+ self._samples: List[Sample] = []
1357
+ self._load()
1358
+
1359
+ @property
1360
+ def name(self) -> str:
1361
+ return self._name
1362
+
1363
+ def __len__(self) -> int:
1364
+ return len(self._samples)
1365
+
1366
+ def __getitem__(self, idx: int) -> Sample:
1367
+ return self._samples[idx]
1368
+
1369
+ def get_compatible_prompts(self) -> list[str]:
1370
+ return [
1371
+ "plab",
1372
+ "mcq",
1373
+ "direct_answer",
1374
+ "chain_of_thought",
1375
+ "uncertainty",
1376
+ "contrarian",
1377
+ "few_shot",
1378
+ ]
1379
+
1380
+ def _resolve_repo_id(self) -> str:
1381
+ import yaml
1382
+
1383
+ if self.repo_id:
1384
+ return self.repo_id
1385
+
1386
+ root_dir = Path(__file__).parent.parent.parent.parent
1387
+ registry_path = root_dir / "data/datasets.yaml"
1388
+ if registry_path.exists():
1389
+ try:
1390
+ with open(registry_path, "r") as f:
1391
+ config = yaml.safe_load(f)
1392
+ ds_config = config.get("datasets", {}).get("plab", {})
1393
+ repo_id = ds_config.get("repo_id", config.get("default", {}).get("repo_id"))
1394
+ if repo_id:
1395
+ return repo_id
1396
+ except Exception as e:
1397
+ print(f"Warning: Failed to load registry for PLAB: {e}")
1398
+
1399
+ raise ValueError(
1400
+ "No repo_id found for PLAB. Set default.repo_id in data/datasets.yaml "
1401
+ "or pass repo_id= explicitly."
1402
+ )
1403
+
1404
+ @staticmethod
1405
+ def _as_text(value: Any) -> str:
1406
+ if value is None:
1407
+ return ""
1408
+ if isinstance(value, list):
1409
+ return "".join(str(x) for x in value)
1410
+ return str(value)
1411
+
1412
+ @staticmethod
1413
+ def _extract_question_id(question_text: str) -> Optional[int]:
1414
+ m = re.match(r"\s*(\d+)\.", question_text)
1415
+ if not m:
1416
+ return None
1417
+ try:
1418
+ return int(m.group(1))
1419
+ except ValueError:
1420
+ return None
1421
+
1422
+ @staticmethod
1423
+ def _extract_label(answer_text: str) -> Optional[str]:
1424
+ patterns = [
1425
+ r"\bkey\s+is\s*([A-G])\b",
1426
+ r"\bAns\.?\s*([A-G])\b",
1427
+ r"\bAnswer\s*[:\-]?\s*([A-G])\b",
1428
+ ]
1429
+ for pat in patterns:
1430
+ m = re.search(pat, answer_text, flags=re.IGNORECASE)
1431
+ if m:
1432
+ return m.group(1).upper()
1433
+ return None
1434
+
1435
+ @staticmethod
1436
+ def _format_options(raw_options: List[Any]) -> tuple[str, Dict[str, str]]:
1437
+ lines: List[str] = []
1438
+ options_map: Dict[str, str] = {}
1439
+ next_letter_ord = ord("A")
1440
+
1441
+ for raw in raw_options:
1442
+ option = str(raw).strip()
1443
+ if not option:
1444
+ continue
1445
+
1446
+ m = re.match(r"^\s*([A-Ga-g])[\)\.\:\-]\s*(.*)$", option)
1447
+ if m:
1448
+ letter = m.group(1).upper()
1449
+ text = m.group(2).strip()
1450
+ else:
1451
+ letter = chr(next_letter_ord)
1452
+ text = option
1453
+ next_letter_ord += 1
1454
+
1455
+ if not text:
1456
+ continue
1457
+
1458
+ options_map[letter] = text
1459
+ lines.append(f"{letter}) {text}")
1460
+
1461
+ return "\n".join(lines), options_map
1462
+
1463
+ def _load(self) -> None:
1464
+ from huggingface_hub import hf_hub_download
1465
+
1466
+ repo_id = self._resolve_repo_id()
1467
+
1468
+ try:
1469
+ data_path = Path(
1470
+ hf_hub_download(
1471
+ repo_id=repo_id,
1472
+ filename=self.filename,
1473
+ repo_type="dataset",
1474
+ )
1475
+ )
1476
+ except Exception as e:
1477
+ raise FileNotFoundError(
1478
+ f"Failed to download PLAB data from {repo_id} (file: {self.filename}): {e}"
1479
+ )
1480
+
1481
+ topics_by_qid: Dict[int, str] = {}
1482
+ try:
1483
+ topics_path = Path(
1484
+ hf_hub_download(
1485
+ repo_id=repo_id,
1486
+ filename=self.topics_filename,
1487
+ repo_type="dataset",
1488
+ )
1489
+ )
1490
+ with open(topics_path, "r", encoding="utf-8") as f:
1491
+ topics_data = json.load(f)
1492
+ for row in topics_data:
1493
+ qid = row.get("question")
1494
+ try:
1495
+ qid_int = int(qid)
1496
+ except (TypeError, ValueError):
1497
+ continue
1498
+ topic = self._as_text(row.get("topic")).strip()
1499
+ if topic:
1500
+ topics_by_qid[qid_int] = topic
1501
+ except Exception:
1502
+ topics_by_qid = {}
1503
+
1504
+ with open(data_path, "r", encoding="utf-8") as f:
1505
+ data = json.load(f)
1506
+
1507
+ sample_idx = 0
1508
+ for row in data:
1509
+ question_text = self._as_text(row.get("question")).strip()
1510
+ if not question_text:
1511
+ continue
1512
+
1513
+ raw_options = row.get("options")
1514
+ if not isinstance(raw_options, list):
1515
+ continue
1516
+
1517
+ options_formatted, options_map = self._format_options(raw_options)
1518
+ if not options_formatted:
1519
+ continue
1520
+
1521
+ answer_explanation = self._as_text(row.get("answer")).strip()
1522
+ label = self._extract_label(answer_explanation)
1523
+ if label is None or label not in options_map:
1524
+ continue
1525
+
1526
+ qid = self._extract_question_id(question_text)
1527
+ topic = topics_by_qid.get(qid) if qid is not None else None
1528
+ text = f"{question_text}\n\n{options_formatted}".strip()
1529
+
1530
+ self._samples.append(
1531
+ Sample(
1532
+ idx=sample_idx,
1533
+ text=text,
1534
+ label=label,
1535
+ metadata={
1536
+ "question_id": qid,
1537
+ "topic": topic,
1538
+ "question": question_text,
1539
+ "options": options_map,
1540
+ "options_formatted": options_formatted,
1541
+ "answer_explanation": answer_explanation,
1542
+ "split": self.split,
1543
+ "repo_id": repo_id,
1544
+ "filename": self.filename,
1545
+ "topics_filename": self.topics_filename,
1546
+ },
1547
+ )
1548
+ )
1549
+ sample_idx += 1
1550
+
1551
+
1552
+ @Registry.register_dataset("pubhealthbench")
1553
+ class PubHealthBenchDataset(BaseDataset):
1554
+ """PubHealthBench multiple-choice public health QA dataset.
1555
+
1556
+ Format: Parquet with fields:
1557
+ - question: Question text
1558
+ - options: List of answer choices
1559
+ - options_formatted: Pre-formatted options string (A. ..., B. ...)
1560
+ - answer_index: Integer index (0-based)
1561
+ - answer: Correct answer letter (A/B/C/...)
1562
+ - category, intended_audience, source_document_title, source_chunk_text
1563
+ - review_annotation, retrieved_context_for_judge, question_id
1564
+ """
1565
+
1566
+ def __init__(
1567
+ self,
1568
+ name: str = "pubhealthbench",
1569
+ repo_id: Optional[str] = None,
1570
+ filename: Optional[str] = None,
1571
+ split: str = "test",
1572
+ **kwargs,
1573
+ ):
1574
+ self._name = name
1575
+ self.repo_id = repo_id
1576
+ self.filename = filename
1577
+ self.split = split
1578
+ self._samples: List[Sample] = []
1579
+ self._load()
1580
+
1581
+ def _default_filename(self) -> str:
1582
+ if self.split == "reviewed":
1583
+ return "pubhealthbench/reviewed-00000-of-00001.parquet"
1584
+ return f"pubhealthbench/{self.split}-00000-of-00001.parquet"
1585
+
1586
+ def _resolve_repo_id(self) -> str:
1587
+ import yaml
1588
+
1589
+ repo_id = self.repo_id
1590
+ if repo_id:
1591
+ return repo_id
1592
+
1593
+ # Try registry override
1594
+ root_dir = Path(__file__).parent.parent.parent.parent
1595
+ registry_path = root_dir / "data/datasets.yaml"
1596
+ if registry_path.exists():
1597
+ try:
1598
+ with open(registry_path, "r") as f:
1599
+ config = yaml.safe_load(f)
1600
+ ds_config = config.get("datasets", {}).get("pubhealthbench", {})
1601
+ # Prefer explicit override, otherwise fall back to default repo_id.
1602
+ repo_id = ds_config.get("repo_id", config.get("default", {}).get("repo_id"))
1603
+ except Exception as e:
1604
+ print(f"Warning: Failed to load registry for PubHealthBench: {e}")
1605
+
1606
+ if not repo_id:
1607
+ raise ValueError(
1608
+ "No repo_id found for PubHealthBench. Set default.repo_id in data/datasets.yaml "
1609
+ "or pass repo_id= explicitly."
1610
+ )
1611
+ return repo_id
1612
+
1613
+ def _format_options(self, options: Optional[List[str]]) -> str:
1614
+ if not options:
1615
+ return ""
1616
+ letters = [chr(65 + i) for i in range(len(options))]
1617
+ return "\n".join(f"{letter}. {opt}" for letter, opt in zip(letters, options))
1618
+
1619
+ def _load(self):
1620
+ import math
1621
+
1622
+ try:
1623
+ import pyarrow.parquet as pq
1624
+ except ImportError as e:
1625
+ raise ImportError(
1626
+ "PubHealthBench requires pyarrow. Install with: uv pip install pyarrow"
1627
+ ) from e
1628
+
1629
+ from huggingface_hub import hf_hub_download
1630
+
1631
+ filename = self.filename or self._default_filename()
1632
+ repo_id = self._resolve_repo_id()
1633
+ try:
1634
+ path = Path(
1635
+ hf_hub_download(
1636
+ repo_id=repo_id,
1637
+ filename=filename,
1638
+ repo_type="dataset",
1639
+ )
1640
+ )
1641
+ except Exception as e:
1642
+ raise FileNotFoundError(
1643
+ f"Failed to download PubHealthBench from {repo_id} (file: {filename}): {e}"
1644
+ )
1645
+
1646
+ table = pq.read_table(path)
1647
+ df = table.to_pandas()
1648
+
1649
+ index_to_letter = {i: chr(65 + i) for i in range(26)}
1650
+
1651
+ for i, row in df.iterrows():
1652
+
1653
+ def clean_value(value: Any) -> Any:
1654
+ if value is None:
1655
+ return None
1656
+ if isinstance(value, float) and math.isnan(value):
1657
+ return None
1658
+ return value
1659
+
1660
+ question = row.get("question", "")
1661
+ options_formatted = clean_value(row.get("options_formatted"))
1662
+ if not options_formatted:
1663
+ options_formatted = self._format_options(row.get("options"))
1664
+
1665
+ text = f"{question}\n\n{options_formatted}".strip()
1666
+
1667
+ answer = clean_value(row.get("answer"))
1668
+ if not answer and row.get("answer_index") is not None:
1669
+ answer = index_to_letter.get(int(row["answer_index"]), "A")
1670
+
1671
+ metadata = {
1672
+ "question_id": clean_value(row.get("question_id")),
1673
+ "question": question,
1674
+ "options_formatted": options_formatted,
1675
+ "category": clean_value(row.get("category")),
1676
+ "intended_audience": clean_value(row.get("intended_audience")),
1677
+ "source_document_title": clean_value(row.get("source_document_title")),
1678
+ "source_chunk_text": clean_value(row.get("source_chunk_text")),
1679
+ "review_annotation": clean_value(row.get("review_annotation")),
1680
+ "retrieved_context_for_judge": clean_value(row.get("retrieved_context_for_judge")),
1681
+ }
1682
+
1683
+ self._samples.append(
1684
+ Sample(
1685
+ idx=i,
1686
+ text=text,
1687
+ label=answer,
1688
+ metadata=metadata,
1689
+ )
1690
+ )
1691
+
1692
+ @property
1693
+ def name(self) -> str:
1694
+ return self._name
1695
+
1696
+ def __len__(self) -> int:
1697
+ return len(self._samples)
1698
+
1699
+ def __getitem__(self, idx: int) -> Sample:
1700
+ return self._samples[idx]
1701
+
1702
+ def get_compatible_prompts(self) -> list[str]:
1703
+ return [
1704
+ "mcq",
1705
+ "pubhealthbench",
1706
+ "direct_answer",
1707
+ "chain_of_thought",
1708
+ "uncertainty",
1709
+ "contrarian",
1710
+ "few_shot",
1711
+ ]
1712
+
1713
+
1714
+ @Registry.register_dataset("pubmedqa")
1715
+ class PubMedQADataset(BaseDataset):
1716
+ """PubMedQA research question answering dataset.
1717
+
1718
+ Format: JSONL with fields:
1719
+ - question: Research question
1720
+ - context: Abstract context
1721
+ - answer: yes/no/maybe
1722
+ - pmid: PubMed ID
1723
+ """
1724
+
1725
+ def __init__(
1726
+ self,
1727
+ name: str = "pubmedqa",
1728
+ repo_id: Optional[str] = None,
1729
+ filename: str = "pubmedqa/test.jsonl",
1730
+ **kwargs,
1731
+ ):
1732
+ self._name = name
1733
+ self.repo_id = repo_id
1734
+ self.filename = filename
1735
+ self._samples: List[Sample] = []
1736
+ self._load()
1737
+
1738
+ def _load(self):
1739
+ import json
1740
+
1741
+ import yaml
1742
+ from huggingface_hub import hf_hub_download
1743
+
1744
+ # Resolve repo_id from registry if not provided
1745
+ repo_id = self.repo_id
1746
+ if not repo_id:
1747
+ root_dir = Path(__file__).parent.parent.parent.parent
1748
+ registry_path = root_dir / "data/datasets.yaml"
1749
+ if registry_path.exists():
1750
+ try:
1751
+ with open(registry_path, "r") as f:
1752
+ config = yaml.safe_load(f)
1753
+ ds_config = config.get("datasets", {}).get("pubmedqa", {})
1754
+ repo_id = ds_config.get("repo_id", config.get("default", {}).get("repo_id"))
1755
+ except Exception as e:
1756
+ print(f"Warning: Failed to load registry for PubMedQA: {e}")
1757
+
1758
+ if not repo_id:
1759
+ raise ValueError("No repo_id found for PubMedQA dataset")
1760
+
1761
+ # Download from HF
1762
+ try:
1763
+ local_path = hf_hub_download(
1764
+ repo_id=repo_id,
1765
+ filename=self.filename,
1766
+ repo_type="dataset",
1767
+ )
1768
+ except Exception as e:
1769
+ raise FileNotFoundError(f"Failed to download PubMedQA from {repo_id}: {e}")
1770
+
1771
+ # Parse JSONL
1772
+ with open(local_path, "r") as f:
1773
+ for i, line in enumerate(f):
1774
+ data = json.loads(line.strip())
1775
+
1776
+ # Format: question + context, predict yes/no/maybe
1777
+ question = data["question"]
1778
+ context = data.get("context", "")
1779
+
1780
+ # Truncate very long contexts
1781
+ if len(context) > 1500:
1782
+ context = context[:1500] + "..."
1783
+
1784
+ text = f"Question: {question}\n\nContext: {context}"
1785
+
1786
+ # Label is yes/no/maybe
1787
+ label = data["answer"]
1788
+
1789
+ self._samples.append(
1790
+ Sample(
1791
+ idx=i,
1792
+ text=text,
1793
+ label=label,
1794
+ metadata={
1795
+ "pmid": data.get("pmid", ""),
1796
+ },
1797
+ )
1798
+ )
1799
+
1800
+ @property
1801
+ def name(self) -> str:
1802
+ return self._name
1803
+
1804
+ def __len__(self) -> int:
1805
+ return len(self._samples)
1806
+
1807
+ def __getitem__(self, idx: int) -> Sample:
1808
+ return self._samples[idx]
1809
+
1810
+ def get_compatible_prompts(self) -> list[str]:
1811
+ return [
1812
+ "pubmedqa",
1813
+ "mcq",
1814
+ "direct_answer",
1815
+ "chain_of_thought",
1816
+ "uncertainty",
1817
+ "contrarian",
1818
+ "few_shot",
1819
+ ]
1820
+
1821
+
1822
+ @Registry.register_dataset("movie_ood")
1823
+ class MovieOODDataset(BaseDataset):
1824
+ """Post-cutoff movie MCQ dataset for true zero-knowledge OOD evaluation.
1825
+
1826
+ Generated from TMDB API: films released after MedGemma training cutoff (2025-07-01).
1827
+ Question types: director, cast, genre, production_country.
1828
+
1829
+ Format: JSONL with fields:
1830
+ - question: Question text
1831
+ - options: Dict {"A": "...", "B": "...", "C": "...", "D": "..."}
1832
+ - answer: Correct answer letter (A/B/C/D)
1833
+ - metadata: movie_id, title, release_date, question_type
1834
+ """
1835
+
1836
+ def __init__(
1837
+ self,
1838
+ name: str = "movie_ood",
1839
+ path: str = "data/movie_ood.jsonl",
1840
+ **kwargs,
1841
+ ):
1842
+ self._name = name
1843
+ self._samples: List[Sample] = []
1844
+ self._load(Path(path))
1845
+
1846
+ def _load(self, path: Path):
1847
+ if not path.exists():
1848
+ raise FileNotFoundError(
1849
+ f"movie_ood dataset not found at {path}.\n"
1850
+ "Run: python scripts/build_movie_ood_dataset.py"
1851
+ )
1852
+ with open(path, "r", encoding="utf-8") as f:
1853
+ for i, line in enumerate(f):
1854
+ line = line.strip()
1855
+ if not line:
1856
+ continue
1857
+ data = json.loads(line)
1858
+ question = data["question"]
1859
+ options = data["options"]
1860
+ formatted_options = "\n".join(
1861
+ f"{key}) {val}" for key, val in sorted(options.items())
1862
+ )
1863
+ text = f"{question}\n\n{formatted_options}"
1864
+ self._samples.append(
1865
+ Sample(
1866
+ idx=i,
1867
+ text=text,
1868
+ label=data.get("answer", "").strip().upper(),
1869
+ metadata=data.get("metadata", {}),
1870
+ )
1871
+ )
1872
+
1873
+ @property
1874
+ def name(self) -> str:
1875
+ return self._name
1876
+
1877
+ def __len__(self) -> int:
1878
+ return len(self._samples)
1879
+
1880
+ def __getitem__(self, idx: int) -> Sample:
1881
+ return self._samples[idx]
1882
+
1883
+ def get_compatible_prompts(self) -> list[str]:
1884
+ return [
1885
+ "mcq",
1886
+ "direct_answer",
1887
+ "chain_of_thought",
1888
+ "few_shot",
1889
+ ]