wisent 0.7.901__py3-none-any.whl → 0.7.1116__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (101) hide show
  1. wisent/__init__.py +1 -1
  2. wisent/comparison/__init__.py +1 -0
  3. wisent/comparison/detect_bos_features.py +275 -0
  4. wisent/comparison/fgaa.py +465 -0
  5. wisent/comparison/lora.py +663 -0
  6. wisent/comparison/lora_dpo.py +604 -0
  7. wisent/comparison/main.py +444 -0
  8. wisent/comparison/ours.py +76 -0
  9. wisent/comparison/reft.py +690 -0
  10. wisent/comparison/sae.py +304 -0
  11. wisent/comparison/utils.py +381 -0
  12. wisent/core/activations/activations_collector.py +3 -2
  13. wisent/core/activations/extraction_strategy.py +8 -4
  14. wisent/core/cli/agent/apply_steering.py +7 -5
  15. wisent/core/cli/agent/train_classifier.py +4 -3
  16. wisent/core/cli/generate_vector_from_task.py +11 -20
  17. wisent/core/cli/get_activations.py +1 -1
  18. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/boolq.py +20 -3
  19. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/cb.py +8 -1
  20. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/truthfulqa_mc1.py +8 -1
  21. wisent/core/parser_arguments/generate_vector_from_task_parser.py +4 -11
  22. wisent/core/parser_arguments/get_activations_parser.py +5 -14
  23. {wisent-0.7.901.dist-info → wisent-0.7.1116.dist-info}/METADATA +5 -1
  24. {wisent-0.7.901.dist-info → wisent-0.7.1116.dist-info}/RECORD +28 -91
  25. wisent/examples/contrastive_pairs/humanization_human_vs_ai.json +0 -2112
  26. wisent/examples/scripts/1/test_basqueglue_evaluation.json +0 -51
  27. wisent/examples/scripts/1/test_basqueglue_pairs.json +0 -14
  28. wisent/examples/scripts/1/test_bec2016eu_evaluation.json +0 -51
  29. wisent/examples/scripts/1/test_bec2016eu_pairs.json +0 -14
  30. wisent/examples/scripts/1/test_belebele_evaluation.json +0 -51
  31. wisent/examples/scripts/1/test_belebele_pairs.json +0 -14
  32. wisent/examples/scripts/1/test_benchmarks_evaluation.json +0 -51
  33. wisent/examples/scripts/1/test_benchmarks_pairs.json +0 -14
  34. wisent/examples/scripts/1/test_bertaqa_evaluation.json +0 -51
  35. wisent/examples/scripts/1/test_bertaqa_pairs.json +0 -14
  36. wisent/examples/scripts/1/test_bhtc_v2_evaluation.json +0 -30
  37. wisent/examples/scripts/1/test_bhtc_v2_pairs.json +0 -8
  38. wisent/examples/scripts/1/test_boolq-seq2seq_evaluation.json +0 -30
  39. wisent/examples/scripts/1/test_boolq-seq2seq_pairs.json +0 -8
  40. wisent/examples/scripts/1/test_cabreu_evaluation.json +0 -30
  41. wisent/examples/scripts/1/test_cabreu_pairs.json +0 -8
  42. wisent/examples/scripts/1/test_careqa_en_evaluation.json +0 -30
  43. wisent/examples/scripts/1/test_careqa_en_pairs.json +0 -8
  44. wisent/examples/scripts/1/test_careqa_evaluation.json +0 -30
  45. wisent/examples/scripts/1/test_careqa_pairs.json +0 -8
  46. wisent/examples/scripts/1/test_catalanqa_evaluation.json +0 -30
  47. wisent/examples/scripts/1/test_catalanqa_pairs.json +0 -8
  48. wisent/examples/scripts/1/test_catcola_evaluation.json +0 -30
  49. wisent/examples/scripts/1/test_catcola_pairs.json +0 -8
  50. wisent/examples/scripts/1/test_chartqa_evaluation.json +0 -30
  51. wisent/examples/scripts/1/test_chartqa_pairs.json +0 -8
  52. wisent/examples/scripts/1/test_claim_stance_topic_evaluation.json +0 -30
  53. wisent/examples/scripts/1/test_claim_stance_topic_pairs.json +0 -8
  54. wisent/examples/scripts/1/test_cnn_dailymail_evaluation.json +0 -30
  55. wisent/examples/scripts/1/test_cnn_dailymail_pairs.json +0 -8
  56. wisent/examples/scripts/1/test_cocoteros_es_evaluation.json +0 -30
  57. wisent/examples/scripts/1/test_cocoteros_es_pairs.json +0 -8
  58. wisent/examples/scripts/1/test_coedit_gec_evaluation.json +0 -30
  59. wisent/examples/scripts/1/test_coedit_gec_pairs.json +0 -8
  60. wisent/examples/scripts/1/test_cola_evaluation.json +0 -30
  61. wisent/examples/scripts/1/test_cola_pairs.json +0 -8
  62. wisent/examples/scripts/1/test_coqcat_evaluation.json +0 -30
  63. wisent/examples/scripts/1/test_coqcat_pairs.json +0 -8
  64. wisent/examples/scripts/1/test_dbpedia_14_evaluation.json +0 -30
  65. wisent/examples/scripts/1/test_dbpedia_14_pairs.json +0 -8
  66. wisent/examples/scripts/1/test_epec_koref_bin_evaluation.json +0 -30
  67. wisent/examples/scripts/1/test_epec_koref_bin_pairs.json +0 -8
  68. wisent/examples/scripts/1/test_ethos_binary_evaluation.json +0 -30
  69. wisent/examples/scripts/1/test_ethos_binary_pairs.json +0 -8
  70. wisent/examples/scripts/2/test_afrimgsm_direct_amh_evaluation.json +0 -30
  71. wisent/examples/scripts/2/test_afrimgsm_direct_amh_pairs.json +0 -8
  72. wisent/examples/scripts/2/test_afrimmlu_direct_amh_evaluation.json +0 -30
  73. wisent/examples/scripts/2/test_afrimmlu_direct_amh_pairs.json +0 -8
  74. wisent/examples/scripts/2/test_afrixnli_en_direct_amh_evaluation.json +0 -30
  75. wisent/examples/scripts/2/test_afrixnli_en_direct_amh_pairs.json +0 -8
  76. wisent/examples/scripts/2/test_arc_ar_evaluation.json +0 -30
  77. wisent/examples/scripts/2/test_arc_ar_pairs.json +0 -8
  78. wisent/examples/scripts/2/test_atis_evaluation.json +0 -30
  79. wisent/examples/scripts/2/test_atis_pairs.json +0 -8
  80. wisent/examples/scripts/2/test_babi_evaluation.json +0 -30
  81. wisent/examples/scripts/2/test_babi_pairs.json +0 -8
  82. wisent/examples/scripts/2/test_babilong_evaluation.json +0 -30
  83. wisent/examples/scripts/2/test_babilong_pairs.json +0 -8
  84. wisent/examples/scripts/2/test_bangla_mmlu_evaluation.json +0 -30
  85. wisent/examples/scripts/2/test_bangla_mmlu_pairs.json +0 -8
  86. wisent/examples/scripts/2/test_basque-glue_pairs.json +0 -14
  87. wisent/examples/scripts/generate_paper_data.py +0 -384
  88. wisent/examples/scripts/intervention_validation.py +0 -626
  89. wisent/examples/scripts/results/test_AraDiCE_ArabicMMLU_lev_evaluation.json +0 -324
  90. wisent/examples/scripts/results/test_AraDiCE_ArabicMMLU_lev_pairs.json +0 -92
  91. wisent/examples/scripts/results/test_aexams_IslamicStudies_evaluation.json +0 -324
  92. wisent/examples/scripts/results/test_aexams_IslamicStudies_pairs.json +0 -92
  93. wisent/examples/scripts/results/test_afrimgsm_pairs.json +0 -92
  94. wisent/examples/scripts/results/test_afrimmlu_evaluation.json +0 -324
  95. wisent/examples/scripts/results/test_afrimmlu_pairs.json +0 -92
  96. wisent/examples/scripts/threshold_analysis.py +0 -434
  97. wisent/examples/scripts/visualization_gallery.py +0 -582
  98. {wisent-0.7.901.dist-info → wisent-0.7.1116.dist-info}/WHEEL +0 -0
  99. {wisent-0.7.901.dist-info → wisent-0.7.1116.dist-info}/entry_points.txt +0 -0
  100. {wisent-0.7.901.dist-info → wisent-0.7.1116.dist-info}/licenses/LICENSE +0 -0
  101. {wisent-0.7.901.dist-info → wisent-0.7.1116.dist-info}/top_level.txt +0 -0
