wisent 0.7.901__py3-none-any.whl → 0.7.1116__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (101) hide show
  1. wisent/__init__.py +1 -1
  2. wisent/comparison/__init__.py +1 -0
  3. wisent/comparison/detect_bos_features.py +275 -0
  4. wisent/comparison/fgaa.py +465 -0
  5. wisent/comparison/lora.py +663 -0
  6. wisent/comparison/lora_dpo.py +604 -0
  7. wisent/comparison/main.py +444 -0
  8. wisent/comparison/ours.py +76 -0
  9. wisent/comparison/reft.py +690 -0
  10. wisent/comparison/sae.py +304 -0
  11. wisent/comparison/utils.py +381 -0
  12. wisent/core/activations/activations_collector.py +3 -2
  13. wisent/core/activations/extraction_strategy.py +8 -4
  14. wisent/core/cli/agent/apply_steering.py +7 -5
  15. wisent/core/cli/agent/train_classifier.py +4 -3
  16. wisent/core/cli/generate_vector_from_task.py +11 -20
  17. wisent/core/cli/get_activations.py +1 -1
  18. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/boolq.py +20 -3
  19. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/cb.py +8 -1
  20. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/truthfulqa_mc1.py +8 -1
  21. wisent/core/parser_arguments/generate_vector_from_task_parser.py +4 -11
  22. wisent/core/parser_arguments/get_activations_parser.py +5 -14
  23. {wisent-0.7.901.dist-info → wisent-0.7.1116.dist-info}/METADATA +5 -1
  24. {wisent-0.7.901.dist-info → wisent-0.7.1116.dist-info}/RECORD +28 -91
  25. wisent/examples/contrastive_pairs/humanization_human_vs_ai.json +0 -2112
  26. wisent/examples/scripts/1/test_basqueglue_evaluation.json +0 -51
  27. wisent/examples/scripts/1/test_basqueglue_pairs.json +0 -14
  28. wisent/examples/scripts/1/test_bec2016eu_evaluation.json +0 -51
  29. wisent/examples/scripts/1/test_bec2016eu_pairs.json +0 -14
  30. wisent/examples/scripts/1/test_belebele_evaluation.json +0 -51
  31. wisent/examples/scripts/1/test_belebele_pairs.json +0 -14
  32. wisent/examples/scripts/1/test_benchmarks_evaluation.json +0 -51
  33. wisent/examples/scripts/1/test_benchmarks_pairs.json +0 -14
  34. wisent/examples/scripts/1/test_bertaqa_evaluation.json +0 -51
  35. wisent/examples/scripts/1/test_bertaqa_pairs.json +0 -14
  36. wisent/examples/scripts/1/test_bhtc_v2_evaluation.json +0 -30
  37. wisent/examples/scripts/1/test_bhtc_v2_pairs.json +0 -8
  38. wisent/examples/scripts/1/test_boolq-seq2seq_evaluation.json +0 -30
  39. wisent/examples/scripts/1/test_boolq-seq2seq_pairs.json +0 -8
  40. wisent/examples/scripts/1/test_cabreu_evaluation.json +0 -30
  41. wisent/examples/scripts/1/test_cabreu_pairs.json +0 -8
  42. wisent/examples/scripts/1/test_careqa_en_evaluation.json +0 -30
  43. wisent/examples/scripts/1/test_careqa_en_pairs.json +0 -8
  44. wisent/examples/scripts/1/test_careqa_evaluation.json +0 -30
  45. wisent/examples/scripts/1/test_careqa_pairs.json +0 -8
  46. wisent/examples/scripts/1/test_catalanqa_evaluation.json +0 -30
  47. wisent/examples/scripts/1/test_catalanqa_pairs.json +0 -8
  48. wisent/examples/scripts/1/test_catcola_evaluation.json +0 -30
  49. wisent/examples/scripts/1/test_catcola_pairs.json +0 -8
  50. wisent/examples/scripts/1/test_chartqa_evaluation.json +0 -30
  51. wisent/examples/scripts/1/test_chartqa_pairs.json +0 -8
  52. wisent/examples/scripts/1/test_claim_stance_topic_evaluation.json +0 -30
  53. wisent/examples/scripts/1/test_claim_stance_topic_pairs.json +0 -8
  54. wisent/examples/scripts/1/test_cnn_dailymail_evaluation.json +0 -30
  55. wisent/examples/scripts/1/test_cnn_dailymail_pairs.json +0 -8
  56. wisent/examples/scripts/1/test_cocoteros_es_evaluation.json +0 -30
  57. wisent/examples/scripts/1/test_cocoteros_es_pairs.json +0 -8
  58. wisent/examples/scripts/1/test_coedit_gec_evaluation.json +0 -30
  59. wisent/examples/scripts/1/test_coedit_gec_pairs.json +0 -8
  60. wisent/examples/scripts/1/test_cola_evaluation.json +0 -30
  61. wisent/examples/scripts/1/test_cola_pairs.json +0 -8
  62. wisent/examples/scripts/1/test_coqcat_evaluation.json +0 -30
  63. wisent/examples/scripts/1/test_coqcat_pairs.json +0 -8
  64. wisent/examples/scripts/1/test_dbpedia_14_evaluation.json +0 -30
  65. wisent/examples/scripts/1/test_dbpedia_14_pairs.json +0 -8
  66. wisent/examples/scripts/1/test_epec_koref_bin_evaluation.json +0 -30
  67. wisent/examples/scripts/1/test_epec_koref_bin_pairs.json +0 -8
  68. wisent/examples/scripts/1/test_ethos_binary_evaluation.json +0 -30
  69. wisent/examples/scripts/1/test_ethos_binary_pairs.json +0 -8
  70. wisent/examples/scripts/2/test_afrimgsm_direct_amh_evaluation.json +0 -30
  71. wisent/examples/scripts/2/test_afrimgsm_direct_amh_pairs.json +0 -8
  72. wisent/examples/scripts/2/test_afrimmlu_direct_amh_evaluation.json +0 -30
  73. wisent/examples/scripts/2/test_afrimmlu_direct_amh_pairs.json +0 -8
  74. wisent/examples/scripts/2/test_afrixnli_en_direct_amh_evaluation.json +0 -30
  75. wisent/examples/scripts/2/test_afrixnli_en_direct_amh_pairs.json +0 -8
  76. wisent/examples/scripts/2/test_arc_ar_evaluation.json +0 -30
  77. wisent/examples/scripts/2/test_arc_ar_pairs.json +0 -8
  78. wisent/examples/scripts/2/test_atis_evaluation.json +0 -30
  79. wisent/examples/scripts/2/test_atis_pairs.json +0 -8
  80. wisent/examples/scripts/2/test_babi_evaluation.json +0 -30
  81. wisent/examples/scripts/2/test_babi_pairs.json +0 -8
  82. wisent/examples/scripts/2/test_babilong_evaluation.json +0 -30
  83. wisent/examples/scripts/2/test_babilong_pairs.json +0 -8
  84. wisent/examples/scripts/2/test_bangla_mmlu_evaluation.json +0 -30
  85. wisent/examples/scripts/2/test_bangla_mmlu_pairs.json +0 -8
  86. wisent/examples/scripts/2/test_basque-glue_pairs.json +0 -14
  87. wisent/examples/scripts/generate_paper_data.py +0 -384
  88. wisent/examples/scripts/intervention_validation.py +0 -626
  89. wisent/examples/scripts/results/test_AraDiCE_ArabicMMLU_lev_evaluation.json +0 -324
  90. wisent/examples/scripts/results/test_AraDiCE_ArabicMMLU_lev_pairs.json +0 -92
  91. wisent/examples/scripts/results/test_aexams_IslamicStudies_evaluation.json +0 -324
  92. wisent/examples/scripts/results/test_aexams_IslamicStudies_pairs.json +0 -92
  93. wisent/examples/scripts/results/test_afrimgsm_pairs.json +0 -92
  94. wisent/examples/scripts/results/test_afrimmlu_evaluation.json +0 -324
  95. wisent/examples/scripts/results/test_afrimmlu_pairs.json +0 -92
  96. wisent/examples/scripts/threshold_analysis.py +0 -434
  97. wisent/examples/scripts/visualization_gallery.py +0 -582
  98. {wisent-0.7.901.dist-info → wisent-0.7.1116.dist-info}/WHEEL +0 -0
  99. {wisent-0.7.901.dist-info → wisent-0.7.1116.dist-info}/entry_points.txt +0 -0
  100. {wisent-0.7.901.dist-info → wisent-0.7.1116.dist-info}/licenses/LICENSE +0 -0
  101. {wisent-0.7.901.dist-info → wisent-0.7.1116.dist-info}/top_level.txt +0 -0
