haoline 0.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- haoline/.streamlit/config.toml +10 -0
- haoline/__init__.py +248 -0
- haoline/analyzer.py +935 -0
- haoline/cli.py +2712 -0
- haoline/compare.py +811 -0
- haoline/compare_visualizations.py +1564 -0
- haoline/edge_analysis.py +525 -0
- haoline/eval/__init__.py +131 -0
- haoline/eval/adapters.py +844 -0
- haoline/eval/cli.py +390 -0
- haoline/eval/comparison.py +542 -0
- haoline/eval/deployment.py +633 -0
- haoline/eval/schemas.py +833 -0
- haoline/examples/__init__.py +15 -0
- haoline/examples/basic_inspection.py +74 -0
- haoline/examples/compare_models.py +117 -0
- haoline/examples/hardware_estimation.py +78 -0
- haoline/format_adapters.py +1001 -0
- haoline/formats/__init__.py +123 -0
- haoline/formats/coreml.py +250 -0
- haoline/formats/gguf.py +483 -0
- haoline/formats/openvino.py +255 -0
- haoline/formats/safetensors.py +273 -0
- haoline/formats/tflite.py +369 -0
- haoline/hardware.py +2307 -0
- haoline/hierarchical_graph.py +462 -0
- haoline/html_export.py +1573 -0
- haoline/layer_summary.py +769 -0
- haoline/llm_summarizer.py +465 -0
- haoline/op_icons.py +618 -0
- haoline/operational_profiling.py +1492 -0
- haoline/patterns.py +1116 -0
- haoline/pdf_generator.py +265 -0
- haoline/privacy.py +250 -0
- haoline/pydantic_models.py +241 -0
- haoline/report.py +1923 -0
- haoline/report_sections.py +539 -0
- haoline/risks.py +521 -0
- haoline/schema.py +523 -0
- haoline/streamlit_app.py +2024 -0
- haoline/tests/__init__.py +4 -0
- haoline/tests/conftest.py +123 -0
- haoline/tests/test_analyzer.py +868 -0
- haoline/tests/test_compare_visualizations.py +293 -0
- haoline/tests/test_edge_analysis.py +243 -0
- haoline/tests/test_eval.py +604 -0
- haoline/tests/test_format_adapters.py +460 -0
- haoline/tests/test_hardware.py +237 -0
- haoline/tests/test_hardware_recommender.py +90 -0
- haoline/tests/test_hierarchical_graph.py +326 -0
- haoline/tests/test_html_export.py +180 -0
- haoline/tests/test_layer_summary.py +428 -0
- haoline/tests/test_llm_patterns.py +540 -0
- haoline/tests/test_llm_summarizer.py +339 -0
- haoline/tests/test_patterns.py +774 -0
- haoline/tests/test_pytorch.py +327 -0
- haoline/tests/test_report.py +383 -0
- haoline/tests/test_risks.py +398 -0
- haoline/tests/test_schema.py +417 -0
- haoline/tests/test_tensorflow.py +380 -0
- haoline/tests/test_visualizations.py +316 -0
- haoline/universal_ir.py +856 -0
- haoline/visualizations.py +1086 -0
- haoline/visualize_yolo.py +44 -0
- haoline/web.py +110 -0
- haoline-0.3.0.dist-info/METADATA +471 -0
- haoline-0.3.0.dist-info/RECORD +70 -0
- haoline-0.3.0.dist-info/WHEEL +4 -0
- haoline-0.3.0.dist-info/entry_points.txt +5 -0
- haoline-0.3.0.dist-info/licenses/LICENSE +22 -0
haoline/cli.py
ADDED
|
@@ -0,0 +1,2712 @@
|
|
|
1
|
+
# Copyright (c) 2025 HaoLine Contributors
|
|
2
|
+
# SPDX-License-Identifier: MIT
|
|
3
|
+
|
|
4
|
+
"""
|
|
5
|
+
HaoLine CLI - Universal Model Inspector.
|
|
6
|
+
|
|
7
|
+
Entry point for the `haoline` command-line tool.
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
from __future__ import annotations
|
|
11
|
+
|
|
12
|
+
import argparse
|
|
13
|
+
import logging
|
|
14
|
+
import os
|
|
15
|
+
import pathlib
|
|
16
|
+
import sys
|
|
17
|
+
from typing import Any
|
|
18
|
+
|
|
19
|
+
from . import ModelInspector
|
|
20
|
+
from .edge_analysis import EdgeAnalyzer
|
|
21
|
+
from .hardware import (
|
|
22
|
+
CLOUD_INSTANCES,
|
|
23
|
+
HardwareEstimator,
|
|
24
|
+
create_multi_gpu_profile,
|
|
25
|
+
detect_local_hardware,
|
|
26
|
+
get_cloud_instance,
|
|
27
|
+
get_profile,
|
|
28
|
+
)
|
|
29
|
+
from .hierarchical_graph import HierarchicalGraphBuilder
|
|
30
|
+
from .html_export import HTMLExporter
|
|
31
|
+
from .html_export import generate_html as generate_graph_html
|
|
32
|
+
from .layer_summary import LayerSummaryBuilder, generate_html_table
|
|
33
|
+
from .llm_summarizer import (
|
|
34
|
+
LLMSummarizer,
|
|
35
|
+
)
|
|
36
|
+
from .llm_summarizer import (
|
|
37
|
+
has_api_key as has_llm_api_key,
|
|
38
|
+
)
|
|
39
|
+
from .llm_summarizer import (
|
|
40
|
+
is_available as is_llm_available,
|
|
41
|
+
)
|
|
42
|
+
from .operational_profiling import OperationalProfiler
|
|
43
|
+
from .patterns import PatternAnalyzer
|
|
44
|
+
from .pdf_generator import (
|
|
45
|
+
PDFGenerator,
|
|
46
|
+
)
|
|
47
|
+
from .pdf_generator import (
|
|
48
|
+
is_available as is_pdf_available,
|
|
49
|
+
)
|
|
50
|
+
from .visualizations import (
|
|
51
|
+
VisualizationGenerator,
|
|
52
|
+
)
|
|
53
|
+
from .visualizations import (
|
|
54
|
+
is_available as is_viz_available,
|
|
55
|
+
)
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
class ProgressIndicator:
|
|
59
|
+
"""Simple progress indicator for CLI operations."""
|
|
60
|
+
|
|
61
|
+
def __init__(self, enabled: bool = True, quiet: bool = False):
|
|
62
|
+
self.enabled = enabled and not quiet
|
|
63
|
+
self._current_step = 0
|
|
64
|
+
self._total_steps = 0
|
|
65
|
+
|
|
66
|
+
def start(self, total_steps: int, description: str = "Processing"):
|
|
67
|
+
"""Start progress tracking."""
|
|
68
|
+
self._total_steps = total_steps
|
|
69
|
+
self._current_step = 0
|
|
70
|
+
if self.enabled:
|
|
71
|
+
print(f"\n{description}...")
|
|
72
|
+
|
|
73
|
+
def step(self, message: str):
|
|
74
|
+
"""Mark completion of a step."""
|
|
75
|
+
self._current_step += 1
|
|
76
|
+
if self.enabled:
|
|
77
|
+
pct = (self._current_step / self._total_steps * 100) if self._total_steps else 0
|
|
78
|
+
print(f" [{self._current_step}/{self._total_steps}] {message} ({pct:.0f}%)")
|
|
79
|
+
|
|
80
|
+
def finish(self, message: str = "Done"):
|
|
81
|
+
"""Mark completion of all steps."""
|
|
82
|
+
if self.enabled:
|
|
83
|
+
print(f" {message}\n")
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
def parse_args():
|
|
87
|
+
"""Parse command line arguments."""
|
|
88
|
+
parser = argparse.ArgumentParser(
|
|
89
|
+
os.path.basename(__file__),
|
|
90
|
+
description="Analyze an ONNX model and generate architecture documentation.",
|
|
91
|
+
formatter_class=argparse.RawDescriptionHelpFormatter,
|
|
92
|
+
epilog="""
|
|
93
|
+
Examples:
|
|
94
|
+
# Basic inspection with console output (auto-detects local hardware)
|
|
95
|
+
python -m haoline model.onnx
|
|
96
|
+
|
|
97
|
+
# Use specific NVIDIA GPU profile for estimates
|
|
98
|
+
python -m haoline model.onnx --hardware a100
|
|
99
|
+
|
|
100
|
+
# List available hardware profiles
|
|
101
|
+
python -m haoline --list-hardware
|
|
102
|
+
|
|
103
|
+
# Generate JSON report with hardware estimates
|
|
104
|
+
python -m haoline model.onnx --hardware rtx4090 --out-json report.json
|
|
105
|
+
|
|
106
|
+
# Specify precision and batch size for hardware estimates
|
|
107
|
+
python -m haoline model.onnx --hardware t4 --precision fp16 --batch-size 8
|
|
108
|
+
|
|
109
|
+
# Convert PyTorch model to ONNX and analyze
|
|
110
|
+
python -m haoline --from-pytorch model.pt --input-shape 1,3,224,224
|
|
111
|
+
|
|
112
|
+
# Convert TensorFlow SavedModel to ONNX and analyze
|
|
113
|
+
python -m haoline --from-tensorflow ./saved_model_dir --out-html report.html
|
|
114
|
+
|
|
115
|
+
# Convert Keras .h5 model to ONNX and analyze
|
|
116
|
+
python -m haoline --from-keras model.h5 --keep-onnx converted.onnx
|
|
117
|
+
|
|
118
|
+
# Convert TensorFlow frozen graph to ONNX (requires input/output names)
|
|
119
|
+
python -m haoline --from-frozen-graph model.pb --tf-inputs input:0 --tf-outputs output:0
|
|
120
|
+
|
|
121
|
+
# Convert JAX model to ONNX (requires apply function and input shape)
|
|
122
|
+
python -m haoline --from-jax params.pkl --jax-apply-fn my_model:apply --input-shape 1,3,224,224
|
|
123
|
+
|
|
124
|
+
# Generate Steam-style system requirements
|
|
125
|
+
python -m haoline model.onnx --system-requirements
|
|
126
|
+
|
|
127
|
+
# Run batch size sweep
|
|
128
|
+
python -m haoline model.onnx --hardware a100 --sweep-batch-sizes
|
|
129
|
+
|
|
130
|
+
# Run resolution sweep for vision models
|
|
131
|
+
python -m haoline model.onnx --hardware rtx4090 --sweep-resolutions auto
|
|
132
|
+
|
|
133
|
+
# Custom resolutions for object detection
|
|
134
|
+
python -m haoline yolo.onnx --hardware rtx4090 --sweep-resolutions "320x320,640x640,1280x1280"
|
|
135
|
+
|
|
136
|
+
# List available format conversions
|
|
137
|
+
python -m haoline --list-conversions
|
|
138
|
+
|
|
139
|
+
# Convert PyTorch to ONNX and save
|
|
140
|
+
python -m haoline --from-pytorch model.pt --input-shape 1,3,224,224 --convert-to onnx --convert-output model.onnx
|
|
141
|
+
|
|
142
|
+
# Export model as Universal IR (JSON)
|
|
143
|
+
python -m haoline model.onnx --export-ir model_ir.json
|
|
144
|
+
|
|
145
|
+
# Export graph visualization (DOT or PNG)
|
|
146
|
+
python -m haoline model.onnx --export-graph graph.dot
|
|
147
|
+
python -m haoline model.onnx --export-graph graph.png --graph-max-nodes 200
|
|
148
|
+
""",
|
|
149
|
+
)
|
|
150
|
+
|
|
151
|
+
parser.add_argument(
|
|
152
|
+
"model_path",
|
|
153
|
+
type=pathlib.Path,
|
|
154
|
+
nargs="?", # Optional now since --list-hardware doesn't need it
|
|
155
|
+
help="Path to the ONNX model file to analyze.",
|
|
156
|
+
)
|
|
157
|
+
|
|
158
|
+
parser.add_argument(
|
|
159
|
+
"--schema",
|
|
160
|
+
action="store_true",
|
|
161
|
+
help="Output the JSON Schema for InspectionReport and exit. "
|
|
162
|
+
"Useful for integrating HaoLine output with other tools.",
|
|
163
|
+
)
|
|
164
|
+
|
|
165
|
+
parser.add_argument(
|
|
166
|
+
"--out-json",
|
|
167
|
+
type=pathlib.Path,
|
|
168
|
+
default=None,
|
|
169
|
+
help="Output path for JSON report. If not specified, no JSON is written.",
|
|
170
|
+
)
|
|
171
|
+
|
|
172
|
+
parser.add_argument(
|
|
173
|
+
"--out-md",
|
|
174
|
+
type=pathlib.Path,
|
|
175
|
+
default=None,
|
|
176
|
+
help="Output path for Markdown model card. If not specified, no Markdown is written.",
|
|
177
|
+
)
|
|
178
|
+
|
|
179
|
+
parser.add_argument(
|
|
180
|
+
"--out-html",
|
|
181
|
+
type=pathlib.Path,
|
|
182
|
+
default=None,
|
|
183
|
+
help="Output path for HTML report with embedded images. Single shareable file.",
|
|
184
|
+
)
|
|
185
|
+
|
|
186
|
+
parser.add_argument(
|
|
187
|
+
"--out-pdf",
|
|
188
|
+
type=pathlib.Path,
|
|
189
|
+
default=None,
|
|
190
|
+
help="Output path for PDF report. Requires playwright: pip install playwright && playwright install chromium",
|
|
191
|
+
)
|
|
192
|
+
|
|
193
|
+
parser.add_argument(
|
|
194
|
+
"--html-graph",
|
|
195
|
+
type=pathlib.Path,
|
|
196
|
+
default=None,
|
|
197
|
+
help="Output path for interactive graph visualization (standalone HTML with D3.js).",
|
|
198
|
+
)
|
|
199
|
+
|
|
200
|
+
parser.add_argument(
|
|
201
|
+
"--layer-csv",
|
|
202
|
+
type=pathlib.Path,
|
|
203
|
+
default=None,
|
|
204
|
+
help="Output path for per-layer metrics CSV (params, FLOPs, memory per layer).",
|
|
205
|
+
)
|
|
206
|
+
|
|
207
|
+
parser.add_argument(
|
|
208
|
+
"--include-graph",
|
|
209
|
+
action="store_true",
|
|
210
|
+
help="Include interactive graph in --out-html report (makes HTML larger but more informative).",
|
|
211
|
+
)
|
|
212
|
+
|
|
213
|
+
parser.add_argument(
|
|
214
|
+
"--include-layer-table",
|
|
215
|
+
action="store_true",
|
|
216
|
+
help="Include per-layer summary table in --out-html report.",
|
|
217
|
+
)
|
|
218
|
+
|
|
219
|
+
# PyTorch conversion options
|
|
220
|
+
pytorch_group = parser.add_argument_group("PyTorch Conversion Options")
|
|
221
|
+
pytorch_group.add_argument(
|
|
222
|
+
"--from-pytorch",
|
|
223
|
+
type=pathlib.Path,
|
|
224
|
+
default=None,
|
|
225
|
+
metavar="MODEL_PATH",
|
|
226
|
+
help="Convert a PyTorch model (.pth, .pt) to ONNX before analysis. Requires torch.",
|
|
227
|
+
)
|
|
228
|
+
pytorch_group.add_argument(
|
|
229
|
+
"--input-shape",
|
|
230
|
+
type=str,
|
|
231
|
+
default=None,
|
|
232
|
+
metavar="SHAPE",
|
|
233
|
+
help="Input shape for PyTorch conversion, e.g., '1,3,224,224'. Required with --from-pytorch.",
|
|
234
|
+
)
|
|
235
|
+
pytorch_group.add_argument(
|
|
236
|
+
"--keep-onnx",
|
|
237
|
+
type=pathlib.Path,
|
|
238
|
+
default=None,
|
|
239
|
+
metavar="PATH",
|
|
240
|
+
help="Save the converted ONNX model to this path (otherwise uses temp file).",
|
|
241
|
+
)
|
|
242
|
+
pytorch_group.add_argument(
|
|
243
|
+
"--opset-version",
|
|
244
|
+
type=int,
|
|
245
|
+
default=17,
|
|
246
|
+
help="ONNX opset version for PyTorch export (default: 17).",
|
|
247
|
+
)
|
|
248
|
+
pytorch_group.add_argument(
|
|
249
|
+
"--pytorch-weights",
|
|
250
|
+
type=pathlib.Path,
|
|
251
|
+
default=None,
|
|
252
|
+
metavar="PATH",
|
|
253
|
+
help="Path to original PyTorch weights (.pt) to extract class names/metadata. "
|
|
254
|
+
"Useful when analyzing a pre-converted ONNX file.",
|
|
255
|
+
)
|
|
256
|
+
|
|
257
|
+
# TensorFlow/Keras conversion options
|
|
258
|
+
tf_group = parser.add_argument_group("TensorFlow/Keras Conversion Options")
|
|
259
|
+
tf_group.add_argument(
|
|
260
|
+
"--from-tensorflow",
|
|
261
|
+
type=pathlib.Path,
|
|
262
|
+
default=None,
|
|
263
|
+
metavar="MODEL_PATH",
|
|
264
|
+
help="Convert a TensorFlow SavedModel directory to ONNX before analysis. Requires tf2onnx.",
|
|
265
|
+
)
|
|
266
|
+
tf_group.add_argument(
|
|
267
|
+
"--from-keras",
|
|
268
|
+
type=pathlib.Path,
|
|
269
|
+
default=None,
|
|
270
|
+
metavar="MODEL_PATH",
|
|
271
|
+
help="Convert a Keras model (.h5, .keras) to ONNX before analysis. Requires tf2onnx.",
|
|
272
|
+
)
|
|
273
|
+
tf_group.add_argument(
|
|
274
|
+
"--from-frozen-graph",
|
|
275
|
+
type=pathlib.Path,
|
|
276
|
+
default=None,
|
|
277
|
+
metavar="MODEL_PATH",
|
|
278
|
+
help="Convert a TensorFlow frozen graph (.pb) to ONNX. Requires --tf-inputs and --tf-outputs.",
|
|
279
|
+
)
|
|
280
|
+
tf_group.add_argument(
|
|
281
|
+
"--tf-inputs",
|
|
282
|
+
type=str,
|
|
283
|
+
default=None,
|
|
284
|
+
metavar="NAMES",
|
|
285
|
+
help="Comma-separated input tensor names for frozen graph conversion, e.g., 'input:0'.",
|
|
286
|
+
)
|
|
287
|
+
tf_group.add_argument(
|
|
288
|
+
"--tf-outputs",
|
|
289
|
+
type=str,
|
|
290
|
+
default=None,
|
|
291
|
+
metavar="NAMES",
|
|
292
|
+
help="Comma-separated output tensor names for frozen graph conversion, e.g., 'output:0'.",
|
|
293
|
+
)
|
|
294
|
+
|
|
295
|
+
# JAX conversion options
|
|
296
|
+
jax_group = parser.add_argument_group("JAX/Flax Conversion Options")
|
|
297
|
+
jax_group.add_argument(
|
|
298
|
+
"--from-jax",
|
|
299
|
+
type=pathlib.Path,
|
|
300
|
+
default=None,
|
|
301
|
+
metavar="MODEL_PATH",
|
|
302
|
+
help="Convert a JAX/Flax model to ONNX before analysis. Requires jax and tf2onnx. "
|
|
303
|
+
"Expects a saved params file (.msgpack, .pkl) with --jax-apply-fn.",
|
|
304
|
+
)
|
|
305
|
+
jax_group.add_argument(
|
|
306
|
+
"--jax-apply-fn",
|
|
307
|
+
type=str,
|
|
308
|
+
default=None,
|
|
309
|
+
metavar="MODULE:FUNCTION",
|
|
310
|
+
help="JAX apply function path, e.g., 'my_model:apply'. Required with --from-jax.",
|
|
311
|
+
)
|
|
312
|
+
|
|
313
|
+
# Hardware options
|
|
314
|
+
hardware_group = parser.add_argument_group("Hardware Options")
|
|
315
|
+
hardware_group.add_argument(
|
|
316
|
+
"--hardware",
|
|
317
|
+
type=str,
|
|
318
|
+
default=None,
|
|
319
|
+
metavar="PROFILE",
|
|
320
|
+
help="Hardware profile for performance estimates. Use 'auto' to detect local hardware, "
|
|
321
|
+
"or specify a profile name (e.g., 'a100', 'rtx4090', 't4'). Use --list-hardware to see all options.",
|
|
322
|
+
)
|
|
323
|
+
|
|
324
|
+
hardware_group.add_argument(
|
|
325
|
+
"--list-hardware",
|
|
326
|
+
action="store_true",
|
|
327
|
+
help="List all available hardware profiles and exit.",
|
|
328
|
+
)
|
|
329
|
+
|
|
330
|
+
hardware_group.add_argument(
|
|
331
|
+
"--precision",
|
|
332
|
+
choices=["fp32", "fp16", "bf16", "int8"],
|
|
333
|
+
default="fp32",
|
|
334
|
+
help="Precision for hardware estimates (default: fp32).",
|
|
335
|
+
)
|
|
336
|
+
|
|
337
|
+
hardware_group.add_argument(
|
|
338
|
+
"--batch-size",
|
|
339
|
+
type=int,
|
|
340
|
+
default=1,
|
|
341
|
+
help="Batch size for hardware estimates (default: 1).",
|
|
342
|
+
)
|
|
343
|
+
|
|
344
|
+
hardware_group.add_argument(
|
|
345
|
+
"--gpu-count",
|
|
346
|
+
type=int,
|
|
347
|
+
default=1,
|
|
348
|
+
metavar="N",
|
|
349
|
+
help="Number of GPUs for multi-GPU estimates (default: 1). "
|
|
350
|
+
"Scales compute and memory with efficiency factors for tensor parallelism.",
|
|
351
|
+
)
|
|
352
|
+
|
|
353
|
+
hardware_group.add_argument(
|
|
354
|
+
"--cloud",
|
|
355
|
+
type=str,
|
|
356
|
+
default=None,
|
|
357
|
+
metavar="INSTANCE",
|
|
358
|
+
help="Cloud instance type for cost/performance estimates. "
|
|
359
|
+
"E.g., 'aws-p4d-24xlarge', 'azure-nd-h100-v5'. Use --list-cloud to see options.",
|
|
360
|
+
)
|
|
361
|
+
|
|
362
|
+
hardware_group.add_argument(
|
|
363
|
+
"--list-cloud",
|
|
364
|
+
action="store_true",
|
|
365
|
+
help="List all available cloud instance profiles and exit.",
|
|
366
|
+
)
|
|
367
|
+
|
|
368
|
+
hardware_group.add_argument(
|
|
369
|
+
"--deployment-target",
|
|
370
|
+
type=str,
|
|
371
|
+
choices=["edge", "local", "cloud"],
|
|
372
|
+
default=None,
|
|
373
|
+
help="High-level deployment target to guide system requirement recommendations "
|
|
374
|
+
"(edge device, local server, or cloud server).",
|
|
375
|
+
)
|
|
376
|
+
|
|
377
|
+
hardware_group.add_argument(
|
|
378
|
+
"--target-latency-ms",
|
|
379
|
+
type=float,
|
|
380
|
+
default=None,
|
|
381
|
+
help="Optional latency target (ms) for system requirements. "
|
|
382
|
+
"If set, this is converted into a throughput target for recommendations.",
|
|
383
|
+
)
|
|
384
|
+
|
|
385
|
+
hardware_group.add_argument(
|
|
386
|
+
"--target-throughput-fps",
|
|
387
|
+
type=float,
|
|
388
|
+
default=None,
|
|
389
|
+
help="Optional throughput target (frames/requests per second) for system requirements.",
|
|
390
|
+
)
|
|
391
|
+
|
|
392
|
+
hardware_group.add_argument(
|
|
393
|
+
"--deployment-fps",
|
|
394
|
+
type=float,
|
|
395
|
+
default=None,
|
|
396
|
+
metavar="FPS",
|
|
397
|
+
help="Target inference rate for deployment cost calculation (e.g., 3 for 3 fps continuous). "
|
|
398
|
+
"Combined with --deployment-hours to estimate $/day and $/month.",
|
|
399
|
+
)
|
|
400
|
+
|
|
401
|
+
hardware_group.add_argument(
|
|
402
|
+
"--deployment-hours",
|
|
403
|
+
type=float,
|
|
404
|
+
default=24.0,
|
|
405
|
+
metavar="HOURS",
|
|
406
|
+
help="Hours per day the model runs for deployment cost calculation (default: 24). "
|
|
407
|
+
"E.g., 8 for business hours only.",
|
|
408
|
+
)
|
|
409
|
+
|
|
410
|
+
# Epic 6C features
|
|
411
|
+
hardware_group.add_argument(
|
|
412
|
+
"--system-requirements",
|
|
413
|
+
action="store_true",
|
|
414
|
+
help="Generate Steam-style system requirements (Minimum, Recommended, Optimal).",
|
|
415
|
+
)
|
|
416
|
+
|
|
417
|
+
hardware_group.add_argument(
|
|
418
|
+
"--sweep-batch-sizes",
|
|
419
|
+
action="store_true",
|
|
420
|
+
help="Run analysis across multiple batch sizes (1, 2, 4, ..., 128) to find optimal throughput.",
|
|
421
|
+
)
|
|
422
|
+
|
|
423
|
+
hardware_group.add_argument(
|
|
424
|
+
"--no-benchmark",
|
|
425
|
+
action="store_true",
|
|
426
|
+
help="Use theoretical estimates instead of actual inference for batch sweeps (faster but less accurate).",
|
|
427
|
+
)
|
|
428
|
+
|
|
429
|
+
hardware_group.add_argument(
|
|
430
|
+
"--input-resolution",
|
|
431
|
+
type=str,
|
|
432
|
+
default=None,
|
|
433
|
+
help=(
|
|
434
|
+
"Override input resolution for analysis. Format: HxW (e.g., 640x640). "
|
|
435
|
+
"For vision models, affects FLOPs and memory estimates."
|
|
436
|
+
),
|
|
437
|
+
)
|
|
438
|
+
|
|
439
|
+
hardware_group.add_argument(
|
|
440
|
+
"--sweep-resolutions",
|
|
441
|
+
type=str,
|
|
442
|
+
default=None,
|
|
443
|
+
help=(
|
|
444
|
+
"Run resolution sweep analysis. Provide comma-separated resolutions "
|
|
445
|
+
"(e.g., '224x224,384x384,512x512,640x640') or 'auto' for common resolutions."
|
|
446
|
+
),
|
|
447
|
+
)
|
|
448
|
+
|
|
449
|
+
# Epic 9: Runtime Profiling Options (defaults to ON for real measurements)
|
|
450
|
+
profiling_group = parser.add_argument_group("Runtime Profiling Options")
|
|
451
|
+
profiling_group.add_argument(
|
|
452
|
+
"--no-profile",
|
|
453
|
+
action="store_true",
|
|
454
|
+
help="Disable per-layer ONNX Runtime profiling (faster but less detailed).",
|
|
455
|
+
)
|
|
456
|
+
|
|
457
|
+
profiling_group.add_argument(
|
|
458
|
+
"--profile-runs",
|
|
459
|
+
type=int,
|
|
460
|
+
default=10,
|
|
461
|
+
metavar="N",
|
|
462
|
+
help="Number of inference runs for profiling (default: 10).",
|
|
463
|
+
)
|
|
464
|
+
|
|
465
|
+
profiling_group.add_argument(
|
|
466
|
+
"--no-gpu-metrics",
|
|
467
|
+
action="store_true",
|
|
468
|
+
help="Disable GPU metrics capture (VRAM, utilization, temperature).",
|
|
469
|
+
)
|
|
470
|
+
|
|
471
|
+
profiling_group.add_argument(
|
|
472
|
+
"--no-bottleneck-analysis",
|
|
473
|
+
action="store_true",
|
|
474
|
+
help="Disable compute vs memory bottleneck analysis.",
|
|
475
|
+
)
|
|
476
|
+
|
|
477
|
+
profiling_group.add_argument(
|
|
478
|
+
"--no-benchmark-resolutions",
|
|
479
|
+
action="store_true",
|
|
480
|
+
help="Disable resolution benchmarking (use theoretical estimates instead).",
|
|
481
|
+
)
|
|
482
|
+
|
|
483
|
+
# Visualization options
|
|
484
|
+
viz_group = parser.add_argument_group("Visualization Options")
|
|
485
|
+
viz_group.add_argument(
|
|
486
|
+
"--with-plots",
|
|
487
|
+
action="store_true",
|
|
488
|
+
help="Generate visualization assets (requires matplotlib).",
|
|
489
|
+
)
|
|
490
|
+
|
|
491
|
+
viz_group.add_argument(
|
|
492
|
+
"--assets-dir",
|
|
493
|
+
type=pathlib.Path,
|
|
494
|
+
default=None,
|
|
495
|
+
metavar="PATH",
|
|
496
|
+
help="Directory for plot PNG files (default: same directory as output files, or 'assets/').",
|
|
497
|
+
)
|
|
498
|
+
|
|
499
|
+
# LLM options
|
|
500
|
+
llm_group = parser.add_argument_group("LLM Summarization Options")
|
|
501
|
+
llm_group.add_argument(
|
|
502
|
+
"--llm-summary",
|
|
503
|
+
action="store_true",
|
|
504
|
+
help="Generate LLM-powered summaries (requires openai package and OPENAI_API_KEY env var).",
|
|
505
|
+
)
|
|
506
|
+
|
|
507
|
+
llm_group.add_argument(
|
|
508
|
+
"--llm-model",
|
|
509
|
+
type=str,
|
|
510
|
+
default="gpt-4o-mini",
|
|
511
|
+
metavar="MODEL",
|
|
512
|
+
help="OpenAI model to use for summaries (default: gpt-4o-mini).",
|
|
513
|
+
)
|
|
514
|
+
|
|
515
|
+
parser.add_argument(
|
|
516
|
+
"--log-level",
|
|
517
|
+
choices=["debug", "info", "warning", "error"],
|
|
518
|
+
default="info",
|
|
519
|
+
help="Logging verbosity level (default: info).",
|
|
520
|
+
)
|
|
521
|
+
|
|
522
|
+
parser.add_argument(
|
|
523
|
+
"--quiet",
|
|
524
|
+
action="store_true",
|
|
525
|
+
help="Suppress console output. Only write to files if --out-json or --out-md specified.",
|
|
526
|
+
)
|
|
527
|
+
|
|
528
|
+
parser.add_argument(
|
|
529
|
+
"--progress",
|
|
530
|
+
action="store_true",
|
|
531
|
+
help="Show progress indicators during analysis (useful for large models).",
|
|
532
|
+
)
|
|
533
|
+
|
|
534
|
+
parser.add_argument(
|
|
535
|
+
"--offline",
|
|
536
|
+
action="store_true",
|
|
537
|
+
help="Run in offline mode. Fails if any network access is attempted. "
|
|
538
|
+
"Disables LLM summaries and other features requiring internet.",
|
|
539
|
+
)
|
|
540
|
+
|
|
541
|
+
# Privacy controls
|
|
542
|
+
privacy_group = parser.add_argument_group("Privacy Options")
|
|
543
|
+
privacy_group.add_argument(
|
|
544
|
+
"--redact-names",
|
|
545
|
+
action="store_true",
|
|
546
|
+
help="Anonymize layer and tensor names in output (e.g., layer_001, tensor_042). "
|
|
547
|
+
"Useful for sharing reports without revealing proprietary architecture details.",
|
|
548
|
+
)
|
|
549
|
+
privacy_group.add_argument(
|
|
550
|
+
"--summary-only",
|
|
551
|
+
action="store_true",
|
|
552
|
+
help="Output only aggregate statistics (params, FLOPs, memory). "
|
|
553
|
+
"Omit per-layer details and graph structure for maximum privacy.",
|
|
554
|
+
)
|
|
555
|
+
|
|
556
|
+
# Format conversion
|
|
557
|
+
convert_group = parser.add_argument_group("Format Conversion Options")
|
|
558
|
+
convert_group.add_argument(
|
|
559
|
+
"--convert-to",
|
|
560
|
+
type=str,
|
|
561
|
+
choices=["onnx"], # Only ONNX writing is currently supported
|
|
562
|
+
default=None,
|
|
563
|
+
metavar="FORMAT",
|
|
564
|
+
help="Convert the model to the specified format (currently only 'onnx' supported). "
|
|
565
|
+
"Use with --from-pytorch or other input formats.",
|
|
566
|
+
)
|
|
567
|
+
convert_group.add_argument(
|
|
568
|
+
"--convert-output",
|
|
569
|
+
type=pathlib.Path,
|
|
570
|
+
default=None,
|
|
571
|
+
metavar="PATH",
|
|
572
|
+
help="Output path for format conversion. Required when using --convert-to.",
|
|
573
|
+
)
|
|
574
|
+
convert_group.add_argument(
|
|
575
|
+
"--list-conversions",
|
|
576
|
+
action="store_true",
|
|
577
|
+
help="List available format conversion paths and exit.",
|
|
578
|
+
)
|
|
579
|
+
|
|
580
|
+
# Universal IR export
|
|
581
|
+
ir_group = parser.add_argument_group("Universal IR Export Options")
|
|
582
|
+
ir_group.add_argument(
|
|
583
|
+
"--export-ir",
|
|
584
|
+
type=pathlib.Path,
|
|
585
|
+
default=None,
|
|
586
|
+
metavar="PATH",
|
|
587
|
+
help="Export model as Universal IR to JSON file. "
|
|
588
|
+
"Provides format-agnostic representation of the model graph.",
|
|
589
|
+
)
|
|
590
|
+
ir_group.add_argument(
|
|
591
|
+
"--export-graph",
|
|
592
|
+
type=pathlib.Path,
|
|
593
|
+
default=None,
|
|
594
|
+
metavar="PATH",
|
|
595
|
+
help="Export model graph visualization. Supports .dot (Graphviz) and .png formats. "
|
|
596
|
+
"Requires graphviz package for PNG: pip install graphviz",
|
|
597
|
+
)
|
|
598
|
+
ir_group.add_argument(
|
|
599
|
+
"--graph-max-nodes",
|
|
600
|
+
type=int,
|
|
601
|
+
default=500,
|
|
602
|
+
metavar="N",
|
|
603
|
+
help="Maximum nodes to include in graph visualization (default: 500). "
|
|
604
|
+
"Prevents huge graphs from crashing.",
|
|
605
|
+
)
|
|
606
|
+
|
|
607
|
+
return parser.parse_args()
|
|
608
|
+
|
|
609
|
+
|
|
610
|
+
def setup_logging(log_level: str) -> logging.Logger:
|
|
611
|
+
"""Configure logging for the CLI."""
|
|
612
|
+
level_map = {
|
|
613
|
+
"debug": logging.DEBUG,
|
|
614
|
+
"info": logging.INFO,
|
|
615
|
+
"warning": logging.WARNING,
|
|
616
|
+
"error": logging.ERROR,
|
|
617
|
+
}
|
|
618
|
+
|
|
619
|
+
logging.basicConfig(
|
|
620
|
+
level=level_map.get(log_level, logging.INFO),
|
|
621
|
+
format="%(levelname)s - %(message)s",
|
|
622
|
+
)
|
|
623
|
+
|
|
624
|
+
return logging.getLogger("haoline")
|
|
625
|
+
|
|
626
|
+
|
|
627
|
+
def _generate_markdown_with_extras(
|
|
628
|
+
report, viz_paths: dict, report_dir: pathlib.Path, llm_summary=None
|
|
629
|
+
) -> str:
|
|
630
|
+
"""Generate markdown with embedded visualizations and LLM summaries."""
|
|
631
|
+
lines = []
|
|
632
|
+
base_md = report.to_markdown()
|
|
633
|
+
|
|
634
|
+
# If we have an LLM summary, insert it after the header
|
|
635
|
+
if llm_summary and llm_summary.success:
|
|
636
|
+
# Insert executive summary after the metadata section
|
|
637
|
+
header_end = base_md.find("## Graph Summary")
|
|
638
|
+
if header_end != -1:
|
|
639
|
+
lines.append(base_md[:header_end])
|
|
640
|
+
lines.append("## Executive Summary\n")
|
|
641
|
+
if llm_summary.short_summary:
|
|
642
|
+
lines.append(f"{llm_summary.short_summary}\n")
|
|
643
|
+
if llm_summary.detailed_summary:
|
|
644
|
+
lines.append(f"\n{llm_summary.detailed_summary}\n")
|
|
645
|
+
lines.append(f"\n*Generated by {llm_summary.model_used}*\n\n")
|
|
646
|
+
base_md = base_md[header_end:]
|
|
647
|
+
|
|
648
|
+
# Split the markdown at the Complexity Metrics section to insert plots
|
|
649
|
+
sections = base_md.split("## Complexity Metrics")
|
|
650
|
+
|
|
651
|
+
if len(sections) < 2:
|
|
652
|
+
# No complexity section found, just append plots at end
|
|
653
|
+
lines.append(base_md)
|
|
654
|
+
else:
|
|
655
|
+
lines.append(sections[0])
|
|
656
|
+
|
|
657
|
+
# Insert visualizations section before Complexity Metrics
|
|
658
|
+
if viz_paths:
|
|
659
|
+
lines.append("## Visualizations\n")
|
|
660
|
+
|
|
661
|
+
if "complexity_summary" in viz_paths:
|
|
662
|
+
rel_path = (
|
|
663
|
+
viz_paths["complexity_summary"].relative_to(report_dir)
|
|
664
|
+
if viz_paths["complexity_summary"].is_relative_to(report_dir)
|
|
665
|
+
else viz_paths["complexity_summary"]
|
|
666
|
+
)
|
|
667
|
+
lines.append("### Complexity Overview\n")
|
|
668
|
+
lines.append(f"\n")
|
|
669
|
+
|
|
670
|
+
if "op_histogram" in viz_paths:
|
|
671
|
+
rel_path = (
|
|
672
|
+
viz_paths["op_histogram"].relative_to(report_dir)
|
|
673
|
+
if viz_paths["op_histogram"].is_relative_to(report_dir)
|
|
674
|
+
else viz_paths["op_histogram"]
|
|
675
|
+
)
|
|
676
|
+
lines.append("### Operator Distribution\n")
|
|
677
|
+
lines.append(f"\n")
|
|
678
|
+
|
|
679
|
+
if "param_distribution" in viz_paths:
|
|
680
|
+
rel_path = (
|
|
681
|
+
viz_paths["param_distribution"].relative_to(report_dir)
|
|
682
|
+
if viz_paths["param_distribution"].is_relative_to(report_dir)
|
|
683
|
+
else viz_paths["param_distribution"]
|
|
684
|
+
)
|
|
685
|
+
lines.append("### Parameter Distribution\n")
|
|
686
|
+
lines.append(f"\n")
|
|
687
|
+
|
|
688
|
+
if "flops_distribution" in viz_paths:
|
|
689
|
+
rel_path = (
|
|
690
|
+
viz_paths["flops_distribution"].relative_to(report_dir)
|
|
691
|
+
if viz_paths["flops_distribution"].is_relative_to(report_dir)
|
|
692
|
+
else viz_paths["flops_distribution"]
|
|
693
|
+
)
|
|
694
|
+
lines.append("### FLOPs Distribution\n")
|
|
695
|
+
lines.append(f"\n")
|
|
696
|
+
|
|
697
|
+
lines.append("")
|
|
698
|
+
|
|
699
|
+
lines.append("## Complexity Metrics" + sections[1])
|
|
700
|
+
|
|
701
|
+
return "\n".join(lines)
|
|
702
|
+
|
|
703
|
+
|
|
704
|
+
def _extract_ultralytics_metadata(
|
|
705
|
+
weights_path: pathlib.Path,
|
|
706
|
+
logger: logging.Logger,
|
|
707
|
+
) -> dict[str, Any] | None:
|
|
708
|
+
"""
|
|
709
|
+
Extract metadata from an Ultralytics model (.pt file).
|
|
710
|
+
|
|
711
|
+
Returns dict with task, num_classes, class_names or None if not Ultralytics.
|
|
712
|
+
"""
|
|
713
|
+
try:
|
|
714
|
+
from ultralytics import YOLO
|
|
715
|
+
|
|
716
|
+
model = YOLO(str(weights_path))
|
|
717
|
+
|
|
718
|
+
return {
|
|
719
|
+
"task": model.task,
|
|
720
|
+
"num_classes": len(model.names),
|
|
721
|
+
"class_names": list(model.names.values()),
|
|
722
|
+
"source": "ultralytics",
|
|
723
|
+
}
|
|
724
|
+
except ImportError:
|
|
725
|
+
logger.debug("ultralytics not installed, skipping metadata extraction")
|
|
726
|
+
return None
|
|
727
|
+
except Exception as e:
|
|
728
|
+
logger.debug(f"Could not extract Ultralytics metadata: {e}")
|
|
729
|
+
return None
|
|
730
|
+
|
|
731
|
+
|
|
732
|
+
def _convert_pytorch_to_onnx(
|
|
733
|
+
pytorch_path: pathlib.Path,
|
|
734
|
+
input_shape_str: str | None,
|
|
735
|
+
output_path: pathlib.Path | None,
|
|
736
|
+
opset_version: int,
|
|
737
|
+
logger: logging.Logger,
|
|
738
|
+
) -> tuple[pathlib.Path | None, Any]:
|
|
739
|
+
"""
|
|
740
|
+
Convert a PyTorch model to ONNX format.
|
|
741
|
+
|
|
742
|
+
Args:
|
|
743
|
+
pytorch_path: Path to PyTorch model (.pth, .pt)
|
|
744
|
+
input_shape_str: Input shape as comma-separated string, e.g., "1,3,224,224"
|
|
745
|
+
output_path: Where to save ONNX file (None = temp file)
|
|
746
|
+
opset_version: ONNX opset version
|
|
747
|
+
logger: Logger instance
|
|
748
|
+
|
|
749
|
+
Returns:
|
|
750
|
+
Tuple of (onnx_path, temp_file_handle_or_None)
|
|
751
|
+
"""
|
|
752
|
+
# Check if torch is available
|
|
753
|
+
try:
|
|
754
|
+
import torch
|
|
755
|
+
except ImportError:
|
|
756
|
+
logger.error("PyTorch not installed. Install with: pip install torch")
|
|
757
|
+
return None, None
|
|
758
|
+
|
|
759
|
+
pytorch_path = pytorch_path.resolve()
|
|
760
|
+
if not pytorch_path.exists():
|
|
761
|
+
logger.error(f"PyTorch model not found: {pytorch_path}")
|
|
762
|
+
return None, None
|
|
763
|
+
|
|
764
|
+
# Parse input shape
|
|
765
|
+
if not input_shape_str:
|
|
766
|
+
logger.error(
|
|
767
|
+
"--input-shape is required for PyTorch conversion. Example: --input-shape 1,3,224,224"
|
|
768
|
+
)
|
|
769
|
+
return None, None
|
|
770
|
+
|
|
771
|
+
try:
|
|
772
|
+
input_shape = tuple(int(x.strip()) for x in input_shape_str.split(","))
|
|
773
|
+
logger.info(f"Input shape: {input_shape}")
|
|
774
|
+
except ValueError:
|
|
775
|
+
logger.error(
|
|
776
|
+
f"Invalid --input-shape format: '{input_shape_str}'. "
|
|
777
|
+
"Use comma-separated integers, e.g., '1,3,224,224'"
|
|
778
|
+
)
|
|
779
|
+
return None, None
|
|
780
|
+
|
|
781
|
+
# Load PyTorch model
|
|
782
|
+
logger.info(f"Loading PyTorch model from: {pytorch_path}")
|
|
783
|
+
model = None
|
|
784
|
+
|
|
785
|
+
# Try 1: TorchScript model (.pt files from torch.jit.save)
|
|
786
|
+
try:
|
|
787
|
+
model = torch.jit.load(str(pytorch_path), map_location="cpu")
|
|
788
|
+
logger.info(f"Loaded TorchScript model: {type(model).__name__}")
|
|
789
|
+
except Exception:
|
|
790
|
+
pass
|
|
791
|
+
|
|
792
|
+
# Try 2: Check for Ultralytics YOLO format first
|
|
793
|
+
if model is None:
|
|
794
|
+
try:
|
|
795
|
+
loaded = torch.load(pytorch_path, map_location="cpu", weights_only=False)
|
|
796
|
+
|
|
797
|
+
if isinstance(loaded, dict):
|
|
798
|
+
# Check if it's an Ultralytics model (has 'model' key with the actual model)
|
|
799
|
+
if "model" in loaded and hasattr(loaded.get("model"), "forward"):
|
|
800
|
+
logger.info("Detected Ultralytics YOLO format, using native export...")
|
|
801
|
+
try:
|
|
802
|
+
from ultralytics import YOLO
|
|
803
|
+
|
|
804
|
+
yolo_model = YOLO(str(pytorch_path))
|
|
805
|
+
|
|
806
|
+
# Determine output path for Ultralytics export
|
|
807
|
+
if output_path:
|
|
808
|
+
onnx_out = output_path.resolve()
|
|
809
|
+
else:
|
|
810
|
+
import tempfile as tf
|
|
811
|
+
|
|
812
|
+
temp = tf.NamedTemporaryFile(suffix=".onnx", delete=False)
|
|
813
|
+
onnx_out = pathlib.Path(temp.name)
|
|
814
|
+
temp.close()
|
|
815
|
+
|
|
816
|
+
# Export using Ultralytics (handles all the complexity)
|
|
817
|
+
yolo_model.export(
|
|
818
|
+
format="onnx",
|
|
819
|
+
imgsz=input_shape[2] if len(input_shape) >= 3 else 640,
|
|
820
|
+
simplify=True,
|
|
821
|
+
opset=opset_version,
|
|
822
|
+
)
|
|
823
|
+
|
|
824
|
+
# Ultralytics saves next to the .pt file, move if needed
|
|
825
|
+
default_onnx = pytorch_path.with_suffix(".onnx")
|
|
826
|
+
if default_onnx.exists() and default_onnx != onnx_out:
|
|
827
|
+
import shutil
|
|
828
|
+
|
|
829
|
+
shutil.move(str(default_onnx), str(onnx_out))
|
|
830
|
+
|
|
831
|
+
logger.info(f"ONNX model saved to: {onnx_out}")
|
|
832
|
+
return onnx_out, None if output_path else onnx_out
|
|
833
|
+
|
|
834
|
+
except ImportError:
|
|
835
|
+
logger.error(
|
|
836
|
+
"Ultralytics YOLO model detected but 'ultralytics' package not installed.\n"
|
|
837
|
+
"Install with: pip install ultralytics\n"
|
|
838
|
+
"Then re-run this command."
|
|
839
|
+
)
|
|
840
|
+
return None, None
|
|
841
|
+
|
|
842
|
+
# It's a regular state_dict - we can't use it directly
|
|
843
|
+
logger.error(
|
|
844
|
+
"Model file appears to be a state_dict (weights only). "
|
|
845
|
+
"To convert, you need either:\n"
|
|
846
|
+
" 1. A TorchScript model: torch.jit.save(torch.jit.script(model), 'model.pt')\n"
|
|
847
|
+
" 2. A full model: torch.save(model, 'model.pth') # run from same codebase\n"
|
|
848
|
+
" 3. Export to ONNX directly in your training code using torch.onnx.export()"
|
|
849
|
+
)
|
|
850
|
+
return None, None
|
|
851
|
+
|
|
852
|
+
model = loaded
|
|
853
|
+
logger.info(f"Loaded PyTorch model: {type(model).__name__}")
|
|
854
|
+
|
|
855
|
+
except Exception as e:
|
|
856
|
+
error_msg = str(e)
|
|
857
|
+
if "Can't get attribute" in error_msg:
|
|
858
|
+
logger.error(
|
|
859
|
+
"Failed to load model - class definition not found.\n"
|
|
860
|
+
"The model was saved with torch.save(model, ...) which requires "
|
|
861
|
+
"the original class to be importable.\n\n"
|
|
862
|
+
"Solutions:\n"
|
|
863
|
+
" 1. Save as TorchScript: torch.jit.save(torch.jit.script(model), 'model.pt')\n"
|
|
864
|
+
" 2. Export to ONNX in your code: torch.onnx.export(model, dummy_input, 'model.onnx')\n"
|
|
865
|
+
" 3. Run this tool from the directory containing your model definition"
|
|
866
|
+
)
|
|
867
|
+
else:
|
|
868
|
+
logger.error(f"Failed to load PyTorch model: {e}")
|
|
869
|
+
return None, None
|
|
870
|
+
|
|
871
|
+
if model is None:
|
|
872
|
+
logger.error("Could not load the PyTorch model.")
|
|
873
|
+
return None, None
|
|
874
|
+
|
|
875
|
+
model.eval()
|
|
876
|
+
|
|
877
|
+
# Create dummy input
|
|
878
|
+
try:
|
|
879
|
+
dummy_input = torch.randn(*input_shape)
|
|
880
|
+
logger.info(f"Created dummy input with shape: {dummy_input.shape}")
|
|
881
|
+
except Exception as e:
|
|
882
|
+
logger.error(f"Failed to create input tensor: {e}")
|
|
883
|
+
return None, None
|
|
884
|
+
|
|
885
|
+
# Determine output path
|
|
886
|
+
temp_file = None
|
|
887
|
+
if output_path:
|
|
888
|
+
onnx_path = output_path.resolve()
|
|
889
|
+
onnx_path.parent.mkdir(parents=True, exist_ok=True)
|
|
890
|
+
else:
|
|
891
|
+
import tempfile
|
|
892
|
+
|
|
893
|
+
temp_file = tempfile.NamedTemporaryFile(suffix=".onnx", delete=False)
|
|
894
|
+
onnx_path = pathlib.Path(temp_file.name)
|
|
895
|
+
temp_file.close()
|
|
896
|
+
|
|
897
|
+
# Export to ONNX
|
|
898
|
+
logger.info(f"Exporting to ONNX (opset {opset_version})...")
|
|
899
|
+
try:
|
|
900
|
+
torch.onnx.export(
|
|
901
|
+
model,
|
|
902
|
+
dummy_input,
|
|
903
|
+
str(onnx_path),
|
|
904
|
+
input_names=["input"],
|
|
905
|
+
output_names=["output"],
|
|
906
|
+
dynamic_axes={
|
|
907
|
+
"input": {0: "batch_size"},
|
|
908
|
+
"output": {0: "batch_size"},
|
|
909
|
+
},
|
|
910
|
+
opset_version=opset_version,
|
|
911
|
+
do_constant_folding=True,
|
|
912
|
+
)
|
|
913
|
+
logger.info(f"ONNX model saved to: {onnx_path}")
|
|
914
|
+
|
|
915
|
+
# Verify the export
|
|
916
|
+
import onnx
|
|
917
|
+
|
|
918
|
+
onnx_model = onnx.load(str(onnx_path))
|
|
919
|
+
onnx.checker.check_model(onnx_model)
|
|
920
|
+
logger.info("ONNX model validated successfully")
|
|
921
|
+
|
|
922
|
+
except Exception as e:
|
|
923
|
+
logger.error(f"ONNX export failed: {e}")
|
|
924
|
+
if temp_file:
|
|
925
|
+
try:
|
|
926
|
+
onnx_path.unlink()
|
|
927
|
+
except Exception:
|
|
928
|
+
pass
|
|
929
|
+
return None, None
|
|
930
|
+
|
|
931
|
+
return onnx_path, temp_file
|
|
932
|
+
|
|
933
|
+
|
|
934
|
+
def _convert_tensorflow_to_onnx(
|
|
935
|
+
tf_path: pathlib.Path,
|
|
936
|
+
output_path: pathlib.Path | None,
|
|
937
|
+
opset_version: int,
|
|
938
|
+
logger: logging.Logger,
|
|
939
|
+
) -> tuple[pathlib.Path | None, Any]:
|
|
940
|
+
"""
|
|
941
|
+
Convert a TensorFlow SavedModel to ONNX format.
|
|
942
|
+
|
|
943
|
+
Args:
|
|
944
|
+
tf_path: Path to TensorFlow SavedModel directory
|
|
945
|
+
output_path: Where to save ONNX file (None = temp file)
|
|
946
|
+
opset_version: ONNX opset version
|
|
947
|
+
logger: Logger instance
|
|
948
|
+
|
|
949
|
+
Returns:
|
|
950
|
+
Tuple of (onnx_path, temp_file_handle_or_None)
|
|
951
|
+
"""
|
|
952
|
+
# Check if tf2onnx is available
|
|
953
|
+
try:
|
|
954
|
+
import tf2onnx
|
|
955
|
+
from tf2onnx import tf_loader
|
|
956
|
+
except ImportError:
|
|
957
|
+
logger.error("tf2onnx not installed. Install with: pip install tf2onnx tensorflow")
|
|
958
|
+
return None, None
|
|
959
|
+
|
|
960
|
+
tf_path = tf_path.resolve()
|
|
961
|
+
if not tf_path.exists():
|
|
962
|
+
logger.error(f"TensorFlow model not found: {tf_path}")
|
|
963
|
+
return None, None
|
|
964
|
+
|
|
965
|
+
# Determine output path
|
|
966
|
+
temp_file = None
|
|
967
|
+
if output_path:
|
|
968
|
+
onnx_path = output_path.resolve()
|
|
969
|
+
onnx_path.parent.mkdir(parents=True, exist_ok=True)
|
|
970
|
+
else:
|
|
971
|
+
import tempfile
|
|
972
|
+
|
|
973
|
+
temp_file = tempfile.NamedTemporaryFile(suffix=".onnx", delete=False)
|
|
974
|
+
onnx_path = pathlib.Path(temp_file.name)
|
|
975
|
+
temp_file.close()
|
|
976
|
+
|
|
977
|
+
logger.info("Converting TensorFlow SavedModel to ONNX...")
|
|
978
|
+
logger.info(f" Source: {tf_path}")
|
|
979
|
+
logger.info(f" Target: {onnx_path}")
|
|
980
|
+
|
|
981
|
+
try:
|
|
982
|
+
import subprocess
|
|
983
|
+
import sys
|
|
984
|
+
|
|
985
|
+
# Use tf2onnx CLI for robustness (handles TF version quirks)
|
|
986
|
+
cmd = [
|
|
987
|
+
sys.executable,
|
|
988
|
+
"-m",
|
|
989
|
+
"tf2onnx.convert",
|
|
990
|
+
"--saved-model",
|
|
991
|
+
str(tf_path),
|
|
992
|
+
"--output",
|
|
993
|
+
str(onnx_path),
|
|
994
|
+
"--opset",
|
|
995
|
+
str(opset_version),
|
|
996
|
+
]
|
|
997
|
+
|
|
998
|
+
result = subprocess.run(
|
|
999
|
+
cmd,
|
|
1000
|
+
check=False,
|
|
1001
|
+
capture_output=True,
|
|
1002
|
+
text=True,
|
|
1003
|
+
timeout=600, # 10 minute timeout for large models
|
|
1004
|
+
)
|
|
1005
|
+
|
|
1006
|
+
if result.returncode != 0:
|
|
1007
|
+
logger.error(f"tf2onnx conversion failed:\n{result.stderr}")
|
|
1008
|
+
if temp_file:
|
|
1009
|
+
try:
|
|
1010
|
+
onnx_path.unlink()
|
|
1011
|
+
except Exception:
|
|
1012
|
+
pass
|
|
1013
|
+
return None, None
|
|
1014
|
+
|
|
1015
|
+
logger.info("TensorFlow model converted successfully")
|
|
1016
|
+
|
|
1017
|
+
# Verify the export
|
|
1018
|
+
import onnx
|
|
1019
|
+
|
|
1020
|
+
onnx_model = onnx.load(str(onnx_path))
|
|
1021
|
+
onnx.checker.check_model(onnx_model)
|
|
1022
|
+
logger.info("ONNX model validated successfully")
|
|
1023
|
+
|
|
1024
|
+
except subprocess.TimeoutExpired:
|
|
1025
|
+
logger.error("tf2onnx conversion timed out after 10 minutes")
|
|
1026
|
+
if temp_file:
|
|
1027
|
+
try:
|
|
1028
|
+
onnx_path.unlink()
|
|
1029
|
+
except Exception:
|
|
1030
|
+
pass
|
|
1031
|
+
return None, None
|
|
1032
|
+
except Exception as e:
|
|
1033
|
+
logger.error(f"TensorFlow conversion failed: {e}")
|
|
1034
|
+
if temp_file:
|
|
1035
|
+
try:
|
|
1036
|
+
onnx_path.unlink()
|
|
1037
|
+
except Exception:
|
|
1038
|
+
pass
|
|
1039
|
+
return None, None
|
|
1040
|
+
|
|
1041
|
+
return onnx_path, temp_file
|
|
1042
|
+
|
|
1043
|
+
|
|
1044
|
+
def _convert_keras_to_onnx(
|
|
1045
|
+
keras_path: pathlib.Path,
|
|
1046
|
+
output_path: pathlib.Path | None,
|
|
1047
|
+
opset_version: int,
|
|
1048
|
+
logger: logging.Logger,
|
|
1049
|
+
) -> tuple[pathlib.Path | None, Any]:
|
|
1050
|
+
"""
|
|
1051
|
+
Convert a Keras model (.h5, .keras) to ONNX format.
|
|
1052
|
+
|
|
1053
|
+
Args:
|
|
1054
|
+
keras_path: Path to Keras model file (.h5 or .keras)
|
|
1055
|
+
output_path: Where to save ONNX file (None = temp file)
|
|
1056
|
+
opset_version: ONNX opset version
|
|
1057
|
+
logger: Logger instance
|
|
1058
|
+
|
|
1059
|
+
Returns:
|
|
1060
|
+
Tuple of (onnx_path, temp_file_handle_or_None)
|
|
1061
|
+
"""
|
|
1062
|
+
# Check if tf2onnx is available
|
|
1063
|
+
try:
|
|
1064
|
+
import tf2onnx
|
|
1065
|
+
except ImportError:
|
|
1066
|
+
logger.error("tf2onnx not installed. Install with: pip install tf2onnx tensorflow")
|
|
1067
|
+
return None, None
|
|
1068
|
+
|
|
1069
|
+
keras_path = keras_path.resolve()
|
|
1070
|
+
if not keras_path.exists():
|
|
1071
|
+
logger.error(f"Keras model not found: {keras_path}")
|
|
1072
|
+
return None, None
|
|
1073
|
+
|
|
1074
|
+
suffix = keras_path.suffix.lower()
|
|
1075
|
+
if suffix not in (".h5", ".keras", ".hdf5"):
|
|
1076
|
+
logger.warning(f"Unexpected Keras file extension: {suffix}. Proceeding anyway.")
|
|
1077
|
+
|
|
1078
|
+
# Determine output path
|
|
1079
|
+
temp_file = None
|
|
1080
|
+
if output_path:
|
|
1081
|
+
onnx_path = output_path.resolve()
|
|
1082
|
+
onnx_path.parent.mkdir(parents=True, exist_ok=True)
|
|
1083
|
+
else:
|
|
1084
|
+
import tempfile
|
|
1085
|
+
|
|
1086
|
+
temp_file = tempfile.NamedTemporaryFile(suffix=".onnx", delete=False)
|
|
1087
|
+
onnx_path = pathlib.Path(temp_file.name)
|
|
1088
|
+
temp_file.close()
|
|
1089
|
+
|
|
1090
|
+
logger.info("Converting Keras model to ONNX...")
|
|
1091
|
+
logger.info(f" Source: {keras_path}")
|
|
1092
|
+
logger.info(f" Target: {onnx_path}")
|
|
1093
|
+
|
|
1094
|
+
try:
|
|
1095
|
+
import subprocess
|
|
1096
|
+
import sys
|
|
1097
|
+
|
|
1098
|
+
# Use tf2onnx CLI with --keras flag
|
|
1099
|
+
cmd = [
|
|
1100
|
+
sys.executable,
|
|
1101
|
+
"-m",
|
|
1102
|
+
"tf2onnx.convert",
|
|
1103
|
+
"--keras",
|
|
1104
|
+
str(keras_path),
|
|
1105
|
+
"--output",
|
|
1106
|
+
str(onnx_path),
|
|
1107
|
+
"--opset",
|
|
1108
|
+
str(opset_version),
|
|
1109
|
+
]
|
|
1110
|
+
|
|
1111
|
+
result = subprocess.run(
|
|
1112
|
+
cmd,
|
|
1113
|
+
check=False,
|
|
1114
|
+
capture_output=True,
|
|
1115
|
+
text=True,
|
|
1116
|
+
timeout=600, # 10 minute timeout
|
|
1117
|
+
)
|
|
1118
|
+
|
|
1119
|
+
if result.returncode != 0:
|
|
1120
|
+
logger.error(f"tf2onnx conversion failed:\n{result.stderr}")
|
|
1121
|
+
if temp_file:
|
|
1122
|
+
try:
|
|
1123
|
+
onnx_path.unlink()
|
|
1124
|
+
except Exception:
|
|
1125
|
+
pass
|
|
1126
|
+
return None, None
|
|
1127
|
+
|
|
1128
|
+
logger.info("Keras model converted successfully")
|
|
1129
|
+
|
|
1130
|
+
# Verify the export
|
|
1131
|
+
import onnx
|
|
1132
|
+
|
|
1133
|
+
onnx_model = onnx.load(str(onnx_path))
|
|
1134
|
+
onnx.checker.check_model(onnx_model)
|
|
1135
|
+
logger.info("ONNX model validated successfully")
|
|
1136
|
+
|
|
1137
|
+
except subprocess.TimeoutExpired:
|
|
1138
|
+
logger.error("tf2onnx conversion timed out after 10 minutes")
|
|
1139
|
+
if temp_file:
|
|
1140
|
+
try:
|
|
1141
|
+
onnx_path.unlink()
|
|
1142
|
+
except Exception:
|
|
1143
|
+
pass
|
|
1144
|
+
return None, None
|
|
1145
|
+
except Exception as e:
|
|
1146
|
+
logger.error(f"Keras conversion failed: {e}")
|
|
1147
|
+
if temp_file:
|
|
1148
|
+
try:
|
|
1149
|
+
onnx_path.unlink()
|
|
1150
|
+
except Exception:
|
|
1151
|
+
pass
|
|
1152
|
+
return None, None
|
|
1153
|
+
|
|
1154
|
+
return onnx_path, temp_file
|
|
1155
|
+
|
|
1156
|
+
|
|
1157
|
+
def _convert_frozen_graph_to_onnx(
|
|
1158
|
+
pb_path: pathlib.Path,
|
|
1159
|
+
inputs: str,
|
|
1160
|
+
outputs: str,
|
|
1161
|
+
output_path: pathlib.Path | None,
|
|
1162
|
+
opset_version: int,
|
|
1163
|
+
logger: logging.Logger,
|
|
1164
|
+
) -> tuple[pathlib.Path | None, Any]:
|
|
1165
|
+
"""
|
|
1166
|
+
Convert a TensorFlow frozen graph (.pb) to ONNX format.
|
|
1167
|
+
|
|
1168
|
+
Args:
|
|
1169
|
+
pb_path: Path to frozen graph .pb file
|
|
1170
|
+
inputs: Comma-separated input tensor names (e.g., "input:0")
|
|
1171
|
+
outputs: Comma-separated output tensor names (e.g., "output:0")
|
|
1172
|
+
output_path: Where to save ONNX file (None = temp file)
|
|
1173
|
+
opset_version: ONNX opset version
|
|
1174
|
+
logger: Logger instance
|
|
1175
|
+
|
|
1176
|
+
Returns:
|
|
1177
|
+
Tuple of (onnx_path, temp_file_handle_or_None)
|
|
1178
|
+
"""
|
|
1179
|
+
# Check if tf2onnx is available
|
|
1180
|
+
try:
|
|
1181
|
+
import tf2onnx
|
|
1182
|
+
except ImportError:
|
|
1183
|
+
logger.error("tf2onnx not installed. Install with: pip install tf2onnx tensorflow")
|
|
1184
|
+
return None, None
|
|
1185
|
+
|
|
1186
|
+
pb_path = pb_path.resolve()
|
|
1187
|
+
if not pb_path.exists():
|
|
1188
|
+
logger.error(f"Frozen graph not found: {pb_path}")
|
|
1189
|
+
return None, None
|
|
1190
|
+
|
|
1191
|
+
if not inputs or not outputs:
|
|
1192
|
+
logger.error(
|
|
1193
|
+
"--tf-inputs and --tf-outputs are required for frozen graph conversion.\n"
|
|
1194
|
+
"Example: --from-frozen-graph model.pb --tf-inputs input:0 --tf-outputs output:0"
|
|
1195
|
+
)
|
|
1196
|
+
return None, None
|
|
1197
|
+
|
|
1198
|
+
# Determine output path
|
|
1199
|
+
temp_file = None
|
|
1200
|
+
if output_path:
|
|
1201
|
+
onnx_path = output_path.resolve()
|
|
1202
|
+
onnx_path.parent.mkdir(parents=True, exist_ok=True)
|
|
1203
|
+
else:
|
|
1204
|
+
import tempfile
|
|
1205
|
+
|
|
1206
|
+
temp_file = tempfile.NamedTemporaryFile(suffix=".onnx", delete=False)
|
|
1207
|
+
onnx_path = pathlib.Path(temp_file.name)
|
|
1208
|
+
temp_file.close()
|
|
1209
|
+
|
|
1210
|
+
logger.info("Converting TensorFlow frozen graph to ONNX...")
|
|
1211
|
+
logger.info(f" Source: {pb_path}")
|
|
1212
|
+
logger.info(f" Inputs: {inputs}")
|
|
1213
|
+
logger.info(f" Outputs: {outputs}")
|
|
1214
|
+
logger.info(f" Target: {onnx_path}")
|
|
1215
|
+
|
|
1216
|
+
try:
|
|
1217
|
+
import subprocess
|
|
1218
|
+
import sys
|
|
1219
|
+
|
|
1220
|
+
# Use tf2onnx CLI with --graphdef flag
|
|
1221
|
+
cmd = [
|
|
1222
|
+
sys.executable,
|
|
1223
|
+
"-m",
|
|
1224
|
+
"tf2onnx.convert",
|
|
1225
|
+
"--graphdef",
|
|
1226
|
+
str(pb_path),
|
|
1227
|
+
"--inputs",
|
|
1228
|
+
inputs,
|
|
1229
|
+
"--outputs",
|
|
1230
|
+
outputs,
|
|
1231
|
+
"--output",
|
|
1232
|
+
str(onnx_path),
|
|
1233
|
+
"--opset",
|
|
1234
|
+
str(opset_version),
|
|
1235
|
+
]
|
|
1236
|
+
|
|
1237
|
+
result = subprocess.run(
|
|
1238
|
+
cmd,
|
|
1239
|
+
check=False,
|
|
1240
|
+
capture_output=True,
|
|
1241
|
+
text=True,
|
|
1242
|
+
timeout=600, # 10 minute timeout
|
|
1243
|
+
)
|
|
1244
|
+
|
|
1245
|
+
if result.returncode != 0:
|
|
1246
|
+
logger.error(f"tf2onnx conversion failed:\n{result.stderr}")
|
|
1247
|
+
if temp_file:
|
|
1248
|
+
try:
|
|
1249
|
+
onnx_path.unlink()
|
|
1250
|
+
except Exception:
|
|
1251
|
+
pass
|
|
1252
|
+
return None, None
|
|
1253
|
+
|
|
1254
|
+
logger.info("Frozen graph converted successfully")
|
|
1255
|
+
|
|
1256
|
+
# Verify the export
|
|
1257
|
+
import onnx
|
|
1258
|
+
|
|
1259
|
+
onnx_model = onnx.load(str(onnx_path))
|
|
1260
|
+
onnx.checker.check_model(onnx_model)
|
|
1261
|
+
logger.info("ONNX model validated successfully")
|
|
1262
|
+
|
|
1263
|
+
except subprocess.TimeoutExpired:
|
|
1264
|
+
logger.error("tf2onnx conversion timed out after 10 minutes")
|
|
1265
|
+
if temp_file:
|
|
1266
|
+
try:
|
|
1267
|
+
onnx_path.unlink()
|
|
1268
|
+
except Exception:
|
|
1269
|
+
pass
|
|
1270
|
+
return None, None
|
|
1271
|
+
except Exception as e:
|
|
1272
|
+
logger.error(f"Frozen graph conversion failed: {e}")
|
|
1273
|
+
if temp_file:
|
|
1274
|
+
try:
|
|
1275
|
+
onnx_path.unlink()
|
|
1276
|
+
except Exception:
|
|
1277
|
+
pass
|
|
1278
|
+
return None, None
|
|
1279
|
+
|
|
1280
|
+
return onnx_path, temp_file
|
|
1281
|
+
|
|
1282
|
+
|
|
1283
|
+
def _convert_jax_to_onnx(
|
|
1284
|
+
jax_path: pathlib.Path,
|
|
1285
|
+
apply_fn_path: str | None,
|
|
1286
|
+
input_shape_str: str | None,
|
|
1287
|
+
output_path: pathlib.Path | None,
|
|
1288
|
+
opset_version: int,
|
|
1289
|
+
logger: logging.Logger,
|
|
1290
|
+
) -> tuple[pathlib.Path | None, Any]:
|
|
1291
|
+
"""
|
|
1292
|
+
Convert a JAX/Flax model to ONNX format.
|
|
1293
|
+
|
|
1294
|
+
This is more complex than other conversions because JAX doesn't have a standard
|
|
1295
|
+
serialization format. The typical flow is:
|
|
1296
|
+
1. Load model params from file (.msgpack, .pkl, etc.)
|
|
1297
|
+
2. Import the apply function from user's code
|
|
1298
|
+
3. Convert JAX -> TF SavedModel -> ONNX
|
|
1299
|
+
|
|
1300
|
+
Args:
|
|
1301
|
+
jax_path: Path to JAX params file (.msgpack, .pkl, .npy)
|
|
1302
|
+
apply_fn_path: Module:function path to the apply function
|
|
1303
|
+
input_shape_str: Input shape for tracing
|
|
1304
|
+
output_path: Where to save ONNX file (None = temp file)
|
|
1305
|
+
opset_version: ONNX opset version
|
|
1306
|
+
logger: Logger instance
|
|
1307
|
+
|
|
1308
|
+
Returns:
|
|
1309
|
+
Tuple of (onnx_path, temp_file_handle_or_None)
|
|
1310
|
+
"""
|
|
1311
|
+
# Check dependencies
|
|
1312
|
+
try:
|
|
1313
|
+
import jax
|
|
1314
|
+
import jax.numpy as jnp
|
|
1315
|
+
except ImportError:
|
|
1316
|
+
logger.error("JAX not installed. Install with: pip install jax jaxlib")
|
|
1317
|
+
return None, None
|
|
1318
|
+
|
|
1319
|
+
try:
|
|
1320
|
+
import tf2onnx
|
|
1321
|
+
except ImportError:
|
|
1322
|
+
logger.error("tf2onnx not installed. Install with: pip install tf2onnx tensorflow")
|
|
1323
|
+
return None, None
|
|
1324
|
+
|
|
1325
|
+
jax_path = jax_path.resolve()
|
|
1326
|
+
if not jax_path.exists():
|
|
1327
|
+
logger.error(f"JAX params file not found: {jax_path}")
|
|
1328
|
+
return None, None
|
|
1329
|
+
|
|
1330
|
+
if not apply_fn_path:
|
|
1331
|
+
logger.error(
|
|
1332
|
+
"--jax-apply-fn is required for JAX conversion.\n"
|
|
1333
|
+
"Example: --from-jax params.pkl --jax-apply-fn my_model:apply --input-shape 1,3,224,224"
|
|
1334
|
+
)
|
|
1335
|
+
return None, None
|
|
1336
|
+
|
|
1337
|
+
if not input_shape_str:
|
|
1338
|
+
logger.error(
|
|
1339
|
+
"--input-shape is required for JAX conversion.\n"
|
|
1340
|
+
"Example: --from-jax params.pkl --jax-apply-fn my_model:apply --input-shape 1,3,224,224"
|
|
1341
|
+
)
|
|
1342
|
+
return None, None
|
|
1343
|
+
|
|
1344
|
+
# Parse input shape
|
|
1345
|
+
try:
|
|
1346
|
+
input_shape = tuple(int(x.strip()) for x in input_shape_str.split(","))
|
|
1347
|
+
except ValueError:
|
|
1348
|
+
logger.error(
|
|
1349
|
+
f"Invalid --input-shape format: '{input_shape_str}'. "
|
|
1350
|
+
"Use comma-separated integers, e.g., '1,3,224,224'"
|
|
1351
|
+
)
|
|
1352
|
+
return None, None
|
|
1353
|
+
|
|
1354
|
+
# Parse apply function path (module:function)
|
|
1355
|
+
if ":" not in apply_fn_path:
|
|
1356
|
+
logger.error(
|
|
1357
|
+
f"Invalid --jax-apply-fn format: '{apply_fn_path}'. "
|
|
1358
|
+
"Use module:function format, e.g., 'my_model:apply'"
|
|
1359
|
+
)
|
|
1360
|
+
return None, None
|
|
1361
|
+
|
|
1362
|
+
module_path, fn_name = apply_fn_path.rsplit(":", 1)
|
|
1363
|
+
|
|
1364
|
+
# Determine output path
|
|
1365
|
+
temp_file = None
|
|
1366
|
+
if output_path:
|
|
1367
|
+
onnx_path = output_path.resolve()
|
|
1368
|
+
onnx_path.parent.mkdir(parents=True, exist_ok=True)
|
|
1369
|
+
else:
|
|
1370
|
+
import tempfile
|
|
1371
|
+
|
|
1372
|
+
temp_file = tempfile.NamedTemporaryFile(suffix=".onnx", delete=False)
|
|
1373
|
+
onnx_path = pathlib.Path(temp_file.name)
|
|
1374
|
+
temp_file.close()
|
|
1375
|
+
|
|
1376
|
+
logger.info("Converting JAX model to ONNX...")
|
|
1377
|
+
logger.info(f" Params: {jax_path}")
|
|
1378
|
+
logger.info(f" Apply fn: {apply_fn_path}")
|
|
1379
|
+
logger.info(f" Input shape: {input_shape}")
|
|
1380
|
+
logger.info(f" Target: {onnx_path}")
|
|
1381
|
+
|
|
1382
|
+
try:
|
|
1383
|
+
import importlib.util
|
|
1384
|
+
import sys
|
|
1385
|
+
|
|
1386
|
+
# Load params
|
|
1387
|
+
suffix = jax_path.suffix.lower()
|
|
1388
|
+
if suffix == ".msgpack":
|
|
1389
|
+
try:
|
|
1390
|
+
from flax.serialization import msgpack_restore
|
|
1391
|
+
|
|
1392
|
+
with open(jax_path, "rb") as f:
|
|
1393
|
+
params = msgpack_restore(f.read())
|
|
1394
|
+
logger.info("Loaded Flax msgpack params")
|
|
1395
|
+
except ImportError:
|
|
1396
|
+
logger.error("Flax not installed. Install with: pip install flax")
|
|
1397
|
+
return None, None
|
|
1398
|
+
elif suffix in (".pkl", ".pickle"):
|
|
1399
|
+
import pickle
|
|
1400
|
+
|
|
1401
|
+
with open(jax_path, "rb") as f:
|
|
1402
|
+
params = pickle.load(f)
|
|
1403
|
+
logger.info("Loaded pickle params")
|
|
1404
|
+
elif suffix == ".npy":
|
|
1405
|
+
import numpy as np
|
|
1406
|
+
|
|
1407
|
+
params = np.load(jax_path, allow_pickle=True).item()
|
|
1408
|
+
logger.info("Loaded numpy params")
|
|
1409
|
+
else:
|
|
1410
|
+
logger.error(f"Unsupported params format: {suffix}. Use .msgpack, .pkl, or .npy")
|
|
1411
|
+
return None, None
|
|
1412
|
+
|
|
1413
|
+
# Import apply function
|
|
1414
|
+
# Add current directory to path for local imports
|
|
1415
|
+
sys.path.insert(0, str(pathlib.Path.cwd()))
|
|
1416
|
+
|
|
1417
|
+
try:
|
|
1418
|
+
module = importlib.import_module(module_path)
|
|
1419
|
+
apply_fn = getattr(module, fn_name)
|
|
1420
|
+
logger.info(f"Loaded apply function: {apply_fn_path}")
|
|
1421
|
+
except (ImportError, AttributeError) as e:
|
|
1422
|
+
logger.error(f"Could not import {apply_fn_path}: {e}")
|
|
1423
|
+
return None, None
|
|
1424
|
+
|
|
1425
|
+
# Convert via jax2tf (JAX -> TF -> ONNX)
|
|
1426
|
+
try:
|
|
1427
|
+
import tensorflow as tf
|
|
1428
|
+
from jax.experimental import jax2tf
|
|
1429
|
+
except ImportError:
|
|
1430
|
+
logger.error("jax2tf or TensorFlow not available. Install with: pip install tensorflow")
|
|
1431
|
+
return None, None
|
|
1432
|
+
|
|
1433
|
+
# Create a concrete function
|
|
1434
|
+
jnp.zeros(input_shape, dtype=jnp.float32)
|
|
1435
|
+
|
|
1436
|
+
# Convert JAX function to TF
|
|
1437
|
+
tf_fn = jax2tf.convert(
|
|
1438
|
+
lambda x: apply_fn(params, x),
|
|
1439
|
+
enable_xla=False,
|
|
1440
|
+
)
|
|
1441
|
+
|
|
1442
|
+
# Create TF SavedModel
|
|
1443
|
+
import tempfile as tf_tempfile
|
|
1444
|
+
|
|
1445
|
+
with tf_tempfile.TemporaryDirectory() as tf_model_dir:
|
|
1446
|
+
# Wrap in tf.Module for SavedModel export
|
|
1447
|
+
class TFModule(tf.Module):
|
|
1448
|
+
@tf.function(input_signature=[tf.TensorSpec(input_shape, tf.float32)])
|
|
1449
|
+
def __call__(self, x):
|
|
1450
|
+
return tf_fn(x)
|
|
1451
|
+
|
|
1452
|
+
tf_module = TFModule()
|
|
1453
|
+
tf.saved_model.save(tf_module, tf_model_dir)
|
|
1454
|
+
logger.info("Created temporary TF SavedModel")
|
|
1455
|
+
|
|
1456
|
+
# Convert TF SavedModel to ONNX
|
|
1457
|
+
import subprocess
|
|
1458
|
+
|
|
1459
|
+
cmd = [
|
|
1460
|
+
sys.executable,
|
|
1461
|
+
"-m",
|
|
1462
|
+
"tf2onnx.convert",
|
|
1463
|
+
"--saved-model",
|
|
1464
|
+
tf_model_dir,
|
|
1465
|
+
"--output",
|
|
1466
|
+
str(onnx_path),
|
|
1467
|
+
"--opset",
|
|
1468
|
+
str(opset_version),
|
|
1469
|
+
]
|
|
1470
|
+
|
|
1471
|
+
result = subprocess.run(
|
|
1472
|
+
cmd,
|
|
1473
|
+
check=False,
|
|
1474
|
+
capture_output=True,
|
|
1475
|
+
text=True,
|
|
1476
|
+
timeout=600,
|
|
1477
|
+
)
|
|
1478
|
+
|
|
1479
|
+
if result.returncode != 0:
|
|
1480
|
+
logger.error(f"tf2onnx conversion failed:\n{result.stderr}")
|
|
1481
|
+
if temp_file:
|
|
1482
|
+
try:
|
|
1483
|
+
onnx_path.unlink()
|
|
1484
|
+
except Exception:
|
|
1485
|
+
pass
|
|
1486
|
+
return None, None
|
|
1487
|
+
|
|
1488
|
+
logger.info("JAX model converted successfully")
|
|
1489
|
+
|
|
1490
|
+
# Verify the export
|
|
1491
|
+
import onnx
|
|
1492
|
+
|
|
1493
|
+
onnx_model = onnx.load(str(onnx_path))
|
|
1494
|
+
onnx.checker.check_model(onnx_model)
|
|
1495
|
+
logger.info("ONNX model validated successfully")
|
|
1496
|
+
|
|
1497
|
+
except Exception as e:
|
|
1498
|
+
logger.error(f"JAX conversion failed: {e}")
|
|
1499
|
+
import traceback
|
|
1500
|
+
|
|
1501
|
+
logger.debug(traceback.format_exc())
|
|
1502
|
+
if temp_file:
|
|
1503
|
+
try:
|
|
1504
|
+
onnx_path.unlink()
|
|
1505
|
+
except Exception:
|
|
1506
|
+
pass
|
|
1507
|
+
return None, None
|
|
1508
|
+
|
|
1509
|
+
return onnx_path, temp_file
|
|
1510
|
+
|
|
1511
|
+
|
|
1512
|
+
def run_inspect():
|
|
1513
|
+
"""Main entry point for the model_inspect CLI."""
|
|
1514
|
+
# Load environment variables from .env file if present
|
|
1515
|
+
try:
|
|
1516
|
+
from dotenv import load_dotenv
|
|
1517
|
+
|
|
1518
|
+
load_dotenv()
|
|
1519
|
+
except ImportError:
|
|
1520
|
+
pass # python-dotenv not installed, use environment variables directly
|
|
1521
|
+
|
|
1522
|
+
args = parse_args()
|
|
1523
|
+
logger = setup_logging(args.log_level)
|
|
1524
|
+
|
|
1525
|
+
# Handle --schema
|
|
1526
|
+
if args.schema:
|
|
1527
|
+
import json
|
|
1528
|
+
|
|
1529
|
+
from .schema import get_schema
|
|
1530
|
+
|
|
1531
|
+
schema = get_schema()
|
|
1532
|
+
print(json.dumps(schema, indent=2))
|
|
1533
|
+
return 0
|
|
1534
|
+
|
|
1535
|
+
# Handle --list-hardware
|
|
1536
|
+
if args.list_hardware:
|
|
1537
|
+
print("\n" + "=" * 70)
|
|
1538
|
+
print("Available Hardware Profiles")
|
|
1539
|
+
print("=" * 70)
|
|
1540
|
+
|
|
1541
|
+
print("\nData Center GPUs - H100 Series:")
|
|
1542
|
+
for name in ["h100-sxm", "h100-pcie", "h100-nvl"]:
|
|
1543
|
+
profile = get_profile(name)
|
|
1544
|
+
if profile:
|
|
1545
|
+
print(
|
|
1546
|
+
f" {name:20} {profile.name:30} {profile.vram_bytes // (1024**3):3} GB {profile.peak_fp16_tflops:6.1f} TF16"
|
|
1547
|
+
)
|
|
1548
|
+
|
|
1549
|
+
print("\nData Center GPUs - A100 Series:")
|
|
1550
|
+
for name in [
|
|
1551
|
+
"a100-80gb-sxm",
|
|
1552
|
+
"a100-80gb-pcie",
|
|
1553
|
+
"a100-40gb-sxm",
|
|
1554
|
+
"a100-40gb-pcie",
|
|
1555
|
+
]:
|
|
1556
|
+
profile = get_profile(name)
|
|
1557
|
+
if profile:
|
|
1558
|
+
print(
|
|
1559
|
+
f" {name:20} {profile.name:30} {profile.vram_bytes // (1024**3):3} GB {profile.peak_fp16_tflops:6.1f} TF16"
|
|
1560
|
+
)
|
|
1561
|
+
|
|
1562
|
+
print("\nData Center GPUs - Other:")
|
|
1563
|
+
for name in ["a10", "l4", "l40", "l40s", "t4"]:
|
|
1564
|
+
profile = get_profile(name)
|
|
1565
|
+
if profile:
|
|
1566
|
+
print(
|
|
1567
|
+
f" {name:20} {profile.name:30} {profile.vram_bytes // (1024**3):3} GB {profile.peak_fp16_tflops:6.1f} TF16"
|
|
1568
|
+
)
|
|
1569
|
+
|
|
1570
|
+
print("\nData Center GPUs - V100 Series:")
|
|
1571
|
+
for name in [
|
|
1572
|
+
"v100-32gb-sxm",
|
|
1573
|
+
"v100-32gb-pcie",
|
|
1574
|
+
"v100-16gb-sxm",
|
|
1575
|
+
"v100-16gb-pcie",
|
|
1576
|
+
]:
|
|
1577
|
+
profile = get_profile(name)
|
|
1578
|
+
if profile:
|
|
1579
|
+
print(
|
|
1580
|
+
f" {name:20} {profile.name:30} {profile.vram_bytes // (1024**3):3} GB {profile.peak_fp16_tflops:6.1f} TF16"
|
|
1581
|
+
)
|
|
1582
|
+
|
|
1583
|
+
print("\nDGX Systems (Multi-GPU):")
|
|
1584
|
+
for name in ["dgx-h100", "dgx-a100-640gb", "dgx-a100-320gb"]:
|
|
1585
|
+
profile = get_profile(name)
|
|
1586
|
+
if profile:
|
|
1587
|
+
print(
|
|
1588
|
+
f" {name:20} {profile.name:30} {profile.vram_bytes // (1024**3):3} GB {profile.peak_fp16_tflops:6.1f} TF16"
|
|
1589
|
+
)
|
|
1590
|
+
|
|
1591
|
+
print("\nJetson Edge/Embedded (Orin Series - Recommended for new projects):")
|
|
1592
|
+
for name in [
|
|
1593
|
+
"jetson-agx-orin-64gb",
|
|
1594
|
+
"jetson-agx-orin-32gb",
|
|
1595
|
+
"jetson-orin-nx-16gb",
|
|
1596
|
+
"jetson-orin-nx-8gb",
|
|
1597
|
+
"jetson-orin-nano-8gb",
|
|
1598
|
+
"jetson-orin-nano-4gb",
|
|
1599
|
+
]:
|
|
1600
|
+
profile = get_profile(name)
|
|
1601
|
+
if profile:
|
|
1602
|
+
print(
|
|
1603
|
+
f" {name:20} {profile.name:30} {profile.vram_bytes // (1024**3):3} GB {profile.peak_fp16_tflops:6.1f} TF16"
|
|
1604
|
+
)
|
|
1605
|
+
|
|
1606
|
+
print("\nJetson Edge/Embedded (Xavier Series):")
|
|
1607
|
+
for name in ["jetson-agx-xavier", "jetson-xavier-nx-8gb"]:
|
|
1608
|
+
profile = get_profile(name)
|
|
1609
|
+
if profile:
|
|
1610
|
+
print(
|
|
1611
|
+
f" {name:20} {profile.name:30} {profile.vram_bytes // (1024**3):3} GB {profile.peak_fp16_tflops:6.1f} TF16"
|
|
1612
|
+
)
|
|
1613
|
+
|
|
1614
|
+
print("\nJetson Edge/Embedded (Legacy - Very Constrained!):")
|
|
1615
|
+
for name in ["jetson-tx2", "jetson-nano", "jetson-nano-2gb"]:
|
|
1616
|
+
profile = get_profile(name)
|
|
1617
|
+
if profile:
|
|
1618
|
+
vram_gb = profile.vram_bytes / (1024**3)
|
|
1619
|
+
print(
|
|
1620
|
+
f" {name:20} {profile.name:30} {vram_gb:3.0f} GB {profile.peak_fp16_tflops:6.3f} TF16"
|
|
1621
|
+
)
|
|
1622
|
+
|
|
1623
|
+
print("\nConsumer GPUs - RTX 40 Series:")
|
|
1624
|
+
for name in [
|
|
1625
|
+
"rtx4090",
|
|
1626
|
+
"4080-super",
|
|
1627
|
+
"rtx4080",
|
|
1628
|
+
"4070-ti-super",
|
|
1629
|
+
"4070-ti",
|
|
1630
|
+
"4070-super",
|
|
1631
|
+
"rtx4070",
|
|
1632
|
+
"4060-ti-16gb",
|
|
1633
|
+
"rtx4060",
|
|
1634
|
+
]:
|
|
1635
|
+
profile = get_profile(name)
|
|
1636
|
+
if profile:
|
|
1637
|
+
print(
|
|
1638
|
+
f" {name:20} {profile.name:30} {profile.vram_bytes // (1024**3):3} GB {profile.peak_fp16_tflops:6.1f} TF16"
|
|
1639
|
+
)
|
|
1640
|
+
|
|
1641
|
+
print("\nConsumer GPUs - RTX 30 Series:")
|
|
1642
|
+
for name in [
|
|
1643
|
+
"3090-ti",
|
|
1644
|
+
"rtx3090",
|
|
1645
|
+
"3080-ti",
|
|
1646
|
+
"3080-12gb",
|
|
1647
|
+
"rtx3080",
|
|
1648
|
+
"3070-ti",
|
|
1649
|
+
"rtx3070",
|
|
1650
|
+
"3060-ti",
|
|
1651
|
+
"rtx3060",
|
|
1652
|
+
"rtx3050",
|
|
1653
|
+
]:
|
|
1654
|
+
profile = get_profile(name)
|
|
1655
|
+
if profile:
|
|
1656
|
+
print(
|
|
1657
|
+
f" {name:20} {profile.name:30} {profile.vram_bytes // (1024**3):3} GB {profile.peak_fp16_tflops:6.1f} TF16"
|
|
1658
|
+
)
|
|
1659
|
+
|
|
1660
|
+
print("\nLaptop/Mobile GPUs:")
|
|
1661
|
+
for name in [
|
|
1662
|
+
"4090-mobile",
|
|
1663
|
+
"4080-mobile",
|
|
1664
|
+
"4070-mobile",
|
|
1665
|
+
"3080-mobile",
|
|
1666
|
+
"3070-mobile",
|
|
1667
|
+
]:
|
|
1668
|
+
profile = get_profile(name)
|
|
1669
|
+
if profile:
|
|
1670
|
+
print(
|
|
1671
|
+
f" {name:20} {profile.name:30} {profile.vram_bytes // (1024**3):3} GB {profile.peak_fp16_tflops:6.1f} TF16"
|
|
1672
|
+
)
|
|
1673
|
+
|
|
1674
|
+
print("\nOther:")
|
|
1675
|
+
print(" auto Auto-detect local GPU/CPU")
|
|
1676
|
+
print(" cpu Generic CPU profile")
|
|
1677
|
+
|
|
1678
|
+
print("\n" + "-" * 70)
|
|
1679
|
+
print("TF16 = Peak FP16 TFLOPS (higher = faster)")
|
|
1680
|
+
print("Use --gpu-count N for multi-GPU estimates")
|
|
1681
|
+
print("Use --list-cloud for cloud instance options")
|
|
1682
|
+
print("=" * 70 + "\n")
|
|
1683
|
+
sys.exit(0)
|
|
1684
|
+
|
|
1685
|
+
# Handle --list-cloud
|
|
1686
|
+
if args.list_cloud:
|
|
1687
|
+
print("\n" + "=" * 70)
|
|
1688
|
+
print("Available Cloud Instance Profiles")
|
|
1689
|
+
print("=" * 70)
|
|
1690
|
+
|
|
1691
|
+
print("\nAWS GPU Instances:")
|
|
1692
|
+
for name, instance in CLOUD_INSTANCES.items():
|
|
1693
|
+
if instance.provider == "aws":
|
|
1694
|
+
vram_gb = instance.hardware.vram_bytes * instance.gpu_count // (1024**3)
|
|
1695
|
+
print(
|
|
1696
|
+
f" {name:25} {instance.gpu_count}x GPU {vram_gb:4} GB ${instance.hourly_cost_usd:6.2f}/hr"
|
|
1697
|
+
)
|
|
1698
|
+
|
|
1699
|
+
print("\nAzure GPU Instances:")
|
|
1700
|
+
for name, instance in CLOUD_INSTANCES.items():
|
|
1701
|
+
if instance.provider == "azure":
|
|
1702
|
+
vram_gb = instance.hardware.vram_bytes * instance.gpu_count // (1024**3)
|
|
1703
|
+
print(
|
|
1704
|
+
f" {name:25} {instance.gpu_count}x GPU {vram_gb:4} GB ${instance.hourly_cost_usd:6.2f}/hr"
|
|
1705
|
+
)
|
|
1706
|
+
|
|
1707
|
+
print("\nGCP GPU Instances:")
|
|
1708
|
+
for name, instance in CLOUD_INSTANCES.items():
|
|
1709
|
+
if instance.provider == "gcp":
|
|
1710
|
+
vram_gb = instance.hardware.vram_bytes * instance.gpu_count // (1024**3)
|
|
1711
|
+
print(
|
|
1712
|
+
f" {name:25} {instance.gpu_count}x GPU {vram_gb:4} GB ${instance.hourly_cost_usd:6.2f}/hr"
|
|
1713
|
+
)
|
|
1714
|
+
|
|
1715
|
+
print("\n" + "-" * 70)
|
|
1716
|
+
print("Prices are approximate on-demand rates (us-east-1 or equivalent)")
|
|
1717
|
+
print("Use --cloud <instance> to get cost estimates for your model")
|
|
1718
|
+
print("=" * 70 + "\n")
|
|
1719
|
+
sys.exit(0)
|
|
1720
|
+
|
|
1721
|
+
# Handle --list-conversions
|
|
1722
|
+
if args.list_conversions:
|
|
1723
|
+
from .format_adapters import list_conversion_paths
|
|
1724
|
+
|
|
1725
|
+
print("\n" + "=" * 70)
|
|
1726
|
+
print("Available Format Conversion Paths")
|
|
1727
|
+
print("=" * 70)
|
|
1728
|
+
|
|
1729
|
+
paths = list_conversion_paths()
|
|
1730
|
+
current_source = None
|
|
1731
|
+
|
|
1732
|
+
for path in paths:
|
|
1733
|
+
if path["source"] != current_source:
|
|
1734
|
+
current_source = path["source"]
|
|
1735
|
+
print(f"\nFrom {current_source.upper()}:")
|
|
1736
|
+
|
|
1737
|
+
level_icons = {
|
|
1738
|
+
"full": "[FULL] ",
|
|
1739
|
+
"partial": "[PARTIAL] ",
|
|
1740
|
+
"lossy": "[LOSSY] ",
|
|
1741
|
+
"none": "[NONE] ",
|
|
1742
|
+
}
|
|
1743
|
+
icon = level_icons.get(path["level"], " ")
|
|
1744
|
+
print(f" {icon} -> {path['target']}")
|
|
1745
|
+
|
|
1746
|
+
print("\n" + "-" * 70)
|
|
1747
|
+
print("Conversion Levels:")
|
|
1748
|
+
print(" FULL = Lossless, complete conversion")
|
|
1749
|
+
print(" PARTIAL = Some limitations or multi-step required")
|
|
1750
|
+
print(" LOSSY = Information may be lost")
|
|
1751
|
+
print(" NONE = No conversion path available")
|
|
1752
|
+
print("\nUse --convert-to FORMAT to convert a model")
|
|
1753
|
+
print("=" * 70 + "\n")
|
|
1754
|
+
sys.exit(0)
|
|
1755
|
+
|
|
1756
|
+
# Handle model conversion if requested
|
|
1757
|
+
temp_onnx_file = None
|
|
1758
|
+
conversion_sources = [
|
|
1759
|
+
("--from-pytorch", args.from_pytorch),
|
|
1760
|
+
("--from-tensorflow", args.from_tensorflow),
|
|
1761
|
+
("--from-keras", args.from_keras),
|
|
1762
|
+
("--from-frozen-graph", args.from_frozen_graph),
|
|
1763
|
+
("--from-jax", args.from_jax),
|
|
1764
|
+
]
|
|
1765
|
+
active_conversions = [(name, path) for name, path in conversion_sources if path]
|
|
1766
|
+
|
|
1767
|
+
if len(active_conversions) > 1:
|
|
1768
|
+
names = [name for name, _ in active_conversions]
|
|
1769
|
+
logger.error(f"Cannot use multiple conversion flags together: {', '.join(names)}")
|
|
1770
|
+
sys.exit(1)
|
|
1771
|
+
|
|
1772
|
+
if args.from_pytorch:
|
|
1773
|
+
model_path, temp_onnx_file = _convert_pytorch_to_onnx(
|
|
1774
|
+
args.from_pytorch,
|
|
1775
|
+
args.input_shape,
|
|
1776
|
+
args.keep_onnx,
|
|
1777
|
+
args.opset_version,
|
|
1778
|
+
logger,
|
|
1779
|
+
)
|
|
1780
|
+
if model_path is None:
|
|
1781
|
+
sys.exit(1)
|
|
1782
|
+
elif args.from_tensorflow:
|
|
1783
|
+
model_path, temp_onnx_file = _convert_tensorflow_to_onnx(
|
|
1784
|
+
args.from_tensorflow,
|
|
1785
|
+
args.keep_onnx,
|
|
1786
|
+
args.opset_version,
|
|
1787
|
+
logger,
|
|
1788
|
+
)
|
|
1789
|
+
if model_path is None:
|
|
1790
|
+
sys.exit(1)
|
|
1791
|
+
elif args.from_keras:
|
|
1792
|
+
model_path, temp_onnx_file = _convert_keras_to_onnx(
|
|
1793
|
+
args.from_keras,
|
|
1794
|
+
args.keep_onnx,
|
|
1795
|
+
args.opset_version,
|
|
1796
|
+
logger,
|
|
1797
|
+
)
|
|
1798
|
+
if model_path is None:
|
|
1799
|
+
sys.exit(1)
|
|
1800
|
+
elif args.from_frozen_graph:
|
|
1801
|
+
model_path, temp_onnx_file = _convert_frozen_graph_to_onnx(
|
|
1802
|
+
args.from_frozen_graph,
|
|
1803
|
+
args.tf_inputs,
|
|
1804
|
+
args.tf_outputs,
|
|
1805
|
+
args.keep_onnx,
|
|
1806
|
+
args.opset_version,
|
|
1807
|
+
logger,
|
|
1808
|
+
)
|
|
1809
|
+
if model_path is None:
|
|
1810
|
+
sys.exit(1)
|
|
1811
|
+
elif args.from_jax:
|
|
1812
|
+
model_path, temp_onnx_file = _convert_jax_to_onnx(
|
|
1813
|
+
args.from_jax,
|
|
1814
|
+
args.jax_apply_fn,
|
|
1815
|
+
args.input_shape,
|
|
1816
|
+
args.keep_onnx,
|
|
1817
|
+
args.opset_version,
|
|
1818
|
+
logger,
|
|
1819
|
+
)
|
|
1820
|
+
if model_path is None:
|
|
1821
|
+
sys.exit(1)
|
|
1822
|
+
else:
|
|
1823
|
+
# Validate model path (no conversion requested)
|
|
1824
|
+
if args.model_path is None:
|
|
1825
|
+
logger.error(
|
|
1826
|
+
"Model path is required. Use --list-hardware to see available profiles, "
|
|
1827
|
+
"or use a conversion flag (--from-pytorch, --from-tensorflow, --from-keras, "
|
|
1828
|
+
"--from-frozen-graph, --from-jax)."
|
|
1829
|
+
)
|
|
1830
|
+
sys.exit(1)
|
|
1831
|
+
|
|
1832
|
+
model_path = args.model_path.resolve()
|
|
1833
|
+
if not model_path.exists():
|
|
1834
|
+
logger.error(f"Model file not found: {model_path}")
|
|
1835
|
+
sys.exit(1)
|
|
1836
|
+
|
|
1837
|
+
if model_path.suffix.lower() not in (".onnx", ".pb", ".ort"):
|
|
1838
|
+
logger.warning(f"Unexpected file extension: {model_path.suffix}. Proceeding anyway.")
|
|
1839
|
+
|
|
1840
|
+
# Handle --convert-to for format conversion
|
|
1841
|
+
if args.convert_to:
|
|
1842
|
+
from .format_adapters import (
|
|
1843
|
+
ConversionLevel,
|
|
1844
|
+
OnnxAdapter,
|
|
1845
|
+
get_conversion_level,
|
|
1846
|
+
load_model,
|
|
1847
|
+
)
|
|
1848
|
+
from .universal_ir import SourceFormat
|
|
1849
|
+
|
|
1850
|
+
if not args.convert_output:
|
|
1851
|
+
logger.error("--convert-output is required when using --convert-to")
|
|
1852
|
+
sys.exit(1)
|
|
1853
|
+
|
|
1854
|
+
target_format = SourceFormat(args.convert_to.lower())
|
|
1855
|
+
output_path = args.convert_output.resolve()
|
|
1856
|
+
|
|
1857
|
+
# Determine source format from file extension
|
|
1858
|
+
source_ext = model_path.suffix.lower()
|
|
1859
|
+
source_format_map = {
|
|
1860
|
+
".onnx": SourceFormat.ONNX,
|
|
1861
|
+
".pt": SourceFormat.PYTORCH,
|
|
1862
|
+
".pth": SourceFormat.PYTORCH,
|
|
1863
|
+
}
|
|
1864
|
+
source_format = source_format_map.get(source_ext, SourceFormat.UNKNOWN)
|
|
1865
|
+
|
|
1866
|
+
# Check conversion level
|
|
1867
|
+
level = get_conversion_level(source_format, target_format)
|
|
1868
|
+
if level == ConversionLevel.NONE:
|
|
1869
|
+
logger.error(
|
|
1870
|
+
f"Cannot convert from {source_format.value} to {target_format.value}. "
|
|
1871
|
+
f"Use --list-conversions to see available paths."
|
|
1872
|
+
)
|
|
1873
|
+
sys.exit(1)
|
|
1874
|
+
elif level == ConversionLevel.LOSSY:
|
|
1875
|
+
logger.warning(
|
|
1876
|
+
f"Converting {source_format.value} to {target_format.value} may lose information"
|
|
1877
|
+
)
|
|
1878
|
+
|
|
1879
|
+
# Load to Universal IR and convert
|
|
1880
|
+
try:
|
|
1881
|
+
logger.info(f"Loading {model_path} as Universal IR...")
|
|
1882
|
+
graph = load_model(model_path)
|
|
1883
|
+
|
|
1884
|
+
if target_format == SourceFormat.ONNX:
|
|
1885
|
+
logger.info(f"Exporting to ONNX: {output_path}")
|
|
1886
|
+
OnnxAdapter().write(graph, output_path)
|
|
1887
|
+
logger.info(f"Successfully converted to {output_path}")
|
|
1888
|
+
print(f"\nConverted: {model_path} -> {output_path}")
|
|
1889
|
+
print(f" Nodes: {graph.num_nodes}")
|
|
1890
|
+
print(f" Parameters: {graph.total_parameters:,}")
|
|
1891
|
+
print(f" Source: {graph.metadata.source_format.value}")
|
|
1892
|
+
else:
|
|
1893
|
+
logger.error(f"Conversion to {target_format.value} not yet implemented")
|
|
1894
|
+
sys.exit(1)
|
|
1895
|
+
|
|
1896
|
+
# If no analysis requested, exit after conversion
|
|
1897
|
+
if not (args.out_json or args.out_md or args.out_html or args.out_pdf):
|
|
1898
|
+
sys.exit(0)
|
|
1899
|
+
|
|
1900
|
+
# Use the converted file for analysis
|
|
1901
|
+
model_path = output_path
|
|
1902
|
+
|
|
1903
|
+
except Exception as e:
|
|
1904
|
+
logger.error(f"Conversion failed: {e}")
|
|
1905
|
+
sys.exit(1)
|
|
1906
|
+
|
|
1907
|
+
# Handle --export-ir and --export-graph (Universal IR exports)
|
|
1908
|
+
if args.export_ir or args.export_graph:
|
|
1909
|
+
from .format_adapters import load_model
|
|
1910
|
+
|
|
1911
|
+
try:
|
|
1912
|
+
logger.info(f"Loading {model_path} as Universal IR...")
|
|
1913
|
+
ir_graph = load_model(model_path)
|
|
1914
|
+
|
|
1915
|
+
if args.export_ir:
|
|
1916
|
+
logger.info(f"Exporting Universal IR to {args.export_ir}")
|
|
1917
|
+
ir_graph.to_json(args.export_ir)
|
|
1918
|
+
print(f"Exported IR: {args.export_ir}")
|
|
1919
|
+
print(f" Nodes: {ir_graph.num_nodes}")
|
|
1920
|
+
print(f" Parameters: {ir_graph.total_parameters:,}")
|
|
1921
|
+
|
|
1922
|
+
if args.export_graph:
|
|
1923
|
+
export_path = args.export_graph
|
|
1924
|
+
suffix = export_path.suffix.lower()
|
|
1925
|
+
|
|
1926
|
+
if suffix == ".dot":
|
|
1927
|
+
logger.info(f"Exporting graph to DOT: {export_path}")
|
|
1928
|
+
ir_graph.save_dot(export_path)
|
|
1929
|
+
print(f"Exported graph: {export_path}")
|
|
1930
|
+
elif suffix == ".png":
|
|
1931
|
+
logger.info(f"Rendering graph to PNG: {export_path}")
|
|
1932
|
+
ir_graph.save_png(export_path, max_nodes=args.graph_max_nodes)
|
|
1933
|
+
print(f"Rendered graph: {export_path}")
|
|
1934
|
+
else:
|
|
1935
|
+
logger.error(f"Unsupported graph format: {suffix}. Use .dot or .png")
|
|
1936
|
+
sys.exit(1)
|
|
1937
|
+
|
|
1938
|
+
# If no other outputs requested, exit after IR export
|
|
1939
|
+
if not (args.out_json or args.out_md or args.out_html or args.out_pdf):
|
|
1940
|
+
sys.exit(0)
|
|
1941
|
+
|
|
1942
|
+
except ImportError as e:
|
|
1943
|
+
logger.error(f"Missing dependency: {e}")
|
|
1944
|
+
sys.exit(1)
|
|
1945
|
+
except Exception as e:
|
|
1946
|
+
logger.error(f"IR export failed: {e}")
|
|
1947
|
+
sys.exit(1)
|
|
1948
|
+
|
|
1949
|
+
# Determine hardware profile
|
|
1950
|
+
hardware_profile = None
|
|
1951
|
+
cloud_instance = None
|
|
1952
|
+
|
|
1953
|
+
if args.cloud:
|
|
1954
|
+
# Cloud instance takes precedence
|
|
1955
|
+
cloud_instance = get_cloud_instance(args.cloud)
|
|
1956
|
+
if cloud_instance is None:
|
|
1957
|
+
logger.error(f"Unknown cloud instance: {args.cloud}")
|
|
1958
|
+
logger.error("Use --list-cloud to see available instances.")
|
|
1959
|
+
sys.exit(1)
|
|
1960
|
+
hardware_profile = cloud_instance.hardware
|
|
1961
|
+
# Override gpu_count from cloud instance
|
|
1962
|
+
args.gpu_count = cloud_instance.gpu_count
|
|
1963
|
+
logger.info(
|
|
1964
|
+
f"Using cloud instance: {cloud_instance.name} "
|
|
1965
|
+
f"({cloud_instance.gpu_count}x GPU, ${cloud_instance.hourly_cost_usd:.2f}/hr)"
|
|
1966
|
+
)
|
|
1967
|
+
elif args.hardware:
|
|
1968
|
+
if args.hardware.lower() == "auto":
|
|
1969
|
+
logger.info("Auto-detecting local hardware...")
|
|
1970
|
+
hardware_profile = detect_local_hardware()
|
|
1971
|
+
logger.info(f"Detected: {hardware_profile.name}")
|
|
1972
|
+
else:
|
|
1973
|
+
hardware_profile = get_profile(args.hardware)
|
|
1974
|
+
if hardware_profile is None:
|
|
1975
|
+
logger.error(f"Unknown hardware profile: {args.hardware}")
|
|
1976
|
+
logger.error("Use --list-hardware to see available profiles.")
|
|
1977
|
+
sys.exit(1)
|
|
1978
|
+
logger.info(f"Using hardware profile: {hardware_profile.name}")
|
|
1979
|
+
|
|
1980
|
+
# Apply multi-GPU scaling if requested
|
|
1981
|
+
if hardware_profile and args.gpu_count > 1 and not args.cloud:
|
|
1982
|
+
multi_gpu = create_multi_gpu_profile(args.hardware or "auto", args.gpu_count)
|
|
1983
|
+
if multi_gpu:
|
|
1984
|
+
hardware_profile = multi_gpu.get_effective_profile()
|
|
1985
|
+
logger.info(
|
|
1986
|
+
f"Multi-GPU: {args.gpu_count}x scaling with "
|
|
1987
|
+
f"{multi_gpu.compute_efficiency:.0%} efficiency"
|
|
1988
|
+
)
|
|
1989
|
+
|
|
1990
|
+
# Setup progress indicator
|
|
1991
|
+
progress = ProgressIndicator(enabled=args.progress, quiet=args.quiet)
|
|
1992
|
+
|
|
1993
|
+
# Calculate total steps based on what will be done
|
|
1994
|
+
total_steps = 2 # Load + Analyze always
|
|
1995
|
+
if hardware_profile:
|
|
1996
|
+
total_steps += 1
|
|
1997
|
+
if args.system_requirements:
|
|
1998
|
+
total_steps += 1
|
|
1999
|
+
if args.sweep_batch_sizes and hardware_profile:
|
|
2000
|
+
total_steps += 1
|
|
2001
|
+
if args.sweep_resolutions and hardware_profile:
|
|
2002
|
+
total_steps += 1
|
|
2003
|
+
if not args.no_gpu_metrics:
|
|
2004
|
+
total_steps += 1
|
|
2005
|
+
if not args.no_profile:
|
|
2006
|
+
total_steps += 1
|
|
2007
|
+
if not args.no_profile and not args.no_bottleneck_analysis and hardware_profile:
|
|
2008
|
+
total_steps += 1
|
|
2009
|
+
if not args.no_benchmark_resolutions:
|
|
2010
|
+
total_steps += 1
|
|
2011
|
+
if args.with_plots and is_viz_available():
|
|
2012
|
+
total_steps += 1
|
|
2013
|
+
if args.llm_summary and is_llm_available() and has_llm_api_key():
|
|
2014
|
+
total_steps += 1
|
|
2015
|
+
if args.out_json or args.out_md or args.out_html:
|
|
2016
|
+
total_steps += 1
|
|
2017
|
+
|
|
2018
|
+
progress.start(total_steps, f"Analyzing {model_path.name}")
|
|
2019
|
+
|
|
2020
|
+
# Run inspection
|
|
2021
|
+
try:
|
|
2022
|
+
progress.step("Loading model and extracting graph structure")
|
|
2023
|
+
inspector = ModelInspector(logger=logger)
|
|
2024
|
+
report = inspector.inspect(model_path)
|
|
2025
|
+
progress.step("Computing metrics (params, FLOPs, memory)")
|
|
2026
|
+
|
|
2027
|
+
# Add hardware estimates if profile specified
|
|
2028
|
+
if (
|
|
2029
|
+
hardware_profile
|
|
2030
|
+
and report.param_counts
|
|
2031
|
+
and report.flop_counts
|
|
2032
|
+
and report.memory_estimates
|
|
2033
|
+
):
|
|
2034
|
+
progress.step(f"Estimating performance on {hardware_profile.name}")
|
|
2035
|
+
estimator = HardwareEstimator(logger=logger)
|
|
2036
|
+
hw_estimates = estimator.estimate(
|
|
2037
|
+
model_params=report.param_counts.total,
|
|
2038
|
+
model_flops=report.flop_counts.total,
|
|
2039
|
+
peak_activation_bytes=report.memory_estimates.peak_activation_bytes,
|
|
2040
|
+
hardware=hardware_profile,
|
|
2041
|
+
batch_size=args.batch_size,
|
|
2042
|
+
precision=args.precision,
|
|
2043
|
+
)
|
|
2044
|
+
report.hardware_estimates = hw_estimates
|
|
2045
|
+
report.hardware_profile = hardware_profile
|
|
2046
|
+
|
|
2047
|
+
# Batch Size Sweep (Story 6C.1)
|
|
2048
|
+
if args.sweep_batch_sizes:
|
|
2049
|
+
profiler = OperationalProfiler(logger=logger)
|
|
2050
|
+
|
|
2051
|
+
if args.no_benchmark:
|
|
2052
|
+
# Use theoretical estimates (faster but less accurate)
|
|
2053
|
+
progress.step("Running batch size sweep (theoretical)")
|
|
2054
|
+
sweep_result = profiler.run_batch_sweep(
|
|
2055
|
+
model_params=report.param_counts.total,
|
|
2056
|
+
model_flops=report.flop_counts.total,
|
|
2057
|
+
peak_activation_bytes=report.memory_estimates.peak_activation_bytes,
|
|
2058
|
+
hardware=hardware_profile,
|
|
2059
|
+
precision=args.precision,
|
|
2060
|
+
)
|
|
2061
|
+
else:
|
|
2062
|
+
# Default: use actual inference benchmarking
|
|
2063
|
+
progress.step("Benchmarking batch sizes (actual inference)")
|
|
2064
|
+
sweep_result = profiler.run_batch_sweep_benchmark(
|
|
2065
|
+
model_path=str(model_path),
|
|
2066
|
+
batch_sizes=[1, 2, 4, 8, 16, 32, 64, 128],
|
|
2067
|
+
)
|
|
2068
|
+
if sweep_result is None:
|
|
2069
|
+
# Fall back to theoretical if benchmark fails
|
|
2070
|
+
logger.warning("Benchmark failed, using theoretical estimates")
|
|
2071
|
+
sweep_result = profiler.run_batch_sweep(
|
|
2072
|
+
model_params=report.param_counts.total,
|
|
2073
|
+
model_flops=report.flop_counts.total,
|
|
2074
|
+
peak_activation_bytes=report.memory_estimates.peak_activation_bytes,
|
|
2075
|
+
hardware=hardware_profile,
|
|
2076
|
+
precision=args.precision,
|
|
2077
|
+
)
|
|
2078
|
+
|
|
2079
|
+
report.batch_size_sweep = sweep_result
|
|
2080
|
+
logger.info(
|
|
2081
|
+
f"Batch sweep complete. Optimal batch size: {sweep_result.optimal_batch_size}"
|
|
2082
|
+
)
|
|
2083
|
+
|
|
2084
|
+
# Resolution Sweep (Story 6.8)
|
|
2085
|
+
if args.sweep_resolutions:
|
|
2086
|
+
progress.step("Running resolution sweep")
|
|
2087
|
+
profiler = OperationalProfiler(logger=logger)
|
|
2088
|
+
|
|
2089
|
+
# Determine base/training resolution from model input shape
|
|
2090
|
+
base_resolution = (224, 224) # Default
|
|
2091
|
+
if report.graph_summary and report.graph_summary.input_shapes:
|
|
2092
|
+
for shape in report.graph_summary.input_shapes.values():
|
|
2093
|
+
if len(shape) >= 3:
|
|
2094
|
+
# Assume NCHW or NHWC format
|
|
2095
|
+
h, w = shape[-2], shape[-1]
|
|
2096
|
+
if isinstance(h, int) and isinstance(w, int) and h > 1 and w > 1:
|
|
2097
|
+
base_resolution = (h, w)
|
|
2098
|
+
break
|
|
2099
|
+
|
|
2100
|
+
# Parse resolutions from CLI argument
|
|
2101
|
+
# Note: Only resolutions UP TO training resolution are reliable
|
|
2102
|
+
resolutions: list[tuple[int, int]] | None = None
|
|
2103
|
+
if args.sweep_resolutions != "auto":
|
|
2104
|
+
resolutions = []
|
|
2105
|
+
for res_part in args.sweep_resolutions.split(","):
|
|
2106
|
+
res_str = res_part.strip()
|
|
2107
|
+
if "x" in res_str:
|
|
2108
|
+
h, w = res_str.split("x")
|
|
2109
|
+
res_h, res_w = int(h), int(w)
|
|
2110
|
+
# Warn if resolution exceeds training resolution
|
|
2111
|
+
if res_h > base_resolution[0] or res_w > base_resolution[1]:
|
|
2112
|
+
logger.warning(
|
|
2113
|
+
f"Resolution {res_str} exceeds training resolution "
|
|
2114
|
+
f"{base_resolution[0]}x{base_resolution[1]}. "
|
|
2115
|
+
"Results may be unreliable."
|
|
2116
|
+
)
|
|
2117
|
+
resolutions.append((res_h, res_w))
|
|
2118
|
+
|
|
2119
|
+
res_sweep_result = profiler.run_resolution_sweep(
|
|
2120
|
+
base_flops=report.flop_counts.total,
|
|
2121
|
+
base_activation_bytes=report.memory_estimates.peak_activation_bytes,
|
|
2122
|
+
base_resolution=base_resolution,
|
|
2123
|
+
model_params=report.param_counts.total,
|
|
2124
|
+
hardware=hardware_profile,
|
|
2125
|
+
resolutions=resolutions,
|
|
2126
|
+
batch_size=args.batch_size,
|
|
2127
|
+
precision=args.precision,
|
|
2128
|
+
)
|
|
2129
|
+
report.resolution_sweep = res_sweep_result
|
|
2130
|
+
logger.info(
|
|
2131
|
+
f"Resolution sweep complete. Max resolution: {res_sweep_result.max_resolution}, "
|
|
2132
|
+
f"Optimal: {res_sweep_result.optimal_resolution}"
|
|
2133
|
+
)
|
|
2134
|
+
|
|
2135
|
+
# System Requirements (Story 6C.2)
|
|
2136
|
+
if (
|
|
2137
|
+
args.system_requirements
|
|
2138
|
+
and report.param_counts
|
|
2139
|
+
and report.flop_counts
|
|
2140
|
+
and report.memory_estimates
|
|
2141
|
+
):
|
|
2142
|
+
progress.step("Generating system requirements")
|
|
2143
|
+
profiler = OperationalProfiler(logger=logger)
|
|
2144
|
+
|
|
2145
|
+
# Derive a target FPS based on deployment target / explicit knobs
|
|
2146
|
+
target_fps: float | None = None
|
|
2147
|
+
if getattr(args, "target_throughput_fps", None):
|
|
2148
|
+
target_fps = float(args.target_throughput_fps)
|
|
2149
|
+
elif getattr(args, "target_latency_ms", None):
|
|
2150
|
+
# latency (ms) -> fps
|
|
2151
|
+
if args.target_latency_ms > 0:
|
|
2152
|
+
target_fps = 1000.0 / float(args.target_latency_ms)
|
|
2153
|
+
|
|
2154
|
+
# Fallback targets based on deployment target category
|
|
2155
|
+
if target_fps is None:
|
|
2156
|
+
if args.deployment_target == "edge":
|
|
2157
|
+
target_fps = 30.0
|
|
2158
|
+
elif args.deployment_target == "local":
|
|
2159
|
+
target_fps = 60.0
|
|
2160
|
+
elif args.deployment_target == "cloud":
|
|
2161
|
+
target_fps = 120.0
|
|
2162
|
+
else:
|
|
2163
|
+
target_fps = 30.0
|
|
2164
|
+
|
|
2165
|
+
sys_reqs = profiler.determine_system_requirements(
|
|
2166
|
+
model_params=report.param_counts.total,
|
|
2167
|
+
model_flops=report.flop_counts.total,
|
|
2168
|
+
peak_activation_bytes=report.memory_estimates.peak_activation_bytes,
|
|
2169
|
+
precision=args.precision,
|
|
2170
|
+
target_fps=target_fps,
|
|
2171
|
+
)
|
|
2172
|
+
report.system_requirements = sys_reqs
|
|
2173
|
+
if sys_reqs.recommended_gpu:
|
|
2174
|
+
logger.info(f"Recommended device: {sys_reqs.recommended_gpu.device}")
|
|
2175
|
+
elif sys_reqs.minimum_gpu:
|
|
2176
|
+
logger.info(f"Minimum device: {sys_reqs.minimum_gpu.device}")
|
|
2177
|
+
|
|
2178
|
+
# Epic 9: Runtime Profiling (defaults ON for real measurements)
|
|
2179
|
+
profiler = OperationalProfiler(logger=logger)
|
|
2180
|
+
profiling_result = None # Store for use in NN graph visualization
|
|
2181
|
+
|
|
2182
|
+
# GPU Metrics (Story 9.2) - default ON
|
|
2183
|
+
if not args.no_gpu_metrics:
|
|
2184
|
+
progress.step("Capturing GPU metrics")
|
|
2185
|
+
gpu_metrics = profiler.get_gpu_metrics()
|
|
2186
|
+
if gpu_metrics:
|
|
2187
|
+
logger.info(
|
|
2188
|
+
f"GPU: {gpu_metrics.vram_used_bytes / (1024**3):.2f} GB VRAM used, "
|
|
2189
|
+
f"{gpu_metrics.gpu_utilization_percent:.0f}% utilization, "
|
|
2190
|
+
f"{gpu_metrics.temperature_c}C"
|
|
2191
|
+
)
|
|
2192
|
+
# Store in report (add to JSON output)
|
|
2193
|
+
report.extra_data = report.extra_data or {}
|
|
2194
|
+
report.extra_data["gpu_metrics"] = gpu_metrics.to_dict()
|
|
2195
|
+
else:
|
|
2196
|
+
logger.debug("GPU metrics unavailable (pynvml not installed)")
|
|
2197
|
+
|
|
2198
|
+
# Per-Layer Profiling (Story 9.3) - default ON
|
|
2199
|
+
if not args.no_profile:
|
|
2200
|
+
progress.step("Running ONNX Runtime profiler")
|
|
2201
|
+
profiling_result = profiler.profile_model(
|
|
2202
|
+
model_path=str(model_path),
|
|
2203
|
+
batch_size=args.batch_size,
|
|
2204
|
+
num_runs=args.profile_runs,
|
|
2205
|
+
)
|
|
2206
|
+
if profiling_result:
|
|
2207
|
+
logger.info(
|
|
2208
|
+
f"Profiling complete: {profiling_result.total_time_ms:.2f}ms "
|
|
2209
|
+
f"({len(profiling_result.layer_profiles)} layers)"
|
|
2210
|
+
)
|
|
2211
|
+
# Show slowest layers
|
|
2212
|
+
slowest = profiling_result.get_slowest_layers(5)
|
|
2213
|
+
if slowest:
|
|
2214
|
+
logger.info("Top 5 slowest layers:")
|
|
2215
|
+
for lp in slowest:
|
|
2216
|
+
logger.info(f" {lp.name}: {lp.duration_ms:.3f}ms ({lp.op_type})")
|
|
2217
|
+
|
|
2218
|
+
# Store in report
|
|
2219
|
+
report.extra_data = report.extra_data or {}
|
|
2220
|
+
report.extra_data["profiling"] = profiling_result.to_dict()
|
|
2221
|
+
|
|
2222
|
+
# Bottleneck Analysis (Story 9.4) - default ON
|
|
2223
|
+
if not args.no_bottleneck_analysis and hardware_profile:
|
|
2224
|
+
progress.step("Analyzing bottlenecks")
|
|
2225
|
+
bottleneck = profiler.analyze_bottleneck(
|
|
2226
|
+
model_flops=report.flop_counts.total,
|
|
2227
|
+
profiling_result=profiling_result,
|
|
2228
|
+
hardware=hardware_profile,
|
|
2229
|
+
precision=args.precision,
|
|
2230
|
+
)
|
|
2231
|
+
logger.info(
|
|
2232
|
+
f"Bottleneck: {bottleneck.bottleneck_type} "
|
|
2233
|
+
f"(compute: {bottleneck.compute_ratio:.0%}, "
|
|
2234
|
+
f"memory: {bottleneck.memory_ratio:.0%})"
|
|
2235
|
+
)
|
|
2236
|
+
logger.info(f"Efficiency: {bottleneck.efficiency_percent:.1f}%")
|
|
2237
|
+
for rec in bottleneck.recommendations[:3]:
|
|
2238
|
+
logger.info(f" - {rec}")
|
|
2239
|
+
report.extra_data["bottleneck_analysis"] = bottleneck.to_dict()
|
|
2240
|
+
else:
|
|
2241
|
+
logger.debug("Profiling unavailable (onnxruntime not installed)")
|
|
2242
|
+
|
|
2243
|
+
# Resolution Benchmarking (Story 9.5) - default ON
|
|
2244
|
+
if not args.no_benchmark_resolutions:
|
|
2245
|
+
progress.step("Benchmarking resolutions (actual inference)")
|
|
2246
|
+
res_benchmark = profiler.benchmark_resolutions(
|
|
2247
|
+
model_path=str(model_path),
|
|
2248
|
+
batch_size=args.batch_size,
|
|
2249
|
+
)
|
|
2250
|
+
if res_benchmark:
|
|
2251
|
+
logger.info(
|
|
2252
|
+
f"Resolution benchmark complete. Optimal: {res_benchmark.optimal_resolution}"
|
|
2253
|
+
)
|
|
2254
|
+
report.extra_data = report.extra_data or {}
|
|
2255
|
+
report.extra_data["resolution_benchmark"] = res_benchmark.to_dict()
|
|
2256
|
+
else:
|
|
2257
|
+
logger.debug("Resolution benchmark unavailable")
|
|
2258
|
+
|
|
2259
|
+
except Exception as e:
|
|
2260
|
+
logger.error(f"Failed to inspect model: {e}")
|
|
2261
|
+
if args.log_level == "debug":
|
|
2262
|
+
import traceback
|
|
2263
|
+
|
|
2264
|
+
traceback.print_exc()
|
|
2265
|
+
sys.exit(1)
|
|
2266
|
+
|
|
2267
|
+
# Extract dataset metadata if PyTorch weights provided
|
|
2268
|
+
if args.pytorch_weights or args.from_pytorch:
|
|
2269
|
+
weights_path = args.pytorch_weights or args.from_pytorch
|
|
2270
|
+
if weights_path.exists():
|
|
2271
|
+
logger.info(f"Extracting metadata from: {weights_path}")
|
|
2272
|
+
metadata = _extract_ultralytics_metadata(weights_path, logger)
|
|
2273
|
+
if metadata:
|
|
2274
|
+
from .report import DatasetInfo
|
|
2275
|
+
|
|
2276
|
+
report.dataset_info = DatasetInfo(
|
|
2277
|
+
task=metadata.get("task"),
|
|
2278
|
+
num_classes=metadata.get("num_classes"),
|
|
2279
|
+
class_names=metadata.get("class_names", []),
|
|
2280
|
+
source=metadata.get("source"),
|
|
2281
|
+
)
|
|
2282
|
+
logger.info(
|
|
2283
|
+
f"Extracted {report.dataset_info.num_classes} class(es): "
|
|
2284
|
+
f"{', '.join(report.dataset_info.class_names[:5])}"
|
|
2285
|
+
f"{'...' if len(report.dataset_info.class_names) > 5 else ''}"
|
|
2286
|
+
)
|
|
2287
|
+
|
|
2288
|
+
# Generate LLM summaries if requested
|
|
2289
|
+
llm_summary = None
|
|
2290
|
+
if args.llm_summary:
|
|
2291
|
+
if args.offline:
|
|
2292
|
+
print("\n[OFFLINE MODE] Skipping LLM summary (requires network access)\n")
|
|
2293
|
+
elif not is_llm_available():
|
|
2294
|
+
print("\n" + "=" * 60)
|
|
2295
|
+
print("LLM PACKAGE NOT INSTALLED")
|
|
2296
|
+
print("=" * 60)
|
|
2297
|
+
print("To enable AI-powered summaries, install the LLM extras:\n")
|
|
2298
|
+
print(" pip install haoline[llm]")
|
|
2299
|
+
print("\nThen set your API key and try again.")
|
|
2300
|
+
print("=" * 60 + "\n")
|
|
2301
|
+
elif not has_llm_api_key():
|
|
2302
|
+
print("\n" + "=" * 60)
|
|
2303
|
+
print("API KEY REQUIRED FOR LLM SUMMARIES")
|
|
2304
|
+
print("=" * 60)
|
|
2305
|
+
print("Set one of the following environment variables:\n")
|
|
2306
|
+
print(" PowerShell: $env:OPENAI_API_KEY = 'sk-...'")
|
|
2307
|
+
print(" Bash/Zsh: export OPENAI_API_KEY='sk-...'")
|
|
2308
|
+
print("\nGet your API key at: https://platform.openai.com/api-keys")
|
|
2309
|
+
print("=" * 60 + "\n")
|
|
2310
|
+
else:
|
|
2311
|
+
try:
|
|
2312
|
+
progress.step(f"Generating LLM summary with {args.llm_model}")
|
|
2313
|
+
logger.info(f"Generating LLM summaries with {args.llm_model}...")
|
|
2314
|
+
summarizer = LLMSummarizer(model=args.llm_model, logger=logger)
|
|
2315
|
+
llm_summary = summarizer.summarize(report)
|
|
2316
|
+
if llm_summary.success:
|
|
2317
|
+
logger.info(f"LLM summaries generated ({llm_summary.tokens_used} tokens used)")
|
|
2318
|
+
else:
|
|
2319
|
+
logger.warning(f"LLM summarization failed: {llm_summary.error_message}")
|
|
2320
|
+
except Exception as e:
|
|
2321
|
+
logger.warning(f"Failed to generate LLM summaries: {e}")
|
|
2322
|
+
|
|
2323
|
+
# Store LLM summary in report for output
|
|
2324
|
+
if llm_summary and llm_summary.success:
|
|
2325
|
+
# Add to report dict for JSON output
|
|
2326
|
+
report._llm_summary = llm_summary # type: ignore
|
|
2327
|
+
|
|
2328
|
+
# Apply privacy transformations if requested
|
|
2329
|
+
report_dict = report.to_dict()
|
|
2330
|
+
if args.summary_only:
|
|
2331
|
+
from .privacy import create_summary_only_dict
|
|
2332
|
+
|
|
2333
|
+
logger.info("Applying summary-only mode (omitting per-layer details)")
|
|
2334
|
+
report_dict = create_summary_only_dict(report_dict)
|
|
2335
|
+
elif args.redact_names:
|
|
2336
|
+
from .privacy import collect_names_from_dict, create_name_mapping, redact_dict
|
|
2337
|
+
|
|
2338
|
+
logger.info("Applying name redaction")
|
|
2339
|
+
names = collect_names_from_dict(report_dict)
|
|
2340
|
+
mapping = create_name_mapping(names)
|
|
2341
|
+
report_dict = redact_dict(report_dict, mapping)
|
|
2342
|
+
logger.debug(f"Redacted {len(mapping)} names")
|
|
2343
|
+
|
|
2344
|
+
# Output results
|
|
2345
|
+
has_output = (
|
|
2346
|
+
args.out_json
|
|
2347
|
+
or args.out_md
|
|
2348
|
+
or args.out_html
|
|
2349
|
+
or args.out_pdf
|
|
2350
|
+
or args.html_graph
|
|
2351
|
+
or args.layer_csv
|
|
2352
|
+
)
|
|
2353
|
+
if has_output:
|
|
2354
|
+
progress.step("Writing output files")
|
|
2355
|
+
|
|
2356
|
+
if args.out_json:
|
|
2357
|
+
try:
|
|
2358
|
+
import json
|
|
2359
|
+
|
|
2360
|
+
args.out_json.parent.mkdir(parents=True, exist_ok=True)
|
|
2361
|
+
args.out_json.write_text(json.dumps(report_dict, indent=2), encoding="utf-8")
|
|
2362
|
+
logger.info(f"JSON report written to: {args.out_json}")
|
|
2363
|
+
except Exception as e:
|
|
2364
|
+
logger.error(f"Failed to write JSON report: {e}")
|
|
2365
|
+
sys.exit(1)
|
|
2366
|
+
|
|
2367
|
+
# Generate visualizations if requested
|
|
2368
|
+
viz_paths = {}
|
|
2369
|
+
if args.with_plots:
|
|
2370
|
+
if not is_viz_available():
|
|
2371
|
+
logger.warning(
|
|
2372
|
+
"matplotlib not installed. Skipping visualizations. Install with: pip install matplotlib"
|
|
2373
|
+
)
|
|
2374
|
+
else:
|
|
2375
|
+
progress.step("Generating visualizations")
|
|
2376
|
+
# Determine assets directory
|
|
2377
|
+
if args.assets_dir:
|
|
2378
|
+
assets_dir = args.assets_dir
|
|
2379
|
+
elif args.out_html:
|
|
2380
|
+
# HTML embeds images, but we still generate them for the file
|
|
2381
|
+
assets_dir = args.out_html.parent / "assets"
|
|
2382
|
+
elif args.out_md:
|
|
2383
|
+
assets_dir = args.out_md.parent / "assets"
|
|
2384
|
+
elif args.out_json:
|
|
2385
|
+
assets_dir = args.out_json.parent / "assets"
|
|
2386
|
+
else:
|
|
2387
|
+
assets_dir = pathlib.Path("assets")
|
|
2388
|
+
|
|
2389
|
+
try:
|
|
2390
|
+
viz_gen = VisualizationGenerator(logger=logger)
|
|
2391
|
+
viz_paths = viz_gen.generate_all(report, assets_dir)
|
|
2392
|
+
logger.info(f"Generated {len(viz_paths)} visualization assets in {assets_dir}")
|
|
2393
|
+
except Exception as e:
|
|
2394
|
+
logger.warning(f"Failed to generate some visualizations: {e}")
|
|
2395
|
+
if args.log_level == "debug":
|
|
2396
|
+
import traceback
|
|
2397
|
+
|
|
2398
|
+
traceback.print_exc()
|
|
2399
|
+
|
|
2400
|
+
if args.out_md:
|
|
2401
|
+
try:
|
|
2402
|
+
args.out_md.parent.mkdir(parents=True, exist_ok=True)
|
|
2403
|
+
# Generate markdown with visualizations and/or LLM summaries
|
|
2404
|
+
if viz_paths or llm_summary:
|
|
2405
|
+
md_content = _generate_markdown_with_extras(
|
|
2406
|
+
report, viz_paths, args.out_md.parent, llm_summary
|
|
2407
|
+
)
|
|
2408
|
+
else:
|
|
2409
|
+
md_content = report.to_markdown()
|
|
2410
|
+
args.out_md.write_text(md_content, encoding="utf-8")
|
|
2411
|
+
logger.info(f"Markdown model card written to: {args.out_md}")
|
|
2412
|
+
except Exception as e:
|
|
2413
|
+
logger.error(f"Failed to write Markdown report: {e}")
|
|
2414
|
+
sys.exit(1)
|
|
2415
|
+
|
|
2416
|
+
if args.out_html or args.out_pdf:
|
|
2417
|
+
# Add LLM summary to report if available
|
|
2418
|
+
if llm_summary and llm_summary.success:
|
|
2419
|
+
report.llm_summary = {
|
|
2420
|
+
"success": True,
|
|
2421
|
+
"short_summary": llm_summary.short_summary,
|
|
2422
|
+
"detailed_summary": llm_summary.detailed_summary,
|
|
2423
|
+
"model": args.llm_model,
|
|
2424
|
+
}
|
|
2425
|
+
# Generate layer table HTML if requested
|
|
2426
|
+
layer_table_html = None
|
|
2427
|
+
if args.include_layer_table or args.layer_csv:
|
|
2428
|
+
try:
|
|
2429
|
+
# Re-load graph info if needed
|
|
2430
|
+
if hasattr(inspector, "_graph_info") and inspector._graph_info:
|
|
2431
|
+
graph_info = inspector._graph_info
|
|
2432
|
+
else:
|
|
2433
|
+
from .analyzer import ONNXGraphLoader
|
|
2434
|
+
|
|
2435
|
+
loader = ONNXGraphLoader(logger=logger)
|
|
2436
|
+
_, graph_info = loader.load(model_path)
|
|
2437
|
+
|
|
2438
|
+
layer_builder = LayerSummaryBuilder(logger=logger)
|
|
2439
|
+
layer_summary = layer_builder.build(
|
|
2440
|
+
graph_info,
|
|
2441
|
+
report.param_counts,
|
|
2442
|
+
report.flop_counts,
|
|
2443
|
+
report.memory_estimates,
|
|
2444
|
+
)
|
|
2445
|
+
|
|
2446
|
+
if args.include_layer_table:
|
|
2447
|
+
layer_table_html = generate_html_table(layer_summary)
|
|
2448
|
+
logger.debug("Generated layer summary table for HTML report")
|
|
2449
|
+
|
|
2450
|
+
if args.layer_csv:
|
|
2451
|
+
args.layer_csv.parent.mkdir(parents=True, exist_ok=True)
|
|
2452
|
+
layer_summary.save_csv(args.layer_csv)
|
|
2453
|
+
logger.info(f"Layer summary CSV written to: {args.layer_csv}")
|
|
2454
|
+
|
|
2455
|
+
except Exception as e:
|
|
2456
|
+
logger.warning(f"Could not generate layer summary: {e}")
|
|
2457
|
+
|
|
2458
|
+
# Generate embedded graph HTML if requested
|
|
2459
|
+
graph_html = None
|
|
2460
|
+
if args.include_graph:
|
|
2461
|
+
try:
|
|
2462
|
+
# Re-load graph info if needed
|
|
2463
|
+
if hasattr(inspector, "_graph_info") and inspector._graph_info:
|
|
2464
|
+
graph_info = inspector._graph_info
|
|
2465
|
+
else:
|
|
2466
|
+
from .analyzer import ONNXGraphLoader
|
|
2467
|
+
|
|
2468
|
+
loader = ONNXGraphLoader(logger=logger)
|
|
2469
|
+
_, graph_info = loader.load(model_path)
|
|
2470
|
+
|
|
2471
|
+
pattern_analyzer = PatternAnalyzer(logger=logger)
|
|
2472
|
+
blocks = pattern_analyzer.group_into_blocks(graph_info)
|
|
2473
|
+
|
|
2474
|
+
edge_analyzer = EdgeAnalyzer(logger=logger)
|
|
2475
|
+
edge_result = edge_analyzer.analyze(graph_info)
|
|
2476
|
+
|
|
2477
|
+
builder = HierarchicalGraphBuilder(logger=logger)
|
|
2478
|
+
hier_graph = builder.build(graph_info, blocks, model_path.stem)
|
|
2479
|
+
|
|
2480
|
+
# Generate graph HTML (just the interactive part, not full document)
|
|
2481
|
+
# For embedding, we'll use an iframe approach
|
|
2482
|
+
|
|
2483
|
+
# Extract layer timing from profiling results if available
|
|
2484
|
+
layer_timing = None
|
|
2485
|
+
if (
|
|
2486
|
+
report.extra_data
|
|
2487
|
+
and "profiling" in report.extra_data
|
|
2488
|
+
and "slowest_layers" in report.extra_data["profiling"]
|
|
2489
|
+
):
|
|
2490
|
+
layer_timing = {}
|
|
2491
|
+
for layer in report.extra_data["profiling"]["slowest_layers"]:
|
|
2492
|
+
layer_timing[layer["name"]] = layer["duration_ms"]
|
|
2493
|
+
|
|
2494
|
+
full_graph_html = generate_graph_html(
|
|
2495
|
+
hier_graph, edge_result, model_path.stem, layer_timing=layer_timing
|
|
2496
|
+
)
|
|
2497
|
+
# Wrap in iframe data URI for embedding
|
|
2498
|
+
import base64
|
|
2499
|
+
|
|
2500
|
+
graph_data = base64.b64encode(full_graph_html.encode()).decode()
|
|
2501
|
+
graph_html = f'<iframe src="data:text/html;base64,{graph_data}" style="width:100%;height:100%;border:none;"></iframe>'
|
|
2502
|
+
logger.debug("Generated interactive graph for HTML report")
|
|
2503
|
+
|
|
2504
|
+
except Exception as e:
|
|
2505
|
+
logger.warning(f"Could not generate embedded graph: {e}")
|
|
2506
|
+
|
|
2507
|
+
# Generate HTML with embedded images, graph, and layer table
|
|
2508
|
+
html_content = report.to_html(
|
|
2509
|
+
image_paths=viz_paths,
|
|
2510
|
+
graph_html=graph_html,
|
|
2511
|
+
layer_table_html=layer_table_html,
|
|
2512
|
+
)
|
|
2513
|
+
|
|
2514
|
+
if args.out_html:
|
|
2515
|
+
try:
|
|
2516
|
+
args.out_html.parent.mkdir(parents=True, exist_ok=True)
|
|
2517
|
+
args.out_html.write_text(html_content, encoding="utf-8")
|
|
2518
|
+
logger.info(f"HTML report written to: {args.out_html}")
|
|
2519
|
+
except Exception as e:
|
|
2520
|
+
logger.error(f"Failed to write HTML report: {e}")
|
|
2521
|
+
sys.exit(1)
|
|
2522
|
+
|
|
2523
|
+
if args.out_pdf:
|
|
2524
|
+
if not is_pdf_available():
|
|
2525
|
+
logger.error(
|
|
2526
|
+
"Playwright not installed. Install with: pip install playwright && playwright install chromium"
|
|
2527
|
+
)
|
|
2528
|
+
sys.exit(1)
|
|
2529
|
+
try:
|
|
2530
|
+
args.out_pdf.parent.mkdir(parents=True, exist_ok=True)
|
|
2531
|
+
pdf_gen = PDFGenerator(logger=logger)
|
|
2532
|
+
success = pdf_gen.generate_from_html(html_content, args.out_pdf)
|
|
2533
|
+
if success:
|
|
2534
|
+
logger.info(f"PDF report written to: {args.out_pdf}")
|
|
2535
|
+
else:
|
|
2536
|
+
logger.error("PDF generation failed")
|
|
2537
|
+
sys.exit(1)
|
|
2538
|
+
except Exception as e:
|
|
2539
|
+
logger.error(f"Failed to write PDF report: {e}")
|
|
2540
|
+
sys.exit(1)
|
|
2541
|
+
|
|
2542
|
+
# Interactive graph visualization
|
|
2543
|
+
if args.html_graph:
|
|
2544
|
+
try:
|
|
2545
|
+
args.html_graph.parent.mkdir(parents=True, exist_ok=True)
|
|
2546
|
+
|
|
2547
|
+
# Use graph_info from the inspector if available
|
|
2548
|
+
if hasattr(inspector, "_graph_info") and inspector._graph_info:
|
|
2549
|
+
graph_info = inspector._graph_info
|
|
2550
|
+
else:
|
|
2551
|
+
# Re-load the model to get graph_info
|
|
2552
|
+
from .analyzer import ONNXGraphLoader
|
|
2553
|
+
|
|
2554
|
+
loader = ONNXGraphLoader(logger=logger)
|
|
2555
|
+
_, graph_info = loader.load(model_path)
|
|
2556
|
+
|
|
2557
|
+
# Detect patterns
|
|
2558
|
+
pattern_analyzer = PatternAnalyzer(logger=logger)
|
|
2559
|
+
blocks = pattern_analyzer.group_into_blocks(graph_info)
|
|
2560
|
+
|
|
2561
|
+
# Analyze edges
|
|
2562
|
+
edge_analyzer = EdgeAnalyzer(logger=logger)
|
|
2563
|
+
edge_result = edge_analyzer.analyze(graph_info)
|
|
2564
|
+
|
|
2565
|
+
# Build hierarchy
|
|
2566
|
+
builder = HierarchicalGraphBuilder(logger=logger)
|
|
2567
|
+
hier_graph = builder.build(graph_info, blocks, model_path.stem)
|
|
2568
|
+
|
|
2569
|
+
# Export HTML with model size and layer timing
|
|
2570
|
+
model_size = model_path.stat().st_size if model_path.exists() else None
|
|
2571
|
+
|
|
2572
|
+
# Extract layer timing from profiling results if available
|
|
2573
|
+
layer_timing = None
|
|
2574
|
+
if (
|
|
2575
|
+
report.extra_data
|
|
2576
|
+
and "profiling" in report.extra_data
|
|
2577
|
+
and "slowest_layers" in report.extra_data["profiling"]
|
|
2578
|
+
):
|
|
2579
|
+
# Build timing dict from profiling results
|
|
2580
|
+
layer_timing = {}
|
|
2581
|
+
for layer in report.extra_data["profiling"]["slowest_layers"]:
|
|
2582
|
+
layer_timing[layer["name"]] = layer["duration_ms"]
|
|
2583
|
+
|
|
2584
|
+
exporter = HTMLExporter(logger=logger)
|
|
2585
|
+
exporter.export(
|
|
2586
|
+
hier_graph,
|
|
2587
|
+
edge_result,
|
|
2588
|
+
args.html_graph,
|
|
2589
|
+
model_path.stem,
|
|
2590
|
+
model_size_bytes=model_size,
|
|
2591
|
+
layer_timing=layer_timing,
|
|
2592
|
+
)
|
|
2593
|
+
|
|
2594
|
+
logger.info(f"Interactive graph visualization written to: {args.html_graph}")
|
|
2595
|
+
except Exception as e:
|
|
2596
|
+
logger.error(f"Failed to generate graph visualization: {e}")
|
|
2597
|
+
if not args.quiet:
|
|
2598
|
+
import traceback
|
|
2599
|
+
|
|
2600
|
+
traceback.print_exc()
|
|
2601
|
+
sys.exit(1)
|
|
2602
|
+
|
|
2603
|
+
# Console output
|
|
2604
|
+
if (
|
|
2605
|
+
not args.quiet
|
|
2606
|
+
and not args.out_json
|
|
2607
|
+
and not args.out_md
|
|
2608
|
+
and not args.out_html
|
|
2609
|
+
and not args.out_pdf
|
|
2610
|
+
and not args.html_graph
|
|
2611
|
+
):
|
|
2612
|
+
# No output files specified - print summary to console
|
|
2613
|
+
print("\n" + "=" * 60)
|
|
2614
|
+
print(f"Model: {model_path.name}")
|
|
2615
|
+
print("=" * 60)
|
|
2616
|
+
|
|
2617
|
+
if report.graph_summary:
|
|
2618
|
+
print(f"\nNodes: {report.graph_summary.num_nodes}")
|
|
2619
|
+
print(f"Inputs: {report.graph_summary.num_inputs}")
|
|
2620
|
+
print(f"Outputs: {report.graph_summary.num_outputs}")
|
|
2621
|
+
print(f"Initializers: {report.graph_summary.num_initializers}")
|
|
2622
|
+
|
|
2623
|
+
if report.param_counts:
|
|
2624
|
+
print(f"\nParameters: {report._format_number(report.param_counts.total)}")
|
|
2625
|
+
|
|
2626
|
+
if report.flop_counts:
|
|
2627
|
+
print(f"FLOPs: {report._format_number(report.flop_counts.total)}")
|
|
2628
|
+
|
|
2629
|
+
if report.memory_estimates:
|
|
2630
|
+
print(f"Model Size: {report._format_bytes(report.memory_estimates.model_size_bytes)}")
|
|
2631
|
+
|
|
2632
|
+
print(f"\nArchitecture: {report.architecture_type}")
|
|
2633
|
+
print(f"Detected Blocks: {len(report.detected_blocks)}")
|
|
2634
|
+
|
|
2635
|
+
# Hardware estimates
|
|
2636
|
+
if hasattr(report, "hardware_estimates") and report.hardware_estimates:
|
|
2637
|
+
hw = report.hardware_estimates
|
|
2638
|
+
print(f"\n--- Hardware Estimates ({hw.device}) ---")
|
|
2639
|
+
print(f"Precision: {hw.precision}, Batch Size: {hw.batch_size}")
|
|
2640
|
+
print(f"VRAM Required: {report._format_bytes(hw.vram_required_bytes)}")
|
|
2641
|
+
print(f"Fits in VRAM: {'Yes' if hw.fits_in_vram else 'NO'}")
|
|
2642
|
+
if hw.fits_in_vram:
|
|
2643
|
+
print(f"Theoretical Latency: {hw.theoretical_latency_ms:.2f} ms")
|
|
2644
|
+
print(f"Bottleneck: {hw.bottleneck}")
|
|
2645
|
+
|
|
2646
|
+
# System Requirements (Console)
|
|
2647
|
+
if hasattr(report, "system_requirements") and report.system_requirements:
|
|
2648
|
+
reqs = report.system_requirements
|
|
2649
|
+
print("\n--- System Requirements ---")
|
|
2650
|
+
print(f"Minimum: {reqs.minimum_gpu.name} ({reqs.minimum_vram_gb} GB VRAM)")
|
|
2651
|
+
print(f"Recommended: {reqs.recommended_gpu.name} ({reqs.recommended_vram_gb} GB VRAM)")
|
|
2652
|
+
print(f"Optimal: {reqs.optimal_gpu.name}")
|
|
2653
|
+
|
|
2654
|
+
# Batch Scaling (Console)
|
|
2655
|
+
if hasattr(report, "batch_size_sweep") and report.batch_size_sweep:
|
|
2656
|
+
sweep = report.batch_size_sweep
|
|
2657
|
+
print("\n--- Batch Size Scaling ---")
|
|
2658
|
+
print(f"Optimal Batch Size: {sweep.optimal_batch_size}")
|
|
2659
|
+
print(f"Max Throughput: {max(sweep.throughputs):.1f} inf/s")
|
|
2660
|
+
|
|
2661
|
+
if report.risk_signals:
|
|
2662
|
+
print(f"\nRisk Signals: {len(report.risk_signals)}")
|
|
2663
|
+
for risk in report.risk_signals:
|
|
2664
|
+
severity_icon = {
|
|
2665
|
+
"info": "[INFO]",
|
|
2666
|
+
"warning": "[WARN]",
|
|
2667
|
+
"high": "[HIGH]",
|
|
2668
|
+
}
|
|
2669
|
+
print(f" {severity_icon.get(risk.severity, '')} {risk.id}")
|
|
2670
|
+
|
|
2671
|
+
# LLM Summary
|
|
2672
|
+
if llm_summary and llm_summary.success:
|
|
2673
|
+
print(f"\n--- LLM Summary ({llm_summary.model_used}) ---")
|
|
2674
|
+
if llm_summary.short_summary:
|
|
2675
|
+
print(f"{llm_summary.short_summary}")
|
|
2676
|
+
|
|
2677
|
+
print("\n" + "=" * 60)
|
|
2678
|
+
print("Use --out-json or --out-md for detailed reports.")
|
|
2679
|
+
if not args.hardware:
|
|
2680
|
+
print("Use --hardware auto or --hardware <profile> for hardware estimates.")
|
|
2681
|
+
print("=" * 60 + "\n")
|
|
2682
|
+
|
|
2683
|
+
elif not args.quiet:
|
|
2684
|
+
# Files written - just confirm
|
|
2685
|
+
print(f"\nInspection complete for: {model_path.name}")
|
|
2686
|
+
if args.out_json:
|
|
2687
|
+
print(f" JSON report: {args.out_json}")
|
|
2688
|
+
if args.out_md:
|
|
2689
|
+
print(f" Markdown card: {args.out_md}")
|
|
2690
|
+
if args.out_html:
|
|
2691
|
+
print(f" HTML report: {args.out_html}")
|
|
2692
|
+
if args.out_pdf:
|
|
2693
|
+
print(f" PDF report: {args.out_pdf}")
|
|
2694
|
+
if args.html_graph:
|
|
2695
|
+
print(f" Graph visualization: {args.html_graph}")
|
|
2696
|
+
if args.layer_csv:
|
|
2697
|
+
print(f" Layer CSV: {args.layer_csv}")
|
|
2698
|
+
|
|
2699
|
+
# Finish progress indicator
|
|
2700
|
+
progress.finish("Analysis complete!")
|
|
2701
|
+
|
|
2702
|
+
# Cleanup temp ONNX file if we created one
|
|
2703
|
+
if temp_onnx_file is not None:
|
|
2704
|
+
try:
|
|
2705
|
+
pathlib.Path(temp_onnx_file.name).unlink()
|
|
2706
|
+
logger.debug(f"Cleaned up temp ONNX file: {temp_onnx_file.name}")
|
|
2707
|
+
except Exception:
|
|
2708
|
+
pass
|
|
2709
|
+
|
|
2710
|
+
|
|
2711
|
+
if __name__ == "__main__":
|
|
2712
|
+
run_inspect()
|