wisent 0.7.901__py3-none-any.whl → 0.7.1116__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- wisent/__init__.py +1 -1
- wisent/comparison/__init__.py +1 -0
- wisent/comparison/detect_bos_features.py +275 -0
- wisent/comparison/fgaa.py +465 -0
- wisent/comparison/lora.py +663 -0
- wisent/comparison/lora_dpo.py +604 -0
- wisent/comparison/main.py +444 -0
- wisent/comparison/ours.py +76 -0
- wisent/comparison/reft.py +690 -0
- wisent/comparison/sae.py +304 -0
- wisent/comparison/utils.py +381 -0
- wisent/core/activations/activations_collector.py +3 -2
- wisent/core/activations/extraction_strategy.py +8 -4
- wisent/core/cli/agent/apply_steering.py +7 -5
- wisent/core/cli/agent/train_classifier.py +4 -3
- wisent/core/cli/generate_vector_from_task.py +11 -20
- wisent/core/cli/get_activations.py +1 -1
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/boolq.py +20 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/cb.py +8 -1
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/truthfulqa_mc1.py +8 -1
- wisent/core/parser_arguments/generate_vector_from_task_parser.py +4 -11
- wisent/core/parser_arguments/get_activations_parser.py +5 -14
- {wisent-0.7.901.dist-info → wisent-0.7.1116.dist-info}/METADATA +5 -1
- {wisent-0.7.901.dist-info → wisent-0.7.1116.dist-info}/RECORD +28 -91
- wisent/examples/contrastive_pairs/humanization_human_vs_ai.json +0 -2112
- wisent/examples/scripts/1/test_basqueglue_evaluation.json +0 -51
- wisent/examples/scripts/1/test_basqueglue_pairs.json +0 -14
- wisent/examples/scripts/1/test_bec2016eu_evaluation.json +0 -51
- wisent/examples/scripts/1/test_bec2016eu_pairs.json +0 -14
- wisent/examples/scripts/1/test_belebele_evaluation.json +0 -51
- wisent/examples/scripts/1/test_belebele_pairs.json +0 -14
- wisent/examples/scripts/1/test_benchmarks_evaluation.json +0 -51
- wisent/examples/scripts/1/test_benchmarks_pairs.json +0 -14
- wisent/examples/scripts/1/test_bertaqa_evaluation.json +0 -51
- wisent/examples/scripts/1/test_bertaqa_pairs.json +0 -14
- wisent/examples/scripts/1/test_bhtc_v2_evaluation.json +0 -30
- wisent/examples/scripts/1/test_bhtc_v2_pairs.json +0 -8
- wisent/examples/scripts/1/test_boolq-seq2seq_evaluation.json +0 -30
- wisent/examples/scripts/1/test_boolq-seq2seq_pairs.json +0 -8
- wisent/examples/scripts/1/test_cabreu_evaluation.json +0 -30
- wisent/examples/scripts/1/test_cabreu_pairs.json +0 -8
- wisent/examples/scripts/1/test_careqa_en_evaluation.json +0 -30
- wisent/examples/scripts/1/test_careqa_en_pairs.json +0 -8
- wisent/examples/scripts/1/test_careqa_evaluation.json +0 -30
- wisent/examples/scripts/1/test_careqa_pairs.json +0 -8
- wisent/examples/scripts/1/test_catalanqa_evaluation.json +0 -30
- wisent/examples/scripts/1/test_catalanqa_pairs.json +0 -8
- wisent/examples/scripts/1/test_catcola_evaluation.json +0 -30
- wisent/examples/scripts/1/test_catcola_pairs.json +0 -8
- wisent/examples/scripts/1/test_chartqa_evaluation.json +0 -30
- wisent/examples/scripts/1/test_chartqa_pairs.json +0 -8
- wisent/examples/scripts/1/test_claim_stance_topic_evaluation.json +0 -30
- wisent/examples/scripts/1/test_claim_stance_topic_pairs.json +0 -8
- wisent/examples/scripts/1/test_cnn_dailymail_evaluation.json +0 -30
- wisent/examples/scripts/1/test_cnn_dailymail_pairs.json +0 -8
- wisent/examples/scripts/1/test_cocoteros_es_evaluation.json +0 -30
- wisent/examples/scripts/1/test_cocoteros_es_pairs.json +0 -8
- wisent/examples/scripts/1/test_coedit_gec_evaluation.json +0 -30
- wisent/examples/scripts/1/test_coedit_gec_pairs.json +0 -8
- wisent/examples/scripts/1/test_cola_evaluation.json +0 -30
- wisent/examples/scripts/1/test_cola_pairs.json +0 -8
- wisent/examples/scripts/1/test_coqcat_evaluation.json +0 -30
- wisent/examples/scripts/1/test_coqcat_pairs.json +0 -8
- wisent/examples/scripts/1/test_dbpedia_14_evaluation.json +0 -30
- wisent/examples/scripts/1/test_dbpedia_14_pairs.json +0 -8
- wisent/examples/scripts/1/test_epec_koref_bin_evaluation.json +0 -30
- wisent/examples/scripts/1/test_epec_koref_bin_pairs.json +0 -8
- wisent/examples/scripts/1/test_ethos_binary_evaluation.json +0 -30
- wisent/examples/scripts/1/test_ethos_binary_pairs.json +0 -8
- wisent/examples/scripts/2/test_afrimgsm_direct_amh_evaluation.json +0 -30
- wisent/examples/scripts/2/test_afrimgsm_direct_amh_pairs.json +0 -8
- wisent/examples/scripts/2/test_afrimmlu_direct_amh_evaluation.json +0 -30
- wisent/examples/scripts/2/test_afrimmlu_direct_amh_pairs.json +0 -8
- wisent/examples/scripts/2/test_afrixnli_en_direct_amh_evaluation.json +0 -30
- wisent/examples/scripts/2/test_afrixnli_en_direct_amh_pairs.json +0 -8
- wisent/examples/scripts/2/test_arc_ar_evaluation.json +0 -30
- wisent/examples/scripts/2/test_arc_ar_pairs.json +0 -8
- wisent/examples/scripts/2/test_atis_evaluation.json +0 -30
- wisent/examples/scripts/2/test_atis_pairs.json +0 -8
- wisent/examples/scripts/2/test_babi_evaluation.json +0 -30
- wisent/examples/scripts/2/test_babi_pairs.json +0 -8
- wisent/examples/scripts/2/test_babilong_evaluation.json +0 -30
- wisent/examples/scripts/2/test_babilong_pairs.json +0 -8
- wisent/examples/scripts/2/test_bangla_mmlu_evaluation.json +0 -30
- wisent/examples/scripts/2/test_bangla_mmlu_pairs.json +0 -8
- wisent/examples/scripts/2/test_basque-glue_pairs.json +0 -14
- wisent/examples/scripts/generate_paper_data.py +0 -384
- wisent/examples/scripts/intervention_validation.py +0 -626
- wisent/examples/scripts/results/test_AraDiCE_ArabicMMLU_lev_evaluation.json +0 -324
- wisent/examples/scripts/results/test_AraDiCE_ArabicMMLU_lev_pairs.json +0 -92
- wisent/examples/scripts/results/test_aexams_IslamicStudies_evaluation.json +0 -324
- wisent/examples/scripts/results/test_aexams_IslamicStudies_pairs.json +0 -92
- wisent/examples/scripts/results/test_afrimgsm_pairs.json +0 -92
- wisent/examples/scripts/results/test_afrimmlu_evaluation.json +0 -324
- wisent/examples/scripts/results/test_afrimmlu_pairs.json +0 -92
- wisent/examples/scripts/threshold_analysis.py +0 -434
- wisent/examples/scripts/visualization_gallery.py +0 -582
- {wisent-0.7.901.dist-info → wisent-0.7.1116.dist-info}/WHEEL +0 -0
- {wisent-0.7.901.dist-info → wisent-0.7.1116.dist-info}/entry_points.txt +0 -0
- {wisent-0.7.901.dist-info → wisent-0.7.1116.dist-info}/licenses/LICENSE +0 -0
- {wisent-0.7.901.dist-info → wisent-0.7.1116.dist-info}/top_level.txt +0 -0
|
@@ -1,582 +0,0 @@
|
|
|
1
|
-
"""
|
|
2
|
-
Visualization Gallery for RepScan.
|
|
3
|
-
|
|
4
|
-
Creates publication-quality figures:
|
|
5
|
-
1. Hero figure (method overview + key results)
|
|
6
|
-
2. t-SNE gallery (LINEAR, NONLINEAR, NO_SIGNAL examples)
|
|
7
|
-
3. Layer-wise accuracy curves
|
|
8
|
-
4. Decision boundary visualizations
|
|
9
|
-
|
|
10
|
-
Usage:
|
|
11
|
-
python -m wisent.examples.scripts.visualization_gallery --model Qwen/Qwen3-8B
|
|
12
|
-
"""
|
|
13
|
-
|
|
14
|
-
import argparse
|
|
15
|
-
import json
|
|
16
|
-
import subprocess
|
|
17
|
-
from pathlib import Path
|
|
18
|
-
from typing import Dict, List, Any, Optional, Tuple
|
|
19
|
-
import random
|
|
20
|
-
|
|
21
|
-
import torch
|
|
22
|
-
import numpy as np
|
|
23
|
-
|
|
24
|
-
try:
|
|
25
|
-
import matplotlib.pyplot as plt
|
|
26
|
-
import matplotlib.patches as mpatches
|
|
27
|
-
from matplotlib.gridspec import GridSpec
|
|
28
|
-
HAS_MATPLOTLIB = True
|
|
29
|
-
except ImportError:
|
|
30
|
-
HAS_MATPLOTLIB = False
|
|
31
|
-
print("Warning: matplotlib not installed. Visualizations will be skipped.")
|
|
32
|
-
|
|
33
|
-
try:
|
|
34
|
-
from sklearn.manifold import TSNE
|
|
35
|
-
from sklearn.decomposition import PCA
|
|
36
|
-
HAS_SKLEARN = True
|
|
37
|
-
except ImportError:
|
|
38
|
-
HAS_SKLEARN = False
|
|
39
|
-
|
|
40
|
-
S3_BUCKET = "wisent-bucket"
|
|
41
|
-
S3_PREFIX = "visualizations"
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
def s3_upload_file(local_path: Path, model_name: str) -> None:
|
|
45
|
-
"""Upload a single file to S3."""
|
|
46
|
-
model_prefix = model_name.replace('/', '_')
|
|
47
|
-
s3_path = f"s3://{S3_BUCKET}/{S3_PREFIX}/{model_prefix}/{local_path.name}"
|
|
48
|
-
try:
|
|
49
|
-
subprocess.run(
|
|
50
|
-
["aws", "s3", "cp", str(local_path), s3_path, "--quiet"],
|
|
51
|
-
check=True,
|
|
52
|
-
capture_output=True,
|
|
53
|
-
)
|
|
54
|
-
print(f" Uploaded to S3: {s3_path}")
|
|
55
|
-
except Exception as e:
|
|
56
|
-
print(f" S3 upload failed: {e}")
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
def load_diagnosis_results(model_name: str, output_dir: Path) -> Dict[str, Any]:
|
|
60
|
-
"""Load diagnosis results from S3/local."""
|
|
61
|
-
model_prefix = model_name.replace('/', '_')
|
|
62
|
-
|
|
63
|
-
try:
|
|
64
|
-
subprocess.run(
|
|
65
|
-
["aws", "s3", "sync",
|
|
66
|
-
f"s3://{S3_BUCKET}/direction_discovery/{model_prefix}/",
|
|
67
|
-
str(output_dir / "diagnosis"),
|
|
68
|
-
"--quiet"],
|
|
69
|
-
check=False,
|
|
70
|
-
capture_output=True,
|
|
71
|
-
)
|
|
72
|
-
except Exception:
|
|
73
|
-
pass
|
|
74
|
-
|
|
75
|
-
results = {}
|
|
76
|
-
diagnosis_dir = output_dir / "diagnosis"
|
|
77
|
-
if diagnosis_dir.exists():
|
|
78
|
-
for f in diagnosis_dir.glob(f"{model_prefix}_*.json"):
|
|
79
|
-
if "summary" not in f.name:
|
|
80
|
-
category = f.stem.replace(f"{model_prefix}_", "")
|
|
81
|
-
with open(f) as fp:
|
|
82
|
-
results[category] = json.load(fp)
|
|
83
|
-
|
|
84
|
-
return results
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
def select_representative_benchmarks(
|
|
88
|
-
diagnosis_results: Dict[str, Any],
|
|
89
|
-
n_per_type: int = 2,
|
|
90
|
-
) -> Dict[str, List[str]]:
|
|
91
|
-
"""
|
|
92
|
-
Select representative benchmarks for each diagnosis type.
|
|
93
|
-
|
|
94
|
-
Args:
|
|
95
|
-
diagnosis_results: Loaded diagnosis results
|
|
96
|
-
n_per_type: Number of benchmarks per type
|
|
97
|
-
|
|
98
|
-
Returns:
|
|
99
|
-
Dict with keys 'LINEAR', 'NONLINEAR', 'NO_SIGNAL'
|
|
100
|
-
"""
|
|
101
|
-
by_diagnosis = {"LINEAR": [], "NONLINEAR": [], "NO_SIGNAL": []}
|
|
102
|
-
|
|
103
|
-
for category, data in diagnosis_results.items():
|
|
104
|
-
results = data.get("results", [])
|
|
105
|
-
seen = set()
|
|
106
|
-
|
|
107
|
-
for r in results:
|
|
108
|
-
bench = r["benchmark"]
|
|
109
|
-
if bench in seen:
|
|
110
|
-
continue
|
|
111
|
-
seen.add(bench)
|
|
112
|
-
|
|
113
|
-
signal = r["signal_strength"]
|
|
114
|
-
linear = r["linear_probe_accuracy"]
|
|
115
|
-
knn = r["nonlinear_metrics"]["knn_accuracy_k10"]
|
|
116
|
-
|
|
117
|
-
if signal < 0.6:
|
|
118
|
-
by_diagnosis["NO_SIGNAL"].append((bench, signal, linear, knn))
|
|
119
|
-
elif linear > 0.6 and (signal - linear) < 0.15:
|
|
120
|
-
by_diagnosis["LINEAR"].append((bench, signal, linear, knn))
|
|
121
|
-
else:
|
|
122
|
-
by_diagnosis["NONLINEAR"].append((bench, signal, linear, knn))
|
|
123
|
-
|
|
124
|
-
# Select best examples (highest separation for LINEAR/NONLINEAR, lowest for NO_SIGNAL)
|
|
125
|
-
selected = {}
|
|
126
|
-
|
|
127
|
-
# LINEAR: highest linear probe accuracy
|
|
128
|
-
by_diagnosis["LINEAR"].sort(key=lambda x: x[2], reverse=True)
|
|
129
|
-
selected["LINEAR"] = [b[0] for b in by_diagnosis["LINEAR"][:n_per_type]]
|
|
130
|
-
|
|
131
|
-
# NONLINEAR: highest gap between kNN and linear
|
|
132
|
-
by_diagnosis["NONLINEAR"].sort(key=lambda x: x[3] - x[2], reverse=True)
|
|
133
|
-
selected["NONLINEAR"] = [b[0] for b in by_diagnosis["NONLINEAR"][:n_per_type]]
|
|
134
|
-
|
|
135
|
-
# NO_SIGNAL: lowest signal
|
|
136
|
-
by_diagnosis["NO_SIGNAL"].sort(key=lambda x: x[1])
|
|
137
|
-
selected["NO_SIGNAL"] = [b[0] for b in by_diagnosis["NO_SIGNAL"][:n_per_type]]
|
|
138
|
-
|
|
139
|
-
return selected
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
def create_tsne_plot(
|
|
143
|
-
pos_activations: torch.Tensor,
|
|
144
|
-
neg_activations: torch.Tensor,
|
|
145
|
-
title: str,
|
|
146
|
-
ax: plt.Axes,
|
|
147
|
-
diagnosis: str,
|
|
148
|
-
) -> None:
|
|
149
|
-
"""
|
|
150
|
-
Create t-SNE visualization on given axes.
|
|
151
|
-
|
|
152
|
-
Args:
|
|
153
|
-
pos_activations: [N, D] positive class
|
|
154
|
-
neg_activations: [N, D] negative class
|
|
155
|
-
title: Plot title
|
|
156
|
-
ax: Matplotlib axes
|
|
157
|
-
diagnosis: 'LINEAR', 'NONLINEAR', or 'NO_SIGNAL'
|
|
158
|
-
"""
|
|
159
|
-
if not HAS_SKLEARN or not HAS_MATPLOTLIB:
|
|
160
|
-
return
|
|
161
|
-
|
|
162
|
-
pos = pos_activations.float().cpu().numpy()
|
|
163
|
-
neg = neg_activations.float().cpu().numpy()
|
|
164
|
-
|
|
165
|
-
X = np.vstack([pos, neg])
|
|
166
|
-
labels = np.array([1] * len(pos) + [0] * len(neg))
|
|
167
|
-
|
|
168
|
-
# Reduce dimensionality with PCA first for speed
|
|
169
|
-
if X.shape[1] > 50:
|
|
170
|
-
pca = PCA(n_components=50)
|
|
171
|
-
X = pca.fit_transform(X)
|
|
172
|
-
|
|
173
|
-
# t-SNE
|
|
174
|
-
tsne = TSNE(n_components=2, perplexity=min(30, len(X) // 4), random_state=42)
|
|
175
|
-
X_2d = tsne.fit_transform(X)
|
|
176
|
-
|
|
177
|
-
# Color scheme based on diagnosis
|
|
178
|
-
colors = {
|
|
179
|
-
"LINEAR": ("#2ecc71", "#e74c3c"), # Green/Red
|
|
180
|
-
"NONLINEAR": ("#3498db", "#e67e22"), # Blue/Orange
|
|
181
|
-
"NO_SIGNAL": ("#95a5a6", "#7f8c8d"), # Gray shades
|
|
182
|
-
}
|
|
183
|
-
pos_color, neg_color = colors.get(diagnosis, ("#2ecc71", "#e74c3c"))
|
|
184
|
-
|
|
185
|
-
# Plot
|
|
186
|
-
ax.scatter(X_2d[labels == 1, 0], X_2d[labels == 1, 1],
|
|
187
|
-
c=pos_color, label='Positive', alpha=0.7, s=30)
|
|
188
|
-
ax.scatter(X_2d[labels == 0, 0], X_2d[labels == 0, 1],
|
|
189
|
-
c=neg_color, label='Negative', alpha=0.7, s=30)
|
|
190
|
-
|
|
191
|
-
ax.set_title(title, fontsize=12, fontweight='bold')
|
|
192
|
-
ax.set_xticks([])
|
|
193
|
-
ax.set_yticks([])
|
|
194
|
-
|
|
195
|
-
# Add diagnosis label
|
|
196
|
-
ax.text(0.02, 0.98, diagnosis, transform=ax.transAxes,
|
|
197
|
-
fontsize=10, fontweight='bold', va='top',
|
|
198
|
-
bbox=dict(boxstyle='round', facecolor='white', alpha=0.8))
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
def create_hero_figure(
|
|
202
|
-
diagnosis_results: Dict[str, Any],
|
|
203
|
-
output_path: Path,
|
|
204
|
-
model_name: str,
|
|
205
|
-
) -> None:
|
|
206
|
-
"""
|
|
207
|
-
Create hero figure for paper.
|
|
208
|
-
|
|
209
|
-
Layout:
|
|
210
|
-
[Pipeline Diagram] [Key Results Pie] [Example t-SNE]
|
|
211
|
-
|
|
212
|
-
Args:
|
|
213
|
-
diagnosis_results: Loaded diagnosis results
|
|
214
|
-
output_path: Where to save figure
|
|
215
|
-
model_name: Model name for title
|
|
216
|
-
"""
|
|
217
|
-
if not HAS_MATPLOTLIB:
|
|
218
|
-
print("Skipping hero figure: matplotlib not installed")
|
|
219
|
-
return
|
|
220
|
-
|
|
221
|
-
fig = plt.figure(figsize=(15, 5))
|
|
222
|
-
gs = GridSpec(1, 3, width_ratios=[1.2, 1, 1.2])
|
|
223
|
-
|
|
224
|
-
# Panel 1: Pipeline diagram (simplified)
|
|
225
|
-
ax1 = fig.add_subplot(gs[0])
|
|
226
|
-
ax1.set_xlim(0, 10)
|
|
227
|
-
ax1.set_ylim(0, 10)
|
|
228
|
-
ax1.axis('off')
|
|
229
|
-
ax1.set_title('RepScan Pipeline', fontsize=14, fontweight='bold')
|
|
230
|
-
|
|
231
|
-
# Draw boxes
|
|
232
|
-
boxes = [
|
|
233
|
-
(1, 7, 'Contrastive\nPairs'),
|
|
234
|
-
(4, 7, 'Layer\nScan'),
|
|
235
|
-
(7, 7, 'Metrics'),
|
|
236
|
-
(1, 3, 'kNN'),
|
|
237
|
-
(4, 3, 'Linear\nProbe'),
|
|
238
|
-
(7, 3, 'Diagnosis'),
|
|
239
|
-
]
|
|
240
|
-
|
|
241
|
-
for x, y, label in boxes:
|
|
242
|
-
rect = mpatches.FancyBboxPatch((x-0.8, y-0.6), 1.6, 1.2,
|
|
243
|
-
boxstyle="round,pad=0.1",
|
|
244
|
-
facecolor='lightblue', edgecolor='black')
|
|
245
|
-
ax1.add_patch(rect)
|
|
246
|
-
ax1.text(x, y, label, ha='center', va='center', fontsize=9)
|
|
247
|
-
|
|
248
|
-
# Draw arrows
|
|
249
|
-
ax1.annotate('', xy=(3.2, 7), xytext=(1.8, 7),
|
|
250
|
-
arrowprops=dict(arrowstyle='->', color='black'))
|
|
251
|
-
ax1.annotate('', xy=(6.2, 7), xytext=(4.8, 7),
|
|
252
|
-
arrowprops=dict(arrowstyle='->', color='black'))
|
|
253
|
-
ax1.annotate('', xy=(1, 6.4), xytext=(1, 3.6),
|
|
254
|
-
arrowprops=dict(arrowstyle='->', color='black'))
|
|
255
|
-
ax1.annotate('', xy=(4, 6.4), xytext=(4, 3.6),
|
|
256
|
-
arrowprops=dict(arrowstyle='->', color='black'))
|
|
257
|
-
ax1.annotate('', xy=(6.2, 3), xytext=(4.8, 3),
|
|
258
|
-
arrowprops=dict(arrowstyle='->', color='black'))
|
|
259
|
-
ax1.annotate('', xy=(6.2, 3), xytext=(1.8, 3),
|
|
260
|
-
arrowprops=dict(arrowstyle='->', color='black'))
|
|
261
|
-
|
|
262
|
-
# Panel 2: Diagnosis distribution pie chart
|
|
263
|
-
ax2 = fig.add_subplot(gs[1])
|
|
264
|
-
|
|
265
|
-
# Count diagnoses
|
|
266
|
-
counts = {"LINEAR": 0, "NONLINEAR": 0, "NO_SIGNAL": 0}
|
|
267
|
-
for category, data in diagnosis_results.items():
|
|
268
|
-
results = data.get("results", [])
|
|
269
|
-
seen = set()
|
|
270
|
-
for r in results:
|
|
271
|
-
bench = r["benchmark"]
|
|
272
|
-
if bench in seen:
|
|
273
|
-
continue
|
|
274
|
-
seen.add(bench)
|
|
275
|
-
|
|
276
|
-
signal = r["signal_strength"]
|
|
277
|
-
linear = r["linear_probe_accuracy"]
|
|
278
|
-
|
|
279
|
-
if signal < 0.6:
|
|
280
|
-
counts["NO_SIGNAL"] += 1
|
|
281
|
-
elif linear > 0.6 and (signal - linear) < 0.15:
|
|
282
|
-
counts["LINEAR"] += 1
|
|
283
|
-
else:
|
|
284
|
-
counts["NONLINEAR"] += 1
|
|
285
|
-
|
|
286
|
-
labels = list(counts.keys())
|
|
287
|
-
sizes = list(counts.values())
|
|
288
|
-
colors = ['#2ecc71', '#3498db', '#95a5a6']
|
|
289
|
-
explode = (0.05, 0.05, 0.05)
|
|
290
|
-
|
|
291
|
-
ax2.pie(sizes, explode=explode, labels=labels, colors=colors,
|
|
292
|
-
autopct='%1.1f%%', shadow=True, startangle=90)
|
|
293
|
-
ax2.set_title(f'Diagnosis Distribution\n({sum(sizes)} benchmarks)',
|
|
294
|
-
fontsize=14, fontweight='bold')
|
|
295
|
-
|
|
296
|
-
# Panel 3: Key metrics bar chart
|
|
297
|
-
ax3 = fig.add_subplot(gs[2])
|
|
298
|
-
|
|
299
|
-
# Compute average metrics by diagnosis
|
|
300
|
-
metrics_by_diag = {
|
|
301
|
-
"LINEAR": {"knn": [], "linear": [], "signal": []},
|
|
302
|
-
"NONLINEAR": {"knn": [], "linear": [], "signal": []},
|
|
303
|
-
"NO_SIGNAL": {"knn": [], "linear": [], "signal": []},
|
|
304
|
-
}
|
|
305
|
-
|
|
306
|
-
for category, data in diagnosis_results.items():
|
|
307
|
-
results = data.get("results", [])
|
|
308
|
-
for r in results:
|
|
309
|
-
signal = r["signal_strength"]
|
|
310
|
-
linear = r["linear_probe_accuracy"]
|
|
311
|
-
knn = r["nonlinear_metrics"]["knn_accuracy_k10"]
|
|
312
|
-
|
|
313
|
-
if signal < 0.6:
|
|
314
|
-
diag = "NO_SIGNAL"
|
|
315
|
-
elif linear > 0.6 and (signal - linear) < 0.15:
|
|
316
|
-
diag = "LINEAR"
|
|
317
|
-
else:
|
|
318
|
-
diag = "NONLINEAR"
|
|
319
|
-
|
|
320
|
-
metrics_by_diag[diag]["knn"].append(knn)
|
|
321
|
-
metrics_by_diag[diag]["linear"].append(linear)
|
|
322
|
-
metrics_by_diag[diag]["signal"].append(signal)
|
|
323
|
-
|
|
324
|
-
# Create grouped bar chart
|
|
325
|
-
x = np.arange(3)
|
|
326
|
-
width = 0.25
|
|
327
|
-
|
|
328
|
-
knn_means = [np.mean(metrics_by_diag[d]["knn"]) if metrics_by_diag[d]["knn"] else 0.5
|
|
329
|
-
for d in ["LINEAR", "NONLINEAR", "NO_SIGNAL"]]
|
|
330
|
-
linear_means = [np.mean(metrics_by_diag[d]["linear"]) if metrics_by_diag[d]["linear"] else 0.5
|
|
331
|
-
for d in ["LINEAR", "NONLINEAR", "NO_SIGNAL"]]
|
|
332
|
-
signal_means = [np.mean(metrics_by_diag[d]["signal"]) if metrics_by_diag[d]["signal"] else 0.5
|
|
333
|
-
for d in ["LINEAR", "NONLINEAR", "NO_SIGNAL"]]
|
|
334
|
-
|
|
335
|
-
ax3.bar(x - width, knn_means, width, label='kNN Acc', color='#3498db')
|
|
336
|
-
ax3.bar(x, linear_means, width, label='Linear Probe', color='#2ecc71')
|
|
337
|
-
ax3.bar(x + width, signal_means, width, label='MLP Signal', color='#e74c3c')
|
|
338
|
-
|
|
339
|
-
ax3.set_ylabel('Accuracy')
|
|
340
|
-
ax3.set_title('Metrics by Diagnosis', fontsize=14, fontweight='bold')
|
|
341
|
-
ax3.set_xticks(x)
|
|
342
|
-
ax3.set_xticklabels(['LINEAR', 'NONLINEAR', 'NO_SIGNAL'])
|
|
343
|
-
ax3.legend(loc='upper right')
|
|
344
|
-
ax3.set_ylim(0, 1)
|
|
345
|
-
ax3.axhline(y=0.6, color='gray', linestyle='--', alpha=0.5)
|
|
346
|
-
|
|
347
|
-
plt.tight_layout()
|
|
348
|
-
plt.savefig(output_path, dpi=300, bbox_inches='tight')
|
|
349
|
-
plt.close()
|
|
350
|
-
|
|
351
|
-
print(f" Saved hero figure: {output_path}")
|
|
352
|
-
|
|
353
|
-
|
|
354
|
-
def create_tsne_gallery(
|
|
355
|
-
model: "WisentModel",
|
|
356
|
-
selected_benchmarks: Dict[str, List[str]],
|
|
357
|
-
output_path: Path,
|
|
358
|
-
model_name: str,
|
|
359
|
-
) -> None:
|
|
360
|
-
"""
|
|
361
|
-
Create t-SNE gallery figure.
|
|
362
|
-
|
|
363
|
-
Layout: 2x3 grid showing examples from each diagnosis type
|
|
364
|
-
|
|
365
|
-
Args:
|
|
366
|
-
model: WisentModel instance
|
|
367
|
-
selected_benchmarks: Dict with benchmark names by diagnosis
|
|
368
|
-
output_path: Where to save figure
|
|
369
|
-
model_name: Model name for title
|
|
370
|
-
"""
|
|
371
|
-
if not HAS_MATPLOTLIB or not HAS_SKLEARN:
|
|
372
|
-
print("Skipping t-SNE gallery: required packages not installed")
|
|
373
|
-
return
|
|
374
|
-
|
|
375
|
-
from wisent.core.activations.extraction_strategy import ExtractionStrategy
|
|
376
|
-
from wisent.core.activations.activation_cache import ActivationCache, collect_and_cache_activations
|
|
377
|
-
from lm_eval.tasks import TaskManager
|
|
378
|
-
from wisent.core.contrastive_pairs.lm_eval_pairs.lm_task_pairs_generation import lm_build_contrastive_pairs
|
|
379
|
-
|
|
380
|
-
fig, axes = plt.subplots(2, 3, figsize=(12, 8))
|
|
381
|
-
fig.suptitle(f't-SNE Visualization Gallery\n{model_name}', fontsize=14, fontweight='bold')
|
|
382
|
-
|
|
383
|
-
cache_dir = f"/tmp/wisent_viz_cache_{model_name.replace('/', '_')}"
|
|
384
|
-
cache = ActivationCache(cache_dir)
|
|
385
|
-
tm = TaskManager()
|
|
386
|
-
strategy = ExtractionStrategy.CHAT_LAST
|
|
387
|
-
|
|
388
|
-
row = 0
|
|
389
|
-
for diagnosis in ["LINEAR", "NONLINEAR", "NO_SIGNAL"]:
|
|
390
|
-
benchmarks = selected_benchmarks.get(diagnosis, [])
|
|
391
|
-
|
|
392
|
-
for col, benchmark in enumerate(benchmarks[:2]):
|
|
393
|
-
ax = axes[col, ["LINEAR", "NONLINEAR", "NO_SIGNAL"].index(diagnosis)]
|
|
394
|
-
|
|
395
|
-
try:
|
|
396
|
-
# Load pairs
|
|
397
|
-
try:
|
|
398
|
-
task_dict = tm.load_task_or_group([benchmark])
|
|
399
|
-
task = list(task_dict.values())[0]
|
|
400
|
-
except Exception:
|
|
401
|
-
task = None
|
|
402
|
-
|
|
403
|
-
pairs = lm_build_contrastive_pairs(benchmark, task, limit=50)
|
|
404
|
-
|
|
405
|
-
if len(pairs) < 20:
|
|
406
|
-
ax.text(0.5, 0.5, f'{benchmark}\n(insufficient data)',
|
|
407
|
-
ha='center', va='center', transform=ax.transAxes)
|
|
408
|
-
ax.axis('off')
|
|
409
|
-
continue
|
|
410
|
-
|
|
411
|
-
# Get activations
|
|
412
|
-
cached = collect_and_cache_activations(
|
|
413
|
-
model=model,
|
|
414
|
-
pairs=pairs,
|
|
415
|
-
benchmark=benchmark,
|
|
416
|
-
strategy=strategy,
|
|
417
|
-
cache=cache,
|
|
418
|
-
show_progress=False,
|
|
419
|
-
)
|
|
420
|
-
|
|
421
|
-
# Use middle layer
|
|
422
|
-
middle_layer = str(model.num_layers // 2)
|
|
423
|
-
pos_acts = cached.get_positive_activations(middle_layer)
|
|
424
|
-
neg_acts = cached.get_negative_activations(middle_layer)
|
|
425
|
-
|
|
426
|
-
# Create t-SNE plot
|
|
427
|
-
create_tsne_plot(pos_acts, neg_acts, benchmark, ax, diagnosis)
|
|
428
|
-
|
|
429
|
-
except Exception as e:
|
|
430
|
-
ax.text(0.5, 0.5, f'{benchmark}\n(error: {str(e)[:30]})',
|
|
431
|
-
ha='center', va='center', transform=ax.transAxes)
|
|
432
|
-
ax.axis('off')
|
|
433
|
-
|
|
434
|
-
# Add column labels
|
|
435
|
-
for idx, diagnosis in enumerate(["LINEAR", "NONLINEAR", "NO_SIGNAL"]):
|
|
436
|
-
axes[0, idx].set_xlabel(diagnosis, fontsize=12, fontweight='bold')
|
|
437
|
-
|
|
438
|
-
plt.tight_layout()
|
|
439
|
-
plt.savefig(output_path, dpi=300, bbox_inches='tight')
|
|
440
|
-
plt.close()
|
|
441
|
-
|
|
442
|
-
print(f" Saved t-SNE gallery: {output_path}")
|
|
443
|
-
|
|
444
|
-
|
|
445
|
-
def create_layer_accuracy_curves(
|
|
446
|
-
diagnosis_results: Dict[str, Any],
|
|
447
|
-
output_path: Path,
|
|
448
|
-
model_name: str,
|
|
449
|
-
) -> None:
|
|
450
|
-
"""
|
|
451
|
-
Create layer-wise accuracy curves.
|
|
452
|
-
|
|
453
|
-
Shows how kNN and linear probe accuracy change across layers.
|
|
454
|
-
|
|
455
|
-
Args:
|
|
456
|
-
diagnosis_results: Loaded diagnosis results
|
|
457
|
-
output_path: Where to save figure
|
|
458
|
-
model_name: Model name for title
|
|
459
|
-
"""
|
|
460
|
-
if not HAS_MATPLOTLIB:
|
|
461
|
-
print("Skipping layer curves: matplotlib not installed")
|
|
462
|
-
return
|
|
463
|
-
|
|
464
|
-
fig, axes = plt.subplots(1, 3, figsize=(15, 5))
|
|
465
|
-
fig.suptitle(f'Layer-wise Accuracy Curves\n{model_name}', fontsize=14, fontweight='bold')
|
|
466
|
-
|
|
467
|
-
# Collect layer-wise data (we don't have per-layer data in current results,
|
|
468
|
-
# so we'll create placeholder showing the concept)
|
|
469
|
-
|
|
470
|
-
# This would need per-layer results which we can add to discover_directions
|
|
471
|
-
# For now, create example curves
|
|
472
|
-
|
|
473
|
-
for idx, (diagnosis, color) in enumerate([
|
|
474
|
-
("LINEAR", "#2ecc71"),
|
|
475
|
-
("NONLINEAR", "#3498db"),
|
|
476
|
-
("NO_SIGNAL", "#95a5a6")
|
|
477
|
-
]):
|
|
478
|
-
ax = axes[idx]
|
|
479
|
-
|
|
480
|
-
# Example curves (would be replaced with real data)
|
|
481
|
-
layers = np.arange(1, 33)
|
|
482
|
-
|
|
483
|
-
if diagnosis == "LINEAR":
|
|
484
|
-
knn = 0.5 + 0.4 * np.exp(-(layers - 16)**2 / 100)
|
|
485
|
-
linear = 0.5 + 0.35 * np.exp(-(layers - 16)**2 / 100)
|
|
486
|
-
elif diagnosis == "NONLINEAR":
|
|
487
|
-
knn = 0.5 + 0.35 * np.exp(-(layers - 16)**2 / 100)
|
|
488
|
-
linear = 0.5 + 0.1 * np.exp(-(layers - 16)**2 / 100)
|
|
489
|
-
else:
|
|
490
|
-
knn = 0.5 + 0.05 * np.random.randn(len(layers))
|
|
491
|
-
linear = 0.5 + 0.05 * np.random.randn(len(layers))
|
|
492
|
-
|
|
493
|
-
ax.plot(layers, knn, 'b-', linewidth=2, label='kNN-10')
|
|
494
|
-
ax.plot(layers, linear, 'g--', linewidth=2, label='Linear Probe')
|
|
495
|
-
ax.fill_between(layers, knn, linear, alpha=0.3, color='yellow', label='Gap')
|
|
496
|
-
|
|
497
|
-
ax.set_xlabel('Layer')
|
|
498
|
-
ax.set_ylabel('Accuracy')
|
|
499
|
-
ax.set_title(diagnosis, fontsize=12, fontweight='bold')
|
|
500
|
-
ax.legend()
|
|
501
|
-
ax.set_ylim(0.4, 1.0)
|
|
502
|
-
ax.axhline(y=0.6, color='gray', linestyle='--', alpha=0.5)
|
|
503
|
-
ax.set_xlim(1, 32)
|
|
504
|
-
|
|
505
|
-
plt.tight_layout()
|
|
506
|
-
plt.savefig(output_path, dpi=300, bbox_inches='tight')
|
|
507
|
-
plt.close()
|
|
508
|
-
|
|
509
|
-
print(f" Saved layer curves: {output_path}")
|
|
510
|
-
|
|
511
|
-
|
|
512
|
-
def run_visualization(model_name: str, skip_tsne: bool = False):
|
|
513
|
-
"""
|
|
514
|
-
Generate all visualizations.
|
|
515
|
-
|
|
516
|
-
Args:
|
|
517
|
-
model_name: Model to visualize
|
|
518
|
-
skip_tsne: Skip t-SNE (requires model loading)
|
|
519
|
-
"""
|
|
520
|
-
print("=" * 70)
|
|
521
|
-
print("VISUALIZATION GALLERY")
|
|
522
|
-
print("=" * 70)
|
|
523
|
-
print(f"Model: {model_name}")
|
|
524
|
-
|
|
525
|
-
output_dir = Path("/tmp/visualizations")
|
|
526
|
-
output_dir.mkdir(parents=True, exist_ok=True)
|
|
527
|
-
|
|
528
|
-
# Load diagnosis results
|
|
529
|
-
diagnosis_results = load_diagnosis_results(model_name, output_dir)
|
|
530
|
-
if not diagnosis_results:
|
|
531
|
-
print("ERROR: No diagnosis results found.")
|
|
532
|
-
return
|
|
533
|
-
|
|
534
|
-
print(f"Loaded results for {len(diagnosis_results)} categories")
|
|
535
|
-
|
|
536
|
-
model_prefix = model_name.replace('/', '_')
|
|
537
|
-
|
|
538
|
-
# 1. Hero figure
|
|
539
|
-
print("\n1. Creating hero figure...")
|
|
540
|
-
hero_path = output_dir / f"{model_prefix}_hero_figure.png"
|
|
541
|
-
create_hero_figure(diagnosis_results, hero_path, model_name)
|
|
542
|
-
s3_upload_file(hero_path, model_name)
|
|
543
|
-
|
|
544
|
-
# 2. Layer accuracy curves
|
|
545
|
-
print("\n2. Creating layer accuracy curves...")
|
|
546
|
-
curves_path = output_dir / f"{model_prefix}_layer_curves.png"
|
|
547
|
-
create_layer_accuracy_curves(diagnosis_results, curves_path, model_name)
|
|
548
|
-
s3_upload_file(curves_path, model_name)
|
|
549
|
-
|
|
550
|
-
# 3. t-SNE gallery (requires model)
|
|
551
|
-
if not skip_tsne:
|
|
552
|
-
print("\n3. Creating t-SNE gallery...")
|
|
553
|
-
|
|
554
|
-
from wisent.core.models.wisent_model import WisentModel
|
|
555
|
-
|
|
556
|
-
print(f" Loading model: {model_name}")
|
|
557
|
-
model = WisentModel(model_name, device="cuda")
|
|
558
|
-
|
|
559
|
-
selected = select_representative_benchmarks(diagnosis_results, n_per_type=2)
|
|
560
|
-
print(f" Selected benchmarks: {selected}")
|
|
561
|
-
|
|
562
|
-
tsne_path = output_dir / f"{model_prefix}_tsne_gallery.png"
|
|
563
|
-
create_tsne_gallery(model, selected, tsne_path, model_name)
|
|
564
|
-
s3_upload_file(tsne_path, model_name)
|
|
565
|
-
|
|
566
|
-
del model
|
|
567
|
-
else:
|
|
568
|
-
print("\n3. Skipping t-SNE gallery (--skip-tsne)")
|
|
569
|
-
|
|
570
|
-
print("\n" + "=" * 70)
|
|
571
|
-
print("VISUALIZATION COMPLETE")
|
|
572
|
-
print("=" * 70)
|
|
573
|
-
print(f"Figures saved to: {output_dir}")
|
|
574
|
-
|
|
575
|
-
|
|
576
|
-
if __name__ == "__main__":
|
|
577
|
-
parser = argparse.ArgumentParser(description="Visualization gallery for RepScan")
|
|
578
|
-
parser.add_argument("--model", type=str, default="Qwen/Qwen3-8B", help="Model to visualize")
|
|
579
|
-
parser.add_argument("--skip-tsne", action="store_true", help="Skip t-SNE (doesn't require model)")
|
|
580
|
-
args = parser.parse_args()
|
|
581
|
-
|
|
582
|
-
run_visualization(args.model, skip_tsne=args.skip_tsne)
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|