cotlab 0.8.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (65) hide show
  1. cotlab/__init__.py +3 -0
  2. cotlab/analyse_experiments.py +392 -0
  3. cotlab/analysis/__init__.py +11 -0
  4. cotlab/analysis/cot_parser.py +243 -0
  5. cotlab/analysis/faithfulness_metrics.py +192 -0
  6. cotlab/backends/__init__.py +16 -0
  7. cotlab/backends/base.py +78 -0
  8. cotlab/backends/transformers_backend.py +335 -0
  9. cotlab/backends/vllm_backend.py +227 -0
  10. cotlab/cli.py +83 -0
  11. cotlab/core/__init__.py +34 -0
  12. cotlab/core/base.py +749 -0
  13. cotlab/core/config.py +90 -0
  14. cotlab/core/registry.py +68 -0
  15. cotlab/datasets/__init__.py +45 -0
  16. cotlab/datasets/loaders.py +1889 -0
  17. cotlab/experiment/__init__.py +315 -0
  18. cotlab/experiments/__init__.py +43 -0
  19. cotlab/experiments/activation_compare.py +290 -0
  20. cotlab/experiments/activation_patching.py +1050 -0
  21. cotlab/experiments/attention_analysis.py +885 -0
  22. cotlab/experiments/classification.py +235 -0
  23. cotlab/experiments/composite_shift_detector.py +524 -0
  24. cotlab/experiments/cot_ablation.py +277 -0
  25. cotlab/experiments/cot_faithfulness.py +187 -0
  26. cotlab/experiments/cot_heads.py +208 -0
  27. cotlab/experiments/full_layer_cot.py +232 -0
  28. cotlab/experiments/full_layer_patching.py +225 -0
  29. cotlab/experiments/h_neuron_analysis.py +712 -0
  30. cotlab/experiments/logit_lens.py +439 -0
  31. cotlab/experiments/multi_head_cot.py +220 -0
  32. cotlab/experiments/multi_head_patching.py +229 -0
  33. cotlab/experiments/probing_classifier.py +402 -0
  34. cotlab/experiments/residual_norm_ood.py +413 -0
  35. cotlab/experiments/sae_feature_analysis.py +673 -0
  36. cotlab/experiments/steering_vectors.py +223 -0
  37. cotlab/experiments/sycophancy_heads.py +224 -0
  38. cotlab/logging/__init__.py +5 -0
  39. cotlab/logging/json_logger.py +161 -0
  40. cotlab/main.py +317 -0
  41. cotlab/patching/__init__.py +24 -0
  42. cotlab/patching/cache.py +141 -0
  43. cotlab/patching/hooks.py +558 -0
  44. cotlab/patching/interventions.py +86 -0
  45. cotlab/patching/patcher.py +439 -0
  46. cotlab/patching/sae.py +181 -0
  47. cotlab/prompts/__init__.py +43 -0
  48. cotlab/prompts/cardiology.py +378 -0
  49. cotlab/prompts/histopathology.py +265 -0
  50. cotlab/prompts/length_matched_strategies.py +157 -0
  51. cotlab/prompts/mcq.py +193 -0
  52. cotlab/prompts/neurology.py +353 -0
  53. cotlab/prompts/oncology.py +367 -0
  54. cotlab/prompts/plab.py +162 -0
  55. cotlab/prompts/pubhealthbench.py +82 -0
  56. cotlab/prompts/pubmedqa.py +173 -0
  57. cotlab/prompts/radiology.py +414 -0
  58. cotlab/prompts/strategies.py +939 -0
  59. cotlab/prompts/tcga.py +168 -0
  60. cotlab/runner.py +204 -0
  61. cotlab-0.8.0.dist-info/METADATA +166 -0
  62. cotlab-0.8.0.dist-info/RECORD +65 -0
  63. cotlab-0.8.0.dist-info/WHEEL +4 -0
  64. cotlab-0.8.0.dist-info/entry_points.txt +3 -0
  65. cotlab-0.8.0.dist-info/licenses/LICENSE +21 -0
