wisent 0.7.901__py3-none-any.whl → 0.7.1045__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- wisent/__init__.py +1 -1
- wisent/comparison/__init__.py +1 -0
- wisent/comparison/detect_bos_features.py +275 -0
- wisent/comparison/fgaa.py +465 -0
- wisent/comparison/lora.py +669 -0
- wisent/comparison/lora_dpo.py +592 -0
- wisent/comparison/main.py +444 -0
- wisent/comparison/ours.py +76 -0
- wisent/comparison/sae.py +304 -0
- wisent/comparison/utils.py +381 -0
- wisent/core/activations/activations_collector.py +3 -2
- wisent/core/activations/extraction_strategy.py +8 -4
- wisent/core/cli/agent/apply_steering.py +7 -5
- wisent/core/cli/agent/train_classifier.py +4 -3
- wisent/core/cli/generate_vector_from_task.py +11 -20
- wisent/core/cli/get_activations.py +1 -1
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/boolq.py +20 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/cb.py +8 -1
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/truthfulqa_mc1.py +8 -1
- wisent/core/parser_arguments/generate_vector_from_task_parser.py +4 -11
- wisent/core/parser_arguments/get_activations_parser.py +5 -14
- {wisent-0.7.901.dist-info → wisent-0.7.1045.dist-info}/METADATA +5 -1
- {wisent-0.7.901.dist-info → wisent-0.7.1045.dist-info}/RECORD +27 -91
- wisent/examples/contrastive_pairs/humanization_human_vs_ai.json +0 -2112
- wisent/examples/scripts/1/test_basqueglue_evaluation.json +0 -51
- wisent/examples/scripts/1/test_basqueglue_pairs.json +0 -14
- wisent/examples/scripts/1/test_bec2016eu_evaluation.json +0 -51
- wisent/examples/scripts/1/test_bec2016eu_pairs.json +0 -14
- wisent/examples/scripts/1/test_belebele_evaluation.json +0 -51
- wisent/examples/scripts/1/test_belebele_pairs.json +0 -14
- wisent/examples/scripts/1/test_benchmarks_evaluation.json +0 -51
- wisent/examples/scripts/1/test_benchmarks_pairs.json +0 -14
- wisent/examples/scripts/1/test_bertaqa_evaluation.json +0 -51
- wisent/examples/scripts/1/test_bertaqa_pairs.json +0 -14
- wisent/examples/scripts/1/test_bhtc_v2_evaluation.json +0 -30
- wisent/examples/scripts/1/test_bhtc_v2_pairs.json +0 -8
- wisent/examples/scripts/1/test_boolq-seq2seq_evaluation.json +0 -30
- wisent/examples/scripts/1/test_boolq-seq2seq_pairs.json +0 -8
- wisent/examples/scripts/1/test_cabreu_evaluation.json +0 -30
- wisent/examples/scripts/1/test_cabreu_pairs.json +0 -8
- wisent/examples/scripts/1/test_careqa_en_evaluation.json +0 -30
- wisent/examples/scripts/1/test_careqa_en_pairs.json +0 -8
- wisent/examples/scripts/1/test_careqa_evaluation.json +0 -30
- wisent/examples/scripts/1/test_careqa_pairs.json +0 -8
- wisent/examples/scripts/1/test_catalanqa_evaluation.json +0 -30
- wisent/examples/scripts/1/test_catalanqa_pairs.json +0 -8
- wisent/examples/scripts/1/test_catcola_evaluation.json +0 -30
- wisent/examples/scripts/1/test_catcola_pairs.json +0 -8
- wisent/examples/scripts/1/test_chartqa_evaluation.json +0 -30
- wisent/examples/scripts/1/test_chartqa_pairs.json +0 -8
- wisent/examples/scripts/1/test_claim_stance_topic_evaluation.json +0 -30
- wisent/examples/scripts/1/test_claim_stance_topic_pairs.json +0 -8
- wisent/examples/scripts/1/test_cnn_dailymail_evaluation.json +0 -30
- wisent/examples/scripts/1/test_cnn_dailymail_pairs.json +0 -8
- wisent/examples/scripts/1/test_cocoteros_es_evaluation.json +0 -30
- wisent/examples/scripts/1/test_cocoteros_es_pairs.json +0 -8
- wisent/examples/scripts/1/test_coedit_gec_evaluation.json +0 -30
- wisent/examples/scripts/1/test_coedit_gec_pairs.json +0 -8
- wisent/examples/scripts/1/test_cola_evaluation.json +0 -30
- wisent/examples/scripts/1/test_cola_pairs.json +0 -8
- wisent/examples/scripts/1/test_coqcat_evaluation.json +0 -30
- wisent/examples/scripts/1/test_coqcat_pairs.json +0 -8
- wisent/examples/scripts/1/test_dbpedia_14_evaluation.json +0 -30
- wisent/examples/scripts/1/test_dbpedia_14_pairs.json +0 -8
- wisent/examples/scripts/1/test_epec_koref_bin_evaluation.json +0 -30
- wisent/examples/scripts/1/test_epec_koref_bin_pairs.json +0 -8
- wisent/examples/scripts/1/test_ethos_binary_evaluation.json +0 -30
- wisent/examples/scripts/1/test_ethos_binary_pairs.json +0 -8
- wisent/examples/scripts/2/test_afrimgsm_direct_amh_evaluation.json +0 -30
- wisent/examples/scripts/2/test_afrimgsm_direct_amh_pairs.json +0 -8
- wisent/examples/scripts/2/test_afrimmlu_direct_amh_evaluation.json +0 -30
- wisent/examples/scripts/2/test_afrimmlu_direct_amh_pairs.json +0 -8
- wisent/examples/scripts/2/test_afrixnli_en_direct_amh_evaluation.json +0 -30
- wisent/examples/scripts/2/test_afrixnli_en_direct_amh_pairs.json +0 -8
- wisent/examples/scripts/2/test_arc_ar_evaluation.json +0 -30
- wisent/examples/scripts/2/test_arc_ar_pairs.json +0 -8
- wisent/examples/scripts/2/test_atis_evaluation.json +0 -30
- wisent/examples/scripts/2/test_atis_pairs.json +0 -8
- wisent/examples/scripts/2/test_babi_evaluation.json +0 -30
- wisent/examples/scripts/2/test_babi_pairs.json +0 -8
- wisent/examples/scripts/2/test_babilong_evaluation.json +0 -30
- wisent/examples/scripts/2/test_babilong_pairs.json +0 -8
- wisent/examples/scripts/2/test_bangla_mmlu_evaluation.json +0 -30
- wisent/examples/scripts/2/test_bangla_mmlu_pairs.json +0 -8
- wisent/examples/scripts/2/test_basque-glue_pairs.json +0 -14
- wisent/examples/scripts/generate_paper_data.py +0 -384
- wisent/examples/scripts/intervention_validation.py +0 -626
- wisent/examples/scripts/results/test_AraDiCE_ArabicMMLU_lev_evaluation.json +0 -324
- wisent/examples/scripts/results/test_AraDiCE_ArabicMMLU_lev_pairs.json +0 -92
- wisent/examples/scripts/results/test_aexams_IslamicStudies_evaluation.json +0 -324
- wisent/examples/scripts/results/test_aexams_IslamicStudies_pairs.json +0 -92
- wisent/examples/scripts/results/test_afrimgsm_pairs.json +0 -92
- wisent/examples/scripts/results/test_afrimmlu_evaluation.json +0 -324
- wisent/examples/scripts/results/test_afrimmlu_pairs.json +0 -92
- wisent/examples/scripts/threshold_analysis.py +0 -434
- wisent/examples/scripts/visualization_gallery.py +0 -582
- {wisent-0.7.901.dist-info → wisent-0.7.1045.dist-info}/WHEEL +0 -0
- {wisent-0.7.901.dist-info → wisent-0.7.1045.dist-info}/entry_points.txt +0 -0
- {wisent-0.7.901.dist-info → wisent-0.7.1045.dist-info}/licenses/LICENSE +0 -0
- {wisent-0.7.901.dist-info → wisent-0.7.1045.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,465 @@
|
|
|
1
|
+
"""
|
|
2
|
+
FGAA (Feature Guided Activation Addition) steering method.
|
|
3
|
+
|
|
4
|
+
Implements the method from "Interpretable Steering of Large Language Models
|
|
5
|
+
with Feature Guided Activation Additions" (arXiv:2501.09929).
|
|
6
|
+
|
|
7
|
+
Uses Gemma Scope SAEs and pre-computed effect approximators.
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
from __future__ import annotations
|
|
11
|
+
|
|
12
|
+
import json
|
|
13
|
+
from pathlib import Path
|
|
14
|
+
from typing import TYPE_CHECKING
|
|
15
|
+
|
|
16
|
+
import torch
|
|
17
|
+
from huggingface_hub import hf_hub_download
|
|
18
|
+
|
|
19
|
+
from wisent.comparison.utils import (
|
|
20
|
+
apply_steering_to_model,
|
|
21
|
+
remove_steering,
|
|
22
|
+
convert_to_lm_eval_format,
|
|
23
|
+
generate_contrastive_pairs,
|
|
24
|
+
load_model_and_tokenizer,
|
|
25
|
+
load_sae,
|
|
26
|
+
SAE_CONFIGS,
|
|
27
|
+
)
|
|
28
|
+
|
|
29
|
+
if TYPE_CHECKING:
|
|
30
|
+
from wisent.core.models.wisent_model import WisentModel
|
|
31
|
+
|
|
32
|
+
__all__ = ["generate_steering_vector", "apply_steering_to_model", "remove_steering", "convert_to_lm_eval_format"]
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
# BOS feature indices - these features activate most strongly on the BOS token
|
|
36
|
+
# Paper features from Appendix G (5 features)
|
|
37
|
+
BOS_FEATURES_PAPER = {
|
|
38
|
+
"google/gemma-2-2b": [11087, 3220, 11752, 12160, 11498],
|
|
39
|
+
"google/gemma-2-9b": [], # Not listed in paper
|
|
40
|
+
}
|
|
41
|
+
|
|
42
|
+
# Detected features from running detect_bos_features.py (top 12 by mean activation)
|
|
43
|
+
BOS_FEATURES_DETECTED = {
|
|
44
|
+
"google/gemma-2-2b": [1041, 7507, 11087, 3220, 11767, 11752, 14669, 6889, 12160, 13700, 2747, 11498],
|
|
45
|
+
"google/gemma-2-9b": [8032, 11906, 7768, 14845, 14483, 10562, 8892, 9151, 5721, 15738, 5285, 13895],
|
|
46
|
+
}
|
|
47
|
+
|
|
48
|
+
# FGAA-specific: effect approximator config (adapter files)
|
|
49
|
+
FGAA_ADAPTER_FILES = {
|
|
50
|
+
"google/gemma-2-2b": "adapter_2b_layer_12.pt",
|
|
51
|
+
"google/gemma-2-9b": "adapter_9b_layer_12.pt",
|
|
52
|
+
}
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
def load_effect_approximator(model_name: str, device: str = "cuda:0") -> tuple[torch.Tensor, torch.Tensor]:
|
|
56
|
+
"""
|
|
57
|
+
Load the pre-trained effect approximator (adapter) from HuggingFace.
|
|
58
|
+
|
|
59
|
+
The adapter contains:
|
|
60
|
+
- W: [d_model, d_sae] - maps SAE feature space to model activation space
|
|
61
|
+
- b: [d_sae] - bias term
|
|
62
|
+
|
|
63
|
+
Args:
|
|
64
|
+
model_name: HuggingFace model name
|
|
65
|
+
device: Device to load on
|
|
66
|
+
|
|
67
|
+
Returns:
|
|
68
|
+
Tuple of (W, b) tensors
|
|
69
|
+
"""
|
|
70
|
+
if model_name not in FGAA_ADAPTER_FILES:
|
|
71
|
+
raise ValueError(f"No effect approximator for model '{model_name}'")
|
|
72
|
+
|
|
73
|
+
adapter_file = FGAA_ADAPTER_FILES[model_name]
|
|
74
|
+
|
|
75
|
+
print(f" Loading adapter from schalnev/sae-ts-effects / {adapter_file}")
|
|
76
|
+
path = hf_hub_download(
|
|
77
|
+
repo_id="schalnev/sae-ts-effects",
|
|
78
|
+
filename=adapter_file,
|
|
79
|
+
repo_type="dataset",
|
|
80
|
+
)
|
|
81
|
+
|
|
82
|
+
adapter = torch.load(path, map_location=device, weights_only=False)
|
|
83
|
+
|
|
84
|
+
# Adapter is OrderedDict with 'W' and 'b'
|
|
85
|
+
W = adapter["W"].to(device) # [d_model, d_sae]
|
|
86
|
+
b = adapter["b"].to(device) # [d_sae]
|
|
87
|
+
|
|
88
|
+
print(f" Adapter W shape: {W.shape}, b shape: {b.shape}")
|
|
89
|
+
|
|
90
|
+
return W, b
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
def compute_v_diff(
|
|
94
|
+
model,
|
|
95
|
+
tokenizer,
|
|
96
|
+
sae,
|
|
97
|
+
pairs: list[dict],
|
|
98
|
+
layer_idx: int,
|
|
99
|
+
device: str,
|
|
100
|
+
) -> torch.Tensor:
|
|
101
|
+
"""
|
|
102
|
+
Compute v_diff: the difference vector between positive and negative examples in SAE space.
|
|
103
|
+
|
|
104
|
+
v_diff = mean(f(h_l(x+))) - mean(f(h_l(x-)))
|
|
105
|
+
|
|
106
|
+
Args:
|
|
107
|
+
model: HuggingFace model
|
|
108
|
+
tokenizer: Tokenizer
|
|
109
|
+
sae: SAE object from sae_lens
|
|
110
|
+
pairs: List of contrastive pairs
|
|
111
|
+
layer_idx: Layer to extract activations from
|
|
112
|
+
device: Device
|
|
113
|
+
|
|
114
|
+
Returns:
|
|
115
|
+
v_diff tensor of shape [d_sae]
|
|
116
|
+
"""
|
|
117
|
+
pos_features_list = []
|
|
118
|
+
neg_features_list = []
|
|
119
|
+
|
|
120
|
+
print(f" Computing v_diff from {len(pairs)} pairs...")
|
|
121
|
+
|
|
122
|
+
for i, pair in enumerate(pairs):
|
|
123
|
+
prompt = pair["prompt"]
|
|
124
|
+
pos_response = pair["positive_response"]["model_response"]
|
|
125
|
+
neg_response = pair["negative_response"]["model_response"]
|
|
126
|
+
|
|
127
|
+
pos_text = f"{prompt} {pos_response}"
|
|
128
|
+
neg_text = f"{prompt} {neg_response}"
|
|
129
|
+
|
|
130
|
+
# Get activations and encode through SAE
|
|
131
|
+
pos_acts = _get_residual_stream_activations(model, tokenizer, pos_text, layer_idx, device)
|
|
132
|
+
pos_acts = pos_acts.to(device).to(sae.W_enc.dtype)
|
|
133
|
+
# SAE encode: latents = (x - b_dec) @ W_enc + b_enc
|
|
134
|
+
pos_latents = sae.encode(pos_acts)
|
|
135
|
+
# Mean over sequence dimension
|
|
136
|
+
pos_features_list.append(pos_latents.mean(dim=1).detach()) # [1, d_sae]
|
|
137
|
+
|
|
138
|
+
neg_acts = _get_residual_stream_activations(model, tokenizer, neg_text, layer_idx, device)
|
|
139
|
+
neg_acts = neg_acts.to(device).to(sae.W_enc.dtype)
|
|
140
|
+
neg_latents = sae.encode(neg_acts)
|
|
141
|
+
neg_features_list.append(neg_latents.mean(dim=1).detach())
|
|
142
|
+
|
|
143
|
+
if (i + 1) % 10 == 0:
|
|
144
|
+
print(f" Processed {i + 1}/{len(pairs)} pairs")
|
|
145
|
+
|
|
146
|
+
# Stack and compute mean
|
|
147
|
+
pos_features = torch.cat(pos_features_list, dim=0) # [num_pairs, d_sae]
|
|
148
|
+
neg_features = torch.cat(neg_features_list, dim=0)
|
|
149
|
+
|
|
150
|
+
v_diff = pos_features.mean(dim=0) - neg_features.mean(dim=0) # [d_sae]
|
|
151
|
+
|
|
152
|
+
print(f" v_diff computed, shape: {v_diff.shape}")
|
|
153
|
+
print(f" v_diff stats: mean={v_diff.mean():.6f}, std={v_diff.std():.6f}, "
|
|
154
|
+
f"min={v_diff.min():.6f}, max={v_diff.max():.6f}")
|
|
155
|
+
|
|
156
|
+
return v_diff
|
|
157
|
+
|
|
158
|
+
|
|
159
|
+
def compute_v_target(
|
|
160
|
+
v_diff: torch.Tensor,
|
|
161
|
+
sparsity: torch.Tensor,
|
|
162
|
+
model_name: str,
|
|
163
|
+
bos_features_source: str = "detected",
|
|
164
|
+
density_threshold: float = 0.01,
|
|
165
|
+
top_k_positive: int = 50,
|
|
166
|
+
top_k_negative: int = 0,
|
|
167
|
+
) -> torch.Tensor:
|
|
168
|
+
"""
|
|
169
|
+
Compute v_target by filtering v_diff.
|
|
170
|
+
|
|
171
|
+
Three filtering stages:
|
|
172
|
+
1. Density filtering: zero out features with activation density > threshold
|
|
173
|
+
2. BOS token filtering: zero out features that activate mainly on BOS token
|
|
174
|
+
3. Top-k selection: keep top positive and negative features
|
|
175
|
+
|
|
176
|
+
Args:
|
|
177
|
+
v_diff: Difference vector in SAE space [d_sae]
|
|
178
|
+
sparsity: Feature sparsity/density values from SAE [d_sae]
|
|
179
|
+
model_name: Model name to look up BOS features
|
|
180
|
+
bos_features_source: Source of BOS features - "paper" (5 features), "detected" (12 features), or "none"
|
|
181
|
+
density_threshold: Zero out features with density above this (default 0.01)
|
|
182
|
+
top_k_positive: Number of top positive features to keep
|
|
183
|
+
top_k_negative: Number of top negative features to keep (paper uses 0)
|
|
184
|
+
|
|
185
|
+
Returns:
|
|
186
|
+
v_target tensor of shape [d_sae]
|
|
187
|
+
"""
|
|
188
|
+
v_filtered = v_diff.clone()
|
|
189
|
+
|
|
190
|
+
# Stage 1: Density filtering
|
|
191
|
+
# Zero out features that are too commonly activated (not specific enough)
|
|
192
|
+
if sparsity is not None:
|
|
193
|
+
density_mask = sparsity > density_threshold
|
|
194
|
+
num_filtered = density_mask.sum().item()
|
|
195
|
+
v_filtered[density_mask] = 0
|
|
196
|
+
print(f" Density filtering: zeroed {num_filtered} features (density > {density_threshold})")
|
|
197
|
+
|
|
198
|
+
# Stage 2: BOS filtering
|
|
199
|
+
# Zero out features that activate mainly on BOS tokens
|
|
200
|
+
if bos_features_source == "paper":
|
|
201
|
+
bos_features = BOS_FEATURES_PAPER.get(model_name, [])
|
|
202
|
+
elif bos_features_source == "detected":
|
|
203
|
+
bos_features = BOS_FEATURES_DETECTED.get(model_name, [])
|
|
204
|
+
else: # "none"
|
|
205
|
+
bos_features = []
|
|
206
|
+
if bos_features:
|
|
207
|
+
for idx in bos_features:
|
|
208
|
+
v_filtered[idx] = 0
|
|
209
|
+
print(f" BOS filtering: zeroed {len(bos_features)} features {bos_features}")
|
|
210
|
+
else:
|
|
211
|
+
print(f" BOS filtering: no known BOS features for {model_name}")
|
|
212
|
+
|
|
213
|
+
# Stage 3: Top-k selection
|
|
214
|
+
v_target = torch.zeros_like(v_filtered)
|
|
215
|
+
|
|
216
|
+
# Get top positive features
|
|
217
|
+
if top_k_positive > 0:
|
|
218
|
+
pos_values = v_filtered.clone()
|
|
219
|
+
pos_values[pos_values < 0] = 0
|
|
220
|
+
top_pos_values, top_pos_indices = pos_values.topk(min(top_k_positive, (pos_values > 0).sum().item()))
|
|
221
|
+
v_target[top_pos_indices] = v_filtered[top_pos_indices]
|
|
222
|
+
print(f" Selected top {len(top_pos_indices)} positive features")
|
|
223
|
+
|
|
224
|
+
# Get top negative features (paper uses 0)
|
|
225
|
+
if top_k_negative > 0:
|
|
226
|
+
neg_values = -v_filtered.clone()
|
|
227
|
+
neg_values[neg_values < 0] = 0
|
|
228
|
+
top_neg_values, top_neg_indices = neg_values.topk(min(top_k_negative, (neg_values > 0).sum().item()))
|
|
229
|
+
v_target[top_neg_indices] = v_filtered[top_neg_indices]
|
|
230
|
+
print(f" Selected top {len(top_neg_indices)} negative features")
|
|
231
|
+
|
|
232
|
+
num_nonzero = (v_target != 0).sum().item()
|
|
233
|
+
print(f" v_target: {num_nonzero} non-zero features")
|
|
234
|
+
|
|
235
|
+
return v_target
|
|
236
|
+
|
|
237
|
+
|
|
238
|
+
def compute_v_opt(
|
|
239
|
+
v_target: torch.Tensor,
|
|
240
|
+
W: torch.Tensor,
|
|
241
|
+
b: torch.Tensor,
|
|
242
|
+
) -> torch.Tensor:
|
|
243
|
+
"""
|
|
244
|
+
Compute v_opt using the effect approximator.
|
|
245
|
+
|
|
246
|
+
From paper: v_opt = (W @ v_target_norm) / ||W @ v_target_norm|| - (W @ b) / ||W @ b||
|
|
247
|
+
|
|
248
|
+
Args:
|
|
249
|
+
v_target: Target vector in SAE space [d_sae]
|
|
250
|
+
W: Effect approximator weight matrix [d_model, d_sae]
|
|
251
|
+
b: Effect approximator bias [d_sae]
|
|
252
|
+
|
|
253
|
+
Returns:
|
|
254
|
+
v_opt tensor of shape [d_model]
|
|
255
|
+
"""
|
|
256
|
+
# L1 normalize v_target (as specified in paper)
|
|
257
|
+
v_target_norm = v_target / (v_target.abs().sum() + 1e-8)
|
|
258
|
+
|
|
259
|
+
# W is [d_model, d_sae], v_target_norm is [d_sae]
|
|
260
|
+
# W @ v_target_norm -> [d_model]
|
|
261
|
+
Wv = W @ v_target_norm
|
|
262
|
+
Wv_normalized = Wv / (Wv.norm() + 1e-8)
|
|
263
|
+
|
|
264
|
+
# Bias term: W @ b -> [d_model]
|
|
265
|
+
Wb = W @ b
|
|
266
|
+
Wb_normalized = Wb / (Wb.norm() + 1e-8)
|
|
267
|
+
|
|
268
|
+
# Final v_opt (paper formula)
|
|
269
|
+
v_opt = Wv_normalized - Wb_normalized
|
|
270
|
+
|
|
271
|
+
print(f" v_opt computed, shape: {v_opt.shape}, norm: {v_opt.norm():.6f}")
|
|
272
|
+
|
|
273
|
+
return v_opt
|
|
274
|
+
|
|
275
|
+
|
|
276
|
+
def _get_residual_stream_activations(
|
|
277
|
+
model,
|
|
278
|
+
tokenizer,
|
|
279
|
+
text: str,
|
|
280
|
+
layer_idx: int,
|
|
281
|
+
device: str,
|
|
282
|
+
) -> torch.Tensor:
|
|
283
|
+
"""
|
|
284
|
+
Get residual stream activations from a specific layer.
|
|
285
|
+
|
|
286
|
+
Uses output_hidden_states=True (same as wisent's ActivationCollector).
|
|
287
|
+
|
|
288
|
+
Args:
|
|
289
|
+
model: HuggingFace model
|
|
290
|
+
tokenizer: Tokenizer
|
|
291
|
+
text: Input text
|
|
292
|
+
layer_idx: Layer index (0-indexed)
|
|
293
|
+
device: Device
|
|
294
|
+
|
|
295
|
+
Returns:
|
|
296
|
+
Tensor of shape (1, seq_len, d_model)
|
|
297
|
+
"""
|
|
298
|
+
inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=1024)
|
|
299
|
+
inputs = {k: v.to(device) for k, v in inputs.items()}
|
|
300
|
+
|
|
301
|
+
with torch.no_grad():
|
|
302
|
+
out = model(**inputs, output_hidden_states=True, use_cache=False)
|
|
303
|
+
|
|
304
|
+
# hidden_states is tuple: (embedding, layer0, layer1, ..., layerN)
|
|
305
|
+
# layer_idx=0 -> hs[1], layer_idx=12 -> hs[13]
|
|
306
|
+
hs = out.hidden_states
|
|
307
|
+
return hs[layer_idx + 1] # +1 because hs[0] is embedding layer
|
|
308
|
+
|
|
309
|
+
|
|
310
|
+
def generate_steering_vector(
|
|
311
|
+
task: str,
|
|
312
|
+
model_name: str,
|
|
313
|
+
output_path: str | Path,
|
|
314
|
+
trait_label: str = "correctness",
|
|
315
|
+
num_pairs: int = 50,
|
|
316
|
+
method: str = "fgaa",
|
|
317
|
+
layers: str | None = None,
|
|
318
|
+
device: str = "cuda:0",
|
|
319
|
+
keep_intermediate: bool = False,
|
|
320
|
+
density_threshold: float = 0.01,
|
|
321
|
+
top_k_positive: int = 50,
|
|
322
|
+
top_k_negative: int = 0,
|
|
323
|
+
bos_features_source: str = "detected",
|
|
324
|
+
**kwargs, # Accept additional kwargs for compatibility (e.g., extraction_strategy)
|
|
325
|
+
) -> Path:
|
|
326
|
+
"""
|
|
327
|
+
Generate a steering vector using the FGAA method.
|
|
328
|
+
|
|
329
|
+
Args:
|
|
330
|
+
task: lm-eval task name (e.g., 'boolq', 'cb')
|
|
331
|
+
model_name: HuggingFace model name (must be Gemma 2B or 9B)
|
|
332
|
+
output_path: Where to save the steering vector
|
|
333
|
+
trait_label: Label for the trait being steered
|
|
334
|
+
num_pairs: Number of contrastive pairs to use
|
|
335
|
+
method: Method name (should be 'fgaa')
|
|
336
|
+
layers: Layer(s) to use (e.g., '12' or '10,11,12')
|
|
337
|
+
device: Device to run on
|
|
338
|
+
keep_intermediate: Whether to keep intermediate files
|
|
339
|
+
density_threshold: Density threshold for filtering (default 0.01)
|
|
340
|
+
top_k_positive: Number of top positive features to keep
|
|
341
|
+
top_k_negative: Number of top negative features to keep
|
|
342
|
+
bos_features_source: Source of BOS features - "paper" (5), "detected" (12), or "none"
|
|
343
|
+
|
|
344
|
+
Returns:
|
|
345
|
+
Path to the saved steering vector
|
|
346
|
+
"""
|
|
347
|
+
import gc
|
|
348
|
+
|
|
349
|
+
output_path = Path(output_path)
|
|
350
|
+
|
|
351
|
+
if model_name not in SAE_CONFIGS:
|
|
352
|
+
raise ValueError(
|
|
353
|
+
f"No SAE config for model '{model_name}'. "
|
|
354
|
+
f"Supported models: {list(SAE_CONFIGS.keys())}"
|
|
355
|
+
)
|
|
356
|
+
|
|
357
|
+
config = SAE_CONFIGS[model_name]
|
|
358
|
+
|
|
359
|
+
# Parse layers
|
|
360
|
+
if layers is None:
|
|
361
|
+
layer_indices = [config["default_layer"]]
|
|
362
|
+
elif layers == "all":
|
|
363
|
+
layer_indices = list(range(config["num_layers"]))
|
|
364
|
+
else:
|
|
365
|
+
layer_indices = [int(l.strip()) for l in layers.split(",")]
|
|
366
|
+
|
|
367
|
+
# Step 1: Generate contrastive pairs
|
|
368
|
+
print(f"Step 1: Generating contrastive pairs from task: {task}")
|
|
369
|
+
pairs, pairs_file = generate_contrastive_pairs(task, num_pairs)
|
|
370
|
+
print(f" Loaded {len(pairs)} contrastive pairs")
|
|
371
|
+
|
|
372
|
+
# Step 2: Load model
|
|
373
|
+
print(f"\nStep 2: Loading model {model_name}...")
|
|
374
|
+
model, tokenizer = load_model_and_tokenizer(model_name, device)
|
|
375
|
+
|
|
376
|
+
# Step 3: Load effect approximator (shared across layers)
|
|
377
|
+
print(f"\nStep 3: Loading effect approximator...")
|
|
378
|
+
W, b = load_effect_approximator(model_name, device=device)
|
|
379
|
+
|
|
380
|
+
steering_vectors = {}
|
|
381
|
+
feature_info = {}
|
|
382
|
+
|
|
383
|
+
for layer_idx in layer_indices:
|
|
384
|
+
print(f"\nStep 4: Processing layer {layer_idx}")
|
|
385
|
+
|
|
386
|
+
# Load SAE for this layer
|
|
387
|
+
sae, sparsity = load_sae(model_name, layer_idx, device=device)
|
|
388
|
+
|
|
389
|
+
# Compute v_diff
|
|
390
|
+
print(f"\nStep 5: Computing v_diff for layer {layer_idx}...")
|
|
391
|
+
v_diff = compute_v_diff(model, tokenizer, sae, pairs, layer_idx, device)
|
|
392
|
+
|
|
393
|
+
# Compute v_target
|
|
394
|
+
print(f"\nStep 6: Computing v_target for layer {layer_idx}...")
|
|
395
|
+
v_target = compute_v_target(
|
|
396
|
+
v_diff,
|
|
397
|
+
sparsity,
|
|
398
|
+
model_name,
|
|
399
|
+
bos_features_source=bos_features_source,
|
|
400
|
+
density_threshold=density_threshold,
|
|
401
|
+
top_k_positive=top_k_positive,
|
|
402
|
+
top_k_negative=top_k_negative,
|
|
403
|
+
)
|
|
404
|
+
|
|
405
|
+
# Compute v_opt
|
|
406
|
+
print(f"\nStep 7: Computing v_opt for layer {layer_idx}...")
|
|
407
|
+
v_opt = compute_v_opt(v_target, W, b)
|
|
408
|
+
|
|
409
|
+
steering_vectors[str(layer_idx)] = v_opt.cpu().float().tolist()
|
|
410
|
+
|
|
411
|
+
# Store feature info
|
|
412
|
+
nonzero_mask = v_target != 0
|
|
413
|
+
nonzero_indices = nonzero_mask.nonzero().squeeze(-1).tolist()
|
|
414
|
+
feature_info[str(layer_idx)] = {
|
|
415
|
+
"num_selected_features": len(nonzero_indices) if isinstance(nonzero_indices, list) else 1,
|
|
416
|
+
"selected_feature_indices": nonzero_indices[:20] if isinstance(nonzero_indices, list) else [nonzero_indices],
|
|
417
|
+
"v_diff_stats": {
|
|
418
|
+
"mean": v_diff.mean().item(),
|
|
419
|
+
"std": v_diff.std().item(),
|
|
420
|
+
"min": v_diff.min().item(),
|
|
421
|
+
"max": v_diff.max().item(),
|
|
422
|
+
},
|
|
423
|
+
}
|
|
424
|
+
|
|
425
|
+
# Cleanup SAE
|
|
426
|
+
del sae, sparsity, v_diff, v_target
|
|
427
|
+
gc.collect()
|
|
428
|
+
if torch.cuda.is_available():
|
|
429
|
+
torch.cuda.empty_cache()
|
|
430
|
+
|
|
431
|
+
# Cleanup
|
|
432
|
+
del model, W, b
|
|
433
|
+
gc.collect()
|
|
434
|
+
if torch.cuda.is_available():
|
|
435
|
+
torch.cuda.empty_cache()
|
|
436
|
+
torch.cuda.synchronize()
|
|
437
|
+
|
|
438
|
+
if not keep_intermediate:
|
|
439
|
+
import os
|
|
440
|
+
os.unlink(pairs_file)
|
|
441
|
+
|
|
442
|
+
# Save results
|
|
443
|
+
result = {
|
|
444
|
+
"steering_vectors": steering_vectors,
|
|
445
|
+
"layers": [str(l) for l in layer_indices],
|
|
446
|
+
"model": model_name,
|
|
447
|
+
"method": "fgaa",
|
|
448
|
+
"trait_label": trait_label,
|
|
449
|
+
"task": task,
|
|
450
|
+
"num_pairs": len(pairs),
|
|
451
|
+
"fgaa_params": {
|
|
452
|
+
"density_threshold": density_threshold,
|
|
453
|
+
"top_k_positive": top_k_positive,
|
|
454
|
+
"top_k_negative": top_k_negative,
|
|
455
|
+
"bos_features_source": bos_features_source,
|
|
456
|
+
},
|
|
457
|
+
"feature_info": feature_info,
|
|
458
|
+
}
|
|
459
|
+
|
|
460
|
+
output_path.parent.mkdir(parents=True, exist_ok=True)
|
|
461
|
+
with open(output_path, "w") as f:
|
|
462
|
+
json.dump(result, f, indent=2)
|
|
463
|
+
|
|
464
|
+
print(f"\nSaved FGAA steering vector to {output_path}")
|
|
465
|
+
return output_path
|