wisent 0.7.901__py3-none-any.whl → 0.7.1045__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (100) hide show
  1. wisent/__init__.py +1 -1
  2. wisent/comparison/__init__.py +1 -0
  3. wisent/comparison/detect_bos_features.py +275 -0
  4. wisent/comparison/fgaa.py +465 -0
  5. wisent/comparison/lora.py +669 -0
  6. wisent/comparison/lora_dpo.py +592 -0
  7. wisent/comparison/main.py +444 -0
  8. wisent/comparison/ours.py +76 -0
  9. wisent/comparison/sae.py +304 -0
  10. wisent/comparison/utils.py +381 -0
  11. wisent/core/activations/activations_collector.py +3 -2
  12. wisent/core/activations/extraction_strategy.py +8 -4
  13. wisent/core/cli/agent/apply_steering.py +7 -5
  14. wisent/core/cli/agent/train_classifier.py +4 -3
  15. wisent/core/cli/generate_vector_from_task.py +11 -20
  16. wisent/core/cli/get_activations.py +1 -1
  17. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/boolq.py +20 -3
  18. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/cb.py +8 -1
  19. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/truthfulqa_mc1.py +8 -1
  20. wisent/core/parser_arguments/generate_vector_from_task_parser.py +4 -11
  21. wisent/core/parser_arguments/get_activations_parser.py +5 -14
  22. {wisent-0.7.901.dist-info → wisent-0.7.1045.dist-info}/METADATA +5 -1
  23. {wisent-0.7.901.dist-info → wisent-0.7.1045.dist-info}/RECORD +27 -91
  24. wisent/examples/contrastive_pairs/humanization_human_vs_ai.json +0 -2112
  25. wisent/examples/scripts/1/test_basqueglue_evaluation.json +0 -51
  26. wisent/examples/scripts/1/test_basqueglue_pairs.json +0 -14
  27. wisent/examples/scripts/1/test_bec2016eu_evaluation.json +0 -51
  28. wisent/examples/scripts/1/test_bec2016eu_pairs.json +0 -14
  29. wisent/examples/scripts/1/test_belebele_evaluation.json +0 -51
  30. wisent/examples/scripts/1/test_belebele_pairs.json +0 -14
  31. wisent/examples/scripts/1/test_benchmarks_evaluation.json +0 -51
  32. wisent/examples/scripts/1/test_benchmarks_pairs.json +0 -14
  33. wisent/examples/scripts/1/test_bertaqa_evaluation.json +0 -51
  34. wisent/examples/scripts/1/test_bertaqa_pairs.json +0 -14
  35. wisent/examples/scripts/1/test_bhtc_v2_evaluation.json +0 -30
  36. wisent/examples/scripts/1/test_bhtc_v2_pairs.json +0 -8
  37. wisent/examples/scripts/1/test_boolq-seq2seq_evaluation.json +0 -30
  38. wisent/examples/scripts/1/test_boolq-seq2seq_pairs.json +0 -8
  39. wisent/examples/scripts/1/test_cabreu_evaluation.json +0 -30
  40. wisent/examples/scripts/1/test_cabreu_pairs.json +0 -8
  41. wisent/examples/scripts/1/test_careqa_en_evaluation.json +0 -30
  42. wisent/examples/scripts/1/test_careqa_en_pairs.json +0 -8
  43. wisent/examples/scripts/1/test_careqa_evaluation.json +0 -30
  44. wisent/examples/scripts/1/test_careqa_pairs.json +0 -8
  45. wisent/examples/scripts/1/test_catalanqa_evaluation.json +0 -30
  46. wisent/examples/scripts/1/test_catalanqa_pairs.json +0 -8
  47. wisent/examples/scripts/1/test_catcola_evaluation.json +0 -30
  48. wisent/examples/scripts/1/test_catcola_pairs.json +0 -8
  49. wisent/examples/scripts/1/test_chartqa_evaluation.json +0 -30
  50. wisent/examples/scripts/1/test_chartqa_pairs.json +0 -8
  51. wisent/examples/scripts/1/test_claim_stance_topic_evaluation.json +0 -30
  52. wisent/examples/scripts/1/test_claim_stance_topic_pairs.json +0 -8
  53. wisent/examples/scripts/1/test_cnn_dailymail_evaluation.json +0 -30
  54. wisent/examples/scripts/1/test_cnn_dailymail_pairs.json +0 -8
  55. wisent/examples/scripts/1/test_cocoteros_es_evaluation.json +0 -30
  56. wisent/examples/scripts/1/test_cocoteros_es_pairs.json +0 -8
  57. wisent/examples/scripts/1/test_coedit_gec_evaluation.json +0 -30
  58. wisent/examples/scripts/1/test_coedit_gec_pairs.json +0 -8
  59. wisent/examples/scripts/1/test_cola_evaluation.json +0 -30
  60. wisent/examples/scripts/1/test_cola_pairs.json +0 -8
  61. wisent/examples/scripts/1/test_coqcat_evaluation.json +0 -30
  62. wisent/examples/scripts/1/test_coqcat_pairs.json +0 -8
  63. wisent/examples/scripts/1/test_dbpedia_14_evaluation.json +0 -30
  64. wisent/examples/scripts/1/test_dbpedia_14_pairs.json +0 -8
  65. wisent/examples/scripts/1/test_epec_koref_bin_evaluation.json +0 -30
  66. wisent/examples/scripts/1/test_epec_koref_bin_pairs.json +0 -8
  67. wisent/examples/scripts/1/test_ethos_binary_evaluation.json +0 -30
  68. wisent/examples/scripts/1/test_ethos_binary_pairs.json +0 -8
  69. wisent/examples/scripts/2/test_afrimgsm_direct_amh_evaluation.json +0 -30
  70. wisent/examples/scripts/2/test_afrimgsm_direct_amh_pairs.json +0 -8
  71. wisent/examples/scripts/2/test_afrimmlu_direct_amh_evaluation.json +0 -30
  72. wisent/examples/scripts/2/test_afrimmlu_direct_amh_pairs.json +0 -8
  73. wisent/examples/scripts/2/test_afrixnli_en_direct_amh_evaluation.json +0 -30
  74. wisent/examples/scripts/2/test_afrixnli_en_direct_amh_pairs.json +0 -8
  75. wisent/examples/scripts/2/test_arc_ar_evaluation.json +0 -30
  76. wisent/examples/scripts/2/test_arc_ar_pairs.json +0 -8
  77. wisent/examples/scripts/2/test_atis_evaluation.json +0 -30
  78. wisent/examples/scripts/2/test_atis_pairs.json +0 -8
  79. wisent/examples/scripts/2/test_babi_evaluation.json +0 -30
  80. wisent/examples/scripts/2/test_babi_pairs.json +0 -8
  81. wisent/examples/scripts/2/test_babilong_evaluation.json +0 -30
  82. wisent/examples/scripts/2/test_babilong_pairs.json +0 -8
  83. wisent/examples/scripts/2/test_bangla_mmlu_evaluation.json +0 -30
  84. wisent/examples/scripts/2/test_bangla_mmlu_pairs.json +0 -8
  85. wisent/examples/scripts/2/test_basque-glue_pairs.json +0 -14
  86. wisent/examples/scripts/generate_paper_data.py +0 -384
  87. wisent/examples/scripts/intervention_validation.py +0 -626
  88. wisent/examples/scripts/results/test_AraDiCE_ArabicMMLU_lev_evaluation.json +0 -324
  89. wisent/examples/scripts/results/test_AraDiCE_ArabicMMLU_lev_pairs.json +0 -92
  90. wisent/examples/scripts/results/test_aexams_IslamicStudies_evaluation.json +0 -324
  91. wisent/examples/scripts/results/test_aexams_IslamicStudies_pairs.json +0 -92
  92. wisent/examples/scripts/results/test_afrimgsm_pairs.json +0 -92
  93. wisent/examples/scripts/results/test_afrimmlu_evaluation.json +0 -324
  94. wisent/examples/scripts/results/test_afrimmlu_pairs.json +0 -92
  95. wisent/examples/scripts/threshold_analysis.py +0 -434
  96. wisent/examples/scripts/visualization_gallery.py +0 -582
  97. {wisent-0.7.901.dist-info → wisent-0.7.1045.dist-info}/WHEEL +0 -0
  98. {wisent-0.7.901.dist-info → wisent-0.7.1045.dist-info}/entry_points.txt +0 -0
  99. {wisent-0.7.901.dist-info → wisent-0.7.1045.dist-info}/licenses/LICENSE +0 -0
  100. {wisent-0.7.901.dist-info → wisent-0.7.1045.dist-info}/top_level.txt +0 -0
