wisent 0.7.901__py3-none-any.whl → 0.7.1045__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (100) hide show
  1. wisent/__init__.py +1 -1
  2. wisent/comparison/__init__.py +1 -0
  3. wisent/comparison/detect_bos_features.py +275 -0
  4. wisent/comparison/fgaa.py +465 -0
  5. wisent/comparison/lora.py +669 -0
  6. wisent/comparison/lora_dpo.py +592 -0
  7. wisent/comparison/main.py +444 -0
  8. wisent/comparison/ours.py +76 -0
  9. wisent/comparison/sae.py +304 -0
  10. wisent/comparison/utils.py +381 -0
  11. wisent/core/activations/activations_collector.py +3 -2
  12. wisent/core/activations/extraction_strategy.py +8 -4
  13. wisent/core/cli/agent/apply_steering.py +7 -5
  14. wisent/core/cli/agent/train_classifier.py +4 -3
  15. wisent/core/cli/generate_vector_from_task.py +11 -20
  16. wisent/core/cli/get_activations.py +1 -1
  17. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/boolq.py +20 -3
  18. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/cb.py +8 -1
  19. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/truthfulqa_mc1.py +8 -1
  20. wisent/core/parser_arguments/generate_vector_from_task_parser.py +4 -11
  21. wisent/core/parser_arguments/get_activations_parser.py +5 -14
  22. {wisent-0.7.901.dist-info → wisent-0.7.1045.dist-info}/METADATA +5 -1
  23. {wisent-0.7.901.dist-info → wisent-0.7.1045.dist-info}/RECORD +27 -91
  24. wisent/examples/contrastive_pairs/humanization_human_vs_ai.json +0 -2112
  25. wisent/examples/scripts/1/test_basqueglue_evaluation.json +0 -51
  26. wisent/examples/scripts/1/test_basqueglue_pairs.json +0 -14
  27. wisent/examples/scripts/1/test_bec2016eu_evaluation.json +0 -51
  28. wisent/examples/scripts/1/test_bec2016eu_pairs.json +0 -14
  29. wisent/examples/scripts/1/test_belebele_evaluation.json +0 -51
  30. wisent/examples/scripts/1/test_belebele_pairs.json +0 -14
  31. wisent/examples/scripts/1/test_benchmarks_evaluation.json +0 -51
  32. wisent/examples/scripts/1/test_benchmarks_pairs.json +0 -14
  33. wisent/examples/scripts/1/test_bertaqa_evaluation.json +0 -51
  34. wisent/examples/scripts/1/test_bertaqa_pairs.json +0 -14
  35. wisent/examples/scripts/1/test_bhtc_v2_evaluation.json +0 -30
  36. wisent/examples/scripts/1/test_bhtc_v2_pairs.json +0 -8
  37. wisent/examples/scripts/1/test_boolq-seq2seq_evaluation.json +0 -30
  38. wisent/examples/scripts/1/test_boolq-seq2seq_pairs.json +0 -8
  39. wisent/examples/scripts/1/test_cabreu_evaluation.json +0 -30
  40. wisent/examples/scripts/1/test_cabreu_pairs.json +0 -8
  41. wisent/examples/scripts/1/test_careqa_en_evaluation.json +0 -30
  42. wisent/examples/scripts/1/test_careqa_en_pairs.json +0 -8
  43. wisent/examples/scripts/1/test_careqa_evaluation.json +0 -30
  44. wisent/examples/scripts/1/test_careqa_pairs.json +0 -8
  45. wisent/examples/scripts/1/test_catalanqa_evaluation.json +0 -30
  46. wisent/examples/scripts/1/test_catalanqa_pairs.json +0 -8
  47. wisent/examples/scripts/1/test_catcola_evaluation.json +0 -30
  48. wisent/examples/scripts/1/test_catcola_pairs.json +0 -8
  49. wisent/examples/scripts/1/test_chartqa_evaluation.json +0 -30
  50. wisent/examples/scripts/1/test_chartqa_pairs.json +0 -8
  51. wisent/examples/scripts/1/test_claim_stance_topic_evaluation.json +0 -30
  52. wisent/examples/scripts/1/test_claim_stance_topic_pairs.json +0 -8
  53. wisent/examples/scripts/1/test_cnn_dailymail_evaluation.json +0 -30
  54. wisent/examples/scripts/1/test_cnn_dailymail_pairs.json +0 -8
  55. wisent/examples/scripts/1/test_cocoteros_es_evaluation.json +0 -30
  56. wisent/examples/scripts/1/test_cocoteros_es_pairs.json +0 -8
  57. wisent/examples/scripts/1/test_coedit_gec_evaluation.json +0 -30
  58. wisent/examples/scripts/1/test_coedit_gec_pairs.json +0 -8
  59. wisent/examples/scripts/1/test_cola_evaluation.json +0 -30
  60. wisent/examples/scripts/1/test_cola_pairs.json +0 -8
  61. wisent/examples/scripts/1/test_coqcat_evaluation.json +0 -30
  62. wisent/examples/scripts/1/test_coqcat_pairs.json +0 -8
  63. wisent/examples/scripts/1/test_dbpedia_14_evaluation.json +0 -30
  64. wisent/examples/scripts/1/test_dbpedia_14_pairs.json +0 -8
  65. wisent/examples/scripts/1/test_epec_koref_bin_evaluation.json +0 -30
  66. wisent/examples/scripts/1/test_epec_koref_bin_pairs.json +0 -8
  67. wisent/examples/scripts/1/test_ethos_binary_evaluation.json +0 -30
  68. wisent/examples/scripts/1/test_ethos_binary_pairs.json +0 -8
  69. wisent/examples/scripts/2/test_afrimgsm_direct_amh_evaluation.json +0 -30
  70. wisent/examples/scripts/2/test_afrimgsm_direct_amh_pairs.json +0 -8
  71. wisent/examples/scripts/2/test_afrimmlu_direct_amh_evaluation.json +0 -30
  72. wisent/examples/scripts/2/test_afrimmlu_direct_amh_pairs.json +0 -8
  73. wisent/examples/scripts/2/test_afrixnli_en_direct_amh_evaluation.json +0 -30
  74. wisent/examples/scripts/2/test_afrixnli_en_direct_amh_pairs.json +0 -8
  75. wisent/examples/scripts/2/test_arc_ar_evaluation.json +0 -30
  76. wisent/examples/scripts/2/test_arc_ar_pairs.json +0 -8
  77. wisent/examples/scripts/2/test_atis_evaluation.json +0 -30
  78. wisent/examples/scripts/2/test_atis_pairs.json +0 -8
  79. wisent/examples/scripts/2/test_babi_evaluation.json +0 -30
  80. wisent/examples/scripts/2/test_babi_pairs.json +0 -8
  81. wisent/examples/scripts/2/test_babilong_evaluation.json +0 -30
  82. wisent/examples/scripts/2/test_babilong_pairs.json +0 -8
  83. wisent/examples/scripts/2/test_bangla_mmlu_evaluation.json +0 -30
  84. wisent/examples/scripts/2/test_bangla_mmlu_pairs.json +0 -8
  85. wisent/examples/scripts/2/test_basque-glue_pairs.json +0 -14
  86. wisent/examples/scripts/generate_paper_data.py +0 -384
  87. wisent/examples/scripts/intervention_validation.py +0 -626
  88. wisent/examples/scripts/results/test_AraDiCE_ArabicMMLU_lev_evaluation.json +0 -324
  89. wisent/examples/scripts/results/test_AraDiCE_ArabicMMLU_lev_pairs.json +0 -92
  90. wisent/examples/scripts/results/test_aexams_IslamicStudies_evaluation.json +0 -324
  91. wisent/examples/scripts/results/test_aexams_IslamicStudies_pairs.json +0 -92
  92. wisent/examples/scripts/results/test_afrimgsm_pairs.json +0 -92
  93. wisent/examples/scripts/results/test_afrimmlu_evaluation.json +0 -324
  94. wisent/examples/scripts/results/test_afrimmlu_pairs.json +0 -92
  95. wisent/examples/scripts/threshold_analysis.py +0 -434
  96. wisent/examples/scripts/visualization_gallery.py +0 -582
  97. {wisent-0.7.901.dist-info → wisent-0.7.1045.dist-info}/WHEEL +0 -0
  98. {wisent-0.7.901.dist-info → wisent-0.7.1045.dist-info}/entry_points.txt +0 -0
  99. {wisent-0.7.901.dist-info → wisent-0.7.1045.dist-info}/licenses/LICENSE +0 -0
  100. {wisent-0.7.901.dist-info → wisent-0.7.1045.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,304 @@
1
+ """
2
+ SAE-based steering method for comparison experiments.
3
+
4
+ Uses Sparse Autoencoders to identify steering directions from contrastive pairs.
5
+ Computes steering vector using SAE decoder features weighted by feature differences.
6
+
7
+ Supports Gemma models with Gemma Scope SAEs via sae_lens.
8
+ """
9
+
10
+ from __future__ import annotations
11
+
12
+ import gc
13
+ import json
14
+ from pathlib import Path
15
+ from typing import TYPE_CHECKING
16
+
17
+ import torch
18
+
19
+ from wisent.comparison.utils import (
20
+ apply_steering_to_model,
21
+ remove_steering,
22
+ convert_to_lm_eval_format,
23
+ generate_contrastive_pairs,
24
+ load_model_and_tokenizer,
25
+ load_sae,
26
+ SAE_CONFIGS,
27
+ )
28
+
29
+ if TYPE_CHECKING:
30
+ from wisent.core.models.wisent_model import WisentModel
31
+
32
+ __all__ = ["generate_steering_vector", "apply_steering_to_model", "remove_steering", "convert_to_lm_eval_format"]
33
+
34
+
35
+ def _get_residual_stream_activations(
36
+ model,
37
+ tokenizer,
38
+ text: str,
39
+ layer_idx: int,
40
+ device: str,
41
+ ) -> torch.Tensor:
42
+ """
43
+ Get residual stream activations from a specific layer.
44
+
45
+ Args:
46
+ model: HuggingFace model
47
+ tokenizer: Tokenizer
48
+ text: Input text
49
+ layer_idx: Layer index (0-indexed)
50
+ device: Device
51
+
52
+ Returns:
53
+ Tensor of shape (1, seq_len, d_model)
54
+ """
55
+ inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=1024)
56
+ inputs = {k: v.to(device) for k, v in inputs.items()}
57
+
58
+ with torch.no_grad():
59
+ out = model(**inputs, output_hidden_states=True, use_cache=False)
60
+
61
+ # hidden_states is tuple: (embedding, layer0, layer1, ..., layerN)
62
+ # layer_idx=0 -> hs[1], layer_idx=12 -> hs[13]
63
+ hs = out.hidden_states
64
+ return hs[layer_idx + 1]
65
+
66
+
67
+ def compute_feature_diff(
68
+ model,
69
+ tokenizer,
70
+ sae,
71
+ pairs: list[dict],
72
+ layer_idx: int,
73
+ device: str,
74
+ ) -> torch.Tensor:
75
+ """
76
+ Compute feature difference between positive and negative examples in SAE space.
77
+
78
+ feature_diff = mean(encode(h_pos)) - mean(encode(h_neg))
79
+
80
+ Args:
81
+ model: HuggingFace model
82
+ tokenizer: Tokenizer
83
+ sae: SAE object from sae_lens
84
+ pairs: List of contrastive pairs
85
+ layer_idx: Layer to extract activations from
86
+ device: Device
87
+
88
+ Returns:
89
+ feature_diff tensor of shape [d_sae]
90
+ """
91
+ pos_features_list = []
92
+ neg_features_list = []
93
+
94
+ print(f" Computing feature diff from {len(pairs)} pairs...")
95
+
96
+ for i, pair in enumerate(pairs):
97
+ prompt = pair["prompt"]
98
+ pos_response = pair["positive_response"]["model_response"]
99
+ neg_response = pair["negative_response"]["model_response"]
100
+
101
+ pos_text = f"{prompt} {pos_response}"
102
+ neg_text = f"{prompt} {neg_response}"
103
+
104
+ # Get activations and encode through SAE
105
+ pos_acts = _get_residual_stream_activations(model, tokenizer, pos_text, layer_idx, device)
106
+ pos_acts = pos_acts.to(device).to(sae.W_enc.dtype)
107
+ pos_latents = sae.encode(pos_acts)
108
+ pos_features_list.append(pos_latents.mean(dim=1).detach()) # Mean over sequence
109
+
110
+ neg_acts = _get_residual_stream_activations(model, tokenizer, neg_text, layer_idx, device)
111
+ neg_acts = neg_acts.to(device).to(sae.W_enc.dtype)
112
+ neg_latents = sae.encode(neg_acts)
113
+ neg_features_list.append(neg_latents.mean(dim=1).detach())
114
+
115
+ if (i + 1) % 10 == 0:
116
+ print(f" Processed {i + 1}/{len(pairs)} pairs")
117
+
118
+ # Stack and compute mean difference
119
+ pos_features = torch.cat(pos_features_list, dim=0) # [num_pairs, d_sae]
120
+ neg_features = torch.cat(neg_features_list, dim=0)
121
+
122
+ feature_diff = pos_features.mean(dim=0) - neg_features.mean(dim=0) # [d_sae]
123
+
124
+ print(f" feature_diff computed, shape: {feature_diff.shape}")
125
+ print(f" feature_diff stats: mean={feature_diff.mean():.6f}, std={feature_diff.std():.6f}")
126
+
127
+ return feature_diff
128
+
129
+
130
+ def compute_steering_vector_from_decoder(
131
+ feature_diff: torch.Tensor,
132
+ sae,
133
+ top_k: int = 4,
134
+ normalize: bool = True,
135
+ ) -> tuple[torch.Tensor, dict]:
136
+ """
137
+ Compute steering vector using SAE decoder features.
138
+
139
+ steering_vector = sum(feature_diff[i] * W_dec[i]) for top-k features
140
+
141
+ Args:
142
+ feature_diff: Difference vector in SAE space [d_sae]
143
+ sae: SAE object with W_dec decoder weights
144
+ top_k: Number of top features to use
145
+ normalize: Whether to normalize the final steering vector
146
+
147
+ Returns:
148
+ Tuple of (steering_vector [d_model], feature_info dict)
149
+ """
150
+ # Find top-k features by absolute difference
151
+ abs_diff = feature_diff.abs()
152
+ top_values, top_indices = abs_diff.topk(min(top_k, len(feature_diff)))
153
+
154
+ print(f" Selected top {len(top_indices)} features")
155
+ print(f" Top feature indices: {top_indices[:10].tolist()}...")
156
+ print(f" Top feature diff magnitudes: {top_values[:10].tolist()}")
157
+
158
+ # Construct steering vector from decoder
159
+ # W_dec shape: [d_sae, d_model]
160
+ steering_vector = torch.zeros(sae.W_dec.shape[1], device=sae.W_dec.device, dtype=sae.W_dec.dtype)
161
+
162
+ for feat_idx in top_indices:
163
+ steering_vector += feature_diff[feat_idx] * sae.W_dec[feat_idx]
164
+
165
+ if normalize:
166
+ norm = steering_vector.norm()
167
+ if norm > 0:
168
+ steering_vector = steering_vector / norm
169
+ print(f" Normalized steering vector (original norm: {norm:.4f})")
170
+
171
+ print(f" steering_vector shape: {steering_vector.shape}, norm: {steering_vector.norm():.6f}")
172
+
173
+ feature_info = {
174
+ "top_k": top_k,
175
+ "top_indices": top_indices.tolist(),
176
+ "top_diff_values": [feature_diff[i].item() for i in top_indices],
177
+ }
178
+
179
+ return steering_vector, feature_info
180
+
181
+
182
+ def generate_steering_vector(
183
+ task: str,
184
+ model_name: str,
185
+ output_path: str | Path,
186
+ trait_label: str = "correctness",
187
+ num_pairs: int = 50,
188
+ method: str = "sae",
189
+ layers: str | None = None,
190
+ normalize: bool = True,
191
+ device: str = "cuda:0",
192
+ keep_intermediate: bool = False,
193
+ top_k: int = 4,
194
+ **kwargs, # Accept additional kwargs for compatibility
195
+ ) -> Path:
196
+ """
197
+ Generate a steering vector using SAE decoder features.
198
+
199
+ Args:
200
+ task: lm-eval task name (e.g., 'boolq', 'cb')
201
+ model_name: HuggingFace model name (must be Gemma 2B or 9B)
202
+ output_path: Where to save the steering vector
203
+ trait_label: Label for the trait being steered
204
+ num_pairs: Number of contrastive pairs to use
205
+ method: Method name (should be 'sae')
206
+ layers: Layer(s) to use (e.g., '12' or '10,11,12')
207
+ normalize: Whether to normalize the steering vector
208
+ device: Device to run on
209
+ keep_intermediate: Whether to keep intermediate files
210
+ top_k: Number of top SAE features to use
211
+
212
+ Returns:
213
+ Path to the saved steering vector
214
+ """
215
+ output_path = Path(output_path)
216
+
217
+ if model_name not in SAE_CONFIGS:
218
+ raise ValueError(
219
+ f"No SAE config for model '{model_name}'. "
220
+ f"Supported models: {list(SAE_CONFIGS.keys())}"
221
+ )
222
+
223
+ config = SAE_CONFIGS[model_name]
224
+
225
+ # Parse layers
226
+ if layers is None:
227
+ layer_indices = [config["default_layer"]]
228
+ elif layers == "all":
229
+ layer_indices = list(range(config["num_layers"]))
230
+ else:
231
+ layer_indices = [int(l.strip()) for l in layers.split(",")]
232
+
233
+ # Step 1: Generate contrastive pairs
234
+ print(f"Step 1: Generating contrastive pairs from task: {task}")
235
+ pairs, pairs_file = generate_contrastive_pairs(task, num_pairs)
236
+ print(f" Loaded {len(pairs)} contrastive pairs")
237
+
238
+ # Step 2: Load model
239
+ print(f"\nStep 2: Loading model {model_name}...")
240
+ model, tokenizer = load_model_and_tokenizer(model_name, device)
241
+
242
+ steering_vectors = {}
243
+ feature_info = {}
244
+
245
+ for layer_idx in layer_indices:
246
+ print(f"\nStep 3: Processing layer {layer_idx}")
247
+
248
+ # Load SAE for this layer
249
+ sae, sparsity = load_sae(model_name, layer_idx, device=device)
250
+
251
+ # Step 4: Compute feature difference
252
+ print(f"\nStep 4: Computing feature diff for layer {layer_idx}...")
253
+ feat_diff = compute_feature_diff(model, tokenizer, sae, pairs, layer_idx, device)
254
+
255
+ # Step 5: Compute steering vector from decoder
256
+ print(f"\nStep 5: Computing steering vector from decoder for layer {layer_idx}...")
257
+ steering_vec, feat_info = compute_steering_vector_from_decoder(
258
+ feat_diff, sae, top_k=top_k, normalize=normalize
259
+ )
260
+
261
+ steering_vectors[str(layer_idx)] = steering_vec.cpu().float().tolist()
262
+ feature_info[str(layer_idx)] = feat_info
263
+
264
+ # Cleanup SAE
265
+ del sae, sparsity, feat_diff
266
+ gc.collect()
267
+ if torch.cuda.is_available():
268
+ torch.cuda.empty_cache()
269
+
270
+ # Cleanup model
271
+ del model
272
+ gc.collect()
273
+ if torch.cuda.is_available():
274
+ torch.cuda.empty_cache()
275
+ torch.cuda.synchronize()
276
+
277
+ # Cleanup temp files
278
+ if not keep_intermediate:
279
+ import os
280
+ os.unlink(pairs_file)
281
+
282
+ # Save results
283
+ result = {
284
+ "steering_vectors": steering_vectors,
285
+ "layers": [str(l) for l in layer_indices],
286
+ "model": model_name,
287
+ "method": "sae",
288
+ "trait_label": trait_label,
289
+ "task": task,
290
+ "num_pairs": len(pairs),
291
+ "sae_config": {
292
+ "release": config["sae_release"],
293
+ "top_k": top_k,
294
+ "normalize": normalize,
295
+ },
296
+ "feature_info": feature_info,
297
+ }
298
+
299
+ output_path.parent.mkdir(parents=True, exist_ok=True)
300
+ with open(output_path, "w") as f:
301
+ json.dump(result, f, indent=2)
302
+
303
+ print(f"\nSaved SAE steering vector to {output_path}")
304
+ return output_path
@@ -0,0 +1,381 @@
1
+ """
2
+ Shared utilities for comparison experiments.
3
+ """
4
+
5
+ from __future__ import annotations
6
+
7
+ import json
8
+ import tempfile
9
+ from argparse import Namespace
10
+ from pathlib import Path
11
+ from typing import TYPE_CHECKING
12
+
13
+ import torch
14
+ from wisent.core.utils.device import preferred_dtype
15
+
16
+ if TYPE_CHECKING:
17
+ from wisent.core.models.wisent_model import WisentModel
18
+
19
+
20
+ # SAE configurations for supported Gemma models
21
+ SAE_CONFIGS = {
22
+ "google/gemma-2-2b": {
23
+ "sae_release": "gemma-scope-2b-pt-res-canonical",
24
+ "sae_id_template": "layer_{layer}/width_16k/canonical",
25
+ "num_layers": 26,
26
+ "default_layer": 12,
27
+ "d_model": 2304,
28
+ "d_sae": 16384,
29
+ },
30
+ "google/gemma-2-9b": {
31
+ "sae_release": "gemma-scope-9b-pt-res-canonical",
32
+ "sae_id_template": "layer_{layer}/width_16k/canonical",
33
+ "num_layers": 42,
34
+ "default_layer": 12,
35
+ "d_model": 3584,
36
+ "d_sae": 16384,
37
+ },
38
+ }
39
+
40
+
41
+ def load_sae(model_name: str, layer_idx: int, device: str = "cuda:0"):
42
+ """
43
+ Load Gemma Scope SAE for a specific layer.
44
+
45
+ Args:
46
+ model_name: HuggingFace model name (e.g., 'google/gemma-2-2b')
47
+ layer_idx: Layer index to load SAE for
48
+ device: Device to load SAE on
49
+
50
+ Returns:
51
+ Tuple of (SAE object, sparsity tensor)
52
+ """
53
+ from sae_lens import SAE
54
+
55
+ if model_name not in SAE_CONFIGS:
56
+ raise ValueError(f"No SAE config for model '{model_name}'. Supported: {list(SAE_CONFIGS.keys())}")
57
+
58
+ config = SAE_CONFIGS[model_name]
59
+ sae_id = config["sae_id_template"].format(layer=layer_idx)
60
+
61
+ print(f" Loading SAE from {config['sae_release']} / {sae_id}")
62
+ sae, _, sparsity = SAE.from_pretrained(
63
+ release=config["sae_release"],
64
+ sae_id=sae_id,
65
+ device=device,
66
+ )
67
+
68
+ return sae, sparsity
69
+
70
+
71
+ def load_model_and_tokenizer(
72
+ model_name: str,
73
+ device: str = "cuda:0",
74
+ eval_mode: bool = True,
75
+ ) -> tuple:
76
+ """
77
+ Load HuggingFace model and tokenizer.
78
+
79
+ Args:
80
+ model_name: HuggingFace model name
81
+ device: Device to load model on
82
+ eval_mode: Whether to set model to eval mode (default True for inference)
83
+
84
+ Returns:
85
+ Tuple of (model, tokenizer)
86
+ """
87
+ from transformers import AutoModelForCausalLM, AutoTokenizer
88
+
89
+ model = AutoModelForCausalLM.from_pretrained(
90
+ model_name,
91
+ torch_dtype=preferred_dtype(device),
92
+ device_map=device,
93
+ trust_remote_code=True,
94
+ )
95
+ tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
96
+ if tokenizer.pad_token is None:
97
+ tokenizer.pad_token = tokenizer.eos_token
98
+
99
+ if eval_mode:
100
+ model.eval()
101
+
102
+ return model, tokenizer
103
+
104
+
105
+ def generate_contrastive_pairs(
106
+ task: str,
107
+ num_pairs: int,
108
+ seed: int = 42,
109
+ verbose: bool = False,
110
+ ) -> tuple[list[dict], str]:
111
+ """
112
+ Generate contrastive pairs from an lm-eval task.
113
+
114
+ Args:
115
+ task: lm-eval task name (e.g., 'boolq', 'cb')
116
+ num_pairs: Number of pairs to generate
117
+ seed: Random seed for reproducibility
118
+ verbose: Whether to print verbose output
119
+
120
+ Returns:
121
+ Tuple of (pairs list, path to temporary pairs file).
122
+ Caller is responsible for cleaning up the file if needed.
123
+ """
124
+ from wisent.core.cli.generate_pairs_from_task import execute_generate_pairs_from_task
125
+
126
+ pairs_file = tempfile.NamedTemporaryFile(mode='w', suffix='_pairs.json', delete=False).name
127
+ pairs_args = Namespace(
128
+ task_name=task,
129
+ limit=num_pairs,
130
+ output=pairs_file,
131
+ seed=seed,
132
+ verbose=verbose,
133
+ )
134
+ execute_generate_pairs_from_task(pairs_args)
135
+
136
+ with open(pairs_file) as f:
137
+ pairs_data = json.load(f)
138
+ pairs = pairs_data["pairs"]
139
+
140
+ return pairs, pairs_file
141
+
142
+
143
+ def create_test_only_task(task_name: str, train_ratio: float = 0.8) -> dict:
144
+ """
145
+ Create a task that evaluates only on our test split.
146
+
147
+ This ensures no overlap with the data used for steering vector training.
148
+
149
+ Args:
150
+ task_name: lm-eval task name (e.g., 'boolq', 'cb')
151
+ train_ratio: Fraction of data used for training (default 0.8)
152
+
153
+ Returns:
154
+ Task dict with test split configured
155
+ """
156
+ from lm_eval.tasks import get_task_dict
157
+ from wisent.core.utils.dataset_splits import get_test_docs
158
+
159
+ task_dict = get_task_dict([task_name])
160
+ task = task_dict[task_name]
161
+
162
+ test_docs = get_test_docs(task, benchmark_name=task_name, train_ratio=train_ratio)
163
+ test_pct = round((1 - train_ratio) * 100)
164
+
165
+ print(f"Test split size: {len(test_docs)} docs ({test_pct}% of pooled data)")
166
+
167
+ # Override task's doc methods to use our test split
168
+ task.test_docs = lambda: test_docs
169
+ task.has_test_docs = lambda: True
170
+ task._eval_docs = test_docs
171
+
172
+ return {task_name: task}
173
+
174
+
175
+ def extract_accuracy(results: dict, task: str) -> float:
176
+ """
177
+ Extract accuracy from lm-eval results.
178
+
179
+ Args:
180
+ results: Results dict from lm-eval evaluator
181
+ task: Task name to extract accuracy for
182
+
183
+ Returns:
184
+ Accuracy value (0.0 if not found)
185
+ """
186
+ task_results = results.get("results", {}).get(task, {})
187
+ for key in ["acc", "acc,none", "accuracy", "acc_norm", "acc_norm,none"]:
188
+ if key in task_results:
189
+ return task_results[key]
190
+ return 0.0
191
+
192
+
193
+ def run_lm_eval_evaluation(
194
+ wisent_model: "WisentModel",
195
+ task_dict: dict,
196
+ task_name: str,
197
+ batch_size: int | str = 1,
198
+ max_batch_size: int = 8,
199
+ limit: int | None = None,
200
+ ) -> dict:
201
+ """
202
+ Run evaluation using lm-eval-harness.
203
+
204
+ Args:
205
+ wisent_model: WisentModel instance
206
+ task_dict: Task dict from create_test_only_task
207
+ task_name: lm-eval task name
208
+ batch_size: Batch size for evaluation
209
+ max_batch_size: Max batch size for lm-eval internal batching
210
+ limit: Max number of examples to evaluate
211
+
212
+ Returns:
213
+ Full results dict from lm-eval
214
+ """
215
+ from lm_eval import evaluator
216
+ from lm_eval.models.huggingface import HFLM
217
+
218
+ lm = HFLM(
219
+ pretrained=wisent_model.hf_model,
220
+ tokenizer=wisent_model.tokenizer,
221
+ batch_size=batch_size,
222
+ max_batch_size=max_batch_size,
223
+ )
224
+
225
+ results = evaluator.evaluate(
226
+ lm=lm,
227
+ task_dict=task_dict,
228
+ limit=limit,
229
+ )
230
+
231
+ return results
232
+
233
+
234
+ def run_ll_evaluation(
235
+ wisent_model: "WisentModel",
236
+ task_dict: dict,
237
+ task_name: str,
238
+ limit: int | None = None,
239
+ ) -> float:
240
+ """
241
+ Run evaluation using wisent's LogLikelihoodsEvaluator.
242
+
243
+ Args:
244
+ wisent_model: WisentModel instance
245
+ task_dict: Task dict from create_test_only_task
246
+ task_name: lm-eval task name
247
+ limit: Max number of examples to evaluate
248
+
249
+ Returns:
250
+ Accuracy as float
251
+ """
252
+ from wisent.core.evaluators.benchmark_specific.log_likelihoods_evaluator import LogLikelihoodsEvaluator
253
+ from wisent.core.contrastive_pairs.lm_eval_pairs.lm_extractor_registry import get_extractor
254
+
255
+ ll_evaluator = LogLikelihoodsEvaluator()
256
+ extractor = get_extractor(task_name)
257
+
258
+ task = task_dict[task_name]
259
+ docs = list(task.test_docs())
260
+
261
+ if limit:
262
+ docs = docs[:limit]
263
+
264
+ print(f"Evaluating {len(docs)} examples with LogLikelihoodsEvaluator")
265
+
266
+ correct = 0
267
+ for i, doc in enumerate(docs):
268
+ question = task.doc_to_text(doc)
269
+ choices, expected = extractor.extract_choices_and_answer(task, doc)
270
+
271
+ result = ll_evaluator.evaluate(
272
+ response="",
273
+ expected=expected,
274
+ model=wisent_model,
275
+ question=question,
276
+ choices=choices,
277
+ task_name=task_name,
278
+ )
279
+
280
+ if result.ground_truth == "TRUTHFUL":
281
+ correct += 1
282
+
283
+ if (i + 1) % 50 == 0:
284
+ print(f" Processed {i + 1}/{len(docs)}, acc: {correct/(i+1):.4f}")
285
+
286
+ return correct / len(docs) if docs else 0.0
287
+
288
+
289
+ def load_steering_vector(path: str | Path, default_method: str = "unknown") -> dict:
290
+ """
291
+ Load a steering vector from file.
292
+
293
+ Args:
294
+ path: Path to steering vector file (.json or .pt)
295
+ default_method: Default method name if not found in file
296
+
297
+ Returns:
298
+ Dictionary with steering vectors and metadata
299
+ """
300
+ path = Path(path)
301
+
302
+ if path.suffix == ".pt":
303
+ from wisent.core.utils.device import resolve_default_device
304
+ data = torch.load(path, map_location=resolve_default_device(), weights_only=False)
305
+ layer_idx = str(data.get("layer_index", data.get("layer", 1)))
306
+ return {
307
+ "steering_vectors": {layer_idx: data["steering_vector"].tolist()},
308
+ "layers": [layer_idx],
309
+ "model": data.get("model", "unknown"),
310
+ "method": data.get("method", default_method),
311
+ "trait_label": data.get("trait_label", "unknown"),
312
+ }
313
+ else:
314
+ with open(path) as f:
315
+ return json.load(f)
316
+
317
+
318
+ def apply_steering_to_model(
319
+ model: "WisentModel",
320
+ steering_data: dict,
321
+ scale: float = 1.0,
322
+ ) -> None:
323
+ """
324
+ Apply loaded steering vectors to a WisentModel.
325
+
326
+ Args:
327
+ model: WisentModel instance
328
+ steering_data: Dictionary from load_steering_vector()
329
+ scale: Scaling factor for steering strength
330
+ """
331
+ raw_map = {}
332
+ dtype = preferred_dtype()
333
+ for layer_str, vec_list in steering_data["steering_vectors"].items():
334
+ raw_map[layer_str] = torch.tensor(vec_list, dtype=dtype)
335
+
336
+ model.set_steering_from_raw(raw_map, scale=scale, normalize=False)
337
+ model.apply_steering()
338
+
339
+
340
+ def remove_steering(model: "WisentModel") -> None:
341
+ """Remove steering from a WisentModel."""
342
+ model.detach()
343
+ model.clear_steering()
344
+
345
+
346
+ def convert_to_lm_eval_format(
347
+ steering_data: dict,
348
+ output_path: str | Path,
349
+ scale: float = 1.0,
350
+ ) -> Path:
351
+ """
352
+ Convert our steering vector format to lm-eval's steered model format.
353
+
354
+ lm-eval expects:
355
+ {
356
+ "layers.N": {
357
+ "steering_vector": tensor of shape (1, hidden_dim),
358
+ "steering_coefficient": float,
359
+ "action": "add"
360
+ }
361
+ }
362
+ """
363
+ output_path = Path(output_path)
364
+
365
+ dtype = preferred_dtype()
366
+ lm_eval_config = {}
367
+ for layer_str, vec_list in steering_data["steering_vectors"].items():
368
+ vec = torch.tensor(vec_list, dtype=dtype)
369
+ # lm-eval expects shape (1, hidden_dim)
370
+ if vec.dim() == 1:
371
+ vec = vec.unsqueeze(0)
372
+
373
+ layer_key = f"layers.{layer_str}"
374
+ lm_eval_config[layer_key] = {
375
+ "steering_vector": vec,
376
+ "steering_coefficient": scale,
377
+ "action": "add",
378
+ }
379
+
380
+ torch.save(lm_eval_config, output_path)
381
+ return output_path
@@ -66,8 +66,9 @@ class ActivationCollector:
66
66
  pos_text = _resp_text(pair.positive_response)
67
67
  neg_text = _resp_text(pair.negative_response)
68
68
 
69
- other_for_pos = neg_text if strategy == ExtractionStrategy.MC_BALANCED else None
70
- other_for_neg = pos_text if strategy == ExtractionStrategy.MC_BALANCED else None
69
+ needs_other = strategy in (ExtractionStrategy.MC_BALANCED, ExtractionStrategy.MC_COMPLETION)
70
+ other_for_pos = neg_text if needs_other else None
71
+ other_for_neg = pos_text if needs_other else None
71
72
 
72
73
  pos = self._collect_single(
73
74
  pair.prompt, pos_text, strategy, layers, normalize,