wisent 0.7.901__py3-none-any.whl → 0.7.1116__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- wisent/__init__.py +1 -1
- wisent/comparison/__init__.py +1 -0
- wisent/comparison/detect_bos_features.py +275 -0
- wisent/comparison/fgaa.py +465 -0
- wisent/comparison/lora.py +663 -0
- wisent/comparison/lora_dpo.py +604 -0
- wisent/comparison/main.py +444 -0
- wisent/comparison/ours.py +76 -0
- wisent/comparison/reft.py +690 -0
- wisent/comparison/sae.py +304 -0
- wisent/comparison/utils.py +381 -0
- wisent/core/activations/activations_collector.py +3 -2
- wisent/core/activations/extraction_strategy.py +8 -4
- wisent/core/cli/agent/apply_steering.py +7 -5
- wisent/core/cli/agent/train_classifier.py +4 -3
- wisent/core/cli/generate_vector_from_task.py +11 -20
- wisent/core/cli/get_activations.py +1 -1
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/boolq.py +20 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/cb.py +8 -1
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/truthfulqa_mc1.py +8 -1
- wisent/core/parser_arguments/generate_vector_from_task_parser.py +4 -11
- wisent/core/parser_arguments/get_activations_parser.py +5 -14
- {wisent-0.7.901.dist-info → wisent-0.7.1116.dist-info}/METADATA +5 -1
- {wisent-0.7.901.dist-info → wisent-0.7.1116.dist-info}/RECORD +28 -91
- wisent/examples/contrastive_pairs/humanization_human_vs_ai.json +0 -2112
- wisent/examples/scripts/1/test_basqueglue_evaluation.json +0 -51
- wisent/examples/scripts/1/test_basqueglue_pairs.json +0 -14
- wisent/examples/scripts/1/test_bec2016eu_evaluation.json +0 -51
- wisent/examples/scripts/1/test_bec2016eu_pairs.json +0 -14
- wisent/examples/scripts/1/test_belebele_evaluation.json +0 -51
- wisent/examples/scripts/1/test_belebele_pairs.json +0 -14
- wisent/examples/scripts/1/test_benchmarks_evaluation.json +0 -51
- wisent/examples/scripts/1/test_benchmarks_pairs.json +0 -14
- wisent/examples/scripts/1/test_bertaqa_evaluation.json +0 -51
- wisent/examples/scripts/1/test_bertaqa_pairs.json +0 -14
- wisent/examples/scripts/1/test_bhtc_v2_evaluation.json +0 -30
- wisent/examples/scripts/1/test_bhtc_v2_pairs.json +0 -8
- wisent/examples/scripts/1/test_boolq-seq2seq_evaluation.json +0 -30
- wisent/examples/scripts/1/test_boolq-seq2seq_pairs.json +0 -8
- wisent/examples/scripts/1/test_cabreu_evaluation.json +0 -30
- wisent/examples/scripts/1/test_cabreu_pairs.json +0 -8
- wisent/examples/scripts/1/test_careqa_en_evaluation.json +0 -30
- wisent/examples/scripts/1/test_careqa_en_pairs.json +0 -8
- wisent/examples/scripts/1/test_careqa_evaluation.json +0 -30
- wisent/examples/scripts/1/test_careqa_pairs.json +0 -8
- wisent/examples/scripts/1/test_catalanqa_evaluation.json +0 -30
- wisent/examples/scripts/1/test_catalanqa_pairs.json +0 -8
- wisent/examples/scripts/1/test_catcola_evaluation.json +0 -30
- wisent/examples/scripts/1/test_catcola_pairs.json +0 -8
- wisent/examples/scripts/1/test_chartqa_evaluation.json +0 -30
- wisent/examples/scripts/1/test_chartqa_pairs.json +0 -8
- wisent/examples/scripts/1/test_claim_stance_topic_evaluation.json +0 -30
- wisent/examples/scripts/1/test_claim_stance_topic_pairs.json +0 -8
- wisent/examples/scripts/1/test_cnn_dailymail_evaluation.json +0 -30
- wisent/examples/scripts/1/test_cnn_dailymail_pairs.json +0 -8
- wisent/examples/scripts/1/test_cocoteros_es_evaluation.json +0 -30
- wisent/examples/scripts/1/test_cocoteros_es_pairs.json +0 -8
- wisent/examples/scripts/1/test_coedit_gec_evaluation.json +0 -30
- wisent/examples/scripts/1/test_coedit_gec_pairs.json +0 -8
- wisent/examples/scripts/1/test_cola_evaluation.json +0 -30
- wisent/examples/scripts/1/test_cola_pairs.json +0 -8
- wisent/examples/scripts/1/test_coqcat_evaluation.json +0 -30
- wisent/examples/scripts/1/test_coqcat_pairs.json +0 -8
- wisent/examples/scripts/1/test_dbpedia_14_evaluation.json +0 -30
- wisent/examples/scripts/1/test_dbpedia_14_pairs.json +0 -8
- wisent/examples/scripts/1/test_epec_koref_bin_evaluation.json +0 -30
- wisent/examples/scripts/1/test_epec_koref_bin_pairs.json +0 -8
- wisent/examples/scripts/1/test_ethos_binary_evaluation.json +0 -30
- wisent/examples/scripts/1/test_ethos_binary_pairs.json +0 -8
- wisent/examples/scripts/2/test_afrimgsm_direct_amh_evaluation.json +0 -30
- wisent/examples/scripts/2/test_afrimgsm_direct_amh_pairs.json +0 -8
- wisent/examples/scripts/2/test_afrimmlu_direct_amh_evaluation.json +0 -30
- wisent/examples/scripts/2/test_afrimmlu_direct_amh_pairs.json +0 -8
- wisent/examples/scripts/2/test_afrixnli_en_direct_amh_evaluation.json +0 -30
- wisent/examples/scripts/2/test_afrixnli_en_direct_amh_pairs.json +0 -8
- wisent/examples/scripts/2/test_arc_ar_evaluation.json +0 -30
- wisent/examples/scripts/2/test_arc_ar_pairs.json +0 -8
- wisent/examples/scripts/2/test_atis_evaluation.json +0 -30
- wisent/examples/scripts/2/test_atis_pairs.json +0 -8
- wisent/examples/scripts/2/test_babi_evaluation.json +0 -30
- wisent/examples/scripts/2/test_babi_pairs.json +0 -8
- wisent/examples/scripts/2/test_babilong_evaluation.json +0 -30
- wisent/examples/scripts/2/test_babilong_pairs.json +0 -8
- wisent/examples/scripts/2/test_bangla_mmlu_evaluation.json +0 -30
- wisent/examples/scripts/2/test_bangla_mmlu_pairs.json +0 -8
- wisent/examples/scripts/2/test_basque-glue_pairs.json +0 -14
- wisent/examples/scripts/generate_paper_data.py +0 -384
- wisent/examples/scripts/intervention_validation.py +0 -626
- wisent/examples/scripts/results/test_AraDiCE_ArabicMMLU_lev_evaluation.json +0 -324
- wisent/examples/scripts/results/test_AraDiCE_ArabicMMLU_lev_pairs.json +0 -92
- wisent/examples/scripts/results/test_aexams_IslamicStudies_evaluation.json +0 -324
- wisent/examples/scripts/results/test_aexams_IslamicStudies_pairs.json +0 -92
- wisent/examples/scripts/results/test_afrimgsm_pairs.json +0 -92
- wisent/examples/scripts/results/test_afrimmlu_evaluation.json +0 -324
- wisent/examples/scripts/results/test_afrimmlu_pairs.json +0 -92
- wisent/examples/scripts/threshold_analysis.py +0 -434
- wisent/examples/scripts/visualization_gallery.py +0 -582
- {wisent-0.7.901.dist-info → wisent-0.7.1116.dist-info}/WHEEL +0 -0
- {wisent-0.7.901.dist-info → wisent-0.7.1116.dist-info}/entry_points.txt +0 -0
- {wisent-0.7.901.dist-info → wisent-0.7.1116.dist-info}/licenses/LICENSE +0 -0
- {wisent-0.7.901.dist-info → wisent-0.7.1116.dist-info}/top_level.txt +0 -0
wisent/comparison/sae.py
ADDED
|
@@ -0,0 +1,304 @@
|
|
|
1
|
+
"""
|
|
2
|
+
SAE-based steering method for comparison experiments.
|
|
3
|
+
|
|
4
|
+
Uses Sparse Autoencoders to identify steering directions from contrastive pairs.
|
|
5
|
+
Computes steering vector using SAE decoder features weighted by feature differences.
|
|
6
|
+
|
|
7
|
+
Supports Gemma models with Gemma Scope SAEs via sae_lens.
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
from __future__ import annotations
|
|
11
|
+
|
|
12
|
+
import gc
|
|
13
|
+
import json
|
|
14
|
+
from pathlib import Path
|
|
15
|
+
from typing import TYPE_CHECKING
|
|
16
|
+
|
|
17
|
+
import torch
|
|
18
|
+
|
|
19
|
+
from wisent.comparison.utils import (
|
|
20
|
+
apply_steering_to_model,
|
|
21
|
+
remove_steering,
|
|
22
|
+
convert_to_lm_eval_format,
|
|
23
|
+
generate_contrastive_pairs,
|
|
24
|
+
load_model_and_tokenizer,
|
|
25
|
+
load_sae,
|
|
26
|
+
SAE_CONFIGS,
|
|
27
|
+
)
|
|
28
|
+
|
|
29
|
+
if TYPE_CHECKING:
|
|
30
|
+
from wisent.core.models.wisent_model import WisentModel
|
|
31
|
+
|
|
32
|
+
__all__ = ["generate_steering_vector", "apply_steering_to_model", "remove_steering", "convert_to_lm_eval_format"]
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def _get_residual_stream_activations(
|
|
36
|
+
model,
|
|
37
|
+
tokenizer,
|
|
38
|
+
text: str,
|
|
39
|
+
layer_idx: int,
|
|
40
|
+
device: str,
|
|
41
|
+
) -> torch.Tensor:
|
|
42
|
+
"""
|
|
43
|
+
Get residual stream activations from a specific layer.
|
|
44
|
+
|
|
45
|
+
Args:
|
|
46
|
+
model: HuggingFace model
|
|
47
|
+
tokenizer: Tokenizer
|
|
48
|
+
text: Input text
|
|
49
|
+
layer_idx: Layer index (0-indexed)
|
|
50
|
+
device: Device
|
|
51
|
+
|
|
52
|
+
Returns:
|
|
53
|
+
Tensor of shape (1, seq_len, d_model)
|
|
54
|
+
"""
|
|
55
|
+
inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=1024)
|
|
56
|
+
inputs = {k: v.to(device) for k, v in inputs.items()}
|
|
57
|
+
|
|
58
|
+
with torch.no_grad():
|
|
59
|
+
out = model(**inputs, output_hidden_states=True, use_cache=False)
|
|
60
|
+
|
|
61
|
+
# hidden_states is tuple: (embedding, layer0, layer1, ..., layerN)
|
|
62
|
+
# layer_idx=0 -> hs[1], layer_idx=12 -> hs[13]
|
|
63
|
+
hs = out.hidden_states
|
|
64
|
+
return hs[layer_idx + 1]
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
def compute_feature_diff(
|
|
68
|
+
model,
|
|
69
|
+
tokenizer,
|
|
70
|
+
sae,
|
|
71
|
+
pairs: list[dict],
|
|
72
|
+
layer_idx: int,
|
|
73
|
+
device: str,
|
|
74
|
+
) -> torch.Tensor:
|
|
75
|
+
"""
|
|
76
|
+
Compute feature difference between positive and negative examples in SAE space.
|
|
77
|
+
|
|
78
|
+
feature_diff = mean(encode(h_pos)) - mean(encode(h_neg))
|
|
79
|
+
|
|
80
|
+
Args:
|
|
81
|
+
model: HuggingFace model
|
|
82
|
+
tokenizer: Tokenizer
|
|
83
|
+
sae: SAE object from sae_lens
|
|
84
|
+
pairs: List of contrastive pairs
|
|
85
|
+
layer_idx: Layer to extract activations from
|
|
86
|
+
device: Device
|
|
87
|
+
|
|
88
|
+
Returns:
|
|
89
|
+
feature_diff tensor of shape [d_sae]
|
|
90
|
+
"""
|
|
91
|
+
pos_features_list = []
|
|
92
|
+
neg_features_list = []
|
|
93
|
+
|
|
94
|
+
print(f" Computing feature diff from {len(pairs)} pairs...")
|
|
95
|
+
|
|
96
|
+
for i, pair in enumerate(pairs):
|
|
97
|
+
prompt = pair["prompt"]
|
|
98
|
+
pos_response = pair["positive_response"]["model_response"]
|
|
99
|
+
neg_response = pair["negative_response"]["model_response"]
|
|
100
|
+
|
|
101
|
+
pos_text = f"{prompt} {pos_response}"
|
|
102
|
+
neg_text = f"{prompt} {neg_response}"
|
|
103
|
+
|
|
104
|
+
# Get activations and encode through SAE
|
|
105
|
+
pos_acts = _get_residual_stream_activations(model, tokenizer, pos_text, layer_idx, device)
|
|
106
|
+
pos_acts = pos_acts.to(device).to(sae.W_enc.dtype)
|
|
107
|
+
pos_latents = sae.encode(pos_acts)
|
|
108
|
+
pos_features_list.append(pos_latents.mean(dim=1).detach()) # Mean over sequence
|
|
109
|
+
|
|
110
|
+
neg_acts = _get_residual_stream_activations(model, tokenizer, neg_text, layer_idx, device)
|
|
111
|
+
neg_acts = neg_acts.to(device).to(sae.W_enc.dtype)
|
|
112
|
+
neg_latents = sae.encode(neg_acts)
|
|
113
|
+
neg_features_list.append(neg_latents.mean(dim=1).detach())
|
|
114
|
+
|
|
115
|
+
if (i + 1) % 10 == 0:
|
|
116
|
+
print(f" Processed {i + 1}/{len(pairs)} pairs")
|
|
117
|
+
|
|
118
|
+
# Stack and compute mean difference
|
|
119
|
+
pos_features = torch.cat(pos_features_list, dim=0) # [num_pairs, d_sae]
|
|
120
|
+
neg_features = torch.cat(neg_features_list, dim=0)
|
|
121
|
+
|
|
122
|
+
feature_diff = pos_features.mean(dim=0) - neg_features.mean(dim=0) # [d_sae]
|
|
123
|
+
|
|
124
|
+
print(f" feature_diff computed, shape: {feature_diff.shape}")
|
|
125
|
+
print(f" feature_diff stats: mean={feature_diff.mean():.6f}, std={feature_diff.std():.6f}")
|
|
126
|
+
|
|
127
|
+
return feature_diff
|
|
128
|
+
|
|
129
|
+
|
|
130
|
+
def compute_steering_vector_from_decoder(
|
|
131
|
+
feature_diff: torch.Tensor,
|
|
132
|
+
sae,
|
|
133
|
+
top_k: int = 4,
|
|
134
|
+
normalize: bool = True,
|
|
135
|
+
) -> tuple[torch.Tensor, dict]:
|
|
136
|
+
"""
|
|
137
|
+
Compute steering vector using SAE decoder features.
|
|
138
|
+
|
|
139
|
+
steering_vector = sum(feature_diff[i] * W_dec[i]) for top-k features
|
|
140
|
+
|
|
141
|
+
Args:
|
|
142
|
+
feature_diff: Difference vector in SAE space [d_sae]
|
|
143
|
+
sae: SAE object with W_dec decoder weights
|
|
144
|
+
top_k: Number of top features to use
|
|
145
|
+
normalize: Whether to normalize the final steering vector
|
|
146
|
+
|
|
147
|
+
Returns:
|
|
148
|
+
Tuple of (steering_vector [d_model], feature_info dict)
|
|
149
|
+
"""
|
|
150
|
+
# Find top-k features by absolute difference
|
|
151
|
+
abs_diff = feature_diff.abs()
|
|
152
|
+
top_values, top_indices = abs_diff.topk(min(top_k, len(feature_diff)))
|
|
153
|
+
|
|
154
|
+
print(f" Selected top {len(top_indices)} features")
|
|
155
|
+
print(f" Top feature indices: {top_indices[:10].tolist()}...")
|
|
156
|
+
print(f" Top feature diff magnitudes: {top_values[:10].tolist()}")
|
|
157
|
+
|
|
158
|
+
# Construct steering vector from decoder
|
|
159
|
+
# W_dec shape: [d_sae, d_model]
|
|
160
|
+
steering_vector = torch.zeros(sae.W_dec.shape[1], device=sae.W_dec.device, dtype=sae.W_dec.dtype)
|
|
161
|
+
|
|
162
|
+
for feat_idx in top_indices:
|
|
163
|
+
steering_vector += feature_diff[feat_idx] * sae.W_dec[feat_idx]
|
|
164
|
+
|
|
165
|
+
if normalize:
|
|
166
|
+
norm = steering_vector.norm()
|
|
167
|
+
if norm > 0:
|
|
168
|
+
steering_vector = steering_vector / norm
|
|
169
|
+
print(f" Normalized steering vector (original norm: {norm:.4f})")
|
|
170
|
+
|
|
171
|
+
print(f" steering_vector shape: {steering_vector.shape}, norm: {steering_vector.norm():.6f}")
|
|
172
|
+
|
|
173
|
+
feature_info = {
|
|
174
|
+
"top_k": top_k,
|
|
175
|
+
"top_indices": top_indices.tolist(),
|
|
176
|
+
"top_diff_values": [feature_diff[i].item() for i in top_indices],
|
|
177
|
+
}
|
|
178
|
+
|
|
179
|
+
return steering_vector, feature_info
|
|
180
|
+
|
|
181
|
+
|
|
182
|
+
def generate_steering_vector(
|
|
183
|
+
task: str,
|
|
184
|
+
model_name: str,
|
|
185
|
+
output_path: str | Path,
|
|
186
|
+
trait_label: str = "correctness",
|
|
187
|
+
num_pairs: int = 50,
|
|
188
|
+
method: str = "sae",
|
|
189
|
+
layers: str | None = None,
|
|
190
|
+
normalize: bool = True,
|
|
191
|
+
device: str = "cuda:0",
|
|
192
|
+
keep_intermediate: bool = False,
|
|
193
|
+
top_k: int = 4,
|
|
194
|
+
**kwargs, # Accept additional kwargs for compatibility
|
|
195
|
+
) -> Path:
|
|
196
|
+
"""
|
|
197
|
+
Generate a steering vector using SAE decoder features.
|
|
198
|
+
|
|
199
|
+
Args:
|
|
200
|
+
task: lm-eval task name (e.g., 'boolq', 'cb')
|
|
201
|
+
model_name: HuggingFace model name (must be Gemma 2B or 9B)
|
|
202
|
+
output_path: Where to save the steering vector
|
|
203
|
+
trait_label: Label for the trait being steered
|
|
204
|
+
num_pairs: Number of contrastive pairs to use
|
|
205
|
+
method: Method name (should be 'sae')
|
|
206
|
+
layers: Layer(s) to use (e.g., '12' or '10,11,12')
|
|
207
|
+
normalize: Whether to normalize the steering vector
|
|
208
|
+
device: Device to run on
|
|
209
|
+
keep_intermediate: Whether to keep intermediate files
|
|
210
|
+
top_k: Number of top SAE features to use
|
|
211
|
+
|
|
212
|
+
Returns:
|
|
213
|
+
Path to the saved steering vector
|
|
214
|
+
"""
|
|
215
|
+
output_path = Path(output_path)
|
|
216
|
+
|
|
217
|
+
if model_name not in SAE_CONFIGS:
|
|
218
|
+
raise ValueError(
|
|
219
|
+
f"No SAE config for model '{model_name}'. "
|
|
220
|
+
f"Supported models: {list(SAE_CONFIGS.keys())}"
|
|
221
|
+
)
|
|
222
|
+
|
|
223
|
+
config = SAE_CONFIGS[model_name]
|
|
224
|
+
|
|
225
|
+
# Parse layers
|
|
226
|
+
if layers is None:
|
|
227
|
+
layer_indices = [config["default_layer"]]
|
|
228
|
+
elif layers == "all":
|
|
229
|
+
layer_indices = list(range(config["num_layers"]))
|
|
230
|
+
else:
|
|
231
|
+
layer_indices = [int(l.strip()) for l in layers.split(",")]
|
|
232
|
+
|
|
233
|
+
# Step 1: Generate contrastive pairs
|
|
234
|
+
print(f"Step 1: Generating contrastive pairs from task: {task}")
|
|
235
|
+
pairs, pairs_file = generate_contrastive_pairs(task, num_pairs)
|
|
236
|
+
print(f" Loaded {len(pairs)} contrastive pairs")
|
|
237
|
+
|
|
238
|
+
# Step 2: Load model
|
|
239
|
+
print(f"\nStep 2: Loading model {model_name}...")
|
|
240
|
+
model, tokenizer = load_model_and_tokenizer(model_name, device)
|
|
241
|
+
|
|
242
|
+
steering_vectors = {}
|
|
243
|
+
feature_info = {}
|
|
244
|
+
|
|
245
|
+
for layer_idx in layer_indices:
|
|
246
|
+
print(f"\nStep 3: Processing layer {layer_idx}")
|
|
247
|
+
|
|
248
|
+
# Load SAE for this layer
|
|
249
|
+
sae, sparsity = load_sae(model_name, layer_idx, device=device)
|
|
250
|
+
|
|
251
|
+
# Step 4: Compute feature difference
|
|
252
|
+
print(f"\nStep 4: Computing feature diff for layer {layer_idx}...")
|
|
253
|
+
feat_diff = compute_feature_diff(model, tokenizer, sae, pairs, layer_idx, device)
|
|
254
|
+
|
|
255
|
+
# Step 5: Compute steering vector from decoder
|
|
256
|
+
print(f"\nStep 5: Computing steering vector from decoder for layer {layer_idx}...")
|
|
257
|
+
steering_vec, feat_info = compute_steering_vector_from_decoder(
|
|
258
|
+
feat_diff, sae, top_k=top_k, normalize=normalize
|
|
259
|
+
)
|
|
260
|
+
|
|
261
|
+
steering_vectors[str(layer_idx)] = steering_vec.cpu().float().tolist()
|
|
262
|
+
feature_info[str(layer_idx)] = feat_info
|
|
263
|
+
|
|
264
|
+
# Cleanup SAE
|
|
265
|
+
del sae, sparsity, feat_diff
|
|
266
|
+
gc.collect()
|
|
267
|
+
if torch.cuda.is_available():
|
|
268
|
+
torch.cuda.empty_cache()
|
|
269
|
+
|
|
270
|
+
# Cleanup model
|
|
271
|
+
del model
|
|
272
|
+
gc.collect()
|
|
273
|
+
if torch.cuda.is_available():
|
|
274
|
+
torch.cuda.empty_cache()
|
|
275
|
+
torch.cuda.synchronize()
|
|
276
|
+
|
|
277
|
+
# Cleanup temp files
|
|
278
|
+
if not keep_intermediate:
|
|
279
|
+
import os
|
|
280
|
+
os.unlink(pairs_file)
|
|
281
|
+
|
|
282
|
+
# Save results
|
|
283
|
+
result = {
|
|
284
|
+
"steering_vectors": steering_vectors,
|
|
285
|
+
"layers": [str(l) for l in layer_indices],
|
|
286
|
+
"model": model_name,
|
|
287
|
+
"method": "sae",
|
|
288
|
+
"trait_label": trait_label,
|
|
289
|
+
"task": task,
|
|
290
|
+
"num_pairs": len(pairs),
|
|
291
|
+
"sae_config": {
|
|
292
|
+
"release": config["sae_release"],
|
|
293
|
+
"top_k": top_k,
|
|
294
|
+
"normalize": normalize,
|
|
295
|
+
},
|
|
296
|
+
"feature_info": feature_info,
|
|
297
|
+
}
|
|
298
|
+
|
|
299
|
+
output_path.parent.mkdir(parents=True, exist_ok=True)
|
|
300
|
+
with open(output_path, "w") as f:
|
|
301
|
+
json.dump(result, f, indent=2)
|
|
302
|
+
|
|
303
|
+
print(f"\nSaved SAE steering vector to {output_path}")
|
|
304
|
+
return output_path
|
|
@@ -0,0 +1,381 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Shared utilities for comparison experiments.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
from __future__ import annotations
|
|
6
|
+
|
|
7
|
+
import json
|
|
8
|
+
import tempfile
|
|
9
|
+
from argparse import Namespace
|
|
10
|
+
from pathlib import Path
|
|
11
|
+
from typing import TYPE_CHECKING
|
|
12
|
+
|
|
13
|
+
import torch
|
|
14
|
+
from wisent.core.utils.device import preferred_dtype
|
|
15
|
+
|
|
16
|
+
if TYPE_CHECKING:
|
|
17
|
+
from wisent.core.models.wisent_model import WisentModel
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
# SAE configurations for supported Gemma models
|
|
21
|
+
SAE_CONFIGS = {
|
|
22
|
+
"google/gemma-2-2b": {
|
|
23
|
+
"sae_release": "gemma-scope-2b-pt-res-canonical",
|
|
24
|
+
"sae_id_template": "layer_{layer}/width_16k/canonical",
|
|
25
|
+
"num_layers": 26,
|
|
26
|
+
"default_layer": 12,
|
|
27
|
+
"d_model": 2304,
|
|
28
|
+
"d_sae": 16384,
|
|
29
|
+
},
|
|
30
|
+
"google/gemma-2-9b": {
|
|
31
|
+
"sae_release": "gemma-scope-9b-pt-res-canonical",
|
|
32
|
+
"sae_id_template": "layer_{layer}/width_16k/canonical",
|
|
33
|
+
"num_layers": 42,
|
|
34
|
+
"default_layer": 12,
|
|
35
|
+
"d_model": 3584,
|
|
36
|
+
"d_sae": 16384,
|
|
37
|
+
},
|
|
38
|
+
}
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
def load_sae(model_name: str, layer_idx: int, device: str = "cuda:0"):
|
|
42
|
+
"""
|
|
43
|
+
Load Gemma Scope SAE for a specific layer.
|
|
44
|
+
|
|
45
|
+
Args:
|
|
46
|
+
model_name: HuggingFace model name (e.g., 'google/gemma-2-2b')
|
|
47
|
+
layer_idx: Layer index to load SAE for
|
|
48
|
+
device: Device to load SAE on
|
|
49
|
+
|
|
50
|
+
Returns:
|
|
51
|
+
Tuple of (SAE object, sparsity tensor)
|
|
52
|
+
"""
|
|
53
|
+
from sae_lens import SAE
|
|
54
|
+
|
|
55
|
+
if model_name not in SAE_CONFIGS:
|
|
56
|
+
raise ValueError(f"No SAE config for model '{model_name}'. Supported: {list(SAE_CONFIGS.keys())}")
|
|
57
|
+
|
|
58
|
+
config = SAE_CONFIGS[model_name]
|
|
59
|
+
sae_id = config["sae_id_template"].format(layer=layer_idx)
|
|
60
|
+
|
|
61
|
+
print(f" Loading SAE from {config['sae_release']} / {sae_id}")
|
|
62
|
+
sae, _, sparsity = SAE.from_pretrained(
|
|
63
|
+
release=config["sae_release"],
|
|
64
|
+
sae_id=sae_id,
|
|
65
|
+
device=device,
|
|
66
|
+
)
|
|
67
|
+
|
|
68
|
+
return sae, sparsity
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
def load_model_and_tokenizer(
|
|
72
|
+
model_name: str,
|
|
73
|
+
device: str = "cuda:0",
|
|
74
|
+
eval_mode: bool = True,
|
|
75
|
+
) -> tuple:
|
|
76
|
+
"""
|
|
77
|
+
Load HuggingFace model and tokenizer.
|
|
78
|
+
|
|
79
|
+
Args:
|
|
80
|
+
model_name: HuggingFace model name
|
|
81
|
+
device: Device to load model on
|
|
82
|
+
eval_mode: Whether to set model to eval mode (default True for inference)
|
|
83
|
+
|
|
84
|
+
Returns:
|
|
85
|
+
Tuple of (model, tokenizer)
|
|
86
|
+
"""
|
|
87
|
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
88
|
+
|
|
89
|
+
model = AutoModelForCausalLM.from_pretrained(
|
|
90
|
+
model_name,
|
|
91
|
+
torch_dtype=preferred_dtype(device),
|
|
92
|
+
device_map=device,
|
|
93
|
+
trust_remote_code=True,
|
|
94
|
+
)
|
|
95
|
+
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
|
|
96
|
+
if tokenizer.pad_token is None:
|
|
97
|
+
tokenizer.pad_token = tokenizer.eos_token
|
|
98
|
+
|
|
99
|
+
if eval_mode:
|
|
100
|
+
model.eval()
|
|
101
|
+
|
|
102
|
+
return model, tokenizer
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
def generate_contrastive_pairs(
|
|
106
|
+
task: str,
|
|
107
|
+
num_pairs: int,
|
|
108
|
+
seed: int = 42,
|
|
109
|
+
verbose: bool = False,
|
|
110
|
+
) -> tuple[list[dict], str]:
|
|
111
|
+
"""
|
|
112
|
+
Generate contrastive pairs from an lm-eval task.
|
|
113
|
+
|
|
114
|
+
Args:
|
|
115
|
+
task: lm-eval task name (e.g., 'boolq', 'cb')
|
|
116
|
+
num_pairs: Number of pairs to generate
|
|
117
|
+
seed: Random seed for reproducibility
|
|
118
|
+
verbose: Whether to print verbose output
|
|
119
|
+
|
|
120
|
+
Returns:
|
|
121
|
+
Tuple of (pairs list, path to temporary pairs file).
|
|
122
|
+
Caller is responsible for cleaning up the file if needed.
|
|
123
|
+
"""
|
|
124
|
+
from wisent.core.cli.generate_pairs_from_task import execute_generate_pairs_from_task
|
|
125
|
+
|
|
126
|
+
pairs_file = tempfile.NamedTemporaryFile(mode='w', suffix='_pairs.json', delete=False).name
|
|
127
|
+
pairs_args = Namespace(
|
|
128
|
+
task_name=task,
|
|
129
|
+
limit=num_pairs,
|
|
130
|
+
output=pairs_file,
|
|
131
|
+
seed=seed,
|
|
132
|
+
verbose=verbose,
|
|
133
|
+
)
|
|
134
|
+
execute_generate_pairs_from_task(pairs_args)
|
|
135
|
+
|
|
136
|
+
with open(pairs_file) as f:
|
|
137
|
+
pairs_data = json.load(f)
|
|
138
|
+
pairs = pairs_data["pairs"]
|
|
139
|
+
|
|
140
|
+
return pairs, pairs_file
|
|
141
|
+
|
|
142
|
+
|
|
143
|
+
def create_test_only_task(task_name: str, train_ratio: float = 0.8) -> dict:
|
|
144
|
+
"""
|
|
145
|
+
Create a task that evaluates only on our test split.
|
|
146
|
+
|
|
147
|
+
This ensures no overlap with the data used for steering vector training.
|
|
148
|
+
|
|
149
|
+
Args:
|
|
150
|
+
task_name: lm-eval task name (e.g., 'boolq', 'cb')
|
|
151
|
+
train_ratio: Fraction of data used for training (default 0.8)
|
|
152
|
+
|
|
153
|
+
Returns:
|
|
154
|
+
Task dict with test split configured
|
|
155
|
+
"""
|
|
156
|
+
from lm_eval.tasks import get_task_dict
|
|
157
|
+
from wisent.core.utils.dataset_splits import get_test_docs
|
|
158
|
+
|
|
159
|
+
task_dict = get_task_dict([task_name])
|
|
160
|
+
task = task_dict[task_name]
|
|
161
|
+
|
|
162
|
+
test_docs = get_test_docs(task, benchmark_name=task_name, train_ratio=train_ratio)
|
|
163
|
+
test_pct = round((1 - train_ratio) * 100)
|
|
164
|
+
|
|
165
|
+
print(f"Test split size: {len(test_docs)} docs ({test_pct}% of pooled data)")
|
|
166
|
+
|
|
167
|
+
# Override task's doc methods to use our test split
|
|
168
|
+
task.test_docs = lambda: test_docs
|
|
169
|
+
task.has_test_docs = lambda: True
|
|
170
|
+
task._eval_docs = test_docs
|
|
171
|
+
|
|
172
|
+
return {task_name: task}
|
|
173
|
+
|
|
174
|
+
|
|
175
|
+
def extract_accuracy(results: dict, task: str) -> float:
|
|
176
|
+
"""
|
|
177
|
+
Extract accuracy from lm-eval results.
|
|
178
|
+
|
|
179
|
+
Args:
|
|
180
|
+
results: Results dict from lm-eval evaluator
|
|
181
|
+
task: Task name to extract accuracy for
|
|
182
|
+
|
|
183
|
+
Returns:
|
|
184
|
+
Accuracy value (0.0 if not found)
|
|
185
|
+
"""
|
|
186
|
+
task_results = results.get("results", {}).get(task, {})
|
|
187
|
+
for key in ["acc", "acc,none", "accuracy", "acc_norm", "acc_norm,none"]:
|
|
188
|
+
if key in task_results:
|
|
189
|
+
return task_results[key]
|
|
190
|
+
return 0.0
|
|
191
|
+
|
|
192
|
+
|
|
193
|
+
def run_lm_eval_evaluation(
|
|
194
|
+
wisent_model: "WisentModel",
|
|
195
|
+
task_dict: dict,
|
|
196
|
+
task_name: str,
|
|
197
|
+
batch_size: int | str = 1,
|
|
198
|
+
max_batch_size: int = 8,
|
|
199
|
+
limit: int | None = None,
|
|
200
|
+
) -> dict:
|
|
201
|
+
"""
|
|
202
|
+
Run evaluation using lm-eval-harness.
|
|
203
|
+
|
|
204
|
+
Args:
|
|
205
|
+
wisent_model: WisentModel instance
|
|
206
|
+
task_dict: Task dict from create_test_only_task
|
|
207
|
+
task_name: lm-eval task name
|
|
208
|
+
batch_size: Batch size for evaluation
|
|
209
|
+
max_batch_size: Max batch size for lm-eval internal batching
|
|
210
|
+
limit: Max number of examples to evaluate
|
|
211
|
+
|
|
212
|
+
Returns:
|
|
213
|
+
Full results dict from lm-eval
|
|
214
|
+
"""
|
|
215
|
+
from lm_eval import evaluator
|
|
216
|
+
from lm_eval.models.huggingface import HFLM
|
|
217
|
+
|
|
218
|
+
lm = HFLM(
|
|
219
|
+
pretrained=wisent_model.hf_model,
|
|
220
|
+
tokenizer=wisent_model.tokenizer,
|
|
221
|
+
batch_size=batch_size,
|
|
222
|
+
max_batch_size=max_batch_size,
|
|
223
|
+
)
|
|
224
|
+
|
|
225
|
+
results = evaluator.evaluate(
|
|
226
|
+
lm=lm,
|
|
227
|
+
task_dict=task_dict,
|
|
228
|
+
limit=limit,
|
|
229
|
+
)
|
|
230
|
+
|
|
231
|
+
return results
|
|
232
|
+
|
|
233
|
+
|
|
234
|
+
def run_ll_evaluation(
|
|
235
|
+
wisent_model: "WisentModel",
|
|
236
|
+
task_dict: dict,
|
|
237
|
+
task_name: str,
|
|
238
|
+
limit: int | None = None,
|
|
239
|
+
) -> float:
|
|
240
|
+
"""
|
|
241
|
+
Run evaluation using wisent's LogLikelihoodsEvaluator.
|
|
242
|
+
|
|
243
|
+
Args:
|
|
244
|
+
wisent_model: WisentModel instance
|
|
245
|
+
task_dict: Task dict from create_test_only_task
|
|
246
|
+
task_name: lm-eval task name
|
|
247
|
+
limit: Max number of examples to evaluate
|
|
248
|
+
|
|
249
|
+
Returns:
|
|
250
|
+
Accuracy as float
|
|
251
|
+
"""
|
|
252
|
+
from wisent.core.evaluators.benchmark_specific.log_likelihoods_evaluator import LogLikelihoodsEvaluator
|
|
253
|
+
from wisent.core.contrastive_pairs.lm_eval_pairs.lm_extractor_registry import get_extractor
|
|
254
|
+
|
|
255
|
+
ll_evaluator = LogLikelihoodsEvaluator()
|
|
256
|
+
extractor = get_extractor(task_name)
|
|
257
|
+
|
|
258
|
+
task = task_dict[task_name]
|
|
259
|
+
docs = list(task.test_docs())
|
|
260
|
+
|
|
261
|
+
if limit:
|
|
262
|
+
docs = docs[:limit]
|
|
263
|
+
|
|
264
|
+
print(f"Evaluating {len(docs)} examples with LogLikelihoodsEvaluator")
|
|
265
|
+
|
|
266
|
+
correct = 0
|
|
267
|
+
for i, doc in enumerate(docs):
|
|
268
|
+
question = task.doc_to_text(doc)
|
|
269
|
+
choices, expected = extractor.extract_choices_and_answer(task, doc)
|
|
270
|
+
|
|
271
|
+
result = ll_evaluator.evaluate(
|
|
272
|
+
response="",
|
|
273
|
+
expected=expected,
|
|
274
|
+
model=wisent_model,
|
|
275
|
+
question=question,
|
|
276
|
+
choices=choices,
|
|
277
|
+
task_name=task_name,
|
|
278
|
+
)
|
|
279
|
+
|
|
280
|
+
if result.ground_truth == "TRUTHFUL":
|
|
281
|
+
correct += 1
|
|
282
|
+
|
|
283
|
+
if (i + 1) % 50 == 0:
|
|
284
|
+
print(f" Processed {i + 1}/{len(docs)}, acc: {correct/(i+1):.4f}")
|
|
285
|
+
|
|
286
|
+
return correct / len(docs) if docs else 0.0
|
|
287
|
+
|
|
288
|
+
|
|
289
|
+
def load_steering_vector(path: str | Path, default_method: str = "unknown") -> dict:
|
|
290
|
+
"""
|
|
291
|
+
Load a steering vector from file.
|
|
292
|
+
|
|
293
|
+
Args:
|
|
294
|
+
path: Path to steering vector file (.json or .pt)
|
|
295
|
+
default_method: Default method name if not found in file
|
|
296
|
+
|
|
297
|
+
Returns:
|
|
298
|
+
Dictionary with steering vectors and metadata
|
|
299
|
+
"""
|
|
300
|
+
path = Path(path)
|
|
301
|
+
|
|
302
|
+
if path.suffix == ".pt":
|
|
303
|
+
from wisent.core.utils.device import resolve_default_device
|
|
304
|
+
data = torch.load(path, map_location=resolve_default_device(), weights_only=False)
|
|
305
|
+
layer_idx = str(data.get("layer_index", data.get("layer", 1)))
|
|
306
|
+
return {
|
|
307
|
+
"steering_vectors": {layer_idx: data["steering_vector"].tolist()},
|
|
308
|
+
"layers": [layer_idx],
|
|
309
|
+
"model": data.get("model", "unknown"),
|
|
310
|
+
"method": data.get("method", default_method),
|
|
311
|
+
"trait_label": data.get("trait_label", "unknown"),
|
|
312
|
+
}
|
|
313
|
+
else:
|
|
314
|
+
with open(path) as f:
|
|
315
|
+
return json.load(f)
|
|
316
|
+
|
|
317
|
+
|
|
318
|
+
def apply_steering_to_model(
|
|
319
|
+
model: "WisentModel",
|
|
320
|
+
steering_data: dict,
|
|
321
|
+
scale: float = 1.0,
|
|
322
|
+
) -> None:
|
|
323
|
+
"""
|
|
324
|
+
Apply loaded steering vectors to a WisentModel.
|
|
325
|
+
|
|
326
|
+
Args:
|
|
327
|
+
model: WisentModel instance
|
|
328
|
+
steering_data: Dictionary from load_steering_vector()
|
|
329
|
+
scale: Scaling factor for steering strength
|
|
330
|
+
"""
|
|
331
|
+
raw_map = {}
|
|
332
|
+
dtype = preferred_dtype()
|
|
333
|
+
for layer_str, vec_list in steering_data["steering_vectors"].items():
|
|
334
|
+
raw_map[layer_str] = torch.tensor(vec_list, dtype=dtype)
|
|
335
|
+
|
|
336
|
+
model.set_steering_from_raw(raw_map, scale=scale, normalize=False)
|
|
337
|
+
model.apply_steering()
|
|
338
|
+
|
|
339
|
+
|
|
340
|
+
def remove_steering(model: "WisentModel") -> None:
|
|
341
|
+
"""Remove steering from a WisentModel."""
|
|
342
|
+
model.detach()
|
|
343
|
+
model.clear_steering()
|
|
344
|
+
|
|
345
|
+
|
|
346
|
+
def convert_to_lm_eval_format(
|
|
347
|
+
steering_data: dict,
|
|
348
|
+
output_path: str | Path,
|
|
349
|
+
scale: float = 1.0,
|
|
350
|
+
) -> Path:
|
|
351
|
+
"""
|
|
352
|
+
Convert our steering vector format to lm-eval's steered model format.
|
|
353
|
+
|
|
354
|
+
lm-eval expects:
|
|
355
|
+
{
|
|
356
|
+
"layers.N": {
|
|
357
|
+
"steering_vector": tensor of shape (1, hidden_dim),
|
|
358
|
+
"steering_coefficient": float,
|
|
359
|
+
"action": "add"
|
|
360
|
+
}
|
|
361
|
+
}
|
|
362
|
+
"""
|
|
363
|
+
output_path = Path(output_path)
|
|
364
|
+
|
|
365
|
+
dtype = preferred_dtype()
|
|
366
|
+
lm_eval_config = {}
|
|
367
|
+
for layer_str, vec_list in steering_data["steering_vectors"].items():
|
|
368
|
+
vec = torch.tensor(vec_list, dtype=dtype)
|
|
369
|
+
# lm-eval expects shape (1, hidden_dim)
|
|
370
|
+
if vec.dim() == 1:
|
|
371
|
+
vec = vec.unsqueeze(0)
|
|
372
|
+
|
|
373
|
+
layer_key = f"layers.{layer_str}"
|
|
374
|
+
lm_eval_config[layer_key] = {
|
|
375
|
+
"steering_vector": vec,
|
|
376
|
+
"steering_coefficient": scale,
|
|
377
|
+
"action": "add",
|
|
378
|
+
}
|
|
379
|
+
|
|
380
|
+
torch.save(lm_eval_config, output_path)
|
|
381
|
+
return output_path
|
|
@@ -66,8 +66,9 @@ class ActivationCollector:
|
|
|
66
66
|
pos_text = _resp_text(pair.positive_response)
|
|
67
67
|
neg_text = _resp_text(pair.negative_response)
|
|
68
68
|
|
|
69
|
-
|
|
70
|
-
|
|
69
|
+
needs_other = strategy in (ExtractionStrategy.MC_BALANCED, ExtractionStrategy.MC_COMPLETION)
|
|
70
|
+
other_for_pos = neg_text if needs_other else None
|
|
71
|
+
other_for_neg = pos_text if needs_other else None
|
|
71
72
|
|
|
72
73
|
pos = self._collect_single(
|
|
73
74
|
pair.prompt, pos_text, strategy, layers, normalize,
|