wisent 0.7.901__py3-none-any.whl → 0.7.1045__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- wisent/__init__.py +1 -1
- wisent/comparison/__init__.py +1 -0
- wisent/comparison/detect_bos_features.py +275 -0
- wisent/comparison/fgaa.py +465 -0
- wisent/comparison/lora.py +669 -0
- wisent/comparison/lora_dpo.py +592 -0
- wisent/comparison/main.py +444 -0
- wisent/comparison/ours.py +76 -0
- wisent/comparison/sae.py +304 -0
- wisent/comparison/utils.py +381 -0
- wisent/core/activations/activations_collector.py +3 -2
- wisent/core/activations/extraction_strategy.py +8 -4
- wisent/core/cli/agent/apply_steering.py +7 -5
- wisent/core/cli/agent/train_classifier.py +4 -3
- wisent/core/cli/generate_vector_from_task.py +11 -20
- wisent/core/cli/get_activations.py +1 -1
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/boolq.py +20 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/cb.py +8 -1
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/truthfulqa_mc1.py +8 -1
- wisent/core/parser_arguments/generate_vector_from_task_parser.py +4 -11
- wisent/core/parser_arguments/get_activations_parser.py +5 -14
- {wisent-0.7.901.dist-info → wisent-0.7.1045.dist-info}/METADATA +5 -1
- {wisent-0.7.901.dist-info → wisent-0.7.1045.dist-info}/RECORD +27 -91
- wisent/examples/contrastive_pairs/humanization_human_vs_ai.json +0 -2112
- wisent/examples/scripts/1/test_basqueglue_evaluation.json +0 -51
- wisent/examples/scripts/1/test_basqueglue_pairs.json +0 -14
- wisent/examples/scripts/1/test_bec2016eu_evaluation.json +0 -51
- wisent/examples/scripts/1/test_bec2016eu_pairs.json +0 -14
- wisent/examples/scripts/1/test_belebele_evaluation.json +0 -51
- wisent/examples/scripts/1/test_belebele_pairs.json +0 -14
- wisent/examples/scripts/1/test_benchmarks_evaluation.json +0 -51
- wisent/examples/scripts/1/test_benchmarks_pairs.json +0 -14
- wisent/examples/scripts/1/test_bertaqa_evaluation.json +0 -51
- wisent/examples/scripts/1/test_bertaqa_pairs.json +0 -14
- wisent/examples/scripts/1/test_bhtc_v2_evaluation.json +0 -30
- wisent/examples/scripts/1/test_bhtc_v2_pairs.json +0 -8
- wisent/examples/scripts/1/test_boolq-seq2seq_evaluation.json +0 -30
- wisent/examples/scripts/1/test_boolq-seq2seq_pairs.json +0 -8
- wisent/examples/scripts/1/test_cabreu_evaluation.json +0 -30
- wisent/examples/scripts/1/test_cabreu_pairs.json +0 -8
- wisent/examples/scripts/1/test_careqa_en_evaluation.json +0 -30
- wisent/examples/scripts/1/test_careqa_en_pairs.json +0 -8
- wisent/examples/scripts/1/test_careqa_evaluation.json +0 -30
- wisent/examples/scripts/1/test_careqa_pairs.json +0 -8
- wisent/examples/scripts/1/test_catalanqa_evaluation.json +0 -30
- wisent/examples/scripts/1/test_catalanqa_pairs.json +0 -8
- wisent/examples/scripts/1/test_catcola_evaluation.json +0 -30
- wisent/examples/scripts/1/test_catcola_pairs.json +0 -8
- wisent/examples/scripts/1/test_chartqa_evaluation.json +0 -30
- wisent/examples/scripts/1/test_chartqa_pairs.json +0 -8
- wisent/examples/scripts/1/test_claim_stance_topic_evaluation.json +0 -30
- wisent/examples/scripts/1/test_claim_stance_topic_pairs.json +0 -8
- wisent/examples/scripts/1/test_cnn_dailymail_evaluation.json +0 -30
- wisent/examples/scripts/1/test_cnn_dailymail_pairs.json +0 -8
- wisent/examples/scripts/1/test_cocoteros_es_evaluation.json +0 -30
- wisent/examples/scripts/1/test_cocoteros_es_pairs.json +0 -8
- wisent/examples/scripts/1/test_coedit_gec_evaluation.json +0 -30
- wisent/examples/scripts/1/test_coedit_gec_pairs.json +0 -8
- wisent/examples/scripts/1/test_cola_evaluation.json +0 -30
- wisent/examples/scripts/1/test_cola_pairs.json +0 -8
- wisent/examples/scripts/1/test_coqcat_evaluation.json +0 -30
- wisent/examples/scripts/1/test_coqcat_pairs.json +0 -8
- wisent/examples/scripts/1/test_dbpedia_14_evaluation.json +0 -30
- wisent/examples/scripts/1/test_dbpedia_14_pairs.json +0 -8
- wisent/examples/scripts/1/test_epec_koref_bin_evaluation.json +0 -30
- wisent/examples/scripts/1/test_epec_koref_bin_pairs.json +0 -8
- wisent/examples/scripts/1/test_ethos_binary_evaluation.json +0 -30
- wisent/examples/scripts/1/test_ethos_binary_pairs.json +0 -8
- wisent/examples/scripts/2/test_afrimgsm_direct_amh_evaluation.json +0 -30
- wisent/examples/scripts/2/test_afrimgsm_direct_amh_pairs.json +0 -8
- wisent/examples/scripts/2/test_afrimmlu_direct_amh_evaluation.json +0 -30
- wisent/examples/scripts/2/test_afrimmlu_direct_amh_pairs.json +0 -8
- wisent/examples/scripts/2/test_afrixnli_en_direct_amh_evaluation.json +0 -30
- wisent/examples/scripts/2/test_afrixnli_en_direct_amh_pairs.json +0 -8
- wisent/examples/scripts/2/test_arc_ar_evaluation.json +0 -30
- wisent/examples/scripts/2/test_arc_ar_pairs.json +0 -8
- wisent/examples/scripts/2/test_atis_evaluation.json +0 -30
- wisent/examples/scripts/2/test_atis_pairs.json +0 -8
- wisent/examples/scripts/2/test_babi_evaluation.json +0 -30
- wisent/examples/scripts/2/test_babi_pairs.json +0 -8
- wisent/examples/scripts/2/test_babilong_evaluation.json +0 -30
- wisent/examples/scripts/2/test_babilong_pairs.json +0 -8
- wisent/examples/scripts/2/test_bangla_mmlu_evaluation.json +0 -30
- wisent/examples/scripts/2/test_bangla_mmlu_pairs.json +0 -8
- wisent/examples/scripts/2/test_basque-glue_pairs.json +0 -14
- wisent/examples/scripts/generate_paper_data.py +0 -384
- wisent/examples/scripts/intervention_validation.py +0 -626
- wisent/examples/scripts/results/test_AraDiCE_ArabicMMLU_lev_evaluation.json +0 -324
- wisent/examples/scripts/results/test_AraDiCE_ArabicMMLU_lev_pairs.json +0 -92
- wisent/examples/scripts/results/test_aexams_IslamicStudies_evaluation.json +0 -324
- wisent/examples/scripts/results/test_aexams_IslamicStudies_pairs.json +0 -92
- wisent/examples/scripts/results/test_afrimgsm_pairs.json +0 -92
- wisent/examples/scripts/results/test_afrimmlu_evaluation.json +0 -324
- wisent/examples/scripts/results/test_afrimmlu_pairs.json +0 -92
- wisent/examples/scripts/threshold_analysis.py +0 -434
- wisent/examples/scripts/visualization_gallery.py +0 -582
- {wisent-0.7.901.dist-info → wisent-0.7.1045.dist-info}/WHEEL +0 -0
- {wisent-0.7.901.dist-info → wisent-0.7.1045.dist-info}/entry_points.txt +0 -0
- {wisent-0.7.901.dist-info → wisent-0.7.1045.dist-info}/licenses/LICENSE +0 -0
- {wisent-0.7.901.dist-info → wisent-0.7.1045.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,444 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Comparison of steering methods: Ours vs SAE-based.
|
|
3
|
+
|
|
4
|
+
This script:
|
|
5
|
+
1. Creates steering vectors using train split of pooled data
|
|
6
|
+
2. Runs base evaluation on test split (no overlap)
|
|
7
|
+
3. Runs steered evaluation on same test split
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
from __future__ import annotations
|
|
11
|
+
|
|
12
|
+
import argparse
|
|
13
|
+
import gc
|
|
14
|
+
import json
|
|
15
|
+
from pathlib import Path
|
|
16
|
+
|
|
17
|
+
import torch
|
|
18
|
+
from lm_eval import evaluator
|
|
19
|
+
from lm_eval.models.hf_steered import SteeredModel
|
|
20
|
+
|
|
21
|
+
from wisent.core.models.wisent_model import WisentModel
|
|
22
|
+
from wisent.comparison import ours
|
|
23
|
+
from wisent.comparison import sae
|
|
24
|
+
from wisent.comparison import fgaa
|
|
25
|
+
from wisent.comparison.utils import (
|
|
26
|
+
load_steering_vector,
|
|
27
|
+
apply_steering_to_model,
|
|
28
|
+
remove_steering,
|
|
29
|
+
convert_to_lm_eval_format,
|
|
30
|
+
create_test_only_task,
|
|
31
|
+
extract_accuracy,
|
|
32
|
+
run_lm_eval_evaluation,
|
|
33
|
+
run_ll_evaluation,
|
|
34
|
+
)
|
|
35
|
+
|
|
36
|
+
# Map method names to modules
|
|
37
|
+
METHOD_MODULES = {
|
|
38
|
+
"caa": ours,
|
|
39
|
+
"sae": sae,
|
|
40
|
+
"fgaa": fgaa,
|
|
41
|
+
}
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
def run_single_task(
|
|
45
|
+
model_name: str,
|
|
46
|
+
task: str,
|
|
47
|
+
methods: list[str] = None,
|
|
48
|
+
num_pairs: int = 50,
|
|
49
|
+
steering_scales: list[float] = None,
|
|
50
|
+
device: str = "cuda:0",
|
|
51
|
+
batch_size: int | str = 1,
|
|
52
|
+
max_batch_size: int = 8,
|
|
53
|
+
eval_limit: int | None = None,
|
|
54
|
+
vectors_dir: Path = None,
|
|
55
|
+
train_ratio: float = 0.8,
|
|
56
|
+
caa_layers: str = "12",
|
|
57
|
+
sae_layers: str = "12",
|
|
58
|
+
extraction_strategies: list[str] = None,
|
|
59
|
+
bos_features_source: str = "detected",
|
|
60
|
+
) -> list[dict]:
|
|
61
|
+
"""
|
|
62
|
+
Run comparison for a single task with multiple methods, scales, and extraction strategies.
|
|
63
|
+
|
|
64
|
+
Returns list of result dicts, one per method/scale/strategy combination.
|
|
65
|
+
"""
|
|
66
|
+
if methods is None:
|
|
67
|
+
methods = ["caa"]
|
|
68
|
+
if steering_scales is None:
|
|
69
|
+
steering_scales = [1.0]
|
|
70
|
+
if extraction_strategies is None:
|
|
71
|
+
extraction_strategies = ["mc_balanced"]
|
|
72
|
+
|
|
73
|
+
results_list = []
|
|
74
|
+
|
|
75
|
+
# Step 1: Create test task
|
|
76
|
+
test_pct = round((1 - train_ratio) * 100)
|
|
77
|
+
print(f"\n{'='*60}")
|
|
78
|
+
print(f"Creating test task for: {task}")
|
|
79
|
+
print(f"(using {test_pct}% of pooled data)")
|
|
80
|
+
print(f"{'='*60}")
|
|
81
|
+
|
|
82
|
+
task_dict = create_test_only_task(task, train_ratio=train_ratio)
|
|
83
|
+
|
|
84
|
+
# Step 2: Generate ALL steering vectors FIRST for ALL strategies (subprocess frees GPU memory after each)
|
|
85
|
+
# Structure: steering_vectors_data[strategy][method] = steering_data
|
|
86
|
+
steering_vectors_data = {}
|
|
87
|
+
train_pct = round(train_ratio * 100)
|
|
88
|
+
|
|
89
|
+
for method in methods:
|
|
90
|
+
if method not in METHOD_MODULES:
|
|
91
|
+
print(f"WARNING: Method '{method}' not implemented, skipping")
|
|
92
|
+
continue
|
|
93
|
+
|
|
94
|
+
method_module = METHOD_MODULES[method]
|
|
95
|
+
|
|
96
|
+
# CAA uses extraction strategy, FGAA/SAE don't
|
|
97
|
+
for extraction_strategy in (extraction_strategies if method == "caa" else [None]):
|
|
98
|
+
print(f"\n{'@'*60}")
|
|
99
|
+
print(f"@ METHOD: {method}, EXTRACTION STRATEGY: {extraction_strategy or 'N/A'}")
|
|
100
|
+
print(f"{'@'*60}")
|
|
101
|
+
|
|
102
|
+
# Select layers based on method: CAA uses caa_layers (default=middle), SAE/FGAA use sae_layers (default=12)
|
|
103
|
+
method_layers = caa_layers if method == "caa" else sae_layers
|
|
104
|
+
|
|
105
|
+
print(f"\n{'='*60}")
|
|
106
|
+
print(f"Generating steering vector for: {task} (method={method})")
|
|
107
|
+
print(f"(using {train_pct}% of pooled data - no overlap with test)")
|
|
108
|
+
print(f"Layers: {method_layers}")
|
|
109
|
+
print(f"{'='*60}")
|
|
110
|
+
|
|
111
|
+
suffix = f"_{extraction_strategy}" if extraction_strategy else ""
|
|
112
|
+
vector_path = vectors_dir / f"{task}_{method}{suffix}_steering_vector.json"
|
|
113
|
+
|
|
114
|
+
kwargs = {
|
|
115
|
+
"task": task,
|
|
116
|
+
"model_name": model_name,
|
|
117
|
+
"output_path": vector_path,
|
|
118
|
+
"num_pairs": num_pairs,
|
|
119
|
+
"device": device,
|
|
120
|
+
"layers": method_layers,
|
|
121
|
+
}
|
|
122
|
+
if extraction_strategy:
|
|
123
|
+
kwargs["extraction_strategy"] = extraction_strategy
|
|
124
|
+
if method == "fgaa":
|
|
125
|
+
kwargs["bos_features_source"] = bos_features_source
|
|
126
|
+
|
|
127
|
+
method_module.generate_steering_vector(**kwargs)
|
|
128
|
+
|
|
129
|
+
steering_data = load_steering_vector(vector_path, default_method=method)
|
|
130
|
+
if extraction_strategy not in steering_vectors_data:
|
|
131
|
+
steering_vectors_data[extraction_strategy] = {}
|
|
132
|
+
steering_vectors_data[extraction_strategy][method] = steering_data
|
|
133
|
+
print(f"Loaded steering vector with layers: {steering_data['layers']}")
|
|
134
|
+
|
|
135
|
+
# Step 3: Load model once for ALL evaluations
|
|
136
|
+
print(f"\n{'='*60}")
|
|
137
|
+
print(f"Loading model: {model_name}")
|
|
138
|
+
print(f"{'='*60}")
|
|
139
|
+
wisent_model = WisentModel(model_name=model_name, device=device)
|
|
140
|
+
|
|
141
|
+
# Step 4: Run base evaluation (no steering applied)
|
|
142
|
+
print(f"\n{'='*60}")
|
|
143
|
+
print(f"Running BASE evaluation for: {task}")
|
|
144
|
+
print(f"{'='*60}")
|
|
145
|
+
|
|
146
|
+
base_results = run_lm_eval_evaluation(
|
|
147
|
+
wisent_model=wisent_model,
|
|
148
|
+
task_dict=task_dict,
|
|
149
|
+
task_name=task,
|
|
150
|
+
batch_size=batch_size,
|
|
151
|
+
max_batch_size=max_batch_size,
|
|
152
|
+
limit=eval_limit,
|
|
153
|
+
)
|
|
154
|
+
base_acc = extract_accuracy(base_results, task)
|
|
155
|
+
print(f"Base accuracy (lm-eval): {base_acc:.4f}")
|
|
156
|
+
|
|
157
|
+
# Step 4b: Run base LL evaluation (no steering)
|
|
158
|
+
print(f"\n{'='*60}")
|
|
159
|
+
print(f"Running BASE LL evaluation for: {task}")
|
|
160
|
+
print(f"{'='*60}")
|
|
161
|
+
|
|
162
|
+
base_ll_acc = run_ll_evaluation(
|
|
163
|
+
wisent_model=wisent_model,
|
|
164
|
+
task_dict=task_dict,
|
|
165
|
+
task_name=task,
|
|
166
|
+
limit=eval_limit,
|
|
167
|
+
)
|
|
168
|
+
print(f"Base accuracy (LL): {base_ll_acc:.4f}")
|
|
169
|
+
|
|
170
|
+
# Step 5: Run ALL wisent steered evaluations first (model stays loaded)
|
|
171
|
+
# Structure: wisent_results[(strategy, method, scale)] = steered_acc
|
|
172
|
+
wisent_results = {}
|
|
173
|
+
for method in methods:
|
|
174
|
+
# CAA uses extraction strategy, FGAA/SAE don't
|
|
175
|
+
for extraction_strategy in (extraction_strategies if method == "caa" else [None]):
|
|
176
|
+
if extraction_strategy not in steering_vectors_data:
|
|
177
|
+
continue
|
|
178
|
+
if method not in steering_vectors_data[extraction_strategy]:
|
|
179
|
+
continue
|
|
180
|
+
|
|
181
|
+
steering_data = steering_vectors_data[extraction_strategy][method]
|
|
182
|
+
|
|
183
|
+
for scale in steering_scales:
|
|
184
|
+
print(f"\n{'='*60}")
|
|
185
|
+
print(f"Running STEERED evaluation for: {task} (strategy={extraction_strategy}, method={method}, scale={scale})")
|
|
186
|
+
print(f"{'='*60}")
|
|
187
|
+
|
|
188
|
+
# Apply steering to existing model
|
|
189
|
+
apply_steering_to_model(wisent_model, steering_data, scale=scale)
|
|
190
|
+
|
|
191
|
+
steered_results = run_lm_eval_evaluation(
|
|
192
|
+
wisent_model=wisent_model,
|
|
193
|
+
task_dict=task_dict,
|
|
194
|
+
task_name=task,
|
|
195
|
+
batch_size=batch_size,
|
|
196
|
+
max_batch_size=max_batch_size,
|
|
197
|
+
limit=eval_limit,
|
|
198
|
+
)
|
|
199
|
+
steered_acc = extract_accuracy(steered_results, task)
|
|
200
|
+
print(f"Steered accuracy (lm-eval): {steered_acc:.4f}")
|
|
201
|
+
|
|
202
|
+
# Run steered LL evaluation
|
|
203
|
+
steered_ll_acc = run_ll_evaluation(
|
|
204
|
+
wisent_model=wisent_model,
|
|
205
|
+
task_dict=task_dict,
|
|
206
|
+
task_name=task,
|
|
207
|
+
limit=eval_limit,
|
|
208
|
+
)
|
|
209
|
+
print(f"Steered accuracy (LL): {steered_ll_acc:.4f}")
|
|
210
|
+
|
|
211
|
+
# Remove steering for next iteration
|
|
212
|
+
remove_steering(wisent_model)
|
|
213
|
+
|
|
214
|
+
# Store wisent results
|
|
215
|
+
wisent_results[(extraction_strategy, method, scale)] = {
|
|
216
|
+
"lm_eval": steered_acc,
|
|
217
|
+
"ll": steered_ll_acc,
|
|
218
|
+
}
|
|
219
|
+
|
|
220
|
+
# Step 6: Free wisent_model to make room for SteeredModel
|
|
221
|
+
del wisent_model
|
|
222
|
+
gc.collect()
|
|
223
|
+
if torch.cuda.is_available():
|
|
224
|
+
torch.cuda.empty_cache()
|
|
225
|
+
torch.cuda.synchronize()
|
|
226
|
+
|
|
227
|
+
# Step 7: Run ALL lm-eval native steered evaluations (one at a time)
|
|
228
|
+
for method in methods:
|
|
229
|
+
# CAA uses extraction strategy, FGAA/SAE don't
|
|
230
|
+
for extraction_strategy in (extraction_strategies if method == "caa" else [None]):
|
|
231
|
+
if extraction_strategy not in steering_vectors_data:
|
|
232
|
+
continue
|
|
233
|
+
if method not in steering_vectors_data[extraction_strategy]:
|
|
234
|
+
continue
|
|
235
|
+
|
|
236
|
+
steering_data = steering_vectors_data[extraction_strategy][method]
|
|
237
|
+
|
|
238
|
+
for scale in steering_scales:
|
|
239
|
+
print(f"\n{'='*60}")
|
|
240
|
+
print(f"Running lm-eval NATIVE steered for: {task} (strategy={extraction_strategy}, method={method}, scale={scale})")
|
|
241
|
+
print(f"{'='*60}")
|
|
242
|
+
|
|
243
|
+
# Convert steering vector to lm-eval format
|
|
244
|
+
suffix = f"_{extraction_strategy}" if extraction_strategy else ""
|
|
245
|
+
lm_eval_steer_path = vectors_dir / f"{task}_{method}{suffix}_lm_eval_steer_scale{scale}.pt"
|
|
246
|
+
convert_to_lm_eval_format(steering_data, lm_eval_steer_path, scale=scale)
|
|
247
|
+
|
|
248
|
+
lm_steered = SteeredModel(
|
|
249
|
+
pretrained=model_name,
|
|
250
|
+
steer_path=str(lm_eval_steer_path),
|
|
251
|
+
device=device,
|
|
252
|
+
batch_size=batch_size,
|
|
253
|
+
max_batch_size=max_batch_size,
|
|
254
|
+
)
|
|
255
|
+
|
|
256
|
+
lm_eval_native_results = evaluator.evaluate(
|
|
257
|
+
lm=lm_steered,
|
|
258
|
+
task_dict=task_dict,
|
|
259
|
+
limit=eval_limit,
|
|
260
|
+
)
|
|
261
|
+
lm_eval_native_acc = extract_accuracy(lm_eval_native_results, task)
|
|
262
|
+
print(f"lm-eval native steered accuracy: {lm_eval_native_acc:.4f}")
|
|
263
|
+
|
|
264
|
+
# Clean up SteeredModel to free GPU for next iteration
|
|
265
|
+
del lm_steered
|
|
266
|
+
gc.collect()
|
|
267
|
+
if torch.cuda.is_available():
|
|
268
|
+
torch.cuda.empty_cache()
|
|
269
|
+
torch.cuda.synchronize()
|
|
270
|
+
|
|
271
|
+
# Store combined results
|
|
272
|
+
wisent_result = wisent_results[(extraction_strategy, method, scale)]
|
|
273
|
+
steered_acc_lm_eval = wisent_result["lm_eval"]
|
|
274
|
+
steered_acc_ll = wisent_result["ll"]
|
|
275
|
+
results_list.append({
|
|
276
|
+
"task": task,
|
|
277
|
+
"extraction_strategy": extraction_strategy or "N/A",
|
|
278
|
+
"method": method,
|
|
279
|
+
"model": model_name,
|
|
280
|
+
"layers": steering_data['layers'],
|
|
281
|
+
"num_pairs": num_pairs,
|
|
282
|
+
"steering_scale": scale,
|
|
283
|
+
"base_accuracy_lm_eval": base_acc,
|
|
284
|
+
"base_accuracy_ll": base_ll_acc,
|
|
285
|
+
"steered_accuracy_lm_eval": steered_acc_lm_eval,
|
|
286
|
+
"steered_accuracy_ll": steered_acc_ll,
|
|
287
|
+
"steered_accuracy_lm_eval_native": lm_eval_native_acc,
|
|
288
|
+
"difference_lm_eval": steered_acc_lm_eval - base_acc,
|
|
289
|
+
"difference_ll": steered_acc_ll - base_ll_acc,
|
|
290
|
+
"difference_lm_eval_native": lm_eval_native_acc - base_acc,
|
|
291
|
+
})
|
|
292
|
+
|
|
293
|
+
return results_list
|
|
294
|
+
|
|
295
|
+
|
|
296
|
+
def run_comparison(
|
|
297
|
+
model_name: str,
|
|
298
|
+
tasks: list[str],
|
|
299
|
+
methods: list[str] = None,
|
|
300
|
+
num_pairs: int = 50,
|
|
301
|
+
steering_scales: list[float] = None,
|
|
302
|
+
device: str = "cuda:0",
|
|
303
|
+
batch_size: int | str = 1,
|
|
304
|
+
max_batch_size: int = 8,
|
|
305
|
+
eval_limit: int | None = None,
|
|
306
|
+
output_dir: str = "comparison_results",
|
|
307
|
+
train_ratio: float = 0.8,
|
|
308
|
+
caa_layers: str = "12",
|
|
309
|
+
sae_layers: str = "12",
|
|
310
|
+
extraction_strategies: list[str] = None,
|
|
311
|
+
bos_features_source: str = "detected",
|
|
312
|
+
) -> list[dict]:
|
|
313
|
+
"""
|
|
314
|
+
Run full comparison for multiple tasks, methods, scales, and extraction strategies.
|
|
315
|
+
"""
|
|
316
|
+
if methods is None:
|
|
317
|
+
methods = ["caa"]
|
|
318
|
+
if steering_scales is None:
|
|
319
|
+
steering_scales = [1.0]
|
|
320
|
+
if extraction_strategies is None:
|
|
321
|
+
extraction_strategies = ["mc_balanced"]
|
|
322
|
+
|
|
323
|
+
output_dir = Path(output_dir)
|
|
324
|
+
# Add model name to path (sanitize "/" -> "_")
|
|
325
|
+
model_dir_name = model_name.replace("/", "_")
|
|
326
|
+
output_dir = output_dir / model_dir_name
|
|
327
|
+
vectors_dir = output_dir / "steering_vectors"
|
|
328
|
+
results_dir = output_dir / "results"
|
|
329
|
+
|
|
330
|
+
output_dir.mkdir(parents=True, exist_ok=True)
|
|
331
|
+
vectors_dir.mkdir(parents=True, exist_ok=True)
|
|
332
|
+
results_dir.mkdir(parents=True, exist_ok=True)
|
|
333
|
+
|
|
334
|
+
all_results = []
|
|
335
|
+
|
|
336
|
+
for task in tasks:
|
|
337
|
+
print(f"\n{'#'*60}")
|
|
338
|
+
print(f"# TASK: {task}")
|
|
339
|
+
print(f"{'#'*60}")
|
|
340
|
+
|
|
341
|
+
task_results = run_single_task(
|
|
342
|
+
model_name=model_name,
|
|
343
|
+
task=task,
|
|
344
|
+
methods=methods,
|
|
345
|
+
num_pairs=num_pairs,
|
|
346
|
+
steering_scales=steering_scales,
|
|
347
|
+
device=device,
|
|
348
|
+
batch_size=batch_size,
|
|
349
|
+
max_batch_size=max_batch_size,
|
|
350
|
+
eval_limit=eval_limit,
|
|
351
|
+
vectors_dir=vectors_dir,
|
|
352
|
+
train_ratio=train_ratio,
|
|
353
|
+
caa_layers=caa_layers,
|
|
354
|
+
sae_layers=sae_layers,
|
|
355
|
+
extraction_strategies=extraction_strategies,
|
|
356
|
+
bos_features_source=bos_features_source,
|
|
357
|
+
)
|
|
358
|
+
all_results.extend(task_results)
|
|
359
|
+
|
|
360
|
+
# Save results for this task (includes all strategies)
|
|
361
|
+
task_results_file = results_dir / f"{task}_results.json"
|
|
362
|
+
with open(task_results_file, "w") as f:
|
|
363
|
+
json.dump(task_results, f, indent=2)
|
|
364
|
+
print(f"Results for {task} saved to: {task_results_file}")
|
|
365
|
+
|
|
366
|
+
# Print final summary table
|
|
367
|
+
print(f"\n{'='*150}")
|
|
368
|
+
print(f"FINAL COMPARISON RESULTS")
|
|
369
|
+
print(f"{'='*150}")
|
|
370
|
+
print(f"Model: {model_name}")
|
|
371
|
+
print(f"Num pairs: {num_pairs}")
|
|
372
|
+
print(f"CAA Layers: {caa_layers}")
|
|
373
|
+
print(f"SAE/FGAA Layers: {sae_layers}")
|
|
374
|
+
print(f"Strategies: {', '.join(extraction_strategies)}")
|
|
375
|
+
print(f"{'='*150}")
|
|
376
|
+
print(f"{'Strategy':<16} {'Task':<10} {'Method':<8} {'Scale':<6} {'Base(E)':<8} {'Base(L)':<8} {'Steer(E)':<9} {'Steer(L)':<9} {'Native':<8} {'Diff(E)':<8} {'Diff(L)':<8} {'Diff(N)':<8}")
|
|
377
|
+
print(f"{'-'*150}")
|
|
378
|
+
|
|
379
|
+
for r in all_results:
|
|
380
|
+
print(f"{r.get('extraction_strategy', 'N/A'):<16} {r['task']:<10} {r['method']:<8} {r['steering_scale']:<6.1f} "
|
|
381
|
+
f"{r['base_accuracy_lm_eval']:<8.4f} {r['base_accuracy_ll']:<8.4f} "
|
|
382
|
+
f"{r['steered_accuracy_lm_eval']:<9.4f} {r['steered_accuracy_ll']:<9.4f} {r['steered_accuracy_lm_eval_native']:<8.4f} "
|
|
383
|
+
f"{r['difference_lm_eval']:+<8.4f} {r['difference_ll']:+<8.4f} {r['difference_lm_eval_native']:+<8.4f}")
|
|
384
|
+
|
|
385
|
+
print(f"{'='*150}")
|
|
386
|
+
|
|
387
|
+
print(f"\nSteering vectors saved to: {vectors_dir}")
|
|
388
|
+
print(f"Results saved to: {results_dir}")
|
|
389
|
+
|
|
390
|
+
return all_results
|
|
391
|
+
|
|
392
|
+
|
|
393
|
+
def main():
|
|
394
|
+
parser = argparse.ArgumentParser(description="Compare steering methods")
|
|
395
|
+
parser.add_argument("--model", default="EleutherAI/gpt-neo-125M", help="Model name")
|
|
396
|
+
parser.add_argument("--tasks", default="boolq", help="Comma-separated lm-eval tasks (e.g., boolq,cb,copa)")
|
|
397
|
+
parser.add_argument("--methods", default="caa", help="Comma-separated methods (e.g., caa,sae,fgaa)")
|
|
398
|
+
parser.add_argument("--num-pairs", type=int, default=50, help="Number of contrastive pairs")
|
|
399
|
+
parser.add_argument("--scales", default="1.0", help="Comma-separated steering scales (e.g., 0.5,1.0,1.5)")
|
|
400
|
+
parser.add_argument("--caa-layers", default="12", help="Layer(s) for CAA steering (default: 12)")
|
|
401
|
+
parser.add_argument("--sae-layers", default="12", help="Layer(s) for SAE/FGAA steering (default: 12)")
|
|
402
|
+
parser.add_argument("--device", default="cuda:0", help="Device")
|
|
403
|
+
parser.add_argument("--batch-size", default=1, help="Batch size (int or 'auto')")
|
|
404
|
+
parser.add_argument("--max-batch-size", type=int, default=8, help="Max batch size for lm-eval internal batching (reduce if OOM)")
|
|
405
|
+
parser.add_argument("--limit", type=int, default=None, help="Limit eval examples")
|
|
406
|
+
parser.add_argument("--output-dir", default="wisent/comparison/comparison_results", help="Output directory")
|
|
407
|
+
parser.add_argument("--train-ratio", type=float, default=0.8, help="Train/test split ratio (default 0.8 = 80%% train, 20%% test)")
|
|
408
|
+
parser.add_argument("--extraction-strategy", default="mc_balanced",
|
|
409
|
+
help="Extraction strategy (comma-separated for multiple). Chat models: chat_mean, chat_first, chat_last, chat_max_norm, chat_weighted, role_play, mc_balanced. Base models: completion_last, completion_mean, mc_completion")
|
|
410
|
+
parser.add_argument("--bos-features-source", default="detected",
|
|
411
|
+
help="BOS features source for FGAA: 'paper' (5 features), 'detected' (12 features), or 'none'")
|
|
412
|
+
|
|
413
|
+
args = parser.parse_args()
|
|
414
|
+
|
|
415
|
+
# Parse comma-separated values
|
|
416
|
+
tasks = [t.strip() for t in args.tasks.split(",")]
|
|
417
|
+
methods = [m.strip() for m in args.methods.split(",")]
|
|
418
|
+
scales = [float(s.strip()) for s in args.scales.split(",")]
|
|
419
|
+
extraction_strategies = [s.strip() for s in args.extraction_strategy.split(",")]
|
|
420
|
+
|
|
421
|
+
# Parse batch_size (can be int or "auto")
|
|
422
|
+
batch_size = args.batch_size if args.batch_size == "auto" else int(args.batch_size)
|
|
423
|
+
|
|
424
|
+
run_comparison(
|
|
425
|
+
model_name=args.model,
|
|
426
|
+
tasks=tasks,
|
|
427
|
+
methods=methods,
|
|
428
|
+
num_pairs=args.num_pairs,
|
|
429
|
+
steering_scales=scales,
|
|
430
|
+
device=args.device,
|
|
431
|
+
batch_size=batch_size,
|
|
432
|
+
max_batch_size=args.max_batch_size,
|
|
433
|
+
eval_limit=args.limit,
|
|
434
|
+
output_dir=args.output_dir,
|
|
435
|
+
train_ratio=args.train_ratio,
|
|
436
|
+
caa_layers=args.caa_layers,
|
|
437
|
+
sae_layers=args.sae_layers,
|
|
438
|
+
extraction_strategies=extraction_strategies,
|
|
439
|
+
bos_features_source=args.bos_features_source,
|
|
440
|
+
)
|
|
441
|
+
|
|
442
|
+
|
|
443
|
+
if __name__ == "__main__":
|
|
444
|
+
main()
|
|
@@ -0,0 +1,76 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Our steering method wrapper for comparison experiments.
|
|
3
|
+
|
|
4
|
+
Uses the existing wisent infrastructure to create steering vectors.
|
|
5
|
+
Runs steering vector generation in subprocess to guarantee memory cleanup.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from __future__ import annotations
|
|
9
|
+
|
|
10
|
+
import json
|
|
11
|
+
import subprocess
|
|
12
|
+
from pathlib import Path
|
|
13
|
+
from typing import TYPE_CHECKING
|
|
14
|
+
|
|
15
|
+
import torch
|
|
16
|
+
|
|
17
|
+
from wisent.comparison.utils import apply_steering_to_model, remove_steering, convert_to_lm_eval_format
|
|
18
|
+
|
|
19
|
+
if TYPE_CHECKING:
|
|
20
|
+
from wisent.core.models.wisent_model import WisentModel
|
|
21
|
+
|
|
22
|
+
__all__ = ["generate_steering_vector", "apply_steering_to_model", "remove_steering", "convert_to_lm_eval_format"]
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def generate_steering_vector(
|
|
26
|
+
task: str,
|
|
27
|
+
model_name: str,
|
|
28
|
+
output_path: str | Path,
|
|
29
|
+
trait_label: str = "correctness",
|
|
30
|
+
num_pairs: int = 50,
|
|
31
|
+
method: str = "caa",
|
|
32
|
+
layers: str | None = None,
|
|
33
|
+
normalize: bool = True,
|
|
34
|
+
device: str = "cuda:0",
|
|
35
|
+
keep_intermediate: bool = False,
|
|
36
|
+
extraction_strategy: str = "mc_balanced",
|
|
37
|
+
) -> Path:
|
|
38
|
+
"""
|
|
39
|
+
Generate a steering vector using wisent CLI in subprocess.
|
|
40
|
+
|
|
41
|
+
Runs in subprocess to guarantee GPU memory is freed when done.
|
|
42
|
+
"""
|
|
43
|
+
output_path = Path(output_path)
|
|
44
|
+
|
|
45
|
+
cmd = [
|
|
46
|
+
"wisent", "generate-vector-from-task",
|
|
47
|
+
"--task", task,
|
|
48
|
+
"--trait-label", trait_label,
|
|
49
|
+
"--model", model_name,
|
|
50
|
+
"--num-pairs", str(num_pairs),
|
|
51
|
+
"--method", method,
|
|
52
|
+
"--output", str(output_path),
|
|
53
|
+
"--device", device,
|
|
54
|
+
"--extraction-strategy", extraction_strategy,
|
|
55
|
+
"--accept-low-quality-vector",
|
|
56
|
+
]
|
|
57
|
+
|
|
58
|
+
if layers:
|
|
59
|
+
cmd.extend(["--layers", layers])
|
|
60
|
+
|
|
61
|
+
if normalize:
|
|
62
|
+
cmd.append("--normalize")
|
|
63
|
+
|
|
64
|
+
if keep_intermediate:
|
|
65
|
+
cmd.append("--keep-intermediate")
|
|
66
|
+
|
|
67
|
+
result = subprocess.run(cmd)
|
|
68
|
+
|
|
69
|
+
if result.returncode != 0:
|
|
70
|
+
raise RuntimeError(f"Failed to generate steering vector (exit code {result.returncode})")
|
|
71
|
+
|
|
72
|
+
return output_path
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
|