cotlab 0.8.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (65) hide show
  1. cotlab/__init__.py +3 -0
  2. cotlab/analyse_experiments.py +392 -0
  3. cotlab/analysis/__init__.py +11 -0
  4. cotlab/analysis/cot_parser.py +243 -0
  5. cotlab/analysis/faithfulness_metrics.py +192 -0
  6. cotlab/backends/__init__.py +16 -0
  7. cotlab/backends/base.py +78 -0
  8. cotlab/backends/transformers_backend.py +335 -0
  9. cotlab/backends/vllm_backend.py +227 -0
  10. cotlab/cli.py +83 -0
  11. cotlab/core/__init__.py +34 -0
  12. cotlab/core/base.py +749 -0
  13. cotlab/core/config.py +90 -0
  14. cotlab/core/registry.py +68 -0
  15. cotlab/datasets/__init__.py +45 -0
  16. cotlab/datasets/loaders.py +1889 -0
  17. cotlab/experiment/__init__.py +315 -0
  18. cotlab/experiments/__init__.py +43 -0
  19. cotlab/experiments/activation_compare.py +290 -0
  20. cotlab/experiments/activation_patching.py +1050 -0
  21. cotlab/experiments/attention_analysis.py +885 -0
  22. cotlab/experiments/classification.py +235 -0
  23. cotlab/experiments/composite_shift_detector.py +524 -0
  24. cotlab/experiments/cot_ablation.py +277 -0
  25. cotlab/experiments/cot_faithfulness.py +187 -0
  26. cotlab/experiments/cot_heads.py +208 -0
  27. cotlab/experiments/full_layer_cot.py +232 -0
  28. cotlab/experiments/full_layer_patching.py +225 -0
  29. cotlab/experiments/h_neuron_analysis.py +712 -0
  30. cotlab/experiments/logit_lens.py +439 -0
  31. cotlab/experiments/multi_head_cot.py +220 -0
  32. cotlab/experiments/multi_head_patching.py +229 -0
  33. cotlab/experiments/probing_classifier.py +402 -0
  34. cotlab/experiments/residual_norm_ood.py +413 -0
  35. cotlab/experiments/sae_feature_analysis.py +673 -0
  36. cotlab/experiments/steering_vectors.py +223 -0
  37. cotlab/experiments/sycophancy_heads.py +224 -0
  38. cotlab/logging/__init__.py +5 -0
  39. cotlab/logging/json_logger.py +161 -0
  40. cotlab/main.py +317 -0
  41. cotlab/patching/__init__.py +24 -0
  42. cotlab/patching/cache.py +141 -0
  43. cotlab/patching/hooks.py +558 -0
  44. cotlab/patching/interventions.py +86 -0
  45. cotlab/patching/patcher.py +439 -0
  46. cotlab/patching/sae.py +181 -0
  47. cotlab/prompts/__init__.py +43 -0
  48. cotlab/prompts/cardiology.py +378 -0
  49. cotlab/prompts/histopathology.py +265 -0
  50. cotlab/prompts/length_matched_strategies.py +157 -0
  51. cotlab/prompts/mcq.py +193 -0
  52. cotlab/prompts/neurology.py +353 -0
  53. cotlab/prompts/oncology.py +367 -0
  54. cotlab/prompts/plab.py +162 -0
  55. cotlab/prompts/pubhealthbench.py +82 -0
  56. cotlab/prompts/pubmedqa.py +173 -0
  57. cotlab/prompts/radiology.py +414 -0
  58. cotlab/prompts/strategies.py +939 -0
  59. cotlab/prompts/tcga.py +168 -0
  60. cotlab/runner.py +204 -0
  61. cotlab-0.8.0.dist-info/METADATA +166 -0
  62. cotlab-0.8.0.dist-info/RECORD +65 -0
  63. cotlab-0.8.0.dist-info/WHEEL +4 -0
  64. cotlab-0.8.0.dist-info/entry_points.txt +3 -0
  65. cotlab-0.8.0.dist-info/licenses/LICENSE +21 -0
