wisent 0.7.901__py3-none-any.whl → 0.7.1116__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- wisent/__init__.py +1 -1
- wisent/comparison/__init__.py +1 -0
- wisent/comparison/detect_bos_features.py +275 -0
- wisent/comparison/fgaa.py +465 -0
- wisent/comparison/lora.py +663 -0
- wisent/comparison/lora_dpo.py +604 -0
- wisent/comparison/main.py +444 -0
- wisent/comparison/ours.py +76 -0
- wisent/comparison/reft.py +690 -0
- wisent/comparison/sae.py +304 -0
- wisent/comparison/utils.py +381 -0
- wisent/core/activations/activations_collector.py +3 -2
- wisent/core/activations/extraction_strategy.py +8 -4
- wisent/core/cli/agent/apply_steering.py +7 -5
- wisent/core/cli/agent/train_classifier.py +4 -3
- wisent/core/cli/generate_vector_from_task.py +11 -20
- wisent/core/cli/get_activations.py +1 -1
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/boolq.py +20 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/cb.py +8 -1
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/truthfulqa_mc1.py +8 -1
- wisent/core/parser_arguments/generate_vector_from_task_parser.py +4 -11
- wisent/core/parser_arguments/get_activations_parser.py +5 -14
- {wisent-0.7.901.dist-info → wisent-0.7.1116.dist-info}/METADATA +5 -1
- {wisent-0.7.901.dist-info → wisent-0.7.1116.dist-info}/RECORD +28 -91
- wisent/examples/contrastive_pairs/humanization_human_vs_ai.json +0 -2112
- wisent/examples/scripts/1/test_basqueglue_evaluation.json +0 -51
- wisent/examples/scripts/1/test_basqueglue_pairs.json +0 -14
- wisent/examples/scripts/1/test_bec2016eu_evaluation.json +0 -51
- wisent/examples/scripts/1/test_bec2016eu_pairs.json +0 -14
- wisent/examples/scripts/1/test_belebele_evaluation.json +0 -51
- wisent/examples/scripts/1/test_belebele_pairs.json +0 -14
- wisent/examples/scripts/1/test_benchmarks_evaluation.json +0 -51
- wisent/examples/scripts/1/test_benchmarks_pairs.json +0 -14
- wisent/examples/scripts/1/test_bertaqa_evaluation.json +0 -51
- wisent/examples/scripts/1/test_bertaqa_pairs.json +0 -14
- wisent/examples/scripts/1/test_bhtc_v2_evaluation.json +0 -30
- wisent/examples/scripts/1/test_bhtc_v2_pairs.json +0 -8
- wisent/examples/scripts/1/test_boolq-seq2seq_evaluation.json +0 -30
- wisent/examples/scripts/1/test_boolq-seq2seq_pairs.json +0 -8
- wisent/examples/scripts/1/test_cabreu_evaluation.json +0 -30
- wisent/examples/scripts/1/test_cabreu_pairs.json +0 -8
- wisent/examples/scripts/1/test_careqa_en_evaluation.json +0 -30
- wisent/examples/scripts/1/test_careqa_en_pairs.json +0 -8
- wisent/examples/scripts/1/test_careqa_evaluation.json +0 -30
- wisent/examples/scripts/1/test_careqa_pairs.json +0 -8
- wisent/examples/scripts/1/test_catalanqa_evaluation.json +0 -30
- wisent/examples/scripts/1/test_catalanqa_pairs.json +0 -8
- wisent/examples/scripts/1/test_catcola_evaluation.json +0 -30
- wisent/examples/scripts/1/test_catcola_pairs.json +0 -8
- wisent/examples/scripts/1/test_chartqa_evaluation.json +0 -30
- wisent/examples/scripts/1/test_chartqa_pairs.json +0 -8
- wisent/examples/scripts/1/test_claim_stance_topic_evaluation.json +0 -30
- wisent/examples/scripts/1/test_claim_stance_topic_pairs.json +0 -8
- wisent/examples/scripts/1/test_cnn_dailymail_evaluation.json +0 -30
- wisent/examples/scripts/1/test_cnn_dailymail_pairs.json +0 -8
- wisent/examples/scripts/1/test_cocoteros_es_evaluation.json +0 -30
- wisent/examples/scripts/1/test_cocoteros_es_pairs.json +0 -8
- wisent/examples/scripts/1/test_coedit_gec_evaluation.json +0 -30
- wisent/examples/scripts/1/test_coedit_gec_pairs.json +0 -8
- wisent/examples/scripts/1/test_cola_evaluation.json +0 -30
- wisent/examples/scripts/1/test_cola_pairs.json +0 -8
- wisent/examples/scripts/1/test_coqcat_evaluation.json +0 -30
- wisent/examples/scripts/1/test_coqcat_pairs.json +0 -8
- wisent/examples/scripts/1/test_dbpedia_14_evaluation.json +0 -30
- wisent/examples/scripts/1/test_dbpedia_14_pairs.json +0 -8
- wisent/examples/scripts/1/test_epec_koref_bin_evaluation.json +0 -30
- wisent/examples/scripts/1/test_epec_koref_bin_pairs.json +0 -8
- wisent/examples/scripts/1/test_ethos_binary_evaluation.json +0 -30
- wisent/examples/scripts/1/test_ethos_binary_pairs.json +0 -8
- wisent/examples/scripts/2/test_afrimgsm_direct_amh_evaluation.json +0 -30
- wisent/examples/scripts/2/test_afrimgsm_direct_amh_pairs.json +0 -8
- wisent/examples/scripts/2/test_afrimmlu_direct_amh_evaluation.json +0 -30
- wisent/examples/scripts/2/test_afrimmlu_direct_amh_pairs.json +0 -8
- wisent/examples/scripts/2/test_afrixnli_en_direct_amh_evaluation.json +0 -30
- wisent/examples/scripts/2/test_afrixnli_en_direct_amh_pairs.json +0 -8
- wisent/examples/scripts/2/test_arc_ar_evaluation.json +0 -30
- wisent/examples/scripts/2/test_arc_ar_pairs.json +0 -8
- wisent/examples/scripts/2/test_atis_evaluation.json +0 -30
- wisent/examples/scripts/2/test_atis_pairs.json +0 -8
- wisent/examples/scripts/2/test_babi_evaluation.json +0 -30
- wisent/examples/scripts/2/test_babi_pairs.json +0 -8
- wisent/examples/scripts/2/test_babilong_evaluation.json +0 -30
- wisent/examples/scripts/2/test_babilong_pairs.json +0 -8
- wisent/examples/scripts/2/test_bangla_mmlu_evaluation.json +0 -30
- wisent/examples/scripts/2/test_bangla_mmlu_pairs.json +0 -8
- wisent/examples/scripts/2/test_basque-glue_pairs.json +0 -14
- wisent/examples/scripts/generate_paper_data.py +0 -384
- wisent/examples/scripts/intervention_validation.py +0 -626
- wisent/examples/scripts/results/test_AraDiCE_ArabicMMLU_lev_evaluation.json +0 -324
- wisent/examples/scripts/results/test_AraDiCE_ArabicMMLU_lev_pairs.json +0 -92
- wisent/examples/scripts/results/test_aexams_IslamicStudies_evaluation.json +0 -324
- wisent/examples/scripts/results/test_aexams_IslamicStudies_pairs.json +0 -92
- wisent/examples/scripts/results/test_afrimgsm_pairs.json +0 -92
- wisent/examples/scripts/results/test_afrimmlu_evaluation.json +0 -324
- wisent/examples/scripts/results/test_afrimmlu_pairs.json +0 -92
- wisent/examples/scripts/threshold_analysis.py +0 -434
- wisent/examples/scripts/visualization_gallery.py +0 -582
- {wisent-0.7.901.dist-info → wisent-0.7.1116.dist-info}/WHEEL +0 -0
- {wisent-0.7.901.dist-info → wisent-0.7.1116.dist-info}/entry_points.txt +0 -0
- {wisent-0.7.901.dist-info → wisent-0.7.1116.dist-info}/licenses/LICENSE +0 -0
- {wisent-0.7.901.dist-info → wisent-0.7.1116.dist-info}/top_level.txt +0 -0
wisent/__init__.py
CHANGED
|
@@ -0,0 +1 @@
|
|
|
1
|
+
# Comparison methods for evaluating steering techniques
|
|
@@ -0,0 +1,275 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Detect BOS (Beginning of Sequence) features in Gemma Scope SAEs.
|
|
3
|
+
|
|
4
|
+
BOS features are SAE features that activate most strongly at the BOS token position.
|
|
5
|
+
These should be filtered out when computing steering vectors as they introduce
|
|
6
|
+
artifacts without contributing to steering.
|
|
7
|
+
|
|
8
|
+
Reference: "Interpretable Steering of Large Language Models with Feature Guided
|
|
9
|
+
Activation Additions" (arXiv:2501.09929), Appendix G.
|
|
10
|
+
|
|
11
|
+
Known BOS features from paper (Gemma-2-2B, layer 12, 16k SAE):
|
|
12
|
+
- 11087, 3220, 11752, 12160, 11498
|
|
13
|
+
|
|
14
|
+
Usage:
|
|
15
|
+
python -m wisent.comparison.detect_bos_features --model google/gemma-2-2b
|
|
16
|
+
"""
|
|
17
|
+
|
|
18
|
+
from __future__ import annotations
|
|
19
|
+
|
|
20
|
+
import argparse
|
|
21
|
+
import json
|
|
22
|
+
from pathlib import Path
|
|
23
|
+
|
|
24
|
+
import torch
|
|
25
|
+
from datasets import load_dataset
|
|
26
|
+
from tqdm import tqdm
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
# Known BOS feature indices from paper (Appendix G)
|
|
30
|
+
KNOWN_BOS_FEATURES = {
|
|
31
|
+
"google/gemma-2-2b": [11087, 3220, 11752, 12160, 11498],
|
|
32
|
+
"google/gemma-2-9b": [], # Not listed in paper
|
|
33
|
+
}
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
def load_sample_texts(num_samples: int = 2000, min_length: int = 50) -> list[str]:
|
|
37
|
+
"""Load sample texts from WikiText dataset."""
|
|
38
|
+
print(f"Loading up to {num_samples} sample texts from WikiText...")
|
|
39
|
+
|
|
40
|
+
dataset = load_dataset("wikitext", "wikitext-103-v1", split="train")
|
|
41
|
+
|
|
42
|
+
texts = []
|
|
43
|
+
for item in dataset:
|
|
44
|
+
if len(texts) >= num_samples:
|
|
45
|
+
break
|
|
46
|
+
text = item["text"].strip()
|
|
47
|
+
if len(text) >= min_length:
|
|
48
|
+
texts.append(text)
|
|
49
|
+
|
|
50
|
+
print(f" Loaded {len(texts)} texts")
|
|
51
|
+
return texts
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
def detect_bos_features(
|
|
55
|
+
model,
|
|
56
|
+
tokenizer,
|
|
57
|
+
sae,
|
|
58
|
+
layer_idx: int,
|
|
59
|
+
device: str,
|
|
60
|
+
texts: list[str],
|
|
61
|
+
top_k: int = 10,
|
|
62
|
+
batch_size: int = 8,
|
|
63
|
+
) -> tuple[list[int], dict[str, torch.Tensor]]:
|
|
64
|
+
"""
|
|
65
|
+
Detect BOS features by finding features that activate most strongly at position 0.
|
|
66
|
+
|
|
67
|
+
Computes statistics (mean, variance, median) of activation at BOS position for each
|
|
68
|
+
SAE feature across all samples, then returns the top-k features with highest mean
|
|
69
|
+
BOS activation.
|
|
70
|
+
|
|
71
|
+
Reference: FGAA paper (arXiv:2501.09929) identifies BOS features as those that
|
|
72
|
+
"exclusively had the strongest activation on the BOS token".
|
|
73
|
+
|
|
74
|
+
Args:
|
|
75
|
+
model: HuggingFace model
|
|
76
|
+
tokenizer: Tokenizer
|
|
77
|
+
sae: SAE object from sae_lens
|
|
78
|
+
layer_idx: Layer index (0-indexed)
|
|
79
|
+
device: Device
|
|
80
|
+
texts: List of sample texts to analyze
|
|
81
|
+
top_k: Number of top BOS features to return (default 10)
|
|
82
|
+
batch_size: Batch size for processing
|
|
83
|
+
|
|
84
|
+
Returns:
|
|
85
|
+
Tuple of (top_bos_feature_indices, stats_dict) where stats_dict contains
|
|
86
|
+
'mean', 'variance', and 'median' tensors of shape [d_sae].
|
|
87
|
+
"""
|
|
88
|
+
d_sae = sae.cfg.d_sae
|
|
89
|
+
|
|
90
|
+
# Collect all BOS activations on CPU for stable statistics computation
|
|
91
|
+
all_bos_activations = []
|
|
92
|
+
|
|
93
|
+
print(f"Detecting BOS features from {len(texts)} samples...")
|
|
94
|
+
print(f" Layer: {layer_idx}, d_sae: {d_sae}")
|
|
95
|
+
|
|
96
|
+
# Use hook to capture only the layer we need (not all 26 layers)
|
|
97
|
+
captured_acts = {}
|
|
98
|
+
def capture_hook(module, input, output):
|
|
99
|
+
captured_acts["hidden"] = output[0].detach()
|
|
100
|
+
|
|
101
|
+
# Register hook on the specific layer
|
|
102
|
+
target_layer = model.model.layers[layer_idx]
|
|
103
|
+
hook_handle = target_layer.register_forward_hook(capture_hook)
|
|
104
|
+
|
|
105
|
+
try:
|
|
106
|
+
for i in tqdm(range(0, len(texts), batch_size), desc="Processing"):
|
|
107
|
+
batch_texts = texts[i:i + batch_size]
|
|
108
|
+
|
|
109
|
+
inputs = tokenizer(
|
|
110
|
+
batch_texts,
|
|
111
|
+
return_tensors="pt",
|
|
112
|
+
truncation=True,
|
|
113
|
+
max_length=128,
|
|
114
|
+
padding=True,
|
|
115
|
+
)
|
|
116
|
+
inputs = {k: v.to(device) for k, v in inputs.items()}
|
|
117
|
+
|
|
118
|
+
with torch.no_grad():
|
|
119
|
+
# Use model.model to skip lm_head (saves ~66MB per forward)
|
|
120
|
+
model.model(**inputs, use_cache=False)
|
|
121
|
+
|
|
122
|
+
# Get captured hidden states (only this layer, not all 26)
|
|
123
|
+
acts = captured_acts["hidden"].to(sae.W_enc.dtype)
|
|
124
|
+
latents = sae.encode(acts)
|
|
125
|
+
|
|
126
|
+
attention_mask = inputs.get("attention_mask")
|
|
127
|
+
for j in range(latents.shape[0]):
|
|
128
|
+
seq_len = int(attention_mask[j].sum().item()) if attention_mask is not None else latents.shape[1]
|
|
129
|
+
sample_latents = latents[j, :seq_len, :] # [seq_len, d_sae]
|
|
130
|
+
|
|
131
|
+
# Collect BOS activation (position 0) - move to CPU immediately
|
|
132
|
+
bos_act = sample_latents[0].float().cpu()
|
|
133
|
+
all_bos_activations.append(bos_act)
|
|
134
|
+
finally:
|
|
135
|
+
hook_handle.remove()
|
|
136
|
+
|
|
137
|
+
# Compute statistics (all on CPU)
|
|
138
|
+
# Stack all activations for stable computation
|
|
139
|
+
all_bos_tensor = torch.stack(all_bos_activations, dim=0) # [num_samples, d_sae]
|
|
140
|
+
mean_bos = all_bos_tensor.mean(dim=0)
|
|
141
|
+
variance_bos = all_bos_tensor.var(dim=0)
|
|
142
|
+
median_bos = all_bos_tensor.median(dim=0).values
|
|
143
|
+
|
|
144
|
+
stats = {
|
|
145
|
+
"mean": mean_bos,
|
|
146
|
+
"variance": variance_bos,
|
|
147
|
+
"median": median_bos,
|
|
148
|
+
}
|
|
149
|
+
|
|
150
|
+
# Select top features by BOS activation
|
|
151
|
+
top_indices = mean_bos.topk(top_k).indices.tolist()
|
|
152
|
+
|
|
153
|
+
print(f"\nDetected top {top_k} BOS features by mean activation")
|
|
154
|
+
return top_indices, stats
|
|
155
|
+
|
|
156
|
+
|
|
157
|
+
def compare_with_known(model_name: str, detected: list[int], stats: dict[str, torch.Tensor]) -> None:
|
|
158
|
+
"""Compare detected BOS features with known list from paper."""
|
|
159
|
+
mean_bos = stats["mean"]
|
|
160
|
+
variance_bos = stats["variance"]
|
|
161
|
+
median_bos = stats["median"]
|
|
162
|
+
|
|
163
|
+
known = KNOWN_BOS_FEATURES.get(model_name, [])
|
|
164
|
+
detected_set = set(detected)
|
|
165
|
+
known_set = set(known)
|
|
166
|
+
|
|
167
|
+
print(f"\n{'='*60}")
|
|
168
|
+
print(f"BOS Feature Comparison for {model_name}")
|
|
169
|
+
print(f"{'='*60}")
|
|
170
|
+
print(f"Known (paper): {sorted(known_set)}")
|
|
171
|
+
print(f"Detected: {sorted(detected_set)}")
|
|
172
|
+
print(f"{'='*60}")
|
|
173
|
+
print(f"Common: {sorted(detected_set & known_set)}")
|
|
174
|
+
print(f"Only in paper: {sorted(known_set - detected_set)}")
|
|
175
|
+
print(f"Only detected: {sorted(detected_set - known_set)}")
|
|
176
|
+
|
|
177
|
+
if known:
|
|
178
|
+
print(f"\nKnown features - BOS activation stats:")
|
|
179
|
+
for idx in sorted(known):
|
|
180
|
+
mean_val = mean_bos[idx].item()
|
|
181
|
+
var_val = variance_bos[idx].item()
|
|
182
|
+
median_val = median_bos[idx].item()
|
|
183
|
+
status = "detected" if idx in detected_set else "missed"
|
|
184
|
+
print(f" Feature {idx}: mean={mean_val:.4f}, var={var_val:.4f}, median={median_val:.4f} ({status})")
|
|
185
|
+
|
|
186
|
+
print(f"\nTop 20 features by mean BOS activation:")
|
|
187
|
+
for rank, idx in enumerate(mean_bos.topk(20).indices.tolist(), 1):
|
|
188
|
+
mean_val = mean_bos[idx].item()
|
|
189
|
+
var_val = variance_bos[idx].item()
|
|
190
|
+
median_val = median_bos[idx].item()
|
|
191
|
+
marker = " (known)" if idx in known_set else ""
|
|
192
|
+
print(f" {rank:2}. Feature {idx}: mean={mean_val:.4f}, var={var_val:.4f}, median={median_val:.4f}{marker}")
|
|
193
|
+
|
|
194
|
+
|
|
195
|
+
def main():
|
|
196
|
+
parser = argparse.ArgumentParser(description="Detect BOS features in Gemma Scope SAEs")
|
|
197
|
+
parser.add_argument("--model", default="google/gemma-2-2b", help="Model name")
|
|
198
|
+
parser.add_argument("--layer", type=int, default=12, help="Layer index")
|
|
199
|
+
parser.add_argument("--num-samples", type=int, default=1000, help="Number of text samples")
|
|
200
|
+
parser.add_argument("--top-k", type=int, default=20, help="Number of top BOS features to detect")
|
|
201
|
+
parser.add_argument("--batch-size", type=int, default=1, help="Batch size")
|
|
202
|
+
parser.add_argument("--device", default="cuda:0", help="Device")
|
|
203
|
+
parser.add_argument("--output-dir", default="wisent/comparison/results", help="Output directory")
|
|
204
|
+
args = parser.parse_args()
|
|
205
|
+
|
|
206
|
+
print(f"Model: {args.model}")
|
|
207
|
+
print(f"Layer: {args.layer}")
|
|
208
|
+
print(f"Top-k: {args.top_k}")
|
|
209
|
+
|
|
210
|
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
211
|
+
|
|
212
|
+
print(f"\nLoading model...")
|
|
213
|
+
model = AutoModelForCausalLM.from_pretrained(
|
|
214
|
+
args.model,
|
|
215
|
+
torch_dtype=torch.bfloat16,
|
|
216
|
+
device_map=args.device,
|
|
217
|
+
trust_remote_code=True,
|
|
218
|
+
)
|
|
219
|
+
tokenizer = AutoTokenizer.from_pretrained(args.model, trust_remote_code=True)
|
|
220
|
+
if tokenizer.pad_token is None:
|
|
221
|
+
tokenizer.pad_token = tokenizer.eos_token
|
|
222
|
+
model.eval()
|
|
223
|
+
|
|
224
|
+
print(f"\nLoading SAE...")
|
|
225
|
+
from sae_lens import SAE
|
|
226
|
+
|
|
227
|
+
release = "gemma-scope-2b-pt-res-canonical" if "2b" in args.model.lower() else "gemma-scope-9b-pt-res-canonical"
|
|
228
|
+
sae_id = f"layer_{args.layer}/width_16k/canonical"
|
|
229
|
+
sae, _, _ = SAE.from_pretrained(release=release, sae_id=sae_id, device=args.device)
|
|
230
|
+
|
|
231
|
+
texts = load_sample_texts(args.num_samples)
|
|
232
|
+
|
|
233
|
+
bos_features, stats = detect_bos_features(
|
|
234
|
+
model=model,
|
|
235
|
+
tokenizer=tokenizer,
|
|
236
|
+
sae=sae,
|
|
237
|
+
layer_idx=args.layer,
|
|
238
|
+
device=args.device,
|
|
239
|
+
texts=texts,
|
|
240
|
+
top_k=args.top_k,
|
|
241
|
+
batch_size=args.batch_size,
|
|
242
|
+
)
|
|
243
|
+
|
|
244
|
+
compare_with_known(args.model, bos_features, stats)
|
|
245
|
+
|
|
246
|
+
# Build detected features with their stats
|
|
247
|
+
detected_with_stats = [
|
|
248
|
+
{
|
|
249
|
+
"feature": idx,
|
|
250
|
+
"mean_bos_activation": stats["mean"][idx].item(),
|
|
251
|
+
"variance_bos_activation": stats["variance"][idx].item(),
|
|
252
|
+
"median_bos_activation": stats["median"][idx].item(),
|
|
253
|
+
}
|
|
254
|
+
for idx in bos_features
|
|
255
|
+
]
|
|
256
|
+
|
|
257
|
+
output = {
|
|
258
|
+
"model": args.model,
|
|
259
|
+
"layer": args.layer,
|
|
260
|
+
"top_k": args.top_k,
|
|
261
|
+
"num_samples": len(texts),
|
|
262
|
+
"detected_bos_features": detected_with_stats,
|
|
263
|
+
"known_bos_features": KNOWN_BOS_FEATURES.get(args.model, []),
|
|
264
|
+
}
|
|
265
|
+
|
|
266
|
+
output_dir = Path(args.output_dir)
|
|
267
|
+
output_dir.mkdir(parents=True, exist_ok=True)
|
|
268
|
+
output_path = output_dir / f"bos_features_{args.model.replace('/', '_')}_layer{args.layer}.json"
|
|
269
|
+
with open(output_path, "w") as f:
|
|
270
|
+
json.dump(output, f, indent=2)
|
|
271
|
+
print(f"\nSaved to {output_path}")
|
|
272
|
+
|
|
273
|
+
|
|
274
|
+
if __name__ == "__main__":
|
|
275
|
+
main()
|