claude-turing 2.4.0 → 3.0.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/.claude-plugin/plugin.json +2 -2
- package/README.md +7 -2
- package/commands/audit.md +56 -0
- package/commands/budget.md +52 -0
- package/commands/distill.md +56 -0
- package/commands/scale.md +55 -0
- package/commands/transfer.md +54 -0
- package/commands/turing.md +10 -0
- package/package.json +1 -1
- package/src/install.js +2 -0
- package/src/verify.js +5 -0
- package/templates/scripts/__pycache__/budget_manager.cpython-314.pyc +0 -0
- package/templates/scripts/__pycache__/generate_brief.cpython-314.pyc +0 -0
- package/templates/scripts/__pycache__/knowledge_transfer.cpython-314.pyc +0 -0
- package/templates/scripts/__pycache__/methodology_audit.cpython-314.pyc +0 -0
- package/templates/scripts/__pycache__/model_distiller.cpython-314.pyc +0 -0
- package/templates/scripts/__pycache__/scaffold.cpython-314.pyc +0 -0
- package/templates/scripts/__pycache__/scaling_estimator.cpython-314.pyc +0 -0
- package/templates/scripts/budget_manager.py +419 -0
- package/templates/scripts/generate_brief.py +101 -0
- package/templates/scripts/knowledge_transfer.py +618 -0
- package/templates/scripts/methodology_audit.py +451 -0
- package/templates/scripts/model_distiller.py +478 -0
- package/templates/scripts/scaffold.py +9 -0
- package/templates/scripts/scaling_estimator.py +523 -0
|
@@ -0,0 +1,478 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
"""Model compression via distillation for the autoresearch pipeline.
|
|
3
|
+
|
|
4
|
+
Takes a large accurate model (teacher) and plans/evaluates a smaller
|
|
5
|
+
model (student) that matches its predictions. Measures the accuracy/size/
|
|
6
|
+
latency tradeoff to bridge "best research model" and "production-ready model."
|
|
7
|
+
|
|
8
|
+
Usage:
|
|
9
|
+
python scripts/model_distiller.py exp-042
|
|
10
|
+
python scripts/model_distiller.py exp-042 --compression 4
|
|
11
|
+
python scripts/model_distiller.py exp-042 --method soft-labels
|
|
12
|
+
python scripts/model_distiller.py exp-042 --target-latency 5
|
|
13
|
+
"""
|
|
14
|
+
|
|
15
|
+
from __future__ import annotations
|
|
16
|
+
|
|
17
|
+
import argparse
|
|
18
|
+
import json
|
|
19
|
+
import math
|
|
20
|
+
import sys
|
|
21
|
+
from datetime import datetime, timezone
|
|
22
|
+
from pathlib import Path
|
|
23
|
+
|
|
24
|
+
import yaml
|
|
25
|
+
|
|
26
|
+
from scripts.turing_io import load_config, load_experiments
|
|
27
|
+
|
|
28
|
+
DEFAULT_LOG_PATH = "experiments/log.jsonl"
|
|
29
|
+
DEFAULT_COMPRESSION = 4 # 4x compression
|
|
30
|
+
DISTILLATION_METHODS = ["soft_labels", "feature_matching", "dataset_distillation"]
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
# --- Student Architecture Selection ---
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
def select_student_architecture(
|
|
37
|
+
teacher_config: dict,
|
|
38
|
+
compression: float,
|
|
39
|
+
) -> dict:
|
|
40
|
+
"""Auto-select student architecture based on teacher and compression target.
|
|
41
|
+
|
|
42
|
+
Args:
|
|
43
|
+
teacher_config: Teacher model config.
|
|
44
|
+
compression: Compression ratio (e.g., 4 = 4x smaller).
|
|
45
|
+
|
|
46
|
+
Returns:
|
|
47
|
+
Student config dict with model type and hyperparameters.
|
|
48
|
+
"""
|
|
49
|
+
model_type = teacher_config.get("model_type", "").lower()
|
|
50
|
+
hyperparams = teacher_config.get("hyperparams", {})
|
|
51
|
+
|
|
52
|
+
student = {
|
|
53
|
+
"model_type": model_type,
|
|
54
|
+
"hyperparams": {},
|
|
55
|
+
"compression_strategy": "",
|
|
56
|
+
}
|
|
57
|
+
|
|
58
|
+
if _is_tree_model(model_type):
|
|
59
|
+
student.update(_select_tree_student(hyperparams, compression))
|
|
60
|
+
elif _is_neural_model(model_type):
|
|
61
|
+
student.update(_select_neural_student(hyperparams, compression))
|
|
62
|
+
elif _is_sklearn_model(model_type):
|
|
63
|
+
student.update(_select_sklearn_student(model_type, hyperparams, compression))
|
|
64
|
+
else:
|
|
65
|
+
# Generic: reduce all numeric hyperparams by compression ratio
|
|
66
|
+
student["compression_strategy"] = "generic_reduction"
|
|
67
|
+
student["hyperparams"] = {
|
|
68
|
+
k: max(1, int(v / compression)) if isinstance(v, int) else v
|
|
69
|
+
for k, v in hyperparams.items()
|
|
70
|
+
}
|
|
71
|
+
|
|
72
|
+
return student
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
def _is_tree_model(model_type: str) -> bool:
|
|
76
|
+
return any(t in model_type for t in ("xgboost", "lightgbm", "catboost", "gbm", "gradient_boosting"))
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
def _is_neural_model(model_type: str) -> bool:
|
|
80
|
+
return any(t in model_type for t in ("mlp", "neural", "nn", "pytorch", "tensorflow", "transformer"))
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
def _is_sklearn_model(model_type: str) -> bool:
|
|
84
|
+
return any(t in model_type for t in ("random_forest", "svm", "knn", "logistic", "ridge"))
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
def _select_tree_student(hyperparams: dict, compression: float) -> dict:
|
|
88
|
+
"""Select student for tree-based models: fewer estimators, shallower."""
|
|
89
|
+
n_estimators = hyperparams.get("n_estimators", 100)
|
|
90
|
+
max_depth = hyperparams.get("max_depth", 6)
|
|
91
|
+
|
|
92
|
+
return {
|
|
93
|
+
"compression_strategy": "reduce_trees",
|
|
94
|
+
"hyperparams": {
|
|
95
|
+
"n_estimators": max(1, int(n_estimators / compression)),
|
|
96
|
+
"max_depth": max(1, int(max_depth / math.sqrt(compression))),
|
|
97
|
+
"learning_rate": hyperparams.get("learning_rate", 0.1),
|
|
98
|
+
},
|
|
99
|
+
}
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
def _select_neural_student(hyperparams: dict, compression: float) -> dict:
|
|
103
|
+
"""Select student for neural models: fewer layers, narrower."""
|
|
104
|
+
hidden_size = hyperparams.get("hidden_size", 256)
|
|
105
|
+
n_layers = hyperparams.get("n_layers", hyperparams.get("layers", 4))
|
|
106
|
+
|
|
107
|
+
return {
|
|
108
|
+
"compression_strategy": "reduce_architecture",
|
|
109
|
+
"hyperparams": {
|
|
110
|
+
"hidden_size": max(8, int(hidden_size / math.sqrt(compression))),
|
|
111
|
+
"n_layers": max(1, int(n_layers / math.sqrt(compression))),
|
|
112
|
+
"learning_rate": hyperparams.get("learning_rate", 0.001),
|
|
113
|
+
},
|
|
114
|
+
}
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
def _select_sklearn_student(model_type: str, hyperparams: dict, compression: float) -> dict:
|
|
118
|
+
"""Select student for sklearn models: simpler model family."""
|
|
119
|
+
# Map complex models to simpler alternatives
|
|
120
|
+
student_map = {
|
|
121
|
+
"random_forest": "decision_tree",
|
|
122
|
+
"svm": "logistic_regression",
|
|
123
|
+
"knn": "logistic_regression",
|
|
124
|
+
}
|
|
125
|
+
|
|
126
|
+
student_type = student_map.get(model_type, model_type)
|
|
127
|
+
|
|
128
|
+
student_params = {}
|
|
129
|
+
if student_type == "decision_tree":
|
|
130
|
+
student_params["max_depth"] = max(1, int(hyperparams.get("max_depth", 10) / compression))
|
|
131
|
+
elif student_type == "logistic_regression":
|
|
132
|
+
student_params["C"] = hyperparams.get("C", 1.0)
|
|
133
|
+
|
|
134
|
+
return {
|
|
135
|
+
"model_type": student_type,
|
|
136
|
+
"compression_strategy": "simpler_family",
|
|
137
|
+
"hyperparams": student_params,
|
|
138
|
+
}
|
|
139
|
+
|
|
140
|
+
|
|
141
|
+
# --- Distillation Configuration ---
|
|
142
|
+
|
|
143
|
+
|
|
144
|
+
def plan_distillation(
|
|
145
|
+
teacher_exp: dict,
|
|
146
|
+
compression: float = DEFAULT_COMPRESSION,
|
|
147
|
+
method: str = "soft_labels",
|
|
148
|
+
target_latency: float | None = None,
|
|
149
|
+
) -> dict:
|
|
150
|
+
"""Plan a distillation run.
|
|
151
|
+
|
|
152
|
+
Args:
|
|
153
|
+
teacher_exp: Teacher experiment dict.
|
|
154
|
+
compression: Compression ratio.
|
|
155
|
+
method: Distillation method.
|
|
156
|
+
target_latency: Optional target latency in ms.
|
|
157
|
+
|
|
158
|
+
Returns:
|
|
159
|
+
Distillation plan dict.
|
|
160
|
+
"""
|
|
161
|
+
teacher_id = teacher_exp.get("experiment_id", "unknown")
|
|
162
|
+
teacher_config = teacher_exp.get("config", {})
|
|
163
|
+
teacher_metrics = teacher_exp.get("metrics", {})
|
|
164
|
+
|
|
165
|
+
# Select student architecture
|
|
166
|
+
student = select_student_architecture(teacher_config, compression)
|
|
167
|
+
|
|
168
|
+
# Estimate student size
|
|
169
|
+
teacher_size = teacher_metrics.get("model_size_bytes", teacher_metrics.get("n_params", 0))
|
|
170
|
+
estimated_student_size = teacher_size / compression if teacher_size else None
|
|
171
|
+
|
|
172
|
+
# Estimate student latency
|
|
173
|
+
teacher_latency = teacher_metrics.get("latency_ms", teacher_metrics.get("inference_ms", 0))
|
|
174
|
+
estimated_student_latency = teacher_latency / math.sqrt(compression) if teacher_latency else None
|
|
175
|
+
|
|
176
|
+
# If target latency specified, adjust compression
|
|
177
|
+
if target_latency and teacher_latency and teacher_latency > 0:
|
|
178
|
+
needed_speedup = teacher_latency / target_latency
|
|
179
|
+
adjusted_compression = needed_speedup ** 2 # Latency scales with sqrt(compression)
|
|
180
|
+
if adjusted_compression > compression:
|
|
181
|
+
compression = adjusted_compression
|
|
182
|
+
student = select_student_architecture(teacher_config, compression)
|
|
183
|
+
|
|
184
|
+
plan = {
|
|
185
|
+
"teacher_id": teacher_id,
|
|
186
|
+
"teacher_metrics": teacher_metrics,
|
|
187
|
+
"teacher_config": teacher_config,
|
|
188
|
+
"compression": round(compression, 2),
|
|
189
|
+
"method": method,
|
|
190
|
+
"student": student,
|
|
191
|
+
"estimates": {
|
|
192
|
+
"student_size_bytes": int(estimated_student_size) if estimated_student_size else None,
|
|
193
|
+
"student_latency_ms": round(estimated_student_latency, 2) if estimated_student_latency else None,
|
|
194
|
+
"size_reduction": f"{(1 - 1/compression) * 100:.0f}%" if compression > 0 else "N/A",
|
|
195
|
+
},
|
|
196
|
+
"distillation_config": _build_distillation_config(method),
|
|
197
|
+
}
|
|
198
|
+
|
|
199
|
+
return plan
|
|
200
|
+
|
|
201
|
+
|
|
202
|
+
def _build_distillation_config(method: str) -> dict:
|
|
203
|
+
"""Build distillation-specific configuration."""
|
|
204
|
+
if method == "soft_labels":
|
|
205
|
+
return {
|
|
206
|
+
"temperature": 3.0,
|
|
207
|
+
"alpha": 0.7, # Weight of soft labels vs hard labels
|
|
208
|
+
"description": "Train student on teacher's probability outputs with temperature scaling",
|
|
209
|
+
}
|
|
210
|
+
elif method == "feature_matching":
|
|
211
|
+
return {
|
|
212
|
+
"match_layers": "last_hidden",
|
|
213
|
+
"loss": "mse",
|
|
214
|
+
"description": "Align student's intermediate representations with teacher's",
|
|
215
|
+
}
|
|
216
|
+
elif method == "dataset_distillation":
|
|
217
|
+
return {
|
|
218
|
+
"synthetic_samples": 1000,
|
|
219
|
+
"description": "Train student on teacher-labeled synthetic data",
|
|
220
|
+
}
|
|
221
|
+
return {"description": "Unknown method"}
|
|
222
|
+
|
|
223
|
+
|
|
224
|
+
# --- Verdict ---
|
|
225
|
+
|
|
226
|
+
|
|
227
|
+
def compute_distillation_verdict(
|
|
228
|
+
teacher_metrics: dict,
|
|
229
|
+
student_metrics: dict,
|
|
230
|
+
primary_metric: str,
|
|
231
|
+
compression: float,
|
|
232
|
+
) -> dict:
|
|
233
|
+
"""Compute verdict on distillation quality.
|
|
234
|
+
|
|
235
|
+
Args:
|
|
236
|
+
teacher_metrics: Teacher model metrics.
|
|
237
|
+
student_metrics: Student model metrics.
|
|
238
|
+
primary_metric: Name of primary metric.
|
|
239
|
+
compression: Achieved compression ratio.
|
|
240
|
+
|
|
241
|
+
Returns:
|
|
242
|
+
Verdict dict.
|
|
243
|
+
"""
|
|
244
|
+
teacher_val = teacher_metrics.get(primary_metric, 0)
|
|
245
|
+
student_val = student_metrics.get(primary_metric, 0)
|
|
246
|
+
|
|
247
|
+
if teacher_val == 0:
|
|
248
|
+
return {"verdict": "no_baseline", "reason": "Teacher has no metric to compare against"}
|
|
249
|
+
|
|
250
|
+
delta = student_val - teacher_val
|
|
251
|
+
relative_loss = abs(delta) / abs(teacher_val) if teacher_val != 0 else 0
|
|
252
|
+
|
|
253
|
+
if relative_loss < 0.01: # < 1% accuracy loss
|
|
254
|
+
verdict = "excellent"
|
|
255
|
+
reason = f"{relative_loss:.1%} accuracy loss for {compression:.0f}x compression. Excellent tradeoff."
|
|
256
|
+
elif relative_loss < 0.03: # < 3% loss
|
|
257
|
+
verdict = "acceptable"
|
|
258
|
+
reason = f"{relative_loss:.1%} accuracy loss for {compression:.0f}x compression. Acceptable for production."
|
|
259
|
+
elif relative_loss < 0.05: # < 5% loss
|
|
260
|
+
verdict = "marginal"
|
|
261
|
+
reason = f"{relative_loss:.1%} accuracy loss for {compression:.0f}x compression. Consider lower compression."
|
|
262
|
+
else:
|
|
263
|
+
verdict = "too_much_loss"
|
|
264
|
+
reason = f"{relative_loss:.1%} accuracy loss for {compression:.0f}x compression. Try a less aggressive compression."
|
|
265
|
+
|
|
266
|
+
return {
|
|
267
|
+
"verdict": verdict,
|
|
268
|
+
"delta": round(delta, 6),
|
|
269
|
+
"relative_loss": round(relative_loss, 6),
|
|
270
|
+
"compression": compression,
|
|
271
|
+
"reason": reason,
|
|
272
|
+
}
|
|
273
|
+
|
|
274
|
+
|
|
275
|
+
# --- Full Pipeline ---
|
|
276
|
+
|
|
277
|
+
|
|
278
|
+
def distill_model(
|
|
279
|
+
teacher_exp_id: str,
|
|
280
|
+
compression: float = DEFAULT_COMPRESSION,
|
|
281
|
+
method: str = "soft_labels",
|
|
282
|
+
target_latency: float | None = None,
|
|
283
|
+
config_path: str = "config.yaml",
|
|
284
|
+
log_path: str = DEFAULT_LOG_PATH,
|
|
285
|
+
) -> dict:
|
|
286
|
+
"""Plan and report a model distillation.
|
|
287
|
+
|
|
288
|
+
Args:
|
|
289
|
+
teacher_exp_id: Teacher experiment ID.
|
|
290
|
+
compression: Compression ratio.
|
|
291
|
+
method: Distillation method.
|
|
292
|
+
target_latency: Target inference latency in ms.
|
|
293
|
+
config_path: Path to config.yaml.
|
|
294
|
+
log_path: Path to experiment log.
|
|
295
|
+
|
|
296
|
+
Returns:
|
|
297
|
+
Complete distillation report.
|
|
298
|
+
"""
|
|
299
|
+
config = load_config(config_path)
|
|
300
|
+
eval_cfg = config.get("evaluation", {})
|
|
301
|
+
primary_metric = eval_cfg.get("primary_metric", "accuracy")
|
|
302
|
+
|
|
303
|
+
experiments = load_experiments(log_path)
|
|
304
|
+
|
|
305
|
+
teacher = None
|
|
306
|
+
for exp in experiments:
|
|
307
|
+
if exp.get("experiment_id") == teacher_exp_id:
|
|
308
|
+
teacher = exp
|
|
309
|
+
break
|
|
310
|
+
|
|
311
|
+
if not teacher:
|
|
312
|
+
return {"error": f"Teacher experiment {teacher_exp_id} not found in {log_path}"}
|
|
313
|
+
|
|
314
|
+
plan = plan_distillation(teacher, compression, method, target_latency)
|
|
315
|
+
|
|
316
|
+
report = {
|
|
317
|
+
"generated_at": datetime.now(timezone.utc).isoformat(),
|
|
318
|
+
"primary_metric": primary_metric,
|
|
319
|
+
"plan": plan,
|
|
320
|
+
}
|
|
321
|
+
|
|
322
|
+
return report
|
|
323
|
+
|
|
324
|
+
|
|
325
|
+
# --- Report Formatting ---
|
|
326
|
+
|
|
327
|
+
|
|
328
|
+
def save_distillation_report(report: dict, output_dir: str = "experiments/distillations") -> Path:
|
|
329
|
+
"""Save distillation report to YAML."""
|
|
330
|
+
out_path = Path(output_dir)
|
|
331
|
+
out_path.mkdir(parents=True, exist_ok=True)
|
|
332
|
+
|
|
333
|
+
teacher_id = report.get("plan", {}).get("teacher_id", "unknown")
|
|
334
|
+
filepath = out_path / f"distill-{teacher_id}.yaml"
|
|
335
|
+
|
|
336
|
+
with open(filepath, "w") as f:
|
|
337
|
+
yaml.dump(report, f, default_flow_style=False, sort_keys=False)
|
|
338
|
+
|
|
339
|
+
return filepath
|
|
340
|
+
|
|
341
|
+
|
|
342
|
+
def format_distillation_report(report: dict) -> str:
|
|
343
|
+
"""Format distillation report as markdown."""
|
|
344
|
+
if "error" in report:
|
|
345
|
+
return f"ERROR: {report['error']}"
|
|
346
|
+
|
|
347
|
+
plan = report.get("plan", {})
|
|
348
|
+
teacher_id = plan.get("teacher_id", "?")
|
|
349
|
+
compression = plan.get("compression", 0)
|
|
350
|
+
method = plan.get("method", "?")
|
|
351
|
+
student = plan.get("student", {})
|
|
352
|
+
estimates = plan.get("estimates", {})
|
|
353
|
+
dist_cfg = plan.get("distillation_config", {})
|
|
354
|
+
|
|
355
|
+
lines = [
|
|
356
|
+
f"# Distillation Plan: {teacher_id}",
|
|
357
|
+
"",
|
|
358
|
+
f"*Generated {report.get('generated_at', 'N/A')[:19]}*",
|
|
359
|
+
"",
|
|
360
|
+
f"**Compression:** {compression:.0f}x",
|
|
361
|
+
f"**Method:** {method}",
|
|
362
|
+
f"**Strategy:** {student.get('compression_strategy', '?')}",
|
|
363
|
+
"",
|
|
364
|
+
]
|
|
365
|
+
|
|
366
|
+
# Teacher info
|
|
367
|
+
teacher_metrics = plan.get("teacher_metrics", {})
|
|
368
|
+
if teacher_metrics:
|
|
369
|
+
lines.extend(["## Teacher Model", ""])
|
|
370
|
+
for k, v in teacher_metrics.items():
|
|
371
|
+
v_str = f"{v:.4f}" if isinstance(v, float) else str(v)
|
|
372
|
+
lines.append(f"- **{k}:** {v_str}")
|
|
373
|
+
lines.append("")
|
|
374
|
+
|
|
375
|
+
# Student architecture
|
|
376
|
+
lines.extend(["## Student Architecture", ""])
|
|
377
|
+
lines.append(f"- **Model type:** {student.get('model_type', '?')}")
|
|
378
|
+
for k, v in student.get("hyperparams", {}).items():
|
|
379
|
+
lines.append(f"- **{k}:** {v}")
|
|
380
|
+
lines.append("")
|
|
381
|
+
|
|
382
|
+
# Estimates
|
|
383
|
+
if any(v is not None for v in estimates.values()):
|
|
384
|
+
lines.extend(["## Estimates", ""])
|
|
385
|
+
if estimates.get("size_reduction"):
|
|
386
|
+
lines.append(f"- **Size reduction:** {estimates['size_reduction']}")
|
|
387
|
+
if estimates.get("student_latency_ms"):
|
|
388
|
+
lines.append(f"- **Estimated latency:** {estimates['student_latency_ms']:.1f} ms")
|
|
389
|
+
lines.append("")
|
|
390
|
+
|
|
391
|
+
# Distillation config
|
|
392
|
+
lines.extend([
|
|
393
|
+
"## Distillation Config",
|
|
394
|
+
"",
|
|
395
|
+
f"*{dist_cfg.get('description', method)}*",
|
|
396
|
+
"",
|
|
397
|
+
])
|
|
398
|
+
for k, v in dist_cfg.items():
|
|
399
|
+
if k != "description":
|
|
400
|
+
lines.append(f"- **{k}:** {v}")
|
|
401
|
+
|
|
402
|
+
# Verdict (if student metrics available)
|
|
403
|
+
verdict = report.get("verdict")
|
|
404
|
+
if verdict:
|
|
405
|
+
labels = {
|
|
406
|
+
"excellent": "EXCELLENT",
|
|
407
|
+
"acceptable": "ACCEPTABLE",
|
|
408
|
+
"marginal": "MARGINAL",
|
|
409
|
+
"too_much_loss": "TOO MUCH LOSS",
|
|
410
|
+
}
|
|
411
|
+
lines.extend([
|
|
412
|
+
"",
|
|
413
|
+
"## Verdict",
|
|
414
|
+
"",
|
|
415
|
+
f"**{labels.get(verdict.get('verdict', ''), verdict.get('verdict', '?'))}**",
|
|
416
|
+
"",
|
|
417
|
+
verdict.get("reason", ""),
|
|
418
|
+
])
|
|
419
|
+
|
|
420
|
+
return "\n".join(lines)
|
|
421
|
+
|
|
422
|
+
|
|
423
|
+
def main() -> None:
|
|
424
|
+
"""CLI entry point."""
|
|
425
|
+
parser = argparse.ArgumentParser(
|
|
426
|
+
description="Model compression via distillation",
|
|
427
|
+
)
|
|
428
|
+
parser.add_argument(
|
|
429
|
+
"teacher_exp_id",
|
|
430
|
+
help="Teacher experiment ID (e.g., exp-042)",
|
|
431
|
+
)
|
|
432
|
+
parser.add_argument(
|
|
433
|
+
"--compression", type=float, default=DEFAULT_COMPRESSION,
|
|
434
|
+
help=f"Compression ratio (default: {DEFAULT_COMPRESSION}x)",
|
|
435
|
+
)
|
|
436
|
+
parser.add_argument(
|
|
437
|
+
"--method", choices=DISTILLATION_METHODS, default="soft_labels",
|
|
438
|
+
help="Distillation method (default: soft_labels)",
|
|
439
|
+
)
|
|
440
|
+
parser.add_argument(
|
|
441
|
+
"--target-latency", type=float,
|
|
442
|
+
help="Target inference latency in ms (auto-adjusts compression)",
|
|
443
|
+
)
|
|
444
|
+
parser.add_argument(
|
|
445
|
+
"--config", default="config.yaml",
|
|
446
|
+
help="Path to config.yaml",
|
|
447
|
+
)
|
|
448
|
+
parser.add_argument(
|
|
449
|
+
"--log", default=DEFAULT_LOG_PATH,
|
|
450
|
+
help="Path to experiment log",
|
|
451
|
+
)
|
|
452
|
+
parser.add_argument(
|
|
453
|
+
"--json", action="store_true",
|
|
454
|
+
help="Output raw JSON instead of formatted report",
|
|
455
|
+
)
|
|
456
|
+
args = parser.parse_args()
|
|
457
|
+
|
|
458
|
+
report = distill_model(
|
|
459
|
+
teacher_exp_id=args.teacher_exp_id,
|
|
460
|
+
compression=args.compression,
|
|
461
|
+
method=args.method,
|
|
462
|
+
target_latency=args.target_latency,
|
|
463
|
+
config_path=args.config,
|
|
464
|
+
log_path=args.log,
|
|
465
|
+
)
|
|
466
|
+
|
|
467
|
+
if "error" not in report:
|
|
468
|
+
filepath = save_distillation_report(report)
|
|
469
|
+
print(f"Saved to {filepath}", file=sys.stderr)
|
|
470
|
+
|
|
471
|
+
if args.json:
|
|
472
|
+
print(json.dumps(report, indent=2, default=str))
|
|
473
|
+
else:
|
|
474
|
+
print(format_distillation_report(report))
|
|
475
|
+
|
|
476
|
+
|
|
477
|
+
if __name__ == "__main__":
|
|
478
|
+
main()
|
|
@@ -113,6 +113,11 @@ TEMPLATE_DIRS = {
|
|
|
113
113
|
"build_ensemble.py",
|
|
114
114
|
"pipeline_manager.py",
|
|
115
115
|
"warm_start.py",
|
|
116
|
+
"scaling_estimator.py",
|
|
117
|
+
"budget_manager.py",
|
|
118
|
+
"model_distiller.py",
|
|
119
|
+
"knowledge_transfer.py",
|
|
120
|
+
"methodology_audit.py",
|
|
116
121
|
],
|
|
117
122
|
"tests": ["__init__.py", "conftest.py"],
|
|
118
123
|
}
|
|
@@ -139,6 +144,10 @@ DIRECTORIES_TO_CREATE = [
|
|
|
139
144
|
"experiments/ensembles",
|
|
140
145
|
"experiments/cache",
|
|
141
146
|
"experiments/warm_starts",
|
|
147
|
+
"experiments/scaling",
|
|
148
|
+
"experiments/distillations",
|
|
149
|
+
"experiments/transfers",
|
|
150
|
+
"experiments/audits",
|
|
142
151
|
"experiments/logs",
|
|
143
152
|
"models/best",
|
|
144
153
|
"models/archive",
|