@@ -0,0 +1,315 @@
1
+ """Automatic experiment documentation generator.
2
+
3
+ Generates EXPERIMENT.md files for every experiment run with:
4
+ - Human-readable experiment description
5
+ - Research context and questions
6
+ - Full configuration
7
+ - Reproduction commands
8
+ - Results summary
9
+ """
10
+
11
+ from datetime import datetime
12
+ from pathlib import Path
13
+ from typing import Any, Dict, Optional
14
+
15
+ from omegaconf import DictConfig, ListConfig, OmegaConf
16
+
17
+
18
+ class ExperimentDocumenter:
19
+ """Generates markdown documentation for experiments."""
20
+
21
+ def __init__(self, config: DictConfig, output_dir: Path):
22
+ self.config = config
23
+ self.output_dir = Path(output_dir)
24
+ self.start_time = datetime.now()
25
+
26
+ def _experiment_name(self) -> str:
27
+ return self.config.get("experiment", {}).get("name", "")
28
+
29
+ def _experiment_variants(self) -> list[dict]:
30
+ exp_cfg = self.config.get("experiment", {})
31
+ variants = exp_cfg.get("variants", [])
32
+ if isinstance(variants, (DictConfig, ListConfig)):
33
+ variants = OmegaConf.to_container(variants, resolve=True)
34
+ return list(variants) if variants else []
35
+
36
+ @staticmethod
37
+ def _format_variant_value(value: Any, base_value: str, default_label: str) -> str:
38
+ if isinstance(value, DictConfig):
39
+ value = OmegaConf.to_container(value, resolve=True)
40
+ if isinstance(value, dict):
41
+ value = value.get("name", default_label)
42
+ if value in ("base", "default"):
43
+ return base_value
44
+ return value if value else base_value
45
+
46
+ def generate_title(self) -> str:
47
+ """Generate human-readable experiment title from config."""
48
+ parts = []
49
+
50
+ exp_name = self._experiment_name()
51
+ if exp_name == "activation_compare":
52
+ variants = self._experiment_variants()
53
+ if variants:
54
+ run_names = [v.get("name", "run") for v in variants]
55
+ return f"Activation Compare: {' vs '.join(run_names)}"
56
+ return "Activation Compare"
57
+
58
+ # Dataset
59
+ dataset_name = self.config.get("dataset", {}).get("name", "unknown")
60
+ parts.append(dataset_name.capitalize())
61
+
62
+ # Reasoning mode
63
+ prompt_cfg = self.config.get("prompt", {})
64
+ if prompt_cfg.get("answer_first", False):
65
+ parts.append("Answer-First")
66
+ elif prompt_cfg.get("contrarian", False):
67
+ parts.append("Contrarian")
68
+ else:
69
+ parts.append("Standard")
70
+
71
+ # Few-shot
72
+ if not prompt_cfg.get("few_shot", True):
73
+ parts.append("Zero-Shot")
74
+
75
+ # Output format
76
+ output_fmt = prompt_cfg.get("output_format", "json")
77
+ if output_fmt != "json":
78
+ parts.append(f"({output_fmt.upper()})")
79
+
80
+ return " ".join(parts)
81
+
82
+ def infer_research_questions(self) -> list[str]:
83
+ """Infer research questions from configuration."""
84
+ questions = []
85
+ exp_name = self._experiment_name()
86
+
87
+ if exp_name == "activation_compare":
88
+ return [
89
+ "How do residual stream activations differ across runs and datasets?",
90
+ "Which layers show the largest activation divergence between runs?",
91
+ "Do activation differences align with task or prompt changes?",
92
+ ]
93
+
94
+ prompt_cfg = self.config.get("prompt", {})
95
+
96
+ # Few-shot ablation
97
+ if not prompt_cfg.get("few_shot", True):
98
+ questions.append("Does the model perform well without few-shot examples (zero-shot)?")
99
+
100
+ # Reasoning mode
101
+ if prompt_cfg.get("contrarian", False):
102
+ questions.append("Does skeptical/contrarian reasoning improve diagnostic accuracy?")
103
+ elif prompt_cfg.get("answer_first", False):
104
+ questions.append(
105
+ 'Does "answer first, then justify" reasoning order affect performance?'
106
+ )
107
+
108
+ # Output format
109
+ output_fmt = prompt_cfg.get("output_format", "json")
110
+ if output_fmt != "json":
111
+ questions.append(
112
+ f"How does {output_fmt.upper()} output format affect parsing and accuracy?"
113
+ )
114
+
115
+ # Default question if none inferred
116
+ if not questions:
117
+ questions.append("Standard baseline experiment for comparison.")
118
+
119
+ return questions
120
+
121
+ def generate_reproduction_command(self) -> str:
122
+ """Generate exact CLI command to reproduce experiment."""
123
+ parts = ["python -m cotlab.main"]
124
+
125
+ # Add overrides from config
126
+ exp_cfg = self.config.get("experiment", {})
127
+ prompt_cfg = self.config.get("prompt", {})
128
+ dataset_cfg = self.config.get("dataset", {})
129
+
130
+ exp_name = exp_cfg.get("name")
131
+ if exp_name:
132
+ parts.append(f"experiment={exp_name}")
133
+ if exp_cfg.get("num_samples") is not None:
134
+ parts.append(f"experiment.num_samples={exp_cfg['num_samples']}")
135
+ if exp_cfg.get("seed") is not None:
136
+ parts.append(f"experiment.seed={exp_cfg['seed']}")
137
+
138
+ # Prompt selection
139
+ if "name" in prompt_cfg:
140
+ parts.append(f"prompt={prompt_cfg['name']}")
141
+
142
+ # Prompt parameters
143
+ if prompt_cfg.get("contrarian", False):
144
+ parts.append("prompt.contrarian=true")
145
+ if prompt_cfg.get("answer_first", False):
146
+ parts.append("prompt.answer_first=true")
147
+ if not prompt_cfg.get("few_shot", True):
148
+ parts.append("prompt.few_shot=false")
149
+
150
+ output_fmt = prompt_cfg.get("output_format", "json")
151
+ if output_fmt != "json":
152
+ parts.append(f"prompt.output_format={output_fmt}")
153
+
154
+ # Dataset
155
+ if "name" in dataset_cfg:
156
+ parts.append(f"dataset={dataset_cfg['name']}")
157
+
158
+ return " \\\n ".join(parts)
159
+
160
+ def create_initial_doc(self) -> str:
161
+ """Create initial experiment documentation (before results)."""
162
+ title = self.generate_title()
163
+ questions = self.infer_research_questions()
164
+ repro_cmd = self.generate_reproduction_command()
165
+
166
+ # Configuration summary
167
+ prompt_cfg = self.config.get("prompt", {})
168
+ dataset_cfg = self.config.get("dataset", {})
169
+
170
+ reasoning_mode = "Standard"
171
+ if prompt_cfg.get("answer_first", False):
172
+ reasoning_mode = "Answer-First"
173
+ elif prompt_cfg.get("contrarian", False):
174
+ reasoning_mode = "Contrarian (skeptical)"
175
+
176
+ few_shot = "Yes" if prompt_cfg.get("few_shot", True) else "No (zero-shot)"
177
+ output_fmt = prompt_cfg.get("output_format", "json").upper()
178
+ dataset_name = dataset_cfg.get("name", "unknown")
179
+
180
+ doc = f"""# Experiment: {title}
181
+
182
+ **Status:** Running
183
+ **Started:** {self.start_time.strftime("%Y-%m-%d %H:%M:%S")}
184
+
185
+ ## Research Questions
186
+
187
+ """
188
+ for i, question in enumerate(questions, 1):
189
+ doc += f"{i}. {question}\n"
190
+
191
+ doc += f"""
192
+ ## Configuration
193
+
194
+ **Prompt Strategy:** {prompt_cfg.get("name", "unknown").capitalize()}
195
+ **Reasoning Mode:** {reasoning_mode}
196
+ **Few-Shot Examples:** {few_shot}
197
+ **Output Format:** {output_fmt}
198
+ **Dataset:** {dataset_name}
199
+ """
200
+
201
+ exp_name = self._experiment_name()
202
+ if exp_name == "activation_compare":
203
+ variants = self._experiment_variants()
204
+ if variants:
205
+ doc += "\n**Variants:**\n"
206
+ for variant in variants:
207
+ variant_name = variant.get("name", "run")
208
+ variant_dataset = variant.get("dataset", "base")
209
+ variant_prompt = variant.get("prompt", "base")
210
+ variant_samples = variant.get("num_samples", "default")
211
+ variant_seed = variant.get("seed", "default")
212
+ variant_dataset = self._format_variant_value(
213
+ variant_dataset, dataset_name, "dataset"
214
+ )
215
+ variant_prompt = self._format_variant_value(
216
+ variant_prompt, prompt_cfg.get("name", "prompt"), "prompt"
217
+ )
218
+ doc += (
219
+ f"- {variant_name}: dataset={variant_dataset}, "
220
+ f"prompt={variant_prompt}, samples={variant_samples}, seed={variant_seed}\n"
221
+ )
222
+
223
+ doc += f"""
224
+ <details>
225
+ <summary>Full Configuration (YAML)</summary>
226
+
227
+ ```yaml
228
+ {OmegaConf.to_yaml(self.config)}
229
+ ```
230
+ </details>
231
+
232
+ ## Reproduce
233
+
234
+ ```bash
235
+ {repro_cmd}
236
+ ```
237
+
238
+ ## Results
239
+
240
+ _Results will be added after experiment completes..._
241
+ """
242
+
243
+ return doc
244
+
245
+ def update_with_results(
246
+ self, results: Optional[Dict[str, Any]] = None, duration_seconds: Optional[float] = None
247
+ ) -> str:
248
+ """Update documentation with results after completion."""
249
+ # Read existing doc
250
+ doc_path = self.output_dir / "EXPERIMENT.md"
251
+ if doc_path.exists():
252
+ doc = doc_path.read_text()
253
+ else:
254
+ doc = self.create_initial_doc()
255
+
256
+ # Update status
257
+ doc = doc.replace("**Status:** Running", "**Status:** Completed")
258
+
259
+ # Add duration if provided
260
+ if duration_seconds is not None:
261
+ minutes = int(duration_seconds // 60)
262
+ seconds = int(duration_seconds % 60)
263
+ duration_str = (
264
+ f"{minutes} minutes {seconds} seconds" if minutes > 0 else f"{seconds} seconds"
265
+ )
266
+ doc = doc.replace(
267
+ f"**Started:** {self.start_time.strftime('%Y-%m-%d %H:%M:%S')}",
268
+ f"**Started:** {self.start_time.strftime('%Y-%m-%d %H:%M:%S')} \n**Duration:** {duration_str}",
269
+ )
270
+
271
+ # Update results section
272
+ if results:
273
+ results_md = self._format_results(results)
274
+ doc = doc.replace(
275
+ "## Results\n\n_Results will be added after experiment completes..._",
276
+ f"## Results\n\n{results_md}",
277
+ )
278
+
279
+ return doc
280
+
281
+ def _format_results(self, results: Dict[str, Any]) -> str:
282
+ """Format results dictionary as markdown."""
283
+ md = ""
284
+ is_activation_compare = "num_runs" in results and "pair_count" in results
285
+
286
+ # Basic metrics
287
+ if "accuracy" in results:
288
+ md += f"- **Accuracy:** {results['accuracy']:.1%}\n"
289
+ if "total_samples" in results and not is_activation_compare:
290
+ md += f"- **Samples Processed:** {results['total_samples']}\n"
291
+ if "parse_failures" in results:
292
+ md += f"- **Parse Failures:** {results['parse_failures']}\n"
293
+ if "avg_time" in results:
294
+ md += f"- **Average Time per Sample:** {results['avg_time']:.2f}s\n"
295
+
296
+ # Additional metrics
297
+ for key, value in results.items():
298
+ if key not in ["accuracy", "total_samples", "parse_failures", "avg_time"]:
299
+ if isinstance(value, float):
300
+ md += f"- **{key.replace('_', ' ').title()}:** {value:.3f}\n"
301
+ else:
302
+ md += f"- **{key.replace('_', ' ').title()}:** {value}\n"
303
+
304
+ return md
305
+
306
+ def save(self, content: Optional[str] = None) -> Path:
307
+ """Save experiment documentation to file."""
308
+ self.output_dir.mkdir(parents=True, exist_ok=True)
309
+ doc_path = self.output_dir / "EXPERIMENT.md"
310
+
311
+ if content is None:
312
+ content = self.create_initial_doc()
313
+
314
+ doc_path.write_text(content)
315
+ return doc_path
@@ -0,0 +1,43 @@
1
+ """Experiments module."""
2
+
3
+ from .activation_compare import ActivationCompareExperiment
4
+ from .activation_patching import ActivationPatchingExperiment
5
+ from .attention_analysis import AttentionAnalysisExperiment
6
+ from .classification import ClassificationExperiment
7
+ from .composite_shift_detector import CompositeShiftDetectorExperiment
8
+ from .cot_ablation import CoTAblationExperiment
9
+ from .cot_faithfulness import CoTFaithfulnessExperiment
10
+ from .cot_heads import CoTHeadsExperiment
11
+ from .full_layer_cot import FullLayerCoTExperiment
12
+ from .full_layer_patching import FullLayerPatchingExperiment
13
+ from .h_neuron_analysis import HNeuronAnalysisExperiment
14
+ from .logit_lens import LogitLensExperiment
15
+ from .multi_head_cot import MultiHeadCoTExperiment
16
+ from .multi_head_patching import MultiHeadPatchingExperiment
17
+ from .probing_classifier import ProbingClassifierExperiment
18
+ from .residual_norm_ood import ResidualNormOODExperiment
19
+ from .sae_feature_analysis import SAEFeatureAnalysisExperiment
20
+ from .steering_vectors import SteeringVectorsExperiment
21
+ from .sycophancy_heads import SycophancyHeadsExperiment
22
+
23
+ __all__ = [
24
+ "HNeuronAnalysisExperiment",
25
+ "CompositeShiftDetectorExperiment",
26
+ "CoTAblationExperiment",
27
+ "CoTFaithfulnessExperiment",
28
+ "ActivationPatchingExperiment",
29
+ "ActivationCompareExperiment",
30
+ "AttentionAnalysisExperiment",
31
+ "ProbingClassifierExperiment",
32
+ "ClassificationExperiment",
33
+ "ResidualNormOODExperiment",
34
+ "SAEFeatureAnalysisExperiment",
35
+ "SycophancyHeadsExperiment",
36
+ "MultiHeadPatchingExperiment",
37
+ "FullLayerPatchingExperiment",
38
+ "SteeringVectorsExperiment",
39
+ "CoTHeadsExperiment",
40
+ "LogitLensExperiment",
41
+ "MultiHeadCoTExperiment",
42
+ "FullLayerCoTExperiment",
43
+ ]
@@ -0,0 +1,290 @@
1
+ """Activation Compare experiment — collect mean residual-stream vectors.
2
+
3
+ One run = one condition (dataset + prompt settings).
4
+ Saves per-layer mean activation vectors to results.json so that two or more
5
+ runs can be compared offline with ``scripts/compare_activations.py``.
6
+
7
+ Design follows logit_lens.py: hooks project inside the callback and move
8
+ tensors to CPU immediately to avoid GPU page-faults on long reports.
9
+ """
10
+
11
+ from typing import Any, Dict, List, Optional
12
+
13
+ import torch
14
+ from tqdm import tqdm
15
+
16
+ from ..backends.base import InferenceBackend
17
+ from ..core.base import BaseExperiment, ExperimentResult
18
+ from ..core.registry import Registry
19
+ from ..datasets.loaders import BaseDataset
20
+ from ..logging import ExperimentLogger
21
+
22
+
23
+ @Registry.register_experiment("activation_compare")
24
+ class ActivationCompareExperiment(BaseExperiment):
25
+ """
26
+ Collect layer-wise mean residual-stream activations for one condition.
27
+
28
+ Forward-passes N samples through the model. At each layer a lightweight
29
+ hook captures the hidden state at the last token (or mean-pooled across
30
+ all tokens) and moves it to CPU immediately. After all samples the
31
+ per-layer mean vector is computed and saved to results.json.
32
+
33
+ Two saved runs can then be compared with ``scripts/compare_activations.py``
34
+ which computes cosine-similarity and L2-distance profiles per layer.
35
+ """
36
+
37
+ def __init__(
38
+ self,
39
+ name: str = "activation_compare",
40
+ description: str = "Collect mean layer activations for representational comparison",
41
+ layer_stride: int = 2,
42
+ num_samples: Optional[int] = None,
43
+ pooling: str = "last_token", # "last_token" | "mean"
44
+ max_input_tokens: int = 1024,
45
+ seed: int = 42,
46
+ answer_cue: str = "\n\nAnswer:", # appended so last position mirrors logit_lens
47
+ batch_size: int = 1,
48
+ # Legacy fields kept so old YAML configs don't break
49
+ layers: Optional[List[int]] = None,
50
+ variants: Optional[List[Dict[str, Any]]] = None,
51
+ comparison_mode: str = "pairwise",
52
+ store_per_layer: bool = True,
53
+ log_samples: bool = False,
54
+ **kwargs,
55
+ ):
56
+ self._name = name
57
+ self.description = description
58
+ self.layer_stride = layer_stride
59
+ self.num_samples = num_samples
60
+ self.pooling = pooling
61
+ self.max_input_tokens = max_input_tokens
62
+ self.seed = seed
63
+ self.answer_cue = answer_cue
64
+ self.batch_size = max(1, int(batch_size))
65
+ # Legacy fields silently ignored — kept for backward compat
66
+ self._layers_legacy = layers
67
+ self._variants_legacy = variants
68
+
69
+ @property
70
+ def name(self) -> str:
71
+ return self._name
72
+
73
+ # ------------------------------------------------------------------
74
+ # Helpers
75
+ # ------------------------------------------------------------------
76
+
77
+ def _resolve_layers(self, backend: InferenceBackend) -> List[int]:
78
+ all_layers = list(range(backend.hook_manager.num_layers))
79
+ return all_layers[:: self.layer_stride]
80
+
81
+ def _pool(self, tensor: torch.Tensor) -> torch.Tensor:
82
+ """tensor: [seq_len, hidden_size] → [hidden_size]"""
83
+ if self.pooling == "last_token":
84
+ return tensor[-1]
85
+ elif self.pooling == "mean":
86
+ return tensor.mean(dim=0)
87
+ else:
88
+ raise ValueError(f"Unknown pooling: {self.pooling}")
89
+
90
+ def _pool_batch(self, tensor: torch.Tensor) -> torch.Tensor:
91
+ """tensor: [B, seq_len, hidden_size] → [B, hidden_size]"""
92
+ if self.pooling == "last_token":
93
+ return tensor[:, -1, :] # left-padded: last position = last real token
94
+ elif self.pooling == "mean":
95
+ return tensor.mean(dim=1)
96
+ else:
97
+ raise ValueError(f"Unknown pooling: {self.pooling}")
98
+
99
+ # ------------------------------------------------------------------
100
+ # Main entry point
101
+ # ------------------------------------------------------------------
102
+
103
+ def run(
104
+ self,
105
+ backend: InferenceBackend,
106
+ dataset: BaseDataset,
107
+ prompt_strategy: Any,
108
+ logger: Optional[ExperimentLogger] = None,
109
+ **kwargs,
110
+ ) -> ExperimentResult:
111
+ """Collect mean residual-stream activations for one dataset condition."""
112
+
113
+ target_layers = self._resolve_layers(backend)
114
+ tokenizer = backend._tokenizer
115
+
116
+ print(f"Model : {backend.model_name}")
117
+ print(f"Layers ({len(target_layers)}): {target_layers}")
118
+ print(f"Pooling: {self.pooling}")
119
+ print(f"Batch size: {self.batch_size}")
120
+
121
+ if self.num_samples is None:
122
+ samples = list(dataset)
123
+ else:
124
+ samples = dataset.sample(self.num_samples, seed=self.seed)
125
+ n = len(samples)
126
+ print(f"Samples: {n}\n")
127
+
128
+ # Accumulators: layer_idx → running sum tensor (float32, CPU)
129
+ layer_sums: Dict[int, torch.Tensor] = {}
130
+ layer_sq_sums: Dict[int, torch.Tensor] = {}
131
+ layer_counts: Dict[int, int] = {}
132
+ processed = 0
133
+
134
+ # Chunk into batches
135
+ batches = [
136
+ samples[i : i + self.batch_size] for i in range(0, len(samples), self.batch_size)
137
+ ]
138
+
139
+ for batch in tqdm(batches, desc="Activation collect"):
140
+ prompt_strs = []
141
+ for sample in batch:
142
+ prompt_input = {
143
+ "text": sample.text,
144
+ "question": sample.text,
145
+ "report": sample.text,
146
+ "metadata": sample.metadata or {},
147
+ }
148
+ prompt_strs.append(prompt_strategy.build_prompt(prompt_input) + self.answer_cue)
149
+
150
+ B = len(prompt_strs)
151
+
152
+ if B == 1:
153
+ # Single-sample path — no padding overhead
154
+ tokens = tokenizer(
155
+ prompt_strs[0],
156
+ return_tensors="pt",
157
+ truncation=True,
158
+ max_length=self.max_input_tokens,
159
+ ).to(backend.device)
160
+ else:
161
+ # Batched path — left-pad so position [-1] = last real token
162
+ orig_side = tokenizer.padding_side
163
+ tokenizer.padding_side = "left"
164
+ if tokenizer.pad_token_id is None:
165
+ tokenizer.pad_token_id = tokenizer.eos_token_id
166
+ tokens = tokenizer(
167
+ prompt_strs,
168
+ return_tensors="pt",
169
+ truncation=True,
170
+ max_length=self.max_input_tokens,
171
+ padding=True,
172
+ ).to(backend.device)
173
+ tokenizer.padding_side = orig_side
174
+
175
+ # Compute position_ids that skip padding so positional embeddings
176
+ # are identical to the non-padded single-sample case.
177
+ # For left-padded input: e.g. [PAD, PAD, t0, t1, t2] → positions [0,0,0,1,2]
178
+ attention_mask = tokens["attention_mask"]
179
+ position_ids = attention_mask.long().cumsum(-1) - 1
180
+ position_ids = position_ids.masked_fill(attention_mask == 0, 0)
181
+ tokens["position_ids"] = position_ids
182
+
183
+ # Capture hidden states in hooks; move to CPU immediately
184
+ layer_vecs: Dict[int, torch.Tensor] = {} # layer_idx -> [B, hidden] CPU float32
185
+
186
+ def make_hook(layer_idx: int):
187
+ def hook(module, inp, output):
188
+ tensor = output[0] if isinstance(output, tuple) else output
189
+ with torch.no_grad():
190
+ if B == 1:
191
+ # tensor: [1, seq_len, hidden]
192
+ vec = self._pool(tensor[0]).unsqueeze(0).cpu().float() # [1, hidden]
193
+ else:
194
+ # tensor: [B, seq_len, hidden]
195
+ vec = self._pool_batch(tensor).cpu().float() # [B, hidden]
196
+ layer_vecs[layer_idx] = vec
197
+ return output
198
+
199
+ return hook
200
+
201
+ handles = []
202
+ for layer_idx in target_layers:
203
+ if layer_idx < backend.hook_manager.num_layers:
204
+ mod = backend.hook_manager.get_residual_module(layer_idx)
205
+ handles.append(mod.register_forward_hook(make_hook(layer_idx)))
206
+
207
+ try:
208
+ with torch.no_grad():
209
+ backend._model(**tokens)
210
+ except Exception as e:
211
+ tqdm.write(f" [skip] batch starting at {batch[0].idx}: {type(e).__name__}: {e}")
212
+ torch.cuda.empty_cache()
213
+ for h in handles:
214
+ h.remove()
215
+ continue
216
+ finally:
217
+ for h in handles:
218
+ h.remove()
219
+
220
+ # Accumulate — iterate over each sample in the batch
221
+ for layer_idx, vecs in layer_vecs.items(): # vecs: [B, hidden]
222
+ for b in range(B):
223
+ vec = vecs[b] # [hidden]
224
+ if layer_idx not in layer_sums:
225
+ layer_sums[layer_idx] = torch.zeros_like(vec)
226
+ layer_sq_sums[layer_idx] = torch.zeros_like(vec)
227
+ layer_counts[layer_idx] = 0
228
+ layer_sums[layer_idx] += vec
229
+ layer_sq_sums[layer_idx] += vec**2
230
+ layer_counts[layer_idx] += 1
231
+
232
+ del layer_vecs
233
+ torch.cuda.empty_cache()
234
+ processed += B
235
+
236
+ # --- Compute statistics -----------------------------------------
237
+ mean_activations: Dict[int, List[float]] = {}
238
+ activation_norm: Dict[int, float] = {}
239
+ activation_std: Dict[int, float] = {}
240
+
241
+ for layer_idx in target_layers:
242
+ cnt = layer_counts.get(layer_idx, 0)
243
+ if cnt == 0:
244
+ continue
245
+ mean_vec = layer_sums[layer_idx] / cnt # [hidden]
246
+ var_vec = layer_sq_sums[layer_idx] / cnt - mean_vec**2
247
+ std_val = float(var_vec.clamp(min=0).sqrt().mean().item())
248
+ norm_val = float(mean_vec.norm().item())
249
+ mean_activations[layer_idx] = mean_vec.tolist()
250
+ activation_norm[layer_idx] = round(norm_val, 4)
251
+ activation_std[layer_idx] = round(std_val, 6)
252
+
253
+ # --- Print summary -----------------------------------------------
254
+ print("\n" + "=" * 60)
255
+ print("ACTIVATION COLLECT SUMMARY")
256
+ print("=" * 60)
257
+ print(f"Processed samples : {processed} / {n}")
258
+ print(f"{'Layer':>6} {'Norm':>10} {'Std':>10}")
259
+ print("-" * 32)
260
+ for layer_idx in target_layers:
261
+ if layer_idx in activation_norm:
262
+ print(
263
+ f"{layer_idx:>6} {activation_norm[layer_idx]:>10.2f} {activation_std[layer_idx]:>10.6f}"
264
+ )
265
+ print("=" * 60)
266
+
267
+ return ExperimentResult(
268
+ experiment_name=self.name,
269
+ model_name=backend.model_name,
270
+ prompt_strategy=(
271
+ prompt_strategy.name if hasattr(prompt_strategy, "name") else "custom"
272
+ ),
273
+ metrics={
274
+ "num_samples": processed,
275
+ "pooling": self.pooling,
276
+ "layer_stride": self.layer_stride,
277
+ "activation_norms": activation_norm,
278
+ "activation_std": activation_std,
279
+ },
280
+ raw_outputs={
281
+ "mean_activations_per_layer": mean_activations,
282
+ },
283
+ metadata={
284
+ "target_layers": target_layers,
285
+ "pooling": self.pooling,
286
+ "num_samples": processed,
287
+ "seed": self.seed,
288
+ "answer_cue": self.answer_cue,
289
+ },
290
+ )