wisent 0.7.901__py3-none-any.whl → 0.7.1045__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- wisent/__init__.py +1 -1
- wisent/comparison/__init__.py +1 -0
- wisent/comparison/detect_bos_features.py +275 -0
- wisent/comparison/fgaa.py +465 -0
- wisent/comparison/lora.py +669 -0
- wisent/comparison/lora_dpo.py +592 -0
- wisent/comparison/main.py +444 -0
- wisent/comparison/ours.py +76 -0
- wisent/comparison/sae.py +304 -0
- wisent/comparison/utils.py +381 -0
- wisent/core/activations/activations_collector.py +3 -2
- wisent/core/activations/extraction_strategy.py +8 -4
- wisent/core/cli/agent/apply_steering.py +7 -5
- wisent/core/cli/agent/train_classifier.py +4 -3
- wisent/core/cli/generate_vector_from_task.py +11 -20
- wisent/core/cli/get_activations.py +1 -1
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/boolq.py +20 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/cb.py +8 -1
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/truthfulqa_mc1.py +8 -1
- wisent/core/parser_arguments/generate_vector_from_task_parser.py +4 -11
- wisent/core/parser_arguments/get_activations_parser.py +5 -14
- {wisent-0.7.901.dist-info → wisent-0.7.1045.dist-info}/METADATA +5 -1
- {wisent-0.7.901.dist-info → wisent-0.7.1045.dist-info}/RECORD +27 -91
- wisent/examples/contrastive_pairs/humanization_human_vs_ai.json +0 -2112
- wisent/examples/scripts/1/test_basqueglue_evaluation.json +0 -51
- wisent/examples/scripts/1/test_basqueglue_pairs.json +0 -14
- wisent/examples/scripts/1/test_bec2016eu_evaluation.json +0 -51
- wisent/examples/scripts/1/test_bec2016eu_pairs.json +0 -14
- wisent/examples/scripts/1/test_belebele_evaluation.json +0 -51
- wisent/examples/scripts/1/test_belebele_pairs.json +0 -14
- wisent/examples/scripts/1/test_benchmarks_evaluation.json +0 -51
- wisent/examples/scripts/1/test_benchmarks_pairs.json +0 -14
- wisent/examples/scripts/1/test_bertaqa_evaluation.json +0 -51
- wisent/examples/scripts/1/test_bertaqa_pairs.json +0 -14
- wisent/examples/scripts/1/test_bhtc_v2_evaluation.json +0 -30
- wisent/examples/scripts/1/test_bhtc_v2_pairs.json +0 -8
- wisent/examples/scripts/1/test_boolq-seq2seq_evaluation.json +0 -30
- wisent/examples/scripts/1/test_boolq-seq2seq_pairs.json +0 -8
- wisent/examples/scripts/1/test_cabreu_evaluation.json +0 -30
- wisent/examples/scripts/1/test_cabreu_pairs.json +0 -8
- wisent/examples/scripts/1/test_careqa_en_evaluation.json +0 -30
- wisent/examples/scripts/1/test_careqa_en_pairs.json +0 -8
- wisent/examples/scripts/1/test_careqa_evaluation.json +0 -30
- wisent/examples/scripts/1/test_careqa_pairs.json +0 -8
- wisent/examples/scripts/1/test_catalanqa_evaluation.json +0 -30
- wisent/examples/scripts/1/test_catalanqa_pairs.json +0 -8
- wisent/examples/scripts/1/test_catcola_evaluation.json +0 -30
- wisent/examples/scripts/1/test_catcola_pairs.json +0 -8
- wisent/examples/scripts/1/test_chartqa_evaluation.json +0 -30
- wisent/examples/scripts/1/test_chartqa_pairs.json +0 -8
- wisent/examples/scripts/1/test_claim_stance_topic_evaluation.json +0 -30
- wisent/examples/scripts/1/test_claim_stance_topic_pairs.json +0 -8
- wisent/examples/scripts/1/test_cnn_dailymail_evaluation.json +0 -30
- wisent/examples/scripts/1/test_cnn_dailymail_pairs.json +0 -8
- wisent/examples/scripts/1/test_cocoteros_es_evaluation.json +0 -30
- wisent/examples/scripts/1/test_cocoteros_es_pairs.json +0 -8
- wisent/examples/scripts/1/test_coedit_gec_evaluation.json +0 -30
- wisent/examples/scripts/1/test_coedit_gec_pairs.json +0 -8
- wisent/examples/scripts/1/test_cola_evaluation.json +0 -30
- wisent/examples/scripts/1/test_cola_pairs.json +0 -8
- wisent/examples/scripts/1/test_coqcat_evaluation.json +0 -30
- wisent/examples/scripts/1/test_coqcat_pairs.json +0 -8
- wisent/examples/scripts/1/test_dbpedia_14_evaluation.json +0 -30
- wisent/examples/scripts/1/test_dbpedia_14_pairs.json +0 -8
- wisent/examples/scripts/1/test_epec_koref_bin_evaluation.json +0 -30
- wisent/examples/scripts/1/test_epec_koref_bin_pairs.json +0 -8
- wisent/examples/scripts/1/test_ethos_binary_evaluation.json +0 -30
- wisent/examples/scripts/1/test_ethos_binary_pairs.json +0 -8
- wisent/examples/scripts/2/test_afrimgsm_direct_amh_evaluation.json +0 -30
- wisent/examples/scripts/2/test_afrimgsm_direct_amh_pairs.json +0 -8
- wisent/examples/scripts/2/test_afrimmlu_direct_amh_evaluation.json +0 -30
- wisent/examples/scripts/2/test_afrimmlu_direct_amh_pairs.json +0 -8
- wisent/examples/scripts/2/test_afrixnli_en_direct_amh_evaluation.json +0 -30
- wisent/examples/scripts/2/test_afrixnli_en_direct_amh_pairs.json +0 -8
- wisent/examples/scripts/2/test_arc_ar_evaluation.json +0 -30
- wisent/examples/scripts/2/test_arc_ar_pairs.json +0 -8
- wisent/examples/scripts/2/test_atis_evaluation.json +0 -30
- wisent/examples/scripts/2/test_atis_pairs.json +0 -8
- wisent/examples/scripts/2/test_babi_evaluation.json +0 -30
- wisent/examples/scripts/2/test_babi_pairs.json +0 -8
- wisent/examples/scripts/2/test_babilong_evaluation.json +0 -30
- wisent/examples/scripts/2/test_babilong_pairs.json +0 -8
- wisent/examples/scripts/2/test_bangla_mmlu_evaluation.json +0 -30
- wisent/examples/scripts/2/test_bangla_mmlu_pairs.json +0 -8
- wisent/examples/scripts/2/test_basque-glue_pairs.json +0 -14
- wisent/examples/scripts/generate_paper_data.py +0 -384
- wisent/examples/scripts/intervention_validation.py +0 -626
- wisent/examples/scripts/results/test_AraDiCE_ArabicMMLU_lev_evaluation.json +0 -324
- wisent/examples/scripts/results/test_AraDiCE_ArabicMMLU_lev_pairs.json +0 -92
- wisent/examples/scripts/results/test_aexams_IslamicStudies_evaluation.json +0 -324
- wisent/examples/scripts/results/test_aexams_IslamicStudies_pairs.json +0 -92
- wisent/examples/scripts/results/test_afrimgsm_pairs.json +0 -92
- wisent/examples/scripts/results/test_afrimmlu_evaluation.json +0 -324
- wisent/examples/scripts/results/test_afrimmlu_pairs.json +0 -92
- wisent/examples/scripts/threshold_analysis.py +0 -434
- wisent/examples/scripts/visualization_gallery.py +0 -582
- {wisent-0.7.901.dist-info → wisent-0.7.1045.dist-info}/WHEEL +0 -0
- {wisent-0.7.901.dist-info → wisent-0.7.1045.dist-info}/entry_points.txt +0 -0
- {wisent-0.7.901.dist-info → wisent-0.7.1045.dist-info}/licenses/LICENSE +0 -0
- {wisent-0.7.901.dist-info → wisent-0.7.1045.dist-info}/top_level.txt +0 -0
|
@@ -1,434 +0,0 @@
|
|
|
1
|
-
"""
|
|
2
|
-
Threshold Analysis for RepScan.
|
|
3
|
-
|
|
4
|
-
Analyzes sensitivity of diagnosis to threshold choices:
|
|
5
|
-
- ROC curves for existence threshold
|
|
6
|
-
- Precision/recall tradeoff
|
|
7
|
-
- Null distribution analysis
|
|
8
|
-
- Synthetic validation
|
|
9
|
-
|
|
10
|
-
Usage:
|
|
11
|
-
python -m wisent.examples.scripts.threshold_analysis --model Qwen/Qwen3-8B
|
|
12
|
-
"""
|
|
13
|
-
|
|
14
|
-
import argparse
|
|
15
|
-
import json
|
|
16
|
-
import subprocess
|
|
17
|
-
from pathlib import Path
|
|
18
|
-
from typing import Dict, List, Any, Optional, Tuple
|
|
19
|
-
from dataclasses import dataclass, field, asdict
|
|
20
|
-
import random
|
|
21
|
-
|
|
22
|
-
import torch
|
|
23
|
-
import numpy as np
|
|
24
|
-
from sklearn.metrics import roc_curve, auc, precision_recall_curve
|
|
25
|
-
|
|
26
|
-
S3_BUCKET = "wisent-bucket"
|
|
27
|
-
S3_PREFIX = "threshold_analysis"
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
def s3_upload_file(local_path: Path, model_name: str) -> None:
|
|
31
|
-
"""Upload a single file to S3."""
|
|
32
|
-
model_prefix = model_name.replace('/', '_')
|
|
33
|
-
s3_path = f"s3://{S3_BUCKET}/{S3_PREFIX}/{model_prefix}/{local_path.name}"
|
|
34
|
-
try:
|
|
35
|
-
subprocess.run(
|
|
36
|
-
["aws", "s3", "cp", str(local_path), s3_path, "--quiet"],
|
|
37
|
-
check=True,
|
|
38
|
-
capture_output=True,
|
|
39
|
-
)
|
|
40
|
-
print(f" Uploaded to S3: {s3_path}")
|
|
41
|
-
except Exception as e:
|
|
42
|
-
print(f" S3 upload failed: {e}")
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
@dataclass
|
|
46
|
-
class ThresholdAnalysisResult:
|
|
47
|
-
"""Result of threshold analysis."""
|
|
48
|
-
# Existence threshold analysis
|
|
49
|
-
existence_thresholds: List[float]
|
|
50
|
-
existence_tpr: List[float] # True positive rate
|
|
51
|
-
existence_fpr: List[float] # False positive rate
|
|
52
|
-
existence_auc: float
|
|
53
|
-
optimal_existence_threshold: float
|
|
54
|
-
|
|
55
|
-
# Gap threshold analysis
|
|
56
|
-
gap_thresholds: List[float]
|
|
57
|
-
gap_precision: List[float]
|
|
58
|
-
gap_recall: List[float]
|
|
59
|
-
gap_f1: List[float]
|
|
60
|
-
optimal_gap_threshold: float
|
|
61
|
-
|
|
62
|
-
# Null distribution stats
|
|
63
|
-
null_mean_knn: float
|
|
64
|
-
null_std_knn: float
|
|
65
|
-
null_mean_linear: float
|
|
66
|
-
null_std_linear: float
|
|
67
|
-
|
|
68
|
-
# Sensitivity analysis
|
|
69
|
-
sensitivity_matrix: Dict[str, Dict[str, float]] # threshold -> diagnosis distribution
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
def generate_null_distribution(
|
|
73
|
-
model: "WisentModel",
|
|
74
|
-
n_samples: int = 100,
|
|
75
|
-
hidden_dim: int = 4096,
|
|
76
|
-
) -> Tuple[List[float], List[float]]:
|
|
77
|
-
"""
|
|
78
|
-
Generate null distribution by testing random/nonsense data.
|
|
79
|
-
|
|
80
|
-
Args:
|
|
81
|
-
model: WisentModel instance
|
|
82
|
-
n_samples: Number of random samples
|
|
83
|
-
hidden_dim: Hidden dimension
|
|
84
|
-
|
|
85
|
-
Returns:
|
|
86
|
-
(knn_scores, linear_scores) for random data
|
|
87
|
-
"""
|
|
88
|
-
from wisent.core.geometry_runner import compute_knn_accuracy, compute_linear_probe_accuracy
|
|
89
|
-
|
|
90
|
-
knn_scores = []
|
|
91
|
-
linear_scores = []
|
|
92
|
-
|
|
93
|
-
for _ in range(n_samples):
|
|
94
|
-
# Generate random activations (no real signal)
|
|
95
|
-
pos = torch.randn(50, hidden_dim)
|
|
96
|
-
neg = torch.randn(50, hidden_dim)
|
|
97
|
-
|
|
98
|
-
knn = compute_knn_accuracy(pos, neg, k=10)
|
|
99
|
-
linear = compute_linear_probe_accuracy(pos, neg)
|
|
100
|
-
|
|
101
|
-
knn_scores.append(knn)
|
|
102
|
-
linear_scores.append(linear)
|
|
103
|
-
|
|
104
|
-
return knn_scores, linear_scores
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
def generate_synthetic_data(
|
|
108
|
-
structure: str,
|
|
109
|
-
n_samples: int = 50,
|
|
110
|
-
hidden_dim: int = 100,
|
|
111
|
-
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
112
|
-
"""
|
|
113
|
-
Generate synthetic data with known structure for validation.
|
|
114
|
-
|
|
115
|
-
Args:
|
|
116
|
-
structure: 'linear', 'xor', 'spirals', 'random'
|
|
117
|
-
n_samples: Samples per class
|
|
118
|
-
hidden_dim: Dimension
|
|
119
|
-
|
|
120
|
-
Returns:
|
|
121
|
-
(pos_activations, neg_activations)
|
|
122
|
-
"""
|
|
123
|
-
if structure == "linear":
|
|
124
|
-
# Linear separable: positive class shifted in one direction
|
|
125
|
-
direction = torch.randn(hidden_dim)
|
|
126
|
-
direction = direction / direction.norm()
|
|
127
|
-
|
|
128
|
-
pos = torch.randn(n_samples, hidden_dim) + 2 * direction
|
|
129
|
-
neg = torch.randn(n_samples, hidden_dim) - 2 * direction
|
|
130
|
-
|
|
131
|
-
elif structure == "xor":
|
|
132
|
-
# XOR pattern: nonlinear but separable
|
|
133
|
-
base = torch.randn(n_samples, hidden_dim)
|
|
134
|
-
|
|
135
|
-
# Positive: (high dim1 AND high dim2) OR (low dim1 AND low dim2)
|
|
136
|
-
pos_mask1 = (base[:n_samples//2, 0] > 0) & (base[:n_samples//2, 1] > 0)
|
|
137
|
-
pos_mask2 = (base[n_samples//2:, 0] < 0) & (base[n_samples//2:, 1] < 0)
|
|
138
|
-
|
|
139
|
-
pos = torch.randn(n_samples, hidden_dim)
|
|
140
|
-
pos[:n_samples//2, 0] = torch.abs(pos[:n_samples//2, 0]) + 1
|
|
141
|
-
pos[:n_samples//2, 1] = torch.abs(pos[:n_samples//2, 1]) + 1
|
|
142
|
-
pos[n_samples//2:, 0] = -torch.abs(pos[n_samples//2:, 0]) - 1
|
|
143
|
-
pos[n_samples//2:, 1] = -torch.abs(pos[n_samples//2:, 1]) - 1
|
|
144
|
-
|
|
145
|
-
neg = torch.randn(n_samples, hidden_dim)
|
|
146
|
-
neg[:n_samples//2, 0] = torch.abs(neg[:n_samples//2, 0]) + 1
|
|
147
|
-
neg[:n_samples//2, 1] = -torch.abs(neg[:n_samples//2, 1]) - 1
|
|
148
|
-
neg[n_samples//2:, 0] = -torch.abs(neg[n_samples//2:, 0]) - 1
|
|
149
|
-
neg[n_samples//2:, 1] = torch.abs(neg[n_samples//2:, 1]) + 1
|
|
150
|
-
|
|
151
|
-
elif structure == "spirals":
|
|
152
|
-
# Interleaved spirals: nonlinear separable
|
|
153
|
-
t_pos = torch.linspace(0, 4*np.pi, n_samples)
|
|
154
|
-
t_neg = torch.linspace(0, 4*np.pi, n_samples) + np.pi
|
|
155
|
-
|
|
156
|
-
pos = torch.zeros(n_samples, hidden_dim)
|
|
157
|
-
pos[:, 0] = t_pos * torch.cos(t_pos) + 0.5 * torch.randn(n_samples)
|
|
158
|
-
pos[:, 1] = t_pos * torch.sin(t_pos) + 0.5 * torch.randn(n_samples)
|
|
159
|
-
pos[:, 2:] = torch.randn(n_samples, hidden_dim - 2) * 0.1
|
|
160
|
-
|
|
161
|
-
neg = torch.zeros(n_samples, hidden_dim)
|
|
162
|
-
neg[:, 0] = t_neg * torch.cos(t_neg) + 0.5 * torch.randn(n_samples)
|
|
163
|
-
neg[:, 1] = t_neg * torch.sin(t_neg) + 0.5 * torch.randn(n_samples)
|
|
164
|
-
neg[:, 2:] = torch.randn(n_samples, hidden_dim - 2) * 0.1
|
|
165
|
-
|
|
166
|
-
else: # random
|
|
167
|
-
pos = torch.randn(n_samples, hidden_dim)
|
|
168
|
-
neg = torch.randn(n_samples, hidden_dim)
|
|
169
|
-
|
|
170
|
-
return pos, neg
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
def compute_roc_for_existence(
|
|
174
|
-
real_results: List[Dict],
|
|
175
|
-
null_scores: List[float],
|
|
176
|
-
) -> Tuple[List[float], List[float], List[float], float]:
|
|
177
|
-
"""
|
|
178
|
-
Compute ROC curve for existence threshold.
|
|
179
|
-
|
|
180
|
-
Args:
|
|
181
|
-
real_results: Results from real benchmarks
|
|
182
|
-
null_scores: kNN scores from null distribution
|
|
183
|
-
|
|
184
|
-
Returns:
|
|
185
|
-
(thresholds, tpr, fpr, auc)
|
|
186
|
-
"""
|
|
187
|
-
# Labels: 1 for real data (should be detected), 0 for null (should not)
|
|
188
|
-
real_knn = [r["nonlinear_metrics"]["knn_accuracy_k10"] for r in real_results]
|
|
189
|
-
|
|
190
|
-
scores = real_knn + null_scores
|
|
191
|
-
labels = [1] * len(real_knn) + [0] * len(null_scores)
|
|
192
|
-
|
|
193
|
-
fpr, tpr, thresholds = roc_curve(labels, scores)
|
|
194
|
-
roc_auc = auc(fpr, tpr)
|
|
195
|
-
|
|
196
|
-
return thresholds.tolist(), tpr.tolist(), fpr.tolist(), roc_auc
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
def compute_precision_recall_for_gap(
|
|
200
|
-
results: List[Dict],
|
|
201
|
-
ground_truth_linear: List[bool],
|
|
202
|
-
) -> Tuple[List[float], List[float], List[float], List[float]]:
|
|
203
|
-
"""
|
|
204
|
-
Compute precision-recall for gap threshold (linear vs nonlinear).
|
|
205
|
-
|
|
206
|
-
Args:
|
|
207
|
-
results: Results from benchmarks
|
|
208
|
-
ground_truth_linear: Ground truth labels (True = linear, False = nonlinear)
|
|
209
|
-
|
|
210
|
-
Returns:
|
|
211
|
-
(thresholds, precision, recall, f1)
|
|
212
|
-
"""
|
|
213
|
-
# Gap = signal_strength - linear_probe_accuracy
|
|
214
|
-
gaps = [r["signal_strength"] - r["linear_probe_accuracy"] for r in results]
|
|
215
|
-
|
|
216
|
-
# Labels: 1 for nonlinear (gap > threshold), 0 for linear
|
|
217
|
-
labels = [0 if gt else 1 for gt in ground_truth_linear]
|
|
218
|
-
|
|
219
|
-
precision, recall, thresholds = precision_recall_curve(labels, gaps)
|
|
220
|
-
|
|
221
|
-
# Compute F1
|
|
222
|
-
f1 = [2 * p * r / (p + r + 1e-10) for p, r in zip(precision, recall)]
|
|
223
|
-
|
|
224
|
-
return thresholds.tolist(), precision.tolist(), recall.tolist(), f1
|
|
225
|
-
|
|
226
|
-
|
|
227
|
-
def run_sensitivity_analysis(
|
|
228
|
-
results: List[Dict],
|
|
229
|
-
existence_thresholds: List[float] = [0.5, 0.55, 0.6, 0.65, 0.7],
|
|
230
|
-
gap_thresholds: List[float] = [0.05, 0.10, 0.15, 0.20, 0.25],
|
|
231
|
-
) -> Dict[str, Dict[str, float]]:
|
|
232
|
-
"""
|
|
233
|
-
Run sensitivity analysis across threshold combinations.
|
|
234
|
-
|
|
235
|
-
Args:
|
|
236
|
-
results: Results from benchmarks
|
|
237
|
-
existence_thresholds: Thresholds to test for existence
|
|
238
|
-
gap_thresholds: Thresholds to test for gap
|
|
239
|
-
|
|
240
|
-
Returns:
|
|
241
|
-
Nested dict: {exist_thresh: {gap_thresh: {diagnosis: percentage}}}
|
|
242
|
-
"""
|
|
243
|
-
sensitivity = {}
|
|
244
|
-
|
|
245
|
-
for exist_t in existence_thresholds:
|
|
246
|
-
sensitivity[str(exist_t)] = {}
|
|
247
|
-
|
|
248
|
-
for gap_t in gap_thresholds:
|
|
249
|
-
diagnoses = {"LINEAR": 0, "NONLINEAR": 0, "NO_SIGNAL": 0}
|
|
250
|
-
|
|
251
|
-
for r in results:
|
|
252
|
-
signal = r["signal_strength"]
|
|
253
|
-
gap = signal - r["linear_probe_accuracy"]
|
|
254
|
-
|
|
255
|
-
if signal < exist_t:
|
|
256
|
-
diagnoses["NO_SIGNAL"] += 1
|
|
257
|
-
elif gap < gap_t:
|
|
258
|
-
diagnoses["LINEAR"] += 1
|
|
259
|
-
else:
|
|
260
|
-
diagnoses["NONLINEAR"] += 1
|
|
261
|
-
|
|
262
|
-
total = len(results)
|
|
263
|
-
sensitivity[str(exist_t)][str(gap_t)] = {
|
|
264
|
-
k: v / total * 100 for k, v in diagnoses.items()
|
|
265
|
-
}
|
|
266
|
-
|
|
267
|
-
return sensitivity
|
|
268
|
-
|
|
269
|
-
|
|
270
|
-
def load_diagnosis_results(model_name: str, output_dir: Path) -> List[Dict]:
|
|
271
|
-
"""Load all diagnosis results."""
|
|
272
|
-
model_prefix = model_name.replace('/', '_')
|
|
273
|
-
|
|
274
|
-
# Try to download from S3
|
|
275
|
-
try:
|
|
276
|
-
subprocess.run(
|
|
277
|
-
["aws", "s3", "sync",
|
|
278
|
-
f"s3://{S3_BUCKET}/direction_discovery/{model_prefix}/",
|
|
279
|
-
str(output_dir / "diagnosis"),
|
|
280
|
-
"--quiet"],
|
|
281
|
-
check=False,
|
|
282
|
-
capture_output=True,
|
|
283
|
-
)
|
|
284
|
-
except Exception:
|
|
285
|
-
pass
|
|
286
|
-
|
|
287
|
-
# Load all results
|
|
288
|
-
all_results = []
|
|
289
|
-
diagnosis_dir = output_dir / "diagnosis"
|
|
290
|
-
|
|
291
|
-
if diagnosis_dir.exists():
|
|
292
|
-
for f in diagnosis_dir.glob(f"{model_prefix}_*.json"):
|
|
293
|
-
if "summary" not in f.name:
|
|
294
|
-
with open(f) as fp:
|
|
295
|
-
data = json.load(fp)
|
|
296
|
-
all_results.extend(data.get("results", []))
|
|
297
|
-
|
|
298
|
-
return all_results
|
|
299
|
-
|
|
300
|
-
|
|
301
|
-
def run_threshold_analysis(model_name: str):
|
|
302
|
-
"""
|
|
303
|
-
Run full threshold analysis.
|
|
304
|
-
|
|
305
|
-
Args:
|
|
306
|
-
model_name: Model to analyze
|
|
307
|
-
"""
|
|
308
|
-
print("=" * 70)
|
|
309
|
-
print("THRESHOLD ANALYSIS")
|
|
310
|
-
print("=" * 70)
|
|
311
|
-
print(f"Model: {model_name}")
|
|
312
|
-
|
|
313
|
-
output_dir = Path("/tmp/threshold_analysis")
|
|
314
|
-
output_dir.mkdir(parents=True, exist_ok=True)
|
|
315
|
-
|
|
316
|
-
# Load diagnosis results
|
|
317
|
-
results = load_diagnosis_results(model_name, output_dir)
|
|
318
|
-
if not results:
|
|
319
|
-
print("ERROR: No diagnosis results found.")
|
|
320
|
-
return
|
|
321
|
-
|
|
322
|
-
print(f"Loaded {len(results)} results")
|
|
323
|
-
|
|
324
|
-
# 1. Generate null distribution
|
|
325
|
-
print("\n1. Generating null distribution...")
|
|
326
|
-
null_knn, null_linear = generate_null_distribution(None, n_samples=100, hidden_dim=4096)
|
|
327
|
-
|
|
328
|
-
print(f" Null kNN: mean={np.mean(null_knn):.3f}, std={np.std(null_knn):.3f}")
|
|
329
|
-
print(f" Null linear: mean={np.mean(null_linear):.3f}, std={np.std(null_linear):.3f}")
|
|
330
|
-
|
|
331
|
-
# 2. ROC for existence threshold
|
|
332
|
-
print("\n2. Computing ROC for existence threshold...")
|
|
333
|
-
thresholds, tpr, fpr, roc_auc = compute_roc_for_existence(results, null_knn)
|
|
334
|
-
|
|
335
|
-
# Find optimal threshold (Youden's J)
|
|
336
|
-
j_scores = [t - f for t, f in zip(tpr, fpr)]
|
|
337
|
-
optimal_idx = np.argmax(j_scores)
|
|
338
|
-
optimal_exist = thresholds[optimal_idx] if optimal_idx < len(thresholds) else 0.6
|
|
339
|
-
|
|
340
|
-
print(f" AUC: {roc_auc:.3f}")
|
|
341
|
-
print(f" Optimal existence threshold: {optimal_exist:.3f}")
|
|
342
|
-
|
|
343
|
-
# 3. Synthetic validation
|
|
344
|
-
print("\n3. Synthetic validation...")
|
|
345
|
-
from wisent.core.geometry_runner import compute_knn_accuracy, compute_linear_probe_accuracy
|
|
346
|
-
|
|
347
|
-
synthetic_results = {}
|
|
348
|
-
for structure in ["linear", "xor", "spirals", "random"]:
|
|
349
|
-
pos, neg = generate_synthetic_data(structure)
|
|
350
|
-
knn = compute_knn_accuracy(pos, neg, k=10)
|
|
351
|
-
linear = compute_linear_probe_accuracy(pos, neg)
|
|
352
|
-
gap = knn - linear
|
|
353
|
-
|
|
354
|
-
synthetic_results[structure] = {
|
|
355
|
-
"knn": knn,
|
|
356
|
-
"linear": linear,
|
|
357
|
-
"gap": gap,
|
|
358
|
-
}
|
|
359
|
-
print(f" {structure}: kNN={knn:.3f}, linear={linear:.3f}, gap={gap:.3f}")
|
|
360
|
-
|
|
361
|
-
# Validate that gap threshold separates linear from nonlinear
|
|
362
|
-
linear_gap = synthetic_results["linear"]["gap"]
|
|
363
|
-
xor_gap = synthetic_results["xor"]["gap"]
|
|
364
|
-
spirals_gap = synthetic_results["spirals"]["gap"]
|
|
365
|
-
|
|
366
|
-
# Good gap threshold should be > linear_gap and < min(xor_gap, spirals_gap)
|
|
367
|
-
optimal_gap = (linear_gap + min(xor_gap, spirals_gap)) / 2
|
|
368
|
-
print(f"\n Suggested gap threshold: {optimal_gap:.3f}")
|
|
369
|
-
|
|
370
|
-
# 4. Sensitivity analysis
|
|
371
|
-
print("\n4. Running sensitivity analysis...")
|
|
372
|
-
sensitivity = run_sensitivity_analysis(results)
|
|
373
|
-
|
|
374
|
-
print("\n Diagnosis distribution (% of benchmarks):")
|
|
375
|
-
print(" " + "-" * 60)
|
|
376
|
-
print(f" {'Exist':>6} | {'Gap':>6} | {'LINEAR':>8} | {'NONLINEAR':>10} | {'NO_SIGNAL':>10}")
|
|
377
|
-
print(" " + "-" * 60)
|
|
378
|
-
|
|
379
|
-
for exist_t, gap_data in sensitivity.items():
|
|
380
|
-
for gap_t, diagnoses in gap_data.items():
|
|
381
|
-
print(f" {exist_t:>6} | {gap_t:>6} | {diagnoses['LINEAR']:>7.1f}% | "
|
|
382
|
-
f"{diagnoses['NONLINEAR']:>9.1f}% | {diagnoses['NO_SIGNAL']:>9.1f}%")
|
|
383
|
-
|
|
384
|
-
# 5. Save results
|
|
385
|
-
analysis_result = ThresholdAnalysisResult(
|
|
386
|
-
existence_thresholds=thresholds[:100], # Limit for JSON
|
|
387
|
-
existence_tpr=tpr[:100],
|
|
388
|
-
existence_fpr=fpr[:100],
|
|
389
|
-
existence_auc=roc_auc,
|
|
390
|
-
optimal_existence_threshold=float(optimal_exist),
|
|
391
|
-
gap_thresholds=[0.05, 0.10, 0.15, 0.20, 0.25],
|
|
392
|
-
gap_precision=[], # Would need ground truth
|
|
393
|
-
gap_recall=[],
|
|
394
|
-
gap_f1=[],
|
|
395
|
-
optimal_gap_threshold=float(optimal_gap),
|
|
396
|
-
null_mean_knn=float(np.mean(null_knn)),
|
|
397
|
-
null_std_knn=float(np.std(null_knn)),
|
|
398
|
-
null_mean_linear=float(np.mean(null_linear)),
|
|
399
|
-
null_std_linear=float(np.std(null_linear)),
|
|
400
|
-
sensitivity_matrix=sensitivity,
|
|
401
|
-
)
|
|
402
|
-
|
|
403
|
-
model_prefix = model_name.replace('/', '_')
|
|
404
|
-
results_file = output_dir / f"{model_prefix}_threshold_analysis.json"
|
|
405
|
-
|
|
406
|
-
with open(results_file, "w") as f:
|
|
407
|
-
json.dump(asdict(analysis_result), f, indent=2)
|
|
408
|
-
|
|
409
|
-
print(f"\nResults saved to: {results_file}")
|
|
410
|
-
s3_upload_file(results_file, model_name)
|
|
411
|
-
|
|
412
|
-
# Summary
|
|
413
|
-
print("\n" + "=" * 70)
|
|
414
|
-
print("RECOMMENDATIONS")
|
|
415
|
-
print("=" * 70)
|
|
416
|
-
print(f"\n1. Existence threshold: {optimal_exist:.2f}")
|
|
417
|
-
print(f" - Based on ROC analysis (AUC={roc_auc:.3f})")
|
|
418
|
-
print(f" - Null distribution: kNN={np.mean(null_knn):.3f} ± {np.std(null_knn):.3f}")
|
|
419
|
-
|
|
420
|
-
print(f"\n2. Gap threshold: {optimal_gap:.2f}")
|
|
421
|
-
print(f" - Based on synthetic validation")
|
|
422
|
-
print(f" - Linear structure gap: {linear_gap:.3f}")
|
|
423
|
-
print(f" - XOR structure gap: {xor_gap:.3f}")
|
|
424
|
-
print(f" - Spirals structure gap: {spirals_gap:.3f}")
|
|
425
|
-
|
|
426
|
-
return analysis_result
|
|
427
|
-
|
|
428
|
-
|
|
429
|
-
if __name__ == "__main__":
|
|
430
|
-
parser = argparse.ArgumentParser(description="Threshold analysis for RepScan")
|
|
431
|
-
parser.add_argument("--model", type=str, default="Qwen/Qwen3-8B", help="Model to analyze")
|
|
432
|
-
args = parser.parse_args()
|
|
433
|
-
|
|
434
|
-
run_threshold_analysis(args.model)
|