cotlab 0.8.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (65) hide show
  1. cotlab/__init__.py +3 -0
  2. cotlab/analyse_experiments.py +392 -0
  3. cotlab/analysis/__init__.py +11 -0
  4. cotlab/analysis/cot_parser.py +243 -0
  5. cotlab/analysis/faithfulness_metrics.py +192 -0
  6. cotlab/backends/__init__.py +16 -0
  7. cotlab/backends/base.py +78 -0
  8. cotlab/backends/transformers_backend.py +335 -0
  9. cotlab/backends/vllm_backend.py +227 -0
  10. cotlab/cli.py +83 -0
  11. cotlab/core/__init__.py +34 -0
  12. cotlab/core/base.py +749 -0
  13. cotlab/core/config.py +90 -0
  14. cotlab/core/registry.py +68 -0
  15. cotlab/datasets/__init__.py +45 -0
  16. cotlab/datasets/loaders.py +1889 -0
  17. cotlab/experiment/__init__.py +315 -0
  18. cotlab/experiments/__init__.py +43 -0
  19. cotlab/experiments/activation_compare.py +290 -0
  20. cotlab/experiments/activation_patching.py +1050 -0
  21. cotlab/experiments/attention_analysis.py +885 -0
  22. cotlab/experiments/classification.py +235 -0
  23. cotlab/experiments/composite_shift_detector.py +524 -0
  24. cotlab/experiments/cot_ablation.py +277 -0
  25. cotlab/experiments/cot_faithfulness.py +187 -0
  26. cotlab/experiments/cot_heads.py +208 -0
  27. cotlab/experiments/full_layer_cot.py +232 -0
  28. cotlab/experiments/full_layer_patching.py +225 -0
  29. cotlab/experiments/h_neuron_analysis.py +712 -0
  30. cotlab/experiments/logit_lens.py +439 -0
  31. cotlab/experiments/multi_head_cot.py +220 -0
  32. cotlab/experiments/multi_head_patching.py +229 -0
  33. cotlab/experiments/probing_classifier.py +402 -0
  34. cotlab/experiments/residual_norm_ood.py +413 -0
  35. cotlab/experiments/sae_feature_analysis.py +673 -0
  36. cotlab/experiments/steering_vectors.py +223 -0
  37. cotlab/experiments/sycophancy_heads.py +224 -0
  38. cotlab/logging/__init__.py +5 -0
  39. cotlab/logging/json_logger.py +161 -0
  40. cotlab/main.py +317 -0
  41. cotlab/patching/__init__.py +24 -0
  42. cotlab/patching/cache.py +141 -0
  43. cotlab/patching/hooks.py +558 -0
  44. cotlab/patching/interventions.py +86 -0
  45. cotlab/patching/patcher.py +439 -0
  46. cotlab/patching/sae.py +181 -0
  47. cotlab/prompts/__init__.py +43 -0
  48. cotlab/prompts/cardiology.py +378 -0
  49. cotlab/prompts/histopathology.py +265 -0
  50. cotlab/prompts/length_matched_strategies.py +157 -0
  51. cotlab/prompts/mcq.py +193 -0
  52. cotlab/prompts/neurology.py +353 -0
  53. cotlab/prompts/oncology.py +367 -0
  54. cotlab/prompts/plab.py +162 -0
  55. cotlab/prompts/pubhealthbench.py +82 -0
  56. cotlab/prompts/pubmedqa.py +173 -0
  57. cotlab/prompts/radiology.py +414 -0
  58. cotlab/prompts/strategies.py +939 -0
  59. cotlab/prompts/tcga.py +168 -0
  60. cotlab/runner.py +204 -0
  61. cotlab-0.8.0.dist-info/METADATA +166 -0
  62. cotlab-0.8.0.dist-info/RECORD +65 -0
  63. cotlab-0.8.0.dist-info/WHEEL +4 -0
  64. cotlab-0.8.0.dist-info/entry_points.txt +3 -0
  65. cotlab-0.8.0.dist-info/licenses/LICENSE +21 -0
