cotlab 0.8.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cotlab/__init__.py +3 -0
- cotlab/analyse_experiments.py +392 -0
- cotlab/analysis/__init__.py +11 -0
- cotlab/analysis/cot_parser.py +243 -0
- cotlab/analysis/faithfulness_metrics.py +192 -0
- cotlab/backends/__init__.py +16 -0
- cotlab/backends/base.py +78 -0
- cotlab/backends/transformers_backend.py +335 -0
- cotlab/backends/vllm_backend.py +227 -0
- cotlab/cli.py +83 -0
- cotlab/core/__init__.py +34 -0
- cotlab/core/base.py +749 -0
- cotlab/core/config.py +90 -0
- cotlab/core/registry.py +68 -0
- cotlab/datasets/__init__.py +45 -0
- cotlab/datasets/loaders.py +1889 -0
- cotlab/experiment/__init__.py +315 -0
- cotlab/experiments/__init__.py +43 -0
- cotlab/experiments/activation_compare.py +290 -0
- cotlab/experiments/activation_patching.py +1050 -0
- cotlab/experiments/attention_analysis.py +885 -0
- cotlab/experiments/classification.py +235 -0
- cotlab/experiments/composite_shift_detector.py +524 -0
- cotlab/experiments/cot_ablation.py +277 -0
- cotlab/experiments/cot_faithfulness.py +187 -0
- cotlab/experiments/cot_heads.py +208 -0
- cotlab/experiments/full_layer_cot.py +232 -0
- cotlab/experiments/full_layer_patching.py +225 -0
- cotlab/experiments/h_neuron_analysis.py +712 -0
- cotlab/experiments/logit_lens.py +439 -0
- cotlab/experiments/multi_head_cot.py +220 -0
- cotlab/experiments/multi_head_patching.py +229 -0
- cotlab/experiments/probing_classifier.py +402 -0
- cotlab/experiments/residual_norm_ood.py +413 -0
- cotlab/experiments/sae_feature_analysis.py +673 -0
- cotlab/experiments/steering_vectors.py +223 -0
- cotlab/experiments/sycophancy_heads.py +224 -0
- cotlab/logging/__init__.py +5 -0
- cotlab/logging/json_logger.py +161 -0
- cotlab/main.py +317 -0
- cotlab/patching/__init__.py +24 -0
- cotlab/patching/cache.py +141 -0
- cotlab/patching/hooks.py +558 -0
- cotlab/patching/interventions.py +86 -0
- cotlab/patching/patcher.py +439 -0
- cotlab/patching/sae.py +181 -0
- cotlab/prompts/__init__.py +43 -0
- cotlab/prompts/cardiology.py +378 -0
- cotlab/prompts/histopathology.py +265 -0
- cotlab/prompts/length_matched_strategies.py +157 -0
- cotlab/prompts/mcq.py +193 -0
- cotlab/prompts/neurology.py +353 -0
- cotlab/prompts/oncology.py +367 -0
- cotlab/prompts/plab.py +162 -0
- cotlab/prompts/pubhealthbench.py +82 -0
- cotlab/prompts/pubmedqa.py +173 -0
- cotlab/prompts/radiology.py +414 -0
- cotlab/prompts/strategies.py +939 -0
- cotlab/prompts/tcga.py +168 -0
- cotlab/runner.py +204 -0
- cotlab-0.8.0.dist-info/METADATA +166 -0
- cotlab-0.8.0.dist-info/RECORD +65 -0
- cotlab-0.8.0.dist-info/WHEEL +4 -0
- cotlab-0.8.0.dist-info/entry_points.txt +3 -0
- cotlab-0.8.0.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,315 @@
|
|
|
1
|
+
"""Automatic experiment documentation generator.
|
|
2
|
+
|
|
3
|
+
Generates EXPERIMENT.md files for every experiment run with:
|
|
4
|
+
- Human-readable experiment description
|
|
5
|
+
- Research context and questions
|
|
6
|
+
- Full configuration
|
|
7
|
+
- Reproduction commands
|
|
8
|
+
- Results summary
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
from datetime import datetime
|
|
12
|
+
from pathlib import Path
|
|
13
|
+
from typing import Any, Dict, Optional
|
|
14
|
+
|
|
15
|
+
from omegaconf import DictConfig, ListConfig, OmegaConf
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class ExperimentDocumenter:
|
|
19
|
+
"""Generates markdown documentation for experiments."""
|
|
20
|
+
|
|
21
|
+
def __init__(self, config: DictConfig, output_dir: Path):
|
|
22
|
+
self.config = config
|
|
23
|
+
self.output_dir = Path(output_dir)
|
|
24
|
+
self.start_time = datetime.now()
|
|
25
|
+
|
|
26
|
+
def _experiment_name(self) -> str:
|
|
27
|
+
return self.config.get("experiment", {}).get("name", "")
|
|
28
|
+
|
|
29
|
+
def _experiment_variants(self) -> list[dict]:
|
|
30
|
+
exp_cfg = self.config.get("experiment", {})
|
|
31
|
+
variants = exp_cfg.get("variants", [])
|
|
32
|
+
if isinstance(variants, (DictConfig, ListConfig)):
|
|
33
|
+
variants = OmegaConf.to_container(variants, resolve=True)
|
|
34
|
+
return list(variants) if variants else []
|
|
35
|
+
|
|
36
|
+
@staticmethod
|
|
37
|
+
def _format_variant_value(value: Any, base_value: str, default_label: str) -> str:
|
|
38
|
+
if isinstance(value, DictConfig):
|
|
39
|
+
value = OmegaConf.to_container(value, resolve=True)
|
|
40
|
+
if isinstance(value, dict):
|
|
41
|
+
value = value.get("name", default_label)
|
|
42
|
+
if value in ("base", "default"):
|
|
43
|
+
return base_value
|
|
44
|
+
return value if value else base_value
|
|
45
|
+
|
|
46
|
+
def generate_title(self) -> str:
|
|
47
|
+
"""Generate human-readable experiment title from config."""
|
|
48
|
+
parts = []
|
|
49
|
+
|
|
50
|
+
exp_name = self._experiment_name()
|
|
51
|
+
if exp_name == "activation_compare":
|
|
52
|
+
variants = self._experiment_variants()
|
|
53
|
+
if variants:
|
|
54
|
+
run_names = [v.get("name", "run") for v in variants]
|
|
55
|
+
return f"Activation Compare: {' vs '.join(run_names)}"
|
|
56
|
+
return "Activation Compare"
|
|
57
|
+
|
|
58
|
+
# Dataset
|
|
59
|
+
dataset_name = self.config.get("dataset", {}).get("name", "unknown")
|
|
60
|
+
parts.append(dataset_name.capitalize())
|
|
61
|
+
|
|
62
|
+
# Reasoning mode
|
|
63
|
+
prompt_cfg = self.config.get("prompt", {})
|
|
64
|
+
if prompt_cfg.get("answer_first", False):
|
|
65
|
+
parts.append("Answer-First")
|
|
66
|
+
elif prompt_cfg.get("contrarian", False):
|
|
67
|
+
parts.append("Contrarian")
|
|
68
|
+
else:
|
|
69
|
+
parts.append("Standard")
|
|
70
|
+
|
|
71
|
+
# Few-shot
|
|
72
|
+
if not prompt_cfg.get("few_shot", True):
|
|
73
|
+
parts.append("Zero-Shot")
|
|
74
|
+
|
|
75
|
+
# Output format
|
|
76
|
+
output_fmt = prompt_cfg.get("output_format", "json")
|
|
77
|
+
if output_fmt != "json":
|
|
78
|
+
parts.append(f"({output_fmt.upper()})")
|
|
79
|
+
|
|
80
|
+
return " ".join(parts)
|
|
81
|
+
|
|
82
|
+
def infer_research_questions(self) -> list[str]:
|
|
83
|
+
"""Infer research questions from configuration."""
|
|
84
|
+
questions = []
|
|
85
|
+
exp_name = self._experiment_name()
|
|
86
|
+
|
|
87
|
+
if exp_name == "activation_compare":
|
|
88
|
+
return [
|
|
89
|
+
"How do residual stream activations differ across runs and datasets?",
|
|
90
|
+
"Which layers show the largest activation divergence between runs?",
|
|
91
|
+
"Do activation differences align with task or prompt changes?",
|
|
92
|
+
]
|
|
93
|
+
|
|
94
|
+
prompt_cfg = self.config.get("prompt", {})
|
|
95
|
+
|
|
96
|
+
# Few-shot ablation
|
|
97
|
+
if not prompt_cfg.get("few_shot", True):
|
|
98
|
+
questions.append("Does the model perform well without few-shot examples (zero-shot)?")
|
|
99
|
+
|
|
100
|
+
# Reasoning mode
|
|
101
|
+
if prompt_cfg.get("contrarian", False):
|
|
102
|
+
questions.append("Does skeptical/contrarian reasoning improve diagnostic accuracy?")
|
|
103
|
+
elif prompt_cfg.get("answer_first", False):
|
|
104
|
+
questions.append(
|
|
105
|
+
'Does "answer first, then justify" reasoning order affect performance?'
|
|
106
|
+
)
|
|
107
|
+
|
|
108
|
+
# Output format
|
|
109
|
+
output_fmt = prompt_cfg.get("output_format", "json")
|
|
110
|
+
if output_fmt != "json":
|
|
111
|
+
questions.append(
|
|
112
|
+
f"How does {output_fmt.upper()} output format affect parsing and accuracy?"
|
|
113
|
+
)
|
|
114
|
+
|
|
115
|
+
# Default question if none inferred
|
|
116
|
+
if not questions:
|
|
117
|
+
questions.append("Standard baseline experiment for comparison.")
|
|
118
|
+
|
|
119
|
+
return questions
|
|
120
|
+
|
|
121
|
+
def generate_reproduction_command(self) -> str:
|
|
122
|
+
"""Generate exact CLI command to reproduce experiment."""
|
|
123
|
+
parts = ["python -m cotlab.main"]
|
|
124
|
+
|
|
125
|
+
# Add overrides from config
|
|
126
|
+
exp_cfg = self.config.get("experiment", {})
|
|
127
|
+
prompt_cfg = self.config.get("prompt", {})
|
|
128
|
+
dataset_cfg = self.config.get("dataset", {})
|
|
129
|
+
|
|
130
|
+
exp_name = exp_cfg.get("name")
|
|
131
|
+
if exp_name:
|
|
132
|
+
parts.append(f"experiment={exp_name}")
|
|
133
|
+
if exp_cfg.get("num_samples") is not None:
|
|
134
|
+
parts.append(f"experiment.num_samples={exp_cfg['num_samples']}")
|
|
135
|
+
if exp_cfg.get("seed") is not None:
|
|
136
|
+
parts.append(f"experiment.seed={exp_cfg['seed']}")
|
|
137
|
+
|
|
138
|
+
# Prompt selection
|
|
139
|
+
if "name" in prompt_cfg:
|
|
140
|
+
parts.append(f"prompt={prompt_cfg['name']}")
|
|
141
|
+
|
|
142
|
+
# Prompt parameters
|
|
143
|
+
if prompt_cfg.get("contrarian", False):
|
|
144
|
+
parts.append("prompt.contrarian=true")
|
|
145
|
+
if prompt_cfg.get("answer_first", False):
|
|
146
|
+
parts.append("prompt.answer_first=true")
|
|
147
|
+
if not prompt_cfg.get("few_shot", True):
|
|
148
|
+
parts.append("prompt.few_shot=false")
|
|
149
|
+
|
|
150
|
+
output_fmt = prompt_cfg.get("output_format", "json")
|
|
151
|
+
if output_fmt != "json":
|
|
152
|
+
parts.append(f"prompt.output_format={output_fmt}")
|
|
153
|
+
|
|
154
|
+
# Dataset
|
|
155
|
+
if "name" in dataset_cfg:
|
|
156
|
+
parts.append(f"dataset={dataset_cfg['name']}")
|
|
157
|
+
|
|
158
|
+
return " \\\n ".join(parts)
|
|
159
|
+
|
|
160
|
+
def create_initial_doc(self) -> str:
|
|
161
|
+
"""Create initial experiment documentation (before results)."""
|
|
162
|
+
title = self.generate_title()
|
|
163
|
+
questions = self.infer_research_questions()
|
|
164
|
+
repro_cmd = self.generate_reproduction_command()
|
|
165
|
+
|
|
166
|
+
# Configuration summary
|
|
167
|
+
prompt_cfg = self.config.get("prompt", {})
|
|
168
|
+
dataset_cfg = self.config.get("dataset", {})
|
|
169
|
+
|
|
170
|
+
reasoning_mode = "Standard"
|
|
171
|
+
if prompt_cfg.get("answer_first", False):
|
|
172
|
+
reasoning_mode = "Answer-First"
|
|
173
|
+
elif prompt_cfg.get("contrarian", False):
|
|
174
|
+
reasoning_mode = "Contrarian (skeptical)"
|
|
175
|
+
|
|
176
|
+
few_shot = "Yes" if prompt_cfg.get("few_shot", True) else "No (zero-shot)"
|
|
177
|
+
output_fmt = prompt_cfg.get("output_format", "json").upper()
|
|
178
|
+
dataset_name = dataset_cfg.get("name", "unknown")
|
|
179
|
+
|
|
180
|
+
doc = f"""# Experiment: {title}
|
|
181
|
+
|
|
182
|
+
**Status:** Running
|
|
183
|
+
**Started:** {self.start_time.strftime("%Y-%m-%d %H:%M:%S")}
|
|
184
|
+
|
|
185
|
+
## Research Questions
|
|
186
|
+
|
|
187
|
+
"""
|
|
188
|
+
for i, question in enumerate(questions, 1):
|
|
189
|
+
doc += f"{i}. {question}\n"
|
|
190
|
+
|
|
191
|
+
doc += f"""
|
|
192
|
+
## Configuration
|
|
193
|
+
|
|
194
|
+
**Prompt Strategy:** {prompt_cfg.get("name", "unknown").capitalize()}
|
|
195
|
+
**Reasoning Mode:** {reasoning_mode}
|
|
196
|
+
**Few-Shot Examples:** {few_shot}
|
|
197
|
+
**Output Format:** {output_fmt}
|
|
198
|
+
**Dataset:** {dataset_name}
|
|
199
|
+
"""
|
|
200
|
+
|
|
201
|
+
exp_name = self._experiment_name()
|
|
202
|
+
if exp_name == "activation_compare":
|
|
203
|
+
variants = self._experiment_variants()
|
|
204
|
+
if variants:
|
|
205
|
+
doc += "\n**Variants:**\n"
|
|
206
|
+
for variant in variants:
|
|
207
|
+
variant_name = variant.get("name", "run")
|
|
208
|
+
variant_dataset = variant.get("dataset", "base")
|
|
209
|
+
variant_prompt = variant.get("prompt", "base")
|
|
210
|
+
variant_samples = variant.get("num_samples", "default")
|
|
211
|
+
variant_seed = variant.get("seed", "default")
|
|
212
|
+
variant_dataset = self._format_variant_value(
|
|
213
|
+
variant_dataset, dataset_name, "dataset"
|
|
214
|
+
)
|
|
215
|
+
variant_prompt = self._format_variant_value(
|
|
216
|
+
variant_prompt, prompt_cfg.get("name", "prompt"), "prompt"
|
|
217
|
+
)
|
|
218
|
+
doc += (
|
|
219
|
+
f"- {variant_name}: dataset={variant_dataset}, "
|
|
220
|
+
f"prompt={variant_prompt}, samples={variant_samples}, seed={variant_seed}\n"
|
|
221
|
+
)
|
|
222
|
+
|
|
223
|
+
doc += f"""
|
|
224
|
+
<details>
|
|
225
|
+
<summary>Full Configuration (YAML)</summary>
|
|
226
|
+
|
|
227
|
+
```yaml
|
|
228
|
+
{OmegaConf.to_yaml(self.config)}
|
|
229
|
+
```
|
|
230
|
+
</details>
|
|
231
|
+
|
|
232
|
+
## Reproduce
|
|
233
|
+
|
|
234
|
+
```bash
|
|
235
|
+
{repro_cmd}
|
|
236
|
+
```
|
|
237
|
+
|
|
238
|
+
## Results
|
|
239
|
+
|
|
240
|
+
_Results will be added after experiment completes..._
|
|
241
|
+
"""
|
|
242
|
+
|
|
243
|
+
return doc
|
|
244
|
+
|
|
245
|
+
def update_with_results(
|
|
246
|
+
self, results: Optional[Dict[str, Any]] = None, duration_seconds: Optional[float] = None
|
|
247
|
+
) -> str:
|
|
248
|
+
"""Update documentation with results after completion."""
|
|
249
|
+
# Read existing doc
|
|
250
|
+
doc_path = self.output_dir / "EXPERIMENT.md"
|
|
251
|
+
if doc_path.exists():
|
|
252
|
+
doc = doc_path.read_text()
|
|
253
|
+
else:
|
|
254
|
+
doc = self.create_initial_doc()
|
|
255
|
+
|
|
256
|
+
# Update status
|
|
257
|
+
doc = doc.replace("**Status:** Running", "**Status:** Completed")
|
|
258
|
+
|
|
259
|
+
# Add duration if provided
|
|
260
|
+
if duration_seconds is not None:
|
|
261
|
+
minutes = int(duration_seconds // 60)
|
|
262
|
+
seconds = int(duration_seconds % 60)
|
|
263
|
+
duration_str = (
|
|
264
|
+
f"{minutes} minutes {seconds} seconds" if minutes > 0 else f"{seconds} seconds"
|
|
265
|
+
)
|
|
266
|
+
doc = doc.replace(
|
|
267
|
+
f"**Started:** {self.start_time.strftime('%Y-%m-%d %H:%M:%S')}",
|
|
268
|
+
f"**Started:** {self.start_time.strftime('%Y-%m-%d %H:%M:%S')} \n**Duration:** {duration_str}",
|
|
269
|
+
)
|
|
270
|
+
|
|
271
|
+
# Update results section
|
|
272
|
+
if results:
|
|
273
|
+
results_md = self._format_results(results)
|
|
274
|
+
doc = doc.replace(
|
|
275
|
+
"## Results\n\n_Results will be added after experiment completes..._",
|
|
276
|
+
f"## Results\n\n{results_md}",
|
|
277
|
+
)
|
|
278
|
+
|
|
279
|
+
return doc
|
|
280
|
+
|
|
281
|
+
def _format_results(self, results: Dict[str, Any]) -> str:
|
|
282
|
+
"""Format results dictionary as markdown."""
|
|
283
|
+
md = ""
|
|
284
|
+
is_activation_compare = "num_runs" in results and "pair_count" in results
|
|
285
|
+
|
|
286
|
+
# Basic metrics
|
|
287
|
+
if "accuracy" in results:
|
|
288
|
+
md += f"- **Accuracy:** {results['accuracy']:.1%}\n"
|
|
289
|
+
if "total_samples" in results and not is_activation_compare:
|
|
290
|
+
md += f"- **Samples Processed:** {results['total_samples']}\n"
|
|
291
|
+
if "parse_failures" in results:
|
|
292
|
+
md += f"- **Parse Failures:** {results['parse_failures']}\n"
|
|
293
|
+
if "avg_time" in results:
|
|
294
|
+
md += f"- **Average Time per Sample:** {results['avg_time']:.2f}s\n"
|
|
295
|
+
|
|
296
|
+
# Additional metrics
|
|
297
|
+
for key, value in results.items():
|
|
298
|
+
if key not in ["accuracy", "total_samples", "parse_failures", "avg_time"]:
|
|
299
|
+
if isinstance(value, float):
|
|
300
|
+
md += f"- **{key.replace('_', ' ').title()}:** {value:.3f}\n"
|
|
301
|
+
else:
|
|
302
|
+
md += f"- **{key.replace('_', ' ').title()}:** {value}\n"
|
|
303
|
+
|
|
304
|
+
return md
|
|
305
|
+
|
|
306
|
+
def save(self, content: Optional[str] = None) -> Path:
|
|
307
|
+
"""Save experiment documentation to file."""
|
|
308
|
+
self.output_dir.mkdir(parents=True, exist_ok=True)
|
|
309
|
+
doc_path = self.output_dir / "EXPERIMENT.md"
|
|
310
|
+
|
|
311
|
+
if content is None:
|
|
312
|
+
content = self.create_initial_doc()
|
|
313
|
+
|
|
314
|
+
doc_path.write_text(content)
|
|
315
|
+
return doc_path
|
|
@@ -0,0 +1,43 @@
|
|
|
1
|
+
"""Experiments module."""
|
|
2
|
+
|
|
3
|
+
from .activation_compare import ActivationCompareExperiment
|
|
4
|
+
from .activation_patching import ActivationPatchingExperiment
|
|
5
|
+
from .attention_analysis import AttentionAnalysisExperiment
|
|
6
|
+
from .classification import ClassificationExperiment
|
|
7
|
+
from .composite_shift_detector import CompositeShiftDetectorExperiment
|
|
8
|
+
from .cot_ablation import CoTAblationExperiment
|
|
9
|
+
from .cot_faithfulness import CoTFaithfulnessExperiment
|
|
10
|
+
from .cot_heads import CoTHeadsExperiment
|
|
11
|
+
from .full_layer_cot import FullLayerCoTExperiment
|
|
12
|
+
from .full_layer_patching import FullLayerPatchingExperiment
|
|
13
|
+
from .h_neuron_analysis import HNeuronAnalysisExperiment
|
|
14
|
+
from .logit_lens import LogitLensExperiment
|
|
15
|
+
from .multi_head_cot import MultiHeadCoTExperiment
|
|
16
|
+
from .multi_head_patching import MultiHeadPatchingExperiment
|
|
17
|
+
from .probing_classifier import ProbingClassifierExperiment
|
|
18
|
+
from .residual_norm_ood import ResidualNormOODExperiment
|
|
19
|
+
from .sae_feature_analysis import SAEFeatureAnalysisExperiment
|
|
20
|
+
from .steering_vectors import SteeringVectorsExperiment
|
|
21
|
+
from .sycophancy_heads import SycophancyHeadsExperiment
|
|
22
|
+
|
|
23
|
+
__all__ = [
|
|
24
|
+
"HNeuronAnalysisExperiment",
|
|
25
|
+
"CompositeShiftDetectorExperiment",
|
|
26
|
+
"CoTAblationExperiment",
|
|
27
|
+
"CoTFaithfulnessExperiment",
|
|
28
|
+
"ActivationPatchingExperiment",
|
|
29
|
+
"ActivationCompareExperiment",
|
|
30
|
+
"AttentionAnalysisExperiment",
|
|
31
|
+
"ProbingClassifierExperiment",
|
|
32
|
+
"ClassificationExperiment",
|
|
33
|
+
"ResidualNormOODExperiment",
|
|
34
|
+
"SAEFeatureAnalysisExperiment",
|
|
35
|
+
"SycophancyHeadsExperiment",
|
|
36
|
+
"MultiHeadPatchingExperiment",
|
|
37
|
+
"FullLayerPatchingExperiment",
|
|
38
|
+
"SteeringVectorsExperiment",
|
|
39
|
+
"CoTHeadsExperiment",
|
|
40
|
+
"LogitLensExperiment",
|
|
41
|
+
"MultiHeadCoTExperiment",
|
|
42
|
+
"FullLayerCoTExperiment",
|
|
43
|
+
]
|
|
@@ -0,0 +1,290 @@
|
|
|
1
|
+
"""Activation Compare experiment — collect mean residual-stream vectors.
|
|
2
|
+
|
|
3
|
+
One run = one condition (dataset + prompt settings).
|
|
4
|
+
Saves per-layer mean activation vectors to results.json so that two or more
|
|
5
|
+
runs can be compared offline with ``scripts/compare_activations.py``.
|
|
6
|
+
|
|
7
|
+
Design follows logit_lens.py: hooks project inside the callback and move
|
|
8
|
+
tensors to CPU immediately to avoid GPU page-faults on long reports.
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
from typing import Any, Dict, List, Optional
|
|
12
|
+
|
|
13
|
+
import torch
|
|
14
|
+
from tqdm import tqdm
|
|
15
|
+
|
|
16
|
+
from ..backends.base import InferenceBackend
|
|
17
|
+
from ..core.base import BaseExperiment, ExperimentResult
|
|
18
|
+
from ..core.registry import Registry
|
|
19
|
+
from ..datasets.loaders import BaseDataset
|
|
20
|
+
from ..logging import ExperimentLogger
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
@Registry.register_experiment("activation_compare")
|
|
24
|
+
class ActivationCompareExperiment(BaseExperiment):
|
|
25
|
+
"""
|
|
26
|
+
Collect layer-wise mean residual-stream activations for one condition.
|
|
27
|
+
|
|
28
|
+
Forward-passes N samples through the model. At each layer a lightweight
|
|
29
|
+
hook captures the hidden state at the last token (or mean-pooled across
|
|
30
|
+
all tokens) and moves it to CPU immediately. After all samples the
|
|
31
|
+
per-layer mean vector is computed and saved to results.json.
|
|
32
|
+
|
|
33
|
+
Two saved runs can then be compared with ``scripts/compare_activations.py``
|
|
34
|
+
which computes cosine-similarity and L2-distance profiles per layer.
|
|
35
|
+
"""
|
|
36
|
+
|
|
37
|
+
def __init__(
|
|
38
|
+
self,
|
|
39
|
+
name: str = "activation_compare",
|
|
40
|
+
description: str = "Collect mean layer activations for representational comparison",
|
|
41
|
+
layer_stride: int = 2,
|
|
42
|
+
num_samples: Optional[int] = None,
|
|
43
|
+
pooling: str = "last_token", # "last_token" | "mean"
|
|
44
|
+
max_input_tokens: int = 1024,
|
|
45
|
+
seed: int = 42,
|
|
46
|
+
answer_cue: str = "\n\nAnswer:", # appended so last position mirrors logit_lens
|
|
47
|
+
batch_size: int = 1,
|
|
48
|
+
# Legacy fields kept so old YAML configs don't break
|
|
49
|
+
layers: Optional[List[int]] = None,
|
|
50
|
+
variants: Optional[List[Dict[str, Any]]] = None,
|
|
51
|
+
comparison_mode: str = "pairwise",
|
|
52
|
+
store_per_layer: bool = True,
|
|
53
|
+
log_samples: bool = False,
|
|
54
|
+
**kwargs,
|
|
55
|
+
):
|
|
56
|
+
self._name = name
|
|
57
|
+
self.description = description
|
|
58
|
+
self.layer_stride = layer_stride
|
|
59
|
+
self.num_samples = num_samples
|
|
60
|
+
self.pooling = pooling
|
|
61
|
+
self.max_input_tokens = max_input_tokens
|
|
62
|
+
self.seed = seed
|
|
63
|
+
self.answer_cue = answer_cue
|
|
64
|
+
self.batch_size = max(1, int(batch_size))
|
|
65
|
+
# Legacy fields silently ignored — kept for backward compat
|
|
66
|
+
self._layers_legacy = layers
|
|
67
|
+
self._variants_legacy = variants
|
|
68
|
+
|
|
69
|
+
@property
|
|
70
|
+
def name(self) -> str:
|
|
71
|
+
return self._name
|
|
72
|
+
|
|
73
|
+
# ------------------------------------------------------------------
|
|
74
|
+
# Helpers
|
|
75
|
+
# ------------------------------------------------------------------
|
|
76
|
+
|
|
77
|
+
def _resolve_layers(self, backend: InferenceBackend) -> List[int]:
|
|
78
|
+
all_layers = list(range(backend.hook_manager.num_layers))
|
|
79
|
+
return all_layers[:: self.layer_stride]
|
|
80
|
+
|
|
81
|
+
def _pool(self, tensor: torch.Tensor) -> torch.Tensor:
|
|
82
|
+
"""tensor: [seq_len, hidden_size] → [hidden_size]"""
|
|
83
|
+
if self.pooling == "last_token":
|
|
84
|
+
return tensor[-1]
|
|
85
|
+
elif self.pooling == "mean":
|
|
86
|
+
return tensor.mean(dim=0)
|
|
87
|
+
else:
|
|
88
|
+
raise ValueError(f"Unknown pooling: {self.pooling}")
|
|
89
|
+
|
|
90
|
+
def _pool_batch(self, tensor: torch.Tensor) -> torch.Tensor:
|
|
91
|
+
"""tensor: [B, seq_len, hidden_size] → [B, hidden_size]"""
|
|
92
|
+
if self.pooling == "last_token":
|
|
93
|
+
return tensor[:, -1, :] # left-padded: last position = last real token
|
|
94
|
+
elif self.pooling == "mean":
|
|
95
|
+
return tensor.mean(dim=1)
|
|
96
|
+
else:
|
|
97
|
+
raise ValueError(f"Unknown pooling: {self.pooling}")
|
|
98
|
+
|
|
99
|
+
# ------------------------------------------------------------------
|
|
100
|
+
# Main entry point
|
|
101
|
+
# ------------------------------------------------------------------
|
|
102
|
+
|
|
103
|
+
def run(
|
|
104
|
+
self,
|
|
105
|
+
backend: InferenceBackend,
|
|
106
|
+
dataset: BaseDataset,
|
|
107
|
+
prompt_strategy: Any,
|
|
108
|
+
logger: Optional[ExperimentLogger] = None,
|
|
109
|
+
**kwargs,
|
|
110
|
+
) -> ExperimentResult:
|
|
111
|
+
"""Collect mean residual-stream activations for one dataset condition."""
|
|
112
|
+
|
|
113
|
+
target_layers = self._resolve_layers(backend)
|
|
114
|
+
tokenizer = backend._tokenizer
|
|
115
|
+
|
|
116
|
+
print(f"Model : {backend.model_name}")
|
|
117
|
+
print(f"Layers ({len(target_layers)}): {target_layers}")
|
|
118
|
+
print(f"Pooling: {self.pooling}")
|
|
119
|
+
print(f"Batch size: {self.batch_size}")
|
|
120
|
+
|
|
121
|
+
if self.num_samples is None:
|
|
122
|
+
samples = list(dataset)
|
|
123
|
+
else:
|
|
124
|
+
samples = dataset.sample(self.num_samples, seed=self.seed)
|
|
125
|
+
n = len(samples)
|
|
126
|
+
print(f"Samples: {n}\n")
|
|
127
|
+
|
|
128
|
+
# Accumulators: layer_idx → running sum tensor (float32, CPU)
|
|
129
|
+
layer_sums: Dict[int, torch.Tensor] = {}
|
|
130
|
+
layer_sq_sums: Dict[int, torch.Tensor] = {}
|
|
131
|
+
layer_counts: Dict[int, int] = {}
|
|
132
|
+
processed = 0
|
|
133
|
+
|
|
134
|
+
# Chunk into batches
|
|
135
|
+
batches = [
|
|
136
|
+
samples[i : i + self.batch_size] for i in range(0, len(samples), self.batch_size)
|
|
137
|
+
]
|
|
138
|
+
|
|
139
|
+
for batch in tqdm(batches, desc="Activation collect"):
|
|
140
|
+
prompt_strs = []
|
|
141
|
+
for sample in batch:
|
|
142
|
+
prompt_input = {
|
|
143
|
+
"text": sample.text,
|
|
144
|
+
"question": sample.text,
|
|
145
|
+
"report": sample.text,
|
|
146
|
+
"metadata": sample.metadata or {},
|
|
147
|
+
}
|
|
148
|
+
prompt_strs.append(prompt_strategy.build_prompt(prompt_input) + self.answer_cue)
|
|
149
|
+
|
|
150
|
+
B = len(prompt_strs)
|
|
151
|
+
|
|
152
|
+
if B == 1:
|
|
153
|
+
# Single-sample path — no padding overhead
|
|
154
|
+
tokens = tokenizer(
|
|
155
|
+
prompt_strs[0],
|
|
156
|
+
return_tensors="pt",
|
|
157
|
+
truncation=True,
|
|
158
|
+
max_length=self.max_input_tokens,
|
|
159
|
+
).to(backend.device)
|
|
160
|
+
else:
|
|
161
|
+
# Batched path — left-pad so position [-1] = last real token
|
|
162
|
+
orig_side = tokenizer.padding_side
|
|
163
|
+
tokenizer.padding_side = "left"
|
|
164
|
+
if tokenizer.pad_token_id is None:
|
|
165
|
+
tokenizer.pad_token_id = tokenizer.eos_token_id
|
|
166
|
+
tokens = tokenizer(
|
|
167
|
+
prompt_strs,
|
|
168
|
+
return_tensors="pt",
|
|
169
|
+
truncation=True,
|
|
170
|
+
max_length=self.max_input_tokens,
|
|
171
|
+
padding=True,
|
|
172
|
+
).to(backend.device)
|
|
173
|
+
tokenizer.padding_side = orig_side
|
|
174
|
+
|
|
175
|
+
# Compute position_ids that skip padding so positional embeddings
|
|
176
|
+
# are identical to the non-padded single-sample case.
|
|
177
|
+
# For left-padded input: e.g. [PAD, PAD, t0, t1, t2] → positions [0,0,0,1,2]
|
|
178
|
+
attention_mask = tokens["attention_mask"]
|
|
179
|
+
position_ids = attention_mask.long().cumsum(-1) - 1
|
|
180
|
+
position_ids = position_ids.masked_fill(attention_mask == 0, 0)
|
|
181
|
+
tokens["position_ids"] = position_ids
|
|
182
|
+
|
|
183
|
+
# Capture hidden states in hooks; move to CPU immediately
|
|
184
|
+
layer_vecs: Dict[int, torch.Tensor] = {} # layer_idx -> [B, hidden] CPU float32
|
|
185
|
+
|
|
186
|
+
def make_hook(layer_idx: int):
|
|
187
|
+
def hook(module, inp, output):
|
|
188
|
+
tensor = output[0] if isinstance(output, tuple) else output
|
|
189
|
+
with torch.no_grad():
|
|
190
|
+
if B == 1:
|
|
191
|
+
# tensor: [1, seq_len, hidden]
|
|
192
|
+
vec = self._pool(tensor[0]).unsqueeze(0).cpu().float() # [1, hidden]
|
|
193
|
+
else:
|
|
194
|
+
# tensor: [B, seq_len, hidden]
|
|
195
|
+
vec = self._pool_batch(tensor).cpu().float() # [B, hidden]
|
|
196
|
+
layer_vecs[layer_idx] = vec
|
|
197
|
+
return output
|
|
198
|
+
|
|
199
|
+
return hook
|
|
200
|
+
|
|
201
|
+
handles = []
|
|
202
|
+
for layer_idx in target_layers:
|
|
203
|
+
if layer_idx < backend.hook_manager.num_layers:
|
|
204
|
+
mod = backend.hook_manager.get_residual_module(layer_idx)
|
|
205
|
+
handles.append(mod.register_forward_hook(make_hook(layer_idx)))
|
|
206
|
+
|
|
207
|
+
try:
|
|
208
|
+
with torch.no_grad():
|
|
209
|
+
backend._model(**tokens)
|
|
210
|
+
except Exception as e:
|
|
211
|
+
tqdm.write(f" [skip] batch starting at {batch[0].idx}: {type(e).__name__}: {e}")
|
|
212
|
+
torch.cuda.empty_cache()
|
|
213
|
+
for h in handles:
|
|
214
|
+
h.remove()
|
|
215
|
+
continue
|
|
216
|
+
finally:
|
|
217
|
+
for h in handles:
|
|
218
|
+
h.remove()
|
|
219
|
+
|
|
220
|
+
# Accumulate — iterate over each sample in the batch
|
|
221
|
+
for layer_idx, vecs in layer_vecs.items(): # vecs: [B, hidden]
|
|
222
|
+
for b in range(B):
|
|
223
|
+
vec = vecs[b] # [hidden]
|
|
224
|
+
if layer_idx not in layer_sums:
|
|
225
|
+
layer_sums[layer_idx] = torch.zeros_like(vec)
|
|
226
|
+
layer_sq_sums[layer_idx] = torch.zeros_like(vec)
|
|
227
|
+
layer_counts[layer_idx] = 0
|
|
228
|
+
layer_sums[layer_idx] += vec
|
|
229
|
+
layer_sq_sums[layer_idx] += vec**2
|
|
230
|
+
layer_counts[layer_idx] += 1
|
|
231
|
+
|
|
232
|
+
del layer_vecs
|
|
233
|
+
torch.cuda.empty_cache()
|
|
234
|
+
processed += B
|
|
235
|
+
|
|
236
|
+
# --- Compute statistics -----------------------------------------
|
|
237
|
+
mean_activations: Dict[int, List[float]] = {}
|
|
238
|
+
activation_norm: Dict[int, float] = {}
|
|
239
|
+
activation_std: Dict[int, float] = {}
|
|
240
|
+
|
|
241
|
+
for layer_idx in target_layers:
|
|
242
|
+
cnt = layer_counts.get(layer_idx, 0)
|
|
243
|
+
if cnt == 0:
|
|
244
|
+
continue
|
|
245
|
+
mean_vec = layer_sums[layer_idx] / cnt # [hidden]
|
|
246
|
+
var_vec = layer_sq_sums[layer_idx] / cnt - mean_vec**2
|
|
247
|
+
std_val = float(var_vec.clamp(min=0).sqrt().mean().item())
|
|
248
|
+
norm_val = float(mean_vec.norm().item())
|
|
249
|
+
mean_activations[layer_idx] = mean_vec.tolist()
|
|
250
|
+
activation_norm[layer_idx] = round(norm_val, 4)
|
|
251
|
+
activation_std[layer_idx] = round(std_val, 6)
|
|
252
|
+
|
|
253
|
+
# --- Print summary -----------------------------------------------
|
|
254
|
+
print("\n" + "=" * 60)
|
|
255
|
+
print("ACTIVATION COLLECT SUMMARY")
|
|
256
|
+
print("=" * 60)
|
|
257
|
+
print(f"Processed samples : {processed} / {n}")
|
|
258
|
+
print(f"{'Layer':>6} {'Norm':>10} {'Std':>10}")
|
|
259
|
+
print("-" * 32)
|
|
260
|
+
for layer_idx in target_layers:
|
|
261
|
+
if layer_idx in activation_norm:
|
|
262
|
+
print(
|
|
263
|
+
f"{layer_idx:>6} {activation_norm[layer_idx]:>10.2f} {activation_std[layer_idx]:>10.6f}"
|
|
264
|
+
)
|
|
265
|
+
print("=" * 60)
|
|
266
|
+
|
|
267
|
+
return ExperimentResult(
|
|
268
|
+
experiment_name=self.name,
|
|
269
|
+
model_name=backend.model_name,
|
|
270
|
+
prompt_strategy=(
|
|
271
|
+
prompt_strategy.name if hasattr(prompt_strategy, "name") else "custom"
|
|
272
|
+
),
|
|
273
|
+
metrics={
|
|
274
|
+
"num_samples": processed,
|
|
275
|
+
"pooling": self.pooling,
|
|
276
|
+
"layer_stride": self.layer_stride,
|
|
277
|
+
"activation_norms": activation_norm,
|
|
278
|
+
"activation_std": activation_std,
|
|
279
|
+
},
|
|
280
|
+
raw_outputs={
|
|
281
|
+
"mean_activations_per_layer": mean_activations,
|
|
282
|
+
},
|
|
283
|
+
metadata={
|
|
284
|
+
"target_layers": target_layers,
|
|
285
|
+
"pooling": self.pooling,
|
|
286
|
+
"num_samples": processed,
|
|
287
|
+
"seed": self.seed,
|
|
288
|
+
"answer_cue": self.answer_cue,
|
|
289
|
+
},
|
|
290
|
+
)
|