cotlab 0.8.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cotlab/__init__.py +3 -0
- cotlab/analyse_experiments.py +392 -0
- cotlab/analysis/__init__.py +11 -0
- cotlab/analysis/cot_parser.py +243 -0
- cotlab/analysis/faithfulness_metrics.py +192 -0
- cotlab/backends/__init__.py +16 -0
- cotlab/backends/base.py +78 -0
- cotlab/backends/transformers_backend.py +335 -0
- cotlab/backends/vllm_backend.py +227 -0
- cotlab/cli.py +83 -0
- cotlab/core/__init__.py +34 -0
- cotlab/core/base.py +749 -0
- cotlab/core/config.py +90 -0
- cotlab/core/registry.py +68 -0
- cotlab/datasets/__init__.py +45 -0
- cotlab/datasets/loaders.py +1889 -0
- cotlab/experiment/__init__.py +315 -0
- cotlab/experiments/__init__.py +43 -0
- cotlab/experiments/activation_compare.py +290 -0
- cotlab/experiments/activation_patching.py +1050 -0
- cotlab/experiments/attention_analysis.py +885 -0
- cotlab/experiments/classification.py +235 -0
- cotlab/experiments/composite_shift_detector.py +524 -0
- cotlab/experiments/cot_ablation.py +277 -0
- cotlab/experiments/cot_faithfulness.py +187 -0
- cotlab/experiments/cot_heads.py +208 -0
- cotlab/experiments/full_layer_cot.py +232 -0
- cotlab/experiments/full_layer_patching.py +225 -0
- cotlab/experiments/h_neuron_analysis.py +712 -0
- cotlab/experiments/logit_lens.py +439 -0
- cotlab/experiments/multi_head_cot.py +220 -0
- cotlab/experiments/multi_head_patching.py +229 -0
- cotlab/experiments/probing_classifier.py +402 -0
- cotlab/experiments/residual_norm_ood.py +413 -0
- cotlab/experiments/sae_feature_analysis.py +673 -0
- cotlab/experiments/steering_vectors.py +223 -0
- cotlab/experiments/sycophancy_heads.py +224 -0
- cotlab/logging/__init__.py +5 -0
- cotlab/logging/json_logger.py +161 -0
- cotlab/main.py +317 -0
- cotlab/patching/__init__.py +24 -0
- cotlab/patching/cache.py +141 -0
- cotlab/patching/hooks.py +558 -0
- cotlab/patching/interventions.py +86 -0
- cotlab/patching/patcher.py +439 -0
- cotlab/patching/sae.py +181 -0
- cotlab/prompts/__init__.py +43 -0
- cotlab/prompts/cardiology.py +378 -0
- cotlab/prompts/histopathology.py +265 -0
- cotlab/prompts/length_matched_strategies.py +157 -0
- cotlab/prompts/mcq.py +193 -0
- cotlab/prompts/neurology.py +353 -0
- cotlab/prompts/oncology.py +367 -0
- cotlab/prompts/plab.py +162 -0
- cotlab/prompts/pubhealthbench.py +82 -0
- cotlab/prompts/pubmedqa.py +173 -0
- cotlab/prompts/radiology.py +414 -0
- cotlab/prompts/strategies.py +939 -0
- cotlab/prompts/tcga.py +168 -0
- cotlab/runner.py +204 -0
- cotlab-0.8.0.dist-info/METADATA +166 -0
- cotlab-0.8.0.dist-info/RECORD +65 -0
- cotlab-0.8.0.dist-info/WHEEL +4 -0
- cotlab-0.8.0.dist-info/entry_points.txt +3 -0
- cotlab-0.8.0.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,885 @@
|
|
|
1
|
+
"""Attention Pattern Analysis Experiment.
|
|
2
|
+
|
|
3
|
+
Extracts attention weights at critical layers (55-60) and computes
|
|
4
|
+
attention entropy to understand which tokens each prompt strategy focuses on.
|
|
5
|
+
|
|
6
|
+
Enhanced to support multiple dataset samples for statistical robustness.
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
from collections import defaultdict
|
|
10
|
+
from typing import Any, Dict, List, Optional
|
|
11
|
+
|
|
12
|
+
import numpy as np
|
|
13
|
+
import torch
|
|
14
|
+
from tqdm import tqdm
|
|
15
|
+
|
|
16
|
+
from ..backends.base import InferenceBackend
|
|
17
|
+
from ..core.base import BaseExperiment, ExperimentResult
|
|
18
|
+
from ..core.registry import Registry
|
|
19
|
+
from ..datasets.loaders import BaseDataset
|
|
20
|
+
from ..logging import ExperimentLogger
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
@Registry.register_experiment("attention_analysis")
|
|
24
|
+
class AttentionAnalysisExperiment(BaseExperiment):
|
|
25
|
+
"""
|
|
26
|
+
Analyze attention patterns at critical layers.
|
|
27
|
+
|
|
28
|
+
Computes:
|
|
29
|
+
1. Last-token attention entropy per head (legacy metric)
|
|
30
|
+
2. All-tokens mean attention entropy per head (primary metric)
|
|
31
|
+
3. Optional last-k-tokens mean attention entropy per head
|
|
32
|
+
4. Optional generated-answer-token span entropy
|
|
33
|
+
5. Top-attended tokens for focused heads
|
|
34
|
+
6. Aggregated statistics across multiple samples
|
|
35
|
+
"""
|
|
36
|
+
|
|
37
|
+
def __init__(
|
|
38
|
+
self,
|
|
39
|
+
name: str = "attention_analysis",
|
|
40
|
+
description: str = "Analyze attention patterns at critical layers",
|
|
41
|
+
target_layers: Optional[List[int]] = None,
|
|
42
|
+
all_layers: bool = False,
|
|
43
|
+
force_eager_reload: bool = True,
|
|
44
|
+
num_samples: Optional[int] = None,
|
|
45
|
+
last_k_tokens: int = 16,
|
|
46
|
+
max_input_tokens: Optional[int] = 1024,
|
|
47
|
+
analyze_generated_tokens: bool = False,
|
|
48
|
+
generated_max_new_tokens: int = 16,
|
|
49
|
+
generated_do_sample: bool = False,
|
|
50
|
+
generated_temperature: float = 0.7,
|
|
51
|
+
generated_top_p: float = 0.9,
|
|
52
|
+
question: str = "Patient presents with chest pain, sweating, and shortness of breath. What is the diagnosis?",
|
|
53
|
+
batch_size: int = 1,
|
|
54
|
+
layer_stride: int = 1,
|
|
55
|
+
**kwargs,
|
|
56
|
+
):
|
|
57
|
+
self._name = name
|
|
58
|
+
self.description = description
|
|
59
|
+
# Default to layers 55-60 (critical reasoning layers found earlier)
|
|
60
|
+
self._target_layers_config = target_layers or [55, 56, 57, 58, 59, 60]
|
|
61
|
+
self.all_layers = all_layers
|
|
62
|
+
self.force_eager_reload = force_eager_reload
|
|
63
|
+
self.target_layers = self._target_layers_config
|
|
64
|
+
self.layer_stride = max(1, int(layer_stride))
|
|
65
|
+
self.num_samples = num_samples
|
|
66
|
+
self.last_k_tokens = max(1, int(last_k_tokens))
|
|
67
|
+
self.max_input_tokens = (
|
|
68
|
+
max(1, int(max_input_tokens)) if max_input_tokens is not None else None
|
|
69
|
+
)
|
|
70
|
+
self.analyze_generated_tokens = bool(analyze_generated_tokens)
|
|
71
|
+
self.generated_max_new_tokens = max(1, int(generated_max_new_tokens))
|
|
72
|
+
self.generated_do_sample = bool(generated_do_sample)
|
|
73
|
+
self.generated_temperature = float(generated_temperature)
|
|
74
|
+
self.generated_top_p = float(generated_top_p)
|
|
75
|
+
self.question = question # Fallback if no dataset
|
|
76
|
+
self.batch_size = max(1, int(batch_size))
|
|
77
|
+
self._generated_analysis_disabled = False
|
|
78
|
+
|
|
79
|
+
@property
|
|
80
|
+
def name(self) -> str:
|
|
81
|
+
return self._name
|
|
82
|
+
|
|
83
|
+
def _compute_entropy(self, attn_dist: torch.Tensor) -> float:
|
|
84
|
+
"""Compute entropy of attention distribution.
|
|
85
|
+
|
|
86
|
+
Note: Use bfloat16 (not float16) for the model to avoid NaN attention weights.
|
|
87
|
+
"""
|
|
88
|
+
eps = 1e-10
|
|
89
|
+
# Compute entropy in float32 for numerical stability.
|
|
90
|
+
probs = attn_dist.float()
|
|
91
|
+
return -torch.sum(probs * torch.log(probs + eps)).item()
|
|
92
|
+
|
|
93
|
+
def _compute_mean_entropy_over_queries(self, attn_qk: torch.Tensor) -> float:
|
|
94
|
+
"""Compute mean entropy over query positions for one head.
|
|
95
|
+
|
|
96
|
+
Args:
|
|
97
|
+
attn_qk: Attention tensor of shape (num_queries, seq_len).
|
|
98
|
+
"""
|
|
99
|
+
eps = 1e-10
|
|
100
|
+
probs = attn_qk.float()
|
|
101
|
+
entropies = -torch.sum(probs * torch.log(probs + eps), dim=-1)
|
|
102
|
+
return float(entropies.mean().item())
|
|
103
|
+
|
|
104
|
+
def _analyze_generated_token_span(
|
|
105
|
+
self,
|
|
106
|
+
model,
|
|
107
|
+
tokenizer,
|
|
108
|
+
inputs: Dict[str, torch.Tensor],
|
|
109
|
+
num_heads: int,
|
|
110
|
+
) -> tuple[Dict[int, Dict[str, Any]], int]:
|
|
111
|
+
"""Analyze attention entropy over generated answer-token steps.
|
|
112
|
+
|
|
113
|
+
Returns:
|
|
114
|
+
per_layer_stats, num_generated_steps
|
|
115
|
+
"""
|
|
116
|
+
pad_token_id = tokenizer.eos_token_id
|
|
117
|
+
if pad_token_id is None:
|
|
118
|
+
pad_token_id = getattr(model.config, "eos_token_id", None)
|
|
119
|
+
|
|
120
|
+
generate_kwargs: Dict[str, Any] = {
|
|
121
|
+
"max_new_tokens": self.generated_max_new_tokens,
|
|
122
|
+
"output_attentions": True,
|
|
123
|
+
"return_dict_in_generate": True,
|
|
124
|
+
"pad_token_id": pad_token_id,
|
|
125
|
+
"do_sample": self.generated_do_sample,
|
|
126
|
+
}
|
|
127
|
+
if self.generated_do_sample:
|
|
128
|
+
generate_kwargs["temperature"] = self.generated_temperature
|
|
129
|
+
generate_kwargs["top_p"] = self.generated_top_p
|
|
130
|
+
|
|
131
|
+
with torch.no_grad():
|
|
132
|
+
gen_outputs = model.generate(**inputs, **generate_kwargs)
|
|
133
|
+
|
|
134
|
+
gen_attentions = getattr(gen_outputs, "attentions", None)
|
|
135
|
+
if not gen_attentions:
|
|
136
|
+
return {}, 0
|
|
137
|
+
|
|
138
|
+
num_generated_steps = len(gen_attentions)
|
|
139
|
+
per_layer_stats: Dict[int, Dict[str, Any]] = {}
|
|
140
|
+
|
|
141
|
+
for layer_idx in self.target_layers:
|
|
142
|
+
per_head_step_entropies: List[List[float]] = [[] for _ in range(num_heads)]
|
|
143
|
+
|
|
144
|
+
for step_attn in gen_attentions:
|
|
145
|
+
layer_attn = None
|
|
146
|
+
if isinstance(step_attn, (tuple, list)):
|
|
147
|
+
if layer_idx < len(step_attn):
|
|
148
|
+
layer_attn = step_attn[layer_idx]
|
|
149
|
+
elif torch.is_tensor(step_attn):
|
|
150
|
+
layer_attn = step_attn
|
|
151
|
+
|
|
152
|
+
if layer_attn is None:
|
|
153
|
+
continue
|
|
154
|
+
|
|
155
|
+
# Typical shape: (batch, heads, q_len, k_len)
|
|
156
|
+
# Some impls may provide (batch, heads, k_len)
|
|
157
|
+
if layer_attn.dim() == 4:
|
|
158
|
+
attn_qk = layer_attn[0] # (heads, q_len, k_len)
|
|
159
|
+
elif layer_attn.dim() == 3:
|
|
160
|
+
attn_qk = layer_attn[0].unsqueeze(1) # (heads, 1, k_len)
|
|
161
|
+
else:
|
|
162
|
+
continue
|
|
163
|
+
|
|
164
|
+
n_heads_eff = min(num_heads, attn_qk.shape[0])
|
|
165
|
+
for h in range(n_heads_eff):
|
|
166
|
+
per_head_step_entropies[h].append(
|
|
167
|
+
self._compute_mean_entropy_over_queries(attn_qk[h])
|
|
168
|
+
)
|
|
169
|
+
|
|
170
|
+
head_entropies_generated_tokens: List[float] = []
|
|
171
|
+
for vals in per_head_step_entropies:
|
|
172
|
+
if vals:
|
|
173
|
+
head_entropies_generated_tokens.append(float(np.nanmean(vals)))
|
|
174
|
+
else:
|
|
175
|
+
head_entropies_generated_tokens.append(float("nan"))
|
|
176
|
+
|
|
177
|
+
if np.all(np.isnan(head_entropies_generated_tokens)):
|
|
178
|
+
continue
|
|
179
|
+
|
|
180
|
+
per_layer_stats[layer_idx] = {
|
|
181
|
+
"avg_entropy_generated_tokens": float(np.nanmean(head_entropies_generated_tokens)),
|
|
182
|
+
"std_entropy_generated_tokens": float(np.nanstd(head_entropies_generated_tokens)),
|
|
183
|
+
"head_entropies_generated_tokens": head_entropies_generated_tokens,
|
|
184
|
+
"generated_tokens_analyzed": num_generated_steps,
|
|
185
|
+
}
|
|
186
|
+
|
|
187
|
+
return per_layer_stats, num_generated_steps
|
|
188
|
+
|
|
189
|
+
def _analyze_batch(
|
|
190
|
+
self,
|
|
191
|
+
model,
|
|
192
|
+
tokenizer,
|
|
193
|
+
prompts: List[str],
|
|
194
|
+
device: str,
|
|
195
|
+
num_heads: int,
|
|
196
|
+
) -> List[Optional[Dict[str, Any]]]:
|
|
197
|
+
"""Analyze attention for a batch of samples in a single forward pass.
|
|
198
|
+
|
|
199
|
+
Tokenizes all prompts together with padding, runs one batched forward pass
|
|
200
|
+
with output_attentions=True, then slices out per-sample entropy stats
|
|
201
|
+
accounting for left/right padding. Attention tensors for each layer are
|
|
202
|
+
freed immediately after processing to keep peak VRAM low.
|
|
203
|
+
"""
|
|
204
|
+
tokenizer_kwargs: Dict[str, Any] = {
|
|
205
|
+
"return_tensors": "pt",
|
|
206
|
+
"padding": True,
|
|
207
|
+
}
|
|
208
|
+
if self.max_input_tokens is not None:
|
|
209
|
+
tokenizer_kwargs.update({"truncation": True, "max_length": self.max_input_tokens})
|
|
210
|
+
|
|
211
|
+
tokens = tokenizer(prompts, **tokenizer_kwargs).to(device)
|
|
212
|
+
input_ids = tokens["input_ids"] # (B, padded_seq)
|
|
213
|
+
attention_mask = tokens["attention_mask"] # (B, padded_seq)
|
|
214
|
+
batch_size_actual = input_ids.shape[0]
|
|
215
|
+
total_len = input_ids.shape[1]
|
|
216
|
+
|
|
217
|
+
# Number of real tokens per sample (padding tokens have mask=0)
|
|
218
|
+
seq_lengths = attention_mask.sum(dim=1).tolist()
|
|
219
|
+
pad_left = getattr(tokenizer, "padding_side", "right") == "left"
|
|
220
|
+
|
|
221
|
+
with torch.no_grad():
|
|
222
|
+
outputs = model(**tokens, output_attentions=True, return_dict=True)
|
|
223
|
+
|
|
224
|
+
attentions = outputs.attentions # tuple[(B, heads, padded_seq, padded_seq)] * num_layers
|
|
225
|
+
del outputs # release non-attention outputs immediately
|
|
226
|
+
|
|
227
|
+
if attentions is None or len(attentions) == 0:
|
|
228
|
+
return [None] * batch_size_actual
|
|
229
|
+
|
|
230
|
+
# Build per-sample result containers
|
|
231
|
+
results: List[Dict] = [{} for _ in range(batch_size_actual)]
|
|
232
|
+
|
|
233
|
+
for layer_idx in self.target_layers:
|
|
234
|
+
if layer_idx >= len(attentions):
|
|
235
|
+
continue
|
|
236
|
+
|
|
237
|
+
layer_attn = attentions[layer_idx] # (B, heads, padded_seq, padded_seq)
|
|
238
|
+
|
|
239
|
+
for sample_i in range(batch_size_actual):
|
|
240
|
+
seq_len = int(seq_lengths[sample_i])
|
|
241
|
+
if seq_len == 0:
|
|
242
|
+
continue
|
|
243
|
+
|
|
244
|
+
# Determine real-token slice (exclude padding positions)
|
|
245
|
+
if pad_left:
|
|
246
|
+
start, end = total_len - seq_len, total_len
|
|
247
|
+
else:
|
|
248
|
+
start, end = 0, seq_len
|
|
249
|
+
|
|
250
|
+
# sample_attn: (heads, seq_len, seq_len) — padding stripped
|
|
251
|
+
sample_attn = layer_attn[sample_i, :, start:end, start:end]
|
|
252
|
+
|
|
253
|
+
last_token_attn = sample_attn[:, -1, :] # (heads, seq_len)
|
|
254
|
+
last_k = min(self.last_k_tokens, seq_len)
|
|
255
|
+
last_k_tokens_attn = sample_attn[:, seq_len - last_k :, :] # (heads, k, seq_len)
|
|
256
|
+
|
|
257
|
+
head_entropies_last_token: List[float] = []
|
|
258
|
+
head_entropies_all_tokens: List[float] = []
|
|
259
|
+
head_entropies_last_k_tokens: List[float] = []
|
|
260
|
+
for h in range(num_heads):
|
|
261
|
+
head_entropies_last_token.append(self._compute_entropy(last_token_attn[h]))
|
|
262
|
+
head_entropies_all_tokens.append(
|
|
263
|
+
self._compute_mean_entropy_over_queries(sample_attn[h])
|
|
264
|
+
)
|
|
265
|
+
head_entropies_last_k_tokens.append(
|
|
266
|
+
self._compute_mean_entropy_over_queries(last_k_tokens_attn[h])
|
|
267
|
+
)
|
|
268
|
+
|
|
269
|
+
avg_entropy_last_token = np.mean(head_entropies_last_token)
|
|
270
|
+
avg_entropy_all_tokens = np.mean(head_entropies_all_tokens)
|
|
271
|
+
avg_entropy_last_k_tokens = np.mean(head_entropies_last_k_tokens)
|
|
272
|
+
min_head = int(np.argmin(head_entropies_last_token))
|
|
273
|
+
|
|
274
|
+
# Top-attended tokens for the most focused head
|
|
275
|
+
focused_attn = last_token_attn[min_head]
|
|
276
|
+
top_positions = torch.topk(focused_attn, k=min(5, seq_len))
|
|
277
|
+
top_tokens = []
|
|
278
|
+
for pos, weight in zip(
|
|
279
|
+
top_positions.indices.tolist(), top_positions.values.tolist()
|
|
280
|
+
):
|
|
281
|
+
actual_pos = start + pos
|
|
282
|
+
token_str = tokenizer.decode([input_ids[sample_i, actual_pos]])
|
|
283
|
+
top_tokens.append({"token": token_str, "weight": weight})
|
|
284
|
+
|
|
285
|
+
results[sample_i][layer_idx] = {
|
|
286
|
+
# Legacy fields preserved
|
|
287
|
+
"avg_entropy": avg_entropy_last_token,
|
|
288
|
+
"head_entropies": head_entropies_last_token,
|
|
289
|
+
"min_entropy": min(head_entropies_last_token),
|
|
290
|
+
"max_entropy": max(head_entropies_last_token),
|
|
291
|
+
# Explicit metrics
|
|
292
|
+
"avg_entropy_last_token": avg_entropy_last_token,
|
|
293
|
+
"avg_entropy_all_tokens": avg_entropy_all_tokens,
|
|
294
|
+
"avg_entropy_last_k_tokens": avg_entropy_last_k_tokens,
|
|
295
|
+
"head_entropies_last_token": head_entropies_last_token,
|
|
296
|
+
"head_entropies_all_tokens": head_entropies_all_tokens,
|
|
297
|
+
"head_entropies_last_k_tokens": head_entropies_last_k_tokens,
|
|
298
|
+
"last_k_tokens_used": last_k,
|
|
299
|
+
# Generated-token fields filled in below
|
|
300
|
+
"avg_entropy_generated_tokens": None,
|
|
301
|
+
"std_entropy_generated_tokens": None,
|
|
302
|
+
"head_entropies_generated_tokens": None,
|
|
303
|
+
"generated_tokens_analyzed": 0,
|
|
304
|
+
"focused_head": min_head,
|
|
305
|
+
"top_tokens": top_tokens,
|
|
306
|
+
}
|
|
307
|
+
|
|
308
|
+
# Free this layer's tensor immediately to keep VRAM headroom
|
|
309
|
+
del layer_attn
|
|
310
|
+
|
|
311
|
+
del attentions
|
|
312
|
+
torch.cuda.empty_cache()
|
|
313
|
+
|
|
314
|
+
# Generated-token analysis: run per-sample (auto-regressive, inherently sequential)
|
|
315
|
+
if self.analyze_generated_tokens and not self._generated_analysis_disabled:
|
|
316
|
+
for sample_i in range(batch_size_actual):
|
|
317
|
+
if not results[sample_i]:
|
|
318
|
+
continue
|
|
319
|
+
seq_len = int(seq_lengths[sample_i])
|
|
320
|
+
if pad_left:
|
|
321
|
+
s = total_len - seq_len
|
|
322
|
+
single_ids = input_ids[sample_i : sample_i + 1, s:]
|
|
323
|
+
single_mask = attention_mask[sample_i : sample_i + 1, s:]
|
|
324
|
+
else:
|
|
325
|
+
single_ids = input_ids[sample_i : sample_i + 1, :seq_len]
|
|
326
|
+
single_mask = attention_mask[sample_i : sample_i + 1, :seq_len]
|
|
327
|
+
single_inputs = {"input_ids": single_ids, "attention_mask": single_mask}
|
|
328
|
+
try:
|
|
329
|
+
gen_stats, gen_steps = self._analyze_generated_token_span(
|
|
330
|
+
model=model,
|
|
331
|
+
tokenizer=tokenizer,
|
|
332
|
+
inputs=single_inputs,
|
|
333
|
+
num_heads=num_heads,
|
|
334
|
+
)
|
|
335
|
+
for layer_idx in self.target_layers:
|
|
336
|
+
if layer_idx in results[sample_i] and layer_idx in gen_stats:
|
|
337
|
+
results[sample_i][layer_idx].update(
|
|
338
|
+
{
|
|
339
|
+
"avg_entropy_generated_tokens": gen_stats[layer_idx].get(
|
|
340
|
+
"avg_entropy_generated_tokens"
|
|
341
|
+
),
|
|
342
|
+
"std_entropy_generated_tokens": gen_stats[layer_idx].get(
|
|
343
|
+
"std_entropy_generated_tokens"
|
|
344
|
+
),
|
|
345
|
+
"head_entropies_generated_tokens": gen_stats[layer_idx].get(
|
|
346
|
+
"head_entropies_generated_tokens"
|
|
347
|
+
),
|
|
348
|
+
"generated_tokens_analyzed": gen_stats[layer_idx].get(
|
|
349
|
+
"generated_tokens_analyzed", gen_steps
|
|
350
|
+
),
|
|
351
|
+
}
|
|
352
|
+
)
|
|
353
|
+
except Exception as e:
|
|
354
|
+
print(
|
|
355
|
+
f"Warning: generated-token analysis failed for sample {sample_i} "
|
|
356
|
+
f"and will be disabled for this run: {type(e).__name__}: {e}"
|
|
357
|
+
)
|
|
358
|
+
self._generated_analysis_disabled = True
|
|
359
|
+
break
|
|
360
|
+
|
|
361
|
+
return [r if r else None for r in results]
|
|
362
|
+
|
|
363
|
+
def _analyze_single_sample(
|
|
364
|
+
self,
|
|
365
|
+
model,
|
|
366
|
+
tokenizer,
|
|
367
|
+
prompt: str,
|
|
368
|
+
device: str,
|
|
369
|
+
num_heads: int,
|
|
370
|
+
) -> Dict[str, Any]:
|
|
371
|
+
"""Analyze attention for a single sample (delegates to _analyze_batch)."""
|
|
372
|
+
results = self._analyze_batch(model, tokenizer, [prompt], device, num_heads)
|
|
373
|
+
return results[0]
|
|
374
|
+
|
|
375
|
+
def _analyze_single_sample_legacy(
|
|
376
|
+
self,
|
|
377
|
+
model,
|
|
378
|
+
tokenizer,
|
|
379
|
+
prompt: str,
|
|
380
|
+
device: str,
|
|
381
|
+
num_heads: int,
|
|
382
|
+
) -> Dict[str, Any]:
|
|
383
|
+
"""Original single-sample implementation kept for reference."""
|
|
384
|
+
tokenizer_kwargs: Dict[str, Any] = {"return_tensors": "pt"}
|
|
385
|
+
if self.max_input_tokens is not None:
|
|
386
|
+
tokenizer_kwargs.update(
|
|
387
|
+
{
|
|
388
|
+
"truncation": True,
|
|
389
|
+
"max_length": self.max_input_tokens,
|
|
390
|
+
}
|
|
391
|
+
)
|
|
392
|
+
tokens = tokenizer(prompt, **tokenizer_kwargs).to(device)
|
|
393
|
+
input_ids = tokens["input_ids"]
|
|
394
|
+
|
|
395
|
+
with torch.no_grad():
|
|
396
|
+
outputs = model(**tokens, output_attentions=True, return_dict=True)
|
|
397
|
+
|
|
398
|
+
attentions = outputs.attentions
|
|
399
|
+
|
|
400
|
+
if attentions is None or len(attentions) == 0:
|
|
401
|
+
return None
|
|
402
|
+
|
|
403
|
+
sample_results = {}
|
|
404
|
+
|
|
405
|
+
generated_layer_stats: Dict[int, Dict[str, Any]] = {}
|
|
406
|
+
generated_steps = 0
|
|
407
|
+
if self.analyze_generated_tokens and not self._generated_analysis_disabled:
|
|
408
|
+
try:
|
|
409
|
+
generated_layer_stats, generated_steps = self._analyze_generated_token_span(
|
|
410
|
+
model=model,
|
|
411
|
+
tokenizer=tokenizer,
|
|
412
|
+
inputs=tokens,
|
|
413
|
+
num_heads=num_heads,
|
|
414
|
+
)
|
|
415
|
+
except Exception as e:
|
|
416
|
+
print(
|
|
417
|
+
"Warning: generated-token attention analysis failed once and will be disabled "
|
|
418
|
+
f"for this run: {type(e).__name__}: {e}"
|
|
419
|
+
)
|
|
420
|
+
self._generated_analysis_disabled = True
|
|
421
|
+
|
|
422
|
+
for layer_idx in self.target_layers:
|
|
423
|
+
if layer_idx >= len(attentions):
|
|
424
|
+
continue
|
|
425
|
+
|
|
426
|
+
attn = attentions[layer_idx] # (batch, heads, seq, seq)
|
|
427
|
+
seq_len = attn.shape[-1]
|
|
428
|
+
last_token_attn = attn[0, :, -1, :] # (heads, seq)
|
|
429
|
+
all_tokens_attn = attn[0, :, :, :] # (heads, seq, seq)
|
|
430
|
+
|
|
431
|
+
last_k = min(self.last_k_tokens, seq_len)
|
|
432
|
+
last_k_tokens_attn = attn[0, :, seq_len - last_k :, :] # (heads, k, seq)
|
|
433
|
+
|
|
434
|
+
head_entropies_last_token = []
|
|
435
|
+
head_entropies_all_tokens = []
|
|
436
|
+
head_entropies_last_k_tokens = []
|
|
437
|
+
for h in range(num_heads):
|
|
438
|
+
head_entropies_last_token.append(self._compute_entropy(last_token_attn[h]))
|
|
439
|
+
head_entropies_all_tokens.append(
|
|
440
|
+
self._compute_mean_entropy_over_queries(all_tokens_attn[h])
|
|
441
|
+
)
|
|
442
|
+
head_entropies_last_k_tokens.append(
|
|
443
|
+
self._compute_mean_entropy_over_queries(last_k_tokens_attn[h])
|
|
444
|
+
)
|
|
445
|
+
|
|
446
|
+
avg_entropy_last_token = np.mean(head_entropies_last_token)
|
|
447
|
+
avg_entropy_all_tokens = np.mean(head_entropies_all_tokens)
|
|
448
|
+
avg_entropy_last_k_tokens = np.mean(head_entropies_last_k_tokens)
|
|
449
|
+
min_head = int(np.argmin(head_entropies_last_token))
|
|
450
|
+
|
|
451
|
+
# Get top-attended tokens for the most focused head
|
|
452
|
+
focused_head_attn = last_token_attn[min_head]
|
|
453
|
+
top_positions = torch.topk(focused_head_attn, k=min(5, input_ids.shape[1]))
|
|
454
|
+
|
|
455
|
+
top_tokens = []
|
|
456
|
+
for pos, weight in zip(top_positions.indices.tolist(), top_positions.values.tolist()):
|
|
457
|
+
token_str = tokenizer.decode([input_ids[0, pos]])
|
|
458
|
+
top_tokens.append({"token": token_str, "weight": weight})
|
|
459
|
+
|
|
460
|
+
sample_results[layer_idx] = {
|
|
461
|
+
# Legacy fields preserved (last-token)
|
|
462
|
+
"avg_entropy": avg_entropy_last_token,
|
|
463
|
+
"head_entropies": head_entropies_last_token,
|
|
464
|
+
"min_entropy": min(head_entropies_last_token),
|
|
465
|
+
"max_entropy": max(head_entropies_last_token),
|
|
466
|
+
# Explicit metrics
|
|
467
|
+
"avg_entropy_last_token": avg_entropy_last_token,
|
|
468
|
+
"avg_entropy_all_tokens": avg_entropy_all_tokens,
|
|
469
|
+
"avg_entropy_last_k_tokens": avg_entropy_last_k_tokens,
|
|
470
|
+
"head_entropies_last_token": head_entropies_last_token,
|
|
471
|
+
"head_entropies_all_tokens": head_entropies_all_tokens,
|
|
472
|
+
"head_entropies_last_k_tokens": head_entropies_last_k_tokens,
|
|
473
|
+
"last_k_tokens_used": last_k,
|
|
474
|
+
"avg_entropy_generated_tokens": generated_layer_stats.get(layer_idx, {}).get(
|
|
475
|
+
"avg_entropy_generated_tokens"
|
|
476
|
+
),
|
|
477
|
+
"std_entropy_generated_tokens": generated_layer_stats.get(layer_idx, {}).get(
|
|
478
|
+
"std_entropy_generated_tokens"
|
|
479
|
+
),
|
|
480
|
+
"head_entropies_generated_tokens": generated_layer_stats.get(layer_idx, {}).get(
|
|
481
|
+
"head_entropies_generated_tokens"
|
|
482
|
+
),
|
|
483
|
+
"generated_tokens_analyzed": generated_layer_stats.get(layer_idx, {}).get(
|
|
484
|
+
"generated_tokens_analyzed", generated_steps
|
|
485
|
+
),
|
|
486
|
+
"focused_head": min_head,
|
|
487
|
+
"top_tokens": top_tokens,
|
|
488
|
+
}
|
|
489
|
+
|
|
490
|
+
return sample_results
|
|
491
|
+
|
|
492
|
+
def run(
|
|
493
|
+
self,
|
|
494
|
+
backend: InferenceBackend,
|
|
495
|
+
dataset: BaseDataset,
|
|
496
|
+
prompt_strategy: Any,
|
|
497
|
+
num_samples: Optional[int] = None,
|
|
498
|
+
logger: Optional[ExperimentLogger] = None,
|
|
499
|
+
) -> ExperimentResult:
|
|
500
|
+
"""Run attention analysis experiment on multiple samples."""
|
|
501
|
+
|
|
502
|
+
tokenizer = backend._tokenizer
|
|
503
|
+
model = backend._model
|
|
504
|
+
|
|
505
|
+
# Get model config
|
|
506
|
+
config = model.config
|
|
507
|
+
if hasattr(config, "text_config"):
|
|
508
|
+
config = config.text_config
|
|
509
|
+
num_heads = config.num_attention_heads
|
|
510
|
+
|
|
511
|
+
if self.all_layers:
|
|
512
|
+
num_layers = getattr(config, "num_hidden_layers", None) or getattr(
|
|
513
|
+
config, "num_layers", None
|
|
514
|
+
)
|
|
515
|
+
if num_layers is None:
|
|
516
|
+
num_layers = backend.num_layers()
|
|
517
|
+
self.target_layers = list(range(0, int(num_layers), self.layer_stride))
|
|
518
|
+
|
|
519
|
+
print(f"Model: {backend.model_name}")
|
|
520
|
+
print(f"Attention heads: {num_heads}")
|
|
521
|
+
print(f"All layers enabled: {self.all_layers}")
|
|
522
|
+
print(f"Layer stride: {self.layer_stride}")
|
|
523
|
+
print(f"Resolved layers: {self.target_layers}")
|
|
524
|
+
print(
|
|
525
|
+
f"Max input tokens: {self.max_input_tokens if self.max_input_tokens is not None else 'None'}"
|
|
526
|
+
)
|
|
527
|
+
print(f"Batch size: {self.batch_size}")
|
|
528
|
+
print(f"Analyze generated tokens: {self.analyze_generated_tokens}")
|
|
529
|
+
if self.analyze_generated_tokens:
|
|
530
|
+
print(f"Generated max_new_tokens: {self.generated_max_new_tokens}")
|
|
531
|
+
|
|
532
|
+
# Set eager attention to enable output_attentions by reloading if necessary
|
|
533
|
+
# We need to check if the model is already using eager attention
|
|
534
|
+
current_attn = getattr(model, "config", None) and getattr(
|
|
535
|
+
model.config, "_attn_implementation", None
|
|
536
|
+
)
|
|
537
|
+
|
|
538
|
+
if current_attn != "eager":
|
|
539
|
+
if hasattr(model, "set_attn_implementation") and not self.force_eager_reload:
|
|
540
|
+
print(f"Current attention implementation: {current_attn}")
|
|
541
|
+
print("Switching attention implementation to 'eager' in-place...")
|
|
542
|
+
model.set_attn_implementation("eager")
|
|
543
|
+
current_attn = getattr(model.config, "_attn_implementation", None)
|
|
544
|
+
elif self.force_eager_reload:
|
|
545
|
+
print(f"Current attention implementation: {current_attn}")
|
|
546
|
+
print(
|
|
547
|
+
"Reloading model with attn_implementation='eager' to support output_attentions=True..."
|
|
548
|
+
)
|
|
549
|
+
# We need to preserve the model name before unloading
|
|
550
|
+
model_name = backend.model_name
|
|
551
|
+
backend.unload()
|
|
552
|
+
# Reload with eager attention
|
|
553
|
+
backend.load_model(model_name, attn_implementation="eager")
|
|
554
|
+
model = backend._model
|
|
555
|
+
tokenizer = backend._tokenizer
|
|
556
|
+
|
|
557
|
+
# Get samples from dataset
|
|
558
|
+
n_samples = num_samples if num_samples is not None else self.num_samples
|
|
559
|
+
samples = (
|
|
560
|
+
list(dataset)
|
|
561
|
+
if n_samples is None
|
|
562
|
+
else (dataset.sample(n_samples) if n_samples < len(dataset) else list(dataset))
|
|
563
|
+
)
|
|
564
|
+
print(f"\nAnalyzing attention on {len(samples)} samples (batch_size={self.batch_size})...")
|
|
565
|
+
|
|
566
|
+
# Aggregate statistics across samples
|
|
567
|
+
layer_entropy_stats_last_token: Dict[int, List[float]] = defaultdict(list)
|
|
568
|
+
layer_entropy_stats_all_tokens: Dict[int, List[float]] = defaultdict(list)
|
|
569
|
+
layer_entropy_stats_last_k_tokens: Dict[int, List[float]] = defaultdict(list)
|
|
570
|
+
layer_entropy_stats_generated_tokens: Dict[int, List[float]] = defaultdict(list)
|
|
571
|
+
layer_head_entropy_stats_last_token: Dict[int, List[List[float]]] = defaultdict(list)
|
|
572
|
+
layer_head_entropy_stats_all_tokens: Dict[int, List[List[float]]] = defaultdict(list)
|
|
573
|
+
layer_head_entropy_stats_last_k_tokens: Dict[int, List[List[float]]] = defaultdict(list)
|
|
574
|
+
layer_head_entropy_stats_generated_tokens: Dict[int, List[List[float]]] = defaultdict(list)
|
|
575
|
+
all_top_tokens: Dict[int, List[str]] = defaultdict(list)
|
|
576
|
+
|
|
577
|
+
sample_results = []
|
|
578
|
+
|
|
579
|
+
# Build batches
|
|
580
|
+
batches = [
|
|
581
|
+
samples[i : i + self.batch_size] for i in range(0, len(samples), self.batch_size)
|
|
582
|
+
]
|
|
583
|
+
|
|
584
|
+
for batch_samples in tqdm(batches, desc="Processing batches"):
|
|
585
|
+
prompts = [
|
|
586
|
+
prompt_strategy.build_prompt(
|
|
587
|
+
{"question": s.text, "text": s.text, "metadata": s.metadata or {}}
|
|
588
|
+
)
|
|
589
|
+
for s in batch_samples
|
|
590
|
+
]
|
|
591
|
+
|
|
592
|
+
batch_results = self._analyze_batch(
|
|
593
|
+
model, tokenizer, prompts, backend.device, num_heads
|
|
594
|
+
)
|
|
595
|
+
|
|
596
|
+
for sample, result in zip(batch_samples, batch_results):
|
|
597
|
+
if result is None:
|
|
598
|
+
print(f"\nWarning: Attention not available for sample {sample.idx}")
|
|
599
|
+
continue
|
|
600
|
+
|
|
601
|
+
sample_results.append(
|
|
602
|
+
{
|
|
603
|
+
"sample_idx": sample.idx,
|
|
604
|
+
"layer_results": result,
|
|
605
|
+
}
|
|
606
|
+
)
|
|
607
|
+
|
|
608
|
+
# Aggregate stats for this sample
|
|
609
|
+
for layer_idx, layer_data in result.items():
|
|
610
|
+
layer_entropy_stats_last_token[layer_idx].append(
|
|
611
|
+
layer_data["avg_entropy_last_token"]
|
|
612
|
+
)
|
|
613
|
+
layer_entropy_stats_all_tokens[layer_idx].append(
|
|
614
|
+
layer_data["avg_entropy_all_tokens"]
|
|
615
|
+
)
|
|
616
|
+
layer_entropy_stats_last_k_tokens[layer_idx].append(
|
|
617
|
+
layer_data["avg_entropy_last_k_tokens"]
|
|
618
|
+
)
|
|
619
|
+
layer_head_entropy_stats_last_token[layer_idx].append(
|
|
620
|
+
layer_data["head_entropies_last_token"]
|
|
621
|
+
)
|
|
622
|
+
layer_head_entropy_stats_all_tokens[layer_idx].append(
|
|
623
|
+
layer_data["head_entropies_all_tokens"]
|
|
624
|
+
)
|
|
625
|
+
layer_head_entropy_stats_last_k_tokens[layer_idx].append(
|
|
626
|
+
layer_data["head_entropies_last_k_tokens"]
|
|
627
|
+
)
|
|
628
|
+
gen_entropy = layer_data.get("avg_entropy_generated_tokens")
|
|
629
|
+
gen_head_entropies = layer_data.get("head_entropies_generated_tokens")
|
|
630
|
+
if gen_entropy is not None:
|
|
631
|
+
layer_entropy_stats_generated_tokens[layer_idx].append(gen_entropy)
|
|
632
|
+
if gen_head_entropies:
|
|
633
|
+
layer_head_entropy_stats_generated_tokens[layer_idx].append(
|
|
634
|
+
gen_head_entropies
|
|
635
|
+
)
|
|
636
|
+
for tok in layer_data["top_tokens"][:3]: # Top 3 tokens
|
|
637
|
+
all_top_tokens[layer_idx].append(tok["token"])
|
|
638
|
+
|
|
639
|
+
if not sample_results:
|
|
640
|
+
return ExperimentResult(
|
|
641
|
+
experiment_name=self.name,
|
|
642
|
+
model_name=backend.model_name,
|
|
643
|
+
prompt_strategy=prompt_strategy.name
|
|
644
|
+
if hasattr(prompt_strategy, "name")
|
|
645
|
+
else "custom",
|
|
646
|
+
metrics={"error": "attention_not_supported", "num_layers_analyzed": 0},
|
|
647
|
+
raw_outputs=[],
|
|
648
|
+
metadata={"target_layers": self.target_layers},
|
|
649
|
+
)
|
|
650
|
+
|
|
651
|
+
# Compute aggregated statistics
|
|
652
|
+
print("\n" + "=" * 70)
|
|
653
|
+
print("ATTENTION ANALYSIS: Aggregated Statistics Across Samples")
|
|
654
|
+
print("=" * 70)
|
|
655
|
+
header = (
|
|
656
|
+
f"{'Layer':<8} | {'LastTok μ':<10} | {'AllTok μ':<10} | "
|
|
657
|
+
f"{f'Last{self.last_k_tokens} μ':<10} | {'AllTok σ':<10}"
|
|
658
|
+
)
|
|
659
|
+
if self.analyze_generated_tokens:
|
|
660
|
+
header += f" | {'GenTok μ':<10}"
|
|
661
|
+
header += " | Top Tokens"
|
|
662
|
+
print(header)
|
|
663
|
+
print("-" * 106)
|
|
664
|
+
|
|
665
|
+
aggregated_results = []
|
|
666
|
+
|
|
667
|
+
for layer_idx in sorted(layer_entropy_stats_last_token.keys()):
|
|
668
|
+
entropies_last_token = layer_entropy_stats_last_token[layer_idx]
|
|
669
|
+
entropies_all_tokens = layer_entropy_stats_all_tokens[layer_idx]
|
|
670
|
+
entropies_last_k_tokens = layer_entropy_stats_last_k_tokens[layer_idx]
|
|
671
|
+
|
|
672
|
+
mean_entropy_last_token = float(np.nanmean(entropies_last_token))
|
|
673
|
+
std_entropy_last_token = float(np.nanstd(entropies_last_token))
|
|
674
|
+
mean_entropy_all_tokens = float(np.nanmean(entropies_all_tokens))
|
|
675
|
+
std_entropy_all_tokens = float(np.nanstd(entropies_all_tokens))
|
|
676
|
+
mean_entropy_last_k_tokens = float(np.nanmean(entropies_last_k_tokens))
|
|
677
|
+
std_entropy_last_k_tokens = float(np.nanstd(entropies_last_k_tokens))
|
|
678
|
+
if layer_entropy_stats_generated_tokens[layer_idx]:
|
|
679
|
+
mean_entropy_generated_tokens = float(
|
|
680
|
+
np.nanmean(layer_entropy_stats_generated_tokens[layer_idx])
|
|
681
|
+
)
|
|
682
|
+
std_entropy_generated_tokens = float(
|
|
683
|
+
np.nanstd(layer_entropy_stats_generated_tokens[layer_idx])
|
|
684
|
+
)
|
|
685
|
+
else:
|
|
686
|
+
mean_entropy_generated_tokens = float("nan")
|
|
687
|
+
std_entropy_generated_tokens = float("nan")
|
|
688
|
+
|
|
689
|
+
# Count most common top tokens
|
|
690
|
+
tokens = all_top_tokens[layer_idx]
|
|
691
|
+
from collections import Counter
|
|
692
|
+
|
|
693
|
+
token_counts = Counter(tokens)
|
|
694
|
+
top_3_tokens = token_counts.most_common(5)
|
|
695
|
+
top_tokens_str = ", ".join([f"'{t}'" for t, _ in top_3_tokens[:3]])
|
|
696
|
+
|
|
697
|
+
# Aggregate head-level entropies for each metric
|
|
698
|
+
head_entropies_last_token = np.array(layer_head_entropy_stats_last_token[layer_idx])
|
|
699
|
+
head_entropies_all_tokens = np.array(layer_head_entropy_stats_all_tokens[layer_idx])
|
|
700
|
+
head_entropies_last_k_tokens = np.array(
|
|
701
|
+
layer_head_entropy_stats_last_k_tokens[layer_idx]
|
|
702
|
+
)
|
|
703
|
+
mean_per_head_last_token = np.nanmean(head_entropies_last_token, axis=0).tolist()
|
|
704
|
+
std_per_head_last_token = np.nanstd(head_entropies_last_token, axis=0).tolist()
|
|
705
|
+
mean_per_head_all_tokens = np.nanmean(head_entropies_all_tokens, axis=0).tolist()
|
|
706
|
+
std_per_head_all_tokens = np.nanstd(head_entropies_all_tokens, axis=0).tolist()
|
|
707
|
+
mean_per_head_last_k_tokens = np.nanmean(head_entropies_last_k_tokens, axis=0).tolist()
|
|
708
|
+
std_per_head_last_k_tokens = np.nanstd(head_entropies_last_k_tokens, axis=0).tolist()
|
|
709
|
+
if layer_head_entropy_stats_generated_tokens[layer_idx]:
|
|
710
|
+
head_entropies_generated_tokens = np.array(
|
|
711
|
+
layer_head_entropy_stats_generated_tokens[layer_idx]
|
|
712
|
+
)
|
|
713
|
+
mean_per_head_generated_tokens = np.nanmean(
|
|
714
|
+
head_entropies_generated_tokens, axis=0
|
|
715
|
+
).tolist()
|
|
716
|
+
std_per_head_generated_tokens = np.nanstd(
|
|
717
|
+
head_entropies_generated_tokens, axis=0
|
|
718
|
+
).tolist()
|
|
719
|
+
else:
|
|
720
|
+
mean_per_head_generated_tokens = []
|
|
721
|
+
std_per_head_generated_tokens = []
|
|
722
|
+
|
|
723
|
+
aggregated_results.append(
|
|
724
|
+
{
|
|
725
|
+
"layer": layer_idx,
|
|
726
|
+
# Legacy keys (last-token metric)
|
|
727
|
+
"mean_entropy": mean_entropy_last_token,
|
|
728
|
+
"std_entropy": std_entropy_last_token,
|
|
729
|
+
"mean_per_head": mean_per_head_last_token,
|
|
730
|
+
"std_per_head": std_per_head_last_token,
|
|
731
|
+
# Explicit metrics
|
|
732
|
+
"mean_entropy_last_token": mean_entropy_last_token,
|
|
733
|
+
"std_entropy_last_token": std_entropy_last_token,
|
|
734
|
+
"mean_entropy_all_tokens": mean_entropy_all_tokens,
|
|
735
|
+
"std_entropy_all_tokens": std_entropy_all_tokens,
|
|
736
|
+
"mean_entropy_last_k_tokens": mean_entropy_last_k_tokens,
|
|
737
|
+
"std_entropy_last_k_tokens": std_entropy_last_k_tokens,
|
|
738
|
+
"mean_entropy_generated_tokens": (
|
|
739
|
+
None
|
|
740
|
+
if np.isnan(mean_entropy_generated_tokens)
|
|
741
|
+
else mean_entropy_generated_tokens
|
|
742
|
+
),
|
|
743
|
+
"std_entropy_generated_tokens": (
|
|
744
|
+
None
|
|
745
|
+
if np.isnan(std_entropy_generated_tokens)
|
|
746
|
+
else std_entropy_generated_tokens
|
|
747
|
+
),
|
|
748
|
+
"last_k_tokens": self.last_k_tokens,
|
|
749
|
+
"mean_per_head_last_token": mean_per_head_last_token,
|
|
750
|
+
"std_per_head_last_token": std_per_head_last_token,
|
|
751
|
+
"mean_per_head_all_tokens": mean_per_head_all_tokens,
|
|
752
|
+
"std_per_head_all_tokens": std_per_head_all_tokens,
|
|
753
|
+
"mean_per_head_last_k_tokens": mean_per_head_last_k_tokens,
|
|
754
|
+
"std_per_head_last_k_tokens": std_per_head_last_k_tokens,
|
|
755
|
+
"mean_per_head_generated_tokens": mean_per_head_generated_tokens,
|
|
756
|
+
"std_per_head_generated_tokens": std_per_head_generated_tokens,
|
|
757
|
+
"top_tokens": [{"token": t, "count": c} for t, c in top_3_tokens],
|
|
758
|
+
}
|
|
759
|
+
)
|
|
760
|
+
|
|
761
|
+
row = (
|
|
762
|
+
f"L{layer_idx:<7} | {mean_entropy_last_token:<10.4f} | "
|
|
763
|
+
f"{mean_entropy_all_tokens:<10.4f} | {mean_entropy_last_k_tokens:<10.4f} | "
|
|
764
|
+
f"{std_entropy_all_tokens:<10.4f}"
|
|
765
|
+
)
|
|
766
|
+
if self.analyze_generated_tokens:
|
|
767
|
+
if np.isnan(mean_entropy_generated_tokens):
|
|
768
|
+
row += f" | {'NA':<10}"
|
|
769
|
+
else:
|
|
770
|
+
row += f" | {mean_entropy_generated_tokens:<10.4f}"
|
|
771
|
+
row += f" | {top_tokens_str}"
|
|
772
|
+
print(row)
|
|
773
|
+
|
|
774
|
+
print("-" * 106)
|
|
775
|
+
|
|
776
|
+
# Overall metrics
|
|
777
|
+
all_mean_entropies_last_token = [r["mean_entropy_last_token"] for r in aggregated_results]
|
|
778
|
+
all_mean_entropies_all_tokens = [r["mean_entropy_all_tokens"] for r in aggregated_results]
|
|
779
|
+
all_mean_entropies_last_k_tokens = [
|
|
780
|
+
r["mean_entropy_last_k_tokens"] for r in aggregated_results
|
|
781
|
+
]
|
|
782
|
+
all_mean_entropies_generated_tokens = [
|
|
783
|
+
r["mean_entropy_generated_tokens"]
|
|
784
|
+
for r in aggregated_results
|
|
785
|
+
if r["mean_entropy_generated_tokens"] is not None
|
|
786
|
+
]
|
|
787
|
+
overall_mean_last_token = (
|
|
788
|
+
float(np.nanmean(all_mean_entropies_last_token))
|
|
789
|
+
if all_mean_entropies_last_token
|
|
790
|
+
else 0.0
|
|
791
|
+
)
|
|
792
|
+
overall_mean_all_tokens = (
|
|
793
|
+
float(np.nanmean(all_mean_entropies_all_tokens))
|
|
794
|
+
if all_mean_entropies_all_tokens
|
|
795
|
+
else 0.0
|
|
796
|
+
)
|
|
797
|
+
overall_mean_last_k_tokens = (
|
|
798
|
+
float(np.nanmean(all_mean_entropies_last_k_tokens))
|
|
799
|
+
if all_mean_entropies_last_k_tokens
|
|
800
|
+
else 0.0
|
|
801
|
+
)
|
|
802
|
+
overall_mean_generated_tokens = (
|
|
803
|
+
float(np.nanmean(all_mean_entropies_generated_tokens))
|
|
804
|
+
if all_mean_entropies_generated_tokens
|
|
805
|
+
else None
|
|
806
|
+
)
|
|
807
|
+
|
|
808
|
+
# Most focused layer using all-tokens metric (primary)
|
|
809
|
+
valid_layers_all_tokens = [
|
|
810
|
+
r for r in aggregated_results if not np.isnan(r["mean_entropy_all_tokens"])
|
|
811
|
+
]
|
|
812
|
+
most_focused_layer = (
|
|
813
|
+
min(valid_layers_all_tokens, key=lambda x: x["mean_entropy_all_tokens"])["layer"]
|
|
814
|
+
if valid_layers_all_tokens
|
|
815
|
+
else None
|
|
816
|
+
)
|
|
817
|
+
most_focused_entropy = (
|
|
818
|
+
min(r["mean_entropy_all_tokens"] for r in valid_layers_all_tokens)
|
|
819
|
+
if valid_layers_all_tokens
|
|
820
|
+
else 0.0
|
|
821
|
+
)
|
|
822
|
+
|
|
823
|
+
# Legacy most-focused values for last-token metric
|
|
824
|
+
valid_layers_last_token = [
|
|
825
|
+
r for r in aggregated_results if not np.isnan(r["mean_entropy_last_token"])
|
|
826
|
+
]
|
|
827
|
+
most_focused_layer_last_token = (
|
|
828
|
+
min(valid_layers_last_token, key=lambda x: x["mean_entropy_last_token"])["layer"]
|
|
829
|
+
if valid_layers_last_token
|
|
830
|
+
else None
|
|
831
|
+
)
|
|
832
|
+
most_focused_entropy_last_token = (
|
|
833
|
+
min(r["mean_entropy_last_token"] for r in valid_layers_last_token)
|
|
834
|
+
if valid_layers_last_token
|
|
835
|
+
else 0.0
|
|
836
|
+
)
|
|
837
|
+
|
|
838
|
+
metrics = {
|
|
839
|
+
"num_samples_analyzed": len(sample_results),
|
|
840
|
+
"num_layers_analyzed": len(aggregated_results),
|
|
841
|
+
"num_heads": num_heads,
|
|
842
|
+
# Primary metrics (all-tokens)
|
|
843
|
+
"overall_mean_entropy": float(overall_mean_all_tokens),
|
|
844
|
+
"overall_mean_entropy_all_tokens": float(overall_mean_all_tokens),
|
|
845
|
+
"overall_mean_entropy_last_token": float(overall_mean_last_token),
|
|
846
|
+
"overall_mean_entropy_last_k_tokens": float(overall_mean_last_k_tokens),
|
|
847
|
+
"last_k_tokens": self.last_k_tokens,
|
|
848
|
+
"most_focused_layer": most_focused_layer,
|
|
849
|
+
"most_focused_entropy": float(most_focused_entropy),
|
|
850
|
+
"most_focused_layer_all_tokens": most_focused_layer,
|
|
851
|
+
"most_focused_entropy_all_tokens": float(most_focused_entropy),
|
|
852
|
+
"most_focused_layer_last_token": most_focused_layer_last_token,
|
|
853
|
+
"most_focused_entropy_last_token": float(most_focused_entropy_last_token),
|
|
854
|
+
"analyze_generated_tokens": self.analyze_generated_tokens,
|
|
855
|
+
}
|
|
856
|
+
if overall_mean_generated_tokens is not None:
|
|
857
|
+
metrics["overall_mean_entropy_generated_tokens"] = float(overall_mean_generated_tokens)
|
|
858
|
+
|
|
859
|
+
print(f"\nOverall mean entropy (all tokens): {overall_mean_all_tokens:.4f}")
|
|
860
|
+
print(f"Overall mean entropy (last token): {overall_mean_last_token:.4f}")
|
|
861
|
+
print(
|
|
862
|
+
f"Overall mean entropy (last {self.last_k_tokens} tokens): {overall_mean_last_k_tokens:.4f}"
|
|
863
|
+
)
|
|
864
|
+
if overall_mean_generated_tokens is not None:
|
|
865
|
+
print(f"Overall mean entropy (generated tokens): {overall_mean_generated_tokens:.4f}")
|
|
866
|
+
print(
|
|
867
|
+
f"Most focused layer (all tokens): L{most_focused_layer} (entropy: {most_focused_entropy:.4f})"
|
|
868
|
+
)
|
|
869
|
+
|
|
870
|
+
return ExperimentResult(
|
|
871
|
+
experiment_name=self.name,
|
|
872
|
+
model_name=backend.model_name,
|
|
873
|
+
prompt_strategy=prompt_strategy.name if hasattr(prompt_strategy, "name") else "custom",
|
|
874
|
+
metrics=metrics,
|
|
875
|
+
raw_outputs=aggregated_results,
|
|
876
|
+
metadata={
|
|
877
|
+
"target_layers": self.target_layers,
|
|
878
|
+
"last_k_tokens": self.last_k_tokens,
|
|
879
|
+
"analyze_generated_tokens": self.analyze_generated_tokens,
|
|
880
|
+
"generated_max_new_tokens": self.generated_max_new_tokens,
|
|
881
|
+
"batch_size": self.batch_size,
|
|
882
|
+
"num_samples": len(samples),
|
|
883
|
+
"sample_results": sample_results, # Include per-sample data
|
|
884
|
+
},
|
|
885
|
+
)
|