wisent 0.7.901__py3-none-any.whl → 0.7.1116__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- wisent/__init__.py +1 -1
- wisent/comparison/__init__.py +1 -0
- wisent/comparison/detect_bos_features.py +275 -0
- wisent/comparison/fgaa.py +465 -0
- wisent/comparison/lora.py +663 -0
- wisent/comparison/lora_dpo.py +604 -0
- wisent/comparison/main.py +444 -0
- wisent/comparison/ours.py +76 -0
- wisent/comparison/reft.py +690 -0
- wisent/comparison/sae.py +304 -0
- wisent/comparison/utils.py +381 -0
- wisent/core/activations/activations_collector.py +3 -2
- wisent/core/activations/extraction_strategy.py +8 -4
- wisent/core/cli/agent/apply_steering.py +7 -5
- wisent/core/cli/agent/train_classifier.py +4 -3
- wisent/core/cli/generate_vector_from_task.py +11 -20
- wisent/core/cli/get_activations.py +1 -1
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/boolq.py +20 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/cb.py +8 -1
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/truthfulqa_mc1.py +8 -1
- wisent/core/parser_arguments/generate_vector_from_task_parser.py +4 -11
- wisent/core/parser_arguments/get_activations_parser.py +5 -14
- {wisent-0.7.901.dist-info → wisent-0.7.1116.dist-info}/METADATA +5 -1
- {wisent-0.7.901.dist-info → wisent-0.7.1116.dist-info}/RECORD +28 -91
- wisent/examples/contrastive_pairs/humanization_human_vs_ai.json +0 -2112
- wisent/examples/scripts/1/test_basqueglue_evaluation.json +0 -51
- wisent/examples/scripts/1/test_basqueglue_pairs.json +0 -14
- wisent/examples/scripts/1/test_bec2016eu_evaluation.json +0 -51
- wisent/examples/scripts/1/test_bec2016eu_pairs.json +0 -14
- wisent/examples/scripts/1/test_belebele_evaluation.json +0 -51
- wisent/examples/scripts/1/test_belebele_pairs.json +0 -14
- wisent/examples/scripts/1/test_benchmarks_evaluation.json +0 -51
- wisent/examples/scripts/1/test_benchmarks_pairs.json +0 -14
- wisent/examples/scripts/1/test_bertaqa_evaluation.json +0 -51
- wisent/examples/scripts/1/test_bertaqa_pairs.json +0 -14
- wisent/examples/scripts/1/test_bhtc_v2_evaluation.json +0 -30
- wisent/examples/scripts/1/test_bhtc_v2_pairs.json +0 -8
- wisent/examples/scripts/1/test_boolq-seq2seq_evaluation.json +0 -30
- wisent/examples/scripts/1/test_boolq-seq2seq_pairs.json +0 -8
- wisent/examples/scripts/1/test_cabreu_evaluation.json +0 -30
- wisent/examples/scripts/1/test_cabreu_pairs.json +0 -8
- wisent/examples/scripts/1/test_careqa_en_evaluation.json +0 -30
- wisent/examples/scripts/1/test_careqa_en_pairs.json +0 -8
- wisent/examples/scripts/1/test_careqa_evaluation.json +0 -30
- wisent/examples/scripts/1/test_careqa_pairs.json +0 -8
- wisent/examples/scripts/1/test_catalanqa_evaluation.json +0 -30
- wisent/examples/scripts/1/test_catalanqa_pairs.json +0 -8
- wisent/examples/scripts/1/test_catcola_evaluation.json +0 -30
- wisent/examples/scripts/1/test_catcola_pairs.json +0 -8
- wisent/examples/scripts/1/test_chartqa_evaluation.json +0 -30
- wisent/examples/scripts/1/test_chartqa_pairs.json +0 -8
- wisent/examples/scripts/1/test_claim_stance_topic_evaluation.json +0 -30
- wisent/examples/scripts/1/test_claim_stance_topic_pairs.json +0 -8
- wisent/examples/scripts/1/test_cnn_dailymail_evaluation.json +0 -30
- wisent/examples/scripts/1/test_cnn_dailymail_pairs.json +0 -8
- wisent/examples/scripts/1/test_cocoteros_es_evaluation.json +0 -30
- wisent/examples/scripts/1/test_cocoteros_es_pairs.json +0 -8
- wisent/examples/scripts/1/test_coedit_gec_evaluation.json +0 -30
- wisent/examples/scripts/1/test_coedit_gec_pairs.json +0 -8
- wisent/examples/scripts/1/test_cola_evaluation.json +0 -30
- wisent/examples/scripts/1/test_cola_pairs.json +0 -8
- wisent/examples/scripts/1/test_coqcat_evaluation.json +0 -30
- wisent/examples/scripts/1/test_coqcat_pairs.json +0 -8
- wisent/examples/scripts/1/test_dbpedia_14_evaluation.json +0 -30
- wisent/examples/scripts/1/test_dbpedia_14_pairs.json +0 -8
- wisent/examples/scripts/1/test_epec_koref_bin_evaluation.json +0 -30
- wisent/examples/scripts/1/test_epec_koref_bin_pairs.json +0 -8
- wisent/examples/scripts/1/test_ethos_binary_evaluation.json +0 -30
- wisent/examples/scripts/1/test_ethos_binary_pairs.json +0 -8
- wisent/examples/scripts/2/test_afrimgsm_direct_amh_evaluation.json +0 -30
- wisent/examples/scripts/2/test_afrimgsm_direct_amh_pairs.json +0 -8
- wisent/examples/scripts/2/test_afrimmlu_direct_amh_evaluation.json +0 -30
- wisent/examples/scripts/2/test_afrimmlu_direct_amh_pairs.json +0 -8
- wisent/examples/scripts/2/test_afrixnli_en_direct_amh_evaluation.json +0 -30
- wisent/examples/scripts/2/test_afrixnli_en_direct_amh_pairs.json +0 -8
- wisent/examples/scripts/2/test_arc_ar_evaluation.json +0 -30
- wisent/examples/scripts/2/test_arc_ar_pairs.json +0 -8
- wisent/examples/scripts/2/test_atis_evaluation.json +0 -30
- wisent/examples/scripts/2/test_atis_pairs.json +0 -8
- wisent/examples/scripts/2/test_babi_evaluation.json +0 -30
- wisent/examples/scripts/2/test_babi_pairs.json +0 -8
- wisent/examples/scripts/2/test_babilong_evaluation.json +0 -30
- wisent/examples/scripts/2/test_babilong_pairs.json +0 -8
- wisent/examples/scripts/2/test_bangla_mmlu_evaluation.json +0 -30
- wisent/examples/scripts/2/test_bangla_mmlu_pairs.json +0 -8
- wisent/examples/scripts/2/test_basque-glue_pairs.json +0 -14
- wisent/examples/scripts/generate_paper_data.py +0 -384
- wisent/examples/scripts/intervention_validation.py +0 -626
- wisent/examples/scripts/results/test_AraDiCE_ArabicMMLU_lev_evaluation.json +0 -324
- wisent/examples/scripts/results/test_AraDiCE_ArabicMMLU_lev_pairs.json +0 -92
- wisent/examples/scripts/results/test_aexams_IslamicStudies_evaluation.json +0 -324
- wisent/examples/scripts/results/test_aexams_IslamicStudies_pairs.json +0 -92
- wisent/examples/scripts/results/test_afrimgsm_pairs.json +0 -92
- wisent/examples/scripts/results/test_afrimmlu_evaluation.json +0 -324
- wisent/examples/scripts/results/test_afrimmlu_pairs.json +0 -92
- wisent/examples/scripts/threshold_analysis.py +0 -434
- wisent/examples/scripts/visualization_gallery.py +0 -582
- {wisent-0.7.901.dist-info → wisent-0.7.1116.dist-info}/WHEEL +0 -0
- {wisent-0.7.901.dist-info → wisent-0.7.1116.dist-info}/entry_points.txt +0 -0
- {wisent-0.7.901.dist-info → wisent-0.7.1116.dist-info}/licenses/LICENSE +0 -0
- {wisent-0.7.901.dist-info → wisent-0.7.1116.dist-info}/top_level.txt +0 -0
|
@@ -1,384 +0,0 @@
|
|
|
1
|
-
"""
|
|
2
|
-
Generate data for RepScan paper.
|
|
3
|
-
|
|
4
|
-
Produces:
|
|
5
|
-
1. Main results table (LaTeX)
|
|
6
|
-
2. Per-category summary table
|
|
7
|
-
3. Benchmark list with contrastive definitions
|
|
8
|
-
4. Data for figures (JSON)
|
|
9
|
-
|
|
10
|
-
Usage:
|
|
11
|
-
python -m wisent.examples.scripts.generate_paper_data
|
|
12
|
-
"""
|
|
13
|
-
|
|
14
|
-
import json
|
|
15
|
-
import subprocess
|
|
16
|
-
from pathlib import Path
|
|
17
|
-
from typing import Dict, List, Any
|
|
18
|
-
from collections import defaultdict
|
|
19
|
-
|
|
20
|
-
S3_BUCKET = "wisent-bucket"
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
def download_all_results(output_dir: Path) -> Dict[str, Path]:
|
|
24
|
-
"""Download all results from S3."""
|
|
25
|
-
output_dir.mkdir(parents=True, exist_ok=True)
|
|
26
|
-
|
|
27
|
-
subprocess.run(
|
|
28
|
-
["aws", "s3", "sync",
|
|
29
|
-
f"s3://{S3_BUCKET}/direction_discovery/",
|
|
30
|
-
str(output_dir),
|
|
31
|
-
"--quiet"],
|
|
32
|
-
check=False,
|
|
33
|
-
capture_output=True,
|
|
34
|
-
)
|
|
35
|
-
|
|
36
|
-
models = {}
|
|
37
|
-
for d in output_dir.iterdir():
|
|
38
|
-
if d.is_dir():
|
|
39
|
-
models[d.name] = d
|
|
40
|
-
|
|
41
|
-
return models
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
def load_model_results(model_dir: Path) -> Dict[str, Any]:
|
|
45
|
-
"""Load all category results for a model."""
|
|
46
|
-
results = {}
|
|
47
|
-
for f in model_dir.glob("*.json"):
|
|
48
|
-
if "summary" in f.name:
|
|
49
|
-
continue
|
|
50
|
-
category = f.stem.split("_")[-1]
|
|
51
|
-
with open(f) as fp:
|
|
52
|
-
results[category] = json.load(fp)
|
|
53
|
-
return results
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
def compute_diagnosis(signal: float, linear: float) -> str:
|
|
57
|
-
"""Compute diagnosis from signal and linear probe accuracy."""
|
|
58
|
-
if signal < 0.6:
|
|
59
|
-
return "NO_SIGNAL"
|
|
60
|
-
elif linear > 0.6 and (signal - linear) < 0.15:
|
|
61
|
-
return "LINEAR"
|
|
62
|
-
else:
|
|
63
|
-
return "NONLINEAR"
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
def generate_main_results_table(all_models: Dict[str, Dict]) -> str:
|
|
67
|
-
"""Generate main results table in LaTeX."""
|
|
68
|
-
lines = [
|
|
69
|
-
r"\begin{table}[t]",
|
|
70
|
-
r"\centering",
|
|
71
|
-
r"\caption{RepScan diagnosis results across models and categories. Signal = MLP CV accuracy, Linear = Linear probe CV accuracy, kNN = k-NN CV accuracy (k=10). Diagnosis: LINEAR indicates CAA-viable representation, NONLINEAR indicates manifold structure, NO\_SIGNAL indicates no detectable representation.}",
|
|
72
|
-
r"\label{tab:main_results}",
|
|
73
|
-
r"\small",
|
|
74
|
-
r"\begin{tabular}{llccccl}",
|
|
75
|
-
r"\toprule",
|
|
76
|
-
r"\textbf{Model} & \textbf{Category} & \textbf{Signal} & \textbf{Linear} & \textbf{kNN} & \textbf{Gap} & \textbf{Diagnosis} \\",
|
|
77
|
-
r"\midrule",
|
|
78
|
-
]
|
|
79
|
-
|
|
80
|
-
for model_name, categories in sorted(all_models.items()):
|
|
81
|
-
model_short = model_name.replace("meta-llama_", "").replace("Qwen_", "").replace("openai_", "")
|
|
82
|
-
first_row = True
|
|
83
|
-
|
|
84
|
-
for cat_name in sorted(categories.keys()):
|
|
85
|
-
data = categories[cat_name]
|
|
86
|
-
results = data.get("results", [])
|
|
87
|
-
if not results:
|
|
88
|
-
continue
|
|
89
|
-
|
|
90
|
-
n = len(results)
|
|
91
|
-
avg_signal = sum(r["signal_strength"] for r in results) / n
|
|
92
|
-
avg_linear = sum(r["linear_probe_accuracy"] for r in results) / n
|
|
93
|
-
avg_knn = sum(r["nonlinear_metrics"]["knn_accuracy_k10"] for r in results) / n
|
|
94
|
-
gap = avg_signal - avg_linear
|
|
95
|
-
diagnosis = compute_diagnosis(avg_signal, avg_linear)
|
|
96
|
-
|
|
97
|
-
# Color coding for diagnosis
|
|
98
|
-
if diagnosis == "LINEAR":
|
|
99
|
-
diag_str = r"\textcolor{green!60!black}{LINEAR}"
|
|
100
|
-
elif diagnosis == "NONLINEAR":
|
|
101
|
-
diag_str = r"\textcolor{blue}{NONLINEAR}"
|
|
102
|
-
else:
|
|
103
|
-
diag_str = r"\textcolor{gray}{NO\_SIGNAL}"
|
|
104
|
-
|
|
105
|
-
model_col = model_short if first_row else ""
|
|
106
|
-
first_row = False
|
|
107
|
-
|
|
108
|
-
lines.append(f"{model_col} & {cat_name} & {avg_signal:.2f} & {avg_linear:.2f} & {avg_knn:.2f} & {gap:+.2f} & {diag_str} \\\\")
|
|
109
|
-
|
|
110
|
-
lines.append(r"\midrule")
|
|
111
|
-
|
|
112
|
-
lines[-1] = r"\bottomrule" # Replace last midrule
|
|
113
|
-
lines.extend([
|
|
114
|
-
r"\end{tabular}",
|
|
115
|
-
r"\end{table}",
|
|
116
|
-
])
|
|
117
|
-
|
|
118
|
-
return "\n".join(lines)
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
def generate_benchmark_table(all_models: Dict[str, Dict]) -> str:
|
|
122
|
-
"""Generate benchmark list with contrastive definitions."""
|
|
123
|
-
# Collect unique benchmarks
|
|
124
|
-
benchmarks = defaultdict(lambda: {"categories": set(), "signal": [], "linear": [], "knn": []})
|
|
125
|
-
|
|
126
|
-
for model_name, categories in all_models.items():
|
|
127
|
-
for cat_name, data in categories.items():
|
|
128
|
-
results = data.get("results", [])
|
|
129
|
-
seen = set()
|
|
130
|
-
for r in results:
|
|
131
|
-
bench = r["benchmark"]
|
|
132
|
-
if bench in seen:
|
|
133
|
-
continue
|
|
134
|
-
seen.add(bench)
|
|
135
|
-
|
|
136
|
-
benchmarks[bench]["categories"].add(cat_name)
|
|
137
|
-
benchmarks[bench]["signal"].append(r["signal_strength"])
|
|
138
|
-
benchmarks[bench]["linear"].append(r["linear_probe_accuracy"])
|
|
139
|
-
benchmarks[bench]["knn"].append(r["nonlinear_metrics"]["knn_accuracy_k10"])
|
|
140
|
-
|
|
141
|
-
lines = [
|
|
142
|
-
r"\begin{longtable}{p{3cm}p{2.5cm}cccl}",
|
|
143
|
-
r"\caption{Per-benchmark RepScan results (averaged across models and strategies).} \label{tab:benchmarks} \\",
|
|
144
|
-
r"\toprule",
|
|
145
|
-
r"\textbf{Benchmark} & \textbf{Category} & \textbf{Signal} & \textbf{Linear} & \textbf{kNN} & \textbf{Diagnosis} \\",
|
|
146
|
-
r"\midrule",
|
|
147
|
-
r"\endfirsthead",
|
|
148
|
-
r"\multicolumn{6}{c}{\tablename\ \thetable{} -- continued} \\",
|
|
149
|
-
r"\toprule",
|
|
150
|
-
r"\textbf{Benchmark} & \textbf{Category} & \textbf{Signal} & \textbf{Linear} & \textbf{kNN} & \textbf{Diagnosis} \\",
|
|
151
|
-
r"\midrule",
|
|
152
|
-
r"\endhead",
|
|
153
|
-
]
|
|
154
|
-
|
|
155
|
-
for bench, data in sorted(benchmarks.items(), key=lambda x: -max(x[1]["signal"]) if x[1]["signal"] else 0):
|
|
156
|
-
cats = ", ".join(sorted(data["categories"]))[:20]
|
|
157
|
-
avg_signal = sum(data["signal"]) / len(data["signal"]) if data["signal"] else 0
|
|
158
|
-
avg_linear = sum(data["linear"]) / len(data["linear"]) if data["linear"] else 0
|
|
159
|
-
avg_knn = sum(data["knn"]) / len(data["knn"]) if data["knn"] else 0
|
|
160
|
-
diagnosis = compute_diagnosis(avg_signal, avg_linear)
|
|
161
|
-
|
|
162
|
-
if diagnosis == "LINEAR":
|
|
163
|
-
diag_str = r"\textcolor{green!60!black}{LINEAR}"
|
|
164
|
-
elif diagnosis == "NONLINEAR":
|
|
165
|
-
diag_str = r"\textcolor{blue}{NONLINEAR}"
|
|
166
|
-
else:
|
|
167
|
-
diag_str = r"\textcolor{gray}{NO\_SIG}"
|
|
168
|
-
|
|
169
|
-
bench_escaped = bench.replace("_", r"\_")
|
|
170
|
-
lines.append(f"{bench_escaped} & {cats} & {avg_signal:.2f} & {avg_linear:.2f} & {avg_knn:.2f} & {diag_str} \\\\")
|
|
171
|
-
|
|
172
|
-
lines.extend([
|
|
173
|
-
r"\bottomrule",
|
|
174
|
-
r"\end{longtable}",
|
|
175
|
-
])
|
|
176
|
-
|
|
177
|
-
return "\n".join(lines)
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
def generate_figure_data(all_models: Dict[str, Dict]) -> Dict[str, Any]:
|
|
181
|
-
"""Generate JSON data for figures."""
|
|
182
|
-
figure_data = {
|
|
183
|
-
"diagnosis_distribution": {"LINEAR": 0, "NONLINEAR": 0, "NO_SIGNAL": 0},
|
|
184
|
-
"per_category": {},
|
|
185
|
-
"top_benchmarks": {"linear": [], "nonlinear": [], "no_signal": []},
|
|
186
|
-
"metrics_by_diagnosis": {
|
|
187
|
-
"LINEAR": {"signal": [], "linear": [], "knn": [], "mmd": []},
|
|
188
|
-
"NONLINEAR": {"signal": [], "linear": [], "knn": [], "mmd": []},
|
|
189
|
-
"NO_SIGNAL": {"signal": [], "linear": [], "knn": [], "mmd": []},
|
|
190
|
-
},
|
|
191
|
-
}
|
|
192
|
-
|
|
193
|
-
all_results = []
|
|
194
|
-
|
|
195
|
-
for model_name, categories in all_models.items():
|
|
196
|
-
for cat_name, data in categories.items():
|
|
197
|
-
results = data.get("results", [])
|
|
198
|
-
all_results.extend(results)
|
|
199
|
-
|
|
200
|
-
if cat_name not in figure_data["per_category"]:
|
|
201
|
-
figure_data["per_category"][cat_name] = {
|
|
202
|
-
"signal": [], "linear": [], "knn": []
|
|
203
|
-
}
|
|
204
|
-
|
|
205
|
-
for r in results:
|
|
206
|
-
signal = r["signal_strength"]
|
|
207
|
-
linear = r["linear_probe_accuracy"]
|
|
208
|
-
knn = r["nonlinear_metrics"]["knn_accuracy_k10"]
|
|
209
|
-
mmd = r["nonlinear_metrics"]["mmd_rbf"]
|
|
210
|
-
diagnosis = compute_diagnosis(signal, linear)
|
|
211
|
-
|
|
212
|
-
figure_data["diagnosis_distribution"][diagnosis] += 1
|
|
213
|
-
figure_data["metrics_by_diagnosis"][diagnosis]["signal"].append(signal)
|
|
214
|
-
figure_data["metrics_by_diagnosis"][diagnosis]["linear"].append(linear)
|
|
215
|
-
figure_data["metrics_by_diagnosis"][diagnosis]["knn"].append(knn)
|
|
216
|
-
figure_data["metrics_by_diagnosis"][diagnosis]["mmd"].append(mmd)
|
|
217
|
-
|
|
218
|
-
figure_data["per_category"][cat_name]["signal"].append(signal)
|
|
219
|
-
figure_data["per_category"][cat_name]["linear"].append(linear)
|
|
220
|
-
figure_data["per_category"][cat_name]["knn"].append(knn)
|
|
221
|
-
|
|
222
|
-
# Compute averages
|
|
223
|
-
for diag in figure_data["metrics_by_diagnosis"]:
|
|
224
|
-
for metric in list(figure_data["metrics_by_diagnosis"][diag].keys()):
|
|
225
|
-
values = figure_data["metrics_by_diagnosis"][diag][metric]
|
|
226
|
-
if values and isinstance(values, list):
|
|
227
|
-
figure_data["metrics_by_diagnosis"][diag][f"{metric}_mean"] = sum(values) / len(values)
|
|
228
|
-
figure_data["metrics_by_diagnosis"][diag][f"{metric}_std"] = (
|
|
229
|
-
sum((v - sum(values)/len(values))**2 for v in values) / len(values)
|
|
230
|
-
) ** 0.5
|
|
231
|
-
|
|
232
|
-
# Top benchmarks per diagnosis
|
|
233
|
-
benchmarks_by_diag = {"LINEAR": [], "NONLINEAR": [], "NO_SIGNAL": []}
|
|
234
|
-
seen = set()
|
|
235
|
-
|
|
236
|
-
for r in all_results:
|
|
237
|
-
bench = r["benchmark"]
|
|
238
|
-
if bench in seen:
|
|
239
|
-
continue
|
|
240
|
-
seen.add(bench)
|
|
241
|
-
|
|
242
|
-
signal = r["signal_strength"]
|
|
243
|
-
linear = r["linear_probe_accuracy"]
|
|
244
|
-
knn = r["nonlinear_metrics"]["knn_accuracy_k10"]
|
|
245
|
-
diagnosis = compute_diagnosis(signal, linear)
|
|
246
|
-
|
|
247
|
-
benchmarks_by_diag[diagnosis].append({
|
|
248
|
-
"benchmark": bench,
|
|
249
|
-
"signal": signal,
|
|
250
|
-
"linear": linear,
|
|
251
|
-
"knn": knn,
|
|
252
|
-
"gap": knn - linear,
|
|
253
|
-
})
|
|
254
|
-
|
|
255
|
-
# Sort and take top 5
|
|
256
|
-
benchmarks_by_diag["LINEAR"].sort(key=lambda x: x["linear"], reverse=True)
|
|
257
|
-
benchmarks_by_diag["NONLINEAR"].sort(key=lambda x: x["gap"], reverse=True)
|
|
258
|
-
benchmarks_by_diag["NO_SIGNAL"].sort(key=lambda x: x["signal"])
|
|
259
|
-
|
|
260
|
-
figure_data["top_benchmarks"] = {
|
|
261
|
-
diag: benches[:5] for diag, benches in benchmarks_by_diag.items()
|
|
262
|
-
}
|
|
263
|
-
|
|
264
|
-
return figure_data
|
|
265
|
-
|
|
266
|
-
|
|
267
|
-
def generate_summary_statistics(all_models: Dict[str, Dict]) -> str:
|
|
268
|
-
"""Generate summary statistics for paper text."""
|
|
269
|
-
total_results = 0
|
|
270
|
-
total_linear = 0
|
|
271
|
-
total_nonlinear = 0
|
|
272
|
-
total_no_signal = 0
|
|
273
|
-
|
|
274
|
-
categories = set()
|
|
275
|
-
benchmarks = set()
|
|
276
|
-
|
|
277
|
-
for model_name, model_categories in all_models.items():
|
|
278
|
-
for cat_name, data in model_categories.items():
|
|
279
|
-
categories.add(cat_name)
|
|
280
|
-
results = data.get("results", [])
|
|
281
|
-
total_results += len(results)
|
|
282
|
-
|
|
283
|
-
for r in results:
|
|
284
|
-
benchmarks.add(r["benchmark"])
|
|
285
|
-
signal = r["signal_strength"]
|
|
286
|
-
linear = r["linear_probe_accuracy"]
|
|
287
|
-
diagnosis = compute_diagnosis(signal, linear)
|
|
288
|
-
|
|
289
|
-
if diagnosis == "LINEAR":
|
|
290
|
-
total_linear += 1
|
|
291
|
-
elif diagnosis == "NONLINEAR":
|
|
292
|
-
total_nonlinear += 1
|
|
293
|
-
else:
|
|
294
|
-
total_no_signal += 1
|
|
295
|
-
|
|
296
|
-
text = f"""
|
|
297
|
-
## Summary Statistics for Paper
|
|
298
|
-
|
|
299
|
-
- **Total tests**: {total_results:,}
|
|
300
|
-
- **Models tested**: {len(all_models)}
|
|
301
|
-
- **Categories**: {len(categories)} ({', '.join(sorted(categories))})
|
|
302
|
-
- **Unique benchmarks**: {len(benchmarks)}
|
|
303
|
-
|
|
304
|
-
### Diagnosis Distribution:
|
|
305
|
-
- **LINEAR (CAA-viable)**: {total_linear:,} ({100*total_linear/total_results:.1f}%)
|
|
306
|
-
- **NONLINEAR (manifold)**: {total_nonlinear:,} ({100*total_nonlinear/total_results:.1f}%)
|
|
307
|
-
- **NO_SIGNAL**: {total_no_signal:,} ({100*total_no_signal/total_results:.1f}%)
|
|
308
|
-
|
|
309
|
-
### Key Findings:
|
|
310
|
-
1. {100*total_linear/total_results:.0f}% of benchmarks have LINEAR representations suitable for CAA
|
|
311
|
-
2. {100*total_nonlinear/total_results:.0f}% have NONLINEAR representations requiring different methods
|
|
312
|
-
3. {100*total_no_signal/total_results:.0f}% show no detectable signal
|
|
313
|
-
"""
|
|
314
|
-
return text
|
|
315
|
-
|
|
316
|
-
|
|
317
|
-
def main():
|
|
318
|
-
"""Generate all paper data."""
|
|
319
|
-
print("=" * 70)
|
|
320
|
-
print("GENERATING PAPER DATA")
|
|
321
|
-
print("=" * 70)
|
|
322
|
-
|
|
323
|
-
output_dir = Path("/tmp/paper_data")
|
|
324
|
-
output_dir.mkdir(parents=True, exist_ok=True)
|
|
325
|
-
|
|
326
|
-
# Download results
|
|
327
|
-
print("\n1. Downloading results from S3...")
|
|
328
|
-
results_dir = output_dir / "results"
|
|
329
|
-
models = download_all_results(results_dir)
|
|
330
|
-
print(f" Found {len(models)} models: {list(models.keys())}")
|
|
331
|
-
|
|
332
|
-
# Load all results
|
|
333
|
-
print("\n2. Loading results...")
|
|
334
|
-
all_models = {}
|
|
335
|
-
for model_name, model_dir in models.items():
|
|
336
|
-
all_models[model_name] = load_model_results(model_dir)
|
|
337
|
-
print(f" {model_name}: {len(all_models[model_name])} categories")
|
|
338
|
-
|
|
339
|
-
# Generate main table
|
|
340
|
-
print("\n3. Generating main results table...")
|
|
341
|
-
main_table = generate_main_results_table(all_models)
|
|
342
|
-
with open(output_dir / "main_results_table.tex", "w") as f:
|
|
343
|
-
f.write(main_table)
|
|
344
|
-
print(f" Saved: {output_dir / 'main_results_table.tex'}")
|
|
345
|
-
|
|
346
|
-
# Generate benchmark table
|
|
347
|
-
print("\n4. Generating benchmark table...")
|
|
348
|
-
bench_table = generate_benchmark_table(all_models)
|
|
349
|
-
with open(output_dir / "benchmark_table.tex", "w") as f:
|
|
350
|
-
f.write(bench_table)
|
|
351
|
-
print(f" Saved: {output_dir / 'benchmark_table.tex'}")
|
|
352
|
-
|
|
353
|
-
# Generate figure data
|
|
354
|
-
print("\n5. Generating figure data...")
|
|
355
|
-
figure_data = generate_figure_data(all_models)
|
|
356
|
-
with open(output_dir / "figure_data.json", "w") as f:
|
|
357
|
-
json.dump(figure_data, f, indent=2)
|
|
358
|
-
print(f" Saved: {output_dir / 'figure_data.json'}")
|
|
359
|
-
|
|
360
|
-
# Generate summary statistics
|
|
361
|
-
print("\n6. Generating summary statistics...")
|
|
362
|
-
summary = generate_summary_statistics(all_models)
|
|
363
|
-
with open(output_dir / "summary_statistics.md", "w") as f:
|
|
364
|
-
f.write(summary)
|
|
365
|
-
print(f" Saved: {output_dir / 'summary_statistics.md'}")
|
|
366
|
-
print(summary)
|
|
367
|
-
|
|
368
|
-
# Upload to S3
|
|
369
|
-
print("\n7. Uploading to S3...")
|
|
370
|
-
for f in output_dir.glob("*"):
|
|
371
|
-
if f.is_file():
|
|
372
|
-
subprocess.run(
|
|
373
|
-
["aws", "s3", "cp", str(f), f"s3://{S3_BUCKET}/paper_data/{f.name}", "--quiet"],
|
|
374
|
-
check=False,
|
|
375
|
-
)
|
|
376
|
-
|
|
377
|
-
print("\n" + "=" * 70)
|
|
378
|
-
print("PAPER DATA GENERATION COMPLETE")
|
|
379
|
-
print("=" * 70)
|
|
380
|
-
print(f"Output directory: {output_dir}")
|
|
381
|
-
|
|
382
|
-
|
|
383
|
-
if __name__ == "__main__":
|
|
384
|
-
main()
|