wisent/__init__.py CHANGED
@@ -1,4 +1,4 @@
1
- __version__ = "0.7.901"
1
+ __version__ = "0.7.1045"
2
2
 
3
3
  from wisent.core.diversity_processors import (
4
4
  OpenerPenaltyProcessor,
@@ -0,0 +1 @@
1
+ # Comparison methods for evaluating steering techniques
@@ -0,0 +1,275 @@
1
+ """
2
+ Detect BOS (Beginning of Sequence) features in Gemma Scope SAEs.
3
+
4
+ BOS features are SAE features that activate most strongly at the BOS token position.
5
+ These should be filtered out when computing steering vectors as they introduce
6
+ artifacts without contributing to steering.
7
+
8
+ Reference: "Interpretable Steering of Large Language Models with Feature Guided
9
+ Activation Additions" (arXiv:2501.09929), Appendix G.
10
+
11
+ Known BOS features from paper (Gemma-2-2B, layer 12, 16k SAE):
12
+ - 11087, 3220, 11752, 12160, 11498
13
+
14
+ Usage:
15
+ python -m wisent.comparison.detect_bos_features --model google/gemma-2-2b
16
+ """
17
+
18
+ from __future__ import annotations
19
+
20
+ import argparse
21
+ import json
22
+ from pathlib import Path
23
+
24
+ import torch
25
+ from datasets import load_dataset
26
+ from tqdm import tqdm
27
+
28
+
29
+ # Known BOS feature indices from paper (Appendix G)
30
+ KNOWN_BOS_FEATURES = {
31
+ "google/gemma-2-2b": [11087, 3220, 11752, 12160, 11498],
32
+ "google/gemma-2-9b": [], # Not listed in paper
33
+ }
34
+
35
+
36
+ def load_sample_texts(num_samples: int = 2000, min_length: int = 50) -> list[str]:
37
+ """Load sample texts from WikiText dataset."""
38
+ print(f"Loading up to {num_samples} sample texts from WikiText...")
39
+
40
+ dataset = load_dataset("wikitext", "wikitext-103-v1", split="train")
41
+
42
+ texts = []
43
+ for item in dataset:
44
+ if len(texts) >= num_samples:
45
+ break
46
+ text = item["text"].strip()
47
+ if len(text) >= min_length:
48
+ texts.append(text)
49
+
50
+ print(f" Loaded {len(texts)} texts")
51
+ return texts
52
+
53
+
54
+ def detect_bos_features(
55
+ model,
56
+ tokenizer,
57
+ sae,
58
+ layer_idx: int,
59
+ device: str,
60
+ texts: list[str],
61
+ top_k: int = 10,
62
+ batch_size: int = 8,
63
+ ) -> tuple[list[int], dict[str, torch.Tensor]]:
64
+ """
65
+ Detect BOS features by finding features that activate most strongly at position 0.
66
+
67
+ Computes statistics (mean, variance, median) of activation at BOS position for each
68
+ SAE feature across all samples, then returns the top-k features with highest mean
69
+ BOS activation.
70
+
71
+ Reference: FGAA paper (arXiv:2501.09929) identifies BOS features as those that
72
+ "exclusively had the strongest activation on the BOS token".
73
+
74
+ Args:
75
+ model: HuggingFace model
76
+ tokenizer: Tokenizer
77
+ sae: SAE object from sae_lens
78
+ layer_idx: Layer index (0-indexed)
79
+ device: Device
80
+ texts: List of sample texts to analyze
81
+ top_k: Number of top BOS features to return (default 10)
82
+ batch_size: Batch size for processing
83
+
84
+ Returns:
85
+ Tuple of (top_bos_feature_indices, stats_dict) where stats_dict contains
86
+ 'mean', 'variance', and 'median' tensors of shape [d_sae].
87
+ """
88
+ d_sae = sae.cfg.d_sae
89
+
90
+ # Collect all BOS activations on CPU for stable statistics computation
91
+ all_bos_activations = []
92
+
93
+ print(f"Detecting BOS features from {len(texts)} samples...")
94
+ print(f" Layer: {layer_idx}, d_sae: {d_sae}")
95
+
96
+ # Use hook to capture only the layer we need (not all 26 layers)
97
+ captured_acts = {}
98
+ def capture_hook(module, input, output):
99
+ captured_acts["hidden"] = output[0].detach()
100
+
101
+ # Register hook on the specific layer
102
+ target_layer = model.model.layers[layer_idx]
103
+ hook_handle = target_layer.register_forward_hook(capture_hook)
104
+
105
+ try:
106
+ for i in tqdm(range(0, len(texts), batch_size), desc="Processing"):
107
+ batch_texts = texts[i:i + batch_size]
108
+
109
+ inputs = tokenizer(
110
+ batch_texts,
111
+ return_tensors="pt",
112
+ truncation=True,
113
+ max_length=128,
114
+ padding=True,
115
+ )
116
+ inputs = {k: v.to(device) for k, v in inputs.items()}
117
+
118
+ with torch.no_grad():
119
+ # Use model.model to skip lm_head (saves ~66MB per forward)
120
+ model.model(**inputs, use_cache=False)
121
+
122
+ # Get captured hidden states (only this layer, not all 26)
123
+ acts = captured_acts["hidden"].to(sae.W_enc.dtype)
124
+ latents = sae.encode(acts)
125
+
126
+ attention_mask = inputs.get("attention_mask")
127
+ for j in range(latents.shape[0]):
128
+ seq_len = int(attention_mask[j].sum().item()) if attention_mask is not None else latents.shape[1]
129
+ sample_latents = latents[j, :seq_len, :] # [seq_len, d_sae]
130
+
131
+ # Collect BOS activation (position 0) - move to CPU immediately
132
+ bos_act = sample_latents[0].float().cpu()
133
+ all_bos_activations.append(bos_act)
134
+ finally:
135
+ hook_handle.remove()
136
+
137
+ # Compute statistics (all on CPU)
138
+ # Stack all activations for stable computation
139
+ all_bos_tensor = torch.stack(all_bos_activations, dim=0) # [num_samples, d_sae]
140
+ mean_bos = all_bos_tensor.mean(dim=0)
141
+ variance_bos = all_bos_tensor.var(dim=0)
142
+ median_bos = all_bos_tensor.median(dim=0).values
143
+
144
+ stats = {
145
+ "mean": mean_bos,
146
+ "variance": variance_bos,
147
+ "median": median_bos,
148
+ }
149
+
150
+ # Select top features by BOS activation
151
+ top_indices = mean_bos.topk(top_k).indices.tolist()
152
+
153
+ print(f"\nDetected top {top_k} BOS features by mean activation")
154
+ return top_indices, stats
155
+
156
+
157
+ def compare_with_known(model_name: str, detected: list[int], stats: dict[str, torch.Tensor]) -> None:
158
+ """Compare detected BOS features with known list from paper."""
159
+ mean_bos = stats["mean"]
160
+ variance_bos = stats["variance"]
161
+ median_bos = stats["median"]
162
+
163
+ known = KNOWN_BOS_FEATURES.get(model_name, [])
164
+ detected_set = set(detected)
165
+ known_set = set(known)
166
+
167
+ print(f"\n{'='*60}")
168
+ print(f"BOS Feature Comparison for {model_name}")
169
+ print(f"{'='*60}")
170
+ print(f"Known (paper): {sorted(known_set)}")
171
+ print(f"Detected: {sorted(detected_set)}")
172
+ print(f"{'='*60}")
173
+ print(f"Common: {sorted(detected_set & known_set)}")
174
+ print(f"Only in paper: {sorted(known_set - detected_set)}")
175
+ print(f"Only detected: {sorted(detected_set - known_set)}")
176
+
177
+ if known:
178
+ print(f"\nKnown features - BOS activation stats:")
179
+ for idx in sorted(known):
180
+ mean_val = mean_bos[idx].item()
181
+ var_val = variance_bos[idx].item()
182
+ median_val = median_bos[idx].item()
183
+ status = "detected" if idx in detected_set else "missed"
184
+ print(f" Feature {idx}: mean={mean_val:.4f}, var={var_val:.4f}, median={median_val:.4f} ({status})")
185
+
186
+ print(f"\nTop 20 features by mean BOS activation:")
187
+ for rank, idx in enumerate(mean_bos.topk(20).indices.tolist(), 1):
188
+ mean_val = mean_bos[idx].item()
189
+ var_val = variance_bos[idx].item()
190
+ median_val = median_bos[idx].item()
191
+ marker = " (known)" if idx in known_set else ""
192
+ print(f" {rank:2}. Feature {idx}: mean={mean_val:.4f}, var={var_val:.4f}, median={median_val:.4f}{marker}")
193
+
194
+
195
+ def main():
196
+ parser = argparse.ArgumentParser(description="Detect BOS features in Gemma Scope SAEs")
197
+ parser.add_argument("--model", default="google/gemma-2-2b", help="Model name")
198
+ parser.add_argument("--layer", type=int, default=12, help="Layer index")
199
+ parser.add_argument("--num-samples", type=int, default=1000, help="Number of text samples")
200
+ parser.add_argument("--top-k", type=int, default=20, help="Number of top BOS features to detect")
201
+ parser.add_argument("--batch-size", type=int, default=1, help="Batch size")
202
+ parser.add_argument("--device", default="cuda:0", help="Device")
203
+ parser.add_argument("--output-dir", default="wisent/comparison/results", help="Output directory")
204
+ args = parser.parse_args()
205
+
206
+ print(f"Model: {args.model}")
207
+ print(f"Layer: {args.layer}")
208
+ print(f"Top-k: {args.top_k}")
209
+
210
+ from transformers import AutoModelForCausalLM, AutoTokenizer
211
+
212
+ print(f"\nLoading model...")
213
+ model = AutoModelForCausalLM.from_pretrained(
214
+ args.model,
215
+ torch_dtype=torch.bfloat16,
216
+ device_map=args.device,
217
+ trust_remote_code=True,
218
+ )
219
+ tokenizer = AutoTokenizer.from_pretrained(args.model, trust_remote_code=True)
220
+ if tokenizer.pad_token is None:
221
+ tokenizer.pad_token = tokenizer.eos_token
222
+ model.eval()
223
+
224
+ print(f"\nLoading SAE...")
225
+ from sae_lens import SAE
226
+
227
+ release = "gemma-scope-2b-pt-res-canonical" if "2b" in args.model.lower() else "gemma-scope-9b-pt-res-canonical"
228
+ sae_id = f"layer_{args.layer}/width_16k/canonical"
229
+ sae, _, _ = SAE.from_pretrained(release=release, sae_id=sae_id, device=args.device)
230
+
231
+ texts = load_sample_texts(args.num_samples)
232
+
233
+ bos_features, stats = detect_bos_features(
234
+ model=model,
235
+ tokenizer=tokenizer,
236
+ sae=sae,
237
+ layer_idx=args.layer,
238
+ device=args.device,
239
+ texts=texts,
240
+ top_k=args.top_k,
241
+ batch_size=args.batch_size,
242
+ )
243
+
244
+ compare_with_known(args.model, bos_features, stats)
245
+
246
+ # Build detected features with their stats
247
+ detected_with_stats = [
248
+ {
249
+ "feature": idx,
250
+ "mean_bos_activation": stats["mean"][idx].item(),
251
+ "variance_bos_activation": stats["variance"][idx].item(),
252
+ "median_bos_activation": stats["median"][idx].item(),
253
+ }
254
+ for idx in bos_features
255
+ ]
256
+
257
+ output = {
258
+ "model": args.model,
259
+ "layer": args.layer,
260
+ "top_k": args.top_k,
261
+ "num_samples": len(texts),
262
+ "detected_bos_features": detected_with_stats,
263
+ "known_bos_features": KNOWN_BOS_FEATURES.get(args.model, []),
264
+ }
265
+
266
+ output_dir = Path(args.output_dir)
267
+ output_dir.mkdir(parents=True, exist_ok=True)
268
+ output_path = output_dir / f"bos_features_{args.model.replace('/', '_')}_layer{args.layer}.json"
269
+ with open(output_path, "w") as f:
270
+ json.dump(output, f, indent=2)
271
+ print(f"\nSaved to {output_path}")
272
+
273
+
274
+ if __name__ == "__main__":
275
+ main()