@@ -0,0 +1,885 @@
1
+ """Attention Pattern Analysis Experiment.
2
+
3
+ Extracts attention weights at critical layers (55-60) and computes
4
+ attention entropy to understand which tokens each prompt strategy focuses on.
5
+
6
+ Enhanced to support multiple dataset samples for statistical robustness.
7
+ """
8
+
9
+ from collections import defaultdict
10
+ from typing import Any, Dict, List, Optional
11
+
12
+ import numpy as np
13
+ import torch
14
+ from tqdm import tqdm
15
+
16
+ from ..backends.base import InferenceBackend
17
+ from ..core.base import BaseExperiment, ExperimentResult
18
+ from ..core.registry import Registry
19
+ from ..datasets.loaders import BaseDataset
20
+ from ..logging import ExperimentLogger
21
+
22
+
23
+ @Registry.register_experiment("attention_analysis")
24
+ class AttentionAnalysisExperiment(BaseExperiment):
25
+ """
26
+ Analyze attention patterns at critical layers.
27
+
28
+ Computes:
29
+ 1. Last-token attention entropy per head (legacy metric)
30
+ 2. All-tokens mean attention entropy per head (primary metric)
31
+ 3. Optional last-k-tokens mean attention entropy per head
32
+ 4. Optional generated-answer-token span entropy
33
+ 5. Top-attended tokens for focused heads
34
+ 6. Aggregated statistics across multiple samples
35
+ """
36
+
37
+ def __init__(
38
+ self,
39
+ name: str = "attention_analysis",
40
+ description: str = "Analyze attention patterns at critical layers",
41
+ target_layers: Optional[List[int]] = None,
42
+ all_layers: bool = False,
43
+ force_eager_reload: bool = True,
44
+ num_samples: Optional[int] = None,
45
+ last_k_tokens: int = 16,
46
+ max_input_tokens: Optional[int] = 1024,
47
+ analyze_generated_tokens: bool = False,
48
+ generated_max_new_tokens: int = 16,
49
+ generated_do_sample: bool = False,
50
+ generated_temperature: float = 0.7,
51
+ generated_top_p: float = 0.9,
52
+ question: str = "Patient presents with chest pain, sweating, and shortness of breath. What is the diagnosis?",
53
+ batch_size: int = 1,
54
+ layer_stride: int = 1,
55
+ **kwargs,
56
+ ):
57
+ self._name = name
58
+ self.description = description
59
+ # Default to layers 55-60 (critical reasoning layers found earlier)
60
+ self._target_layers_config = target_layers or [55, 56, 57, 58, 59, 60]
61
+ self.all_layers = all_layers
62
+ self.force_eager_reload = force_eager_reload
63
+ self.target_layers = self._target_layers_config
64
+ self.layer_stride = max(1, int(layer_stride))
65
+ self.num_samples = num_samples
66
+ self.last_k_tokens = max(1, int(last_k_tokens))
67
+ self.max_input_tokens = (
68
+ max(1, int(max_input_tokens)) if max_input_tokens is not None else None
69
+ )
70
+ self.analyze_generated_tokens = bool(analyze_generated_tokens)
71
+ self.generated_max_new_tokens = max(1, int(generated_max_new_tokens))
72
+ self.generated_do_sample = bool(generated_do_sample)
73
+ self.generated_temperature = float(generated_temperature)
74
+ self.generated_top_p = float(generated_top_p)
75
+ self.question = question # Fallback if no dataset
76
+ self.batch_size = max(1, int(batch_size))
77
+ self._generated_analysis_disabled = False
78
+
79
+ @property
80
+ def name(self) -> str:
81
+ return self._name
82
+
83
+ def _compute_entropy(self, attn_dist: torch.Tensor) -> float:
84
+ """Compute entropy of attention distribution.
85
+
86
+ Note: Use bfloat16 (not float16) for the model to avoid NaN attention weights.
87
+ """
88
+ eps = 1e-10
89
+ # Compute entropy in float32 for numerical stability.
90
+ probs = attn_dist.float()
91
+ return -torch.sum(probs * torch.log(probs + eps)).item()
92
+
93
+ def _compute_mean_entropy_over_queries(self, attn_qk: torch.Tensor) -> float:
94
+ """Compute mean entropy over query positions for one head.
95
+
96
+ Args:
97
+ attn_qk: Attention tensor of shape (num_queries, seq_len).
98
+ """
99
+ eps = 1e-10
100
+ probs = attn_qk.float()
101
+ entropies = -torch.sum(probs * torch.log(probs + eps), dim=-1)
102
+ return float(entropies.mean().item())
103
+
104
+ def _analyze_generated_token_span(
105
+ self,
106
+ model,
107
+ tokenizer,
108
+ inputs: Dict[str, torch.Tensor],
109
+ num_heads: int,
110
+ ) -> tuple[Dict[int, Dict[str, Any]], int]:
111
+ """Analyze attention entropy over generated answer-token steps.
112
+
113
+ Returns:
114
+ per_layer_stats, num_generated_steps
115
+ """
116
+ pad_token_id = tokenizer.eos_token_id
117
+ if pad_token_id is None:
118
+ pad_token_id = getattr(model.config, "eos_token_id", None)
119
+
120
+ generate_kwargs: Dict[str, Any] = {
121
+ "max_new_tokens": self.generated_max_new_tokens,
122
+ "output_attentions": True,
123
+ "return_dict_in_generate": True,
124
+ "pad_token_id": pad_token_id,
125
+ "do_sample": self.generated_do_sample,
126
+ }
127
+ if self.generated_do_sample:
128
+ generate_kwargs["temperature"] = self.generated_temperature
129
+ generate_kwargs["top_p"] = self.generated_top_p
130
+
131
+ with torch.no_grad():
132
+ gen_outputs = model.generate(**inputs, **generate_kwargs)
133
+
134
+ gen_attentions = getattr(gen_outputs, "attentions", None)
135
+ if not gen_attentions:
136
+ return {}, 0
137
+
138
+ num_generated_steps = len(gen_attentions)
139
+ per_layer_stats: Dict[int, Dict[str, Any]] = {}
140
+
141
+ for layer_idx in self.target_layers:
142
+ per_head_step_entropies: List[List[float]] = [[] for _ in range(num_heads)]
143
+
144
+ for step_attn in gen_attentions:
145
+ layer_attn = None
146
+ if isinstance(step_attn, (tuple, list)):
147
+ if layer_idx < len(step_attn):
148
+ layer_attn = step_attn[layer_idx]
149
+ elif torch.is_tensor(step_attn):
150
+ layer_attn = step_attn
151
+
152
+ if layer_attn is None:
153
+ continue
154
+
155
+ # Typical shape: (batch, heads, q_len, k_len)
156
+ # Some impls may provide (batch, heads, k_len)
157
+ if layer_attn.dim() == 4:
158
+ attn_qk = layer_attn[0] # (heads, q_len, k_len)
159
+ elif layer_attn.dim() == 3:
160
+ attn_qk = layer_attn[0].unsqueeze(1) # (heads, 1, k_len)
161
+ else:
162
+ continue
163
+
164
+ n_heads_eff = min(num_heads, attn_qk.shape[0])
165
+ for h in range(n_heads_eff):
166
+ per_head_step_entropies[h].append(
167
+ self._compute_mean_entropy_over_queries(attn_qk[h])
168
+ )
169
+
170
+ head_entropies_generated_tokens: List[float] = []
171
+ for vals in per_head_step_entropies:
172
+ if vals:
173
+ head_entropies_generated_tokens.append(float(np.nanmean(vals)))
174
+ else:
175
+ head_entropies_generated_tokens.append(float("nan"))
176
+
177
+ if np.all(np.isnan(head_entropies_generated_tokens)):
178
+ continue
179
+
180
+ per_layer_stats[layer_idx] = {
181
+ "avg_entropy_generated_tokens": float(np.nanmean(head_entropies_generated_tokens)),
182
+ "std_entropy_generated_tokens": float(np.nanstd(head_entropies_generated_tokens)),
183
+ "head_entropies_generated_tokens": head_entropies_generated_tokens,
184
+ "generated_tokens_analyzed": num_generated_steps,
185
+ }
186
+
187
+ return per_layer_stats, num_generated_steps
188
+
189
+ def _analyze_batch(
190
+ self,
191
+ model,
192
+ tokenizer,
193
+ prompts: List[str],
194
+ device: str,
195
+ num_heads: int,
196
+ ) -> List[Optional[Dict[str, Any]]]:
197
+ """Analyze attention for a batch of samples in a single forward pass.
198
+
199
+ Tokenizes all prompts together with padding, runs one batched forward pass
200
+ with output_attentions=True, then slices out per-sample entropy stats
201
+ accounting for left/right padding. Attention tensors for each layer are
202
+ freed immediately after processing to keep peak VRAM low.
203
+ """
204
+ tokenizer_kwargs: Dict[str, Any] = {
205
+ "return_tensors": "pt",
206
+ "padding": True,
207
+ }
208
+ if self.max_input_tokens is not None:
209
+ tokenizer_kwargs.update({"truncation": True, "max_length": self.max_input_tokens})
210
+
211
+ tokens = tokenizer(prompts, **tokenizer_kwargs).to(device)
212
+ input_ids = tokens["input_ids"] # (B, padded_seq)
213
+ attention_mask = tokens["attention_mask"] # (B, padded_seq)
214
+ batch_size_actual = input_ids.shape[0]
215
+ total_len = input_ids.shape[1]
216
+
217
+ # Number of real tokens per sample (padding tokens have mask=0)
218
+ seq_lengths = attention_mask.sum(dim=1).tolist()
219
+ pad_left = getattr(tokenizer, "padding_side", "right") == "left"
220
+
221
+ with torch.no_grad():
222
+ outputs = model(**tokens, output_attentions=True, return_dict=True)
223
+
224
+ attentions = outputs.attentions # tuple[(B, heads, padded_seq, padded_seq)] * num_layers
225
+ del outputs # release non-attention outputs immediately
226
+
227
+ if attentions is None or len(attentions) == 0:
228
+ return [None] * batch_size_actual
229
+
230
+ # Build per-sample result containers
231
+ results: List[Dict] = [{} for _ in range(batch_size_actual)]
232
+
233
+ for layer_idx in self.target_layers:
234
+ if layer_idx >= len(attentions):
235
+ continue
236
+
237
+ layer_attn = attentions[layer_idx] # (B, heads, padded_seq, padded_seq)
238
+
239
+ for sample_i in range(batch_size_actual):
240
+ seq_len = int(seq_lengths[sample_i])
241
+ if seq_len == 0:
242
+ continue
243
+
244
+ # Determine real-token slice (exclude padding positions)
245
+ if pad_left:
246
+ start, end = total_len - seq_len, total_len
247
+ else:
248
+ start, end = 0, seq_len
249
+
250
+ # sample_attn: (heads, seq_len, seq_len) — padding stripped
251
+ sample_attn = layer_attn[sample_i, :, start:end, start:end]
252
+
253
+ last_token_attn = sample_attn[:, -1, :] # (heads, seq_len)
254
+ last_k = min(self.last_k_tokens, seq_len)
255
+ last_k_tokens_attn = sample_attn[:, seq_len - last_k :, :] # (heads, k, seq_len)
256
+
257
+ head_entropies_last_token: List[float] = []
258
+ head_entropies_all_tokens: List[float] = []
259
+ head_entropies_last_k_tokens: List[float] = []
260
+ for h in range(num_heads):
261
+ head_entropies_last_token.append(self._compute_entropy(last_token_attn[h]))
262
+ head_entropies_all_tokens.append(
263
+ self._compute_mean_entropy_over_queries(sample_attn[h])
264
+ )
265
+ head_entropies_last_k_tokens.append(
266
+ self._compute_mean_entropy_over_queries(last_k_tokens_attn[h])
267
+ )
268
+
269
+ avg_entropy_last_token = np.mean(head_entropies_last_token)
270
+ avg_entropy_all_tokens = np.mean(head_entropies_all_tokens)
271
+ avg_entropy_last_k_tokens = np.mean(head_entropies_last_k_tokens)
272
+ min_head = int(np.argmin(head_entropies_last_token))
273
+
274
+ # Top-attended tokens for the most focused head
275
+ focused_attn = last_token_attn[min_head]
276
+ top_positions = torch.topk(focused_attn, k=min(5, seq_len))
277
+ top_tokens = []
278
+ for pos, weight in zip(
279
+ top_positions.indices.tolist(), top_positions.values.tolist()
280
+ ):
281
+ actual_pos = start + pos
282
+ token_str = tokenizer.decode([input_ids[sample_i, actual_pos]])
283
+ top_tokens.append({"token": token_str, "weight": weight})
284
+
285
+ results[sample_i][layer_idx] = {
286
+ # Legacy fields preserved
287
+ "avg_entropy": avg_entropy_last_token,
288
+ "head_entropies": head_entropies_last_token,
289
+ "min_entropy": min(head_entropies_last_token),
290
+ "max_entropy": max(head_entropies_last_token),
291
+ # Explicit metrics
292
+ "avg_entropy_last_token": avg_entropy_last_token,
293
+ "avg_entropy_all_tokens": avg_entropy_all_tokens,
294
+ "avg_entropy_last_k_tokens": avg_entropy_last_k_tokens,
295
+ "head_entropies_last_token": head_entropies_last_token,
296
+ "head_entropies_all_tokens": head_entropies_all_tokens,
297
+ "head_entropies_last_k_tokens": head_entropies_last_k_tokens,
298
+ "last_k_tokens_used": last_k,
299
+ # Generated-token fields filled in below
300
+ "avg_entropy_generated_tokens": None,
301
+ "std_entropy_generated_tokens": None,
302
+ "head_entropies_generated_tokens": None,
303
+ "generated_tokens_analyzed": 0,
304
+ "focused_head": min_head,
305
+ "top_tokens": top_tokens,
306
+ }
307
+
308
+ # Free this layer's tensor immediately to keep VRAM headroom
309
+ del layer_attn
310
+
311
+ del attentions
312
+ torch.cuda.empty_cache()
313
+
314
+ # Generated-token analysis: run per-sample (auto-regressive, inherently sequential)
315
+ if self.analyze_generated_tokens and not self._generated_analysis_disabled:
316
+ for sample_i in range(batch_size_actual):
317
+ if not results[sample_i]:
318
+ continue
319
+ seq_len = int(seq_lengths[sample_i])
320
+ if pad_left:
321
+ s = total_len - seq_len
322
+ single_ids = input_ids[sample_i : sample_i + 1, s:]
323
+ single_mask = attention_mask[sample_i : sample_i + 1, s:]
324
+ else:
325
+ single_ids = input_ids[sample_i : sample_i + 1, :seq_len]
326
+ single_mask = attention_mask[sample_i : sample_i + 1, :seq_len]
327
+ single_inputs = {"input_ids": single_ids, "attention_mask": single_mask}
328
+ try:
329
+ gen_stats, gen_steps = self._analyze_generated_token_span(
330
+ model=model,
331
+ tokenizer=tokenizer,
332
+ inputs=single_inputs,
333
+ num_heads=num_heads,
334
+ )
335
+ for layer_idx in self.target_layers:
336
+ if layer_idx in results[sample_i] and layer_idx in gen_stats:
337
+ results[sample_i][layer_idx].update(
338
+ {
339
+ "avg_entropy_generated_tokens": gen_stats[layer_idx].get(
340
+ "avg_entropy_generated_tokens"
341
+ ),
342
+ "std_entropy_generated_tokens": gen_stats[layer_idx].get(
343
+ "std_entropy_generated_tokens"
344
+ ),
345
+ "head_entropies_generated_tokens": gen_stats[layer_idx].get(
346
+ "head_entropies_generated_tokens"
347
+ ),
348
+ "generated_tokens_analyzed": gen_stats[layer_idx].get(
349
+ "generated_tokens_analyzed", gen_steps
350
+ ),
351
+ }
352
+ )
353
+ except Exception as e:
354
+ print(
355
+ f"Warning: generated-token analysis failed for sample {sample_i} "
356
+ f"and will be disabled for this run: {type(e).__name__}: {e}"
357
+ )
358
+ self._generated_analysis_disabled = True
359
+ break
360
+
361
+ return [r if r else None for r in results]
362
+
363
+ def _analyze_single_sample(
364
+ self,
365
+ model,
366
+ tokenizer,
367
+ prompt: str,
368
+ device: str,
369
+ num_heads: int,
370
+ ) -> Dict[str, Any]:
371
+ """Analyze attention for a single sample (delegates to _analyze_batch)."""
372
+ results = self._analyze_batch(model, tokenizer, [prompt], device, num_heads)
373
+ return results[0]
374
+
375
+ def _analyze_single_sample_legacy(
376
+ self,
377
+ model,
378
+ tokenizer,
379
+ prompt: str,
380
+ device: str,
381
+ num_heads: int,
382
+ ) -> Dict[str, Any]:
383
+ """Original single-sample implementation kept for reference."""
384
+ tokenizer_kwargs: Dict[str, Any] = {"return_tensors": "pt"}
385
+ if self.max_input_tokens is not None:
386
+ tokenizer_kwargs.update(
387
+ {
388
+ "truncation": True,
389
+ "max_length": self.max_input_tokens,
390
+ }
391
+ )
392
+ tokens = tokenizer(prompt, **tokenizer_kwargs).to(device)
393
+ input_ids = tokens["input_ids"]
394
+
395
+ with torch.no_grad():
396
+ outputs = model(**tokens, output_attentions=True, return_dict=True)
397
+
398
+ attentions = outputs.attentions
399
+
400
+ if attentions is None or len(attentions) == 0:
401
+ return None
402
+
403
+ sample_results = {}
404
+
405
+ generated_layer_stats: Dict[int, Dict[str, Any]] = {}
406
+ generated_steps = 0
407
+ if self.analyze_generated_tokens and not self._generated_analysis_disabled:
408
+ try:
409
+ generated_layer_stats, generated_steps = self._analyze_generated_token_span(
410
+ model=model,
411
+ tokenizer=tokenizer,
412
+ inputs=tokens,
413
+ num_heads=num_heads,
414
+ )
415
+ except Exception as e:
416
+ print(
417
+ "Warning: generated-token attention analysis failed once and will be disabled "
418
+ f"for this run: {type(e).__name__}: {e}"
419
+ )
420
+ self._generated_analysis_disabled = True
421
+
422
+ for layer_idx in self.target_layers:
423
+ if layer_idx >= len(attentions):
424
+ continue
425
+
426
+ attn = attentions[layer_idx] # (batch, heads, seq, seq)
427
+ seq_len = attn.shape[-1]
428
+ last_token_attn = attn[0, :, -1, :] # (heads, seq)
429
+ all_tokens_attn = attn[0, :, :, :] # (heads, seq, seq)
430
+
431
+ last_k = min(self.last_k_tokens, seq_len)
432
+ last_k_tokens_attn = attn[0, :, seq_len - last_k :, :] # (heads, k, seq)
433
+
434
+ head_entropies_last_token = []
435
+ head_entropies_all_tokens = []
436
+ head_entropies_last_k_tokens = []
437
+ for h in range(num_heads):
438
+ head_entropies_last_token.append(self._compute_entropy(last_token_attn[h]))
439
+ head_entropies_all_tokens.append(
440
+ self._compute_mean_entropy_over_queries(all_tokens_attn[h])
441
+ )
442
+ head_entropies_last_k_tokens.append(
443
+ self._compute_mean_entropy_over_queries(last_k_tokens_attn[h])
444
+ )
445
+
446
+ avg_entropy_last_token = np.mean(head_entropies_last_token)
447
+ avg_entropy_all_tokens = np.mean(head_entropies_all_tokens)
448
+ avg_entropy_last_k_tokens = np.mean(head_entropies_last_k_tokens)
449
+ min_head = int(np.argmin(head_entropies_last_token))
450
+
451
+ # Get top-attended tokens for the most focused head
452
+ focused_head_attn = last_token_attn[min_head]
453
+ top_positions = torch.topk(focused_head_attn, k=min(5, input_ids.shape[1]))
454
+
455
+ top_tokens = []
456
+ for pos, weight in zip(top_positions.indices.tolist(), top_positions.values.tolist()):
457
+ token_str = tokenizer.decode([input_ids[0, pos]])
458
+ top_tokens.append({"token": token_str, "weight": weight})
459
+
460
+ sample_results[layer_idx] = {
461
+ # Legacy fields preserved (last-token)
462
+ "avg_entropy": avg_entropy_last_token,
463
+ "head_entropies": head_entropies_last_token,
464
+ "min_entropy": min(head_entropies_last_token),
465
+ "max_entropy": max(head_entropies_last_token),
466
+ # Explicit metrics
467
+ "avg_entropy_last_token": avg_entropy_last_token,
468
+ "avg_entropy_all_tokens": avg_entropy_all_tokens,
469
+ "avg_entropy_last_k_tokens": avg_entropy_last_k_tokens,
470
+ "head_entropies_last_token": head_entropies_last_token,
471
+ "head_entropies_all_tokens": head_entropies_all_tokens,
472
+ "head_entropies_last_k_tokens": head_entropies_last_k_tokens,
473
+ "last_k_tokens_used": last_k,
474
+ "avg_entropy_generated_tokens": generated_layer_stats.get(layer_idx, {}).get(
475
+ "avg_entropy_generated_tokens"
476
+ ),
477
+ "std_entropy_generated_tokens": generated_layer_stats.get(layer_idx, {}).get(
478
+ "std_entropy_generated_tokens"
479
+ ),
480
+ "head_entropies_generated_tokens": generated_layer_stats.get(layer_idx, {}).get(
481
+ "head_entropies_generated_tokens"
482
+ ),
483
+ "generated_tokens_analyzed": generated_layer_stats.get(layer_idx, {}).get(
484
+ "generated_tokens_analyzed", generated_steps
485
+ ),
486
+ "focused_head": min_head,
487
+ "top_tokens": top_tokens,
488
+ }
489
+
490
+ return sample_results
491
+
492
+ def run(
493
+ self,
494
+ backend: InferenceBackend,
495
+ dataset: BaseDataset,
496
+ prompt_strategy: Any,
497
+ num_samples: Optional[int] = None,
498
+ logger: Optional[ExperimentLogger] = None,
499
+ ) -> ExperimentResult:
500
+ """Run attention analysis experiment on multiple samples."""
501
+
502
+ tokenizer = backend._tokenizer
503
+ model = backend._model
504
+
505
+ # Get model config
506
+ config = model.config
507
+ if hasattr(config, "text_config"):
508
+ config = config.text_config
509
+ num_heads = config.num_attention_heads
510
+
511
+ if self.all_layers:
512
+ num_layers = getattr(config, "num_hidden_layers", None) or getattr(
513
+ config, "num_layers", None
514
+ )
515
+ if num_layers is None:
516
+ num_layers = backend.num_layers()
517
+ self.target_layers = list(range(0, int(num_layers), self.layer_stride))
518
+
519
+ print(f"Model: {backend.model_name}")
520
+ print(f"Attention heads: {num_heads}")
521
+ print(f"All layers enabled: {self.all_layers}")
522
+ print(f"Layer stride: {self.layer_stride}")
523
+ print(f"Resolved layers: {self.target_layers}")
524
+ print(
525
+ f"Max input tokens: {self.max_input_tokens if self.max_input_tokens is not None else 'None'}"
526
+ )
527
+ print(f"Batch size: {self.batch_size}")
528
+ print(f"Analyze generated tokens: {self.analyze_generated_tokens}")
529
+ if self.analyze_generated_tokens:
530
+ print(f"Generated max_new_tokens: {self.generated_max_new_tokens}")
531
+
532
+ # Set eager attention to enable output_attentions by reloading if necessary
533
+ # We need to check if the model is already using eager attention
534
+ current_attn = getattr(model, "config", None) and getattr(
535
+ model.config, "_attn_implementation", None
536
+ )
537
+
538
+ if current_attn != "eager":
539
+ if hasattr(model, "set_attn_implementation") and not self.force_eager_reload:
540
+ print(f"Current attention implementation: {current_attn}")
541
+ print("Switching attention implementation to 'eager' in-place...")
542
+ model.set_attn_implementation("eager")
543
+ current_attn = getattr(model.config, "_attn_implementation", None)
544
+ elif self.force_eager_reload:
545
+ print(f"Current attention implementation: {current_attn}")
546
+ print(
547
+ "Reloading model with attn_implementation='eager' to support output_attentions=True..."
548
+ )
549
+ # We need to preserve the model name before unloading
550
+ model_name = backend.model_name
551
+ backend.unload()
552
+ # Reload with eager attention
553
+ backend.load_model(model_name, attn_implementation="eager")
554
+ model = backend._model
555
+ tokenizer = backend._tokenizer
556
+
557
+ # Get samples from dataset
558
+ n_samples = num_samples if num_samples is not None else self.num_samples
559
+ samples = (
560
+ list(dataset)
561
+ if n_samples is None
562
+ else (dataset.sample(n_samples) if n_samples < len(dataset) else list(dataset))
563
+ )
564
+ print(f"\nAnalyzing attention on {len(samples)} samples (batch_size={self.batch_size})...")
565
+
566
+ # Aggregate statistics across samples
567
+ layer_entropy_stats_last_token: Dict[int, List[float]] = defaultdict(list)
568
+ layer_entropy_stats_all_tokens: Dict[int, List[float]] = defaultdict(list)
569
+ layer_entropy_stats_last_k_tokens: Dict[int, List[float]] = defaultdict(list)
570
+ layer_entropy_stats_generated_tokens: Dict[int, List[float]] = defaultdict(list)
571
+ layer_head_entropy_stats_last_token: Dict[int, List[List[float]]] = defaultdict(list)
572
+ layer_head_entropy_stats_all_tokens: Dict[int, List[List[float]]] = defaultdict(list)
573
+ layer_head_entropy_stats_last_k_tokens: Dict[int, List[List[float]]] = defaultdict(list)
574
+ layer_head_entropy_stats_generated_tokens: Dict[int, List[List[float]]] = defaultdict(list)
575
+ all_top_tokens: Dict[int, List[str]] = defaultdict(list)
576
+
577
+ sample_results = []
578
+
579
+ # Build batches
580
+ batches = [
581
+ samples[i : i + self.batch_size] for i in range(0, len(samples), self.batch_size)
582
+ ]
583
+
584
+ for batch_samples in tqdm(batches, desc="Processing batches"):
585
+ prompts = [
586
+ prompt_strategy.build_prompt(
587
+ {"question": s.text, "text": s.text, "metadata": s.metadata or {}}
588
+ )
589
+ for s in batch_samples
590
+ ]
591
+
592
+ batch_results = self._analyze_batch(
593
+ model, tokenizer, prompts, backend.device, num_heads
594
+ )
595
+
596
+ for sample, result in zip(batch_samples, batch_results):
597
+ if result is None:
598
+ print(f"\nWarning: Attention not available for sample {sample.idx}")
599
+ continue
600
+
601
+ sample_results.append(
602
+ {
603
+ "sample_idx": sample.idx,
604
+ "layer_results": result,
605
+ }
606
+ )
607
+
608
+ # Aggregate stats for this sample
609
+ for layer_idx, layer_data in result.items():
610
+ layer_entropy_stats_last_token[layer_idx].append(
611
+ layer_data["avg_entropy_last_token"]
612
+ )
613
+ layer_entropy_stats_all_tokens[layer_idx].append(
614
+ layer_data["avg_entropy_all_tokens"]
615
+ )
616
+ layer_entropy_stats_last_k_tokens[layer_idx].append(
617
+ layer_data["avg_entropy_last_k_tokens"]
618
+ )
619
+ layer_head_entropy_stats_last_token[layer_idx].append(
620
+ layer_data["head_entropies_last_token"]
621
+ )
622
+ layer_head_entropy_stats_all_tokens[layer_idx].append(
623
+ layer_data["head_entropies_all_tokens"]
624
+ )
625
+ layer_head_entropy_stats_last_k_tokens[layer_idx].append(
626
+ layer_data["head_entropies_last_k_tokens"]
627
+ )
628
+ gen_entropy = layer_data.get("avg_entropy_generated_tokens")
629
+ gen_head_entropies = layer_data.get("head_entropies_generated_tokens")
630
+ if gen_entropy is not None:
631
+ layer_entropy_stats_generated_tokens[layer_idx].append(gen_entropy)
632
+ if gen_head_entropies:
633
+ layer_head_entropy_stats_generated_tokens[layer_idx].append(
634
+ gen_head_entropies
635
+ )
636
+ for tok in layer_data["top_tokens"][:3]: # Top 3 tokens
637
+ all_top_tokens[layer_idx].append(tok["token"])
638
+
639
+ if not sample_results:
640
+ return ExperimentResult(
641
+ experiment_name=self.name,
642
+ model_name=backend.model_name,
643
+ prompt_strategy=prompt_strategy.name
644
+ if hasattr(prompt_strategy, "name")
645
+ else "custom",
646
+ metrics={"error": "attention_not_supported", "num_layers_analyzed": 0},
647
+ raw_outputs=[],
648
+ metadata={"target_layers": self.target_layers},
649
+ )
650
+
651
+ # Compute aggregated statistics
652
+ print("\n" + "=" * 70)
653
+ print("ATTENTION ANALYSIS: Aggregated Statistics Across Samples")
654
+ print("=" * 70)
655
+ header = (
656
+ f"{'Layer':<8} | {'LastTok μ':<10} | {'AllTok μ':<10} | "
657
+ f"{f'Last{self.last_k_tokens} μ':<10} | {'AllTok σ':<10}"
658
+ )
659
+ if self.analyze_generated_tokens:
660
+ header += f" | {'GenTok μ':<10}"
661
+ header += " | Top Tokens"
662
+ print(header)
663
+ print("-" * 106)
664
+
665
+ aggregated_results = []
666
+
667
+ for layer_idx in sorted(layer_entropy_stats_last_token.keys()):
668
+ entropies_last_token = layer_entropy_stats_last_token[layer_idx]
669
+ entropies_all_tokens = layer_entropy_stats_all_tokens[layer_idx]
670
+ entropies_last_k_tokens = layer_entropy_stats_last_k_tokens[layer_idx]
671
+
672
+ mean_entropy_last_token = float(np.nanmean(entropies_last_token))
673
+ std_entropy_last_token = float(np.nanstd(entropies_last_token))
674
+ mean_entropy_all_tokens = float(np.nanmean(entropies_all_tokens))
675
+ std_entropy_all_tokens = float(np.nanstd(entropies_all_tokens))
676
+ mean_entropy_last_k_tokens = float(np.nanmean(entropies_last_k_tokens))
677
+ std_entropy_last_k_tokens = float(np.nanstd(entropies_last_k_tokens))
678
+ if layer_entropy_stats_generated_tokens[layer_idx]:
679
+ mean_entropy_generated_tokens = float(
680
+ np.nanmean(layer_entropy_stats_generated_tokens[layer_idx])
681
+ )
682
+ std_entropy_generated_tokens = float(
683
+ np.nanstd(layer_entropy_stats_generated_tokens[layer_idx])
684
+ )
685
+ else:
686
+ mean_entropy_generated_tokens = float("nan")
687
+ std_entropy_generated_tokens = float("nan")
688
+
689
+ # Count most common top tokens
690
+ tokens = all_top_tokens[layer_idx]
691
+ from collections import Counter
692
+
693
+ token_counts = Counter(tokens)
694
+ top_3_tokens = token_counts.most_common(5)
695
+ top_tokens_str = ", ".join([f"'{t}'" for t, _ in top_3_tokens[:3]])
696
+
697
+ # Aggregate head-level entropies for each metric
698
+ head_entropies_last_token = np.array(layer_head_entropy_stats_last_token[layer_idx])
699
+ head_entropies_all_tokens = np.array(layer_head_entropy_stats_all_tokens[layer_idx])
700
+ head_entropies_last_k_tokens = np.array(
701
+ layer_head_entropy_stats_last_k_tokens[layer_idx]
702
+ )
703
+ mean_per_head_last_token = np.nanmean(head_entropies_last_token, axis=0).tolist()
704
+ std_per_head_last_token = np.nanstd(head_entropies_last_token, axis=0).tolist()
705
+ mean_per_head_all_tokens = np.nanmean(head_entropies_all_tokens, axis=0).tolist()
706
+ std_per_head_all_tokens = np.nanstd(head_entropies_all_tokens, axis=0).tolist()
707
+ mean_per_head_last_k_tokens = np.nanmean(head_entropies_last_k_tokens, axis=0).tolist()
708
+ std_per_head_last_k_tokens = np.nanstd(head_entropies_last_k_tokens, axis=0).tolist()
709
+ if layer_head_entropy_stats_generated_tokens[layer_idx]:
710
+ head_entropies_generated_tokens = np.array(
711
+ layer_head_entropy_stats_generated_tokens[layer_idx]
712
+ )
713
+ mean_per_head_generated_tokens = np.nanmean(
714
+ head_entropies_generated_tokens, axis=0
715
+ ).tolist()
716
+ std_per_head_generated_tokens = np.nanstd(
717
+ head_entropies_generated_tokens, axis=0
718
+ ).tolist()
719
+ else:
720
+ mean_per_head_generated_tokens = []
721
+ std_per_head_generated_tokens = []
722
+
723
+ aggregated_results.append(
724
+ {
725
+ "layer": layer_idx,
726
+ # Legacy keys (last-token metric)
727
+ "mean_entropy": mean_entropy_last_token,
728
+ "std_entropy": std_entropy_last_token,
729
+ "mean_per_head": mean_per_head_last_token,
730
+ "std_per_head": std_per_head_last_token,
731
+ # Explicit metrics
732
+ "mean_entropy_last_token": mean_entropy_last_token,
733
+ "std_entropy_last_token": std_entropy_last_token,
734
+ "mean_entropy_all_tokens": mean_entropy_all_tokens,
735
+ "std_entropy_all_tokens": std_entropy_all_tokens,
736
+ "mean_entropy_last_k_tokens": mean_entropy_last_k_tokens,
737
+ "std_entropy_last_k_tokens": std_entropy_last_k_tokens,
738
+ "mean_entropy_generated_tokens": (
739
+ None
740
+ if np.isnan(mean_entropy_generated_tokens)
741
+ else mean_entropy_generated_tokens
742
+ ),
743
+ "std_entropy_generated_tokens": (
744
+ None
745
+ if np.isnan(std_entropy_generated_tokens)
746
+ else std_entropy_generated_tokens
747
+ ),
748
+ "last_k_tokens": self.last_k_tokens,
749
+ "mean_per_head_last_token": mean_per_head_last_token,
750
+ "std_per_head_last_token": std_per_head_last_token,
751
+ "mean_per_head_all_tokens": mean_per_head_all_tokens,
752
+ "std_per_head_all_tokens": std_per_head_all_tokens,
753
+ "mean_per_head_last_k_tokens": mean_per_head_last_k_tokens,
754
+ "std_per_head_last_k_tokens": std_per_head_last_k_tokens,
755
+ "mean_per_head_generated_tokens": mean_per_head_generated_tokens,
756
+ "std_per_head_generated_tokens": std_per_head_generated_tokens,
757
+ "top_tokens": [{"token": t, "count": c} for t, c in top_3_tokens],
758
+ }
759
+ )
760
+
761
+ row = (
762
+ f"L{layer_idx:<7} | {mean_entropy_last_token:<10.4f} | "
763
+ f"{mean_entropy_all_tokens:<10.4f} | {mean_entropy_last_k_tokens:<10.4f} | "
764
+ f"{std_entropy_all_tokens:<10.4f}"
765
+ )
766
+ if self.analyze_generated_tokens:
767
+ if np.isnan(mean_entropy_generated_tokens):
768
+ row += f" | {'NA':<10}"
769
+ else:
770
+ row += f" | {mean_entropy_generated_tokens:<10.4f}"
771
+ row += f" | {top_tokens_str}"
772
+ print(row)
773
+
774
+ print("-" * 106)
775
+
776
+ # Overall metrics
777
+ all_mean_entropies_last_token = [r["mean_entropy_last_token"] for r in aggregated_results]
778
+ all_mean_entropies_all_tokens = [r["mean_entropy_all_tokens"] for r in aggregated_results]
779
+ all_mean_entropies_last_k_tokens = [
780
+ r["mean_entropy_last_k_tokens"] for r in aggregated_results
781
+ ]
782
+ all_mean_entropies_generated_tokens = [
783
+ r["mean_entropy_generated_tokens"]
784
+ for r in aggregated_results
785
+ if r["mean_entropy_generated_tokens"] is not None
786
+ ]
787
+ overall_mean_last_token = (
788
+ float(np.nanmean(all_mean_entropies_last_token))
789
+ if all_mean_entropies_last_token
790
+ else 0.0
791
+ )
792
+ overall_mean_all_tokens = (
793
+ float(np.nanmean(all_mean_entropies_all_tokens))
794
+ if all_mean_entropies_all_tokens
795
+ else 0.0
796
+ )
797
+ overall_mean_last_k_tokens = (
798
+ float(np.nanmean(all_mean_entropies_last_k_tokens))
799
+ if all_mean_entropies_last_k_tokens
800
+ else 0.0
801
+ )
802
+ overall_mean_generated_tokens = (
803
+ float(np.nanmean(all_mean_entropies_generated_tokens))
804
+ if all_mean_entropies_generated_tokens
805
+ else None
806
+ )
807
+
808
+ # Most focused layer using all-tokens metric (primary)
809
+ valid_layers_all_tokens = [
810
+ r for r in aggregated_results if not np.isnan(r["mean_entropy_all_tokens"])
811
+ ]
812
+ most_focused_layer = (
813
+ min(valid_layers_all_tokens, key=lambda x: x["mean_entropy_all_tokens"])["layer"]
814
+ if valid_layers_all_tokens
815
+ else None
816
+ )
817
+ most_focused_entropy = (
818
+ min(r["mean_entropy_all_tokens"] for r in valid_layers_all_tokens)
819
+ if valid_layers_all_tokens
820
+ else 0.0
821
+ )
822
+
823
+ # Legacy most-focused values for last-token metric
824
+ valid_layers_last_token = [
825
+ r for r in aggregated_results if not np.isnan(r["mean_entropy_last_token"])
826
+ ]
827
+ most_focused_layer_last_token = (
828
+ min(valid_layers_last_token, key=lambda x: x["mean_entropy_last_token"])["layer"]
829
+ if valid_layers_last_token
830
+ else None
831
+ )
832
+ most_focused_entropy_last_token = (
833
+ min(r["mean_entropy_last_token"] for r in valid_layers_last_token)
834
+ if valid_layers_last_token
835
+ else 0.0
836
+ )
837
+
838
+ metrics = {
839
+ "num_samples_analyzed": len(sample_results),
840
+ "num_layers_analyzed": len(aggregated_results),
841
+ "num_heads": num_heads,
842
+ # Primary metrics (all-tokens)
843
+ "overall_mean_entropy": float(overall_mean_all_tokens),
844
+ "overall_mean_entropy_all_tokens": float(overall_mean_all_tokens),
845
+ "overall_mean_entropy_last_token": float(overall_mean_last_token),
846
+ "overall_mean_entropy_last_k_tokens": float(overall_mean_last_k_tokens),
847
+ "last_k_tokens": self.last_k_tokens,
848
+ "most_focused_layer": most_focused_layer,
849
+ "most_focused_entropy": float(most_focused_entropy),
850
+ "most_focused_layer_all_tokens": most_focused_layer,
851
+ "most_focused_entropy_all_tokens": float(most_focused_entropy),
852
+ "most_focused_layer_last_token": most_focused_layer_last_token,
853
+ "most_focused_entropy_last_token": float(most_focused_entropy_last_token),
854
+ "analyze_generated_tokens": self.analyze_generated_tokens,
855
+ }
856
+ if overall_mean_generated_tokens is not None:
857
+ metrics["overall_mean_entropy_generated_tokens"] = float(overall_mean_generated_tokens)
858
+
859
+ print(f"\nOverall mean entropy (all tokens): {overall_mean_all_tokens:.4f}")
860
+ print(f"Overall mean entropy (last token): {overall_mean_last_token:.4f}")
861
+ print(
862
+ f"Overall mean entropy (last {self.last_k_tokens} tokens): {overall_mean_last_k_tokens:.4f}"
863
+ )
864
+ if overall_mean_generated_tokens is not None:
865
+ print(f"Overall mean entropy (generated tokens): {overall_mean_generated_tokens:.4f}")
866
+ print(
867
+ f"Most focused layer (all tokens): L{most_focused_layer} (entropy: {most_focused_entropy:.4f})"
868
+ )
869
+
870
+ return ExperimentResult(
871
+ experiment_name=self.name,
872
+ model_name=backend.model_name,
873
+ prompt_strategy=prompt_strategy.name if hasattr(prompt_strategy, "name") else "custom",
874
+ metrics=metrics,
875
+ raw_outputs=aggregated_results,
876
+ metadata={
877
+ "target_layers": self.target_layers,
878
+ "last_k_tokens": self.last_k_tokens,
879
+ "analyze_generated_tokens": self.analyze_generated_tokens,
880
+ "generated_max_new_tokens": self.generated_max_new_tokens,
881
+ "batch_size": self.batch_size,
882
+ "num_samples": len(samples),
883
+ "sample_results": sample_results, # Include per-sample data
884
+ },
885
+ )