@@ -1,626 +0,0 @@
1
- """
2
- Intervention Validation for RepScan.
3
-
4
- Tests whether RepScan diagnosis predicts CAA steering success:
5
- - LINEAR diagnosis -> CAA should work
6
- - NONLINEAR diagnosis -> CAA should fail (but detection still works)
7
- - NO_SIGNAL diagnosis -> neither should work
8
-
9
- This is the CRITICAL missing piece identified by reviewers.
10
-
11
- Usage:
12
- python -m wisent.examples.scripts.intervention_validation --model Qwen/Qwen3-8B
13
- """
14
-
15
- import argparse
16
- import json
17
- import subprocess
18
- from pathlib import Path
19
- from typing import Dict, List, Any, Optional, Tuple
20
- from dataclasses import dataclass, field, asdict
21
- import random
22
-
23
- import torch
24
- import numpy as np
25
-
26
- S3_BUCKET = "wisent-bucket"
27
- S3_PREFIX = "intervention_validation"
28
-
29
-
30
- def s3_upload_file(local_path: Path, model_name: str) -> None:
31
- """Upload a single file to S3."""
32
- model_prefix = model_name.replace('/', '_')
33
- s3_path = f"s3://{S3_BUCKET}/{S3_PREFIX}/{model_prefix}/{local_path.name}"
34
- try:
35
- subprocess.run(
36
- ["aws", "s3", "cp", str(local_path), s3_path, "--quiet"],
37
- check=True,
38
- capture_output=True,
39
- )
40
- print(f" Uploaded to S3: {s3_path}")
41
- except Exception as e:
42
- print(f" S3 upload failed: {e}")
43
-
44
-
45
- @dataclass
46
- class SteeringResult:
47
- """Result of a single steering experiment."""
48
- benchmark: str
49
- strategy: str
50
- layer: int
51
- diagnosis: str # LINEAR, NONLINEAR, NO_SIGNAL
52
-
53
- # Before steering
54
- baseline_accuracy: float # Model's baseline accuracy on task
55
- baseline_correct_logprob: float # Avg logprob of correct answer
56
- baseline_incorrect_logprob: float # Avg logprob of incorrect answer
57
-
58
- # After steering (with CAA)
59
- steered_accuracy: float
60
- steered_correct_logprob: float
61
- steered_incorrect_logprob: float
62
-
63
- # Steering effect
64
- accuracy_change: float # steered - baseline (positive = improvement)
65
- logprob_shift: float # Change in correct - incorrect gap
66
- steering_success: bool # Did steering improve in expected direction?
67
-
68
- # Steering parameters
69
- steering_coefficient: float
70
- num_test_samples: int
71
-
72
-
73
- @dataclass
74
- class ValidationResults:
75
- """Results from full intervention validation."""
76
- model: str
77
- results: List[SteeringResult] = field(default_factory=list)
78
-
79
- # Summary statistics by diagnosis
80
- linear_success_rate: float = 0.0
81
- nonlinear_success_rate: float = 0.0
82
- no_signal_success_rate: float = 0.0
83
-
84
- def compute_summary(self):
85
- """Compute summary statistics."""
86
- linear = [r for r in self.results if r.diagnosis == "LINEAR"]
87
- nonlinear = [r for r in self.results if r.diagnosis == "NONLINEAR"]
88
- no_signal = [r for r in self.results if r.diagnosis == "NO_SIGNAL"]
89
-
90
- if linear:
91
- self.linear_success_rate = sum(r.steering_success for r in linear) / len(linear)
92
- if nonlinear:
93
- self.nonlinear_success_rate = sum(r.steering_success for r in nonlinear) / len(nonlinear)
94
- if no_signal:
95
- self.no_signal_success_rate = sum(r.steering_success for r in no_signal) / len(no_signal)
96
-
97
-
98
- def compute_caa_direction(
99
- pos_activations: torch.Tensor,
100
- neg_activations: torch.Tensor,
101
- ) -> torch.Tensor:
102
- """
103
- Compute CAA (Contrastive Activation Addition) direction.
104
-
105
- This is the difference-in-means direction used for steering.
106
-
107
- Args:
108
- pos_activations: [N, hidden_dim] positive class activations
109
- neg_activations: [N, hidden_dim] negative class activations
110
-
111
- Returns:
112
- [hidden_dim] steering direction (normalized)
113
- """
114
- pos_mean = pos_activations.float().mean(dim=0)
115
- neg_mean = neg_activations.float().mean(dim=0)
116
- direction = pos_mean - neg_mean
117
- return direction / (direction.norm() + 1e-10)
118
-
119
-
120
- def apply_steering_hook(
121
- model: "WisentModel",
122
- layer: int,
123
- direction: torch.Tensor,
124
- coefficient: float,
125
- ) -> callable:
126
- """
127
- Create a forward hook that adds steering vector to activations.
128
-
129
- Args:
130
- model: WisentModel instance
131
- layer: Layer index to apply steering
132
- direction: [hidden_dim] steering direction
133
- coefficient: Steering strength
134
-
135
- Returns:
136
- Hook function
137
- """
138
- def hook(module, input, output):
139
- # output is (hidden_states, ...) or just hidden_states
140
- if isinstance(output, tuple):
141
- hidden_states = output[0]
142
- # Add steering vector to last token position
143
- hidden_states[:, -1, :] += coefficient * direction.to(hidden_states.device)
144
- return (hidden_states,) + output[1:]
145
- else:
146
- output[:, -1, :] += coefficient * direction.to(output.device)
147
- return output
148
-
149
- return hook
150
-
151
-
152
- def get_model_logprobs(
153
- model: "WisentModel",
154
- prompt: str,
155
- completion: str,
156
- ) -> float:
157
- """
158
- Get log probability of completion given prompt.
159
-
160
- Args:
161
- model: WisentModel instance
162
- prompt: Input prompt
163
- completion: Completion to score
164
-
165
- Returns:
166
- Average log probability of completion tokens
167
- """
168
- full_text = prompt + completion
169
-
170
- inputs = model.tokenizer(
171
- full_text,
172
- return_tensors="pt",
173
- truncation=True,
174
- max_length=2048,
175
- ).to(model.device)
176
-
177
- prompt_tokens = model.tokenizer(
178
- prompt,
179
- return_tensors="pt",
180
- truncation=True,
181
- max_length=2048,
182
- ).input_ids.shape[1]
183
-
184
- with torch.no_grad():
185
- outputs = model.model(**inputs)
186
- logits = outputs.logits
187
-
188
- # Get logprobs for completion tokens only
189
- shift_logits = logits[:, prompt_tokens-1:-1, :].contiguous()
190
- shift_labels = inputs.input_ids[:, prompt_tokens:].contiguous()
191
-
192
- log_probs = torch.nn.functional.log_softmax(shift_logits, dim=-1)
193
- token_log_probs = log_probs.gather(-1, shift_labels.unsqueeze(-1)).squeeze(-1)
194
-
195
- return token_log_probs.mean().item()
196
-
197
-
198
- def evaluate_steering(
199
- model: "WisentModel",
200
- test_pairs: List,
201
- layer: int,
202
- direction: torch.Tensor,
203
- coefficient: float,
204
- ) -> Tuple[float, float, float]:
205
- """
206
- Evaluate steering effect on test pairs.
207
-
208
- Args:
209
- model: WisentModel instance
210
- test_pairs: List of ContrastivePair objects
211
- layer: Layer to apply steering
212
- direction: Steering direction
213
- coefficient: Steering strength
214
-
215
- Returns:
216
- (accuracy, avg_correct_logprob, avg_incorrect_logprob)
217
- """
218
- # Get the layer module
219
- layer_module = model.model.model.layers[layer]
220
-
221
- # Register hook
222
- hook = apply_steering_hook(model, layer, direction, coefficient)
223
- handle = layer_module.register_forward_hook(hook)
224
-
225
- try:
226
- correct = 0
227
- correct_logprobs = []
228
- incorrect_logprobs = []
229
-
230
- for pair in test_pairs:
231
- prompt = pair.positive_prompt # Same prompt for both
232
- correct_completion = pair.positive_completion
233
- incorrect_completion = pair.negative_completion
234
-
235
- correct_lp = get_model_logprobs(model, prompt, correct_completion)
236
- incorrect_lp = get_model_logprobs(model, prompt, incorrect_completion)
237
-
238
- correct_logprobs.append(correct_lp)
239
- incorrect_logprobs.append(incorrect_lp)
240
-
241
- if correct_lp > incorrect_lp:
242
- correct += 1
243
-
244
- accuracy = correct / len(test_pairs) if test_pairs else 0.0
245
- avg_correct = np.mean(correct_logprobs) if correct_logprobs else 0.0
246
- avg_incorrect = np.mean(incorrect_logprobs) if incorrect_logprobs else 0.0
247
-
248
- return accuracy, avg_correct, avg_incorrect
249
-
250
- finally:
251
- handle.remove()
252
-
253
-
254
- def evaluate_baseline(
255
- model: "WisentModel",
256
- test_pairs: List,
257
- ) -> Tuple[float, float, float]:
258
- """
259
- Evaluate baseline (no steering) on test pairs.
260
-
261
- Args:
262
- model: WisentModel instance
263
- test_pairs: List of ContrastivePair objects
264
-
265
- Returns:
266
- (accuracy, avg_correct_logprob, avg_incorrect_logprob)
267
- """
268
- correct = 0
269
- correct_logprobs = []
270
- incorrect_logprobs = []
271
-
272
- for pair in test_pairs:
273
- prompt = pair.positive_prompt
274
- correct_completion = pair.positive_completion
275
- incorrect_completion = pair.negative_completion
276
-
277
- correct_lp = get_model_logprobs(model, prompt, correct_completion)
278
- incorrect_lp = get_model_logprobs(model, prompt, incorrect_completion)
279
-
280
- correct_logprobs.append(correct_lp)
281
- incorrect_logprobs.append(incorrect_lp)
282
-
283
- if correct_lp > incorrect_lp:
284
- correct += 1
285
-
286
- accuracy = correct / len(test_pairs) if test_pairs else 0.0
287
- avg_correct = np.mean(correct_logprobs) if correct_logprobs else 0.0
288
- avg_incorrect = np.mean(incorrect_logprobs) if incorrect_logprobs else 0.0
289
-
290
- return accuracy, avg_correct, avg_incorrect
291
-
292
-
293
- def load_diagnosis_results(model_name: str, output_dir: Path) -> Dict[str, Any]:
294
- """Load RepScan diagnosis results from S3/local."""
295
- model_prefix = model_name.replace('/', '_')
296
-
297
- # Try to download from S3 first
298
- try:
299
- subprocess.run(
300
- ["aws", "s3", "sync",
301
- f"s3://{S3_BUCKET}/direction_discovery/{model_prefix}/",
302
- str(output_dir / "diagnosis"),
303
- "--quiet"],
304
- check=False,
305
- capture_output=True,
306
- )
307
- except Exception:
308
- pass
309
-
310
- # Load results
311
- results = {}
312
- diagnosis_dir = output_dir / "diagnosis"
313
- if diagnosis_dir.exists():
314
- for f in diagnosis_dir.glob(f"{model_prefix}_*.json"):
315
- if "summary" not in f.name:
316
- category = f.stem.replace(f"{model_prefix}_", "")
317
- with open(f) as fp:
318
- results[category] = json.load(fp)
319
-
320
- return results
321
-
322
-
323
- def get_diagnosis_for_benchmark(
324
- diagnosis_results: Dict[str, Any],
325
- benchmark: str,
326
- strategy: str = "chat_last",
327
- ) -> Tuple[str, int, float, float]:
328
- """
329
- Get RepScan diagnosis for a specific benchmark.
330
-
331
- Args:
332
- diagnosis_results: Loaded diagnosis results
333
- benchmark: Benchmark name
334
- strategy: Extraction strategy
335
-
336
- Returns:
337
- (diagnosis, best_layer, signal_strength, linear_probe_accuracy)
338
- """
339
- for category, data in diagnosis_results.items():
340
- results = data.get("results", [])
341
- for r in results:
342
- if r["benchmark"] == benchmark and r["strategy"] == strategy:
343
- signal = r["signal_strength"]
344
- linear = r["linear_probe_accuracy"]
345
- layers = r["layers"]
346
- best_layer = layers[0] if layers else 16 # Default to middle layer
347
-
348
- # Determine diagnosis
349
- if signal < 0.6:
350
- diagnosis = "NO_SIGNAL"
351
- elif linear > 0.6 and (signal - linear) < 0.15:
352
- diagnosis = "LINEAR"
353
- else:
354
- diagnosis = "NONLINEAR"
355
-
356
- return diagnosis, best_layer, signal, linear
357
-
358
- return "UNKNOWN", 16, 0.5, 0.5
359
-
360
-
361
- def run_intervention_validation(
362
- model_name: str,
363
- benchmarks_to_test: Optional[List[str]] = None,
364
- samples_per_benchmark: int = 20,
365
- test_samples: int = 30,
366
- steering_coefficients: List[float] = [0.5, 1.0, 2.0, 3.0],
367
- ):
368
- """
369
- Run intervention validation experiments.
370
-
371
- Args:
372
- model_name: Model to test
373
- benchmarks_to_test: Specific benchmarks (default: sample from each diagnosis)
374
- samples_per_benchmark: Pairs for computing steering direction
375
- test_samples: Pairs for evaluating steering
376
- steering_coefficients: Coefficients to test
377
- """
378
- from wisent.core.models.wisent_model import WisentModel
379
- from wisent.core.activations.extraction_strategy import ExtractionStrategy
380
- from wisent.core.activations.activation_cache import ActivationCache, collect_and_cache_activations
381
- from lm_eval.tasks import TaskManager
382
- from wisent.core.contrastive_pairs.lm_eval_pairs.lm_task_pairs_generation import lm_build_contrastive_pairs
383
-
384
- print("=" * 70)
385
- print("INTERVENTION VALIDATION")
386
- print("=" * 70)
387
- print(f"Model: {model_name}")
388
-
389
- output_dir = Path("/tmp/intervention_validation")
390
- output_dir.mkdir(parents=True, exist_ok=True)
391
-
392
- # Load diagnosis results
393
- diagnosis_results = load_diagnosis_results(model_name, output_dir)
394
- if not diagnosis_results:
395
- print("ERROR: No diagnosis results found. Run discover_directions first.")
396
- return
397
-
398
- # Select benchmarks to test (sample from each diagnosis type)
399
- if benchmarks_to_test is None:
400
- benchmarks_to_test = []
401
-
402
- # Collect all benchmarks with their diagnoses
403
- by_diagnosis = {"LINEAR": [], "NONLINEAR": [], "NO_SIGNAL": []}
404
-
405
- for category, data in diagnosis_results.items():
406
- results = data.get("results", [])
407
- seen_benchmarks = set()
408
- for r in results:
409
- bench = r["benchmark"]
410
- if bench in seen_benchmarks:
411
- continue
412
- seen_benchmarks.add(bench)
413
-
414
- signal = r["signal_strength"]
415
- linear = r["linear_probe_accuracy"]
416
-
417
- if signal < 0.6:
418
- by_diagnosis["NO_SIGNAL"].append(bench)
419
- elif linear > 0.6 and (signal - linear) < 0.15:
420
- by_diagnosis["LINEAR"].append(bench)
421
- else:
422
- by_diagnosis["NONLINEAR"].append(bench)
423
-
424
- # Sample 3 from each category
425
- random.seed(42)
426
- for diag, benches in by_diagnosis.items():
427
- if benches:
428
- sampled = random.sample(benches, min(3, len(benches)))
429
- benchmarks_to_test.extend(sampled)
430
- print(f" {diag}: {sampled}")
431
-
432
- print(f"\nBenchmarks to test: {benchmarks_to_test}")
433
-
434
- # Load model
435
- print(f"\nLoading model: {model_name}")
436
- model = WisentModel(model_name, device="cuda")
437
- print(f" Layers: {model.num_layers}, Hidden: {model.hidden_size}")
438
-
439
- # Cache directory
440
- model_prefix = model_name.replace('/', '_')
441
- cache_dir = f"/tmp/wisent_intervention_cache_{model_prefix}"
442
- cache = ActivationCache(cache_dir)
443
-
444
- # Results
445
- validation_results = ValidationResults(model=model_name)
446
-
447
- tm = TaskManager()
448
- strategy = ExtractionStrategy.CHAT_LAST
449
-
450
- for benchmark in benchmarks_to_test:
451
- print(f"\n{'-' * 50}")
452
- print(f"Benchmark: {benchmark}")
453
- print("-" * 50)
454
-
455
- # Get diagnosis
456
- diagnosis, best_layer, signal, linear_acc = get_diagnosis_for_benchmark(
457
- diagnosis_results, benchmark, strategy.value
458
- )
459
- print(f" Diagnosis: {diagnosis}")
460
- print(f" Signal: {signal:.3f}, Linear: {linear_acc:.3f}")
461
- print(f" Best layer: {best_layer}")
462
-
463
- # Load pairs
464
- try:
465
- task_dict = tm.load_task_or_group([benchmark])
466
- task = list(task_dict.values())[0]
467
- except Exception:
468
- task = None
469
-
470
- try:
471
- all_pairs = lm_build_contrastive_pairs(
472
- benchmark,
473
- task,
474
- limit=samples_per_benchmark + test_samples
475
- )
476
- except Exception as e:
477
- print(f" ERROR loading pairs: {e}")
478
- continue
479
-
480
- if len(all_pairs) < samples_per_benchmark + test_samples:
481
- print(f" SKIP: Not enough pairs ({len(all_pairs)})")
482
- continue
483
-
484
- # Split into train (for direction) and test (for evaluation)
485
- random.shuffle(all_pairs)
486
- train_pairs = all_pairs[:samples_per_benchmark]
487
- test_pairs = all_pairs[samples_per_benchmark:samples_per_benchmark + test_samples]
488
-
489
- print(f" Train pairs: {len(train_pairs)}, Test pairs: {len(test_pairs)}")
490
-
491
- # Get activations for training pairs
492
- print(f" Extracting activations...")
493
- try:
494
- cached = collect_and_cache_activations(
495
- model=model,
496
- pairs=train_pairs,
497
- benchmark=benchmark,
498
- strategy=strategy,
499
- cache=cache,
500
- show_progress=False,
501
- )
502
- except Exception as e:
503
- print(f" ERROR extracting activations: {e}")
504
- continue
505
-
506
- # Get activations at best layer
507
- layer_name = str(best_layer + 1) # 1-based
508
- try:
509
- pos_acts = cached.get_positive_activations(layer_name)
510
- neg_acts = cached.get_negative_activations(layer_name)
511
- except Exception as e:
512
- print(f" ERROR getting activations: {e}")
513
- continue
514
-
515
- # Compute CAA direction
516
- direction = compute_caa_direction(pos_acts, neg_acts)
517
- print(f" Direction norm: {direction.norm().item():.4f}")
518
-
519
- # Evaluate baseline
520
- print(f" Evaluating baseline...")
521
- base_acc, base_correct_lp, base_incorrect_lp = evaluate_baseline(model, test_pairs)
522
- print(f" Baseline accuracy: {base_acc:.3f}")
523
- print(f" Baseline logprob gap: {base_correct_lp - base_incorrect_lp:.4f}")
524
-
525
- # Test steering at different coefficients
526
- best_result = None
527
- best_improvement = -float('inf')
528
-
529
- for coef in steering_coefficients:
530
- print(f" Testing coefficient={coef}...")
531
- steered_acc, steered_correct_lp, steered_incorrect_lp = evaluate_steering(
532
- model, test_pairs, best_layer, direction, coef
533
- )
534
-
535
- acc_change = steered_acc - base_acc
536
- lp_shift = (steered_correct_lp - steered_incorrect_lp) - (base_correct_lp - base_incorrect_lp)
537
-
538
- print(f" Steered accuracy: {steered_acc:.3f} (change: {acc_change:+.3f})")
539
- print(f" Logprob shift: {lp_shift:+.4f}")
540
-
541
- # Steering is successful if it improves accuracy OR logprob gap
542
- steering_success = acc_change > 0.05 or lp_shift > 0.1
543
-
544
- if acc_change > best_improvement:
545
- best_improvement = acc_change
546
- best_result = SteeringResult(
547
- benchmark=benchmark,
548
- strategy=strategy.value,
549
- layer=best_layer,
550
- diagnosis=diagnosis,
551
- baseline_accuracy=base_acc,
552
- baseline_correct_logprob=base_correct_lp,
553
- baseline_incorrect_logprob=base_incorrect_lp,
554
- steered_accuracy=steered_acc,
555
- steered_correct_logprob=steered_correct_lp,
556
- steered_incorrect_logprob=steered_incorrect_lp,
557
- accuracy_change=acc_change,
558
- logprob_shift=lp_shift,
559
- steering_success=steering_success,
560
- steering_coefficient=coef,
561
- num_test_samples=len(test_pairs),
562
- )
563
-
564
- if best_result:
565
- validation_results.results.append(best_result)
566
- print(f"\n Best result: coef={best_result.steering_coefficient}, "
567
- f"acc_change={best_result.accuracy_change:+.3f}, "
568
- f"success={best_result.steering_success}")
569
-
570
- # Compute summary
571
- validation_results.compute_summary()
572
-
573
- # Print summary
574
- print("\n" + "=" * 70)
575
- print("VALIDATION SUMMARY")
576
- print("=" * 70)
577
- print(f"\nLinear diagnosis -> CAA success rate: {validation_results.linear_success_rate:.1%}")
578
- print(f"Nonlinear diagnosis -> CAA success rate: {validation_results.nonlinear_success_rate:.1%}")
579
- print(f"No signal diagnosis -> CAA success rate: {validation_results.no_signal_success_rate:.1%}")
580
-
581
- # Expected pattern:
582
- # LINEAR -> high success rate
583
- # NONLINEAR -> low success rate (CAA doesn't work, but detection does)
584
- # NO_SIGNAL -> low success rate
585
-
586
- if validation_results.linear_success_rate > validation_results.nonlinear_success_rate:
587
- print("\n✓ VALIDATION PASSED: LINEAR diagnosis predicts higher CAA success!")
588
- else:
589
- print("\n✗ VALIDATION FAILED: LINEAR diagnosis does not predict higher CAA success")
590
-
591
- # Save results
592
- results_file = output_dir / f"{model_prefix}_validation.json"
593
- with open(results_file, "w") as f:
594
- json.dump({
595
- "model": model_name,
596
- "results": [asdict(r) for r in validation_results.results],
597
- "summary": {
598
- "linear_success_rate": validation_results.linear_success_rate,
599
- "nonlinear_success_rate": validation_results.nonlinear_success_rate,
600
- "no_signal_success_rate": validation_results.no_signal_success_rate,
601
- }
602
- }, f, indent=2)
603
-
604
- print(f"\nResults saved to: {results_file}")
605
- s3_upload_file(results_file, model_name)
606
-
607
- # Cleanup
608
- del model
609
-
610
- return validation_results
611
-
612
-
613
- if __name__ == "__main__":
614
- parser = argparse.ArgumentParser(description="Intervention validation for RepScan")
615
- parser.add_argument("--model", type=str, default="Qwen/Qwen3-8B", help="Model to test")
616
- parser.add_argument("--benchmarks", type=str, nargs="+", default=None, help="Specific benchmarks to test")
617
- parser.add_argument("--samples", type=int, default=20, help="Samples for direction computation")
618
- parser.add_argument("--test-samples", type=int, default=30, help="Samples for evaluation")
619
- args = parser.parse_args()
620
-
621
- run_intervention_validation(
622
- model_name=args.model,
623
- benchmarks_to_test=args.benchmarks,
624
- samples_per_benchmark=args.samples,
625
- test_samples=args.test_samples,
626
- )