cotlab 0.8.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cotlab/__init__.py +3 -0
- cotlab/analyse_experiments.py +392 -0
- cotlab/analysis/__init__.py +11 -0
- cotlab/analysis/cot_parser.py +243 -0
- cotlab/analysis/faithfulness_metrics.py +192 -0
- cotlab/backends/__init__.py +16 -0
- cotlab/backends/base.py +78 -0
- cotlab/backends/transformers_backend.py +335 -0
- cotlab/backends/vllm_backend.py +227 -0
- cotlab/cli.py +83 -0
- cotlab/core/__init__.py +34 -0
- cotlab/core/base.py +749 -0
- cotlab/core/config.py +90 -0
- cotlab/core/registry.py +68 -0
- cotlab/datasets/__init__.py +45 -0
- cotlab/datasets/loaders.py +1889 -0
- cotlab/experiment/__init__.py +315 -0
- cotlab/experiments/__init__.py +43 -0
- cotlab/experiments/activation_compare.py +290 -0
- cotlab/experiments/activation_patching.py +1050 -0
- cotlab/experiments/attention_analysis.py +885 -0
- cotlab/experiments/classification.py +235 -0
- cotlab/experiments/composite_shift_detector.py +524 -0
- cotlab/experiments/cot_ablation.py +277 -0
- cotlab/experiments/cot_faithfulness.py +187 -0
- cotlab/experiments/cot_heads.py +208 -0
- cotlab/experiments/full_layer_cot.py +232 -0
- cotlab/experiments/full_layer_patching.py +225 -0
- cotlab/experiments/h_neuron_analysis.py +712 -0
- cotlab/experiments/logit_lens.py +439 -0
- cotlab/experiments/multi_head_cot.py +220 -0
- cotlab/experiments/multi_head_patching.py +229 -0
- cotlab/experiments/probing_classifier.py +402 -0
- cotlab/experiments/residual_norm_ood.py +413 -0
- cotlab/experiments/sae_feature_analysis.py +673 -0
- cotlab/experiments/steering_vectors.py +223 -0
- cotlab/experiments/sycophancy_heads.py +224 -0
- cotlab/logging/__init__.py +5 -0
- cotlab/logging/json_logger.py +161 -0
- cotlab/main.py +317 -0
- cotlab/patching/__init__.py +24 -0
- cotlab/patching/cache.py +141 -0
- cotlab/patching/hooks.py +558 -0
- cotlab/patching/interventions.py +86 -0
- cotlab/patching/patcher.py +439 -0
- cotlab/patching/sae.py +181 -0
- cotlab/prompts/__init__.py +43 -0
- cotlab/prompts/cardiology.py +378 -0
- cotlab/prompts/histopathology.py +265 -0
- cotlab/prompts/length_matched_strategies.py +157 -0
- cotlab/prompts/mcq.py +193 -0
- cotlab/prompts/neurology.py +353 -0
- cotlab/prompts/oncology.py +367 -0
- cotlab/prompts/plab.py +162 -0
- cotlab/prompts/pubhealthbench.py +82 -0
- cotlab/prompts/pubmedqa.py +173 -0
- cotlab/prompts/radiology.py +414 -0
- cotlab/prompts/strategies.py +939 -0
- cotlab/prompts/tcga.py +168 -0
- cotlab/runner.py +204 -0
- cotlab-0.8.0.dist-info/METADATA +166 -0
- cotlab-0.8.0.dist-info/RECORD +65 -0
- cotlab-0.8.0.dist-info/WHEEL +4 -0
- cotlab-0.8.0.dist-info/entry_points.txt +3 -0
- cotlab-0.8.0.dist-info/licenses/LICENSE +21 -0
cotlab/core/base.py
ADDED
|
@@ -0,0 +1,749 @@
|
|
|
1
|
+
"""Base classes and data structures for the CoT research framework."""
|
|
2
|
+
|
|
3
|
+
import json
|
|
4
|
+
from abc import ABC, abstractmethod
|
|
5
|
+
from dataclasses import dataclass, field
|
|
6
|
+
from datetime import datetime
|
|
7
|
+
from typing import Any, Dict, List, Optional
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
@dataclass
|
|
11
|
+
class GenerationOutput:
|
|
12
|
+
"""Standard output format for model generation."""
|
|
13
|
+
|
|
14
|
+
text: str
|
|
15
|
+
tokens: List[int]
|
|
16
|
+
logprobs: Optional[List[float]] = None
|
|
17
|
+
|
|
18
|
+
def __repr__(self) -> str:
|
|
19
|
+
return f"GenerationOutput(text={self.text[:50]}..., tokens={len(self.tokens)})"
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
@dataclass
|
|
23
|
+
class ExperimentResult:
|
|
24
|
+
"""JSON-serializable experiment result."""
|
|
25
|
+
|
|
26
|
+
experiment_name: str
|
|
27
|
+
model_name: str
|
|
28
|
+
prompt_strategy: str
|
|
29
|
+
metrics: Dict[str, Any]
|
|
30
|
+
raw_outputs: List[Dict[str, Any]] = field(default_factory=list)
|
|
31
|
+
metadata: Dict[str, Any] = field(default_factory=dict)
|
|
32
|
+
timestamp: str = field(default_factory=lambda: datetime.now().isoformat())
|
|
33
|
+
|
|
34
|
+
def to_json(self) -> str:
|
|
35
|
+
"""Serialize to JSON string."""
|
|
36
|
+
return json.dumps(self.__dict__, indent=2, default=str)
|
|
37
|
+
|
|
38
|
+
def save(self, path: str) -> None:
|
|
39
|
+
"""Save to JSON file."""
|
|
40
|
+
with open(path, "w") as f:
|
|
41
|
+
f.write(self.to_json())
|
|
42
|
+
|
|
43
|
+
@classmethod
|
|
44
|
+
def load(cls, path: str) -> "ExperimentResult":
|
|
45
|
+
"""Load from JSON file."""
|
|
46
|
+
with open(path, "r") as f:
|
|
47
|
+
data = json.load(f)
|
|
48
|
+
return cls(**data)
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
class BasePromptStrategy(ABC):
|
|
52
|
+
"""Abstract base class for prompt construction strategies."""
|
|
53
|
+
|
|
54
|
+
@property
|
|
55
|
+
@abstractmethod
|
|
56
|
+
def name(self) -> str:
|
|
57
|
+
"""Strategy name for logging."""
|
|
58
|
+
...
|
|
59
|
+
|
|
60
|
+
@abstractmethod
|
|
61
|
+
def build_prompt(self, input_data: Dict[str, Any]) -> str:
|
|
62
|
+
"""
|
|
63
|
+
Build a prompt from input data.
|
|
64
|
+
|
|
65
|
+
Args:
|
|
66
|
+
input_data: Dictionary with at least 'question' key
|
|
67
|
+
|
|
68
|
+
Returns:
|
|
69
|
+
Formatted prompt string
|
|
70
|
+
"""
|
|
71
|
+
...
|
|
72
|
+
|
|
73
|
+
@abstractmethod
|
|
74
|
+
def parse_response(self, response: str) -> Dict[str, Any]:
|
|
75
|
+
"""
|
|
76
|
+
Parse model response to extract answer and reasoning.
|
|
77
|
+
|
|
78
|
+
Args:
|
|
79
|
+
response: Raw model output
|
|
80
|
+
|
|
81
|
+
Returns:
|
|
82
|
+
Dictionary with 'answer', 'reasoning', and any other extracted fields
|
|
83
|
+
"""
|
|
84
|
+
...
|
|
85
|
+
|
|
86
|
+
def get_system_message(self) -> Optional[str]:
|
|
87
|
+
"""Return system message if applicable."""
|
|
88
|
+
return None
|
|
89
|
+
|
|
90
|
+
def get_compatible_datasets(self) -> Optional[List[str]]:
|
|
91
|
+
"""
|
|
92
|
+
Return list of compatible dataset names, or None if compatible with all.
|
|
93
|
+
|
|
94
|
+
Override this in specialized prompts to restrict usage.
|
|
95
|
+
"""
|
|
96
|
+
return None
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
# Output format options
|
|
100
|
+
class OutputFormat:
|
|
101
|
+
"""Supported output formats for structured responses."""
|
|
102
|
+
|
|
103
|
+
PLAIN = "plain"
|
|
104
|
+
JSON = "json"
|
|
105
|
+
TOON = "toon"
|
|
106
|
+
TOML = "toml"
|
|
107
|
+
XML = "xml"
|
|
108
|
+
YAML = "yaml"
|
|
109
|
+
MARKDOWN = "markdown"
|
|
110
|
+
|
|
111
|
+
@classmethod
|
|
112
|
+
def all(cls) -> list:
|
|
113
|
+
return [cls.PLAIN, cls.JSON, cls.TOON, cls.TOML, cls.XML, cls.YAML, cls.MARKDOWN]
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
# Plain text output schema for medical diagnosis tasks
|
|
117
|
+
PLAIN_OUTPUT_SCHEMA = """\
|
|
118
|
+
|
|
119
|
+
Provide your answer in plain text. At the end of your response, clearly state:
|
|
120
|
+
|
|
121
|
+
FINAL ANSWER: [your diagnosis]
|
|
122
|
+
CONFIDENCE: [0-100]
|
|
123
|
+
|
|
124
|
+
Example:
|
|
125
|
+
Based on the clinical findings, the patient shows signs consistent with the condition.
|
|
126
|
+
|
|
127
|
+
FINAL ANSWER: positive
|
|
128
|
+
CONFIDENCE: 85"""
|
|
129
|
+
|
|
130
|
+
PLAIN_COT_SCHEMA = """\
|
|
131
|
+
|
|
132
|
+
Provide your reasoning step by step in plain text. At the end, clearly state:
|
|
133
|
+
|
|
134
|
+
FINAL ANSWER: [your diagnosis]
|
|
135
|
+
CONFIDENCE: [0-100]
|
|
136
|
+
|
|
137
|
+
Example:
|
|
138
|
+
Step 1: The report mentions...
|
|
139
|
+
Step 2: This suggests...
|
|
140
|
+
Step 3: Therefore...
|
|
141
|
+
|
|
142
|
+
FINAL ANSWER: positive
|
|
143
|
+
CONFIDENCE: 85"""
|
|
144
|
+
|
|
145
|
+
# JSON output schema for medical diagnosis tasks
|
|
146
|
+
JSON_OUTPUT_SCHEMA = """\
|
|
147
|
+
|
|
148
|
+
Output your answer ONLY in this JSON format:
|
|
149
|
+
```json
|
|
150
|
+
{
|
|
151
|
+
"diagnosis": "[your diagnosis]",
|
|
152
|
+
"confidence": [0-100],
|
|
153
|
+
"reasoning": "[brief explanation]"
|
|
154
|
+
}
|
|
155
|
+
```"""
|
|
156
|
+
|
|
157
|
+
JSON_COT_SCHEMA = """\
|
|
158
|
+
|
|
159
|
+
Output your answer ONLY in this JSON format:
|
|
160
|
+
```json
|
|
161
|
+
{
|
|
162
|
+
"step_by_step": ["Step 1: ...", "Step 2: ..."],
|
|
163
|
+
"diagnosis": "[your diagnosis]",
|
|
164
|
+
"confidence": [0-100]
|
|
165
|
+
}
|
|
166
|
+
```"""
|
|
167
|
+
|
|
168
|
+
# TOON (Token-Oriented Object Notation) - 40% fewer tokens than JSON
|
|
169
|
+
TOON_OUTPUT_SCHEMA = """\
|
|
170
|
+
|
|
171
|
+
Output your answer ONLY in this TOON format (indentation-based, no braces):
|
|
172
|
+
```
|
|
173
|
+
diagnosis: [your diagnosis]
|
|
174
|
+
confidence: [0-100]
|
|
175
|
+
reasoning: [brief explanation]
|
|
176
|
+
```"""
|
|
177
|
+
|
|
178
|
+
TOON_COT_SCHEMA = """\
|
|
179
|
+
|
|
180
|
+
Output your answer ONLY in this TOON format:
|
|
181
|
+
```
|
|
182
|
+
step_by_step: [3]
|
|
183
|
+
Step 1: ...
|
|
184
|
+
Step 2: ...
|
|
185
|
+
Step 3: ...
|
|
186
|
+
diagnosis: [your diagnosis]
|
|
187
|
+
confidence: [0-100]
|
|
188
|
+
```"""
|
|
189
|
+
|
|
190
|
+
# TOML format
|
|
191
|
+
TOML_OUTPUT_SCHEMA = """\
|
|
192
|
+
|
|
193
|
+
Output your answer ONLY in this TOML format:
|
|
194
|
+
```toml
|
|
195
|
+
[response]
|
|
196
|
+
diagnosis = "[your diagnosis]"
|
|
197
|
+
confidence = 0-100
|
|
198
|
+
reasoning = "[brief explanation]"
|
|
199
|
+
```"""
|
|
200
|
+
|
|
201
|
+
TOML_COT_SCHEMA = """\
|
|
202
|
+
|
|
203
|
+
Output your answer ONLY in this TOML format:
|
|
204
|
+
```toml
|
|
205
|
+
[response]
|
|
206
|
+
step_by_step = ["Step 1: ...", "Step 2: ..."]
|
|
207
|
+
diagnosis = "[your diagnosis]"
|
|
208
|
+
confidence = 0-100
|
|
209
|
+
```"""
|
|
210
|
+
|
|
211
|
+
# XML format
|
|
212
|
+
XML_OUTPUT_SCHEMA = """\
|
|
213
|
+
|
|
214
|
+
Output your answer ONLY in this XML format:
|
|
215
|
+
```xml
|
|
216
|
+
<response>
|
|
217
|
+
<diagnosis>[your diagnosis]</diagnosis>
|
|
218
|
+
<confidence>[0-100]</confidence>
|
|
219
|
+
<reasoning>[brief explanation]</reasoning>
|
|
220
|
+
</response>
|
|
221
|
+
```"""
|
|
222
|
+
|
|
223
|
+
XML_COT_SCHEMA = """\
|
|
224
|
+
|
|
225
|
+
Output your answer ONLY in this XML format:
|
|
226
|
+
```xml
|
|
227
|
+
<response>
|
|
228
|
+
<step_by_step>
|
|
229
|
+
<step>Step 1: ...</step>
|
|
230
|
+
<step>Step 2: ...</step>
|
|
231
|
+
</step_by_step>
|
|
232
|
+
<diagnosis>[your diagnosis]</diagnosis>
|
|
233
|
+
<confidence>[0-100]</confidence>
|
|
234
|
+
</response>
|
|
235
|
+
```"""
|
|
236
|
+
|
|
237
|
+
# YAML format
|
|
238
|
+
YAML_OUTPUT_SCHEMA = """\
|
|
239
|
+
|
|
240
|
+
Output your answer ONLY in this YAML format:
|
|
241
|
+
```yaml
|
|
242
|
+
diagnosis: "[your diagnosis]"
|
|
243
|
+
confidence: 0-100
|
|
244
|
+
reasoning: "[brief explanation]"
|
|
245
|
+
```"""
|
|
246
|
+
|
|
247
|
+
YAML_COT_SCHEMA = """\
|
|
248
|
+
|
|
249
|
+
Output your answer ONLY in this YAML format:
|
|
250
|
+
```yaml
|
|
251
|
+
step_by_step:
|
|
252
|
+
- "Step 1: ..."
|
|
253
|
+
- "Step 2: ..."
|
|
254
|
+
diagnosis: "[your diagnosis]"
|
|
255
|
+
confidence: 0-100
|
|
256
|
+
```"""
|
|
257
|
+
|
|
258
|
+
# Markdown structured format
|
|
259
|
+
MARKDOWN_OUTPUT_SCHEMA = """\
|
|
260
|
+
|
|
261
|
+
Output your answer ONLY in this structured Markdown format:
|
|
262
|
+
## Diagnosis
|
|
263
|
+
[your diagnosis]
|
|
264
|
+
|
|
265
|
+
## Confidence
|
|
266
|
+
[0-100]
|
|
267
|
+
|
|
268
|
+
## Reasoning
|
|
269
|
+
[brief explanation]"""
|
|
270
|
+
|
|
271
|
+
MARKDOWN_COT_SCHEMA = """\
|
|
272
|
+
|
|
273
|
+
Output your answer ONLY in this structured Markdown format:
|
|
274
|
+
## Step-by-Step Reasoning
|
|
275
|
+
1. Step 1: ...
|
|
276
|
+
2. Step 2: ...
|
|
277
|
+
|
|
278
|
+
## Diagnosis
|
|
279
|
+
[your diagnosis]
|
|
280
|
+
|
|
281
|
+
## Confidence
|
|
282
|
+
[0-100]"""
|
|
283
|
+
|
|
284
|
+
|
|
285
|
+
class StructuredOutputMixin:
|
|
286
|
+
"""Mixin to add multi-format structured output capability to any prompt strategy.
|
|
287
|
+
|
|
288
|
+
Supports: PLAIN, JSON, TOON, TOML, XML, YAML, MARKDOWN
|
|
289
|
+
|
|
290
|
+
Usage:
|
|
291
|
+
class MyStrategy(StructuredOutputMixin, BasePromptStrategy):
|
|
292
|
+
def __init__(self, output_format="plain", ...):
|
|
293
|
+
self.output_format = output_format
|
|
294
|
+
self.cot_format = False # Set True for CoT inside structure
|
|
295
|
+
"""
|
|
296
|
+
|
|
297
|
+
output_format: str = "plain"
|
|
298
|
+
cot_format: bool = False
|
|
299
|
+
|
|
300
|
+
# Backward compatibility with json_output
|
|
301
|
+
@property
|
|
302
|
+
def json_output(self) -> bool:
|
|
303
|
+
return self.output_format == OutputFormat.JSON
|
|
304
|
+
|
|
305
|
+
@json_output.setter
|
|
306
|
+
def json_output(self, value: bool):
|
|
307
|
+
if value:
|
|
308
|
+
self.output_format = OutputFormat.JSON
|
|
309
|
+
|
|
310
|
+
def _get_format_schema(self) -> tuple:
|
|
311
|
+
"""Return (output_schema, cot_schema) for current format."""
|
|
312
|
+
schemas = {
|
|
313
|
+
OutputFormat.PLAIN: (PLAIN_OUTPUT_SCHEMA, PLAIN_COT_SCHEMA),
|
|
314
|
+
OutputFormat.JSON: (JSON_OUTPUT_SCHEMA, JSON_COT_SCHEMA),
|
|
315
|
+
OutputFormat.TOON: (TOON_OUTPUT_SCHEMA, TOON_COT_SCHEMA),
|
|
316
|
+
OutputFormat.TOML: (TOML_OUTPUT_SCHEMA, TOML_COT_SCHEMA),
|
|
317
|
+
OutputFormat.XML: (XML_OUTPUT_SCHEMA, XML_COT_SCHEMA),
|
|
318
|
+
OutputFormat.YAML: (YAML_OUTPUT_SCHEMA, YAML_COT_SCHEMA),
|
|
319
|
+
OutputFormat.MARKDOWN: (MARKDOWN_OUTPUT_SCHEMA, MARKDOWN_COT_SCHEMA),
|
|
320
|
+
}
|
|
321
|
+
return schemas.get(self.output_format, ("", ""))
|
|
322
|
+
|
|
323
|
+
def _add_format_instruction(self) -> str:
|
|
324
|
+
"""Return format instruction to append to prompt."""
|
|
325
|
+
output_schema, cot_schema = self._get_format_schema()
|
|
326
|
+
return cot_schema if self.cot_format else output_schema
|
|
327
|
+
|
|
328
|
+
def _format_example(self, example_data: Dict[str, Any], format_type: str = None) -> str:
|
|
329
|
+
"""Format a single example in the specified format.
|
|
330
|
+
|
|
331
|
+
Args:
|
|
332
|
+
example_data: Dictionary with example data
|
|
333
|
+
format_type: Output format (json, yaml, toml, xml, plain). Uses self.output_format if None.
|
|
334
|
+
|
|
335
|
+
Returns:
|
|
336
|
+
Formatted example string with code block markers
|
|
337
|
+
"""
|
|
338
|
+
import json
|
|
339
|
+
|
|
340
|
+
fmt = format_type or self.output_format
|
|
341
|
+
|
|
342
|
+
if fmt == OutputFormat.JSON or fmt == "json":
|
|
343
|
+
return f"```json\n{json.dumps(example_data, indent=4)}\n```"
|
|
344
|
+
|
|
345
|
+
elif fmt == OutputFormat.YAML or fmt == "yaml":
|
|
346
|
+
# Simple YAML conversion
|
|
347
|
+
def to_yaml(obj, indent=0):
|
|
348
|
+
lines = []
|
|
349
|
+
prefix = " " * indent
|
|
350
|
+
if isinstance(obj, dict):
|
|
351
|
+
for k, v in obj.items():
|
|
352
|
+
if isinstance(v, (dict, list)):
|
|
353
|
+
lines.append(f"{prefix}{k}:")
|
|
354
|
+
lines.append(to_yaml(v, indent + 1))
|
|
355
|
+
else:
|
|
356
|
+
val = (
|
|
357
|
+
str(v).lower()
|
|
358
|
+
if isinstance(v, bool)
|
|
359
|
+
else f'"{v}"'
|
|
360
|
+
if isinstance(v, str)
|
|
361
|
+
else v
|
|
362
|
+
)
|
|
363
|
+
lines.append(f"{prefix}{k}: {val}")
|
|
364
|
+
elif isinstance(obj, list):
|
|
365
|
+
for item in obj:
|
|
366
|
+
if isinstance(item, str):
|
|
367
|
+
lines.append(f'{prefix}- "{item}"')
|
|
368
|
+
else:
|
|
369
|
+
lines.append(f"{prefix}- {to_yaml(item, indent + 1)}")
|
|
370
|
+
return "\n".join(lines)
|
|
371
|
+
|
|
372
|
+
return f"```yaml\n{to_yaml(example_data)}\n```"
|
|
373
|
+
|
|
374
|
+
elif fmt == OutputFormat.TOML or fmt == "toml":
|
|
375
|
+
# Simple TOML conversion
|
|
376
|
+
def to_toml(obj, section=""):
|
|
377
|
+
lines = []
|
|
378
|
+
if section:
|
|
379
|
+
lines.append(f"[{section}]")
|
|
380
|
+
for k, v in obj.items():
|
|
381
|
+
if isinstance(v, dict):
|
|
382
|
+
lines.append(to_toml(v, f"{section}.{k}" if section else k))
|
|
383
|
+
elif isinstance(v, list):
|
|
384
|
+
items = ", ".join(f'"{i}"' if isinstance(i, str) else str(i) for i in v)
|
|
385
|
+
lines.append(f"{k} = [{items}]")
|
|
386
|
+
elif isinstance(v, bool):
|
|
387
|
+
lines.append(f"{k} = {str(v).lower()}")
|
|
388
|
+
elif isinstance(v, str):
|
|
389
|
+
lines.append(f'{k} = "{v}"')
|
|
390
|
+
else:
|
|
391
|
+
lines.append(f"{k} = {v}")
|
|
392
|
+
return "\n".join(lines)
|
|
393
|
+
|
|
394
|
+
return f"```toml\n{to_toml(example_data)}\n```"
|
|
395
|
+
|
|
396
|
+
elif fmt == OutputFormat.PLAIN or fmt == "plain":
|
|
397
|
+
# Plain text - use plain_answer if available
|
|
398
|
+
plain_answer = example_data.get("_plain_answer", "")
|
|
399
|
+
reasoning = example_data.get("evidence", {}).get("rationale", "")
|
|
400
|
+
return f"{reasoning}\n\nFINAL ANSWER: {plain_answer}"
|
|
401
|
+
|
|
402
|
+
elif fmt == OutputFormat.XML or fmt == "xml":
|
|
403
|
+
|
|
404
|
+
def to_xml(obj, root="response"):
|
|
405
|
+
lines = [f"<{root}>"]
|
|
406
|
+
for k, v in obj.items():
|
|
407
|
+
if isinstance(v, dict):
|
|
408
|
+
lines.append(to_xml(v, k))
|
|
409
|
+
elif isinstance(v, list):
|
|
410
|
+
lines.append(f" <{k}>")
|
|
411
|
+
for item in v:
|
|
412
|
+
lines.append(f" <item>{item}</item>")
|
|
413
|
+
lines.append(f" </{k}>")
|
|
414
|
+
else:
|
|
415
|
+
val = str(v).lower() if isinstance(v, bool) else v
|
|
416
|
+
lines.append(f" <{k}>{val}</{k}>")
|
|
417
|
+
lines.append(f"</{root}>")
|
|
418
|
+
return "\n".join(lines)
|
|
419
|
+
|
|
420
|
+
return f"```xml\n{to_xml(example_data)}\n```"
|
|
421
|
+
|
|
422
|
+
elif fmt == OutputFormat.TOON or fmt == "toon":
|
|
423
|
+
# TOON: Simple Key-Value pairs
|
|
424
|
+
lines = []
|
|
425
|
+
# Extract main fields from example data
|
|
426
|
+
plain_answer = example_data.get("_plain_answer", "")
|
|
427
|
+
reasoning = example_data.get("evidence", {}).get("rationale", "")
|
|
428
|
+
|
|
429
|
+
lines.append(f"diagnosis: {plain_answer}")
|
|
430
|
+
lines.append(f"reasoning: {reasoning}")
|
|
431
|
+
# Add other flat keys if they exist and aren't dicts/lists
|
|
432
|
+
for k, v in example_data.items():
|
|
433
|
+
if k not in ["evidence", "_plain_answer"] and not isinstance(v, (dict, list)):
|
|
434
|
+
lines.append(f"{k}: {str(v).lower() if isinstance(v, bool) else v}")
|
|
435
|
+
return f"```toon\n{chr(10).join(lines)}\n```"
|
|
436
|
+
|
|
437
|
+
elif fmt == OutputFormat.MARKDOWN or fmt == "markdown":
|
|
438
|
+
# Markdown: Header sections
|
|
439
|
+
plain_answer = example_data.get("_plain_answer", "")
|
|
440
|
+
reasoning = example_data.get("evidence", {}).get("rationale", "")
|
|
441
|
+
|
|
442
|
+
md = f"## Diagnosis\n{plain_answer}\n\n"
|
|
443
|
+
md += f"## Reasoning\n{reasoning}\n\n"
|
|
444
|
+
md += "## Confidence\n95" # Example confidence
|
|
445
|
+
return md
|
|
446
|
+
|
|
447
|
+
else:
|
|
448
|
+
# Default to JSON
|
|
449
|
+
return f"```json\n{json.dumps(example_data, indent=4)}\n```"
|
|
450
|
+
|
|
451
|
+
def _parse_formatted_response(self, response: str) -> Dict[str, Any]:
|
|
452
|
+
"""Parse response based on output format."""
|
|
453
|
+
import re
|
|
454
|
+
|
|
455
|
+
if self.output_format == OutputFormat.PLAIN:
|
|
456
|
+
# Extract FINAL ANSWER and CONFIDENCE from plain text
|
|
457
|
+
answer_match = re.search(r"FINAL ANSWER:\s*(.+?)(?:\n|$)", response, re.IGNORECASE)
|
|
458
|
+
confidence_match = re.search(r"CONFIDENCE:\s*(\d+)", response, re.IGNORECASE)
|
|
459
|
+
|
|
460
|
+
answer = answer_match.group(1).strip() if answer_match else response.strip()
|
|
461
|
+
confidence = int(confidence_match.group(1)) if confidence_match else None
|
|
462
|
+
|
|
463
|
+
return {
|
|
464
|
+
"answer": answer,
|
|
465
|
+
"confidence": confidence,
|
|
466
|
+
"reasoning": response,
|
|
467
|
+
"raw": response,
|
|
468
|
+
"parse_success": bool(answer_match),
|
|
469
|
+
"format": "plain",
|
|
470
|
+
}
|
|
471
|
+
|
|
472
|
+
if self.output_format == OutputFormat.JSON:
|
|
473
|
+
return self._parse_json(response)
|
|
474
|
+
elif self.output_format == OutputFormat.TOON:
|
|
475
|
+
return self._parse_toon(response)
|
|
476
|
+
elif self.output_format == OutputFormat.TOML:
|
|
477
|
+
return self._parse_toml(response)
|
|
478
|
+
elif self.output_format == OutputFormat.XML:
|
|
479
|
+
return self._parse_xml(response)
|
|
480
|
+
elif self.output_format == OutputFormat.YAML:
|
|
481
|
+
return self._parse_yaml(response)
|
|
482
|
+
elif self.output_format == OutputFormat.MARKDOWN:
|
|
483
|
+
return self._parse_markdown(response)
|
|
484
|
+
else:
|
|
485
|
+
return {
|
|
486
|
+
"answer": response.strip(),
|
|
487
|
+
"reasoning": response,
|
|
488
|
+
"raw": response,
|
|
489
|
+
"parse_success": False,
|
|
490
|
+
"format": self.output_format,
|
|
491
|
+
"error": f"Unknown format: {self.output_format}",
|
|
492
|
+
}
|
|
493
|
+
|
|
494
|
+
def _parse_json(self, response: str) -> Dict[str, Any]:
|
|
495
|
+
"""Parse JSON response."""
|
|
496
|
+
import re
|
|
497
|
+
|
|
498
|
+
json_match = re.search(r"```json\s*(.*?)\s*```", response, re.DOTALL)
|
|
499
|
+
if json_match:
|
|
500
|
+
json_str = json_match.group(1)
|
|
501
|
+
else:
|
|
502
|
+
json_match = re.search(r'\{[^{}]*"diagnosis"[^{}]*\}', response, re.DOTALL)
|
|
503
|
+
if json_match:
|
|
504
|
+
json_str = json_match.group(0)
|
|
505
|
+
else:
|
|
506
|
+
return self._parse_failure(response, "json")
|
|
507
|
+
|
|
508
|
+
try:
|
|
509
|
+
parsed = json.loads(json_str)
|
|
510
|
+
return self._extract_fields(parsed, response, "json")
|
|
511
|
+
except json.JSONDecodeError:
|
|
512
|
+
return self._parse_failure(response, "json")
|
|
513
|
+
|
|
514
|
+
def _parse_toon(self, response: str) -> Dict[str, Any]:
|
|
515
|
+
"""Parse TOON (Token-Oriented Object Notation) response."""
|
|
516
|
+
import re
|
|
517
|
+
|
|
518
|
+
# Extract TOON block from code fence
|
|
519
|
+
toon_match = re.search(r"```(?:toon)?\s*(.*?)\s*```", response, re.DOTALL)
|
|
520
|
+
content = toon_match.group(1) if toon_match else response
|
|
521
|
+
|
|
522
|
+
# Parse key: value pairs
|
|
523
|
+
result = {}
|
|
524
|
+
for line in content.strip().split("\n"):
|
|
525
|
+
if ":" in line:
|
|
526
|
+
key, _, value = line.partition(":")
|
|
527
|
+
key = key.strip()
|
|
528
|
+
value = value.strip()
|
|
529
|
+
if key and value:
|
|
530
|
+
# Auto-convert booleans
|
|
531
|
+
if value.lower() == "true":
|
|
532
|
+
result[key] = True
|
|
533
|
+
elif value.lower() == "false":
|
|
534
|
+
result[key] = False
|
|
535
|
+
else:
|
|
536
|
+
result[key] = value
|
|
537
|
+
|
|
538
|
+
if "diagnosis" in result:
|
|
539
|
+
return self._extract_fields(result, response, "toon")
|
|
540
|
+
return self._parse_failure(response, "toon")
|
|
541
|
+
|
|
542
|
+
def _parse_toml(self, response: str) -> Dict[str, Any]:
|
|
543
|
+
"""Parse TOML response."""
|
|
544
|
+
import re
|
|
545
|
+
|
|
546
|
+
toml_match = re.search(r"```toml\s*(.*?)\s*```", response, re.DOTALL)
|
|
547
|
+
if not toml_match:
|
|
548
|
+
return self._parse_failure(response, "toml")
|
|
549
|
+
|
|
550
|
+
try:
|
|
551
|
+
import tomllib
|
|
552
|
+
|
|
553
|
+
parsed = tomllib.loads(toml_match.group(1))
|
|
554
|
+
# Handle [response] section
|
|
555
|
+
if "response" in parsed:
|
|
556
|
+
parsed = parsed["response"]
|
|
557
|
+
return self._extract_fields(parsed, response, "toml")
|
|
558
|
+
except Exception:
|
|
559
|
+
return self._parse_failure(response, "toml")
|
|
560
|
+
|
|
561
|
+
def _parse_xml(self, response: str) -> Dict[str, Any]:
|
|
562
|
+
"""Parse XML response."""
|
|
563
|
+
import re
|
|
564
|
+
|
|
565
|
+
xml_match = re.search(r"```xml\s*(.*?)\s*```", response, re.DOTALL)
|
|
566
|
+
content = xml_match.group(1) if xml_match else response
|
|
567
|
+
|
|
568
|
+
try:
|
|
569
|
+
import xmltodict
|
|
570
|
+
|
|
571
|
+
def postprocessor(path, key, value):
|
|
572
|
+
if value and isinstance(value, str):
|
|
573
|
+
if value.lower() == "true":
|
|
574
|
+
return key, True
|
|
575
|
+
if value.lower() == "false":
|
|
576
|
+
return key, False
|
|
577
|
+
return key, value
|
|
578
|
+
|
|
579
|
+
parsed = xmltodict.parse(content.strip(), postprocessor=postprocessor)
|
|
580
|
+
# Remove root element wrapper if present (e.g. <response>)
|
|
581
|
+
if len(parsed) == 1:
|
|
582
|
+
key = next(iter(parsed))
|
|
583
|
+
parsed = parsed[key]
|
|
584
|
+
|
|
585
|
+
return self._extract_fields(parsed, response, "xml")
|
|
586
|
+
except Exception:
|
|
587
|
+
# Fallback to ElementTree with recursive parsing
|
|
588
|
+
try:
|
|
589
|
+
import xml.etree.ElementTree as ET
|
|
590
|
+
|
|
591
|
+
def elem_to_dict(elem):
|
|
592
|
+
text = elem.text.strip() if elem.text else None
|
|
593
|
+
# Convert booleans
|
|
594
|
+
if text:
|
|
595
|
+
if text.lower() == "true":
|
|
596
|
+
text = True
|
|
597
|
+
elif text.lower() == "false":
|
|
598
|
+
text = False
|
|
599
|
+
|
|
600
|
+
children = list(elem)
|
|
601
|
+
if not children:
|
|
602
|
+
return text
|
|
603
|
+
|
|
604
|
+
result = {}
|
|
605
|
+
for child in children:
|
|
606
|
+
child_val = elem_to_dict(child)
|
|
607
|
+
if child.tag in result:
|
|
608
|
+
if not isinstance(result[child.tag], list):
|
|
609
|
+
result[child.tag] = [result[child.tag]]
|
|
610
|
+
result[child.tag].append(child_val)
|
|
611
|
+
else:
|
|
612
|
+
result[child.tag] = child_val
|
|
613
|
+
return result
|
|
614
|
+
|
|
615
|
+
root = ET.fromstring(content.strip())
|
|
616
|
+
parsed = elem_to_dict(root)
|
|
617
|
+
# If root returned a dict (nested content), use it directly
|
|
618
|
+
# If root was simple value (not likely for top level), usage might differ
|
|
619
|
+
if isinstance(parsed, dict):
|
|
620
|
+
pass
|
|
621
|
+
|
|
622
|
+
return self._extract_fields(parsed, response, "xml")
|
|
623
|
+
except Exception:
|
|
624
|
+
return self._parse_failure(response, "xml")
|
|
625
|
+
|
|
626
|
+
def _parse_yaml(self, response: str) -> Dict[str, Any]:
|
|
627
|
+
"""Parse YAML response."""
|
|
628
|
+
import re
|
|
629
|
+
|
|
630
|
+
yaml_match = re.search(r"```yaml\s*(.*?)\s*```", response, re.DOTALL)
|
|
631
|
+
if not yaml_match:
|
|
632
|
+
return self._parse_failure(response, "yaml")
|
|
633
|
+
|
|
634
|
+
try:
|
|
635
|
+
import yaml
|
|
636
|
+
|
|
637
|
+
parsed = yaml.safe_load(yaml_match.group(1))
|
|
638
|
+
return self._extract_fields(parsed, response, "yaml")
|
|
639
|
+
except Exception:
|
|
640
|
+
return self._parse_failure(response, "yaml")
|
|
641
|
+
|
|
642
|
+
def _parse_markdown(self, response: str) -> Dict[str, Any]:
|
|
643
|
+
"""Parse structured Markdown response."""
|
|
644
|
+
import re
|
|
645
|
+
|
|
646
|
+
result = {}
|
|
647
|
+
# Extract sections
|
|
648
|
+
diagnosis_match = re.search(
|
|
649
|
+
r"##\s*Diagnosis\s*\n(.*?)(?=\n##|\Z)", response, re.DOTALL | re.IGNORECASE
|
|
650
|
+
)
|
|
651
|
+
confidence_match = re.search(
|
|
652
|
+
r"##\s*Confidence\s*\n(.*?)(?=\n##|\Z)", response, re.DOTALL | re.IGNORECASE
|
|
653
|
+
)
|
|
654
|
+
reasoning_match = re.search(
|
|
655
|
+
r"##\s*Reasoning\s*\n(.*?)(?=\n##|\Z)", response, re.DOTALL | re.IGNORECASE
|
|
656
|
+
)
|
|
657
|
+
|
|
658
|
+
if diagnosis_match:
|
|
659
|
+
result["diagnosis"] = diagnosis_match.group(1).strip()
|
|
660
|
+
if confidence_match:
|
|
661
|
+
try:
|
|
662
|
+
result["confidence"] = int(confidence_match.group(1).strip())
|
|
663
|
+
except ValueError:
|
|
664
|
+
result["confidence"] = 0
|
|
665
|
+
if reasoning_match:
|
|
666
|
+
result["reasoning"] = reasoning_match.group(1).strip()
|
|
667
|
+
|
|
668
|
+
if "diagnosis" in result:
|
|
669
|
+
return self._extract_fields(result, response, "markdown")
|
|
670
|
+
return self._parse_failure(response, "markdown")
|
|
671
|
+
|
|
672
|
+
def _extract_fields(self, parsed: dict, response: str, fmt: str) -> Dict[str, Any]:
|
|
673
|
+
"""Extract standard fields from parsed response."""
|
|
674
|
+
steps = parsed.get("step_by_step", [])
|
|
675
|
+
reasoning = "\n".join(steps) if isinstance(steps, list) else parsed.get("reasoning", "")
|
|
676
|
+
|
|
677
|
+
# Ensure confidence is integer
|
|
678
|
+
confidence = parsed.get("confidence", 0)
|
|
679
|
+
if isinstance(confidence, str):
|
|
680
|
+
try:
|
|
681
|
+
confidence = int(confidence)
|
|
682
|
+
except ValueError:
|
|
683
|
+
confidence = 0
|
|
684
|
+
elif not isinstance(confidence, int):
|
|
685
|
+
confidence = int(confidence) if confidence else 0
|
|
686
|
+
|
|
687
|
+
result = {
|
|
688
|
+
"answer": str(parsed.get("diagnosis", "")),
|
|
689
|
+
"reasoning": reasoning,
|
|
690
|
+
"confidence": confidence,
|
|
691
|
+
"step_by_step": steps,
|
|
692
|
+
"raw": response,
|
|
693
|
+
"parsed_data": parsed,
|
|
694
|
+
"parse_success": True,
|
|
695
|
+
"format": fmt,
|
|
696
|
+
}
|
|
697
|
+
|
|
698
|
+
# Merge all parsed fields into top-level result for easy access
|
|
699
|
+
# (excluding standard keys to avoid overwriting processed values)
|
|
700
|
+
for k, v in parsed.items():
|
|
701
|
+
if k not in result:
|
|
702
|
+
result[k] = v
|
|
703
|
+
|
|
704
|
+
return result
|
|
705
|
+
|
|
706
|
+
def _parse_failure(self, response: str, fmt: str) -> Dict[str, Any]:
|
|
707
|
+
"""Return parse failure result."""
|
|
708
|
+
return {
|
|
709
|
+
"answer": response.strip(),
|
|
710
|
+
"reasoning": response,
|
|
711
|
+
"raw": response,
|
|
712
|
+
"parse_success": False,
|
|
713
|
+
"format": fmt,
|
|
714
|
+
}
|
|
715
|
+
|
|
716
|
+
|
|
717
|
+
class BaseExperiment(ABC):
|
|
718
|
+
"""Abstract base class for experiments."""
|
|
719
|
+
|
|
720
|
+
@property
|
|
721
|
+
@abstractmethod
|
|
722
|
+
def name(self) -> str:
|
|
723
|
+
"""Experiment name for logging."""
|
|
724
|
+
...
|
|
725
|
+
|
|
726
|
+
@abstractmethod
|
|
727
|
+
def run(
|
|
728
|
+
self,
|
|
729
|
+
backend: "InferenceBackend",
|
|
730
|
+
dataset: Any,
|
|
731
|
+
prompt_strategy: BasePromptStrategy,
|
|
732
|
+
**kwargs,
|
|
733
|
+
) -> ExperimentResult:
|
|
734
|
+
"""
|
|
735
|
+
Run the experiment.
|
|
736
|
+
|
|
737
|
+
Args:
|
|
738
|
+
backend: Inference backend (vLLM or Transformers)
|
|
739
|
+
dataset: Dataset to run on
|
|
740
|
+
prompt_strategy: How to construct prompts
|
|
741
|
+
|
|
742
|
+
Returns:
|
|
743
|
+
ExperimentResult with metrics and outputs
|
|
744
|
+
"""
|
|
745
|
+
...
|
|
746
|
+
|
|
747
|
+
def validate_backend(self, backend: "InferenceBackend") -> None:
|
|
748
|
+
"""Check if backend supports required features."""
|
|
749
|
+
pass
|