wisent 0.7.901__py3-none-any.whl → 0.7.1116__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (101) hide show
  1. wisent/__init__.py +1 -1
  2. wisent/comparison/__init__.py +1 -0
  3. wisent/comparison/detect_bos_features.py +275 -0
  4. wisent/comparison/fgaa.py +465 -0
  5. wisent/comparison/lora.py +663 -0
  6. wisent/comparison/lora_dpo.py +604 -0
  7. wisent/comparison/main.py +444 -0
  8. wisent/comparison/ours.py +76 -0
  9. wisent/comparison/reft.py +690 -0
  10. wisent/comparison/sae.py +304 -0
  11. wisent/comparison/utils.py +381 -0
  12. wisent/core/activations/activations_collector.py +3 -2
  13. wisent/core/activations/extraction_strategy.py +8 -4
  14. wisent/core/cli/agent/apply_steering.py +7 -5
  15. wisent/core/cli/agent/train_classifier.py +4 -3
  16. wisent/core/cli/generate_vector_from_task.py +11 -20
  17. wisent/core/cli/get_activations.py +1 -1
  18. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/boolq.py +20 -3
  19. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/cb.py +8 -1
  20. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/truthfulqa_mc1.py +8 -1
  21. wisent/core/parser_arguments/generate_vector_from_task_parser.py +4 -11
  22. wisent/core/parser_arguments/get_activations_parser.py +5 -14
  23. {wisent-0.7.901.dist-info → wisent-0.7.1116.dist-info}/METADATA +5 -1
  24. {wisent-0.7.901.dist-info → wisent-0.7.1116.dist-info}/RECORD +28 -91
  25. wisent/examples/contrastive_pairs/humanization_human_vs_ai.json +0 -2112
  26. wisent/examples/scripts/1/test_basqueglue_evaluation.json +0 -51
  27. wisent/examples/scripts/1/test_basqueglue_pairs.json +0 -14
  28. wisent/examples/scripts/1/test_bec2016eu_evaluation.json +0 -51
  29. wisent/examples/scripts/1/test_bec2016eu_pairs.json +0 -14
  30. wisent/examples/scripts/1/test_belebele_evaluation.json +0 -51
  31. wisent/examples/scripts/1/test_belebele_pairs.json +0 -14
  32. wisent/examples/scripts/1/test_benchmarks_evaluation.json +0 -51
  33. wisent/examples/scripts/1/test_benchmarks_pairs.json +0 -14
  34. wisent/examples/scripts/1/test_bertaqa_evaluation.json +0 -51
  35. wisent/examples/scripts/1/test_bertaqa_pairs.json +0 -14
  36. wisent/examples/scripts/1/test_bhtc_v2_evaluation.json +0 -30
  37. wisent/examples/scripts/1/test_bhtc_v2_pairs.json +0 -8
  38. wisent/examples/scripts/1/test_boolq-seq2seq_evaluation.json +0 -30
  39. wisent/examples/scripts/1/test_boolq-seq2seq_pairs.json +0 -8
  40. wisent/examples/scripts/1/test_cabreu_evaluation.json +0 -30
  41. wisent/examples/scripts/1/test_cabreu_pairs.json +0 -8
  42. wisent/examples/scripts/1/test_careqa_en_evaluation.json +0 -30
  43. wisent/examples/scripts/1/test_careqa_en_pairs.json +0 -8
  44. wisent/examples/scripts/1/test_careqa_evaluation.json +0 -30
  45. wisent/examples/scripts/1/test_careqa_pairs.json +0 -8
  46. wisent/examples/scripts/1/test_catalanqa_evaluation.json +0 -30
  47. wisent/examples/scripts/1/test_catalanqa_pairs.json +0 -8
  48. wisent/examples/scripts/1/test_catcola_evaluation.json +0 -30
  49. wisent/examples/scripts/1/test_catcola_pairs.json +0 -8
  50. wisent/examples/scripts/1/test_chartqa_evaluation.json +0 -30
  51. wisent/examples/scripts/1/test_chartqa_pairs.json +0 -8
  52. wisent/examples/scripts/1/test_claim_stance_topic_evaluation.json +0 -30
  53. wisent/examples/scripts/1/test_claim_stance_topic_pairs.json +0 -8
  54. wisent/examples/scripts/1/test_cnn_dailymail_evaluation.json +0 -30
  55. wisent/examples/scripts/1/test_cnn_dailymail_pairs.json +0 -8
  56. wisent/examples/scripts/1/test_cocoteros_es_evaluation.json +0 -30
  57. wisent/examples/scripts/1/test_cocoteros_es_pairs.json +0 -8
  58. wisent/examples/scripts/1/test_coedit_gec_evaluation.json +0 -30
  59. wisent/examples/scripts/1/test_coedit_gec_pairs.json +0 -8
  60. wisent/examples/scripts/1/test_cola_evaluation.json +0 -30
  61. wisent/examples/scripts/1/test_cola_pairs.json +0 -8
  62. wisent/examples/scripts/1/test_coqcat_evaluation.json +0 -30
  63. wisent/examples/scripts/1/test_coqcat_pairs.json +0 -8
  64. wisent/examples/scripts/1/test_dbpedia_14_evaluation.json +0 -30
  65. wisent/examples/scripts/1/test_dbpedia_14_pairs.json +0 -8
  66. wisent/examples/scripts/1/test_epec_koref_bin_evaluation.json +0 -30
  67. wisent/examples/scripts/1/test_epec_koref_bin_pairs.json +0 -8
  68. wisent/examples/scripts/1/test_ethos_binary_evaluation.json +0 -30
  69. wisent/examples/scripts/1/test_ethos_binary_pairs.json +0 -8
  70. wisent/examples/scripts/2/test_afrimgsm_direct_amh_evaluation.json +0 -30
  71. wisent/examples/scripts/2/test_afrimgsm_direct_amh_pairs.json +0 -8
  72. wisent/examples/scripts/2/test_afrimmlu_direct_amh_evaluation.json +0 -30
  73. wisent/examples/scripts/2/test_afrimmlu_direct_amh_pairs.json +0 -8
  74. wisent/examples/scripts/2/test_afrixnli_en_direct_amh_evaluation.json +0 -30
  75. wisent/examples/scripts/2/test_afrixnli_en_direct_amh_pairs.json +0 -8
  76. wisent/examples/scripts/2/test_arc_ar_evaluation.json +0 -30
  77. wisent/examples/scripts/2/test_arc_ar_pairs.json +0 -8
  78. wisent/examples/scripts/2/test_atis_evaluation.json +0 -30
  79. wisent/examples/scripts/2/test_atis_pairs.json +0 -8
  80. wisent/examples/scripts/2/test_babi_evaluation.json +0 -30
  81. wisent/examples/scripts/2/test_babi_pairs.json +0 -8
  82. wisent/examples/scripts/2/test_babilong_evaluation.json +0 -30
  83. wisent/examples/scripts/2/test_babilong_pairs.json +0 -8
  84. wisent/examples/scripts/2/test_bangla_mmlu_evaluation.json +0 -30
  85. wisent/examples/scripts/2/test_bangla_mmlu_pairs.json +0 -8
  86. wisent/examples/scripts/2/test_basque-glue_pairs.json +0 -14
  87. wisent/examples/scripts/generate_paper_data.py +0 -384
  88. wisent/examples/scripts/intervention_validation.py +0 -626
  89. wisent/examples/scripts/results/test_AraDiCE_ArabicMMLU_lev_evaluation.json +0 -324
  90. wisent/examples/scripts/results/test_AraDiCE_ArabicMMLU_lev_pairs.json +0 -92
  91. wisent/examples/scripts/results/test_aexams_IslamicStudies_evaluation.json +0 -324
  92. wisent/examples/scripts/results/test_aexams_IslamicStudies_pairs.json +0 -92
  93. wisent/examples/scripts/results/test_afrimgsm_pairs.json +0 -92
  94. wisent/examples/scripts/results/test_afrimmlu_evaluation.json +0 -324
  95. wisent/examples/scripts/results/test_afrimmlu_pairs.json +0 -92
  96. wisent/examples/scripts/threshold_analysis.py +0 -434
  97. wisent/examples/scripts/visualization_gallery.py +0 -582
  98. {wisent-0.7.901.dist-info → wisent-0.7.1116.dist-info}/WHEEL +0 -0
  99. {wisent-0.7.901.dist-info → wisent-0.7.1116.dist-info}/entry_points.txt +0 -0
  100. {wisent-0.7.901.dist-info → wisent-0.7.1116.dist-info}/licenses/LICENSE +0 -0
  101. {wisent-0.7.901.dist-info → wisent-0.7.1116.dist-info}/top_level.txt +0 -0
