claude-turing 2.3.0 → 2.5.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/.claude-plugin/plugin.json +2 -2
- package/README.md +8 -2
- package/commands/budget.md +52 -0
- package/commands/distill.md +56 -0
- package/commands/ensemble.md +54 -0
- package/commands/scale.md +55 -0
- package/commands/stitch.md +49 -0
- package/commands/turing.md +12 -0
- package/commands/warm.md +53 -0
- package/package.json +1 -1
- package/src/install.js +2 -0
- package/src/verify.js +6 -0
- package/templates/scripts/__pycache__/budget_manager.cpython-314.pyc +0 -0
- package/templates/scripts/__pycache__/build_ensemble.cpython-314.pyc +0 -0
- package/templates/scripts/__pycache__/generate_brief.cpython-314.pyc +0 -0
- package/templates/scripts/__pycache__/model_distiller.cpython-314.pyc +0 -0
- package/templates/scripts/__pycache__/pipeline_manager.cpython-314.pyc +0 -0
- package/templates/scripts/__pycache__/scaffold.cpython-314.pyc +0 -0
- package/templates/scripts/__pycache__/scaling_estimator.cpython-314.pyc +0 -0
- package/templates/scripts/__pycache__/warm_start.cpython-314.pyc +0 -0
- package/templates/scripts/budget_manager.py +419 -0
- package/templates/scripts/build_ensemble.py +696 -0
- package/templates/scripts/generate_brief.py +95 -0
- package/templates/scripts/model_distiller.py +478 -0
- package/templates/scripts/pipeline_manager.py +457 -0
- package/templates/scripts/scaffold.py +11 -0
- package/templates/scripts/scaling_estimator.py +523 -0
- package/templates/scripts/warm_start.py +493 -0
|
@@ -309,6 +309,52 @@ def load_regression_checks(regress_dir: str = "experiments/regressions") -> list
|
|
|
309
309
|
return reports
|
|
310
310
|
|
|
311
311
|
|
|
312
|
+
def load_ensemble_results(ensemble_dir: str = "experiments/ensembles") -> list[dict]:
|
|
313
|
+
"""Load ensemble result reports from YAML files."""
|
|
314
|
+
path = Path(ensemble_dir)
|
|
315
|
+
if not path.exists():
|
|
316
|
+
return []
|
|
317
|
+
reports = []
|
|
318
|
+
for f in sorted(path.glob("ensemble-*.yaml")):
|
|
319
|
+
try:
|
|
320
|
+
with open(f) as fh:
|
|
321
|
+
report = yaml.safe_load(fh)
|
|
322
|
+
if report and isinstance(report, dict):
|
|
323
|
+
reports.append(report)
|
|
324
|
+
except (yaml.YAMLError, OSError):
|
|
325
|
+
continue
|
|
326
|
+
return reports
|
|
327
|
+
|
|
328
|
+
|
|
329
|
+
def load_budget_status(state_path: str = "experiment_state.yaml", log_path: str = "experiments/log.jsonl") -> dict | None:
|
|
330
|
+
"""Load budget status if active."""
|
|
331
|
+
try:
|
|
332
|
+
from scripts.budget_manager import get_budget_status
|
|
333
|
+
result = get_budget_status(state_path, log_path)
|
|
334
|
+
if "error" not in result:
|
|
335
|
+
return result
|
|
336
|
+
except (ImportError, Exception):
|
|
337
|
+
pass
|
|
338
|
+
return None
|
|
339
|
+
|
|
340
|
+
|
|
341
|
+
def load_scaling_results(scaling_dir: str = "experiments/scaling") -> list[dict]:
|
|
342
|
+
"""Load scaling study results from YAML files."""
|
|
343
|
+
path = Path(scaling_dir)
|
|
344
|
+
if not path.exists():
|
|
345
|
+
return []
|
|
346
|
+
reports = []
|
|
347
|
+
for f in sorted(path.glob("scale-*.yaml")):
|
|
348
|
+
try:
|
|
349
|
+
with open(f) as fh:
|
|
350
|
+
report = yaml.safe_load(fh)
|
|
351
|
+
if report and isinstance(report, dict) and "verdict" in report:
|
|
352
|
+
reports.append(report)
|
|
353
|
+
except (yaml.YAMLError, OSError):
|
|
354
|
+
continue
|
|
355
|
+
return reports
|
|
356
|
+
|
|
357
|
+
|
|
312
358
|
def format_brief(
|
|
313
359
|
campaign: dict,
|
|
314
360
|
best: dict | None,
|
|
@@ -327,6 +373,9 @@ def format_brief(
|
|
|
327
373
|
profiles: list[dict] | None = None,
|
|
328
374
|
queue_summary: dict | None = None,
|
|
329
375
|
regression_checks: list[dict] | None = None,
|
|
376
|
+
ensemble_results: list[dict] | None = None,
|
|
377
|
+
budget_status: dict | None = None,
|
|
378
|
+
scaling_results: list[dict] | None = None,
|
|
330
379
|
) -> str:
|
|
331
380
|
"""Format the research briefing as markdown."""
|
|
332
381
|
direction = "lower" if lower_is_better else "higher"
|
|
@@ -546,6 +595,46 @@ def format_brief(
|
|
|
546
595
|
if auto_hyps:
|
|
547
596
|
lines.append(f"\n*{auto_hyps} auto-generated hypotheses from failure analysis.*")
|
|
548
597
|
|
|
598
|
+
# Ensemble results
|
|
599
|
+
if ensemble_results:
|
|
600
|
+
lines.extend(["", "## Ensembles", ""])
|
|
601
|
+
for ens in ensemble_results:
|
|
602
|
+
best_method = ens.get("best_method", "?")
|
|
603
|
+
improvement = ens.get("improvement", 0)
|
|
604
|
+
n_models = ens.get("n_candidates", 0)
|
|
605
|
+
if best_method != "best_single" and improvement > 0:
|
|
606
|
+
lines.append(
|
|
607
|
+
f"- **{best_method}** ({n_models} models): "
|
|
608
|
+
f"{metric} improvement {improvement:+.4f} over best single"
|
|
609
|
+
)
|
|
610
|
+
else:
|
|
611
|
+
lines.append(f"- {n_models}-model ensemble: no improvement over best single")
|
|
612
|
+
|
|
613
|
+
# Budget status
|
|
614
|
+
if budget_status and budget_status.get("usage"):
|
|
615
|
+
usage = budget_status["usage"]
|
|
616
|
+
phase = budget_status.get("phase", "?")
|
|
617
|
+
lines.extend(["", "## Budget", ""])
|
|
618
|
+
if usage.get("experiments_max"):
|
|
619
|
+
lines.append(
|
|
620
|
+
f"- **Experiments:** {usage['experiments_used']}/{usage['experiments_max']} "
|
|
621
|
+
f"({usage['budget_fraction']:.0%} used)"
|
|
622
|
+
)
|
|
623
|
+
if usage.get("hours_max"):
|
|
624
|
+
lines.append(f"- **Time:** {usage['hours_used']:.1f}/{usage['hours_max']:.1f}h")
|
|
625
|
+
lines.append(f"- **Phase:** {phase}")
|
|
626
|
+
if budget_status.get("exhausted"):
|
|
627
|
+
lines.append("- **STATUS: EXHAUSTED** — no more experiments will run")
|
|
628
|
+
|
|
629
|
+
# Scaling predictions
|
|
630
|
+
if scaling_results:
|
|
631
|
+
lines.extend(["", "## Scaling Predictions", ""])
|
|
632
|
+
for study in scaling_results:
|
|
633
|
+
verdict = study.get("verdict", {})
|
|
634
|
+
v = verdict.get("verdict", "?")
|
|
635
|
+
reason = verdict.get("reason", "")
|
|
636
|
+
lines.append(f"- **{v.upper()}**: {reason}")
|
|
637
|
+
|
|
549
638
|
# Regression check history (stability)
|
|
550
639
|
if regression_checks:
|
|
551
640
|
lines.extend(["", "## Stability", ""])
|
|
@@ -636,6 +725,9 @@ def generate_brief(
|
|
|
636
725
|
profiles = load_profiles()
|
|
637
726
|
queue_summary = load_queue_summary()
|
|
638
727
|
regression_checks = load_regression_checks()
|
|
728
|
+
ensemble_results = load_ensemble_results()
|
|
729
|
+
budget_status = load_budget_status(log_path=log_path)
|
|
730
|
+
scaling_results = load_scaling_results()
|
|
639
731
|
|
|
640
732
|
return format_brief(
|
|
641
733
|
campaign, best, trajectory, model_types, hypotheses,
|
|
@@ -648,6 +740,9 @@ def generate_brief(
|
|
|
648
740
|
profiles=profiles if profiles else None,
|
|
649
741
|
queue_summary=queue_summary,
|
|
650
742
|
regression_checks=regression_checks if regression_checks else None,
|
|
743
|
+
ensemble_results=ensemble_results if ensemble_results else None,
|
|
744
|
+
budget_status=budget_status,
|
|
745
|
+
scaling_results=scaling_results if scaling_results else None,
|
|
651
746
|
)
|
|
652
747
|
|
|
653
748
|
|
|
@@ -0,0 +1,478 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
"""Model compression via distillation for the autoresearch pipeline.
|
|
3
|
+
|
|
4
|
+
Takes a large accurate model (teacher) and plans/evaluates a smaller
|
|
5
|
+
model (student) that matches its predictions. Measures the accuracy/size/
|
|
6
|
+
latency tradeoff to bridge "best research model" and "production-ready model."
|
|
7
|
+
|
|
8
|
+
Usage:
|
|
9
|
+
python scripts/model_distiller.py exp-042
|
|
10
|
+
python scripts/model_distiller.py exp-042 --compression 4
|
|
11
|
+
python scripts/model_distiller.py exp-042 --method soft-labels
|
|
12
|
+
python scripts/model_distiller.py exp-042 --target-latency 5
|
|
13
|
+
"""
|
|
14
|
+
|
|
15
|
+
from __future__ import annotations
|
|
16
|
+
|
|
17
|
+
import argparse
|
|
18
|
+
import json
|
|
19
|
+
import math
|
|
20
|
+
import sys
|
|
21
|
+
from datetime import datetime, timezone
|
|
22
|
+
from pathlib import Path
|
|
23
|
+
|
|
24
|
+
import yaml
|
|
25
|
+
|
|
26
|
+
from scripts.turing_io import load_config, load_experiments
|
|
27
|
+
|
|
28
|
+
DEFAULT_LOG_PATH = "experiments/log.jsonl"
|
|
29
|
+
DEFAULT_COMPRESSION = 4 # 4x compression
|
|
30
|
+
DISTILLATION_METHODS = ["soft_labels", "feature_matching", "dataset_distillation"]
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
# --- Student Architecture Selection ---
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
def select_student_architecture(
|
|
37
|
+
teacher_config: dict,
|
|
38
|
+
compression: float,
|
|
39
|
+
) -> dict:
|
|
40
|
+
"""Auto-select student architecture based on teacher and compression target.
|
|
41
|
+
|
|
42
|
+
Args:
|
|
43
|
+
teacher_config: Teacher model config.
|
|
44
|
+
compression: Compression ratio (e.g., 4 = 4x smaller).
|
|
45
|
+
|
|
46
|
+
Returns:
|
|
47
|
+
Student config dict with model type and hyperparameters.
|
|
48
|
+
"""
|
|
49
|
+
model_type = teacher_config.get("model_type", "").lower()
|
|
50
|
+
hyperparams = teacher_config.get("hyperparams", {})
|
|
51
|
+
|
|
52
|
+
student = {
|
|
53
|
+
"model_type": model_type,
|
|
54
|
+
"hyperparams": {},
|
|
55
|
+
"compression_strategy": "",
|
|
56
|
+
}
|
|
57
|
+
|
|
58
|
+
if _is_tree_model(model_type):
|
|
59
|
+
student.update(_select_tree_student(hyperparams, compression))
|
|
60
|
+
elif _is_neural_model(model_type):
|
|
61
|
+
student.update(_select_neural_student(hyperparams, compression))
|
|
62
|
+
elif _is_sklearn_model(model_type):
|
|
63
|
+
student.update(_select_sklearn_student(model_type, hyperparams, compression))
|
|
64
|
+
else:
|
|
65
|
+
# Generic: reduce all numeric hyperparams by compression ratio
|
|
66
|
+
student["compression_strategy"] = "generic_reduction"
|
|
67
|
+
student["hyperparams"] = {
|
|
68
|
+
k: max(1, int(v / compression)) if isinstance(v, int) else v
|
|
69
|
+
for k, v in hyperparams.items()
|
|
70
|
+
}
|
|
71
|
+
|
|
72
|
+
return student
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
def _is_tree_model(model_type: str) -> bool:
|
|
76
|
+
return any(t in model_type for t in ("xgboost", "lightgbm", "catboost", "gbm", "gradient_boosting"))
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
def _is_neural_model(model_type: str) -> bool:
|
|
80
|
+
return any(t in model_type for t in ("mlp", "neural", "nn", "pytorch", "tensorflow", "transformer"))
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
def _is_sklearn_model(model_type: str) -> bool:
|
|
84
|
+
return any(t in model_type for t in ("random_forest", "svm", "knn", "logistic", "ridge"))
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
def _select_tree_student(hyperparams: dict, compression: float) -> dict:
|
|
88
|
+
"""Select student for tree-based models: fewer estimators, shallower."""
|
|
89
|
+
n_estimators = hyperparams.get("n_estimators", 100)
|
|
90
|
+
max_depth = hyperparams.get("max_depth", 6)
|
|
91
|
+
|
|
92
|
+
return {
|
|
93
|
+
"compression_strategy": "reduce_trees",
|
|
94
|
+
"hyperparams": {
|
|
95
|
+
"n_estimators": max(1, int(n_estimators / compression)),
|
|
96
|
+
"max_depth": max(1, int(max_depth / math.sqrt(compression))),
|
|
97
|
+
"learning_rate": hyperparams.get("learning_rate", 0.1),
|
|
98
|
+
},
|
|
99
|
+
}
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
def _select_neural_student(hyperparams: dict, compression: float) -> dict:
|
|
103
|
+
"""Select student for neural models: fewer layers, narrower."""
|
|
104
|
+
hidden_size = hyperparams.get("hidden_size", 256)
|
|
105
|
+
n_layers = hyperparams.get("n_layers", hyperparams.get("layers", 4))
|
|
106
|
+
|
|
107
|
+
return {
|
|
108
|
+
"compression_strategy": "reduce_architecture",
|
|
109
|
+
"hyperparams": {
|
|
110
|
+
"hidden_size": max(8, int(hidden_size / math.sqrt(compression))),
|
|
111
|
+
"n_layers": max(1, int(n_layers / math.sqrt(compression))),
|
|
112
|
+
"learning_rate": hyperparams.get("learning_rate", 0.001),
|
|
113
|
+
},
|
|
114
|
+
}
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
def _select_sklearn_student(model_type: str, hyperparams: dict, compression: float) -> dict:
|
|
118
|
+
"""Select student for sklearn models: simpler model family."""
|
|
119
|
+
# Map complex models to simpler alternatives
|
|
120
|
+
student_map = {
|
|
121
|
+
"random_forest": "decision_tree",
|
|
122
|
+
"svm": "logistic_regression",
|
|
123
|
+
"knn": "logistic_regression",
|
|
124
|
+
}
|
|
125
|
+
|
|
126
|
+
student_type = student_map.get(model_type, model_type)
|
|
127
|
+
|
|
128
|
+
student_params = {}
|
|
129
|
+
if student_type == "decision_tree":
|
|
130
|
+
student_params["max_depth"] = max(1, int(hyperparams.get("max_depth", 10) / compression))
|
|
131
|
+
elif student_type == "logistic_regression":
|
|
132
|
+
student_params["C"] = hyperparams.get("C", 1.0)
|
|
133
|
+
|
|
134
|
+
return {
|
|
135
|
+
"model_type": student_type,
|
|
136
|
+
"compression_strategy": "simpler_family",
|
|
137
|
+
"hyperparams": student_params,
|
|
138
|
+
}
|
|
139
|
+
|
|
140
|
+
|
|
141
|
+
# --- Distillation Configuration ---
|
|
142
|
+
|
|
143
|
+
|
|
144
|
+
def plan_distillation(
|
|
145
|
+
teacher_exp: dict,
|
|
146
|
+
compression: float = DEFAULT_COMPRESSION,
|
|
147
|
+
method: str = "soft_labels",
|
|
148
|
+
target_latency: float | None = None,
|
|
149
|
+
) -> dict:
|
|
150
|
+
"""Plan a distillation run.
|
|
151
|
+
|
|
152
|
+
Args:
|
|
153
|
+
teacher_exp: Teacher experiment dict.
|
|
154
|
+
compression: Compression ratio.
|
|
155
|
+
method: Distillation method.
|
|
156
|
+
target_latency: Optional target latency in ms.
|
|
157
|
+
|
|
158
|
+
Returns:
|
|
159
|
+
Distillation plan dict.
|
|
160
|
+
"""
|
|
161
|
+
teacher_id = teacher_exp.get("experiment_id", "unknown")
|
|
162
|
+
teacher_config = teacher_exp.get("config", {})
|
|
163
|
+
teacher_metrics = teacher_exp.get("metrics", {})
|
|
164
|
+
|
|
165
|
+
# Select student architecture
|
|
166
|
+
student = select_student_architecture(teacher_config, compression)
|
|
167
|
+
|
|
168
|
+
# Estimate student size
|
|
169
|
+
teacher_size = teacher_metrics.get("model_size_bytes", teacher_metrics.get("n_params", 0))
|
|
170
|
+
estimated_student_size = teacher_size / compression if teacher_size else None
|
|
171
|
+
|
|
172
|
+
# Estimate student latency
|
|
173
|
+
teacher_latency = teacher_metrics.get("latency_ms", teacher_metrics.get("inference_ms", 0))
|
|
174
|
+
estimated_student_latency = teacher_latency / math.sqrt(compression) if teacher_latency else None
|
|
175
|
+
|
|
176
|
+
# If target latency specified, adjust compression
|
|
177
|
+
if target_latency and teacher_latency and teacher_latency > 0:
|
|
178
|
+
needed_speedup = teacher_latency / target_latency
|
|
179
|
+
adjusted_compression = needed_speedup ** 2 # Latency scales with sqrt(compression)
|
|
180
|
+
if adjusted_compression > compression:
|
|
181
|
+
compression = adjusted_compression
|
|
182
|
+
student = select_student_architecture(teacher_config, compression)
|
|
183
|
+
|
|
184
|
+
plan = {
|
|
185
|
+
"teacher_id": teacher_id,
|
|
186
|
+
"teacher_metrics": teacher_metrics,
|
|
187
|
+
"teacher_config": teacher_config,
|
|
188
|
+
"compression": round(compression, 2),
|
|
189
|
+
"method": method,
|
|
190
|
+
"student": student,
|
|
191
|
+
"estimates": {
|
|
192
|
+
"student_size_bytes": int(estimated_student_size) if estimated_student_size else None,
|
|
193
|
+
"student_latency_ms": round(estimated_student_latency, 2) if estimated_student_latency else None,
|
|
194
|
+
"size_reduction": f"{(1 - 1/compression) * 100:.0f}%" if compression > 0 else "N/A",
|
|
195
|
+
},
|
|
196
|
+
"distillation_config": _build_distillation_config(method),
|
|
197
|
+
}
|
|
198
|
+
|
|
199
|
+
return plan
|
|
200
|
+
|
|
201
|
+
|
|
202
|
+
def _build_distillation_config(method: str) -> dict:
|
|
203
|
+
"""Build distillation-specific configuration."""
|
|
204
|
+
if method == "soft_labels":
|
|
205
|
+
return {
|
|
206
|
+
"temperature": 3.0,
|
|
207
|
+
"alpha": 0.7, # Weight of soft labels vs hard labels
|
|
208
|
+
"description": "Train student on teacher's probability outputs with temperature scaling",
|
|
209
|
+
}
|
|
210
|
+
elif method == "feature_matching":
|
|
211
|
+
return {
|
|
212
|
+
"match_layers": "last_hidden",
|
|
213
|
+
"loss": "mse",
|
|
214
|
+
"description": "Align student's intermediate representations with teacher's",
|
|
215
|
+
}
|
|
216
|
+
elif method == "dataset_distillation":
|
|
217
|
+
return {
|
|
218
|
+
"synthetic_samples": 1000,
|
|
219
|
+
"description": "Train student on teacher-labeled synthetic data",
|
|
220
|
+
}
|
|
221
|
+
return {"description": "Unknown method"}
|
|
222
|
+
|
|
223
|
+
|
|
224
|
+
# --- Verdict ---
|
|
225
|
+
|
|
226
|
+
|
|
227
|
+
def compute_distillation_verdict(
|
|
228
|
+
teacher_metrics: dict,
|
|
229
|
+
student_metrics: dict,
|
|
230
|
+
primary_metric: str,
|
|
231
|
+
compression: float,
|
|
232
|
+
) -> dict:
|
|
233
|
+
"""Compute verdict on distillation quality.
|
|
234
|
+
|
|
235
|
+
Args:
|
|
236
|
+
teacher_metrics: Teacher model metrics.
|
|
237
|
+
student_metrics: Student model metrics.
|
|
238
|
+
primary_metric: Name of primary metric.
|
|
239
|
+
compression: Achieved compression ratio.
|
|
240
|
+
|
|
241
|
+
Returns:
|
|
242
|
+
Verdict dict.
|
|
243
|
+
"""
|
|
244
|
+
teacher_val = teacher_metrics.get(primary_metric, 0)
|
|
245
|
+
student_val = student_metrics.get(primary_metric, 0)
|
|
246
|
+
|
|
247
|
+
if teacher_val == 0:
|
|
248
|
+
return {"verdict": "no_baseline", "reason": "Teacher has no metric to compare against"}
|
|
249
|
+
|
|
250
|
+
delta = student_val - teacher_val
|
|
251
|
+
relative_loss = abs(delta) / abs(teacher_val) if teacher_val != 0 else 0
|
|
252
|
+
|
|
253
|
+
if relative_loss < 0.01: # < 1% accuracy loss
|
|
254
|
+
verdict = "excellent"
|
|
255
|
+
reason = f"{relative_loss:.1%} accuracy loss for {compression:.0f}x compression. Excellent tradeoff."
|
|
256
|
+
elif relative_loss < 0.03: # < 3% loss
|
|
257
|
+
verdict = "acceptable"
|
|
258
|
+
reason = f"{relative_loss:.1%} accuracy loss for {compression:.0f}x compression. Acceptable for production."
|
|
259
|
+
elif relative_loss < 0.05: # < 5% loss
|
|
260
|
+
verdict = "marginal"
|
|
261
|
+
reason = f"{relative_loss:.1%} accuracy loss for {compression:.0f}x compression. Consider lower compression."
|
|
262
|
+
else:
|
|
263
|
+
verdict = "too_much_loss"
|
|
264
|
+
reason = f"{relative_loss:.1%} accuracy loss for {compression:.0f}x compression. Try a less aggressive compression."
|
|
265
|
+
|
|
266
|
+
return {
|
|
267
|
+
"verdict": verdict,
|
|
268
|
+
"delta": round(delta, 6),
|
|
269
|
+
"relative_loss": round(relative_loss, 6),
|
|
270
|
+
"compression": compression,
|
|
271
|
+
"reason": reason,
|
|
272
|
+
}
|
|
273
|
+
|
|
274
|
+
|
|
275
|
+
# --- Full Pipeline ---
|
|
276
|
+
|
|
277
|
+
|
|
278
|
+
def distill_model(
|
|
279
|
+
teacher_exp_id: str,
|
|
280
|
+
compression: float = DEFAULT_COMPRESSION,
|
|
281
|
+
method: str = "soft_labels",
|
|
282
|
+
target_latency: float | None = None,
|
|
283
|
+
config_path: str = "config.yaml",
|
|
284
|
+
log_path: str = DEFAULT_LOG_PATH,
|
|
285
|
+
) -> dict:
|
|
286
|
+
"""Plan and report a model distillation.
|
|
287
|
+
|
|
288
|
+
Args:
|
|
289
|
+
teacher_exp_id: Teacher experiment ID.
|
|
290
|
+
compression: Compression ratio.
|
|
291
|
+
method: Distillation method.
|
|
292
|
+
target_latency: Target inference latency in ms.
|
|
293
|
+
config_path: Path to config.yaml.
|
|
294
|
+
log_path: Path to experiment log.
|
|
295
|
+
|
|
296
|
+
Returns:
|
|
297
|
+
Complete distillation report.
|
|
298
|
+
"""
|
|
299
|
+
config = load_config(config_path)
|
|
300
|
+
eval_cfg = config.get("evaluation", {})
|
|
301
|
+
primary_metric = eval_cfg.get("primary_metric", "accuracy")
|
|
302
|
+
|
|
303
|
+
experiments = load_experiments(log_path)
|
|
304
|
+
|
|
305
|
+
teacher = None
|
|
306
|
+
for exp in experiments:
|
|
307
|
+
if exp.get("experiment_id") == teacher_exp_id:
|
|
308
|
+
teacher = exp
|
|
309
|
+
break
|
|
310
|
+
|
|
311
|
+
if not teacher:
|
|
312
|
+
return {"error": f"Teacher experiment {teacher_exp_id} not found in {log_path}"}
|
|
313
|
+
|
|
314
|
+
plan = plan_distillation(teacher, compression, method, target_latency)
|
|
315
|
+
|
|
316
|
+
report = {
|
|
317
|
+
"generated_at": datetime.now(timezone.utc).isoformat(),
|
|
318
|
+
"primary_metric": primary_metric,
|
|
319
|
+
"plan": plan,
|
|
320
|
+
}
|
|
321
|
+
|
|
322
|
+
return report
|
|
323
|
+
|
|
324
|
+
|
|
325
|
+
# --- Report Formatting ---
|
|
326
|
+
|
|
327
|
+
|
|
328
|
+
def save_distillation_report(report: dict, output_dir: str = "experiments/distillations") -> Path:
|
|
329
|
+
"""Save distillation report to YAML."""
|
|
330
|
+
out_path = Path(output_dir)
|
|
331
|
+
out_path.mkdir(parents=True, exist_ok=True)
|
|
332
|
+
|
|
333
|
+
teacher_id = report.get("plan", {}).get("teacher_id", "unknown")
|
|
334
|
+
filepath = out_path / f"distill-{teacher_id}.yaml"
|
|
335
|
+
|
|
336
|
+
with open(filepath, "w") as f:
|
|
337
|
+
yaml.dump(report, f, default_flow_style=False, sort_keys=False)
|
|
338
|
+
|
|
339
|
+
return filepath
|
|
340
|
+
|
|
341
|
+
|
|
342
|
+
def format_distillation_report(report: dict) -> str:
|
|
343
|
+
"""Format distillation report as markdown."""
|
|
344
|
+
if "error" in report:
|
|
345
|
+
return f"ERROR: {report['error']}"
|
|
346
|
+
|
|
347
|
+
plan = report.get("plan", {})
|
|
348
|
+
teacher_id = plan.get("teacher_id", "?")
|
|
349
|
+
compression = plan.get("compression", 0)
|
|
350
|
+
method = plan.get("method", "?")
|
|
351
|
+
student = plan.get("student", {})
|
|
352
|
+
estimates = plan.get("estimates", {})
|
|
353
|
+
dist_cfg = plan.get("distillation_config", {})
|
|
354
|
+
|
|
355
|
+
lines = [
|
|
356
|
+
f"# Distillation Plan: {teacher_id}",
|
|
357
|
+
"",
|
|
358
|
+
f"*Generated {report.get('generated_at', 'N/A')[:19]}*",
|
|
359
|
+
"",
|
|
360
|
+
f"**Compression:** {compression:.0f}x",
|
|
361
|
+
f"**Method:** {method}",
|
|
362
|
+
f"**Strategy:** {student.get('compression_strategy', '?')}",
|
|
363
|
+
"",
|
|
364
|
+
]
|
|
365
|
+
|
|
366
|
+
# Teacher info
|
|
367
|
+
teacher_metrics = plan.get("teacher_metrics", {})
|
|
368
|
+
if teacher_metrics:
|
|
369
|
+
lines.extend(["## Teacher Model", ""])
|
|
370
|
+
for k, v in teacher_metrics.items():
|
|
371
|
+
v_str = f"{v:.4f}" if isinstance(v, float) else str(v)
|
|
372
|
+
lines.append(f"- **{k}:** {v_str}")
|
|
373
|
+
lines.append("")
|
|
374
|
+
|
|
375
|
+
# Student architecture
|
|
376
|
+
lines.extend(["## Student Architecture", ""])
|
|
377
|
+
lines.append(f"- **Model type:** {student.get('model_type', '?')}")
|
|
378
|
+
for k, v in student.get("hyperparams", {}).items():
|
|
379
|
+
lines.append(f"- **{k}:** {v}")
|
|
380
|
+
lines.append("")
|
|
381
|
+
|
|
382
|
+
# Estimates
|
|
383
|
+
if any(v is not None for v in estimates.values()):
|
|
384
|
+
lines.extend(["## Estimates", ""])
|
|
385
|
+
if estimates.get("size_reduction"):
|
|
386
|
+
lines.append(f"- **Size reduction:** {estimates['size_reduction']}")
|
|
387
|
+
if estimates.get("student_latency_ms"):
|
|
388
|
+
lines.append(f"- **Estimated latency:** {estimates['student_latency_ms']:.1f} ms")
|
|
389
|
+
lines.append("")
|
|
390
|
+
|
|
391
|
+
# Distillation config
|
|
392
|
+
lines.extend([
|
|
393
|
+
"## Distillation Config",
|
|
394
|
+
"",
|
|
395
|
+
f"*{dist_cfg.get('description', method)}*",
|
|
396
|
+
"",
|
|
397
|
+
])
|
|
398
|
+
for k, v in dist_cfg.items():
|
|
399
|
+
if k != "description":
|
|
400
|
+
lines.append(f"- **{k}:** {v}")
|
|
401
|
+
|
|
402
|
+
# Verdict (if student metrics available)
|
|
403
|
+
verdict = report.get("verdict")
|
|
404
|
+
if verdict:
|
|
405
|
+
labels = {
|
|
406
|
+
"excellent": "EXCELLENT",
|
|
407
|
+
"acceptable": "ACCEPTABLE",
|
|
408
|
+
"marginal": "MARGINAL",
|
|
409
|
+
"too_much_loss": "TOO MUCH LOSS",
|
|
410
|
+
}
|
|
411
|
+
lines.extend([
|
|
412
|
+
"",
|
|
413
|
+
"## Verdict",
|
|
414
|
+
"",
|
|
415
|
+
f"**{labels.get(verdict.get('verdict', ''), verdict.get('verdict', '?'))}**",
|
|
416
|
+
"",
|
|
417
|
+
verdict.get("reason", ""),
|
|
418
|
+
])
|
|
419
|
+
|
|
420
|
+
return "\n".join(lines)
|
|
421
|
+
|
|
422
|
+
|
|
423
|
+
def main() -> None:
|
|
424
|
+
"""CLI entry point."""
|
|
425
|
+
parser = argparse.ArgumentParser(
|
|
426
|
+
description="Model compression via distillation",
|
|
427
|
+
)
|
|
428
|
+
parser.add_argument(
|
|
429
|
+
"teacher_exp_id",
|
|
430
|
+
help="Teacher experiment ID (e.g., exp-042)",
|
|
431
|
+
)
|
|
432
|
+
parser.add_argument(
|
|
433
|
+
"--compression", type=float, default=DEFAULT_COMPRESSION,
|
|
434
|
+
help=f"Compression ratio (default: {DEFAULT_COMPRESSION}x)",
|
|
435
|
+
)
|
|
436
|
+
parser.add_argument(
|
|
437
|
+
"--method", choices=DISTILLATION_METHODS, default="soft_labels",
|
|
438
|
+
help="Distillation method (default: soft_labels)",
|
|
439
|
+
)
|
|
440
|
+
parser.add_argument(
|
|
441
|
+
"--target-latency", type=float,
|
|
442
|
+
help="Target inference latency in ms (auto-adjusts compression)",
|
|
443
|
+
)
|
|
444
|
+
parser.add_argument(
|
|
445
|
+
"--config", default="config.yaml",
|
|
446
|
+
help="Path to config.yaml",
|
|
447
|
+
)
|
|
448
|
+
parser.add_argument(
|
|
449
|
+
"--log", default=DEFAULT_LOG_PATH,
|
|
450
|
+
help="Path to experiment log",
|
|
451
|
+
)
|
|
452
|
+
parser.add_argument(
|
|
453
|
+
"--json", action="store_true",
|
|
454
|
+
help="Output raw JSON instead of formatted report",
|
|
455
|
+
)
|
|
456
|
+
args = parser.parse_args()
|
|
457
|
+
|
|
458
|
+
report = distill_model(
|
|
459
|
+
teacher_exp_id=args.teacher_exp_id,
|
|
460
|
+
compression=args.compression,
|
|
461
|
+
method=args.method,
|
|
462
|
+
target_latency=args.target_latency,
|
|
463
|
+
config_path=args.config,
|
|
464
|
+
log_path=args.log,
|
|
465
|
+
)
|
|
466
|
+
|
|
467
|
+
if "error" not in report:
|
|
468
|
+
filepath = save_distillation_report(report)
|
|
469
|
+
print(f"Saved to {filepath}", file=sys.stderr)
|
|
470
|
+
|
|
471
|
+
if args.json:
|
|
472
|
+
print(json.dumps(report, indent=2, default=str))
|
|
473
|
+
else:
|
|
474
|
+
print(format_distillation_report(report))
|
|
475
|
+
|
|
476
|
+
|
|
477
|
+
if __name__ == "__main__":
|
|
478
|
+
main()
|