@@ -1,434 +0,0 @@
1
- """
2
- Threshold Analysis for RepScan.
3
-
4
- Analyzes sensitivity of diagnosis to threshold choices:
5
- - ROC curves for existence threshold
6
- - Precision/recall tradeoff
7
- - Null distribution analysis
8
- - Synthetic validation
9
-
10
- Usage:
11
- python -m wisent.examples.scripts.threshold_analysis --model Qwen/Qwen3-8B
12
- """
13
-
14
- import argparse
15
- import json
16
- import subprocess
17
- from pathlib import Path
18
- from typing import Dict, List, Any, Optional, Tuple
19
- from dataclasses import dataclass, field, asdict
20
- import random
21
-
22
- import torch
23
- import numpy as np
24
- from sklearn.metrics import roc_curve, auc, precision_recall_curve
25
-
26
- S3_BUCKET = "wisent-bucket"
27
- S3_PREFIX = "threshold_analysis"
28
-
29
-
30
- def s3_upload_file(local_path: Path, model_name: str) -> None:
31
- """Upload a single file to S3."""
32
- model_prefix = model_name.replace('/', '_')
33
- s3_path = f"s3://{S3_BUCKET}/{S3_PREFIX}/{model_prefix}/{local_path.name}"
34
- try:
35
- subprocess.run(
36
- ["aws", "s3", "cp", str(local_path), s3_path, "--quiet"],
37
- check=True,
38
- capture_output=True,
39
- )
40
- print(f" Uploaded to S3: {s3_path}")
41
- except Exception as e:
42
- print(f" S3 upload failed: {e}")
43
-
44
-
45
- @dataclass
46
- class ThresholdAnalysisResult:
47
- """Result of threshold analysis."""
48
- # Existence threshold analysis
49
- existence_thresholds: List[float]
50
- existence_tpr: List[float] # True positive rate
51
- existence_fpr: List[float] # False positive rate
52
- existence_auc: float
53
- optimal_existence_threshold: float
54
-
55
- # Gap threshold analysis
56
- gap_thresholds: List[float]
57
- gap_precision: List[float]
58
- gap_recall: List[float]
59
- gap_f1: List[float]
60
- optimal_gap_threshold: float
61
-
62
- # Null distribution stats
63
- null_mean_knn: float
64
- null_std_knn: float
65
- null_mean_linear: float
66
- null_std_linear: float
67
-
68
- # Sensitivity analysis
69
- sensitivity_matrix: Dict[str, Dict[str, float]] # threshold -> diagnosis distribution
70
-
71
-
72
- def generate_null_distribution(
73
- model: "WisentModel",
74
- n_samples: int = 100,
75
- hidden_dim: int = 4096,
76
- ) -> Tuple[List[float], List[float]]:
77
- """
78
- Generate null distribution by testing random/nonsense data.
79
-
80
- Args:
81
- model: WisentModel instance
82
- n_samples: Number of random samples
83
- hidden_dim: Hidden dimension
84
-
85
- Returns:
86
- (knn_scores, linear_scores) for random data
87
- """
88
- from wisent.core.geometry_runner import compute_knn_accuracy, compute_linear_probe_accuracy
89
-
90
- knn_scores = []
91
- linear_scores = []
92
-
93
- for _ in range(n_samples):
94
- # Generate random activations (no real signal)
95
- pos = torch.randn(50, hidden_dim)
96
- neg = torch.randn(50, hidden_dim)
97
-
98
- knn = compute_knn_accuracy(pos, neg, k=10)
99
- linear = compute_linear_probe_accuracy(pos, neg)
100
-
101
- knn_scores.append(knn)
102
- linear_scores.append(linear)
103
-
104
- return knn_scores, linear_scores
105
-
106
-
107
- def generate_synthetic_data(
108
- structure: str,
109
- n_samples: int = 50,
110
- hidden_dim: int = 100,
111
- ) -> Tuple[torch.Tensor, torch.Tensor]:
112
- """
113
- Generate synthetic data with known structure for validation.
114
-
115
- Args:
116
- structure: 'linear', 'xor', 'spirals', 'random'
117
- n_samples: Samples per class
118
- hidden_dim: Dimension
119
-
120
- Returns:
121
- (pos_activations, neg_activations)
122
- """
123
- if structure == "linear":
124
- # Linear separable: positive class shifted in one direction
125
- direction = torch.randn(hidden_dim)
126
- direction = direction / direction.norm()
127
-
128
- pos = torch.randn(n_samples, hidden_dim) + 2 * direction
129
- neg = torch.randn(n_samples, hidden_dim) - 2 * direction
130
-
131
- elif structure == "xor":
132
- # XOR pattern: nonlinear but separable
133
- base = torch.randn(n_samples, hidden_dim)
134
-
135
- # Positive: (high dim1 AND high dim2) OR (low dim1 AND low dim2)
136
- pos_mask1 = (base[:n_samples//2, 0] > 0) & (base[:n_samples//2, 1] > 0)
137
- pos_mask2 = (base[n_samples//2:, 0] < 0) & (base[n_samples//2:, 1] < 0)
138
-
139
- pos = torch.randn(n_samples, hidden_dim)
140
- pos[:n_samples//2, 0] = torch.abs(pos[:n_samples//2, 0]) + 1
141
- pos[:n_samples//2, 1] = torch.abs(pos[:n_samples//2, 1]) + 1
142
- pos[n_samples//2:, 0] = -torch.abs(pos[n_samples//2:, 0]) - 1
143
- pos[n_samples//2:, 1] = -torch.abs(pos[n_samples//2:, 1]) - 1
144
-
145
- neg = torch.randn(n_samples, hidden_dim)
146
- neg[:n_samples//2, 0] = torch.abs(neg[:n_samples//2, 0]) + 1
147
- neg[:n_samples//2, 1] = -torch.abs(neg[:n_samples//2, 1]) - 1
148
- neg[n_samples//2:, 0] = -torch.abs(neg[n_samples//2:, 0]) - 1
149
- neg[n_samples//2:, 1] = torch.abs(neg[n_samples//2:, 1]) + 1
150
-
151
- elif structure == "spirals":
152
- # Interleaved spirals: nonlinear separable
153
- t_pos = torch.linspace(0, 4*np.pi, n_samples)
154
- t_neg = torch.linspace(0, 4*np.pi, n_samples) + np.pi
155
-
156
- pos = torch.zeros(n_samples, hidden_dim)
157
- pos[:, 0] = t_pos * torch.cos(t_pos) + 0.5 * torch.randn(n_samples)
158
- pos[:, 1] = t_pos * torch.sin(t_pos) + 0.5 * torch.randn(n_samples)
159
- pos[:, 2:] = torch.randn(n_samples, hidden_dim - 2) * 0.1
160
-
161
- neg = torch.zeros(n_samples, hidden_dim)
162
- neg[:, 0] = t_neg * torch.cos(t_neg) + 0.5 * torch.randn(n_samples)
163
- neg[:, 1] = t_neg * torch.sin(t_neg) + 0.5 * torch.randn(n_samples)
164
- neg[:, 2:] = torch.randn(n_samples, hidden_dim - 2) * 0.1
165
-
166
- else: # random
167
- pos = torch.randn(n_samples, hidden_dim)
168
- neg = torch.randn(n_samples, hidden_dim)
169
-
170
- return pos, neg
171
-
172
-
173
- def compute_roc_for_existence(
174
- real_results: List[Dict],
175
- null_scores: List[float],
176
- ) -> Tuple[List[float], List[float], List[float], float]:
177
- """
178
- Compute ROC curve for existence threshold.
179
-
180
- Args:
181
- real_results: Results from real benchmarks
182
- null_scores: kNN scores from null distribution
183
-
184
- Returns:
185
- (thresholds, tpr, fpr, auc)
186
- """
187
- # Labels: 1 for real data (should be detected), 0 for null (should not)
188
- real_knn = [r["nonlinear_metrics"]["knn_accuracy_k10"] for r in real_results]
189
-
190
- scores = real_knn + null_scores
191
- labels = [1] * len(real_knn) + [0] * len(null_scores)
192
-
193
- fpr, tpr, thresholds = roc_curve(labels, scores)
194
- roc_auc = auc(fpr, tpr)
195
-
196
- return thresholds.tolist(), tpr.tolist(), fpr.tolist(), roc_auc
197
-
198
-
199
- def compute_precision_recall_for_gap(
200
- results: List[Dict],
201
- ground_truth_linear: List[bool],
202
- ) -> Tuple[List[float], List[float], List[float], List[float]]:
203
- """
204
- Compute precision-recall for gap threshold (linear vs nonlinear).
205
-
206
- Args:
207
- results: Results from benchmarks
208
- ground_truth_linear: Ground truth labels (True = linear, False = nonlinear)
209
-
210
- Returns:
211
- (thresholds, precision, recall, f1)
212
- """
213
- # Gap = signal_strength - linear_probe_accuracy
214
- gaps = [r["signal_strength"] - r["linear_probe_accuracy"] for r in results]
215
-
216
- # Labels: 1 for nonlinear (gap > threshold), 0 for linear
217
- labels = [0 if gt else 1 for gt in ground_truth_linear]
218
-
219
- precision, recall, thresholds = precision_recall_curve(labels, gaps)
220
-
221
- # Compute F1
222
- f1 = [2 * p * r / (p + r + 1e-10) for p, r in zip(precision, recall)]
223
-
224
- return thresholds.tolist(), precision.tolist(), recall.tolist(), f1
225
-
226
-
227
- def run_sensitivity_analysis(
228
- results: List[Dict],
229
- existence_thresholds: List[float] = [0.5, 0.55, 0.6, 0.65, 0.7],
230
- gap_thresholds: List[float] = [0.05, 0.10, 0.15, 0.20, 0.25],
231
- ) -> Dict[str, Dict[str, float]]:
232
- """
233
- Run sensitivity analysis across threshold combinations.
234
-
235
- Args:
236
- results: Results from benchmarks
237
- existence_thresholds: Thresholds to test for existence
238
- gap_thresholds: Thresholds to test for gap
239
-
240
- Returns:
241
- Nested dict: {exist_thresh: {gap_thresh: {diagnosis: percentage}}}
242
- """
243
- sensitivity = {}
244
-
245
- for exist_t in existence_thresholds:
246
- sensitivity[str(exist_t)] = {}
247
-
248
- for gap_t in gap_thresholds:
249
- diagnoses = {"LINEAR": 0, "NONLINEAR": 0, "NO_SIGNAL": 0}
250
-
251
- for r in results:
252
- signal = r["signal_strength"]
253
- gap = signal - r["linear_probe_accuracy"]
254
-
255
- if signal < exist_t:
256
- diagnoses["NO_SIGNAL"] += 1
257
- elif gap < gap_t:
258
- diagnoses["LINEAR"] += 1
259
- else:
260
- diagnoses["NONLINEAR"] += 1
261
-
262
- total = len(results)
263
- sensitivity[str(exist_t)][str(gap_t)] = {
264
- k: v / total * 100 for k, v in diagnoses.items()
265
- }
266
-
267
- return sensitivity
268
-
269
-
270
- def load_diagnosis_results(model_name: str, output_dir: Path) -> List[Dict]:
271
- """Load all diagnosis results."""
272
- model_prefix = model_name.replace('/', '_')
273
-
274
- # Try to download from S3
275
- try:
276
- subprocess.run(
277
- ["aws", "s3", "sync",
278
- f"s3://{S3_BUCKET}/direction_discovery/{model_prefix}/",
279
- str(output_dir / "diagnosis"),
280
- "--quiet"],
281
- check=False,
282
- capture_output=True,
283
- )
284
- except Exception:
285
- pass
286
-
287
- # Load all results
288
- all_results = []
289
- diagnosis_dir = output_dir / "diagnosis"
290
-
291
- if diagnosis_dir.exists():
292
- for f in diagnosis_dir.glob(f"{model_prefix}_*.json"):
293
- if "summary" not in f.name:
294
- with open(f) as fp:
295
- data = json.load(fp)
296
- all_results.extend(data.get("results", []))
297
-
298
- return all_results
299
-
300
-
301
- def run_threshold_analysis(model_name: str):
302
- """
303
- Run full threshold analysis.
304
-
305
- Args:
306
- model_name: Model to analyze
307
- """
308
- print("=" * 70)
309
- print("THRESHOLD ANALYSIS")
310
- print("=" * 70)
311
- print(f"Model: {model_name}")
312
-
313
- output_dir = Path("/tmp/threshold_analysis")
314
- output_dir.mkdir(parents=True, exist_ok=True)
315
-
316
- # Load diagnosis results
317
- results = load_diagnosis_results(model_name, output_dir)
318
- if not results:
319
- print("ERROR: No diagnosis results found.")
320
- return
321
-
322
- print(f"Loaded {len(results)} results")
323
-
324
- # 1. Generate null distribution
325
- print("\n1. Generating null distribution...")
326
- null_knn, null_linear = generate_null_distribution(None, n_samples=100, hidden_dim=4096)
327
-
328
- print(f" Null kNN: mean={np.mean(null_knn):.3f}, std={np.std(null_knn):.3f}")
329
- print(f" Null linear: mean={np.mean(null_linear):.3f}, std={np.std(null_linear):.3f}")
330
-
331
- # 2. ROC for existence threshold
332
- print("\n2. Computing ROC for existence threshold...")
333
- thresholds, tpr, fpr, roc_auc = compute_roc_for_existence(results, null_knn)
334
-
335
- # Find optimal threshold (Youden's J)
336
- j_scores = [t - f for t, f in zip(tpr, fpr)]
337
- optimal_idx = np.argmax(j_scores)
338
- optimal_exist = thresholds[optimal_idx] if optimal_idx < len(thresholds) else 0.6
339
-
340
- print(f" AUC: {roc_auc:.3f}")
341
- print(f" Optimal existence threshold: {optimal_exist:.3f}")
342
-
343
- # 3. Synthetic validation
344
- print("\n3. Synthetic validation...")
345
- from wisent.core.geometry_runner import compute_knn_accuracy, compute_linear_probe_accuracy
346
-
347
- synthetic_results = {}
348
- for structure in ["linear", "xor", "spirals", "random"]:
349
- pos, neg = generate_synthetic_data(structure)
350
- knn = compute_knn_accuracy(pos, neg, k=10)
351
- linear = compute_linear_probe_accuracy(pos, neg)
352
- gap = knn - linear
353
-
354
- synthetic_results[structure] = {
355
- "knn": knn,
356
- "linear": linear,
357
- "gap": gap,
358
- }
359
- print(f" {structure}: kNN={knn:.3f}, linear={linear:.3f}, gap={gap:.3f}")
360
-
361
- # Validate that gap threshold separates linear from nonlinear
362
- linear_gap = synthetic_results["linear"]["gap"]
363
- xor_gap = synthetic_results["xor"]["gap"]
364
- spirals_gap = synthetic_results["spirals"]["gap"]
365
-
366
- # Good gap threshold should be > linear_gap and < min(xor_gap, spirals_gap)
367
- optimal_gap = (linear_gap + min(xor_gap, spirals_gap)) / 2
368
- print(f"\n Suggested gap threshold: {optimal_gap:.3f}")
369
-
370
- # 4. Sensitivity analysis
371
- print("\n4. Running sensitivity analysis...")
372
- sensitivity = run_sensitivity_analysis(results)
373
-
374
- print("\n Diagnosis distribution (% of benchmarks):")
375
- print(" " + "-" * 60)
376
- print(f" {'Exist':>6} | {'Gap':>6} | {'LINEAR':>8} | {'NONLINEAR':>10} | {'NO_SIGNAL':>10}")
377
- print(" " + "-" * 60)
378
-
379
- for exist_t, gap_data in sensitivity.items():
380
- for gap_t, diagnoses in gap_data.items():
381
- print(f" {exist_t:>6} | {gap_t:>6} | {diagnoses['LINEAR']:>7.1f}% | "
382
- f"{diagnoses['NONLINEAR']:>9.1f}% | {diagnoses['NO_SIGNAL']:>9.1f}%")
383
-
384
- # 5. Save results
385
- analysis_result = ThresholdAnalysisResult(
386
- existence_thresholds=thresholds[:100], # Limit for JSON
387
- existence_tpr=tpr[:100],
388
- existence_fpr=fpr[:100],
389
- existence_auc=roc_auc,
390
- optimal_existence_threshold=float(optimal_exist),
391
- gap_thresholds=[0.05, 0.10, 0.15, 0.20, 0.25],
392
- gap_precision=[], # Would need ground truth
393
- gap_recall=[],
394
- gap_f1=[],
395
- optimal_gap_threshold=float(optimal_gap),
396
- null_mean_knn=float(np.mean(null_knn)),
397
- null_std_knn=float(np.std(null_knn)),
398
- null_mean_linear=float(np.mean(null_linear)),
399
- null_std_linear=float(np.std(null_linear)),
400
- sensitivity_matrix=sensitivity,
401
- )
402
-
403
- model_prefix = model_name.replace('/', '_')
404
- results_file = output_dir / f"{model_prefix}_threshold_analysis.json"
405
-
406
- with open(results_file, "w") as f:
407
- json.dump(asdict(analysis_result), f, indent=2)
408
-
409
- print(f"\nResults saved to: {results_file}")
410
- s3_upload_file(results_file, model_name)
411
-
412
- # Summary
413
- print("\n" + "=" * 70)
414
- print("RECOMMENDATIONS")
415
- print("=" * 70)
416
- print(f"\n1. Existence threshold: {optimal_exist:.2f}")
417
- print(f" - Based on ROC analysis (AUC={roc_auc:.3f})")
418
- print(f" - Null distribution: kNN={np.mean(null_knn):.3f} ± {np.std(null_knn):.3f}")
419
-
420
- print(f"\n2. Gap threshold: {optimal_gap:.2f}")
421
- print(f" - Based on synthetic validation")
422
- print(f" - Linear structure gap: {linear_gap:.3f}")
423
- print(f" - XOR structure gap: {xor_gap:.3f}")
424
- print(f" - Spirals structure gap: {spirals_gap:.3f}")
425
-
426
- return analysis_result
427
-
428
-
429
- if __name__ == "__main__":
430
- parser = argparse.ArgumentParser(description="Threshold analysis for RepScan")
431
- parser.add_argument("--model", type=str, default="Qwen/Qwen3-8B", help="Model to analyze")
432
- args = parser.parse_args()
433
-
434
- run_threshold_analysis(args.model)