@@ -1,582 +0,0 @@
1
- """
2
- Visualization Gallery for RepScan.
3
-
4
- Creates publication-quality figures:
5
- 1. Hero figure (method overview + key results)
6
- 2. t-SNE gallery (LINEAR, NONLINEAR, NO_SIGNAL examples)
7
- 3. Layer-wise accuracy curves
8
- 4. Decision boundary visualizations
9
-
10
- Usage:
11
- python -m wisent.examples.scripts.visualization_gallery --model Qwen/Qwen3-8B
12
- """
13
-
14
- import argparse
15
- import json
16
- import subprocess
17
- from pathlib import Path
18
- from typing import Dict, List, Any, Optional, Tuple
19
- import random
20
-
21
- import torch
22
- import numpy as np
23
-
24
- try:
25
- import matplotlib.pyplot as plt
26
- import matplotlib.patches as mpatches
27
- from matplotlib.gridspec import GridSpec
28
- HAS_MATPLOTLIB = True
29
- except ImportError:
30
- HAS_MATPLOTLIB = False
31
- print("Warning: matplotlib not installed. Visualizations will be skipped.")
32
-
33
- try:
34
- from sklearn.manifold import TSNE
35
- from sklearn.decomposition import PCA
36
- HAS_SKLEARN = True
37
- except ImportError:
38
- HAS_SKLEARN = False
39
-
40
- S3_BUCKET = "wisent-bucket"
41
- S3_PREFIX = "visualizations"
42
-
43
-
44
- def s3_upload_file(local_path: Path, model_name: str) -> None:
45
- """Upload a single file to S3."""
46
- model_prefix = model_name.replace('/', '_')
47
- s3_path = f"s3://{S3_BUCKET}/{S3_PREFIX}/{model_prefix}/{local_path.name}"
48
- try:
49
- subprocess.run(
50
- ["aws", "s3", "cp", str(local_path), s3_path, "--quiet"],
51
- check=True,
52
- capture_output=True,
53
- )
54
- print(f" Uploaded to S3: {s3_path}")
55
- except Exception as e:
56
- print(f" S3 upload failed: {e}")
57
-
58
-
59
- def load_diagnosis_results(model_name: str, output_dir: Path) -> Dict[str, Any]:
60
- """Load diagnosis results from S3/local."""
61
- model_prefix = model_name.replace('/', '_')
62
-
63
- try:
64
- subprocess.run(
65
- ["aws", "s3", "sync",
66
- f"s3://{S3_BUCKET}/direction_discovery/{model_prefix}/",
67
- str(output_dir / "diagnosis"),
68
- "--quiet"],
69
- check=False,
70
- capture_output=True,
71
- )
72
- except Exception:
73
- pass
74
-
75
- results = {}
76
- diagnosis_dir = output_dir / "diagnosis"
77
- if diagnosis_dir.exists():
78
- for f in diagnosis_dir.glob(f"{model_prefix}_*.json"):
79
- if "summary" not in f.name:
80
- category = f.stem.replace(f"{model_prefix}_", "")
81
- with open(f) as fp:
82
- results[category] = json.load(fp)
83
-
84
- return results
85
-
86
-
87
- def select_representative_benchmarks(
88
- diagnosis_results: Dict[str, Any],
89
- n_per_type: int = 2,
90
- ) -> Dict[str, List[str]]:
91
- """
92
- Select representative benchmarks for each diagnosis type.
93
-
94
- Args:
95
- diagnosis_results: Loaded diagnosis results
96
- n_per_type: Number of benchmarks per type
97
-
98
- Returns:
99
- Dict with keys 'LINEAR', 'NONLINEAR', 'NO_SIGNAL'
100
- """
101
- by_diagnosis = {"LINEAR": [], "NONLINEAR": [], "NO_SIGNAL": []}
102
-
103
- for category, data in diagnosis_results.items():
104
- results = data.get("results", [])
105
- seen = set()
106
-
107
- for r in results:
108
- bench = r["benchmark"]
109
- if bench in seen:
110
- continue
111
- seen.add(bench)
112
-
113
- signal = r["signal_strength"]
114
- linear = r["linear_probe_accuracy"]
115
- knn = r["nonlinear_metrics"]["knn_accuracy_k10"]
116
-
117
- if signal < 0.6:
118
- by_diagnosis["NO_SIGNAL"].append((bench, signal, linear, knn))
119
- elif linear > 0.6 and (signal - linear) < 0.15:
120
- by_diagnosis["LINEAR"].append((bench, signal, linear, knn))
121
- else:
122
- by_diagnosis["NONLINEAR"].append((bench, signal, linear, knn))
123
-
124
- # Select best examples (highest separation for LINEAR/NONLINEAR, lowest for NO_SIGNAL)
125
- selected = {}
126
-
127
- # LINEAR: highest linear probe accuracy
128
- by_diagnosis["LINEAR"].sort(key=lambda x: x[2], reverse=True)
129
- selected["LINEAR"] = [b[0] for b in by_diagnosis["LINEAR"][:n_per_type]]
130
-
131
- # NONLINEAR: highest gap between kNN and linear
132
- by_diagnosis["NONLINEAR"].sort(key=lambda x: x[3] - x[2], reverse=True)
133
- selected["NONLINEAR"] = [b[0] for b in by_diagnosis["NONLINEAR"][:n_per_type]]
134
-
135
- # NO_SIGNAL: lowest signal
136
- by_diagnosis["NO_SIGNAL"].sort(key=lambda x: x[1])
137
- selected["NO_SIGNAL"] = [b[0] for b in by_diagnosis["NO_SIGNAL"][:n_per_type]]
138
-
139
- return selected
140
-
141
-
142
- def create_tsne_plot(
143
- pos_activations: torch.Tensor,
144
- neg_activations: torch.Tensor,
145
- title: str,
146
- ax: plt.Axes,
147
- diagnosis: str,
148
- ) -> None:
149
- """
150
- Create t-SNE visualization on given axes.
151
-
152
- Args:
153
- pos_activations: [N, D] positive class
154
- neg_activations: [N, D] negative class
155
- title: Plot title
156
- ax: Matplotlib axes
157
- diagnosis: 'LINEAR', 'NONLINEAR', or 'NO_SIGNAL'
158
- """
159
- if not HAS_SKLEARN or not HAS_MATPLOTLIB:
160
- return
161
-
162
- pos = pos_activations.float().cpu().numpy()
163
- neg = neg_activations.float().cpu().numpy()
164
-
165
- X = np.vstack([pos, neg])
166
- labels = np.array([1] * len(pos) + [0] * len(neg))
167
-
168
- # Reduce dimensionality with PCA first for speed
169
- if X.shape[1] > 50:
170
- pca = PCA(n_components=50)
171
- X = pca.fit_transform(X)
172
-
173
- # t-SNE
174
- tsne = TSNE(n_components=2, perplexity=min(30, len(X) // 4), random_state=42)
175
- X_2d = tsne.fit_transform(X)
176
-
177
- # Color scheme based on diagnosis
178
- colors = {
179
- "LINEAR": ("#2ecc71", "#e74c3c"), # Green/Red
180
- "NONLINEAR": ("#3498db", "#e67e22"), # Blue/Orange
181
- "NO_SIGNAL": ("#95a5a6", "#7f8c8d"), # Gray shades
182
- }
183
- pos_color, neg_color = colors.get(diagnosis, ("#2ecc71", "#e74c3c"))
184
-
185
- # Plot
186
- ax.scatter(X_2d[labels == 1, 0], X_2d[labels == 1, 1],
187
- c=pos_color, label='Positive', alpha=0.7, s=30)
188
- ax.scatter(X_2d[labels == 0, 0], X_2d[labels == 0, 1],
189
- c=neg_color, label='Negative', alpha=0.7, s=30)
190
-
191
- ax.set_title(title, fontsize=12, fontweight='bold')
192
- ax.set_xticks([])
193
- ax.set_yticks([])
194
-
195
- # Add diagnosis label
196
- ax.text(0.02, 0.98, diagnosis, transform=ax.transAxes,
197
- fontsize=10, fontweight='bold', va='top',
198
- bbox=dict(boxstyle='round', facecolor='white', alpha=0.8))
199
-
200
-
201
- def create_hero_figure(
202
- diagnosis_results: Dict[str, Any],
203
- output_path: Path,
204
- model_name: str,
205
- ) -> None:
206
- """
207
- Create hero figure for paper.
208
-
209
- Layout:
210
- [Pipeline Diagram] [Key Results Pie] [Example t-SNE]
211
-
212
- Args:
213
- diagnosis_results: Loaded diagnosis results
214
- output_path: Where to save figure
215
- model_name: Model name for title
216
- """
217
- if not HAS_MATPLOTLIB:
218
- print("Skipping hero figure: matplotlib not installed")
219
- return
220
-
221
- fig = plt.figure(figsize=(15, 5))
222
- gs = GridSpec(1, 3, width_ratios=[1.2, 1, 1.2])
223
-
224
- # Panel 1: Pipeline diagram (simplified)
225
- ax1 = fig.add_subplot(gs[0])
226
- ax1.set_xlim(0, 10)
227
- ax1.set_ylim(0, 10)
228
- ax1.axis('off')
229
- ax1.set_title('RepScan Pipeline', fontsize=14, fontweight='bold')
230
-
231
- # Draw boxes
232
- boxes = [
233
- (1, 7, 'Contrastive\nPairs'),
234
- (4, 7, 'Layer\nScan'),
235
- (7, 7, 'Metrics'),
236
- (1, 3, 'kNN'),
237
- (4, 3, 'Linear\nProbe'),
238
- (7, 3, 'Diagnosis'),
239
- ]
240
-
241
- for x, y, label in boxes:
242
- rect = mpatches.FancyBboxPatch((x-0.8, y-0.6), 1.6, 1.2,
243
- boxstyle="round,pad=0.1",
244
- facecolor='lightblue', edgecolor='black')
245
- ax1.add_patch(rect)
246
- ax1.text(x, y, label, ha='center', va='center', fontsize=9)
247
-
248
- # Draw arrows
249
- ax1.annotate('', xy=(3.2, 7), xytext=(1.8, 7),
250
- arrowprops=dict(arrowstyle='->', color='black'))
251
- ax1.annotate('', xy=(6.2, 7), xytext=(4.8, 7),
252
- arrowprops=dict(arrowstyle='->', color='black'))
253
- ax1.annotate('', xy=(1, 6.4), xytext=(1, 3.6),
254
- arrowprops=dict(arrowstyle='->', color='black'))
255
- ax1.annotate('', xy=(4, 6.4), xytext=(4, 3.6),
256
- arrowprops=dict(arrowstyle='->', color='black'))
257
- ax1.annotate('', xy=(6.2, 3), xytext=(4.8, 3),
258
- arrowprops=dict(arrowstyle='->', color='black'))
259
- ax1.annotate('', xy=(6.2, 3), xytext=(1.8, 3),
260
- arrowprops=dict(arrowstyle='->', color='black'))
261
-
262
- # Panel 2: Diagnosis distribution pie chart
263
- ax2 = fig.add_subplot(gs[1])
264
-
265
- # Count diagnoses
266
- counts = {"LINEAR": 0, "NONLINEAR": 0, "NO_SIGNAL": 0}
267
- for category, data in diagnosis_results.items():
268
- results = data.get("results", [])
269
- seen = set()
270
- for r in results:
271
- bench = r["benchmark"]
272
- if bench in seen:
273
- continue
274
- seen.add(bench)
275
-
276
- signal = r["signal_strength"]
277
- linear = r["linear_probe_accuracy"]
278
-
279
- if signal < 0.6:
280
- counts["NO_SIGNAL"] += 1
281
- elif linear > 0.6 and (signal - linear) < 0.15:
282
- counts["LINEAR"] += 1
283
- else:
284
- counts["NONLINEAR"] += 1
285
-
286
- labels = list(counts.keys())
287
- sizes = list(counts.values())
288
- colors = ['#2ecc71', '#3498db', '#95a5a6']
289
- explode = (0.05, 0.05, 0.05)
290
-
291
- ax2.pie(sizes, explode=explode, labels=labels, colors=colors,
292
- autopct='%1.1f%%', shadow=True, startangle=90)
293
- ax2.set_title(f'Diagnosis Distribution\n({sum(sizes)} benchmarks)',
294
- fontsize=14, fontweight='bold')
295
-
296
- # Panel 3: Key metrics bar chart
297
- ax3 = fig.add_subplot(gs[2])
298
-
299
- # Compute average metrics by diagnosis
300
- metrics_by_diag = {
301
- "LINEAR": {"knn": [], "linear": [], "signal": []},
302
- "NONLINEAR": {"knn": [], "linear": [], "signal": []},
303
- "NO_SIGNAL": {"knn": [], "linear": [], "signal": []},
304
- }
305
-
306
- for category, data in diagnosis_results.items():
307
- results = data.get("results", [])
308
- for r in results:
309
- signal = r["signal_strength"]
310
- linear = r["linear_probe_accuracy"]
311
- knn = r["nonlinear_metrics"]["knn_accuracy_k10"]
312
-
313
- if signal < 0.6:
314
- diag = "NO_SIGNAL"
315
- elif linear > 0.6 and (signal - linear) < 0.15:
316
- diag = "LINEAR"
317
- else:
318
- diag = "NONLINEAR"
319
-
320
- metrics_by_diag[diag]["knn"].append(knn)
321
- metrics_by_diag[diag]["linear"].append(linear)
322
- metrics_by_diag[diag]["signal"].append(signal)
323
-
324
- # Create grouped bar chart
325
- x = np.arange(3)
326
- width = 0.25
327
-
328
- knn_means = [np.mean(metrics_by_diag[d]["knn"]) if metrics_by_diag[d]["knn"] else 0.5
329
- for d in ["LINEAR", "NONLINEAR", "NO_SIGNAL"]]
330
- linear_means = [np.mean(metrics_by_diag[d]["linear"]) if metrics_by_diag[d]["linear"] else 0.5
331
- for d in ["LINEAR", "NONLINEAR", "NO_SIGNAL"]]
332
- signal_means = [np.mean(metrics_by_diag[d]["signal"]) if metrics_by_diag[d]["signal"] else 0.5
333
- for d in ["LINEAR", "NONLINEAR", "NO_SIGNAL"]]
334
-
335
- ax3.bar(x - width, knn_means, width, label='kNN Acc', color='#3498db')
336
- ax3.bar(x, linear_means, width, label='Linear Probe', color='#2ecc71')
337
- ax3.bar(x + width, signal_means, width, label='MLP Signal', color='#e74c3c')
338
-
339
- ax3.set_ylabel('Accuracy')
340
- ax3.set_title('Metrics by Diagnosis', fontsize=14, fontweight='bold')
341
- ax3.set_xticks(x)
342
- ax3.set_xticklabels(['LINEAR', 'NONLINEAR', 'NO_SIGNAL'])
343
- ax3.legend(loc='upper right')
344
- ax3.set_ylim(0, 1)
345
- ax3.axhline(y=0.6, color='gray', linestyle='--', alpha=0.5)
346
-
347
- plt.tight_layout()
348
- plt.savefig(output_path, dpi=300, bbox_inches='tight')
349
- plt.close()
350
-
351
- print(f" Saved hero figure: {output_path}")
352
-
353
-
354
- def create_tsne_gallery(
355
- model: "WisentModel",
356
- selected_benchmarks: Dict[str, List[str]],
357
- output_path: Path,
358
- model_name: str,
359
- ) -> None:
360
- """
361
- Create t-SNE gallery figure.
362
-
363
- Layout: 2x3 grid showing examples from each diagnosis type
364
-
365
- Args:
366
- model: WisentModel instance
367
- selected_benchmarks: Dict with benchmark names by diagnosis
368
- output_path: Where to save figure
369
- model_name: Model name for title
370
- """
371
- if not HAS_MATPLOTLIB or not HAS_SKLEARN:
372
- print("Skipping t-SNE gallery: required packages not installed")
373
- return
374
-
375
- from wisent.core.activations.extraction_strategy import ExtractionStrategy
376
- from wisent.core.activations.activation_cache import ActivationCache, collect_and_cache_activations
377
- from lm_eval.tasks import TaskManager
378
- from wisent.core.contrastive_pairs.lm_eval_pairs.lm_task_pairs_generation import lm_build_contrastive_pairs
379
-
380
- fig, axes = plt.subplots(2, 3, figsize=(12, 8))
381
- fig.suptitle(f't-SNE Visualization Gallery\n{model_name}', fontsize=14, fontweight='bold')
382
-
383
- cache_dir = f"/tmp/wisent_viz_cache_{model_name.replace('/', '_')}"
384
- cache = ActivationCache(cache_dir)
385
- tm = TaskManager()
386
- strategy = ExtractionStrategy.CHAT_LAST
387
-
388
- row = 0
389
- for diagnosis in ["LINEAR", "NONLINEAR", "NO_SIGNAL"]:
390
- benchmarks = selected_benchmarks.get(diagnosis, [])
391
-
392
- for col, benchmark in enumerate(benchmarks[:2]):
393
- ax = axes[col, ["LINEAR", "NONLINEAR", "NO_SIGNAL"].index(diagnosis)]
394
-
395
- try:
396
- # Load pairs
397
- try:
398
- task_dict = tm.load_task_or_group([benchmark])
399
- task = list(task_dict.values())[0]
400
- except Exception:
401
- task = None
402
-
403
- pairs = lm_build_contrastive_pairs(benchmark, task, limit=50)
404
-
405
- if len(pairs) < 20:
406
- ax.text(0.5, 0.5, f'{benchmark}\n(insufficient data)',
407
- ha='center', va='center', transform=ax.transAxes)
408
- ax.axis('off')
409
- continue
410
-
411
- # Get activations
412
- cached = collect_and_cache_activations(
413
- model=model,
414
- pairs=pairs,
415
- benchmark=benchmark,
416
- strategy=strategy,
417
- cache=cache,
418
- show_progress=False,
419
- )
420
-
421
- # Use middle layer
422
- middle_layer = str(model.num_layers // 2)
423
- pos_acts = cached.get_positive_activations(middle_layer)
424
- neg_acts = cached.get_negative_activations(middle_layer)
425
-
426
- # Create t-SNE plot
427
- create_tsne_plot(pos_acts, neg_acts, benchmark, ax, diagnosis)
428
-
429
- except Exception as e:
430
- ax.text(0.5, 0.5, f'{benchmark}\n(error: {str(e)[:30]})',
431
- ha='center', va='center', transform=ax.transAxes)
432
- ax.axis('off')
433
-
434
- # Add column labels
435
- for idx, diagnosis in enumerate(["LINEAR", "NONLINEAR", "NO_SIGNAL"]):
436
- axes[0, idx].set_xlabel(diagnosis, fontsize=12, fontweight='bold')
437
-
438
- plt.tight_layout()
439
- plt.savefig(output_path, dpi=300, bbox_inches='tight')
440
- plt.close()
441
-
442
- print(f" Saved t-SNE gallery: {output_path}")
443
-
444
-
445
- def create_layer_accuracy_curves(
446
- diagnosis_results: Dict[str, Any],
447
- output_path: Path,
448
- model_name: str,
449
- ) -> None:
450
- """
451
- Create layer-wise accuracy curves.
452
-
453
- Shows how kNN and linear probe accuracy change across layers.
454
-
455
- Args:
456
- diagnosis_results: Loaded diagnosis results
457
- output_path: Where to save figure
458
- model_name: Model name for title
459
- """
460
- if not HAS_MATPLOTLIB:
461
- print("Skipping layer curves: matplotlib not installed")
462
- return
463
-
464
- fig, axes = plt.subplots(1, 3, figsize=(15, 5))
465
- fig.suptitle(f'Layer-wise Accuracy Curves\n{model_name}', fontsize=14, fontweight='bold')
466
-
467
- # Collect layer-wise data (we don't have per-layer data in current results,
468
- # so we'll create placeholder showing the concept)
469
-
470
- # This would need per-layer results which we can add to discover_directions
471
- # For now, create example curves
472
-
473
- for idx, (diagnosis, color) in enumerate([
474
- ("LINEAR", "#2ecc71"),
475
- ("NONLINEAR", "#3498db"),
476
- ("NO_SIGNAL", "#95a5a6")
477
- ]):
478
- ax = axes[idx]
479
-
480
- # Example curves (would be replaced with real data)
481
- layers = np.arange(1, 33)
482
-
483
- if diagnosis == "LINEAR":
484
- knn = 0.5 + 0.4 * np.exp(-(layers - 16)**2 / 100)
485
- linear = 0.5 + 0.35 * np.exp(-(layers - 16)**2 / 100)
486
- elif diagnosis == "NONLINEAR":
487
- knn = 0.5 + 0.35 * np.exp(-(layers - 16)**2 / 100)
488
- linear = 0.5 + 0.1 * np.exp(-(layers - 16)**2 / 100)
489
- else:
490
- knn = 0.5 + 0.05 * np.random.randn(len(layers))
491
- linear = 0.5 + 0.05 * np.random.randn(len(layers))
492
-
493
- ax.plot(layers, knn, 'b-', linewidth=2, label='kNN-10')
494
- ax.plot(layers, linear, 'g--', linewidth=2, label='Linear Probe')
495
- ax.fill_between(layers, knn, linear, alpha=0.3, color='yellow', label='Gap')
496
-
497
- ax.set_xlabel('Layer')
498
- ax.set_ylabel('Accuracy')
499
- ax.set_title(diagnosis, fontsize=12, fontweight='bold')
500
- ax.legend()
501
- ax.set_ylim(0.4, 1.0)
502
- ax.axhline(y=0.6, color='gray', linestyle='--', alpha=0.5)
503
- ax.set_xlim(1, 32)
504
-
505
- plt.tight_layout()
506
- plt.savefig(output_path, dpi=300, bbox_inches='tight')
507
- plt.close()
508
-
509
- print(f" Saved layer curves: {output_path}")
510
-
511
-
512
- def run_visualization(model_name: str, skip_tsne: bool = False):
513
- """
514
- Generate all visualizations.
515
-
516
- Args:
517
- model_name: Model to visualize
518
- skip_tsne: Skip t-SNE (requires model loading)
519
- """
520
- print("=" * 70)
521
- print("VISUALIZATION GALLERY")
522
- print("=" * 70)
523
- print(f"Model: {model_name}")
524
-
525
- output_dir = Path("/tmp/visualizations")
526
- output_dir.mkdir(parents=True, exist_ok=True)
527
-
528
- # Load diagnosis results
529
- diagnosis_results = load_diagnosis_results(model_name, output_dir)
530
- if not diagnosis_results:
531
- print("ERROR: No diagnosis results found.")
532
- return
533
-
534
- print(f"Loaded results for {len(diagnosis_results)} categories")
535
-
536
- model_prefix = model_name.replace('/', '_')
537
-
538
- # 1. Hero figure
539
- print("\n1. Creating hero figure...")
540
- hero_path = output_dir / f"{model_prefix}_hero_figure.png"
541
- create_hero_figure(diagnosis_results, hero_path, model_name)
542
- s3_upload_file(hero_path, model_name)
543
-
544
- # 2. Layer accuracy curves
545
- print("\n2. Creating layer accuracy curves...")
546
- curves_path = output_dir / f"{model_prefix}_layer_curves.png"
547
- create_layer_accuracy_curves(diagnosis_results, curves_path, model_name)
548
- s3_upload_file(curves_path, model_name)
549
-
550
- # 3. t-SNE gallery (requires model)
551
- if not skip_tsne:
552
- print("\n3. Creating t-SNE gallery...")
553
-
554
- from wisent.core.models.wisent_model import WisentModel
555
-
556
- print(f" Loading model: {model_name}")
557
- model = WisentModel(model_name, device="cuda")
558
-
559
- selected = select_representative_benchmarks(diagnosis_results, n_per_type=2)
560
- print(f" Selected benchmarks: {selected}")
561
-
562
- tsne_path = output_dir / f"{model_prefix}_tsne_gallery.png"
563
- create_tsne_gallery(model, selected, tsne_path, model_name)
564
- s3_upload_file(tsne_path, model_name)
565
-
566
- del model
567
- else:
568
- print("\n3. Skipping t-SNE gallery (--skip-tsne)")
569
-
570
- print("\n" + "=" * 70)
571
- print("VISUALIZATION COMPLETE")
572
- print("=" * 70)
573
- print(f"Figures saved to: {output_dir}")
574
-
575
-
576
- if __name__ == "__main__":
577
- parser = argparse.ArgumentParser(description="Visualization gallery for RepScan")
578
- parser.add_argument("--model", type=str, default="Qwen/Qwen3-8B", help="Model to visualize")
579
- parser.add_argument("--skip-tsne", action="store_true", help="Skip t-SNE (doesn't require model)")
580
- args = parser.parse_args()
581
-
582
- run_visualization(args.model, skip_tsne=args.skip_tsne)