wisent 0.7.901__py3-none-any.whl → 0.7.1045__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- wisent/__init__.py +1 -1
- wisent/comparison/__init__.py +1 -0
- wisent/comparison/detect_bos_features.py +275 -0
- wisent/comparison/fgaa.py +465 -0
- wisent/comparison/lora.py +669 -0
- wisent/comparison/lora_dpo.py +592 -0
- wisent/comparison/main.py +444 -0
- wisent/comparison/ours.py +76 -0
- wisent/comparison/sae.py +304 -0
- wisent/comparison/utils.py +381 -0
- wisent/core/activations/activations_collector.py +3 -2
- wisent/core/activations/extraction_strategy.py +8 -4
- wisent/core/cli/agent/apply_steering.py +7 -5
- wisent/core/cli/agent/train_classifier.py +4 -3
- wisent/core/cli/generate_vector_from_task.py +11 -20
- wisent/core/cli/get_activations.py +1 -1
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/boolq.py +20 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/cb.py +8 -1
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/truthfulqa_mc1.py +8 -1
- wisent/core/parser_arguments/generate_vector_from_task_parser.py +4 -11
- wisent/core/parser_arguments/get_activations_parser.py +5 -14
- {wisent-0.7.901.dist-info → wisent-0.7.1045.dist-info}/METADATA +5 -1
- {wisent-0.7.901.dist-info → wisent-0.7.1045.dist-info}/RECORD +27 -91
- wisent/examples/contrastive_pairs/humanization_human_vs_ai.json +0 -2112
- wisent/examples/scripts/1/test_basqueglue_evaluation.json +0 -51
- wisent/examples/scripts/1/test_basqueglue_pairs.json +0 -14
- wisent/examples/scripts/1/test_bec2016eu_evaluation.json +0 -51
- wisent/examples/scripts/1/test_bec2016eu_pairs.json +0 -14
- wisent/examples/scripts/1/test_belebele_evaluation.json +0 -51
- wisent/examples/scripts/1/test_belebele_pairs.json +0 -14
- wisent/examples/scripts/1/test_benchmarks_evaluation.json +0 -51
- wisent/examples/scripts/1/test_benchmarks_pairs.json +0 -14
- wisent/examples/scripts/1/test_bertaqa_evaluation.json +0 -51
- wisent/examples/scripts/1/test_bertaqa_pairs.json +0 -14
- wisent/examples/scripts/1/test_bhtc_v2_evaluation.json +0 -30
- wisent/examples/scripts/1/test_bhtc_v2_pairs.json +0 -8
- wisent/examples/scripts/1/test_boolq-seq2seq_evaluation.json +0 -30
- wisent/examples/scripts/1/test_boolq-seq2seq_pairs.json +0 -8
- wisent/examples/scripts/1/test_cabreu_evaluation.json +0 -30
- wisent/examples/scripts/1/test_cabreu_pairs.json +0 -8
- wisent/examples/scripts/1/test_careqa_en_evaluation.json +0 -30
- wisent/examples/scripts/1/test_careqa_en_pairs.json +0 -8
- wisent/examples/scripts/1/test_careqa_evaluation.json +0 -30
- wisent/examples/scripts/1/test_careqa_pairs.json +0 -8
- wisent/examples/scripts/1/test_catalanqa_evaluation.json +0 -30
- wisent/examples/scripts/1/test_catalanqa_pairs.json +0 -8
- wisent/examples/scripts/1/test_catcola_evaluation.json +0 -30
- wisent/examples/scripts/1/test_catcola_pairs.json +0 -8
- wisent/examples/scripts/1/test_chartqa_evaluation.json +0 -30
- wisent/examples/scripts/1/test_chartqa_pairs.json +0 -8
- wisent/examples/scripts/1/test_claim_stance_topic_evaluation.json +0 -30
- wisent/examples/scripts/1/test_claim_stance_topic_pairs.json +0 -8
- wisent/examples/scripts/1/test_cnn_dailymail_evaluation.json +0 -30
- wisent/examples/scripts/1/test_cnn_dailymail_pairs.json +0 -8
- wisent/examples/scripts/1/test_cocoteros_es_evaluation.json +0 -30
- wisent/examples/scripts/1/test_cocoteros_es_pairs.json +0 -8
- wisent/examples/scripts/1/test_coedit_gec_evaluation.json +0 -30
- wisent/examples/scripts/1/test_coedit_gec_pairs.json +0 -8
- wisent/examples/scripts/1/test_cola_evaluation.json +0 -30
- wisent/examples/scripts/1/test_cola_pairs.json +0 -8
- wisent/examples/scripts/1/test_coqcat_evaluation.json +0 -30
- wisent/examples/scripts/1/test_coqcat_pairs.json +0 -8
- wisent/examples/scripts/1/test_dbpedia_14_evaluation.json +0 -30
- wisent/examples/scripts/1/test_dbpedia_14_pairs.json +0 -8
- wisent/examples/scripts/1/test_epec_koref_bin_evaluation.json +0 -30
- wisent/examples/scripts/1/test_epec_koref_bin_pairs.json +0 -8
- wisent/examples/scripts/1/test_ethos_binary_evaluation.json +0 -30
- wisent/examples/scripts/1/test_ethos_binary_pairs.json +0 -8
- wisent/examples/scripts/2/test_afrimgsm_direct_amh_evaluation.json +0 -30
- wisent/examples/scripts/2/test_afrimgsm_direct_amh_pairs.json +0 -8
- wisent/examples/scripts/2/test_afrimmlu_direct_amh_evaluation.json +0 -30
- wisent/examples/scripts/2/test_afrimmlu_direct_amh_pairs.json +0 -8
- wisent/examples/scripts/2/test_afrixnli_en_direct_amh_evaluation.json +0 -30
- wisent/examples/scripts/2/test_afrixnli_en_direct_amh_pairs.json +0 -8
- wisent/examples/scripts/2/test_arc_ar_evaluation.json +0 -30
- wisent/examples/scripts/2/test_arc_ar_pairs.json +0 -8
- wisent/examples/scripts/2/test_atis_evaluation.json +0 -30
- wisent/examples/scripts/2/test_atis_pairs.json +0 -8
- wisent/examples/scripts/2/test_babi_evaluation.json +0 -30
- wisent/examples/scripts/2/test_babi_pairs.json +0 -8
- wisent/examples/scripts/2/test_babilong_evaluation.json +0 -30
- wisent/examples/scripts/2/test_babilong_pairs.json +0 -8
- wisent/examples/scripts/2/test_bangla_mmlu_evaluation.json +0 -30
- wisent/examples/scripts/2/test_bangla_mmlu_pairs.json +0 -8
- wisent/examples/scripts/2/test_basque-glue_pairs.json +0 -14
- wisent/examples/scripts/generate_paper_data.py +0 -384
- wisent/examples/scripts/intervention_validation.py +0 -626
- wisent/examples/scripts/results/test_AraDiCE_ArabicMMLU_lev_evaluation.json +0 -324
- wisent/examples/scripts/results/test_AraDiCE_ArabicMMLU_lev_pairs.json +0 -92
- wisent/examples/scripts/results/test_aexams_IslamicStudies_evaluation.json +0 -324
- wisent/examples/scripts/results/test_aexams_IslamicStudies_pairs.json +0 -92
- wisent/examples/scripts/results/test_afrimgsm_pairs.json +0 -92
- wisent/examples/scripts/results/test_afrimmlu_evaluation.json +0 -324
- wisent/examples/scripts/results/test_afrimmlu_pairs.json +0 -92
- wisent/examples/scripts/threshold_analysis.py +0 -434
- wisent/examples/scripts/visualization_gallery.py +0 -582
- {wisent-0.7.901.dist-info → wisent-0.7.1045.dist-info}/WHEEL +0 -0
- {wisent-0.7.901.dist-info → wisent-0.7.1045.dist-info}/entry_points.txt +0 -0
- {wisent-0.7.901.dist-info → wisent-0.7.1045.dist-info}/licenses/LICENSE +0 -0
- {wisent-0.7.901.dist-info → wisent-0.7.1045.dist-info}/top_level.txt +0 -0
|
@@ -1,626 +0,0 @@
|
|
|
1
|
-
"""
|
|
2
|
-
Intervention Validation for RepScan.
|
|
3
|
-
|
|
4
|
-
Tests whether RepScan diagnosis predicts CAA steering success:
|
|
5
|
-
- LINEAR diagnosis -> CAA should work
|
|
6
|
-
- NONLINEAR diagnosis -> CAA should fail (but detection still works)
|
|
7
|
-
- NO_SIGNAL diagnosis -> neither should work
|
|
8
|
-
|
|
9
|
-
This is the CRITICAL missing piece identified by reviewers.
|
|
10
|
-
|
|
11
|
-
Usage:
|
|
12
|
-
python -m wisent.examples.scripts.intervention_validation --model Qwen/Qwen3-8B
|
|
13
|
-
"""
|
|
14
|
-
|
|
15
|
-
import argparse
|
|
16
|
-
import json
|
|
17
|
-
import subprocess
|
|
18
|
-
from pathlib import Path
|
|
19
|
-
from typing import Dict, List, Any, Optional, Tuple
|
|
20
|
-
from dataclasses import dataclass, field, asdict
|
|
21
|
-
import random
|
|
22
|
-
|
|
23
|
-
import torch
|
|
24
|
-
import numpy as np
|
|
25
|
-
|
|
26
|
-
S3_BUCKET = "wisent-bucket"
|
|
27
|
-
S3_PREFIX = "intervention_validation"
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
def s3_upload_file(local_path: Path, model_name: str) -> None:
|
|
31
|
-
"""Upload a single file to S3."""
|
|
32
|
-
model_prefix = model_name.replace('/', '_')
|
|
33
|
-
s3_path = f"s3://{S3_BUCKET}/{S3_PREFIX}/{model_prefix}/{local_path.name}"
|
|
34
|
-
try:
|
|
35
|
-
subprocess.run(
|
|
36
|
-
["aws", "s3", "cp", str(local_path), s3_path, "--quiet"],
|
|
37
|
-
check=True,
|
|
38
|
-
capture_output=True,
|
|
39
|
-
)
|
|
40
|
-
print(f" Uploaded to S3: {s3_path}")
|
|
41
|
-
except Exception as e:
|
|
42
|
-
print(f" S3 upload failed: {e}")
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
@dataclass
|
|
46
|
-
class SteeringResult:
|
|
47
|
-
"""Result of a single steering experiment."""
|
|
48
|
-
benchmark: str
|
|
49
|
-
strategy: str
|
|
50
|
-
layer: int
|
|
51
|
-
diagnosis: str # LINEAR, NONLINEAR, NO_SIGNAL
|
|
52
|
-
|
|
53
|
-
# Before steering
|
|
54
|
-
baseline_accuracy: float # Model's baseline accuracy on task
|
|
55
|
-
baseline_correct_logprob: float # Avg logprob of correct answer
|
|
56
|
-
baseline_incorrect_logprob: float # Avg logprob of incorrect answer
|
|
57
|
-
|
|
58
|
-
# After steering (with CAA)
|
|
59
|
-
steered_accuracy: float
|
|
60
|
-
steered_correct_logprob: float
|
|
61
|
-
steered_incorrect_logprob: float
|
|
62
|
-
|
|
63
|
-
# Steering effect
|
|
64
|
-
accuracy_change: float # steered - baseline (positive = improvement)
|
|
65
|
-
logprob_shift: float # Change in correct - incorrect gap
|
|
66
|
-
steering_success: bool # Did steering improve in expected direction?
|
|
67
|
-
|
|
68
|
-
# Steering parameters
|
|
69
|
-
steering_coefficient: float
|
|
70
|
-
num_test_samples: int
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
@dataclass
|
|
74
|
-
class ValidationResults:
|
|
75
|
-
"""Results from full intervention validation."""
|
|
76
|
-
model: str
|
|
77
|
-
results: List[SteeringResult] = field(default_factory=list)
|
|
78
|
-
|
|
79
|
-
# Summary statistics by diagnosis
|
|
80
|
-
linear_success_rate: float = 0.0
|
|
81
|
-
nonlinear_success_rate: float = 0.0
|
|
82
|
-
no_signal_success_rate: float = 0.0
|
|
83
|
-
|
|
84
|
-
def compute_summary(self):
|
|
85
|
-
"""Compute summary statistics."""
|
|
86
|
-
linear = [r for r in self.results if r.diagnosis == "LINEAR"]
|
|
87
|
-
nonlinear = [r for r in self.results if r.diagnosis == "NONLINEAR"]
|
|
88
|
-
no_signal = [r for r in self.results if r.diagnosis == "NO_SIGNAL"]
|
|
89
|
-
|
|
90
|
-
if linear:
|
|
91
|
-
self.linear_success_rate = sum(r.steering_success for r in linear) / len(linear)
|
|
92
|
-
if nonlinear:
|
|
93
|
-
self.nonlinear_success_rate = sum(r.steering_success for r in nonlinear) / len(nonlinear)
|
|
94
|
-
if no_signal:
|
|
95
|
-
self.no_signal_success_rate = sum(r.steering_success for r in no_signal) / len(no_signal)
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
def compute_caa_direction(
|
|
99
|
-
pos_activations: torch.Tensor,
|
|
100
|
-
neg_activations: torch.Tensor,
|
|
101
|
-
) -> torch.Tensor:
|
|
102
|
-
"""
|
|
103
|
-
Compute CAA (Contrastive Activation Addition) direction.
|
|
104
|
-
|
|
105
|
-
This is the difference-in-means direction used for steering.
|
|
106
|
-
|
|
107
|
-
Args:
|
|
108
|
-
pos_activations: [N, hidden_dim] positive class activations
|
|
109
|
-
neg_activations: [N, hidden_dim] negative class activations
|
|
110
|
-
|
|
111
|
-
Returns:
|
|
112
|
-
[hidden_dim] steering direction (normalized)
|
|
113
|
-
"""
|
|
114
|
-
pos_mean = pos_activations.float().mean(dim=0)
|
|
115
|
-
neg_mean = neg_activations.float().mean(dim=0)
|
|
116
|
-
direction = pos_mean - neg_mean
|
|
117
|
-
return direction / (direction.norm() + 1e-10)
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
def apply_steering_hook(
|
|
121
|
-
model: "WisentModel",
|
|
122
|
-
layer: int,
|
|
123
|
-
direction: torch.Tensor,
|
|
124
|
-
coefficient: float,
|
|
125
|
-
) -> callable:
|
|
126
|
-
"""
|
|
127
|
-
Create a forward hook that adds steering vector to activations.
|
|
128
|
-
|
|
129
|
-
Args:
|
|
130
|
-
model: WisentModel instance
|
|
131
|
-
layer: Layer index to apply steering
|
|
132
|
-
direction: [hidden_dim] steering direction
|
|
133
|
-
coefficient: Steering strength
|
|
134
|
-
|
|
135
|
-
Returns:
|
|
136
|
-
Hook function
|
|
137
|
-
"""
|
|
138
|
-
def hook(module, input, output):
|
|
139
|
-
# output is (hidden_states, ...) or just hidden_states
|
|
140
|
-
if isinstance(output, tuple):
|
|
141
|
-
hidden_states = output[0]
|
|
142
|
-
# Add steering vector to last token position
|
|
143
|
-
hidden_states[:, -1, :] += coefficient * direction.to(hidden_states.device)
|
|
144
|
-
return (hidden_states,) + output[1:]
|
|
145
|
-
else:
|
|
146
|
-
output[:, -1, :] += coefficient * direction.to(output.device)
|
|
147
|
-
return output
|
|
148
|
-
|
|
149
|
-
return hook
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
def get_model_logprobs(
|
|
153
|
-
model: "WisentModel",
|
|
154
|
-
prompt: str,
|
|
155
|
-
completion: str,
|
|
156
|
-
) -> float:
|
|
157
|
-
"""
|
|
158
|
-
Get log probability of completion given prompt.
|
|
159
|
-
|
|
160
|
-
Args:
|
|
161
|
-
model: WisentModel instance
|
|
162
|
-
prompt: Input prompt
|
|
163
|
-
completion: Completion to score
|
|
164
|
-
|
|
165
|
-
Returns:
|
|
166
|
-
Average log probability of completion tokens
|
|
167
|
-
"""
|
|
168
|
-
full_text = prompt + completion
|
|
169
|
-
|
|
170
|
-
inputs = model.tokenizer(
|
|
171
|
-
full_text,
|
|
172
|
-
return_tensors="pt",
|
|
173
|
-
truncation=True,
|
|
174
|
-
max_length=2048,
|
|
175
|
-
).to(model.device)
|
|
176
|
-
|
|
177
|
-
prompt_tokens = model.tokenizer(
|
|
178
|
-
prompt,
|
|
179
|
-
return_tensors="pt",
|
|
180
|
-
truncation=True,
|
|
181
|
-
max_length=2048,
|
|
182
|
-
).input_ids.shape[1]
|
|
183
|
-
|
|
184
|
-
with torch.no_grad():
|
|
185
|
-
outputs = model.model(**inputs)
|
|
186
|
-
logits = outputs.logits
|
|
187
|
-
|
|
188
|
-
# Get logprobs for completion tokens only
|
|
189
|
-
shift_logits = logits[:, prompt_tokens-1:-1, :].contiguous()
|
|
190
|
-
shift_labels = inputs.input_ids[:, prompt_tokens:].contiguous()
|
|
191
|
-
|
|
192
|
-
log_probs = torch.nn.functional.log_softmax(shift_logits, dim=-1)
|
|
193
|
-
token_log_probs = log_probs.gather(-1, shift_labels.unsqueeze(-1)).squeeze(-1)
|
|
194
|
-
|
|
195
|
-
return token_log_probs.mean().item()
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
def evaluate_steering(
|
|
199
|
-
model: "WisentModel",
|
|
200
|
-
test_pairs: List,
|
|
201
|
-
layer: int,
|
|
202
|
-
direction: torch.Tensor,
|
|
203
|
-
coefficient: float,
|
|
204
|
-
) -> Tuple[float, float, float]:
|
|
205
|
-
"""
|
|
206
|
-
Evaluate steering effect on test pairs.
|
|
207
|
-
|
|
208
|
-
Args:
|
|
209
|
-
model: WisentModel instance
|
|
210
|
-
test_pairs: List of ContrastivePair objects
|
|
211
|
-
layer: Layer to apply steering
|
|
212
|
-
direction: Steering direction
|
|
213
|
-
coefficient: Steering strength
|
|
214
|
-
|
|
215
|
-
Returns:
|
|
216
|
-
(accuracy, avg_correct_logprob, avg_incorrect_logprob)
|
|
217
|
-
"""
|
|
218
|
-
# Get the layer module
|
|
219
|
-
layer_module = model.model.model.layers[layer]
|
|
220
|
-
|
|
221
|
-
# Register hook
|
|
222
|
-
hook = apply_steering_hook(model, layer, direction, coefficient)
|
|
223
|
-
handle = layer_module.register_forward_hook(hook)
|
|
224
|
-
|
|
225
|
-
try:
|
|
226
|
-
correct = 0
|
|
227
|
-
correct_logprobs = []
|
|
228
|
-
incorrect_logprobs = []
|
|
229
|
-
|
|
230
|
-
for pair in test_pairs:
|
|
231
|
-
prompt = pair.positive_prompt # Same prompt for both
|
|
232
|
-
correct_completion = pair.positive_completion
|
|
233
|
-
incorrect_completion = pair.negative_completion
|
|
234
|
-
|
|
235
|
-
correct_lp = get_model_logprobs(model, prompt, correct_completion)
|
|
236
|
-
incorrect_lp = get_model_logprobs(model, prompt, incorrect_completion)
|
|
237
|
-
|
|
238
|
-
correct_logprobs.append(correct_lp)
|
|
239
|
-
incorrect_logprobs.append(incorrect_lp)
|
|
240
|
-
|
|
241
|
-
if correct_lp > incorrect_lp:
|
|
242
|
-
correct += 1
|
|
243
|
-
|
|
244
|
-
accuracy = correct / len(test_pairs) if test_pairs else 0.0
|
|
245
|
-
avg_correct = np.mean(correct_logprobs) if correct_logprobs else 0.0
|
|
246
|
-
avg_incorrect = np.mean(incorrect_logprobs) if incorrect_logprobs else 0.0
|
|
247
|
-
|
|
248
|
-
return accuracy, avg_correct, avg_incorrect
|
|
249
|
-
|
|
250
|
-
finally:
|
|
251
|
-
handle.remove()
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
def evaluate_baseline(
|
|
255
|
-
model: "WisentModel",
|
|
256
|
-
test_pairs: List,
|
|
257
|
-
) -> Tuple[float, float, float]:
|
|
258
|
-
"""
|
|
259
|
-
Evaluate baseline (no steering) on test pairs.
|
|
260
|
-
|
|
261
|
-
Args:
|
|
262
|
-
model: WisentModel instance
|
|
263
|
-
test_pairs: List of ContrastivePair objects
|
|
264
|
-
|
|
265
|
-
Returns:
|
|
266
|
-
(accuracy, avg_correct_logprob, avg_incorrect_logprob)
|
|
267
|
-
"""
|
|
268
|
-
correct = 0
|
|
269
|
-
correct_logprobs = []
|
|
270
|
-
incorrect_logprobs = []
|
|
271
|
-
|
|
272
|
-
for pair in test_pairs:
|
|
273
|
-
prompt = pair.positive_prompt
|
|
274
|
-
correct_completion = pair.positive_completion
|
|
275
|
-
incorrect_completion = pair.negative_completion
|
|
276
|
-
|
|
277
|
-
correct_lp = get_model_logprobs(model, prompt, correct_completion)
|
|
278
|
-
incorrect_lp = get_model_logprobs(model, prompt, incorrect_completion)
|
|
279
|
-
|
|
280
|
-
correct_logprobs.append(correct_lp)
|
|
281
|
-
incorrect_logprobs.append(incorrect_lp)
|
|
282
|
-
|
|
283
|
-
if correct_lp > incorrect_lp:
|
|
284
|
-
correct += 1
|
|
285
|
-
|
|
286
|
-
accuracy = correct / len(test_pairs) if test_pairs else 0.0
|
|
287
|
-
avg_correct = np.mean(correct_logprobs) if correct_logprobs else 0.0
|
|
288
|
-
avg_incorrect = np.mean(incorrect_logprobs) if incorrect_logprobs else 0.0
|
|
289
|
-
|
|
290
|
-
return accuracy, avg_correct, avg_incorrect
|
|
291
|
-
|
|
292
|
-
|
|
293
|
-
def load_diagnosis_results(model_name: str, output_dir: Path) -> Dict[str, Any]:
|
|
294
|
-
"""Load RepScan diagnosis results from S3/local."""
|
|
295
|
-
model_prefix = model_name.replace('/', '_')
|
|
296
|
-
|
|
297
|
-
# Try to download from S3 first
|
|
298
|
-
try:
|
|
299
|
-
subprocess.run(
|
|
300
|
-
["aws", "s3", "sync",
|
|
301
|
-
f"s3://{S3_BUCKET}/direction_discovery/{model_prefix}/",
|
|
302
|
-
str(output_dir / "diagnosis"),
|
|
303
|
-
"--quiet"],
|
|
304
|
-
check=False,
|
|
305
|
-
capture_output=True,
|
|
306
|
-
)
|
|
307
|
-
except Exception:
|
|
308
|
-
pass
|
|
309
|
-
|
|
310
|
-
# Load results
|
|
311
|
-
results = {}
|
|
312
|
-
diagnosis_dir = output_dir / "diagnosis"
|
|
313
|
-
if diagnosis_dir.exists():
|
|
314
|
-
for f in diagnosis_dir.glob(f"{model_prefix}_*.json"):
|
|
315
|
-
if "summary" not in f.name:
|
|
316
|
-
category = f.stem.replace(f"{model_prefix}_", "")
|
|
317
|
-
with open(f) as fp:
|
|
318
|
-
results[category] = json.load(fp)
|
|
319
|
-
|
|
320
|
-
return results
|
|
321
|
-
|
|
322
|
-
|
|
323
|
-
def get_diagnosis_for_benchmark(
|
|
324
|
-
diagnosis_results: Dict[str, Any],
|
|
325
|
-
benchmark: str,
|
|
326
|
-
strategy: str = "chat_last",
|
|
327
|
-
) -> Tuple[str, int, float, float]:
|
|
328
|
-
"""
|
|
329
|
-
Get RepScan diagnosis for a specific benchmark.
|
|
330
|
-
|
|
331
|
-
Args:
|
|
332
|
-
diagnosis_results: Loaded diagnosis results
|
|
333
|
-
benchmark: Benchmark name
|
|
334
|
-
strategy: Extraction strategy
|
|
335
|
-
|
|
336
|
-
Returns:
|
|
337
|
-
(diagnosis, best_layer, signal_strength, linear_probe_accuracy)
|
|
338
|
-
"""
|
|
339
|
-
for category, data in diagnosis_results.items():
|
|
340
|
-
results = data.get("results", [])
|
|
341
|
-
for r in results:
|
|
342
|
-
if r["benchmark"] == benchmark and r["strategy"] == strategy:
|
|
343
|
-
signal = r["signal_strength"]
|
|
344
|
-
linear = r["linear_probe_accuracy"]
|
|
345
|
-
layers = r["layers"]
|
|
346
|
-
best_layer = layers[0] if layers else 16 # Default to middle layer
|
|
347
|
-
|
|
348
|
-
# Determine diagnosis
|
|
349
|
-
if signal < 0.6:
|
|
350
|
-
diagnosis = "NO_SIGNAL"
|
|
351
|
-
elif linear > 0.6 and (signal - linear) < 0.15:
|
|
352
|
-
diagnosis = "LINEAR"
|
|
353
|
-
else:
|
|
354
|
-
diagnosis = "NONLINEAR"
|
|
355
|
-
|
|
356
|
-
return diagnosis, best_layer, signal, linear
|
|
357
|
-
|
|
358
|
-
return "UNKNOWN", 16, 0.5, 0.5
|
|
359
|
-
|
|
360
|
-
|
|
361
|
-
def run_intervention_validation(
|
|
362
|
-
model_name: str,
|
|
363
|
-
benchmarks_to_test: Optional[List[str]] = None,
|
|
364
|
-
samples_per_benchmark: int = 20,
|
|
365
|
-
test_samples: int = 30,
|
|
366
|
-
steering_coefficients: List[float] = [0.5, 1.0, 2.0, 3.0],
|
|
367
|
-
):
|
|
368
|
-
"""
|
|
369
|
-
Run intervention validation experiments.
|
|
370
|
-
|
|
371
|
-
Args:
|
|
372
|
-
model_name: Model to test
|
|
373
|
-
benchmarks_to_test: Specific benchmarks (default: sample from each diagnosis)
|
|
374
|
-
samples_per_benchmark: Pairs for computing steering direction
|
|
375
|
-
test_samples: Pairs for evaluating steering
|
|
376
|
-
steering_coefficients: Coefficients to test
|
|
377
|
-
"""
|
|
378
|
-
from wisent.core.models.wisent_model import WisentModel
|
|
379
|
-
from wisent.core.activations.extraction_strategy import ExtractionStrategy
|
|
380
|
-
from wisent.core.activations.activation_cache import ActivationCache, collect_and_cache_activations
|
|
381
|
-
from lm_eval.tasks import TaskManager
|
|
382
|
-
from wisent.core.contrastive_pairs.lm_eval_pairs.lm_task_pairs_generation import lm_build_contrastive_pairs
|
|
383
|
-
|
|
384
|
-
print("=" * 70)
|
|
385
|
-
print("INTERVENTION VALIDATION")
|
|
386
|
-
print("=" * 70)
|
|
387
|
-
print(f"Model: {model_name}")
|
|
388
|
-
|
|
389
|
-
output_dir = Path("/tmp/intervention_validation")
|
|
390
|
-
output_dir.mkdir(parents=True, exist_ok=True)
|
|
391
|
-
|
|
392
|
-
# Load diagnosis results
|
|
393
|
-
diagnosis_results = load_diagnosis_results(model_name, output_dir)
|
|
394
|
-
if not diagnosis_results:
|
|
395
|
-
print("ERROR: No diagnosis results found. Run discover_directions first.")
|
|
396
|
-
return
|
|
397
|
-
|
|
398
|
-
# Select benchmarks to test (sample from each diagnosis type)
|
|
399
|
-
if benchmarks_to_test is None:
|
|
400
|
-
benchmarks_to_test = []
|
|
401
|
-
|
|
402
|
-
# Collect all benchmarks with their diagnoses
|
|
403
|
-
by_diagnosis = {"LINEAR": [], "NONLINEAR": [], "NO_SIGNAL": []}
|
|
404
|
-
|
|
405
|
-
for category, data in diagnosis_results.items():
|
|
406
|
-
results = data.get("results", [])
|
|
407
|
-
seen_benchmarks = set()
|
|
408
|
-
for r in results:
|
|
409
|
-
bench = r["benchmark"]
|
|
410
|
-
if bench in seen_benchmarks:
|
|
411
|
-
continue
|
|
412
|
-
seen_benchmarks.add(bench)
|
|
413
|
-
|
|
414
|
-
signal = r["signal_strength"]
|
|
415
|
-
linear = r["linear_probe_accuracy"]
|
|
416
|
-
|
|
417
|
-
if signal < 0.6:
|
|
418
|
-
by_diagnosis["NO_SIGNAL"].append(bench)
|
|
419
|
-
elif linear > 0.6 and (signal - linear) < 0.15:
|
|
420
|
-
by_diagnosis["LINEAR"].append(bench)
|
|
421
|
-
else:
|
|
422
|
-
by_diagnosis["NONLINEAR"].append(bench)
|
|
423
|
-
|
|
424
|
-
# Sample 3 from each category
|
|
425
|
-
random.seed(42)
|
|
426
|
-
for diag, benches in by_diagnosis.items():
|
|
427
|
-
if benches:
|
|
428
|
-
sampled = random.sample(benches, min(3, len(benches)))
|
|
429
|
-
benchmarks_to_test.extend(sampled)
|
|
430
|
-
print(f" {diag}: {sampled}")
|
|
431
|
-
|
|
432
|
-
print(f"\nBenchmarks to test: {benchmarks_to_test}")
|
|
433
|
-
|
|
434
|
-
# Load model
|
|
435
|
-
print(f"\nLoading model: {model_name}")
|
|
436
|
-
model = WisentModel(model_name, device="cuda")
|
|
437
|
-
print(f" Layers: {model.num_layers}, Hidden: {model.hidden_size}")
|
|
438
|
-
|
|
439
|
-
# Cache directory
|
|
440
|
-
model_prefix = model_name.replace('/', '_')
|
|
441
|
-
cache_dir = f"/tmp/wisent_intervention_cache_{model_prefix}"
|
|
442
|
-
cache = ActivationCache(cache_dir)
|
|
443
|
-
|
|
444
|
-
# Results
|
|
445
|
-
validation_results = ValidationResults(model=model_name)
|
|
446
|
-
|
|
447
|
-
tm = TaskManager()
|
|
448
|
-
strategy = ExtractionStrategy.CHAT_LAST
|
|
449
|
-
|
|
450
|
-
for benchmark in benchmarks_to_test:
|
|
451
|
-
print(f"\n{'-' * 50}")
|
|
452
|
-
print(f"Benchmark: {benchmark}")
|
|
453
|
-
print("-" * 50)
|
|
454
|
-
|
|
455
|
-
# Get diagnosis
|
|
456
|
-
diagnosis, best_layer, signal, linear_acc = get_diagnosis_for_benchmark(
|
|
457
|
-
diagnosis_results, benchmark, strategy.value
|
|
458
|
-
)
|
|
459
|
-
print(f" Diagnosis: {diagnosis}")
|
|
460
|
-
print(f" Signal: {signal:.3f}, Linear: {linear_acc:.3f}")
|
|
461
|
-
print(f" Best layer: {best_layer}")
|
|
462
|
-
|
|
463
|
-
# Load pairs
|
|
464
|
-
try:
|
|
465
|
-
task_dict = tm.load_task_or_group([benchmark])
|
|
466
|
-
task = list(task_dict.values())[0]
|
|
467
|
-
except Exception:
|
|
468
|
-
task = None
|
|
469
|
-
|
|
470
|
-
try:
|
|
471
|
-
all_pairs = lm_build_contrastive_pairs(
|
|
472
|
-
benchmark,
|
|
473
|
-
task,
|
|
474
|
-
limit=samples_per_benchmark + test_samples
|
|
475
|
-
)
|
|
476
|
-
except Exception as e:
|
|
477
|
-
print(f" ERROR loading pairs: {e}")
|
|
478
|
-
continue
|
|
479
|
-
|
|
480
|
-
if len(all_pairs) < samples_per_benchmark + test_samples:
|
|
481
|
-
print(f" SKIP: Not enough pairs ({len(all_pairs)})")
|
|
482
|
-
continue
|
|
483
|
-
|
|
484
|
-
# Split into train (for direction) and test (for evaluation)
|
|
485
|
-
random.shuffle(all_pairs)
|
|
486
|
-
train_pairs = all_pairs[:samples_per_benchmark]
|
|
487
|
-
test_pairs = all_pairs[samples_per_benchmark:samples_per_benchmark + test_samples]
|
|
488
|
-
|
|
489
|
-
print(f" Train pairs: {len(train_pairs)}, Test pairs: {len(test_pairs)}")
|
|
490
|
-
|
|
491
|
-
# Get activations for training pairs
|
|
492
|
-
print(f" Extracting activations...")
|
|
493
|
-
try:
|
|
494
|
-
cached = collect_and_cache_activations(
|
|
495
|
-
model=model,
|
|
496
|
-
pairs=train_pairs,
|
|
497
|
-
benchmark=benchmark,
|
|
498
|
-
strategy=strategy,
|
|
499
|
-
cache=cache,
|
|
500
|
-
show_progress=False,
|
|
501
|
-
)
|
|
502
|
-
except Exception as e:
|
|
503
|
-
print(f" ERROR extracting activations: {e}")
|
|
504
|
-
continue
|
|
505
|
-
|
|
506
|
-
# Get activations at best layer
|
|
507
|
-
layer_name = str(best_layer + 1) # 1-based
|
|
508
|
-
try:
|
|
509
|
-
pos_acts = cached.get_positive_activations(layer_name)
|
|
510
|
-
neg_acts = cached.get_negative_activations(layer_name)
|
|
511
|
-
except Exception as e:
|
|
512
|
-
print(f" ERROR getting activations: {e}")
|
|
513
|
-
continue
|
|
514
|
-
|
|
515
|
-
# Compute CAA direction
|
|
516
|
-
direction = compute_caa_direction(pos_acts, neg_acts)
|
|
517
|
-
print(f" Direction norm: {direction.norm().item():.4f}")
|
|
518
|
-
|
|
519
|
-
# Evaluate baseline
|
|
520
|
-
print(f" Evaluating baseline...")
|
|
521
|
-
base_acc, base_correct_lp, base_incorrect_lp = evaluate_baseline(model, test_pairs)
|
|
522
|
-
print(f" Baseline accuracy: {base_acc:.3f}")
|
|
523
|
-
print(f" Baseline logprob gap: {base_correct_lp - base_incorrect_lp:.4f}")
|
|
524
|
-
|
|
525
|
-
# Test steering at different coefficients
|
|
526
|
-
best_result = None
|
|
527
|
-
best_improvement = -float('inf')
|
|
528
|
-
|
|
529
|
-
for coef in steering_coefficients:
|
|
530
|
-
print(f" Testing coefficient={coef}...")
|
|
531
|
-
steered_acc, steered_correct_lp, steered_incorrect_lp = evaluate_steering(
|
|
532
|
-
model, test_pairs, best_layer, direction, coef
|
|
533
|
-
)
|
|
534
|
-
|
|
535
|
-
acc_change = steered_acc - base_acc
|
|
536
|
-
lp_shift = (steered_correct_lp - steered_incorrect_lp) - (base_correct_lp - base_incorrect_lp)
|
|
537
|
-
|
|
538
|
-
print(f" Steered accuracy: {steered_acc:.3f} (change: {acc_change:+.3f})")
|
|
539
|
-
print(f" Logprob shift: {lp_shift:+.4f}")
|
|
540
|
-
|
|
541
|
-
# Steering is successful if it improves accuracy OR logprob gap
|
|
542
|
-
steering_success = acc_change > 0.05 or lp_shift > 0.1
|
|
543
|
-
|
|
544
|
-
if acc_change > best_improvement:
|
|
545
|
-
best_improvement = acc_change
|
|
546
|
-
best_result = SteeringResult(
|
|
547
|
-
benchmark=benchmark,
|
|
548
|
-
strategy=strategy.value,
|
|
549
|
-
layer=best_layer,
|
|
550
|
-
diagnosis=diagnosis,
|
|
551
|
-
baseline_accuracy=base_acc,
|
|
552
|
-
baseline_correct_logprob=base_correct_lp,
|
|
553
|
-
baseline_incorrect_logprob=base_incorrect_lp,
|
|
554
|
-
steered_accuracy=steered_acc,
|
|
555
|
-
steered_correct_logprob=steered_correct_lp,
|
|
556
|
-
steered_incorrect_logprob=steered_incorrect_lp,
|
|
557
|
-
accuracy_change=acc_change,
|
|
558
|
-
logprob_shift=lp_shift,
|
|
559
|
-
steering_success=steering_success,
|
|
560
|
-
steering_coefficient=coef,
|
|
561
|
-
num_test_samples=len(test_pairs),
|
|
562
|
-
)
|
|
563
|
-
|
|
564
|
-
if best_result:
|
|
565
|
-
validation_results.results.append(best_result)
|
|
566
|
-
print(f"\n Best result: coef={best_result.steering_coefficient}, "
|
|
567
|
-
f"acc_change={best_result.accuracy_change:+.3f}, "
|
|
568
|
-
f"success={best_result.steering_success}")
|
|
569
|
-
|
|
570
|
-
# Compute summary
|
|
571
|
-
validation_results.compute_summary()
|
|
572
|
-
|
|
573
|
-
# Print summary
|
|
574
|
-
print("\n" + "=" * 70)
|
|
575
|
-
print("VALIDATION SUMMARY")
|
|
576
|
-
print("=" * 70)
|
|
577
|
-
print(f"\nLinear diagnosis -> CAA success rate: {validation_results.linear_success_rate:.1%}")
|
|
578
|
-
print(f"Nonlinear diagnosis -> CAA success rate: {validation_results.nonlinear_success_rate:.1%}")
|
|
579
|
-
print(f"No signal diagnosis -> CAA success rate: {validation_results.no_signal_success_rate:.1%}")
|
|
580
|
-
|
|
581
|
-
# Expected pattern:
|
|
582
|
-
# LINEAR -> high success rate
|
|
583
|
-
# NONLINEAR -> low success rate (CAA doesn't work, but detection does)
|
|
584
|
-
# NO_SIGNAL -> low success rate
|
|
585
|
-
|
|
586
|
-
if validation_results.linear_success_rate > validation_results.nonlinear_success_rate:
|
|
587
|
-
print("\n✓ VALIDATION PASSED: LINEAR diagnosis predicts higher CAA success!")
|
|
588
|
-
else:
|
|
589
|
-
print("\n✗ VALIDATION FAILED: LINEAR diagnosis does not predict higher CAA success")
|
|
590
|
-
|
|
591
|
-
# Save results
|
|
592
|
-
results_file = output_dir / f"{model_prefix}_validation.json"
|
|
593
|
-
with open(results_file, "w") as f:
|
|
594
|
-
json.dump({
|
|
595
|
-
"model": model_name,
|
|
596
|
-
"results": [asdict(r) for r in validation_results.results],
|
|
597
|
-
"summary": {
|
|
598
|
-
"linear_success_rate": validation_results.linear_success_rate,
|
|
599
|
-
"nonlinear_success_rate": validation_results.nonlinear_success_rate,
|
|
600
|
-
"no_signal_success_rate": validation_results.no_signal_success_rate,
|
|
601
|
-
}
|
|
602
|
-
}, f, indent=2)
|
|
603
|
-
|
|
604
|
-
print(f"\nResults saved to: {results_file}")
|
|
605
|
-
s3_upload_file(results_file, model_name)
|
|
606
|
-
|
|
607
|
-
# Cleanup
|
|
608
|
-
del model
|
|
609
|
-
|
|
610
|
-
return validation_results
|
|
611
|
-
|
|
612
|
-
|
|
613
|
-
if __name__ == "__main__":
|
|
614
|
-
parser = argparse.ArgumentParser(description="Intervention validation for RepScan")
|
|
615
|
-
parser.add_argument("--model", type=str, default="Qwen/Qwen3-8B", help="Model to test")
|
|
616
|
-
parser.add_argument("--benchmarks", type=str, nargs="+", default=None, help="Specific benchmarks to test")
|
|
617
|
-
parser.add_argument("--samples", type=int, default=20, help="Samples for direction computation")
|
|
618
|
-
parser.add_argument("--test-samples", type=int, default=30, help="Samples for evaluation")
|
|
619
|
-
args = parser.parse_args()
|
|
620
|
-
|
|
621
|
-
run_intervention_validation(
|
|
622
|
-
model_name=args.model,
|
|
623
|
-
benchmarks_to_test=args.benchmarks,
|
|
624
|
-
samples_per_benchmark=args.samples,
|
|
625
|
-
test_samples=args.test_samples,
|
|
626
|
-
)
|