wisent 0.7.901__py3-none-any.whl → 0.7.1045__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (100) hide show
  1. wisent/__init__.py +1 -1
  2. wisent/comparison/__init__.py +1 -0
  3. wisent/comparison/detect_bos_features.py +275 -0
  4. wisent/comparison/fgaa.py +465 -0
  5. wisent/comparison/lora.py +669 -0
  6. wisent/comparison/lora_dpo.py +592 -0
  7. wisent/comparison/main.py +444 -0
  8. wisent/comparison/ours.py +76 -0
  9. wisent/comparison/sae.py +304 -0
  10. wisent/comparison/utils.py +381 -0
  11. wisent/core/activations/activations_collector.py +3 -2
  12. wisent/core/activations/extraction_strategy.py +8 -4
  13. wisent/core/cli/agent/apply_steering.py +7 -5
  14. wisent/core/cli/agent/train_classifier.py +4 -3
  15. wisent/core/cli/generate_vector_from_task.py +11 -20
  16. wisent/core/cli/get_activations.py +1 -1
  17. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/boolq.py +20 -3
  18. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/cb.py +8 -1
  19. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/truthfulqa_mc1.py +8 -1
  20. wisent/core/parser_arguments/generate_vector_from_task_parser.py +4 -11
  21. wisent/core/parser_arguments/get_activations_parser.py +5 -14
  22. {wisent-0.7.901.dist-info → wisent-0.7.1045.dist-info}/METADATA +5 -1
  23. {wisent-0.7.901.dist-info → wisent-0.7.1045.dist-info}/RECORD +27 -91
  24. wisent/examples/contrastive_pairs/humanization_human_vs_ai.json +0 -2112
  25. wisent/examples/scripts/1/test_basqueglue_evaluation.json +0 -51
  26. wisent/examples/scripts/1/test_basqueglue_pairs.json +0 -14
  27. wisent/examples/scripts/1/test_bec2016eu_evaluation.json +0 -51
  28. wisent/examples/scripts/1/test_bec2016eu_pairs.json +0 -14
  29. wisent/examples/scripts/1/test_belebele_evaluation.json +0 -51
  30. wisent/examples/scripts/1/test_belebele_pairs.json +0 -14
  31. wisent/examples/scripts/1/test_benchmarks_evaluation.json +0 -51
  32. wisent/examples/scripts/1/test_benchmarks_pairs.json +0 -14
  33. wisent/examples/scripts/1/test_bertaqa_evaluation.json +0 -51
  34. wisent/examples/scripts/1/test_bertaqa_pairs.json +0 -14
  35. wisent/examples/scripts/1/test_bhtc_v2_evaluation.json +0 -30
  36. wisent/examples/scripts/1/test_bhtc_v2_pairs.json +0 -8
  37. wisent/examples/scripts/1/test_boolq-seq2seq_evaluation.json +0 -30
  38. wisent/examples/scripts/1/test_boolq-seq2seq_pairs.json +0 -8
  39. wisent/examples/scripts/1/test_cabreu_evaluation.json +0 -30
  40. wisent/examples/scripts/1/test_cabreu_pairs.json +0 -8
  41. wisent/examples/scripts/1/test_careqa_en_evaluation.json +0 -30
  42. wisent/examples/scripts/1/test_careqa_en_pairs.json +0 -8
  43. wisent/examples/scripts/1/test_careqa_evaluation.json +0 -30
  44. wisent/examples/scripts/1/test_careqa_pairs.json +0 -8
  45. wisent/examples/scripts/1/test_catalanqa_evaluation.json +0 -30
  46. wisent/examples/scripts/1/test_catalanqa_pairs.json +0 -8
  47. wisent/examples/scripts/1/test_catcola_evaluation.json +0 -30
  48. wisent/examples/scripts/1/test_catcola_pairs.json +0 -8
  49. wisent/examples/scripts/1/test_chartqa_evaluation.json +0 -30
  50. wisent/examples/scripts/1/test_chartqa_pairs.json +0 -8
  51. wisent/examples/scripts/1/test_claim_stance_topic_evaluation.json +0 -30
  52. wisent/examples/scripts/1/test_claim_stance_topic_pairs.json +0 -8
  53. wisent/examples/scripts/1/test_cnn_dailymail_evaluation.json +0 -30
  54. wisent/examples/scripts/1/test_cnn_dailymail_pairs.json +0 -8
  55. wisent/examples/scripts/1/test_cocoteros_es_evaluation.json +0 -30
  56. wisent/examples/scripts/1/test_cocoteros_es_pairs.json +0 -8
  57. wisent/examples/scripts/1/test_coedit_gec_evaluation.json +0 -30
  58. wisent/examples/scripts/1/test_coedit_gec_pairs.json +0 -8
  59. wisent/examples/scripts/1/test_cola_evaluation.json +0 -30
  60. wisent/examples/scripts/1/test_cola_pairs.json +0 -8
  61. wisent/examples/scripts/1/test_coqcat_evaluation.json +0 -30
  62. wisent/examples/scripts/1/test_coqcat_pairs.json +0 -8
  63. wisent/examples/scripts/1/test_dbpedia_14_evaluation.json +0 -30
  64. wisent/examples/scripts/1/test_dbpedia_14_pairs.json +0 -8
  65. wisent/examples/scripts/1/test_epec_koref_bin_evaluation.json +0 -30
  66. wisent/examples/scripts/1/test_epec_koref_bin_pairs.json +0 -8
  67. wisent/examples/scripts/1/test_ethos_binary_evaluation.json +0 -30
  68. wisent/examples/scripts/1/test_ethos_binary_pairs.json +0 -8
  69. wisent/examples/scripts/2/test_afrimgsm_direct_amh_evaluation.json +0 -30
  70. wisent/examples/scripts/2/test_afrimgsm_direct_amh_pairs.json +0 -8
  71. wisent/examples/scripts/2/test_afrimmlu_direct_amh_evaluation.json +0 -30
  72. wisent/examples/scripts/2/test_afrimmlu_direct_amh_pairs.json +0 -8
  73. wisent/examples/scripts/2/test_afrixnli_en_direct_amh_evaluation.json +0 -30
  74. wisent/examples/scripts/2/test_afrixnli_en_direct_amh_pairs.json +0 -8
  75. wisent/examples/scripts/2/test_arc_ar_evaluation.json +0 -30
  76. wisent/examples/scripts/2/test_arc_ar_pairs.json +0 -8
  77. wisent/examples/scripts/2/test_atis_evaluation.json +0 -30
  78. wisent/examples/scripts/2/test_atis_pairs.json +0 -8
  79. wisent/examples/scripts/2/test_babi_evaluation.json +0 -30
  80. wisent/examples/scripts/2/test_babi_pairs.json +0 -8
  81. wisent/examples/scripts/2/test_babilong_evaluation.json +0 -30
  82. wisent/examples/scripts/2/test_babilong_pairs.json +0 -8
  83. wisent/examples/scripts/2/test_bangla_mmlu_evaluation.json +0 -30
  84. wisent/examples/scripts/2/test_bangla_mmlu_pairs.json +0 -8
  85. wisent/examples/scripts/2/test_basque-glue_pairs.json +0 -14
  86. wisent/examples/scripts/generate_paper_data.py +0 -384
  87. wisent/examples/scripts/intervention_validation.py +0 -626
  88. wisent/examples/scripts/results/test_AraDiCE_ArabicMMLU_lev_evaluation.json +0 -324
  89. wisent/examples/scripts/results/test_AraDiCE_ArabicMMLU_lev_pairs.json +0 -92
  90. wisent/examples/scripts/results/test_aexams_IslamicStudies_evaluation.json +0 -324
  91. wisent/examples/scripts/results/test_aexams_IslamicStudies_pairs.json +0 -92
  92. wisent/examples/scripts/results/test_afrimgsm_pairs.json +0 -92
  93. wisent/examples/scripts/results/test_afrimmlu_evaluation.json +0 -324
  94. wisent/examples/scripts/results/test_afrimmlu_pairs.json +0 -92
  95. wisent/examples/scripts/threshold_analysis.py +0 -434
  96. wisent/examples/scripts/visualization_gallery.py +0 -582
  97. {wisent-0.7.901.dist-info → wisent-0.7.1045.dist-info}/WHEEL +0 -0
  98. {wisent-0.7.901.dist-info → wisent-0.7.1045.dist-info}/entry_points.txt +0 -0
  99. {wisent-0.7.901.dist-info → wisent-0.7.1045.dist-info}/licenses/LICENSE +0 -0
  100. {wisent-0.7.901.dist-info → wisent-0.7.1045.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,444 @@
1
+ """
2
+ Comparison of steering methods: Ours vs SAE-based.
3
+
4
+ This script:
5
+ 1. Creates steering vectors using train split of pooled data
6
+ 2. Runs base evaluation on test split (no overlap)
7
+ 3. Runs steered evaluation on same test split
8
+ """
9
+
10
+ from __future__ import annotations
11
+
12
+ import argparse
13
+ import gc
14
+ import json
15
+ from pathlib import Path
16
+
17
+ import torch
18
+ from lm_eval import evaluator
19
+ from lm_eval.models.hf_steered import SteeredModel
20
+
21
+ from wisent.core.models.wisent_model import WisentModel
22
+ from wisent.comparison import ours
23
+ from wisent.comparison import sae
24
+ from wisent.comparison import fgaa
25
+ from wisent.comparison.utils import (
26
+ load_steering_vector,
27
+ apply_steering_to_model,
28
+ remove_steering,
29
+ convert_to_lm_eval_format,
30
+ create_test_only_task,
31
+ extract_accuracy,
32
+ run_lm_eval_evaluation,
33
+ run_ll_evaluation,
34
+ )
35
+
36
+ # Map method names to modules
37
+ METHOD_MODULES = {
38
+ "caa": ours,
39
+ "sae": sae,
40
+ "fgaa": fgaa,
41
+ }
42
+
43
+
44
+ def run_single_task(
45
+ model_name: str,
46
+ task: str,
47
+ methods: list[str] = None,
48
+ num_pairs: int = 50,
49
+ steering_scales: list[float] = None,
50
+ device: str = "cuda:0",
51
+ batch_size: int | str = 1,
52
+ max_batch_size: int = 8,
53
+ eval_limit: int | None = None,
54
+ vectors_dir: Path = None,
55
+ train_ratio: float = 0.8,
56
+ caa_layers: str = "12",
57
+ sae_layers: str = "12",
58
+ extraction_strategies: list[str] = None,
59
+ bos_features_source: str = "detected",
60
+ ) -> list[dict]:
61
+ """
62
+ Run comparison for a single task with multiple methods, scales, and extraction strategies.
63
+
64
+ Returns list of result dicts, one per method/scale/strategy combination.
65
+ """
66
+ if methods is None:
67
+ methods = ["caa"]
68
+ if steering_scales is None:
69
+ steering_scales = [1.0]
70
+ if extraction_strategies is None:
71
+ extraction_strategies = ["mc_balanced"]
72
+
73
+ results_list = []
74
+
75
+ # Step 1: Create test task
76
+ test_pct = round((1 - train_ratio) * 100)
77
+ print(f"\n{'='*60}")
78
+ print(f"Creating test task for: {task}")
79
+ print(f"(using {test_pct}% of pooled data)")
80
+ print(f"{'='*60}")
81
+
82
+ task_dict = create_test_only_task(task, train_ratio=train_ratio)
83
+
84
+ # Step 2: Generate ALL steering vectors FIRST for ALL strategies (subprocess frees GPU memory after each)
85
+ # Structure: steering_vectors_data[strategy][method] = steering_data
86
+ steering_vectors_data = {}
87
+ train_pct = round(train_ratio * 100)
88
+
89
+ for method in methods:
90
+ if method not in METHOD_MODULES:
91
+ print(f"WARNING: Method '{method}' not implemented, skipping")
92
+ continue
93
+
94
+ method_module = METHOD_MODULES[method]
95
+
96
+ # CAA uses extraction strategy, FGAA/SAE don't
97
+ for extraction_strategy in (extraction_strategies if method == "caa" else [None]):
98
+ print(f"\n{'@'*60}")
99
+ print(f"@ METHOD: {method}, EXTRACTION STRATEGY: {extraction_strategy or 'N/A'}")
100
+ print(f"{'@'*60}")
101
+
102
+ # Select layers based on method: CAA uses caa_layers (default=middle), SAE/FGAA use sae_layers (default=12)
103
+ method_layers = caa_layers if method == "caa" else sae_layers
104
+
105
+ print(f"\n{'='*60}")
106
+ print(f"Generating steering vector for: {task} (method={method})")
107
+ print(f"(using {train_pct}% of pooled data - no overlap with test)")
108
+ print(f"Layers: {method_layers}")
109
+ print(f"{'='*60}")
110
+
111
+ suffix = f"_{extraction_strategy}" if extraction_strategy else ""
112
+ vector_path = vectors_dir / f"{task}_{method}{suffix}_steering_vector.json"
113
+
114
+ kwargs = {
115
+ "task": task,
116
+ "model_name": model_name,
117
+ "output_path": vector_path,
118
+ "num_pairs": num_pairs,
119
+ "device": device,
120
+ "layers": method_layers,
121
+ }
122
+ if extraction_strategy:
123
+ kwargs["extraction_strategy"] = extraction_strategy
124
+ if method == "fgaa":
125
+ kwargs["bos_features_source"] = bos_features_source
126
+
127
+ method_module.generate_steering_vector(**kwargs)
128
+
129
+ steering_data = load_steering_vector(vector_path, default_method=method)
130
+ if extraction_strategy not in steering_vectors_data:
131
+ steering_vectors_data[extraction_strategy] = {}
132
+ steering_vectors_data[extraction_strategy][method] = steering_data
133
+ print(f"Loaded steering vector with layers: {steering_data['layers']}")
134
+
135
+ # Step 3: Load model once for ALL evaluations
136
+ print(f"\n{'='*60}")
137
+ print(f"Loading model: {model_name}")
138
+ print(f"{'='*60}")
139
+ wisent_model = WisentModel(model_name=model_name, device=device)
140
+
141
+ # Step 4: Run base evaluation (no steering applied)
142
+ print(f"\n{'='*60}")
143
+ print(f"Running BASE evaluation for: {task}")
144
+ print(f"{'='*60}")
145
+
146
+ base_results = run_lm_eval_evaluation(
147
+ wisent_model=wisent_model,
148
+ task_dict=task_dict,
149
+ task_name=task,
150
+ batch_size=batch_size,
151
+ max_batch_size=max_batch_size,
152
+ limit=eval_limit,
153
+ )
154
+ base_acc = extract_accuracy(base_results, task)
155
+ print(f"Base accuracy (lm-eval): {base_acc:.4f}")
156
+
157
+ # Step 4b: Run base LL evaluation (no steering)
158
+ print(f"\n{'='*60}")
159
+ print(f"Running BASE LL evaluation for: {task}")
160
+ print(f"{'='*60}")
161
+
162
+ base_ll_acc = run_ll_evaluation(
163
+ wisent_model=wisent_model,
164
+ task_dict=task_dict,
165
+ task_name=task,
166
+ limit=eval_limit,
167
+ )
168
+ print(f"Base accuracy (LL): {base_ll_acc:.4f}")
169
+
170
+ # Step 5: Run ALL wisent steered evaluations first (model stays loaded)
171
+ # Structure: wisent_results[(strategy, method, scale)] = steered_acc
172
+ wisent_results = {}
173
+ for method in methods:
174
+ # CAA uses extraction strategy, FGAA/SAE don't
175
+ for extraction_strategy in (extraction_strategies if method == "caa" else [None]):
176
+ if extraction_strategy not in steering_vectors_data:
177
+ continue
178
+ if method not in steering_vectors_data[extraction_strategy]:
179
+ continue
180
+
181
+ steering_data = steering_vectors_data[extraction_strategy][method]
182
+
183
+ for scale in steering_scales:
184
+ print(f"\n{'='*60}")
185
+ print(f"Running STEERED evaluation for: {task} (strategy={extraction_strategy}, method={method}, scale={scale})")
186
+ print(f"{'='*60}")
187
+
188
+ # Apply steering to existing model
189
+ apply_steering_to_model(wisent_model, steering_data, scale=scale)
190
+
191
+ steered_results = run_lm_eval_evaluation(
192
+ wisent_model=wisent_model,
193
+ task_dict=task_dict,
194
+ task_name=task,
195
+ batch_size=batch_size,
196
+ max_batch_size=max_batch_size,
197
+ limit=eval_limit,
198
+ )
199
+ steered_acc = extract_accuracy(steered_results, task)
200
+ print(f"Steered accuracy (lm-eval): {steered_acc:.4f}")
201
+
202
+ # Run steered LL evaluation
203
+ steered_ll_acc = run_ll_evaluation(
204
+ wisent_model=wisent_model,
205
+ task_dict=task_dict,
206
+ task_name=task,
207
+ limit=eval_limit,
208
+ )
209
+ print(f"Steered accuracy (LL): {steered_ll_acc:.4f}")
210
+
211
+ # Remove steering for next iteration
212
+ remove_steering(wisent_model)
213
+
214
+ # Store wisent results
215
+ wisent_results[(extraction_strategy, method, scale)] = {
216
+ "lm_eval": steered_acc,
217
+ "ll": steered_ll_acc,
218
+ }
219
+
220
+ # Step 6: Free wisent_model to make room for SteeredModel
221
+ del wisent_model
222
+ gc.collect()
223
+ if torch.cuda.is_available():
224
+ torch.cuda.empty_cache()
225
+ torch.cuda.synchronize()
226
+
227
+ # Step 7: Run ALL lm-eval native steered evaluations (one at a time)
228
+ for method in methods:
229
+ # CAA uses extraction strategy, FGAA/SAE don't
230
+ for extraction_strategy in (extraction_strategies if method == "caa" else [None]):
231
+ if extraction_strategy not in steering_vectors_data:
232
+ continue
233
+ if method not in steering_vectors_data[extraction_strategy]:
234
+ continue
235
+
236
+ steering_data = steering_vectors_data[extraction_strategy][method]
237
+
238
+ for scale in steering_scales:
239
+ print(f"\n{'='*60}")
240
+ print(f"Running lm-eval NATIVE steered for: {task} (strategy={extraction_strategy}, method={method}, scale={scale})")
241
+ print(f"{'='*60}")
242
+
243
+ # Convert steering vector to lm-eval format
244
+ suffix = f"_{extraction_strategy}" if extraction_strategy else ""
245
+ lm_eval_steer_path = vectors_dir / f"{task}_{method}{suffix}_lm_eval_steer_scale{scale}.pt"
246
+ convert_to_lm_eval_format(steering_data, lm_eval_steer_path, scale=scale)
247
+
248
+ lm_steered = SteeredModel(
249
+ pretrained=model_name,
250
+ steer_path=str(lm_eval_steer_path),
251
+ device=device,
252
+ batch_size=batch_size,
253
+ max_batch_size=max_batch_size,
254
+ )
255
+
256
+ lm_eval_native_results = evaluator.evaluate(
257
+ lm=lm_steered,
258
+ task_dict=task_dict,
259
+ limit=eval_limit,
260
+ )
261
+ lm_eval_native_acc = extract_accuracy(lm_eval_native_results, task)
262
+ print(f"lm-eval native steered accuracy: {lm_eval_native_acc:.4f}")
263
+
264
+ # Clean up SteeredModel to free GPU for next iteration
265
+ del lm_steered
266
+ gc.collect()
267
+ if torch.cuda.is_available():
268
+ torch.cuda.empty_cache()
269
+ torch.cuda.synchronize()
270
+
271
+ # Store combined results
272
+ wisent_result = wisent_results[(extraction_strategy, method, scale)]
273
+ steered_acc_lm_eval = wisent_result["lm_eval"]
274
+ steered_acc_ll = wisent_result["ll"]
275
+ results_list.append({
276
+ "task": task,
277
+ "extraction_strategy": extraction_strategy or "N/A",
278
+ "method": method,
279
+ "model": model_name,
280
+ "layers": steering_data['layers'],
281
+ "num_pairs": num_pairs,
282
+ "steering_scale": scale,
283
+ "base_accuracy_lm_eval": base_acc,
284
+ "base_accuracy_ll": base_ll_acc,
285
+ "steered_accuracy_lm_eval": steered_acc_lm_eval,
286
+ "steered_accuracy_ll": steered_acc_ll,
287
+ "steered_accuracy_lm_eval_native": lm_eval_native_acc,
288
+ "difference_lm_eval": steered_acc_lm_eval - base_acc,
289
+ "difference_ll": steered_acc_ll - base_ll_acc,
290
+ "difference_lm_eval_native": lm_eval_native_acc - base_acc,
291
+ })
292
+
293
+ return results_list
294
+
295
+
296
+ def run_comparison(
297
+ model_name: str,
298
+ tasks: list[str],
299
+ methods: list[str] = None,
300
+ num_pairs: int = 50,
301
+ steering_scales: list[float] = None,
302
+ device: str = "cuda:0",
303
+ batch_size: int | str = 1,
304
+ max_batch_size: int = 8,
305
+ eval_limit: int | None = None,
306
+ output_dir: str = "comparison_results",
307
+ train_ratio: float = 0.8,
308
+ caa_layers: str = "12",
309
+ sae_layers: str = "12",
310
+ extraction_strategies: list[str] = None,
311
+ bos_features_source: str = "detected",
312
+ ) -> list[dict]:
313
+ """
314
+ Run full comparison for multiple tasks, methods, scales, and extraction strategies.
315
+ """
316
+ if methods is None:
317
+ methods = ["caa"]
318
+ if steering_scales is None:
319
+ steering_scales = [1.0]
320
+ if extraction_strategies is None:
321
+ extraction_strategies = ["mc_balanced"]
322
+
323
+ output_dir = Path(output_dir)
324
+ # Add model name to path (sanitize "/" -> "_")
325
+ model_dir_name = model_name.replace("/", "_")
326
+ output_dir = output_dir / model_dir_name
327
+ vectors_dir = output_dir / "steering_vectors"
328
+ results_dir = output_dir / "results"
329
+
330
+ output_dir.mkdir(parents=True, exist_ok=True)
331
+ vectors_dir.mkdir(parents=True, exist_ok=True)
332
+ results_dir.mkdir(parents=True, exist_ok=True)
333
+
334
+ all_results = []
335
+
336
+ for task in tasks:
337
+ print(f"\n{'#'*60}")
338
+ print(f"# TASK: {task}")
339
+ print(f"{'#'*60}")
340
+
341
+ task_results = run_single_task(
342
+ model_name=model_name,
343
+ task=task,
344
+ methods=methods,
345
+ num_pairs=num_pairs,
346
+ steering_scales=steering_scales,
347
+ device=device,
348
+ batch_size=batch_size,
349
+ max_batch_size=max_batch_size,
350
+ eval_limit=eval_limit,
351
+ vectors_dir=vectors_dir,
352
+ train_ratio=train_ratio,
353
+ caa_layers=caa_layers,
354
+ sae_layers=sae_layers,
355
+ extraction_strategies=extraction_strategies,
356
+ bos_features_source=bos_features_source,
357
+ )
358
+ all_results.extend(task_results)
359
+
360
+ # Save results for this task (includes all strategies)
361
+ task_results_file = results_dir / f"{task}_results.json"
362
+ with open(task_results_file, "w") as f:
363
+ json.dump(task_results, f, indent=2)
364
+ print(f"Results for {task} saved to: {task_results_file}")
365
+
366
+ # Print final summary table
367
+ print(f"\n{'='*150}")
368
+ print(f"FINAL COMPARISON RESULTS")
369
+ print(f"{'='*150}")
370
+ print(f"Model: {model_name}")
371
+ print(f"Num pairs: {num_pairs}")
372
+ print(f"CAA Layers: {caa_layers}")
373
+ print(f"SAE/FGAA Layers: {sae_layers}")
374
+ print(f"Strategies: {', '.join(extraction_strategies)}")
375
+ print(f"{'='*150}")
376
+ print(f"{'Strategy':<16} {'Task':<10} {'Method':<8} {'Scale':<6} {'Base(E)':<8} {'Base(L)':<8} {'Steer(E)':<9} {'Steer(L)':<9} {'Native':<8} {'Diff(E)':<8} {'Diff(L)':<8} {'Diff(N)':<8}")
377
+ print(f"{'-'*150}")
378
+
379
+ for r in all_results:
380
+ print(f"{r.get('extraction_strategy', 'N/A'):<16} {r['task']:<10} {r['method']:<8} {r['steering_scale']:<6.1f} "
381
+ f"{r['base_accuracy_lm_eval']:<8.4f} {r['base_accuracy_ll']:<8.4f} "
382
+ f"{r['steered_accuracy_lm_eval']:<9.4f} {r['steered_accuracy_ll']:<9.4f} {r['steered_accuracy_lm_eval_native']:<8.4f} "
383
+ f"{r['difference_lm_eval']:+<8.4f} {r['difference_ll']:+<8.4f} {r['difference_lm_eval_native']:+<8.4f}")
384
+
385
+ print(f"{'='*150}")
386
+
387
+ print(f"\nSteering vectors saved to: {vectors_dir}")
388
+ print(f"Results saved to: {results_dir}")
389
+
390
+ return all_results
391
+
392
+
393
+ def main():
394
+ parser = argparse.ArgumentParser(description="Compare steering methods")
395
+ parser.add_argument("--model", default="EleutherAI/gpt-neo-125M", help="Model name")
396
+ parser.add_argument("--tasks", default="boolq", help="Comma-separated lm-eval tasks (e.g., boolq,cb,copa)")
397
+ parser.add_argument("--methods", default="caa", help="Comma-separated methods (e.g., caa,sae,fgaa)")
398
+ parser.add_argument("--num-pairs", type=int, default=50, help="Number of contrastive pairs")
399
+ parser.add_argument("--scales", default="1.0", help="Comma-separated steering scales (e.g., 0.5,1.0,1.5)")
400
+ parser.add_argument("--caa-layers", default="12", help="Layer(s) for CAA steering (default: 12)")
401
+ parser.add_argument("--sae-layers", default="12", help="Layer(s) for SAE/FGAA steering (default: 12)")
402
+ parser.add_argument("--device", default="cuda:0", help="Device")
403
+ parser.add_argument("--batch-size", default=1, help="Batch size (int or 'auto')")
404
+ parser.add_argument("--max-batch-size", type=int, default=8, help="Max batch size for lm-eval internal batching (reduce if OOM)")
405
+ parser.add_argument("--limit", type=int, default=None, help="Limit eval examples")
406
+ parser.add_argument("--output-dir", default="wisent/comparison/comparison_results", help="Output directory")
407
+ parser.add_argument("--train-ratio", type=float, default=0.8, help="Train/test split ratio (default 0.8 = 80%% train, 20%% test)")
408
+ parser.add_argument("--extraction-strategy", default="mc_balanced",
409
+ help="Extraction strategy (comma-separated for multiple). Chat models: chat_mean, chat_first, chat_last, chat_max_norm, chat_weighted, role_play, mc_balanced. Base models: completion_last, completion_mean, mc_completion")
410
+ parser.add_argument("--bos-features-source", default="detected",
411
+ help="BOS features source for FGAA: 'paper' (5 features), 'detected' (12 features), or 'none'")
412
+
413
+ args = parser.parse_args()
414
+
415
+ # Parse comma-separated values
416
+ tasks = [t.strip() for t in args.tasks.split(",")]
417
+ methods = [m.strip() for m in args.methods.split(",")]
418
+ scales = [float(s.strip()) for s in args.scales.split(",")]
419
+ extraction_strategies = [s.strip() for s in args.extraction_strategy.split(",")]
420
+
421
+ # Parse batch_size (can be int or "auto")
422
+ batch_size = args.batch_size if args.batch_size == "auto" else int(args.batch_size)
423
+
424
+ run_comparison(
425
+ model_name=args.model,
426
+ tasks=tasks,
427
+ methods=methods,
428
+ num_pairs=args.num_pairs,
429
+ steering_scales=scales,
430
+ device=args.device,
431
+ batch_size=batch_size,
432
+ max_batch_size=args.max_batch_size,
433
+ eval_limit=args.limit,
434
+ output_dir=args.output_dir,
435
+ train_ratio=args.train_ratio,
436
+ caa_layers=args.caa_layers,
437
+ sae_layers=args.sae_layers,
438
+ extraction_strategies=extraction_strategies,
439
+ bos_features_source=args.bos_features_source,
440
+ )
441
+
442
+
443
+ if __name__ == "__main__":
444
+ main()
@@ -0,0 +1,76 @@
1
+ """
2
+ Our steering method wrapper for comparison experiments.
3
+
4
+ Uses the existing wisent infrastructure to create steering vectors.
5
+ Runs steering vector generation in subprocess to guarantee memory cleanup.
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ import json
11
+ import subprocess
12
+ from pathlib import Path
13
+ from typing import TYPE_CHECKING
14
+
15
+ import torch
16
+
17
+ from wisent.comparison.utils import apply_steering_to_model, remove_steering, convert_to_lm_eval_format
18
+
19
+ if TYPE_CHECKING:
20
+ from wisent.core.models.wisent_model import WisentModel
21
+
22
+ __all__ = ["generate_steering_vector", "apply_steering_to_model", "remove_steering", "convert_to_lm_eval_format"]
23
+
24
+
25
+ def generate_steering_vector(
26
+ task: str,
27
+ model_name: str,
28
+ output_path: str | Path,
29
+ trait_label: str = "correctness",
30
+ num_pairs: int = 50,
31
+ method: str = "caa",
32
+ layers: str | None = None,
33
+ normalize: bool = True,
34
+ device: str = "cuda:0",
35
+ keep_intermediate: bool = False,
36
+ extraction_strategy: str = "mc_balanced",
37
+ ) -> Path:
38
+ """
39
+ Generate a steering vector using wisent CLI in subprocess.
40
+
41
+ Runs in subprocess to guarantee GPU memory is freed when done.
42
+ """
43
+ output_path = Path(output_path)
44
+
45
+ cmd = [
46
+ "wisent", "generate-vector-from-task",
47
+ "--task", task,
48
+ "--trait-label", trait_label,
49
+ "--model", model_name,
50
+ "--num-pairs", str(num_pairs),
51
+ "--method", method,
52
+ "--output", str(output_path),
53
+ "--device", device,
54
+ "--extraction-strategy", extraction_strategy,
55
+ "--accept-low-quality-vector",
56
+ ]
57
+
58
+ if layers:
59
+ cmd.extend(["--layers", layers])
60
+
61
+ if normalize:
62
+ cmd.append("--normalize")
63
+
64
+ if keep_intermediate:
65
+ cmd.append("--keep-intermediate")
66
+
67
+ result = subprocess.run(cmd)
68
+
69
+ if result.returncode != 0:
70
+ raise RuntimeError(f"Failed to generate steering vector (exit code {result.returncode})")
71
+
72
+ return output_path
73
+
74
+
75
+
76
+