cotlab/core/base.py ADDED
@@ -0,0 +1,749 @@
1
+ """Base classes and data structures for the CoT research framework."""
2
+
3
+ import json
4
+ from abc import ABC, abstractmethod
5
+ from dataclasses import dataclass, field
6
+ from datetime import datetime
7
+ from typing import Any, Dict, List, Optional
8
+
9
+
10
+ @dataclass
11
+ class GenerationOutput:
12
+ """Standard output format for model generation."""
13
+
14
+ text: str
15
+ tokens: List[int]
16
+ logprobs: Optional[List[float]] = None
17
+
18
+ def __repr__(self) -> str:
19
+ return f"GenerationOutput(text={self.text[:50]}..., tokens={len(self.tokens)})"
20
+
21
+
22
+ @dataclass
23
+ class ExperimentResult:
24
+ """JSON-serializable experiment result."""
25
+
26
+ experiment_name: str
27
+ model_name: str
28
+ prompt_strategy: str
29
+ metrics: Dict[str, Any]
30
+ raw_outputs: List[Dict[str, Any]] = field(default_factory=list)
31
+ metadata: Dict[str, Any] = field(default_factory=dict)
32
+ timestamp: str = field(default_factory=lambda: datetime.now().isoformat())
33
+
34
+ def to_json(self) -> str:
35
+ """Serialize to JSON string."""
36
+ return json.dumps(self.__dict__, indent=2, default=str)
37
+
38
+ def save(self, path: str) -> None:
39
+ """Save to JSON file."""
40
+ with open(path, "w") as f:
41
+ f.write(self.to_json())
42
+
43
+ @classmethod
44
+ def load(cls, path: str) -> "ExperimentResult":
45
+ """Load from JSON file."""
46
+ with open(path, "r") as f:
47
+ data = json.load(f)
48
+ return cls(**data)
49
+
50
+
51
+ class BasePromptStrategy(ABC):
52
+ """Abstract base class for prompt construction strategies."""
53
+
54
+ @property
55
+ @abstractmethod
56
+ def name(self) -> str:
57
+ """Strategy name for logging."""
58
+ ...
59
+
60
+ @abstractmethod
61
+ def build_prompt(self, input_data: Dict[str, Any]) -> str:
62
+ """
63
+ Build a prompt from input data.
64
+
65
+ Args:
66
+ input_data: Dictionary with at least 'question' key
67
+
68
+ Returns:
69
+ Formatted prompt string
70
+ """
71
+ ...
72
+
73
+ @abstractmethod
74
+ def parse_response(self, response: str) -> Dict[str, Any]:
75
+ """
76
+ Parse model response to extract answer and reasoning.
77
+
78
+ Args:
79
+ response: Raw model output
80
+
81
+ Returns:
82
+ Dictionary with 'answer', 'reasoning', and any other extracted fields
83
+ """
84
+ ...
85
+
86
+ def get_system_message(self) -> Optional[str]:
87
+ """Return system message if applicable."""
88
+ return None
89
+
90
+ def get_compatible_datasets(self) -> Optional[List[str]]:
91
+ """
92
+ Return list of compatible dataset names, or None if compatible with all.
93
+
94
+ Override this in specialized prompts to restrict usage.
95
+ """
96
+ return None
97
+
98
+
99
+ # Output format options
100
+ class OutputFormat:
101
+ """Supported output formats for structured responses."""
102
+
103
+ PLAIN = "plain"
104
+ JSON = "json"
105
+ TOON = "toon"
106
+ TOML = "toml"
107
+ XML = "xml"
108
+ YAML = "yaml"
109
+ MARKDOWN = "markdown"
110
+
111
+ @classmethod
112
+ def all(cls) -> list:
113
+ return [cls.PLAIN, cls.JSON, cls.TOON, cls.TOML, cls.XML, cls.YAML, cls.MARKDOWN]
114
+
115
+
116
+ # Plain text output schema for medical diagnosis tasks
117
+ PLAIN_OUTPUT_SCHEMA = """\
118
+
119
+ Provide your answer in plain text. At the end of your response, clearly state:
120
+
121
+ FINAL ANSWER: [your diagnosis]
122
+ CONFIDENCE: [0-100]
123
+
124
+ Example:
125
+ Based on the clinical findings, the patient shows signs consistent with the condition.
126
+
127
+ FINAL ANSWER: positive
128
+ CONFIDENCE: 85"""
129
+
130
+ PLAIN_COT_SCHEMA = """\
131
+
132
+ Provide your reasoning step by step in plain text. At the end, clearly state:
133
+
134
+ FINAL ANSWER: [your diagnosis]
135
+ CONFIDENCE: [0-100]
136
+
137
+ Example:
138
+ Step 1: The report mentions...
139
+ Step 2: This suggests...
140
+ Step 3: Therefore...
141
+
142
+ FINAL ANSWER: positive
143
+ CONFIDENCE: 85"""
144
+
145
+ # JSON output schema for medical diagnosis tasks
146
+ JSON_OUTPUT_SCHEMA = """\
147
+
148
+ Output your answer ONLY in this JSON format:
149
+ ```json
150
+ {
151
+ "diagnosis": "[your diagnosis]",
152
+ "confidence": [0-100],
153
+ "reasoning": "[brief explanation]"
154
+ }
155
+ ```"""
156
+
157
+ JSON_COT_SCHEMA = """\
158
+
159
+ Output your answer ONLY in this JSON format:
160
+ ```json
161
+ {
162
+ "step_by_step": ["Step 1: ...", "Step 2: ..."],
163
+ "diagnosis": "[your diagnosis]",
164
+ "confidence": [0-100]
165
+ }
166
+ ```"""
167
+
168
+ # TOON (Token-Oriented Object Notation) - 40% fewer tokens than JSON
169
+ TOON_OUTPUT_SCHEMA = """\
170
+
171
+ Output your answer ONLY in this TOON format (indentation-based, no braces):
172
+ ```
173
+ diagnosis: [your diagnosis]
174
+ confidence: [0-100]
175
+ reasoning: [brief explanation]
176
+ ```"""
177
+
178
+ TOON_COT_SCHEMA = """\
179
+
180
+ Output your answer ONLY in this TOON format:
181
+ ```
182
+ step_by_step: [3]
183
+ Step 1: ...
184
+ Step 2: ...
185
+ Step 3: ...
186
+ diagnosis: [your diagnosis]
187
+ confidence: [0-100]
188
+ ```"""
189
+
190
+ # TOML format
191
+ TOML_OUTPUT_SCHEMA = """\
192
+
193
+ Output your answer ONLY in this TOML format:
194
+ ```toml
195
+ [response]
196
+ diagnosis = "[your diagnosis]"
197
+ confidence = 0-100
198
+ reasoning = "[brief explanation]"
199
+ ```"""
200
+
201
+ TOML_COT_SCHEMA = """\
202
+
203
+ Output your answer ONLY in this TOML format:
204
+ ```toml
205
+ [response]
206
+ step_by_step = ["Step 1: ...", "Step 2: ..."]
207
+ diagnosis = "[your diagnosis]"
208
+ confidence = 0-100
209
+ ```"""
210
+
211
+ # XML format
212
+ XML_OUTPUT_SCHEMA = """\
213
+
214
+ Output your answer ONLY in this XML format:
215
+ ```xml
216
+ <response>
217
+ <diagnosis>[your diagnosis]</diagnosis>
218
+ <confidence>[0-100]</confidence>
219
+ <reasoning>[brief explanation]</reasoning>
220
+ </response>
221
+ ```"""
222
+
223
+ XML_COT_SCHEMA = """\
224
+
225
+ Output your answer ONLY in this XML format:
226
+ ```xml
227
+ <response>
228
+ <step_by_step>
229
+ <step>Step 1: ...</step>
230
+ <step>Step 2: ...</step>
231
+ </step_by_step>
232
+ <diagnosis>[your diagnosis]</diagnosis>
233
+ <confidence>[0-100]</confidence>
234
+ </response>
235
+ ```"""
236
+
237
+ # YAML format
238
+ YAML_OUTPUT_SCHEMA = """\
239
+
240
+ Output your answer ONLY in this YAML format:
241
+ ```yaml
242
+ diagnosis: "[your diagnosis]"
243
+ confidence: 0-100
244
+ reasoning: "[brief explanation]"
245
+ ```"""
246
+
247
+ YAML_COT_SCHEMA = """\
248
+
249
+ Output your answer ONLY in this YAML format:
250
+ ```yaml
251
+ step_by_step:
252
+ - "Step 1: ..."
253
+ - "Step 2: ..."
254
+ diagnosis: "[your diagnosis]"
255
+ confidence: 0-100
256
+ ```"""
257
+
258
+ # Markdown structured format
259
+ MARKDOWN_OUTPUT_SCHEMA = """\
260
+
261
+ Output your answer ONLY in this structured Markdown format:
262
+ ## Diagnosis
263
+ [your diagnosis]
264
+
265
+ ## Confidence
266
+ [0-100]
267
+
268
+ ## Reasoning
269
+ [brief explanation]"""
270
+
271
+ MARKDOWN_COT_SCHEMA = """\
272
+
273
+ Output your answer ONLY in this structured Markdown format:
274
+ ## Step-by-Step Reasoning
275
+ 1. Step 1: ...
276
+ 2. Step 2: ...
277
+
278
+ ## Diagnosis
279
+ [your diagnosis]
280
+
281
+ ## Confidence
282
+ [0-100]"""
283
+
284
+
285
+ class StructuredOutputMixin:
286
+ """Mixin to add multi-format structured output capability to any prompt strategy.
287
+
288
+ Supports: PLAIN, JSON, TOON, TOML, XML, YAML, MARKDOWN
289
+
290
+ Usage:
291
+ class MyStrategy(StructuredOutputMixin, BasePromptStrategy):
292
+ def __init__(self, output_format="plain", ...):
293
+ self.output_format = output_format
294
+ self.cot_format = False # Set True for CoT inside structure
295
+ """
296
+
297
+ output_format: str = "plain"
298
+ cot_format: bool = False
299
+
300
+ # Backward compatibility with json_output
301
+ @property
302
+ def json_output(self) -> bool:
303
+ return self.output_format == OutputFormat.JSON
304
+
305
+ @json_output.setter
306
+ def json_output(self, value: bool):
307
+ if value:
308
+ self.output_format = OutputFormat.JSON
309
+
310
+ def _get_format_schema(self) -> tuple:
311
+ """Return (output_schema, cot_schema) for current format."""
312
+ schemas = {
313
+ OutputFormat.PLAIN: (PLAIN_OUTPUT_SCHEMA, PLAIN_COT_SCHEMA),
314
+ OutputFormat.JSON: (JSON_OUTPUT_SCHEMA, JSON_COT_SCHEMA),
315
+ OutputFormat.TOON: (TOON_OUTPUT_SCHEMA, TOON_COT_SCHEMA),
316
+ OutputFormat.TOML: (TOML_OUTPUT_SCHEMA, TOML_COT_SCHEMA),
317
+ OutputFormat.XML: (XML_OUTPUT_SCHEMA, XML_COT_SCHEMA),
318
+ OutputFormat.YAML: (YAML_OUTPUT_SCHEMA, YAML_COT_SCHEMA),
319
+ OutputFormat.MARKDOWN: (MARKDOWN_OUTPUT_SCHEMA, MARKDOWN_COT_SCHEMA),
320
+ }
321
+ return schemas.get(self.output_format, ("", ""))
322
+
323
+ def _add_format_instruction(self) -> str:
324
+ """Return format instruction to append to prompt."""
325
+ output_schema, cot_schema = self._get_format_schema()
326
+ return cot_schema if self.cot_format else output_schema
327
+
328
+ def _format_example(self, example_data: Dict[str, Any], format_type: str = None) -> str:
329
+ """Format a single example in the specified format.
330
+
331
+ Args:
332
+ example_data: Dictionary with example data
333
+ format_type: Output format (json, yaml, toml, xml, plain). Uses self.output_format if None.
334
+
335
+ Returns:
336
+ Formatted example string with code block markers
337
+ """
338
+ import json
339
+
340
+ fmt = format_type or self.output_format
341
+
342
+ if fmt == OutputFormat.JSON or fmt == "json":
343
+ return f"```json\n{json.dumps(example_data, indent=4)}\n```"
344
+
345
+ elif fmt == OutputFormat.YAML or fmt == "yaml":
346
+ # Simple YAML conversion
347
+ def to_yaml(obj, indent=0):
348
+ lines = []
349
+ prefix = " " * indent
350
+ if isinstance(obj, dict):
351
+ for k, v in obj.items():
352
+ if isinstance(v, (dict, list)):
353
+ lines.append(f"{prefix}{k}:")
354
+ lines.append(to_yaml(v, indent + 1))
355
+ else:
356
+ val = (
357
+ str(v).lower()
358
+ if isinstance(v, bool)
359
+ else f'"{v}"'
360
+ if isinstance(v, str)
361
+ else v
362
+ )
363
+ lines.append(f"{prefix}{k}: {val}")
364
+ elif isinstance(obj, list):
365
+ for item in obj:
366
+ if isinstance(item, str):
367
+ lines.append(f'{prefix}- "{item}"')
368
+ else:
369
+ lines.append(f"{prefix}- {to_yaml(item, indent + 1)}")
370
+ return "\n".join(lines)
371
+
372
+ return f"```yaml\n{to_yaml(example_data)}\n```"
373
+
374
+ elif fmt == OutputFormat.TOML or fmt == "toml":
375
+ # Simple TOML conversion
376
+ def to_toml(obj, section=""):
377
+ lines = []
378
+ if section:
379
+ lines.append(f"[{section}]")
380
+ for k, v in obj.items():
381
+ if isinstance(v, dict):
382
+ lines.append(to_toml(v, f"{section}.{k}" if section else k))
383
+ elif isinstance(v, list):
384
+ items = ", ".join(f'"{i}"' if isinstance(i, str) else str(i) for i in v)
385
+ lines.append(f"{k} = [{items}]")
386
+ elif isinstance(v, bool):
387
+ lines.append(f"{k} = {str(v).lower()}")
388
+ elif isinstance(v, str):
389
+ lines.append(f'{k} = "{v}"')
390
+ else:
391
+ lines.append(f"{k} = {v}")
392
+ return "\n".join(lines)
393
+
394
+ return f"```toml\n{to_toml(example_data)}\n```"
395
+
396
+ elif fmt == OutputFormat.PLAIN or fmt == "plain":
397
+ # Plain text - use plain_answer if available
398
+ plain_answer = example_data.get("_plain_answer", "")
399
+ reasoning = example_data.get("evidence", {}).get("rationale", "")
400
+ return f"{reasoning}\n\nFINAL ANSWER: {plain_answer}"
401
+
402
+ elif fmt == OutputFormat.XML or fmt == "xml":
403
+
404
+ def to_xml(obj, root="response"):
405
+ lines = [f"<{root}>"]
406
+ for k, v in obj.items():
407
+ if isinstance(v, dict):
408
+ lines.append(to_xml(v, k))
409
+ elif isinstance(v, list):
410
+ lines.append(f" <{k}>")
411
+ for item in v:
412
+ lines.append(f" <item>{item}</item>")
413
+ lines.append(f" </{k}>")
414
+ else:
415
+ val = str(v).lower() if isinstance(v, bool) else v
416
+ lines.append(f" <{k}>{val}</{k}>")
417
+ lines.append(f"</{root}>")
418
+ return "\n".join(lines)
419
+
420
+ return f"```xml\n{to_xml(example_data)}\n```"
421
+
422
+ elif fmt == OutputFormat.TOON or fmt == "toon":
423
+ # TOON: Simple Key-Value pairs
424
+ lines = []
425
+ # Extract main fields from example data
426
+ plain_answer = example_data.get("_plain_answer", "")
427
+ reasoning = example_data.get("evidence", {}).get("rationale", "")
428
+
429
+ lines.append(f"diagnosis: {plain_answer}")
430
+ lines.append(f"reasoning: {reasoning}")
431
+ # Add other flat keys if they exist and aren't dicts/lists
432
+ for k, v in example_data.items():
433
+ if k not in ["evidence", "_plain_answer"] and not isinstance(v, (dict, list)):
434
+ lines.append(f"{k}: {str(v).lower() if isinstance(v, bool) else v}")
435
+ return f"```toon\n{chr(10).join(lines)}\n```"
436
+
437
+ elif fmt == OutputFormat.MARKDOWN or fmt == "markdown":
438
+ # Markdown: Header sections
439
+ plain_answer = example_data.get("_plain_answer", "")
440
+ reasoning = example_data.get("evidence", {}).get("rationale", "")
441
+
442
+ md = f"## Diagnosis\n{plain_answer}\n\n"
443
+ md += f"## Reasoning\n{reasoning}\n\n"
444
+ md += "## Confidence\n95" # Example confidence
445
+ return md
446
+
447
+ else:
448
+ # Default to JSON
449
+ return f"```json\n{json.dumps(example_data, indent=4)}\n```"
450
+
451
+ def _parse_formatted_response(self, response: str) -> Dict[str, Any]:
452
+ """Parse response based on output format."""
453
+ import re
454
+
455
+ if self.output_format == OutputFormat.PLAIN:
456
+ # Extract FINAL ANSWER and CONFIDENCE from plain text
457
+ answer_match = re.search(r"FINAL ANSWER:\s*(.+?)(?:\n|$)", response, re.IGNORECASE)
458
+ confidence_match = re.search(r"CONFIDENCE:\s*(\d+)", response, re.IGNORECASE)
459
+
460
+ answer = answer_match.group(1).strip() if answer_match else response.strip()
461
+ confidence = int(confidence_match.group(1)) if confidence_match else None
462
+
463
+ return {
464
+ "answer": answer,
465
+ "confidence": confidence,
466
+ "reasoning": response,
467
+ "raw": response,
468
+ "parse_success": bool(answer_match),
469
+ "format": "plain",
470
+ }
471
+
472
+ if self.output_format == OutputFormat.JSON:
473
+ return self._parse_json(response)
474
+ elif self.output_format == OutputFormat.TOON:
475
+ return self._parse_toon(response)
476
+ elif self.output_format == OutputFormat.TOML:
477
+ return self._parse_toml(response)
478
+ elif self.output_format == OutputFormat.XML:
479
+ return self._parse_xml(response)
480
+ elif self.output_format == OutputFormat.YAML:
481
+ return self._parse_yaml(response)
482
+ elif self.output_format == OutputFormat.MARKDOWN:
483
+ return self._parse_markdown(response)
484
+ else:
485
+ return {
486
+ "answer": response.strip(),
487
+ "reasoning": response,
488
+ "raw": response,
489
+ "parse_success": False,
490
+ "format": self.output_format,
491
+ "error": f"Unknown format: {self.output_format}",
492
+ }
493
+
494
+ def _parse_json(self, response: str) -> Dict[str, Any]:
495
+ """Parse JSON response."""
496
+ import re
497
+
498
+ json_match = re.search(r"```json\s*(.*?)\s*```", response, re.DOTALL)
499
+ if json_match:
500
+ json_str = json_match.group(1)
501
+ else:
502
+ json_match = re.search(r'\{[^{}]*"diagnosis"[^{}]*\}', response, re.DOTALL)
503
+ if json_match:
504
+ json_str = json_match.group(0)
505
+ else:
506
+ return self._parse_failure(response, "json")
507
+
508
+ try:
509
+ parsed = json.loads(json_str)
510
+ return self._extract_fields(parsed, response, "json")
511
+ except json.JSONDecodeError:
512
+ return self._parse_failure(response, "json")
513
+
514
+ def _parse_toon(self, response: str) -> Dict[str, Any]:
515
+ """Parse TOON (Token-Oriented Object Notation) response."""
516
+ import re
517
+
518
+ # Extract TOON block from code fence
519
+ toon_match = re.search(r"```(?:toon)?\s*(.*?)\s*```", response, re.DOTALL)
520
+ content = toon_match.group(1) if toon_match else response
521
+
522
+ # Parse key: value pairs
523
+ result = {}
524
+ for line in content.strip().split("\n"):
525
+ if ":" in line:
526
+ key, _, value = line.partition(":")
527
+ key = key.strip()
528
+ value = value.strip()
529
+ if key and value:
530
+ # Auto-convert booleans
531
+ if value.lower() == "true":
532
+ result[key] = True
533
+ elif value.lower() == "false":
534
+ result[key] = False
535
+ else:
536
+ result[key] = value
537
+
538
+ if "diagnosis" in result:
539
+ return self._extract_fields(result, response, "toon")
540
+ return self._parse_failure(response, "toon")
541
+
542
+ def _parse_toml(self, response: str) -> Dict[str, Any]:
543
+ """Parse TOML response."""
544
+ import re
545
+
546
+ toml_match = re.search(r"```toml\s*(.*?)\s*```", response, re.DOTALL)
547
+ if not toml_match:
548
+ return self._parse_failure(response, "toml")
549
+
550
+ try:
551
+ import tomllib
552
+
553
+ parsed = tomllib.loads(toml_match.group(1))
554
+ # Handle [response] section
555
+ if "response" in parsed:
556
+ parsed = parsed["response"]
557
+ return self._extract_fields(parsed, response, "toml")
558
+ except Exception:
559
+ return self._parse_failure(response, "toml")
560
+
561
+ def _parse_xml(self, response: str) -> Dict[str, Any]:
562
+ """Parse XML response."""
563
+ import re
564
+
565
+ xml_match = re.search(r"```xml\s*(.*?)\s*```", response, re.DOTALL)
566
+ content = xml_match.group(1) if xml_match else response
567
+
568
+ try:
569
+ import xmltodict
570
+
571
+ def postprocessor(path, key, value):
572
+ if value and isinstance(value, str):
573
+ if value.lower() == "true":
574
+ return key, True
575
+ if value.lower() == "false":
576
+ return key, False
577
+ return key, value
578
+
579
+ parsed = xmltodict.parse(content.strip(), postprocessor=postprocessor)
580
+ # Remove root element wrapper if present (e.g. <response>)
581
+ if len(parsed) == 1:
582
+ key = next(iter(parsed))
583
+ parsed = parsed[key]
584
+
585
+ return self._extract_fields(parsed, response, "xml")
586
+ except Exception:
587
+ # Fallback to ElementTree with recursive parsing
588
+ try:
589
+ import xml.etree.ElementTree as ET
590
+
591
+ def elem_to_dict(elem):
592
+ text = elem.text.strip() if elem.text else None
593
+ # Convert booleans
594
+ if text:
595
+ if text.lower() == "true":
596
+ text = True
597
+ elif text.lower() == "false":
598
+ text = False
599
+
600
+ children = list(elem)
601
+ if not children:
602
+ return text
603
+
604
+ result = {}
605
+ for child in children:
606
+ child_val = elem_to_dict(child)
607
+ if child.tag in result:
608
+ if not isinstance(result[child.tag], list):
609
+ result[child.tag] = [result[child.tag]]
610
+ result[child.tag].append(child_val)
611
+ else:
612
+ result[child.tag] = child_val
613
+ return result
614
+
615
+ root = ET.fromstring(content.strip())
616
+ parsed = elem_to_dict(root)
617
+ # If root returned a dict (nested content), use it directly
618
+ # If root was simple value (not likely for top level), usage might differ
619
+ if isinstance(parsed, dict):
620
+ pass
621
+
622
+ return self._extract_fields(parsed, response, "xml")
623
+ except Exception:
624
+ return self._parse_failure(response, "xml")
625
+
626
+ def _parse_yaml(self, response: str) -> Dict[str, Any]:
627
+ """Parse YAML response."""
628
+ import re
629
+
630
+ yaml_match = re.search(r"```yaml\s*(.*?)\s*```", response, re.DOTALL)
631
+ if not yaml_match:
632
+ return self._parse_failure(response, "yaml")
633
+
634
+ try:
635
+ import yaml
636
+
637
+ parsed = yaml.safe_load(yaml_match.group(1))
638
+ return self._extract_fields(parsed, response, "yaml")
639
+ except Exception:
640
+ return self._parse_failure(response, "yaml")
641
+
642
+ def _parse_markdown(self, response: str) -> Dict[str, Any]:
643
+ """Parse structured Markdown response."""
644
+ import re
645
+
646
+ result = {}
647
+ # Extract sections
648
+ diagnosis_match = re.search(
649
+ r"##\s*Diagnosis\s*\n(.*?)(?=\n##|\Z)", response, re.DOTALL | re.IGNORECASE
650
+ )
651
+ confidence_match = re.search(
652
+ r"##\s*Confidence\s*\n(.*?)(?=\n##|\Z)", response, re.DOTALL | re.IGNORECASE
653
+ )
654
+ reasoning_match = re.search(
655
+ r"##\s*Reasoning\s*\n(.*?)(?=\n##|\Z)", response, re.DOTALL | re.IGNORECASE
656
+ )
657
+
658
+ if diagnosis_match:
659
+ result["diagnosis"] = diagnosis_match.group(1).strip()
660
+ if confidence_match:
661
+ try:
662
+ result["confidence"] = int(confidence_match.group(1).strip())
663
+ except ValueError:
664
+ result["confidence"] = 0
665
+ if reasoning_match:
666
+ result["reasoning"] = reasoning_match.group(1).strip()
667
+
668
+ if "diagnosis" in result:
669
+ return self._extract_fields(result, response, "markdown")
670
+ return self._parse_failure(response, "markdown")
671
+
672
+ def _extract_fields(self, parsed: dict, response: str, fmt: str) -> Dict[str, Any]:
673
+ """Extract standard fields from parsed response."""
674
+ steps = parsed.get("step_by_step", [])
675
+ reasoning = "\n".join(steps) if isinstance(steps, list) else parsed.get("reasoning", "")
676
+
677
+ # Ensure confidence is integer
678
+ confidence = parsed.get("confidence", 0)
679
+ if isinstance(confidence, str):
680
+ try:
681
+ confidence = int(confidence)
682
+ except ValueError:
683
+ confidence = 0
684
+ elif not isinstance(confidence, int):
685
+ confidence = int(confidence) if confidence else 0
686
+
687
+ result = {
688
+ "answer": str(parsed.get("diagnosis", "")),
689
+ "reasoning": reasoning,
690
+ "confidence": confidence,
691
+ "step_by_step": steps,
692
+ "raw": response,
693
+ "parsed_data": parsed,
694
+ "parse_success": True,
695
+ "format": fmt,
696
+ }
697
+
698
+ # Merge all parsed fields into top-level result for easy access
699
+ # (excluding standard keys to avoid overwriting processed values)
700
+ for k, v in parsed.items():
701
+ if k not in result:
702
+ result[k] = v
703
+
704
+ return result
705
+
706
+ def _parse_failure(self, response: str, fmt: str) -> Dict[str, Any]:
707
+ """Return parse failure result."""
708
+ return {
709
+ "answer": response.strip(),
710
+ "reasoning": response,
711
+ "raw": response,
712
+ "parse_success": False,
713
+ "format": fmt,
714
+ }
715
+
716
+
717
+ class BaseExperiment(ABC):
718
+ """Abstract base class for experiments."""
719
+
720
+ @property
721
+ @abstractmethod
722
+ def name(self) -> str:
723
+ """Experiment name for logging."""
724
+ ...
725
+
726
+ @abstractmethod
727
+ def run(
728
+ self,
729
+ backend: "InferenceBackend",
730
+ dataset: Any,
731
+ prompt_strategy: BasePromptStrategy,
732
+ **kwargs,
733
+ ) -> ExperimentResult:
734
+ """
735
+ Run the experiment.
736
+
737
+ Args:
738
+ backend: Inference backend (vLLM or Transformers)
739
+ dataset: Dataset to run on
740
+ prompt_strategy: How to construct prompts
741
+
742
+ Returns:
743
+ ExperimentResult with metrics and outputs
744
+ """
745
+ ...
746
+
747
+ def validate_backend(self, backend: "InferenceBackend") -> None:
748
+ """Check if backend supports required features."""
749
+ pass