claude-turing 1.4.0 → 2.0.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/.claude-plugin/plugin.json +2 -2
- package/README.md +5 -2
- package/commands/checkpoint.md +47 -0
- package/commands/export.md +48 -0
- package/commands/profile.md +43 -0
- package/commands/turing.md +6 -0
- package/package.json +1 -1
- package/src/install.js +1 -1
- package/src/verify.js +3 -0
- package/templates/scripts/__pycache__/checkpoint_manager.cpython-314.pyc +0 -0
- package/templates/scripts/__pycache__/equivalence_checker.cpython-314.pyc +0 -0
- package/templates/scripts/__pycache__/export_card.cpython-314.pyc +0 -0
- package/templates/scripts/__pycache__/export_formats.cpython-314.pyc +0 -0
- package/templates/scripts/__pycache__/generate_brief.cpython-314.pyc +0 -0
- package/templates/scripts/__pycache__/latency_benchmark.cpython-314.pyc +0 -0
- package/templates/scripts/__pycache__/profile_training.cpython-314.pyc +0 -0
- package/templates/scripts/__pycache__/scaffold.cpython-314.pyc +0 -0
- package/templates/scripts/checkpoint_manager.py +449 -0
- package/templates/scripts/equivalence_checker.py +158 -0
- package/templates/scripts/export_card.py +183 -0
- package/templates/scripts/export_formats.py +385 -0
- package/templates/scripts/export_model.py +324 -0
- package/templates/scripts/generate_brief.py +38 -1
- package/templates/scripts/latency_benchmark.py +167 -0
- package/templates/scripts/profile_training.py +533 -0
- package/templates/scripts/scaffold.py +10 -0
|
@@ -0,0 +1,183 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
"""Deployment model card generation for exported models.
|
|
3
|
+
|
|
4
|
+
Produces a structured model card with metrics, seed study results,
|
|
5
|
+
export format, equivalence check, latency benchmarks, and dependencies.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from __future__ import annotations
|
|
9
|
+
|
|
10
|
+
import yaml
|
|
11
|
+
from datetime import datetime, timezone
|
|
12
|
+
from pathlib import Path
|
|
13
|
+
|
|
14
|
+
from scripts.turing_io import load_config, load_seed_study
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def generate_export_card(
|
|
18
|
+
experiment: dict,
|
|
19
|
+
export_result: dict,
|
|
20
|
+
equivalence: dict | None = None,
|
|
21
|
+
latency: dict | None = None,
|
|
22
|
+
config: dict | None = None,
|
|
23
|
+
) -> dict:
|
|
24
|
+
"""Generate a deployment model card for an exported model.
|
|
25
|
+
|
|
26
|
+
Args:
|
|
27
|
+
experiment: Original experiment dict from log.jsonl.
|
|
28
|
+
export_result: Result from export_formats.export_model().
|
|
29
|
+
equivalence: Result from equivalence_checker.compare_outputs().
|
|
30
|
+
latency: Latency comparison from latency_benchmark.compare_latency().
|
|
31
|
+
config: Project config dict.
|
|
32
|
+
|
|
33
|
+
Returns:
|
|
34
|
+
Model card dict.
|
|
35
|
+
"""
|
|
36
|
+
exp_id = experiment.get("experiment_id", "unknown")
|
|
37
|
+
metrics = experiment.get("metrics", {})
|
|
38
|
+
exp_config = experiment.get("config", {})
|
|
39
|
+
model_type = exp_config.get("model_type", config.get("model", {}).get("type", "unknown") if config else "unknown")
|
|
40
|
+
|
|
41
|
+
eval_cfg = config.get("evaluation", {}) if config else {}
|
|
42
|
+
primary_metric = eval_cfg.get("primary_metric", "accuracy")
|
|
43
|
+
task_desc = config.get("task_description", eval_cfg.get("primary_metric", "N/A")) if config else "N/A"
|
|
44
|
+
|
|
45
|
+
card = {
|
|
46
|
+
"name": f"{exp_id}-{model_type}",
|
|
47
|
+
"experiment_id": exp_id,
|
|
48
|
+
"task": task_desc,
|
|
49
|
+
"model_type": model_type,
|
|
50
|
+
"primary_metric": primary_metric,
|
|
51
|
+
"metrics": {k: round(v, 4) if isinstance(v, float) else v for k, v in metrics.items()},
|
|
52
|
+
"export_format": export_result.get("format", "unknown"),
|
|
53
|
+
"export_path": export_result.get("path"),
|
|
54
|
+
"size_mb": export_result.get("size_mb", 0),
|
|
55
|
+
"dependencies": export_result.get("dependencies", []),
|
|
56
|
+
"training_date": experiment.get("timestamp", "unknown"),
|
|
57
|
+
"export_date": datetime.now(timezone.utc).isoformat(),
|
|
58
|
+
}
|
|
59
|
+
|
|
60
|
+
# Seed study (if available)
|
|
61
|
+
seed_study = load_seed_study(exp_id)
|
|
62
|
+
if seed_study and "mean" in seed_study:
|
|
63
|
+
card["seed_study"] = {
|
|
64
|
+
"mean": seed_study["mean"],
|
|
65
|
+
"std": seed_study.get("std", 0),
|
|
66
|
+
"cv_percent": seed_study.get("cv_percent", 0),
|
|
67
|
+
"seed_sensitive": seed_study.get("seed_sensitive", False),
|
|
68
|
+
"seeds_tested": len(seed_study.get("seeds_run", [])),
|
|
69
|
+
}
|
|
70
|
+
|
|
71
|
+
# Equivalence check
|
|
72
|
+
if equivalence:
|
|
73
|
+
card["equivalence"] = {
|
|
74
|
+
"verdict": equivalence.get("verdict", "unknown"),
|
|
75
|
+
"max_delta": equivalence.get("max_delta", 0),
|
|
76
|
+
"n_samples_tested": equivalence.get("n_samples", 0),
|
|
77
|
+
}
|
|
78
|
+
|
|
79
|
+
# Latency benchmark
|
|
80
|
+
if latency and latency.get("verdict") != "error":
|
|
81
|
+
card["inference_latency"] = {
|
|
82
|
+
"exported_p50_ms": latency.get("exported_p50_ms"),
|
|
83
|
+
"exported_p95_ms": latency.get("exported_p95_ms"),
|
|
84
|
+
"original_p50_ms": latency.get("original_p50_ms"),
|
|
85
|
+
"speedup": latency.get("speedup_ratio"),
|
|
86
|
+
}
|
|
87
|
+
|
|
88
|
+
# Environment
|
|
89
|
+
env = experiment.get("environment")
|
|
90
|
+
if env:
|
|
91
|
+
card["training_environment"] = {
|
|
92
|
+
"python_version": env.get("python_version"),
|
|
93
|
+
"gpu": env.get("gpu_name") or env.get("gpu"),
|
|
94
|
+
}
|
|
95
|
+
|
|
96
|
+
return card
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
def save_export_card(card: dict, output_dir: str) -> Path:
|
|
100
|
+
"""Save export model card to YAML file."""
|
|
101
|
+
out_path = Path(output_dir)
|
|
102
|
+
out_path.mkdir(parents=True, exist_ok=True)
|
|
103
|
+
filepath = out_path / "model_card.yaml"
|
|
104
|
+
with open(filepath, "w") as f:
|
|
105
|
+
yaml.dump(card, f, default_flow_style=False, sort_keys=False)
|
|
106
|
+
return filepath
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
def format_export_card(card: dict) -> str:
|
|
110
|
+
"""Format export model card as readable markdown."""
|
|
111
|
+
lines = [
|
|
112
|
+
f"# Export Model Card: {card.get('name', 'unknown')}",
|
|
113
|
+
"",
|
|
114
|
+
f"- **Experiment:** {card.get('experiment_id', '?')}",
|
|
115
|
+
f"- **Task:** {card.get('task', 'N/A')}",
|
|
116
|
+
f"- **Model type:** {card.get('model_type', '?')}",
|
|
117
|
+
f"- **Export format:** {card.get('export_format', '?')}",
|
|
118
|
+
f"- **Size:** {card.get('size_mb', 0):.2f} MB",
|
|
119
|
+
f"- **Dependencies:** {', '.join(card.get('dependencies', []))}",
|
|
120
|
+
"",
|
|
121
|
+
"## Metrics",
|
|
122
|
+
"",
|
|
123
|
+
]
|
|
124
|
+
|
|
125
|
+
for metric, value in card.get("metrics", {}).items():
|
|
126
|
+
if isinstance(value, float):
|
|
127
|
+
lines.append(f"- **{metric}:** {value:.4f}")
|
|
128
|
+
else:
|
|
129
|
+
lines.append(f"- **{metric}:** {value}")
|
|
130
|
+
|
|
131
|
+
# Seed study
|
|
132
|
+
seed = card.get("seed_study")
|
|
133
|
+
if seed:
|
|
134
|
+
status = "SEED-SENSITIVE" if seed.get("seed_sensitive") else "STABLE"
|
|
135
|
+
lines.extend([
|
|
136
|
+
"",
|
|
137
|
+
"## Seed Study",
|
|
138
|
+
"",
|
|
139
|
+
f"- **Status:** {status}",
|
|
140
|
+
f"- **Mean ± Std:** {seed['mean']:.4f} ± {seed.get('std', 0):.4f}",
|
|
141
|
+
f"- **CV:** {seed.get('cv_percent', 0):.2f}%",
|
|
142
|
+
f"- **Seeds tested:** {seed.get('seeds_tested', 0)}",
|
|
143
|
+
])
|
|
144
|
+
|
|
145
|
+
# Equivalence
|
|
146
|
+
eq = card.get("equivalence")
|
|
147
|
+
if eq:
|
|
148
|
+
verdict_markers = {
|
|
149
|
+
"equivalent": "PASS (exact)",
|
|
150
|
+
"approximately_equivalent": "PASS (approx)",
|
|
151
|
+
"divergent": "FAIL",
|
|
152
|
+
}
|
|
153
|
+
marker = verdict_markers.get(eq["verdict"], eq["verdict"])
|
|
154
|
+
lines.extend([
|
|
155
|
+
"",
|
|
156
|
+
"## Equivalence",
|
|
157
|
+
"",
|
|
158
|
+
f"- **Verdict:** {marker}",
|
|
159
|
+
f"- **Max delta:** {eq.get('max_delta', 0):.2e}",
|
|
160
|
+
f"- **Samples tested:** {eq.get('n_samples_tested', 0)}",
|
|
161
|
+
])
|
|
162
|
+
|
|
163
|
+
# Latency
|
|
164
|
+
lat = card.get("inference_latency")
|
|
165
|
+
if lat:
|
|
166
|
+
lines.extend([
|
|
167
|
+
"",
|
|
168
|
+
"## Inference Latency",
|
|
169
|
+
"",
|
|
170
|
+
f"- **Exported p50:** {lat.get('exported_p50_ms', 0):.2f} ms",
|
|
171
|
+
f"- **Exported p95:** {lat.get('exported_p95_ms', 0):.2f} ms",
|
|
172
|
+
])
|
|
173
|
+
if lat.get("original_p50_ms"):
|
|
174
|
+
lines.append(f"- **Original p50:** {lat['original_p50_ms']:.2f} ms")
|
|
175
|
+
if lat.get("speedup"):
|
|
176
|
+
lines.append(f"- **Speedup:** {lat['speedup']:.1f}x")
|
|
177
|
+
|
|
178
|
+
lines.extend([
|
|
179
|
+
"",
|
|
180
|
+
f"*Exported: {card.get('export_date', 'unknown')}*",
|
|
181
|
+
])
|
|
182
|
+
|
|
183
|
+
return "\n".join(lines)
|
|
@@ -0,0 +1,385 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
"""Format-specific model export handlers.
|
|
3
|
+
|
|
4
|
+
Each handler knows how to export a specific model type to a production
|
|
5
|
+
format. Returns the export path and metadata.
|
|
6
|
+
|
|
7
|
+
Supported formats:
|
|
8
|
+
- joblib: scikit-learn, XGBoost, LightGBM (default)
|
|
9
|
+
- xgboost_json: XGBoost native JSON format
|
|
10
|
+
- lightgbm_text: LightGBM native text format
|
|
11
|
+
- onnx: ONNX via framework-specific converters
|
|
12
|
+
- torchscript: PyTorch JIT trace
|
|
13
|
+
- tflite: TensorFlow Lite
|
|
14
|
+
"""
|
|
15
|
+
|
|
16
|
+
from __future__ import annotations
|
|
17
|
+
|
|
18
|
+
import shutil
|
|
19
|
+
import sys
|
|
20
|
+
from pathlib import Path
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
# Registry of model types -> supported export formats
|
|
24
|
+
FORMAT_REGISTRY = {
|
|
25
|
+
"xgboost": ["joblib", "xgboost_json", "onnx"],
|
|
26
|
+
"lightgbm": ["joblib", "lightgbm_text", "onnx"],
|
|
27
|
+
"random_forest": ["joblib", "onnx"],
|
|
28
|
+
"gradient_boosting": ["joblib", "onnx"],
|
|
29
|
+
"logistic_regression": ["joblib", "onnx"],
|
|
30
|
+
"svm": ["joblib", "onnx"],
|
|
31
|
+
"mlp": ["joblib", "onnx"],
|
|
32
|
+
"pytorch": ["torchscript", "onnx"],
|
|
33
|
+
"tensorflow": ["tflite", "onnx"],
|
|
34
|
+
"keras": ["tflite", "onnx"],
|
|
35
|
+
"catboost": ["joblib", "onnx"],
|
|
36
|
+
}
|
|
37
|
+
|
|
38
|
+
# Default format for each model type
|
|
39
|
+
DEFAULT_FORMAT = {
|
|
40
|
+
"xgboost": "joblib",
|
|
41
|
+
"lightgbm": "joblib",
|
|
42
|
+
"random_forest": "joblib",
|
|
43
|
+
"gradient_boosting": "joblib",
|
|
44
|
+
"logistic_regression": "joblib",
|
|
45
|
+
"svm": "joblib",
|
|
46
|
+
"mlp": "joblib",
|
|
47
|
+
"pytorch": "torchscript",
|
|
48
|
+
"tensorflow": "tflite",
|
|
49
|
+
"keras": "tflite",
|
|
50
|
+
"catboost": "joblib",
|
|
51
|
+
}
|
|
52
|
+
|
|
53
|
+
# File extensions for each format
|
|
54
|
+
FORMAT_EXTENSIONS = {
|
|
55
|
+
"joblib": ".joblib",
|
|
56
|
+
"xgboost_json": ".json",
|
|
57
|
+
"lightgbm_text": ".txt",
|
|
58
|
+
"onnx": ".onnx",
|
|
59
|
+
"torchscript": ".pt",
|
|
60
|
+
"tflite": ".tflite",
|
|
61
|
+
}
|
|
62
|
+
|
|
63
|
+
# Dependencies required for each format
|
|
64
|
+
FORMAT_DEPENDENCIES = {
|
|
65
|
+
"joblib": ["joblib"],
|
|
66
|
+
"xgboost_json": ["xgboost>=1.7"],
|
|
67
|
+
"lightgbm_text": ["lightgbm>=3.0"],
|
|
68
|
+
"onnx": ["onnx", "onnxruntime"],
|
|
69
|
+
"torchscript": ["torch>=1.9"],
|
|
70
|
+
"tflite": ["tensorflow>=2.0"],
|
|
71
|
+
}
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
def get_supported_formats(model_type: str) -> list[str]:
|
|
75
|
+
"""Get supported export formats for a model type."""
|
|
76
|
+
return FORMAT_REGISTRY.get(model_type, ["joblib"])
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
def get_default_format(model_type: str) -> str:
|
|
80
|
+
"""Get the default export format for a model type."""
|
|
81
|
+
return DEFAULT_FORMAT.get(model_type, "joblib")
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
def detect_model_type(config: dict) -> str:
|
|
85
|
+
"""Detect model type from experiment config."""
|
|
86
|
+
model_type = config.get("model", {}).get("type", "")
|
|
87
|
+
if not model_type:
|
|
88
|
+
model_type = config.get("model_type", "unknown")
|
|
89
|
+
return model_type.lower().replace("-", "_").replace(" ", "_")
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
def export_joblib(
|
|
93
|
+
model_path: str,
|
|
94
|
+
output_dir: str,
|
|
95
|
+
model_name: str,
|
|
96
|
+
) -> dict:
|
|
97
|
+
"""Export model as joblib bundle (copy if already joblib, else convert).
|
|
98
|
+
|
|
99
|
+
Returns dict with path, size_bytes, format, and dependencies.
|
|
100
|
+
"""
|
|
101
|
+
src = Path(model_path)
|
|
102
|
+
if not src.exists():
|
|
103
|
+
return {"error": f"Model file not found: {model_path}"}
|
|
104
|
+
|
|
105
|
+
out_path = Path(output_dir)
|
|
106
|
+
out_path.mkdir(parents=True, exist_ok=True)
|
|
107
|
+
|
|
108
|
+
dst = out_path / f"{model_name}.joblib"
|
|
109
|
+
shutil.copy2(str(src), str(dst))
|
|
110
|
+
|
|
111
|
+
return {
|
|
112
|
+
"path": str(dst),
|
|
113
|
+
"format": "joblib",
|
|
114
|
+
"size_bytes": dst.stat().st_size,
|
|
115
|
+
"size_mb": round(dst.stat().st_size / 1024**2, 2),
|
|
116
|
+
"dependencies": FORMAT_DEPENDENCIES["joblib"],
|
|
117
|
+
}
|
|
118
|
+
|
|
119
|
+
|
|
120
|
+
def export_xgboost_json(
|
|
121
|
+
model_path: str,
|
|
122
|
+
output_dir: str,
|
|
123
|
+
model_name: str,
|
|
124
|
+
) -> dict:
|
|
125
|
+
"""Export XGBoost model to native JSON format."""
|
|
126
|
+
try:
|
|
127
|
+
import joblib
|
|
128
|
+
import xgboost as xgb
|
|
129
|
+
except ImportError as e:
|
|
130
|
+
return {"error": f"Missing dependency: {e}"}
|
|
131
|
+
|
|
132
|
+
src = Path(model_path)
|
|
133
|
+
if not src.exists():
|
|
134
|
+
return {"error": f"Model file not found: {model_path}"}
|
|
135
|
+
|
|
136
|
+
out_path = Path(output_dir)
|
|
137
|
+
out_path.mkdir(parents=True, exist_ok=True)
|
|
138
|
+
|
|
139
|
+
try:
|
|
140
|
+
model = joblib.load(str(src))
|
|
141
|
+
# Handle wrapped models (e.g., in a pipeline)
|
|
142
|
+
if hasattr(model, "get_booster"):
|
|
143
|
+
booster = model.get_booster()
|
|
144
|
+
elif isinstance(model, xgb.Booster):
|
|
145
|
+
booster = model
|
|
146
|
+
else:
|
|
147
|
+
return {"error": "Model is not an XGBoost model or doesn't have get_booster()"}
|
|
148
|
+
|
|
149
|
+
dst = out_path / f"{model_name}.json"
|
|
150
|
+
booster.save_model(str(dst))
|
|
151
|
+
|
|
152
|
+
return {
|
|
153
|
+
"path": str(dst),
|
|
154
|
+
"format": "xgboost_json",
|
|
155
|
+
"size_bytes": dst.stat().st_size,
|
|
156
|
+
"size_mb": round(dst.stat().st_size / 1024**2, 2),
|
|
157
|
+
"dependencies": FORMAT_DEPENDENCIES["xgboost_json"],
|
|
158
|
+
}
|
|
159
|
+
except Exception as e:
|
|
160
|
+
return {"error": f"XGBoost JSON export failed: {e}"}
|
|
161
|
+
|
|
162
|
+
|
|
163
|
+
def export_lightgbm_text(
|
|
164
|
+
model_path: str,
|
|
165
|
+
output_dir: str,
|
|
166
|
+
model_name: str,
|
|
167
|
+
) -> dict:
|
|
168
|
+
"""Export LightGBM model to native text format."""
|
|
169
|
+
try:
|
|
170
|
+
import joblib
|
|
171
|
+
import lightgbm as lgb
|
|
172
|
+
except ImportError as e:
|
|
173
|
+
return {"error": f"Missing dependency: {e}"}
|
|
174
|
+
|
|
175
|
+
src = Path(model_path)
|
|
176
|
+
if not src.exists():
|
|
177
|
+
return {"error": f"Model file not found: {model_path}"}
|
|
178
|
+
|
|
179
|
+
out_path = Path(output_dir)
|
|
180
|
+
out_path.mkdir(parents=True, exist_ok=True)
|
|
181
|
+
|
|
182
|
+
try:
|
|
183
|
+
model = joblib.load(str(src))
|
|
184
|
+
if hasattr(model, "booster_"):
|
|
185
|
+
booster = model.booster_
|
|
186
|
+
elif isinstance(model, lgb.Booster):
|
|
187
|
+
booster = model
|
|
188
|
+
else:
|
|
189
|
+
return {"error": "Model is not a LightGBM model"}
|
|
190
|
+
|
|
191
|
+
dst = out_path / f"{model_name}.txt"
|
|
192
|
+
booster.save_model(str(dst))
|
|
193
|
+
|
|
194
|
+
return {
|
|
195
|
+
"path": str(dst),
|
|
196
|
+
"format": "lightgbm_text",
|
|
197
|
+
"size_bytes": dst.stat().st_size,
|
|
198
|
+
"size_mb": round(dst.stat().st_size / 1024**2, 2),
|
|
199
|
+
"dependencies": FORMAT_DEPENDENCIES["lightgbm_text"],
|
|
200
|
+
}
|
|
201
|
+
except Exception as e:
|
|
202
|
+
return {"error": f"LightGBM text export failed: {e}"}
|
|
203
|
+
|
|
204
|
+
|
|
205
|
+
def export_onnx(
|
|
206
|
+
model_path: str,
|
|
207
|
+
output_dir: str,
|
|
208
|
+
model_name: str,
|
|
209
|
+
model_type: str,
|
|
210
|
+
) -> dict:
|
|
211
|
+
"""Export model to ONNX format."""
|
|
212
|
+
try:
|
|
213
|
+
import joblib
|
|
214
|
+
except ImportError as e:
|
|
215
|
+
return {"error": f"Missing dependency: {e}"}
|
|
216
|
+
|
|
217
|
+
src = Path(model_path)
|
|
218
|
+
if not src.exists():
|
|
219
|
+
return {"error": f"Model file not found: {model_path}"}
|
|
220
|
+
|
|
221
|
+
out_path = Path(output_dir)
|
|
222
|
+
out_path.mkdir(parents=True, exist_ok=True)
|
|
223
|
+
dst = out_path / f"{model_name}.onnx"
|
|
224
|
+
|
|
225
|
+
try:
|
|
226
|
+
model = joblib.load(str(src))
|
|
227
|
+
|
|
228
|
+
# Try sklearn-onnx for scikit-learn compatible models
|
|
229
|
+
try:
|
|
230
|
+
from skl2onnx import convert_sklearn
|
|
231
|
+
from skl2onnx.common.data_types import FloatTensorType
|
|
232
|
+
import numpy as np
|
|
233
|
+
|
|
234
|
+
# Infer input shape from model if possible
|
|
235
|
+
n_features = getattr(model, "n_features_in_", 10)
|
|
236
|
+
initial_type = [("float_input", FloatTensorType([None, n_features]))]
|
|
237
|
+
onx = convert_sklearn(model, initial_types=initial_type)
|
|
238
|
+
|
|
239
|
+
with open(dst, "wb") as f:
|
|
240
|
+
f.write(onx.SerializeToString())
|
|
241
|
+
|
|
242
|
+
return {
|
|
243
|
+
"path": str(dst),
|
|
244
|
+
"format": "onnx",
|
|
245
|
+
"size_bytes": dst.stat().st_size,
|
|
246
|
+
"size_mb": round(dst.stat().st_size / 1024**2, 2),
|
|
247
|
+
"dependencies": FORMAT_DEPENDENCIES["onnx"] + ["skl2onnx"],
|
|
248
|
+
}
|
|
249
|
+
except ImportError:
|
|
250
|
+
return {"error": "ONNX export requires skl2onnx: pip install skl2onnx"}
|
|
251
|
+
except Exception as e:
|
|
252
|
+
return {"error": f"ONNX conversion failed: {e}"}
|
|
253
|
+
|
|
254
|
+
except Exception as e:
|
|
255
|
+
return {"error": f"ONNX export failed: {e}"}
|
|
256
|
+
|
|
257
|
+
|
|
258
|
+
def export_torchscript(
|
|
259
|
+
model_path: str,
|
|
260
|
+
output_dir: str,
|
|
261
|
+
model_name: str,
|
|
262
|
+
) -> dict:
|
|
263
|
+
"""Export PyTorch model to TorchScript."""
|
|
264
|
+
try:
|
|
265
|
+
import torch
|
|
266
|
+
except ImportError:
|
|
267
|
+
return {"error": "TorchScript export requires PyTorch: pip install torch"}
|
|
268
|
+
|
|
269
|
+
src = Path(model_path)
|
|
270
|
+
if not src.exists():
|
|
271
|
+
return {"error": f"Model file not found: {model_path}"}
|
|
272
|
+
|
|
273
|
+
out_path = Path(output_dir)
|
|
274
|
+
out_path.mkdir(parents=True, exist_ok=True)
|
|
275
|
+
dst = out_path / f"{model_name}.pt"
|
|
276
|
+
|
|
277
|
+
try:
|
|
278
|
+
model = torch.load(str(src), map_location="cpu")
|
|
279
|
+
if hasattr(model, "eval"):
|
|
280
|
+
model.eval()
|
|
281
|
+
|
|
282
|
+
# Try tracing with dummy input
|
|
283
|
+
if hasattr(model, "example_input"):
|
|
284
|
+
dummy = model.example_input
|
|
285
|
+
else:
|
|
286
|
+
# Create a dummy input — user may need to customize
|
|
287
|
+
dummy = torch.randn(1, 10)
|
|
288
|
+
|
|
289
|
+
scripted = torch.jit.trace(model, dummy)
|
|
290
|
+
scripted.save(str(dst))
|
|
291
|
+
|
|
292
|
+
return {
|
|
293
|
+
"path": str(dst),
|
|
294
|
+
"format": "torchscript",
|
|
295
|
+
"size_bytes": dst.stat().st_size,
|
|
296
|
+
"size_mb": round(dst.stat().st_size / 1024**2, 2),
|
|
297
|
+
"dependencies": FORMAT_DEPENDENCIES["torchscript"],
|
|
298
|
+
}
|
|
299
|
+
except Exception as e:
|
|
300
|
+
return {"error": f"TorchScript export failed: {e}"}
|
|
301
|
+
|
|
302
|
+
|
|
303
|
+
def export_tflite(
|
|
304
|
+
model_path: str,
|
|
305
|
+
output_dir: str,
|
|
306
|
+
model_name: str,
|
|
307
|
+
) -> dict:
|
|
308
|
+
"""Export TensorFlow/Keras model to TFLite."""
|
|
309
|
+
try:
|
|
310
|
+
import tensorflow as tf
|
|
311
|
+
except ImportError:
|
|
312
|
+
return {"error": "TFLite export requires TensorFlow: pip install tensorflow"}
|
|
313
|
+
|
|
314
|
+
src = Path(model_path)
|
|
315
|
+
if not src.exists():
|
|
316
|
+
return {"error": f"Model file not found: {model_path}"}
|
|
317
|
+
|
|
318
|
+
out_path = Path(output_dir)
|
|
319
|
+
out_path.mkdir(parents=True, exist_ok=True)
|
|
320
|
+
dst = out_path / f"{model_name}.tflite"
|
|
321
|
+
|
|
322
|
+
try:
|
|
323
|
+
model = tf.keras.models.load_model(str(src))
|
|
324
|
+
converter = tf.lite.TFLiteConverter.from_keras_model(model)
|
|
325
|
+
tflite_model = converter.convert()
|
|
326
|
+
|
|
327
|
+
with open(dst, "wb") as f:
|
|
328
|
+
f.write(tflite_model)
|
|
329
|
+
|
|
330
|
+
return {
|
|
331
|
+
"path": str(dst),
|
|
332
|
+
"format": "tflite",
|
|
333
|
+
"size_bytes": dst.stat().st_size,
|
|
334
|
+
"size_mb": round(dst.stat().st_size / 1024**2, 2),
|
|
335
|
+
"dependencies": FORMAT_DEPENDENCIES["tflite"],
|
|
336
|
+
}
|
|
337
|
+
except Exception as e:
|
|
338
|
+
return {"error": f"TFLite export failed: {e}"}
|
|
339
|
+
|
|
340
|
+
|
|
341
|
+
def export_model(
|
|
342
|
+
model_path: str,
|
|
343
|
+
output_dir: str,
|
|
344
|
+
model_name: str,
|
|
345
|
+
model_type: str,
|
|
346
|
+
export_format: str | None = None,
|
|
347
|
+
) -> dict:
|
|
348
|
+
"""Export a model to the specified format.
|
|
349
|
+
|
|
350
|
+
Auto-selects format if not specified. Dispatches to format-specific handler.
|
|
351
|
+
|
|
352
|
+
Args:
|
|
353
|
+
model_path: Path to the original model file.
|
|
354
|
+
output_dir: Directory to write exported model.
|
|
355
|
+
model_name: Base name for the exported file.
|
|
356
|
+
model_type: Model type (e.g., "xgboost", "pytorch").
|
|
357
|
+
export_format: Target format (e.g., "joblib", "onnx"). Auto-detected if None.
|
|
358
|
+
|
|
359
|
+
Returns:
|
|
360
|
+
Export result dict with path, format, size, dependencies.
|
|
361
|
+
"""
|
|
362
|
+
if not export_format:
|
|
363
|
+
export_format = get_default_format(model_type)
|
|
364
|
+
|
|
365
|
+
supported = get_supported_formats(model_type)
|
|
366
|
+
if export_format not in supported:
|
|
367
|
+
return {
|
|
368
|
+
"error": f"Format '{export_format}' not supported for model type '{model_type}'. "
|
|
369
|
+
f"Supported: {supported}",
|
|
370
|
+
}
|
|
371
|
+
|
|
372
|
+
handlers = {
|
|
373
|
+
"joblib": lambda: export_joblib(model_path, output_dir, model_name),
|
|
374
|
+
"xgboost_json": lambda: export_xgboost_json(model_path, output_dir, model_name),
|
|
375
|
+
"lightgbm_text": lambda: export_lightgbm_text(model_path, output_dir, model_name),
|
|
376
|
+
"onnx": lambda: export_onnx(model_path, output_dir, model_name, model_type),
|
|
377
|
+
"torchscript": lambda: export_torchscript(model_path, output_dir, model_name),
|
|
378
|
+
"tflite": lambda: export_tflite(model_path, output_dir, model_name),
|
|
379
|
+
}
|
|
380
|
+
|
|
381
|
+
handler = handlers.get(export_format)
|
|
382
|
+
if not handler:
|
|
383
|
+
return {"error": f"No handler for format '{export_format}'"}
|
|
384
|
+
|
|
385
|
+
return handler()
|