haoline 0.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (70) hide show
  1. haoline/.streamlit/config.toml +10 -0
  2. haoline/__init__.py +248 -0
  3. haoline/analyzer.py +935 -0
  4. haoline/cli.py +2712 -0
  5. haoline/compare.py +811 -0
  6. haoline/compare_visualizations.py +1564 -0
  7. haoline/edge_analysis.py +525 -0
  8. haoline/eval/__init__.py +131 -0
  9. haoline/eval/adapters.py +844 -0
  10. haoline/eval/cli.py +390 -0
  11. haoline/eval/comparison.py +542 -0
  12. haoline/eval/deployment.py +633 -0
  13. haoline/eval/schemas.py +833 -0
  14. haoline/examples/__init__.py +15 -0
  15. haoline/examples/basic_inspection.py +74 -0
  16. haoline/examples/compare_models.py +117 -0
  17. haoline/examples/hardware_estimation.py +78 -0
  18. haoline/format_adapters.py +1001 -0
  19. haoline/formats/__init__.py +123 -0
  20. haoline/formats/coreml.py +250 -0
  21. haoline/formats/gguf.py +483 -0
  22. haoline/formats/openvino.py +255 -0
  23. haoline/formats/safetensors.py +273 -0
  24. haoline/formats/tflite.py +369 -0
  25. haoline/hardware.py +2307 -0
  26. haoline/hierarchical_graph.py +462 -0
  27. haoline/html_export.py +1573 -0
  28. haoline/layer_summary.py +769 -0
  29. haoline/llm_summarizer.py +465 -0
  30. haoline/op_icons.py +618 -0
  31. haoline/operational_profiling.py +1492 -0
  32. haoline/patterns.py +1116 -0
  33. haoline/pdf_generator.py +265 -0
  34. haoline/privacy.py +250 -0
  35. haoline/pydantic_models.py +241 -0
  36. haoline/report.py +1923 -0
  37. haoline/report_sections.py +539 -0
  38. haoline/risks.py +521 -0
  39. haoline/schema.py +523 -0
  40. haoline/streamlit_app.py +2024 -0
  41. haoline/tests/__init__.py +4 -0
  42. haoline/tests/conftest.py +123 -0
  43. haoline/tests/test_analyzer.py +868 -0
  44. haoline/tests/test_compare_visualizations.py +293 -0
  45. haoline/tests/test_edge_analysis.py +243 -0
  46. haoline/tests/test_eval.py +604 -0
  47. haoline/tests/test_format_adapters.py +460 -0
  48. haoline/tests/test_hardware.py +237 -0
  49. haoline/tests/test_hardware_recommender.py +90 -0
  50. haoline/tests/test_hierarchical_graph.py +326 -0
  51. haoline/tests/test_html_export.py +180 -0
  52. haoline/tests/test_layer_summary.py +428 -0
  53. haoline/tests/test_llm_patterns.py +540 -0
  54. haoline/tests/test_llm_summarizer.py +339 -0
  55. haoline/tests/test_patterns.py +774 -0
  56. haoline/tests/test_pytorch.py +327 -0
  57. haoline/tests/test_report.py +383 -0
  58. haoline/tests/test_risks.py +398 -0
  59. haoline/tests/test_schema.py +417 -0
  60. haoline/tests/test_tensorflow.py +380 -0
  61. haoline/tests/test_visualizations.py +316 -0
  62. haoline/universal_ir.py +856 -0
  63. haoline/visualizations.py +1086 -0
  64. haoline/visualize_yolo.py +44 -0
  65. haoline/web.py +110 -0
  66. haoline-0.3.0.dist-info/METADATA +471 -0
  67. haoline-0.3.0.dist-info/RECORD +70 -0
  68. haoline-0.3.0.dist-info/WHEEL +4 -0
  69. haoline-0.3.0.dist-info/entry_points.txt +5 -0
  70. haoline-0.3.0.dist-info/licenses/LICENSE +22 -0
haoline/cli.py ADDED
@@ -0,0 +1,2712 @@
1
+ # Copyright (c) 2025 HaoLine Contributors
2
+ # SPDX-License-Identifier: MIT
3
+
4
+ """
5
+ HaoLine CLI - Universal Model Inspector.
6
+
7
+ Entry point for the `haoline` command-line tool.
8
+ """
9
+
10
+ from __future__ import annotations
11
+
12
+ import argparse
13
+ import logging
14
+ import os
15
+ import pathlib
16
+ import sys
17
+ from typing import Any
18
+
19
+ from . import ModelInspector
20
+ from .edge_analysis import EdgeAnalyzer
21
+ from .hardware import (
22
+ CLOUD_INSTANCES,
23
+ HardwareEstimator,
24
+ create_multi_gpu_profile,
25
+ detect_local_hardware,
26
+ get_cloud_instance,
27
+ get_profile,
28
+ )
29
+ from .hierarchical_graph import HierarchicalGraphBuilder
30
+ from .html_export import HTMLExporter
31
+ from .html_export import generate_html as generate_graph_html
32
+ from .layer_summary import LayerSummaryBuilder, generate_html_table
33
+ from .llm_summarizer import (
34
+ LLMSummarizer,
35
+ )
36
+ from .llm_summarizer import (
37
+ has_api_key as has_llm_api_key,
38
+ )
39
+ from .llm_summarizer import (
40
+ is_available as is_llm_available,
41
+ )
42
+ from .operational_profiling import OperationalProfiler
43
+ from .patterns import PatternAnalyzer
44
+ from .pdf_generator import (
45
+ PDFGenerator,
46
+ )
47
+ from .pdf_generator import (
48
+ is_available as is_pdf_available,
49
+ )
50
+ from .visualizations import (
51
+ VisualizationGenerator,
52
+ )
53
+ from .visualizations import (
54
+ is_available as is_viz_available,
55
+ )
56
+
57
+
58
+ class ProgressIndicator:
59
+ """Simple progress indicator for CLI operations."""
60
+
61
+ def __init__(self, enabled: bool = True, quiet: bool = False):
62
+ self.enabled = enabled and not quiet
63
+ self._current_step = 0
64
+ self._total_steps = 0
65
+
66
+ def start(self, total_steps: int, description: str = "Processing"):
67
+ """Start progress tracking."""
68
+ self._total_steps = total_steps
69
+ self._current_step = 0
70
+ if self.enabled:
71
+ print(f"\n{description}...")
72
+
73
+ def step(self, message: str):
74
+ """Mark completion of a step."""
75
+ self._current_step += 1
76
+ if self.enabled:
77
+ pct = (self._current_step / self._total_steps * 100) if self._total_steps else 0
78
+ print(f" [{self._current_step}/{self._total_steps}] {message} ({pct:.0f}%)")
79
+
80
+ def finish(self, message: str = "Done"):
81
+ """Mark completion of all steps."""
82
+ if self.enabled:
83
+ print(f" {message}\n")
84
+
85
+
86
+ def parse_args():
87
+ """Parse command line arguments."""
88
+ parser = argparse.ArgumentParser(
89
+ os.path.basename(__file__),
90
+ description="Analyze an ONNX model and generate architecture documentation.",
91
+ formatter_class=argparse.RawDescriptionHelpFormatter,
92
+ epilog="""
93
+ Examples:
94
+ # Basic inspection with console output (auto-detects local hardware)
95
+ python -m haoline model.onnx
96
+
97
+ # Use specific NVIDIA GPU profile for estimates
98
+ python -m haoline model.onnx --hardware a100
99
+
100
+ # List available hardware profiles
101
+ python -m haoline --list-hardware
102
+
103
+ # Generate JSON report with hardware estimates
104
+ python -m haoline model.onnx --hardware rtx4090 --out-json report.json
105
+
106
+ # Specify precision and batch size for hardware estimates
107
+ python -m haoline model.onnx --hardware t4 --precision fp16 --batch-size 8
108
+
109
+ # Convert PyTorch model to ONNX and analyze
110
+ python -m haoline --from-pytorch model.pt --input-shape 1,3,224,224
111
+
112
+ # Convert TensorFlow SavedModel to ONNX and analyze
113
+ python -m haoline --from-tensorflow ./saved_model_dir --out-html report.html
114
+
115
+ # Convert Keras .h5 model to ONNX and analyze
116
+ python -m haoline --from-keras model.h5 --keep-onnx converted.onnx
117
+
118
+ # Convert TensorFlow frozen graph to ONNX (requires input/output names)
119
+ python -m haoline --from-frozen-graph model.pb --tf-inputs input:0 --tf-outputs output:0
120
+
121
+ # Convert JAX model to ONNX (requires apply function and input shape)
122
+ python -m haoline --from-jax params.pkl --jax-apply-fn my_model:apply --input-shape 1,3,224,224
123
+
124
+ # Generate Steam-style system requirements
125
+ python -m haoline model.onnx --system-requirements
126
+
127
+ # Run batch size sweep
128
+ python -m haoline model.onnx --hardware a100 --sweep-batch-sizes
129
+
130
+ # Run resolution sweep for vision models
131
+ python -m haoline model.onnx --hardware rtx4090 --sweep-resolutions auto
132
+
133
+ # Custom resolutions for object detection
134
+ python -m haoline yolo.onnx --hardware rtx4090 --sweep-resolutions "320x320,640x640,1280x1280"
135
+
136
+ # List available format conversions
137
+ python -m haoline --list-conversions
138
+
139
+ # Convert PyTorch to ONNX and save
140
+ python -m haoline --from-pytorch model.pt --input-shape 1,3,224,224 --convert-to onnx --convert-output model.onnx
141
+
142
+ # Export model as Universal IR (JSON)
143
+ python -m haoline model.onnx --export-ir model_ir.json
144
+
145
+ # Export graph visualization (DOT or PNG)
146
+ python -m haoline model.onnx --export-graph graph.dot
147
+ python -m haoline model.onnx --export-graph graph.png --graph-max-nodes 200
148
+ """,
149
+ )
150
+
151
+ parser.add_argument(
152
+ "model_path",
153
+ type=pathlib.Path,
154
+ nargs="?", # Optional now since --list-hardware doesn't need it
155
+ help="Path to the ONNX model file to analyze.",
156
+ )
157
+
158
+ parser.add_argument(
159
+ "--schema",
160
+ action="store_true",
161
+ help="Output the JSON Schema for InspectionReport and exit. "
162
+ "Useful for integrating HaoLine output with other tools.",
163
+ )
164
+
165
+ parser.add_argument(
166
+ "--out-json",
167
+ type=pathlib.Path,
168
+ default=None,
169
+ help="Output path for JSON report. If not specified, no JSON is written.",
170
+ )
171
+
172
+ parser.add_argument(
173
+ "--out-md",
174
+ type=pathlib.Path,
175
+ default=None,
176
+ help="Output path for Markdown model card. If not specified, no Markdown is written.",
177
+ )
178
+
179
+ parser.add_argument(
180
+ "--out-html",
181
+ type=pathlib.Path,
182
+ default=None,
183
+ help="Output path for HTML report with embedded images. Single shareable file.",
184
+ )
185
+
186
+ parser.add_argument(
187
+ "--out-pdf",
188
+ type=pathlib.Path,
189
+ default=None,
190
+ help="Output path for PDF report. Requires playwright: pip install playwright && playwright install chromium",
191
+ )
192
+
193
+ parser.add_argument(
194
+ "--html-graph",
195
+ type=pathlib.Path,
196
+ default=None,
197
+ help="Output path for interactive graph visualization (standalone HTML with D3.js).",
198
+ )
199
+
200
+ parser.add_argument(
201
+ "--layer-csv",
202
+ type=pathlib.Path,
203
+ default=None,
204
+ help="Output path for per-layer metrics CSV (params, FLOPs, memory per layer).",
205
+ )
206
+
207
+ parser.add_argument(
208
+ "--include-graph",
209
+ action="store_true",
210
+ help="Include interactive graph in --out-html report (makes HTML larger but more informative).",
211
+ )
212
+
213
+ parser.add_argument(
214
+ "--include-layer-table",
215
+ action="store_true",
216
+ help="Include per-layer summary table in --out-html report.",
217
+ )
218
+
219
+ # PyTorch conversion options
220
+ pytorch_group = parser.add_argument_group("PyTorch Conversion Options")
221
+ pytorch_group.add_argument(
222
+ "--from-pytorch",
223
+ type=pathlib.Path,
224
+ default=None,
225
+ metavar="MODEL_PATH",
226
+ help="Convert a PyTorch model (.pth, .pt) to ONNX before analysis. Requires torch.",
227
+ )
228
+ pytorch_group.add_argument(
229
+ "--input-shape",
230
+ type=str,
231
+ default=None,
232
+ metavar="SHAPE",
233
+ help="Input shape for PyTorch conversion, e.g., '1,3,224,224'. Required with --from-pytorch.",
234
+ )
235
+ pytorch_group.add_argument(
236
+ "--keep-onnx",
237
+ type=pathlib.Path,
238
+ default=None,
239
+ metavar="PATH",
240
+ help="Save the converted ONNX model to this path (otherwise uses temp file).",
241
+ )
242
+ pytorch_group.add_argument(
243
+ "--opset-version",
244
+ type=int,
245
+ default=17,
246
+ help="ONNX opset version for PyTorch export (default: 17).",
247
+ )
248
+ pytorch_group.add_argument(
249
+ "--pytorch-weights",
250
+ type=pathlib.Path,
251
+ default=None,
252
+ metavar="PATH",
253
+ help="Path to original PyTorch weights (.pt) to extract class names/metadata. "
254
+ "Useful when analyzing a pre-converted ONNX file.",
255
+ )
256
+
257
+ # TensorFlow/Keras conversion options
258
+ tf_group = parser.add_argument_group("TensorFlow/Keras Conversion Options")
259
+ tf_group.add_argument(
260
+ "--from-tensorflow",
261
+ type=pathlib.Path,
262
+ default=None,
263
+ metavar="MODEL_PATH",
264
+ help="Convert a TensorFlow SavedModel directory to ONNX before analysis. Requires tf2onnx.",
265
+ )
266
+ tf_group.add_argument(
267
+ "--from-keras",
268
+ type=pathlib.Path,
269
+ default=None,
270
+ metavar="MODEL_PATH",
271
+ help="Convert a Keras model (.h5, .keras) to ONNX before analysis. Requires tf2onnx.",
272
+ )
273
+ tf_group.add_argument(
274
+ "--from-frozen-graph",
275
+ type=pathlib.Path,
276
+ default=None,
277
+ metavar="MODEL_PATH",
278
+ help="Convert a TensorFlow frozen graph (.pb) to ONNX. Requires --tf-inputs and --tf-outputs.",
279
+ )
280
+ tf_group.add_argument(
281
+ "--tf-inputs",
282
+ type=str,
283
+ default=None,
284
+ metavar="NAMES",
285
+ help="Comma-separated input tensor names for frozen graph conversion, e.g., 'input:0'.",
286
+ )
287
+ tf_group.add_argument(
288
+ "--tf-outputs",
289
+ type=str,
290
+ default=None,
291
+ metavar="NAMES",
292
+ help="Comma-separated output tensor names for frozen graph conversion, e.g., 'output:0'.",
293
+ )
294
+
295
+ # JAX conversion options
296
+ jax_group = parser.add_argument_group("JAX/Flax Conversion Options")
297
+ jax_group.add_argument(
298
+ "--from-jax",
299
+ type=pathlib.Path,
300
+ default=None,
301
+ metavar="MODEL_PATH",
302
+ help="Convert a JAX/Flax model to ONNX before analysis. Requires jax and tf2onnx. "
303
+ "Expects a saved params file (.msgpack, .pkl) with --jax-apply-fn.",
304
+ )
305
+ jax_group.add_argument(
306
+ "--jax-apply-fn",
307
+ type=str,
308
+ default=None,
309
+ metavar="MODULE:FUNCTION",
310
+ help="JAX apply function path, e.g., 'my_model:apply'. Required with --from-jax.",
311
+ )
312
+
313
+ # Hardware options
314
+ hardware_group = parser.add_argument_group("Hardware Options")
315
+ hardware_group.add_argument(
316
+ "--hardware",
317
+ type=str,
318
+ default=None,
319
+ metavar="PROFILE",
320
+ help="Hardware profile for performance estimates. Use 'auto' to detect local hardware, "
321
+ "or specify a profile name (e.g., 'a100', 'rtx4090', 't4'). Use --list-hardware to see all options.",
322
+ )
323
+
324
+ hardware_group.add_argument(
325
+ "--list-hardware",
326
+ action="store_true",
327
+ help="List all available hardware profiles and exit.",
328
+ )
329
+
330
+ hardware_group.add_argument(
331
+ "--precision",
332
+ choices=["fp32", "fp16", "bf16", "int8"],
333
+ default="fp32",
334
+ help="Precision for hardware estimates (default: fp32).",
335
+ )
336
+
337
+ hardware_group.add_argument(
338
+ "--batch-size",
339
+ type=int,
340
+ default=1,
341
+ help="Batch size for hardware estimates (default: 1).",
342
+ )
343
+
344
+ hardware_group.add_argument(
345
+ "--gpu-count",
346
+ type=int,
347
+ default=1,
348
+ metavar="N",
349
+ help="Number of GPUs for multi-GPU estimates (default: 1). "
350
+ "Scales compute and memory with efficiency factors for tensor parallelism.",
351
+ )
352
+
353
+ hardware_group.add_argument(
354
+ "--cloud",
355
+ type=str,
356
+ default=None,
357
+ metavar="INSTANCE",
358
+ help="Cloud instance type for cost/performance estimates. "
359
+ "E.g., 'aws-p4d-24xlarge', 'azure-nd-h100-v5'. Use --list-cloud to see options.",
360
+ )
361
+
362
+ hardware_group.add_argument(
363
+ "--list-cloud",
364
+ action="store_true",
365
+ help="List all available cloud instance profiles and exit.",
366
+ )
367
+
368
+ hardware_group.add_argument(
369
+ "--deployment-target",
370
+ type=str,
371
+ choices=["edge", "local", "cloud"],
372
+ default=None,
373
+ help="High-level deployment target to guide system requirement recommendations "
374
+ "(edge device, local server, or cloud server).",
375
+ )
376
+
377
+ hardware_group.add_argument(
378
+ "--target-latency-ms",
379
+ type=float,
380
+ default=None,
381
+ help="Optional latency target (ms) for system requirements. "
382
+ "If set, this is converted into a throughput target for recommendations.",
383
+ )
384
+
385
+ hardware_group.add_argument(
386
+ "--target-throughput-fps",
387
+ type=float,
388
+ default=None,
389
+ help="Optional throughput target (frames/requests per second) for system requirements.",
390
+ )
391
+
392
+ hardware_group.add_argument(
393
+ "--deployment-fps",
394
+ type=float,
395
+ default=None,
396
+ metavar="FPS",
397
+ help="Target inference rate for deployment cost calculation (e.g., 3 for 3 fps continuous). "
398
+ "Combined with --deployment-hours to estimate $/day and $/month.",
399
+ )
400
+
401
+ hardware_group.add_argument(
402
+ "--deployment-hours",
403
+ type=float,
404
+ default=24.0,
405
+ metavar="HOURS",
406
+ help="Hours per day the model runs for deployment cost calculation (default: 24). "
407
+ "E.g., 8 for business hours only.",
408
+ )
409
+
410
+ # Epic 6C features
411
+ hardware_group.add_argument(
412
+ "--system-requirements",
413
+ action="store_true",
414
+ help="Generate Steam-style system requirements (Minimum, Recommended, Optimal).",
415
+ )
416
+
417
+ hardware_group.add_argument(
418
+ "--sweep-batch-sizes",
419
+ action="store_true",
420
+ help="Run analysis across multiple batch sizes (1, 2, 4, ..., 128) to find optimal throughput.",
421
+ )
422
+
423
+ hardware_group.add_argument(
424
+ "--no-benchmark",
425
+ action="store_true",
426
+ help="Use theoretical estimates instead of actual inference for batch sweeps (faster but less accurate).",
427
+ )
428
+
429
+ hardware_group.add_argument(
430
+ "--input-resolution",
431
+ type=str,
432
+ default=None,
433
+ help=(
434
+ "Override input resolution for analysis. Format: HxW (e.g., 640x640). "
435
+ "For vision models, affects FLOPs and memory estimates."
436
+ ),
437
+ )
438
+
439
+ hardware_group.add_argument(
440
+ "--sweep-resolutions",
441
+ type=str,
442
+ default=None,
443
+ help=(
444
+ "Run resolution sweep analysis. Provide comma-separated resolutions "
445
+ "(e.g., '224x224,384x384,512x512,640x640') or 'auto' for common resolutions."
446
+ ),
447
+ )
448
+
449
+ # Epic 9: Runtime Profiling Options (defaults to ON for real measurements)
450
+ profiling_group = parser.add_argument_group("Runtime Profiling Options")
451
+ profiling_group.add_argument(
452
+ "--no-profile",
453
+ action="store_true",
454
+ help="Disable per-layer ONNX Runtime profiling (faster but less detailed).",
455
+ )
456
+
457
+ profiling_group.add_argument(
458
+ "--profile-runs",
459
+ type=int,
460
+ default=10,
461
+ metavar="N",
462
+ help="Number of inference runs for profiling (default: 10).",
463
+ )
464
+
465
+ profiling_group.add_argument(
466
+ "--no-gpu-metrics",
467
+ action="store_true",
468
+ help="Disable GPU metrics capture (VRAM, utilization, temperature).",
469
+ )
470
+
471
+ profiling_group.add_argument(
472
+ "--no-bottleneck-analysis",
473
+ action="store_true",
474
+ help="Disable compute vs memory bottleneck analysis.",
475
+ )
476
+
477
+ profiling_group.add_argument(
478
+ "--no-benchmark-resolutions",
479
+ action="store_true",
480
+ help="Disable resolution benchmarking (use theoretical estimates instead).",
481
+ )
482
+
483
+ # Visualization options
484
+ viz_group = parser.add_argument_group("Visualization Options")
485
+ viz_group.add_argument(
486
+ "--with-plots",
487
+ action="store_true",
488
+ help="Generate visualization assets (requires matplotlib).",
489
+ )
490
+
491
+ viz_group.add_argument(
492
+ "--assets-dir",
493
+ type=pathlib.Path,
494
+ default=None,
495
+ metavar="PATH",
496
+ help="Directory for plot PNG files (default: same directory as output files, or 'assets/').",
497
+ )
498
+
499
+ # LLM options
500
+ llm_group = parser.add_argument_group("LLM Summarization Options")
501
+ llm_group.add_argument(
502
+ "--llm-summary",
503
+ action="store_true",
504
+ help="Generate LLM-powered summaries (requires openai package and OPENAI_API_KEY env var).",
505
+ )
506
+
507
+ llm_group.add_argument(
508
+ "--llm-model",
509
+ type=str,
510
+ default="gpt-4o-mini",
511
+ metavar="MODEL",
512
+ help="OpenAI model to use for summaries (default: gpt-4o-mini).",
513
+ )
514
+
515
+ parser.add_argument(
516
+ "--log-level",
517
+ choices=["debug", "info", "warning", "error"],
518
+ default="info",
519
+ help="Logging verbosity level (default: info).",
520
+ )
521
+
522
+ parser.add_argument(
523
+ "--quiet",
524
+ action="store_true",
525
+ help="Suppress console output. Only write to files if --out-json or --out-md specified.",
526
+ )
527
+
528
+ parser.add_argument(
529
+ "--progress",
530
+ action="store_true",
531
+ help="Show progress indicators during analysis (useful for large models).",
532
+ )
533
+
534
+ parser.add_argument(
535
+ "--offline",
536
+ action="store_true",
537
+ help="Run in offline mode. Fails if any network access is attempted. "
538
+ "Disables LLM summaries and other features requiring internet.",
539
+ )
540
+
541
+ # Privacy controls
542
+ privacy_group = parser.add_argument_group("Privacy Options")
543
+ privacy_group.add_argument(
544
+ "--redact-names",
545
+ action="store_true",
546
+ help="Anonymize layer and tensor names in output (e.g., layer_001, tensor_042). "
547
+ "Useful for sharing reports without revealing proprietary architecture details.",
548
+ )
549
+ privacy_group.add_argument(
550
+ "--summary-only",
551
+ action="store_true",
552
+ help="Output only aggregate statistics (params, FLOPs, memory). "
553
+ "Omit per-layer details and graph structure for maximum privacy.",
554
+ )
555
+
556
+ # Format conversion
557
+ convert_group = parser.add_argument_group("Format Conversion Options")
558
+ convert_group.add_argument(
559
+ "--convert-to",
560
+ type=str,
561
+ choices=["onnx"], # Only ONNX writing is currently supported
562
+ default=None,
563
+ metavar="FORMAT",
564
+ help="Convert the model to the specified format (currently only 'onnx' supported). "
565
+ "Use with --from-pytorch or other input formats.",
566
+ )
567
+ convert_group.add_argument(
568
+ "--convert-output",
569
+ type=pathlib.Path,
570
+ default=None,
571
+ metavar="PATH",
572
+ help="Output path for format conversion. Required when using --convert-to.",
573
+ )
574
+ convert_group.add_argument(
575
+ "--list-conversions",
576
+ action="store_true",
577
+ help="List available format conversion paths and exit.",
578
+ )
579
+
580
+ # Universal IR export
581
+ ir_group = parser.add_argument_group("Universal IR Export Options")
582
+ ir_group.add_argument(
583
+ "--export-ir",
584
+ type=pathlib.Path,
585
+ default=None,
586
+ metavar="PATH",
587
+ help="Export model as Universal IR to JSON file. "
588
+ "Provides format-agnostic representation of the model graph.",
589
+ )
590
+ ir_group.add_argument(
591
+ "--export-graph",
592
+ type=pathlib.Path,
593
+ default=None,
594
+ metavar="PATH",
595
+ help="Export model graph visualization. Supports .dot (Graphviz) and .png formats. "
596
+ "Requires graphviz package for PNG: pip install graphviz",
597
+ )
598
+ ir_group.add_argument(
599
+ "--graph-max-nodes",
600
+ type=int,
601
+ default=500,
602
+ metavar="N",
603
+ help="Maximum nodes to include in graph visualization (default: 500). "
604
+ "Prevents huge graphs from crashing.",
605
+ )
606
+
607
+ return parser.parse_args()
608
+
609
+
610
+ def setup_logging(log_level: str) -> logging.Logger:
611
+ """Configure logging for the CLI."""
612
+ level_map = {
613
+ "debug": logging.DEBUG,
614
+ "info": logging.INFO,
615
+ "warning": logging.WARNING,
616
+ "error": logging.ERROR,
617
+ }
618
+
619
+ logging.basicConfig(
620
+ level=level_map.get(log_level, logging.INFO),
621
+ format="%(levelname)s - %(message)s",
622
+ )
623
+
624
+ return logging.getLogger("haoline")
625
+
626
+
627
+ def _generate_markdown_with_extras(
628
+ report, viz_paths: dict, report_dir: pathlib.Path, llm_summary=None
629
+ ) -> str:
630
+ """Generate markdown with embedded visualizations and LLM summaries."""
631
+ lines = []
632
+ base_md = report.to_markdown()
633
+
634
+ # If we have an LLM summary, insert it after the header
635
+ if llm_summary and llm_summary.success:
636
+ # Insert executive summary after the metadata section
637
+ header_end = base_md.find("## Graph Summary")
638
+ if header_end != -1:
639
+ lines.append(base_md[:header_end])
640
+ lines.append("## Executive Summary\n")
641
+ if llm_summary.short_summary:
642
+ lines.append(f"{llm_summary.short_summary}\n")
643
+ if llm_summary.detailed_summary:
644
+ lines.append(f"\n{llm_summary.detailed_summary}\n")
645
+ lines.append(f"\n*Generated by {llm_summary.model_used}*\n\n")
646
+ base_md = base_md[header_end:]
647
+
648
+ # Split the markdown at the Complexity Metrics section to insert plots
649
+ sections = base_md.split("## Complexity Metrics")
650
+
651
+ if len(sections) < 2:
652
+ # No complexity section found, just append plots at end
653
+ lines.append(base_md)
654
+ else:
655
+ lines.append(sections[0])
656
+
657
+ # Insert visualizations section before Complexity Metrics
658
+ if viz_paths:
659
+ lines.append("## Visualizations\n")
660
+
661
+ if "complexity_summary" in viz_paths:
662
+ rel_path = (
663
+ viz_paths["complexity_summary"].relative_to(report_dir)
664
+ if viz_paths["complexity_summary"].is_relative_to(report_dir)
665
+ else viz_paths["complexity_summary"]
666
+ )
667
+ lines.append("### Complexity Overview\n")
668
+ lines.append(f"![Complexity Summary]({rel_path})\n")
669
+
670
+ if "op_histogram" in viz_paths:
671
+ rel_path = (
672
+ viz_paths["op_histogram"].relative_to(report_dir)
673
+ if viz_paths["op_histogram"].is_relative_to(report_dir)
674
+ else viz_paths["op_histogram"]
675
+ )
676
+ lines.append("### Operator Distribution\n")
677
+ lines.append(f"![Operator Histogram]({rel_path})\n")
678
+
679
+ if "param_distribution" in viz_paths:
680
+ rel_path = (
681
+ viz_paths["param_distribution"].relative_to(report_dir)
682
+ if viz_paths["param_distribution"].is_relative_to(report_dir)
683
+ else viz_paths["param_distribution"]
684
+ )
685
+ lines.append("### Parameter Distribution\n")
686
+ lines.append(f"![Parameter Distribution]({rel_path})\n")
687
+
688
+ if "flops_distribution" in viz_paths:
689
+ rel_path = (
690
+ viz_paths["flops_distribution"].relative_to(report_dir)
691
+ if viz_paths["flops_distribution"].is_relative_to(report_dir)
692
+ else viz_paths["flops_distribution"]
693
+ )
694
+ lines.append("### FLOPs Distribution\n")
695
+ lines.append(f"![FLOPs Distribution]({rel_path})\n")
696
+
697
+ lines.append("")
698
+
699
+ lines.append("## Complexity Metrics" + sections[1])
700
+
701
+ return "\n".join(lines)
702
+
703
+
704
+ def _extract_ultralytics_metadata(
705
+ weights_path: pathlib.Path,
706
+ logger: logging.Logger,
707
+ ) -> dict[str, Any] | None:
708
+ """
709
+ Extract metadata from an Ultralytics model (.pt file).
710
+
711
+ Returns dict with task, num_classes, class_names or None if not Ultralytics.
712
+ """
713
+ try:
714
+ from ultralytics import YOLO
715
+
716
+ model = YOLO(str(weights_path))
717
+
718
+ return {
719
+ "task": model.task,
720
+ "num_classes": len(model.names),
721
+ "class_names": list(model.names.values()),
722
+ "source": "ultralytics",
723
+ }
724
+ except ImportError:
725
+ logger.debug("ultralytics not installed, skipping metadata extraction")
726
+ return None
727
+ except Exception as e:
728
+ logger.debug(f"Could not extract Ultralytics metadata: {e}")
729
+ return None
730
+
731
+
732
+ def _convert_pytorch_to_onnx(
733
+ pytorch_path: pathlib.Path,
734
+ input_shape_str: str | None,
735
+ output_path: pathlib.Path | None,
736
+ opset_version: int,
737
+ logger: logging.Logger,
738
+ ) -> tuple[pathlib.Path | None, Any]:
739
+ """
740
+ Convert a PyTorch model to ONNX format.
741
+
742
+ Args:
743
+ pytorch_path: Path to PyTorch model (.pth, .pt)
744
+ input_shape_str: Input shape as comma-separated string, e.g., "1,3,224,224"
745
+ output_path: Where to save ONNX file (None = temp file)
746
+ opset_version: ONNX opset version
747
+ logger: Logger instance
748
+
749
+ Returns:
750
+ Tuple of (onnx_path, temp_file_handle_or_None)
751
+ """
752
+ # Check if torch is available
753
+ try:
754
+ import torch
755
+ except ImportError:
756
+ logger.error("PyTorch not installed. Install with: pip install torch")
757
+ return None, None
758
+
759
+ pytorch_path = pytorch_path.resolve()
760
+ if not pytorch_path.exists():
761
+ logger.error(f"PyTorch model not found: {pytorch_path}")
762
+ return None, None
763
+
764
+ # Parse input shape
765
+ if not input_shape_str:
766
+ logger.error(
767
+ "--input-shape is required for PyTorch conversion. Example: --input-shape 1,3,224,224"
768
+ )
769
+ return None, None
770
+
771
+ try:
772
+ input_shape = tuple(int(x.strip()) for x in input_shape_str.split(","))
773
+ logger.info(f"Input shape: {input_shape}")
774
+ except ValueError:
775
+ logger.error(
776
+ f"Invalid --input-shape format: '{input_shape_str}'. "
777
+ "Use comma-separated integers, e.g., '1,3,224,224'"
778
+ )
779
+ return None, None
780
+
781
+ # Load PyTorch model
782
+ logger.info(f"Loading PyTorch model from: {pytorch_path}")
783
+ model = None
784
+
785
+ # Try 1: TorchScript model (.pt files from torch.jit.save)
786
+ try:
787
+ model = torch.jit.load(str(pytorch_path), map_location="cpu")
788
+ logger.info(f"Loaded TorchScript model: {type(model).__name__}")
789
+ except Exception:
790
+ pass
791
+
792
+ # Try 2: Check for Ultralytics YOLO format first
793
+ if model is None:
794
+ try:
795
+ loaded = torch.load(pytorch_path, map_location="cpu", weights_only=False)
796
+
797
+ if isinstance(loaded, dict):
798
+ # Check if it's an Ultralytics model (has 'model' key with the actual model)
799
+ if "model" in loaded and hasattr(loaded.get("model"), "forward"):
800
+ logger.info("Detected Ultralytics YOLO format, using native export...")
801
+ try:
802
+ from ultralytics import YOLO
803
+
804
+ yolo_model = YOLO(str(pytorch_path))
805
+
806
+ # Determine output path for Ultralytics export
807
+ if output_path:
808
+ onnx_out = output_path.resolve()
809
+ else:
810
+ import tempfile as tf
811
+
812
+ temp = tf.NamedTemporaryFile(suffix=".onnx", delete=False)
813
+ onnx_out = pathlib.Path(temp.name)
814
+ temp.close()
815
+
816
+ # Export using Ultralytics (handles all the complexity)
817
+ yolo_model.export(
818
+ format="onnx",
819
+ imgsz=input_shape[2] if len(input_shape) >= 3 else 640,
820
+ simplify=True,
821
+ opset=opset_version,
822
+ )
823
+
824
+ # Ultralytics saves next to the .pt file, move if needed
825
+ default_onnx = pytorch_path.with_suffix(".onnx")
826
+ if default_onnx.exists() and default_onnx != onnx_out:
827
+ import shutil
828
+
829
+ shutil.move(str(default_onnx), str(onnx_out))
830
+
831
+ logger.info(f"ONNX model saved to: {onnx_out}")
832
+ return onnx_out, None if output_path else onnx_out
833
+
834
+ except ImportError:
835
+ logger.error(
836
+ "Ultralytics YOLO model detected but 'ultralytics' package not installed.\n"
837
+ "Install with: pip install ultralytics\n"
838
+ "Then re-run this command."
839
+ )
840
+ return None, None
841
+
842
+ # It's a regular state_dict - we can't use it directly
843
+ logger.error(
844
+ "Model file appears to be a state_dict (weights only). "
845
+ "To convert, you need either:\n"
846
+ " 1. A TorchScript model: torch.jit.save(torch.jit.script(model), 'model.pt')\n"
847
+ " 2. A full model: torch.save(model, 'model.pth') # run from same codebase\n"
848
+ " 3. Export to ONNX directly in your training code using torch.onnx.export()"
849
+ )
850
+ return None, None
851
+
852
+ model = loaded
853
+ logger.info(f"Loaded PyTorch model: {type(model).__name__}")
854
+
855
+ except Exception as e:
856
+ error_msg = str(e)
857
+ if "Can't get attribute" in error_msg:
858
+ logger.error(
859
+ "Failed to load model - class definition not found.\n"
860
+ "The model was saved with torch.save(model, ...) which requires "
861
+ "the original class to be importable.\n\n"
862
+ "Solutions:\n"
863
+ " 1. Save as TorchScript: torch.jit.save(torch.jit.script(model), 'model.pt')\n"
864
+ " 2. Export to ONNX in your code: torch.onnx.export(model, dummy_input, 'model.onnx')\n"
865
+ " 3. Run this tool from the directory containing your model definition"
866
+ )
867
+ else:
868
+ logger.error(f"Failed to load PyTorch model: {e}")
869
+ return None, None
870
+
871
+ if model is None:
872
+ logger.error("Could not load the PyTorch model.")
873
+ return None, None
874
+
875
+ model.eval()
876
+
877
+ # Create dummy input
878
+ try:
879
+ dummy_input = torch.randn(*input_shape)
880
+ logger.info(f"Created dummy input with shape: {dummy_input.shape}")
881
+ except Exception as e:
882
+ logger.error(f"Failed to create input tensor: {e}")
883
+ return None, None
884
+
885
+ # Determine output path
886
+ temp_file = None
887
+ if output_path:
888
+ onnx_path = output_path.resolve()
889
+ onnx_path.parent.mkdir(parents=True, exist_ok=True)
890
+ else:
891
+ import tempfile
892
+
893
+ temp_file = tempfile.NamedTemporaryFile(suffix=".onnx", delete=False)
894
+ onnx_path = pathlib.Path(temp_file.name)
895
+ temp_file.close()
896
+
897
+ # Export to ONNX
898
+ logger.info(f"Exporting to ONNX (opset {opset_version})...")
899
+ try:
900
+ torch.onnx.export(
901
+ model,
902
+ dummy_input,
903
+ str(onnx_path),
904
+ input_names=["input"],
905
+ output_names=["output"],
906
+ dynamic_axes={
907
+ "input": {0: "batch_size"},
908
+ "output": {0: "batch_size"},
909
+ },
910
+ opset_version=opset_version,
911
+ do_constant_folding=True,
912
+ )
913
+ logger.info(f"ONNX model saved to: {onnx_path}")
914
+
915
+ # Verify the export
916
+ import onnx
917
+
918
+ onnx_model = onnx.load(str(onnx_path))
919
+ onnx.checker.check_model(onnx_model)
920
+ logger.info("ONNX model validated successfully")
921
+
922
+ except Exception as e:
923
+ logger.error(f"ONNX export failed: {e}")
924
+ if temp_file:
925
+ try:
926
+ onnx_path.unlink()
927
+ except Exception:
928
+ pass
929
+ return None, None
930
+
931
+ return onnx_path, temp_file
932
+
933
+
934
+ def _convert_tensorflow_to_onnx(
935
+ tf_path: pathlib.Path,
936
+ output_path: pathlib.Path | None,
937
+ opset_version: int,
938
+ logger: logging.Logger,
939
+ ) -> tuple[pathlib.Path | None, Any]:
940
+ """
941
+ Convert a TensorFlow SavedModel to ONNX format.
942
+
943
+ Args:
944
+ tf_path: Path to TensorFlow SavedModel directory
945
+ output_path: Where to save ONNX file (None = temp file)
946
+ opset_version: ONNX opset version
947
+ logger: Logger instance
948
+
949
+ Returns:
950
+ Tuple of (onnx_path, temp_file_handle_or_None)
951
+ """
952
+ # Check if tf2onnx is available
953
+ try:
954
+ import tf2onnx
955
+ from tf2onnx import tf_loader
956
+ except ImportError:
957
+ logger.error("tf2onnx not installed. Install with: pip install tf2onnx tensorflow")
958
+ return None, None
959
+
960
+ tf_path = tf_path.resolve()
961
+ if not tf_path.exists():
962
+ logger.error(f"TensorFlow model not found: {tf_path}")
963
+ return None, None
964
+
965
+ # Determine output path
966
+ temp_file = None
967
+ if output_path:
968
+ onnx_path = output_path.resolve()
969
+ onnx_path.parent.mkdir(parents=True, exist_ok=True)
970
+ else:
971
+ import tempfile
972
+
973
+ temp_file = tempfile.NamedTemporaryFile(suffix=".onnx", delete=False)
974
+ onnx_path = pathlib.Path(temp_file.name)
975
+ temp_file.close()
976
+
977
+ logger.info("Converting TensorFlow SavedModel to ONNX...")
978
+ logger.info(f" Source: {tf_path}")
979
+ logger.info(f" Target: {onnx_path}")
980
+
981
+ try:
982
+ import subprocess
983
+ import sys
984
+
985
+ # Use tf2onnx CLI for robustness (handles TF version quirks)
986
+ cmd = [
987
+ sys.executable,
988
+ "-m",
989
+ "tf2onnx.convert",
990
+ "--saved-model",
991
+ str(tf_path),
992
+ "--output",
993
+ str(onnx_path),
994
+ "--opset",
995
+ str(opset_version),
996
+ ]
997
+
998
+ result = subprocess.run(
999
+ cmd,
1000
+ check=False,
1001
+ capture_output=True,
1002
+ text=True,
1003
+ timeout=600, # 10 minute timeout for large models
1004
+ )
1005
+
1006
+ if result.returncode != 0:
1007
+ logger.error(f"tf2onnx conversion failed:\n{result.stderr}")
1008
+ if temp_file:
1009
+ try:
1010
+ onnx_path.unlink()
1011
+ except Exception:
1012
+ pass
1013
+ return None, None
1014
+
1015
+ logger.info("TensorFlow model converted successfully")
1016
+
1017
+ # Verify the export
1018
+ import onnx
1019
+
1020
+ onnx_model = onnx.load(str(onnx_path))
1021
+ onnx.checker.check_model(onnx_model)
1022
+ logger.info("ONNX model validated successfully")
1023
+
1024
+ except subprocess.TimeoutExpired:
1025
+ logger.error("tf2onnx conversion timed out after 10 minutes")
1026
+ if temp_file:
1027
+ try:
1028
+ onnx_path.unlink()
1029
+ except Exception:
1030
+ pass
1031
+ return None, None
1032
+ except Exception as e:
1033
+ logger.error(f"TensorFlow conversion failed: {e}")
1034
+ if temp_file:
1035
+ try:
1036
+ onnx_path.unlink()
1037
+ except Exception:
1038
+ pass
1039
+ return None, None
1040
+
1041
+ return onnx_path, temp_file
1042
+
1043
+
1044
+ def _convert_keras_to_onnx(
1045
+ keras_path: pathlib.Path,
1046
+ output_path: pathlib.Path | None,
1047
+ opset_version: int,
1048
+ logger: logging.Logger,
1049
+ ) -> tuple[pathlib.Path | None, Any]:
1050
+ """
1051
+ Convert a Keras model (.h5, .keras) to ONNX format.
1052
+
1053
+ Args:
1054
+ keras_path: Path to Keras model file (.h5 or .keras)
1055
+ output_path: Where to save ONNX file (None = temp file)
1056
+ opset_version: ONNX opset version
1057
+ logger: Logger instance
1058
+
1059
+ Returns:
1060
+ Tuple of (onnx_path, temp_file_handle_or_None)
1061
+ """
1062
+ # Check if tf2onnx is available
1063
+ try:
1064
+ import tf2onnx
1065
+ except ImportError:
1066
+ logger.error("tf2onnx not installed. Install with: pip install tf2onnx tensorflow")
1067
+ return None, None
1068
+
1069
+ keras_path = keras_path.resolve()
1070
+ if not keras_path.exists():
1071
+ logger.error(f"Keras model not found: {keras_path}")
1072
+ return None, None
1073
+
1074
+ suffix = keras_path.suffix.lower()
1075
+ if suffix not in (".h5", ".keras", ".hdf5"):
1076
+ logger.warning(f"Unexpected Keras file extension: {suffix}. Proceeding anyway.")
1077
+
1078
+ # Determine output path
1079
+ temp_file = None
1080
+ if output_path:
1081
+ onnx_path = output_path.resolve()
1082
+ onnx_path.parent.mkdir(parents=True, exist_ok=True)
1083
+ else:
1084
+ import tempfile
1085
+
1086
+ temp_file = tempfile.NamedTemporaryFile(suffix=".onnx", delete=False)
1087
+ onnx_path = pathlib.Path(temp_file.name)
1088
+ temp_file.close()
1089
+
1090
+ logger.info("Converting Keras model to ONNX...")
1091
+ logger.info(f" Source: {keras_path}")
1092
+ logger.info(f" Target: {onnx_path}")
1093
+
1094
+ try:
1095
+ import subprocess
1096
+ import sys
1097
+
1098
+ # Use tf2onnx CLI with --keras flag
1099
+ cmd = [
1100
+ sys.executable,
1101
+ "-m",
1102
+ "tf2onnx.convert",
1103
+ "--keras",
1104
+ str(keras_path),
1105
+ "--output",
1106
+ str(onnx_path),
1107
+ "--opset",
1108
+ str(opset_version),
1109
+ ]
1110
+
1111
+ result = subprocess.run(
1112
+ cmd,
1113
+ check=False,
1114
+ capture_output=True,
1115
+ text=True,
1116
+ timeout=600, # 10 minute timeout
1117
+ )
1118
+
1119
+ if result.returncode != 0:
1120
+ logger.error(f"tf2onnx conversion failed:\n{result.stderr}")
1121
+ if temp_file:
1122
+ try:
1123
+ onnx_path.unlink()
1124
+ except Exception:
1125
+ pass
1126
+ return None, None
1127
+
1128
+ logger.info("Keras model converted successfully")
1129
+
1130
+ # Verify the export
1131
+ import onnx
1132
+
1133
+ onnx_model = onnx.load(str(onnx_path))
1134
+ onnx.checker.check_model(onnx_model)
1135
+ logger.info("ONNX model validated successfully")
1136
+
1137
+ except subprocess.TimeoutExpired:
1138
+ logger.error("tf2onnx conversion timed out after 10 minutes")
1139
+ if temp_file:
1140
+ try:
1141
+ onnx_path.unlink()
1142
+ except Exception:
1143
+ pass
1144
+ return None, None
1145
+ except Exception as e:
1146
+ logger.error(f"Keras conversion failed: {e}")
1147
+ if temp_file:
1148
+ try:
1149
+ onnx_path.unlink()
1150
+ except Exception:
1151
+ pass
1152
+ return None, None
1153
+
1154
+ return onnx_path, temp_file
1155
+
1156
+
1157
+ def _convert_frozen_graph_to_onnx(
1158
+ pb_path: pathlib.Path,
1159
+ inputs: str,
1160
+ outputs: str,
1161
+ output_path: pathlib.Path | None,
1162
+ opset_version: int,
1163
+ logger: logging.Logger,
1164
+ ) -> tuple[pathlib.Path | None, Any]:
1165
+ """
1166
+ Convert a TensorFlow frozen graph (.pb) to ONNX format.
1167
+
1168
+ Args:
1169
+ pb_path: Path to frozen graph .pb file
1170
+ inputs: Comma-separated input tensor names (e.g., "input:0")
1171
+ outputs: Comma-separated output tensor names (e.g., "output:0")
1172
+ output_path: Where to save ONNX file (None = temp file)
1173
+ opset_version: ONNX opset version
1174
+ logger: Logger instance
1175
+
1176
+ Returns:
1177
+ Tuple of (onnx_path, temp_file_handle_or_None)
1178
+ """
1179
+ # Check if tf2onnx is available
1180
+ try:
1181
+ import tf2onnx
1182
+ except ImportError:
1183
+ logger.error("tf2onnx not installed. Install with: pip install tf2onnx tensorflow")
1184
+ return None, None
1185
+
1186
+ pb_path = pb_path.resolve()
1187
+ if not pb_path.exists():
1188
+ logger.error(f"Frozen graph not found: {pb_path}")
1189
+ return None, None
1190
+
1191
+ if not inputs or not outputs:
1192
+ logger.error(
1193
+ "--tf-inputs and --tf-outputs are required for frozen graph conversion.\n"
1194
+ "Example: --from-frozen-graph model.pb --tf-inputs input:0 --tf-outputs output:0"
1195
+ )
1196
+ return None, None
1197
+
1198
+ # Determine output path
1199
+ temp_file = None
1200
+ if output_path:
1201
+ onnx_path = output_path.resolve()
1202
+ onnx_path.parent.mkdir(parents=True, exist_ok=True)
1203
+ else:
1204
+ import tempfile
1205
+
1206
+ temp_file = tempfile.NamedTemporaryFile(suffix=".onnx", delete=False)
1207
+ onnx_path = pathlib.Path(temp_file.name)
1208
+ temp_file.close()
1209
+
1210
+ logger.info("Converting TensorFlow frozen graph to ONNX...")
1211
+ logger.info(f" Source: {pb_path}")
1212
+ logger.info(f" Inputs: {inputs}")
1213
+ logger.info(f" Outputs: {outputs}")
1214
+ logger.info(f" Target: {onnx_path}")
1215
+
1216
+ try:
1217
+ import subprocess
1218
+ import sys
1219
+
1220
+ # Use tf2onnx CLI with --graphdef flag
1221
+ cmd = [
1222
+ sys.executable,
1223
+ "-m",
1224
+ "tf2onnx.convert",
1225
+ "--graphdef",
1226
+ str(pb_path),
1227
+ "--inputs",
1228
+ inputs,
1229
+ "--outputs",
1230
+ outputs,
1231
+ "--output",
1232
+ str(onnx_path),
1233
+ "--opset",
1234
+ str(opset_version),
1235
+ ]
1236
+
1237
+ result = subprocess.run(
1238
+ cmd,
1239
+ check=False,
1240
+ capture_output=True,
1241
+ text=True,
1242
+ timeout=600, # 10 minute timeout
1243
+ )
1244
+
1245
+ if result.returncode != 0:
1246
+ logger.error(f"tf2onnx conversion failed:\n{result.stderr}")
1247
+ if temp_file:
1248
+ try:
1249
+ onnx_path.unlink()
1250
+ except Exception:
1251
+ pass
1252
+ return None, None
1253
+
1254
+ logger.info("Frozen graph converted successfully")
1255
+
1256
+ # Verify the export
1257
+ import onnx
1258
+
1259
+ onnx_model = onnx.load(str(onnx_path))
1260
+ onnx.checker.check_model(onnx_model)
1261
+ logger.info("ONNX model validated successfully")
1262
+
1263
+ except subprocess.TimeoutExpired:
1264
+ logger.error("tf2onnx conversion timed out after 10 minutes")
1265
+ if temp_file:
1266
+ try:
1267
+ onnx_path.unlink()
1268
+ except Exception:
1269
+ pass
1270
+ return None, None
1271
+ except Exception as e:
1272
+ logger.error(f"Frozen graph conversion failed: {e}")
1273
+ if temp_file:
1274
+ try:
1275
+ onnx_path.unlink()
1276
+ except Exception:
1277
+ pass
1278
+ return None, None
1279
+
1280
+ return onnx_path, temp_file
1281
+
1282
+
1283
+ def _convert_jax_to_onnx(
1284
+ jax_path: pathlib.Path,
1285
+ apply_fn_path: str | None,
1286
+ input_shape_str: str | None,
1287
+ output_path: pathlib.Path | None,
1288
+ opset_version: int,
1289
+ logger: logging.Logger,
1290
+ ) -> tuple[pathlib.Path | None, Any]:
1291
+ """
1292
+ Convert a JAX/Flax model to ONNX format.
1293
+
1294
+ This is more complex than other conversions because JAX doesn't have a standard
1295
+ serialization format. The typical flow is:
1296
+ 1. Load model params from file (.msgpack, .pkl, etc.)
1297
+ 2. Import the apply function from user's code
1298
+ 3. Convert JAX -> TF SavedModel -> ONNX
1299
+
1300
+ Args:
1301
+ jax_path: Path to JAX params file (.msgpack, .pkl, .npy)
1302
+ apply_fn_path: Module:function path to the apply function
1303
+ input_shape_str: Input shape for tracing
1304
+ output_path: Where to save ONNX file (None = temp file)
1305
+ opset_version: ONNX opset version
1306
+ logger: Logger instance
1307
+
1308
+ Returns:
1309
+ Tuple of (onnx_path, temp_file_handle_or_None)
1310
+ """
1311
+ # Check dependencies
1312
+ try:
1313
+ import jax
1314
+ import jax.numpy as jnp
1315
+ except ImportError:
1316
+ logger.error("JAX not installed. Install with: pip install jax jaxlib")
1317
+ return None, None
1318
+
1319
+ try:
1320
+ import tf2onnx
1321
+ except ImportError:
1322
+ logger.error("tf2onnx not installed. Install with: pip install tf2onnx tensorflow")
1323
+ return None, None
1324
+
1325
+ jax_path = jax_path.resolve()
1326
+ if not jax_path.exists():
1327
+ logger.error(f"JAX params file not found: {jax_path}")
1328
+ return None, None
1329
+
1330
+ if not apply_fn_path:
1331
+ logger.error(
1332
+ "--jax-apply-fn is required for JAX conversion.\n"
1333
+ "Example: --from-jax params.pkl --jax-apply-fn my_model:apply --input-shape 1,3,224,224"
1334
+ )
1335
+ return None, None
1336
+
1337
+ if not input_shape_str:
1338
+ logger.error(
1339
+ "--input-shape is required for JAX conversion.\n"
1340
+ "Example: --from-jax params.pkl --jax-apply-fn my_model:apply --input-shape 1,3,224,224"
1341
+ )
1342
+ return None, None
1343
+
1344
+ # Parse input shape
1345
+ try:
1346
+ input_shape = tuple(int(x.strip()) for x in input_shape_str.split(","))
1347
+ except ValueError:
1348
+ logger.error(
1349
+ f"Invalid --input-shape format: '{input_shape_str}'. "
1350
+ "Use comma-separated integers, e.g., '1,3,224,224'"
1351
+ )
1352
+ return None, None
1353
+
1354
+ # Parse apply function path (module:function)
1355
+ if ":" not in apply_fn_path:
1356
+ logger.error(
1357
+ f"Invalid --jax-apply-fn format: '{apply_fn_path}'. "
1358
+ "Use module:function format, e.g., 'my_model:apply'"
1359
+ )
1360
+ return None, None
1361
+
1362
+ module_path, fn_name = apply_fn_path.rsplit(":", 1)
1363
+
1364
+ # Determine output path
1365
+ temp_file = None
1366
+ if output_path:
1367
+ onnx_path = output_path.resolve()
1368
+ onnx_path.parent.mkdir(parents=True, exist_ok=True)
1369
+ else:
1370
+ import tempfile
1371
+
1372
+ temp_file = tempfile.NamedTemporaryFile(suffix=".onnx", delete=False)
1373
+ onnx_path = pathlib.Path(temp_file.name)
1374
+ temp_file.close()
1375
+
1376
+ logger.info("Converting JAX model to ONNX...")
1377
+ logger.info(f" Params: {jax_path}")
1378
+ logger.info(f" Apply fn: {apply_fn_path}")
1379
+ logger.info(f" Input shape: {input_shape}")
1380
+ logger.info(f" Target: {onnx_path}")
1381
+
1382
+ try:
1383
+ import importlib.util
1384
+ import sys
1385
+
1386
+ # Load params
1387
+ suffix = jax_path.suffix.lower()
1388
+ if suffix == ".msgpack":
1389
+ try:
1390
+ from flax.serialization import msgpack_restore
1391
+
1392
+ with open(jax_path, "rb") as f:
1393
+ params = msgpack_restore(f.read())
1394
+ logger.info("Loaded Flax msgpack params")
1395
+ except ImportError:
1396
+ logger.error("Flax not installed. Install with: pip install flax")
1397
+ return None, None
1398
+ elif suffix in (".pkl", ".pickle"):
1399
+ import pickle
1400
+
1401
+ with open(jax_path, "rb") as f:
1402
+ params = pickle.load(f)
1403
+ logger.info("Loaded pickle params")
1404
+ elif suffix == ".npy":
1405
+ import numpy as np
1406
+
1407
+ params = np.load(jax_path, allow_pickle=True).item()
1408
+ logger.info("Loaded numpy params")
1409
+ else:
1410
+ logger.error(f"Unsupported params format: {suffix}. Use .msgpack, .pkl, or .npy")
1411
+ return None, None
1412
+
1413
+ # Import apply function
1414
+ # Add current directory to path for local imports
1415
+ sys.path.insert(0, str(pathlib.Path.cwd()))
1416
+
1417
+ try:
1418
+ module = importlib.import_module(module_path)
1419
+ apply_fn = getattr(module, fn_name)
1420
+ logger.info(f"Loaded apply function: {apply_fn_path}")
1421
+ except (ImportError, AttributeError) as e:
1422
+ logger.error(f"Could not import {apply_fn_path}: {e}")
1423
+ return None, None
1424
+
1425
+ # Convert via jax2tf (JAX -> TF -> ONNX)
1426
+ try:
1427
+ import tensorflow as tf
1428
+ from jax.experimental import jax2tf
1429
+ except ImportError:
1430
+ logger.error("jax2tf or TensorFlow not available. Install with: pip install tensorflow")
1431
+ return None, None
1432
+
1433
+ # Create a concrete function
1434
+ jnp.zeros(input_shape, dtype=jnp.float32)
1435
+
1436
+ # Convert JAX function to TF
1437
+ tf_fn = jax2tf.convert(
1438
+ lambda x: apply_fn(params, x),
1439
+ enable_xla=False,
1440
+ )
1441
+
1442
+ # Create TF SavedModel
1443
+ import tempfile as tf_tempfile
1444
+
1445
+ with tf_tempfile.TemporaryDirectory() as tf_model_dir:
1446
+ # Wrap in tf.Module for SavedModel export
1447
+ class TFModule(tf.Module):
1448
+ @tf.function(input_signature=[tf.TensorSpec(input_shape, tf.float32)])
1449
+ def __call__(self, x):
1450
+ return tf_fn(x)
1451
+
1452
+ tf_module = TFModule()
1453
+ tf.saved_model.save(tf_module, tf_model_dir)
1454
+ logger.info("Created temporary TF SavedModel")
1455
+
1456
+ # Convert TF SavedModel to ONNX
1457
+ import subprocess
1458
+
1459
+ cmd = [
1460
+ sys.executable,
1461
+ "-m",
1462
+ "tf2onnx.convert",
1463
+ "--saved-model",
1464
+ tf_model_dir,
1465
+ "--output",
1466
+ str(onnx_path),
1467
+ "--opset",
1468
+ str(opset_version),
1469
+ ]
1470
+
1471
+ result = subprocess.run(
1472
+ cmd,
1473
+ check=False,
1474
+ capture_output=True,
1475
+ text=True,
1476
+ timeout=600,
1477
+ )
1478
+
1479
+ if result.returncode != 0:
1480
+ logger.error(f"tf2onnx conversion failed:\n{result.stderr}")
1481
+ if temp_file:
1482
+ try:
1483
+ onnx_path.unlink()
1484
+ except Exception:
1485
+ pass
1486
+ return None, None
1487
+
1488
+ logger.info("JAX model converted successfully")
1489
+
1490
+ # Verify the export
1491
+ import onnx
1492
+
1493
+ onnx_model = onnx.load(str(onnx_path))
1494
+ onnx.checker.check_model(onnx_model)
1495
+ logger.info("ONNX model validated successfully")
1496
+
1497
+ except Exception as e:
1498
+ logger.error(f"JAX conversion failed: {e}")
1499
+ import traceback
1500
+
1501
+ logger.debug(traceback.format_exc())
1502
+ if temp_file:
1503
+ try:
1504
+ onnx_path.unlink()
1505
+ except Exception:
1506
+ pass
1507
+ return None, None
1508
+
1509
+ return onnx_path, temp_file
1510
+
1511
+
1512
+ def run_inspect():
1513
+ """Main entry point for the model_inspect CLI."""
1514
+ # Load environment variables from .env file if present
1515
+ try:
1516
+ from dotenv import load_dotenv
1517
+
1518
+ load_dotenv()
1519
+ except ImportError:
1520
+ pass # python-dotenv not installed, use environment variables directly
1521
+
1522
+ args = parse_args()
1523
+ logger = setup_logging(args.log_level)
1524
+
1525
+ # Handle --schema
1526
+ if args.schema:
1527
+ import json
1528
+
1529
+ from .schema import get_schema
1530
+
1531
+ schema = get_schema()
1532
+ print(json.dumps(schema, indent=2))
1533
+ return 0
1534
+
1535
+ # Handle --list-hardware
1536
+ if args.list_hardware:
1537
+ print("\n" + "=" * 70)
1538
+ print("Available Hardware Profiles")
1539
+ print("=" * 70)
1540
+
1541
+ print("\nData Center GPUs - H100 Series:")
1542
+ for name in ["h100-sxm", "h100-pcie", "h100-nvl"]:
1543
+ profile = get_profile(name)
1544
+ if profile:
1545
+ print(
1546
+ f" {name:20} {profile.name:30} {profile.vram_bytes // (1024**3):3} GB {profile.peak_fp16_tflops:6.1f} TF16"
1547
+ )
1548
+
1549
+ print("\nData Center GPUs - A100 Series:")
1550
+ for name in [
1551
+ "a100-80gb-sxm",
1552
+ "a100-80gb-pcie",
1553
+ "a100-40gb-sxm",
1554
+ "a100-40gb-pcie",
1555
+ ]:
1556
+ profile = get_profile(name)
1557
+ if profile:
1558
+ print(
1559
+ f" {name:20} {profile.name:30} {profile.vram_bytes // (1024**3):3} GB {profile.peak_fp16_tflops:6.1f} TF16"
1560
+ )
1561
+
1562
+ print("\nData Center GPUs - Other:")
1563
+ for name in ["a10", "l4", "l40", "l40s", "t4"]:
1564
+ profile = get_profile(name)
1565
+ if profile:
1566
+ print(
1567
+ f" {name:20} {profile.name:30} {profile.vram_bytes // (1024**3):3} GB {profile.peak_fp16_tflops:6.1f} TF16"
1568
+ )
1569
+
1570
+ print("\nData Center GPUs - V100 Series:")
1571
+ for name in [
1572
+ "v100-32gb-sxm",
1573
+ "v100-32gb-pcie",
1574
+ "v100-16gb-sxm",
1575
+ "v100-16gb-pcie",
1576
+ ]:
1577
+ profile = get_profile(name)
1578
+ if profile:
1579
+ print(
1580
+ f" {name:20} {profile.name:30} {profile.vram_bytes // (1024**3):3} GB {profile.peak_fp16_tflops:6.1f} TF16"
1581
+ )
1582
+
1583
+ print("\nDGX Systems (Multi-GPU):")
1584
+ for name in ["dgx-h100", "dgx-a100-640gb", "dgx-a100-320gb"]:
1585
+ profile = get_profile(name)
1586
+ if profile:
1587
+ print(
1588
+ f" {name:20} {profile.name:30} {profile.vram_bytes // (1024**3):3} GB {profile.peak_fp16_tflops:6.1f} TF16"
1589
+ )
1590
+
1591
+ print("\nJetson Edge/Embedded (Orin Series - Recommended for new projects):")
1592
+ for name in [
1593
+ "jetson-agx-orin-64gb",
1594
+ "jetson-agx-orin-32gb",
1595
+ "jetson-orin-nx-16gb",
1596
+ "jetson-orin-nx-8gb",
1597
+ "jetson-orin-nano-8gb",
1598
+ "jetson-orin-nano-4gb",
1599
+ ]:
1600
+ profile = get_profile(name)
1601
+ if profile:
1602
+ print(
1603
+ f" {name:20} {profile.name:30} {profile.vram_bytes // (1024**3):3} GB {profile.peak_fp16_tflops:6.1f} TF16"
1604
+ )
1605
+
1606
+ print("\nJetson Edge/Embedded (Xavier Series):")
1607
+ for name in ["jetson-agx-xavier", "jetson-xavier-nx-8gb"]:
1608
+ profile = get_profile(name)
1609
+ if profile:
1610
+ print(
1611
+ f" {name:20} {profile.name:30} {profile.vram_bytes // (1024**3):3} GB {profile.peak_fp16_tflops:6.1f} TF16"
1612
+ )
1613
+
1614
+ print("\nJetson Edge/Embedded (Legacy - Very Constrained!):")
1615
+ for name in ["jetson-tx2", "jetson-nano", "jetson-nano-2gb"]:
1616
+ profile = get_profile(name)
1617
+ if profile:
1618
+ vram_gb = profile.vram_bytes / (1024**3)
1619
+ print(
1620
+ f" {name:20} {profile.name:30} {vram_gb:3.0f} GB {profile.peak_fp16_tflops:6.3f} TF16"
1621
+ )
1622
+
1623
+ print("\nConsumer GPUs - RTX 40 Series:")
1624
+ for name in [
1625
+ "rtx4090",
1626
+ "4080-super",
1627
+ "rtx4080",
1628
+ "4070-ti-super",
1629
+ "4070-ti",
1630
+ "4070-super",
1631
+ "rtx4070",
1632
+ "4060-ti-16gb",
1633
+ "rtx4060",
1634
+ ]:
1635
+ profile = get_profile(name)
1636
+ if profile:
1637
+ print(
1638
+ f" {name:20} {profile.name:30} {profile.vram_bytes // (1024**3):3} GB {profile.peak_fp16_tflops:6.1f} TF16"
1639
+ )
1640
+
1641
+ print("\nConsumer GPUs - RTX 30 Series:")
1642
+ for name in [
1643
+ "3090-ti",
1644
+ "rtx3090",
1645
+ "3080-ti",
1646
+ "3080-12gb",
1647
+ "rtx3080",
1648
+ "3070-ti",
1649
+ "rtx3070",
1650
+ "3060-ti",
1651
+ "rtx3060",
1652
+ "rtx3050",
1653
+ ]:
1654
+ profile = get_profile(name)
1655
+ if profile:
1656
+ print(
1657
+ f" {name:20} {profile.name:30} {profile.vram_bytes // (1024**3):3} GB {profile.peak_fp16_tflops:6.1f} TF16"
1658
+ )
1659
+
1660
+ print("\nLaptop/Mobile GPUs:")
1661
+ for name in [
1662
+ "4090-mobile",
1663
+ "4080-mobile",
1664
+ "4070-mobile",
1665
+ "3080-mobile",
1666
+ "3070-mobile",
1667
+ ]:
1668
+ profile = get_profile(name)
1669
+ if profile:
1670
+ print(
1671
+ f" {name:20} {profile.name:30} {profile.vram_bytes // (1024**3):3} GB {profile.peak_fp16_tflops:6.1f} TF16"
1672
+ )
1673
+
1674
+ print("\nOther:")
1675
+ print(" auto Auto-detect local GPU/CPU")
1676
+ print(" cpu Generic CPU profile")
1677
+
1678
+ print("\n" + "-" * 70)
1679
+ print("TF16 = Peak FP16 TFLOPS (higher = faster)")
1680
+ print("Use --gpu-count N for multi-GPU estimates")
1681
+ print("Use --list-cloud for cloud instance options")
1682
+ print("=" * 70 + "\n")
1683
+ sys.exit(0)
1684
+
1685
+ # Handle --list-cloud
1686
+ if args.list_cloud:
1687
+ print("\n" + "=" * 70)
1688
+ print("Available Cloud Instance Profiles")
1689
+ print("=" * 70)
1690
+
1691
+ print("\nAWS GPU Instances:")
1692
+ for name, instance in CLOUD_INSTANCES.items():
1693
+ if instance.provider == "aws":
1694
+ vram_gb = instance.hardware.vram_bytes * instance.gpu_count // (1024**3)
1695
+ print(
1696
+ f" {name:25} {instance.gpu_count}x GPU {vram_gb:4} GB ${instance.hourly_cost_usd:6.2f}/hr"
1697
+ )
1698
+
1699
+ print("\nAzure GPU Instances:")
1700
+ for name, instance in CLOUD_INSTANCES.items():
1701
+ if instance.provider == "azure":
1702
+ vram_gb = instance.hardware.vram_bytes * instance.gpu_count // (1024**3)
1703
+ print(
1704
+ f" {name:25} {instance.gpu_count}x GPU {vram_gb:4} GB ${instance.hourly_cost_usd:6.2f}/hr"
1705
+ )
1706
+
1707
+ print("\nGCP GPU Instances:")
1708
+ for name, instance in CLOUD_INSTANCES.items():
1709
+ if instance.provider == "gcp":
1710
+ vram_gb = instance.hardware.vram_bytes * instance.gpu_count // (1024**3)
1711
+ print(
1712
+ f" {name:25} {instance.gpu_count}x GPU {vram_gb:4} GB ${instance.hourly_cost_usd:6.2f}/hr"
1713
+ )
1714
+
1715
+ print("\n" + "-" * 70)
1716
+ print("Prices are approximate on-demand rates (us-east-1 or equivalent)")
1717
+ print("Use --cloud <instance> to get cost estimates for your model")
1718
+ print("=" * 70 + "\n")
1719
+ sys.exit(0)
1720
+
1721
+ # Handle --list-conversions
1722
+ if args.list_conversions:
1723
+ from .format_adapters import list_conversion_paths
1724
+
1725
+ print("\n" + "=" * 70)
1726
+ print("Available Format Conversion Paths")
1727
+ print("=" * 70)
1728
+
1729
+ paths = list_conversion_paths()
1730
+ current_source = None
1731
+
1732
+ for path in paths:
1733
+ if path["source"] != current_source:
1734
+ current_source = path["source"]
1735
+ print(f"\nFrom {current_source.upper()}:")
1736
+
1737
+ level_icons = {
1738
+ "full": "[FULL] ",
1739
+ "partial": "[PARTIAL] ",
1740
+ "lossy": "[LOSSY] ",
1741
+ "none": "[NONE] ",
1742
+ }
1743
+ icon = level_icons.get(path["level"], " ")
1744
+ print(f" {icon} -> {path['target']}")
1745
+
1746
+ print("\n" + "-" * 70)
1747
+ print("Conversion Levels:")
1748
+ print(" FULL = Lossless, complete conversion")
1749
+ print(" PARTIAL = Some limitations or multi-step required")
1750
+ print(" LOSSY = Information may be lost")
1751
+ print(" NONE = No conversion path available")
1752
+ print("\nUse --convert-to FORMAT to convert a model")
1753
+ print("=" * 70 + "\n")
1754
+ sys.exit(0)
1755
+
1756
+ # Handle model conversion if requested
1757
+ temp_onnx_file = None
1758
+ conversion_sources = [
1759
+ ("--from-pytorch", args.from_pytorch),
1760
+ ("--from-tensorflow", args.from_tensorflow),
1761
+ ("--from-keras", args.from_keras),
1762
+ ("--from-frozen-graph", args.from_frozen_graph),
1763
+ ("--from-jax", args.from_jax),
1764
+ ]
1765
+ active_conversions = [(name, path) for name, path in conversion_sources if path]
1766
+
1767
+ if len(active_conversions) > 1:
1768
+ names = [name for name, _ in active_conversions]
1769
+ logger.error(f"Cannot use multiple conversion flags together: {', '.join(names)}")
1770
+ sys.exit(1)
1771
+
1772
+ if args.from_pytorch:
1773
+ model_path, temp_onnx_file = _convert_pytorch_to_onnx(
1774
+ args.from_pytorch,
1775
+ args.input_shape,
1776
+ args.keep_onnx,
1777
+ args.opset_version,
1778
+ logger,
1779
+ )
1780
+ if model_path is None:
1781
+ sys.exit(1)
1782
+ elif args.from_tensorflow:
1783
+ model_path, temp_onnx_file = _convert_tensorflow_to_onnx(
1784
+ args.from_tensorflow,
1785
+ args.keep_onnx,
1786
+ args.opset_version,
1787
+ logger,
1788
+ )
1789
+ if model_path is None:
1790
+ sys.exit(1)
1791
+ elif args.from_keras:
1792
+ model_path, temp_onnx_file = _convert_keras_to_onnx(
1793
+ args.from_keras,
1794
+ args.keep_onnx,
1795
+ args.opset_version,
1796
+ logger,
1797
+ )
1798
+ if model_path is None:
1799
+ sys.exit(1)
1800
+ elif args.from_frozen_graph:
1801
+ model_path, temp_onnx_file = _convert_frozen_graph_to_onnx(
1802
+ args.from_frozen_graph,
1803
+ args.tf_inputs,
1804
+ args.tf_outputs,
1805
+ args.keep_onnx,
1806
+ args.opset_version,
1807
+ logger,
1808
+ )
1809
+ if model_path is None:
1810
+ sys.exit(1)
1811
+ elif args.from_jax:
1812
+ model_path, temp_onnx_file = _convert_jax_to_onnx(
1813
+ args.from_jax,
1814
+ args.jax_apply_fn,
1815
+ args.input_shape,
1816
+ args.keep_onnx,
1817
+ args.opset_version,
1818
+ logger,
1819
+ )
1820
+ if model_path is None:
1821
+ sys.exit(1)
1822
+ else:
1823
+ # Validate model path (no conversion requested)
1824
+ if args.model_path is None:
1825
+ logger.error(
1826
+ "Model path is required. Use --list-hardware to see available profiles, "
1827
+ "or use a conversion flag (--from-pytorch, --from-tensorflow, --from-keras, "
1828
+ "--from-frozen-graph, --from-jax)."
1829
+ )
1830
+ sys.exit(1)
1831
+
1832
+ model_path = args.model_path.resolve()
1833
+ if not model_path.exists():
1834
+ logger.error(f"Model file not found: {model_path}")
1835
+ sys.exit(1)
1836
+
1837
+ if model_path.suffix.lower() not in (".onnx", ".pb", ".ort"):
1838
+ logger.warning(f"Unexpected file extension: {model_path.suffix}. Proceeding anyway.")
1839
+
1840
+ # Handle --convert-to for format conversion
1841
+ if args.convert_to:
1842
+ from .format_adapters import (
1843
+ ConversionLevel,
1844
+ OnnxAdapter,
1845
+ get_conversion_level,
1846
+ load_model,
1847
+ )
1848
+ from .universal_ir import SourceFormat
1849
+
1850
+ if not args.convert_output:
1851
+ logger.error("--convert-output is required when using --convert-to")
1852
+ sys.exit(1)
1853
+
1854
+ target_format = SourceFormat(args.convert_to.lower())
1855
+ output_path = args.convert_output.resolve()
1856
+
1857
+ # Determine source format from file extension
1858
+ source_ext = model_path.suffix.lower()
1859
+ source_format_map = {
1860
+ ".onnx": SourceFormat.ONNX,
1861
+ ".pt": SourceFormat.PYTORCH,
1862
+ ".pth": SourceFormat.PYTORCH,
1863
+ }
1864
+ source_format = source_format_map.get(source_ext, SourceFormat.UNKNOWN)
1865
+
1866
+ # Check conversion level
1867
+ level = get_conversion_level(source_format, target_format)
1868
+ if level == ConversionLevel.NONE:
1869
+ logger.error(
1870
+ f"Cannot convert from {source_format.value} to {target_format.value}. "
1871
+ f"Use --list-conversions to see available paths."
1872
+ )
1873
+ sys.exit(1)
1874
+ elif level == ConversionLevel.LOSSY:
1875
+ logger.warning(
1876
+ f"Converting {source_format.value} to {target_format.value} may lose information"
1877
+ )
1878
+
1879
+ # Load to Universal IR and convert
1880
+ try:
1881
+ logger.info(f"Loading {model_path} as Universal IR...")
1882
+ graph = load_model(model_path)
1883
+
1884
+ if target_format == SourceFormat.ONNX:
1885
+ logger.info(f"Exporting to ONNX: {output_path}")
1886
+ OnnxAdapter().write(graph, output_path)
1887
+ logger.info(f"Successfully converted to {output_path}")
1888
+ print(f"\nConverted: {model_path} -> {output_path}")
1889
+ print(f" Nodes: {graph.num_nodes}")
1890
+ print(f" Parameters: {graph.total_parameters:,}")
1891
+ print(f" Source: {graph.metadata.source_format.value}")
1892
+ else:
1893
+ logger.error(f"Conversion to {target_format.value} not yet implemented")
1894
+ sys.exit(1)
1895
+
1896
+ # If no analysis requested, exit after conversion
1897
+ if not (args.out_json or args.out_md or args.out_html or args.out_pdf):
1898
+ sys.exit(0)
1899
+
1900
+ # Use the converted file for analysis
1901
+ model_path = output_path
1902
+
1903
+ except Exception as e:
1904
+ logger.error(f"Conversion failed: {e}")
1905
+ sys.exit(1)
1906
+
1907
+ # Handle --export-ir and --export-graph (Universal IR exports)
1908
+ if args.export_ir or args.export_graph:
1909
+ from .format_adapters import load_model
1910
+
1911
+ try:
1912
+ logger.info(f"Loading {model_path} as Universal IR...")
1913
+ ir_graph = load_model(model_path)
1914
+
1915
+ if args.export_ir:
1916
+ logger.info(f"Exporting Universal IR to {args.export_ir}")
1917
+ ir_graph.to_json(args.export_ir)
1918
+ print(f"Exported IR: {args.export_ir}")
1919
+ print(f" Nodes: {ir_graph.num_nodes}")
1920
+ print(f" Parameters: {ir_graph.total_parameters:,}")
1921
+
1922
+ if args.export_graph:
1923
+ export_path = args.export_graph
1924
+ suffix = export_path.suffix.lower()
1925
+
1926
+ if suffix == ".dot":
1927
+ logger.info(f"Exporting graph to DOT: {export_path}")
1928
+ ir_graph.save_dot(export_path)
1929
+ print(f"Exported graph: {export_path}")
1930
+ elif suffix == ".png":
1931
+ logger.info(f"Rendering graph to PNG: {export_path}")
1932
+ ir_graph.save_png(export_path, max_nodes=args.graph_max_nodes)
1933
+ print(f"Rendered graph: {export_path}")
1934
+ else:
1935
+ logger.error(f"Unsupported graph format: {suffix}. Use .dot or .png")
1936
+ sys.exit(1)
1937
+
1938
+ # If no other outputs requested, exit after IR export
1939
+ if not (args.out_json or args.out_md or args.out_html or args.out_pdf):
1940
+ sys.exit(0)
1941
+
1942
+ except ImportError as e:
1943
+ logger.error(f"Missing dependency: {e}")
1944
+ sys.exit(1)
1945
+ except Exception as e:
1946
+ logger.error(f"IR export failed: {e}")
1947
+ sys.exit(1)
1948
+
1949
+ # Determine hardware profile
1950
+ hardware_profile = None
1951
+ cloud_instance = None
1952
+
1953
+ if args.cloud:
1954
+ # Cloud instance takes precedence
1955
+ cloud_instance = get_cloud_instance(args.cloud)
1956
+ if cloud_instance is None:
1957
+ logger.error(f"Unknown cloud instance: {args.cloud}")
1958
+ logger.error("Use --list-cloud to see available instances.")
1959
+ sys.exit(1)
1960
+ hardware_profile = cloud_instance.hardware
1961
+ # Override gpu_count from cloud instance
1962
+ args.gpu_count = cloud_instance.gpu_count
1963
+ logger.info(
1964
+ f"Using cloud instance: {cloud_instance.name} "
1965
+ f"({cloud_instance.gpu_count}x GPU, ${cloud_instance.hourly_cost_usd:.2f}/hr)"
1966
+ )
1967
+ elif args.hardware:
1968
+ if args.hardware.lower() == "auto":
1969
+ logger.info("Auto-detecting local hardware...")
1970
+ hardware_profile = detect_local_hardware()
1971
+ logger.info(f"Detected: {hardware_profile.name}")
1972
+ else:
1973
+ hardware_profile = get_profile(args.hardware)
1974
+ if hardware_profile is None:
1975
+ logger.error(f"Unknown hardware profile: {args.hardware}")
1976
+ logger.error("Use --list-hardware to see available profiles.")
1977
+ sys.exit(1)
1978
+ logger.info(f"Using hardware profile: {hardware_profile.name}")
1979
+
1980
+ # Apply multi-GPU scaling if requested
1981
+ if hardware_profile and args.gpu_count > 1 and not args.cloud:
1982
+ multi_gpu = create_multi_gpu_profile(args.hardware or "auto", args.gpu_count)
1983
+ if multi_gpu:
1984
+ hardware_profile = multi_gpu.get_effective_profile()
1985
+ logger.info(
1986
+ f"Multi-GPU: {args.gpu_count}x scaling with "
1987
+ f"{multi_gpu.compute_efficiency:.0%} efficiency"
1988
+ )
1989
+
1990
+ # Setup progress indicator
1991
+ progress = ProgressIndicator(enabled=args.progress, quiet=args.quiet)
1992
+
1993
+ # Calculate total steps based on what will be done
1994
+ total_steps = 2 # Load + Analyze always
1995
+ if hardware_profile:
1996
+ total_steps += 1
1997
+ if args.system_requirements:
1998
+ total_steps += 1
1999
+ if args.sweep_batch_sizes and hardware_profile:
2000
+ total_steps += 1
2001
+ if args.sweep_resolutions and hardware_profile:
2002
+ total_steps += 1
2003
+ if not args.no_gpu_metrics:
2004
+ total_steps += 1
2005
+ if not args.no_profile:
2006
+ total_steps += 1
2007
+ if not args.no_profile and not args.no_bottleneck_analysis and hardware_profile:
2008
+ total_steps += 1
2009
+ if not args.no_benchmark_resolutions:
2010
+ total_steps += 1
2011
+ if args.with_plots and is_viz_available():
2012
+ total_steps += 1
2013
+ if args.llm_summary and is_llm_available() and has_llm_api_key():
2014
+ total_steps += 1
2015
+ if args.out_json or args.out_md or args.out_html:
2016
+ total_steps += 1
2017
+
2018
+ progress.start(total_steps, f"Analyzing {model_path.name}")
2019
+
2020
+ # Run inspection
2021
+ try:
2022
+ progress.step("Loading model and extracting graph structure")
2023
+ inspector = ModelInspector(logger=logger)
2024
+ report = inspector.inspect(model_path)
2025
+ progress.step("Computing metrics (params, FLOPs, memory)")
2026
+
2027
+ # Add hardware estimates if profile specified
2028
+ if (
2029
+ hardware_profile
2030
+ and report.param_counts
2031
+ and report.flop_counts
2032
+ and report.memory_estimates
2033
+ ):
2034
+ progress.step(f"Estimating performance on {hardware_profile.name}")
2035
+ estimator = HardwareEstimator(logger=logger)
2036
+ hw_estimates = estimator.estimate(
2037
+ model_params=report.param_counts.total,
2038
+ model_flops=report.flop_counts.total,
2039
+ peak_activation_bytes=report.memory_estimates.peak_activation_bytes,
2040
+ hardware=hardware_profile,
2041
+ batch_size=args.batch_size,
2042
+ precision=args.precision,
2043
+ )
2044
+ report.hardware_estimates = hw_estimates
2045
+ report.hardware_profile = hardware_profile
2046
+
2047
+ # Batch Size Sweep (Story 6C.1)
2048
+ if args.sweep_batch_sizes:
2049
+ profiler = OperationalProfiler(logger=logger)
2050
+
2051
+ if args.no_benchmark:
2052
+ # Use theoretical estimates (faster but less accurate)
2053
+ progress.step("Running batch size sweep (theoretical)")
2054
+ sweep_result = profiler.run_batch_sweep(
2055
+ model_params=report.param_counts.total,
2056
+ model_flops=report.flop_counts.total,
2057
+ peak_activation_bytes=report.memory_estimates.peak_activation_bytes,
2058
+ hardware=hardware_profile,
2059
+ precision=args.precision,
2060
+ )
2061
+ else:
2062
+ # Default: use actual inference benchmarking
2063
+ progress.step("Benchmarking batch sizes (actual inference)")
2064
+ sweep_result = profiler.run_batch_sweep_benchmark(
2065
+ model_path=str(model_path),
2066
+ batch_sizes=[1, 2, 4, 8, 16, 32, 64, 128],
2067
+ )
2068
+ if sweep_result is None:
2069
+ # Fall back to theoretical if benchmark fails
2070
+ logger.warning("Benchmark failed, using theoretical estimates")
2071
+ sweep_result = profiler.run_batch_sweep(
2072
+ model_params=report.param_counts.total,
2073
+ model_flops=report.flop_counts.total,
2074
+ peak_activation_bytes=report.memory_estimates.peak_activation_bytes,
2075
+ hardware=hardware_profile,
2076
+ precision=args.precision,
2077
+ )
2078
+
2079
+ report.batch_size_sweep = sweep_result
2080
+ logger.info(
2081
+ f"Batch sweep complete. Optimal batch size: {sweep_result.optimal_batch_size}"
2082
+ )
2083
+
2084
+ # Resolution Sweep (Story 6.8)
2085
+ if args.sweep_resolutions:
2086
+ progress.step("Running resolution sweep")
2087
+ profiler = OperationalProfiler(logger=logger)
2088
+
2089
+ # Determine base/training resolution from model input shape
2090
+ base_resolution = (224, 224) # Default
2091
+ if report.graph_summary and report.graph_summary.input_shapes:
2092
+ for shape in report.graph_summary.input_shapes.values():
2093
+ if len(shape) >= 3:
2094
+ # Assume NCHW or NHWC format
2095
+ h, w = shape[-2], shape[-1]
2096
+ if isinstance(h, int) and isinstance(w, int) and h > 1 and w > 1:
2097
+ base_resolution = (h, w)
2098
+ break
2099
+
2100
+ # Parse resolutions from CLI argument
2101
+ # Note: Only resolutions UP TO training resolution are reliable
2102
+ resolutions: list[tuple[int, int]] | None = None
2103
+ if args.sweep_resolutions != "auto":
2104
+ resolutions = []
2105
+ for res_part in args.sweep_resolutions.split(","):
2106
+ res_str = res_part.strip()
2107
+ if "x" in res_str:
2108
+ h, w = res_str.split("x")
2109
+ res_h, res_w = int(h), int(w)
2110
+ # Warn if resolution exceeds training resolution
2111
+ if res_h > base_resolution[0] or res_w > base_resolution[1]:
2112
+ logger.warning(
2113
+ f"Resolution {res_str} exceeds training resolution "
2114
+ f"{base_resolution[0]}x{base_resolution[1]}. "
2115
+ "Results may be unreliable."
2116
+ )
2117
+ resolutions.append((res_h, res_w))
2118
+
2119
+ res_sweep_result = profiler.run_resolution_sweep(
2120
+ base_flops=report.flop_counts.total,
2121
+ base_activation_bytes=report.memory_estimates.peak_activation_bytes,
2122
+ base_resolution=base_resolution,
2123
+ model_params=report.param_counts.total,
2124
+ hardware=hardware_profile,
2125
+ resolutions=resolutions,
2126
+ batch_size=args.batch_size,
2127
+ precision=args.precision,
2128
+ )
2129
+ report.resolution_sweep = res_sweep_result
2130
+ logger.info(
2131
+ f"Resolution sweep complete. Max resolution: {res_sweep_result.max_resolution}, "
2132
+ f"Optimal: {res_sweep_result.optimal_resolution}"
2133
+ )
2134
+
2135
+ # System Requirements (Story 6C.2)
2136
+ if (
2137
+ args.system_requirements
2138
+ and report.param_counts
2139
+ and report.flop_counts
2140
+ and report.memory_estimates
2141
+ ):
2142
+ progress.step("Generating system requirements")
2143
+ profiler = OperationalProfiler(logger=logger)
2144
+
2145
+ # Derive a target FPS based on deployment target / explicit knobs
2146
+ target_fps: float | None = None
2147
+ if getattr(args, "target_throughput_fps", None):
2148
+ target_fps = float(args.target_throughput_fps)
2149
+ elif getattr(args, "target_latency_ms", None):
2150
+ # latency (ms) -> fps
2151
+ if args.target_latency_ms > 0:
2152
+ target_fps = 1000.0 / float(args.target_latency_ms)
2153
+
2154
+ # Fallback targets based on deployment target category
2155
+ if target_fps is None:
2156
+ if args.deployment_target == "edge":
2157
+ target_fps = 30.0
2158
+ elif args.deployment_target == "local":
2159
+ target_fps = 60.0
2160
+ elif args.deployment_target == "cloud":
2161
+ target_fps = 120.0
2162
+ else:
2163
+ target_fps = 30.0
2164
+
2165
+ sys_reqs = profiler.determine_system_requirements(
2166
+ model_params=report.param_counts.total,
2167
+ model_flops=report.flop_counts.total,
2168
+ peak_activation_bytes=report.memory_estimates.peak_activation_bytes,
2169
+ precision=args.precision,
2170
+ target_fps=target_fps,
2171
+ )
2172
+ report.system_requirements = sys_reqs
2173
+ if sys_reqs.recommended_gpu:
2174
+ logger.info(f"Recommended device: {sys_reqs.recommended_gpu.device}")
2175
+ elif sys_reqs.minimum_gpu:
2176
+ logger.info(f"Minimum device: {sys_reqs.minimum_gpu.device}")
2177
+
2178
+ # Epic 9: Runtime Profiling (defaults ON for real measurements)
2179
+ profiler = OperationalProfiler(logger=logger)
2180
+ profiling_result = None # Store for use in NN graph visualization
2181
+
2182
+ # GPU Metrics (Story 9.2) - default ON
2183
+ if not args.no_gpu_metrics:
2184
+ progress.step("Capturing GPU metrics")
2185
+ gpu_metrics = profiler.get_gpu_metrics()
2186
+ if gpu_metrics:
2187
+ logger.info(
2188
+ f"GPU: {gpu_metrics.vram_used_bytes / (1024**3):.2f} GB VRAM used, "
2189
+ f"{gpu_metrics.gpu_utilization_percent:.0f}% utilization, "
2190
+ f"{gpu_metrics.temperature_c}C"
2191
+ )
2192
+ # Store in report (add to JSON output)
2193
+ report.extra_data = report.extra_data or {}
2194
+ report.extra_data["gpu_metrics"] = gpu_metrics.to_dict()
2195
+ else:
2196
+ logger.debug("GPU metrics unavailable (pynvml not installed)")
2197
+
2198
+ # Per-Layer Profiling (Story 9.3) - default ON
2199
+ if not args.no_profile:
2200
+ progress.step("Running ONNX Runtime profiler")
2201
+ profiling_result = profiler.profile_model(
2202
+ model_path=str(model_path),
2203
+ batch_size=args.batch_size,
2204
+ num_runs=args.profile_runs,
2205
+ )
2206
+ if profiling_result:
2207
+ logger.info(
2208
+ f"Profiling complete: {profiling_result.total_time_ms:.2f}ms "
2209
+ f"({len(profiling_result.layer_profiles)} layers)"
2210
+ )
2211
+ # Show slowest layers
2212
+ slowest = profiling_result.get_slowest_layers(5)
2213
+ if slowest:
2214
+ logger.info("Top 5 slowest layers:")
2215
+ for lp in slowest:
2216
+ logger.info(f" {lp.name}: {lp.duration_ms:.3f}ms ({lp.op_type})")
2217
+
2218
+ # Store in report
2219
+ report.extra_data = report.extra_data or {}
2220
+ report.extra_data["profiling"] = profiling_result.to_dict()
2221
+
2222
+ # Bottleneck Analysis (Story 9.4) - default ON
2223
+ if not args.no_bottleneck_analysis and hardware_profile:
2224
+ progress.step("Analyzing bottlenecks")
2225
+ bottleneck = profiler.analyze_bottleneck(
2226
+ model_flops=report.flop_counts.total,
2227
+ profiling_result=profiling_result,
2228
+ hardware=hardware_profile,
2229
+ precision=args.precision,
2230
+ )
2231
+ logger.info(
2232
+ f"Bottleneck: {bottleneck.bottleneck_type} "
2233
+ f"(compute: {bottleneck.compute_ratio:.0%}, "
2234
+ f"memory: {bottleneck.memory_ratio:.0%})"
2235
+ )
2236
+ logger.info(f"Efficiency: {bottleneck.efficiency_percent:.1f}%")
2237
+ for rec in bottleneck.recommendations[:3]:
2238
+ logger.info(f" - {rec}")
2239
+ report.extra_data["bottleneck_analysis"] = bottleneck.to_dict()
2240
+ else:
2241
+ logger.debug("Profiling unavailable (onnxruntime not installed)")
2242
+
2243
+ # Resolution Benchmarking (Story 9.5) - default ON
2244
+ if not args.no_benchmark_resolutions:
2245
+ progress.step("Benchmarking resolutions (actual inference)")
2246
+ res_benchmark = profiler.benchmark_resolutions(
2247
+ model_path=str(model_path),
2248
+ batch_size=args.batch_size,
2249
+ )
2250
+ if res_benchmark:
2251
+ logger.info(
2252
+ f"Resolution benchmark complete. Optimal: {res_benchmark.optimal_resolution}"
2253
+ )
2254
+ report.extra_data = report.extra_data or {}
2255
+ report.extra_data["resolution_benchmark"] = res_benchmark.to_dict()
2256
+ else:
2257
+ logger.debug("Resolution benchmark unavailable")
2258
+
2259
+ except Exception as e:
2260
+ logger.error(f"Failed to inspect model: {e}")
2261
+ if args.log_level == "debug":
2262
+ import traceback
2263
+
2264
+ traceback.print_exc()
2265
+ sys.exit(1)
2266
+
2267
+ # Extract dataset metadata if PyTorch weights provided
2268
+ if args.pytorch_weights or args.from_pytorch:
2269
+ weights_path = args.pytorch_weights or args.from_pytorch
2270
+ if weights_path.exists():
2271
+ logger.info(f"Extracting metadata from: {weights_path}")
2272
+ metadata = _extract_ultralytics_metadata(weights_path, logger)
2273
+ if metadata:
2274
+ from .report import DatasetInfo
2275
+
2276
+ report.dataset_info = DatasetInfo(
2277
+ task=metadata.get("task"),
2278
+ num_classes=metadata.get("num_classes"),
2279
+ class_names=metadata.get("class_names", []),
2280
+ source=metadata.get("source"),
2281
+ )
2282
+ logger.info(
2283
+ f"Extracted {report.dataset_info.num_classes} class(es): "
2284
+ f"{', '.join(report.dataset_info.class_names[:5])}"
2285
+ f"{'...' if len(report.dataset_info.class_names) > 5 else ''}"
2286
+ )
2287
+
2288
+ # Generate LLM summaries if requested
2289
+ llm_summary = None
2290
+ if args.llm_summary:
2291
+ if args.offline:
2292
+ print("\n[OFFLINE MODE] Skipping LLM summary (requires network access)\n")
2293
+ elif not is_llm_available():
2294
+ print("\n" + "=" * 60)
2295
+ print("LLM PACKAGE NOT INSTALLED")
2296
+ print("=" * 60)
2297
+ print("To enable AI-powered summaries, install the LLM extras:\n")
2298
+ print(" pip install haoline[llm]")
2299
+ print("\nThen set your API key and try again.")
2300
+ print("=" * 60 + "\n")
2301
+ elif not has_llm_api_key():
2302
+ print("\n" + "=" * 60)
2303
+ print("API KEY REQUIRED FOR LLM SUMMARIES")
2304
+ print("=" * 60)
2305
+ print("Set one of the following environment variables:\n")
2306
+ print(" PowerShell: $env:OPENAI_API_KEY = 'sk-...'")
2307
+ print(" Bash/Zsh: export OPENAI_API_KEY='sk-...'")
2308
+ print("\nGet your API key at: https://platform.openai.com/api-keys")
2309
+ print("=" * 60 + "\n")
2310
+ else:
2311
+ try:
2312
+ progress.step(f"Generating LLM summary with {args.llm_model}")
2313
+ logger.info(f"Generating LLM summaries with {args.llm_model}...")
2314
+ summarizer = LLMSummarizer(model=args.llm_model, logger=logger)
2315
+ llm_summary = summarizer.summarize(report)
2316
+ if llm_summary.success:
2317
+ logger.info(f"LLM summaries generated ({llm_summary.tokens_used} tokens used)")
2318
+ else:
2319
+ logger.warning(f"LLM summarization failed: {llm_summary.error_message}")
2320
+ except Exception as e:
2321
+ logger.warning(f"Failed to generate LLM summaries: {e}")
2322
+
2323
+ # Store LLM summary in report for output
2324
+ if llm_summary and llm_summary.success:
2325
+ # Add to report dict for JSON output
2326
+ report._llm_summary = llm_summary # type: ignore
2327
+
2328
+ # Apply privacy transformations if requested
2329
+ report_dict = report.to_dict()
2330
+ if args.summary_only:
2331
+ from .privacy import create_summary_only_dict
2332
+
2333
+ logger.info("Applying summary-only mode (omitting per-layer details)")
2334
+ report_dict = create_summary_only_dict(report_dict)
2335
+ elif args.redact_names:
2336
+ from .privacy import collect_names_from_dict, create_name_mapping, redact_dict
2337
+
2338
+ logger.info("Applying name redaction")
2339
+ names = collect_names_from_dict(report_dict)
2340
+ mapping = create_name_mapping(names)
2341
+ report_dict = redact_dict(report_dict, mapping)
2342
+ logger.debug(f"Redacted {len(mapping)} names")
2343
+
2344
+ # Output results
2345
+ has_output = (
2346
+ args.out_json
2347
+ or args.out_md
2348
+ or args.out_html
2349
+ or args.out_pdf
2350
+ or args.html_graph
2351
+ or args.layer_csv
2352
+ )
2353
+ if has_output:
2354
+ progress.step("Writing output files")
2355
+
2356
+ if args.out_json:
2357
+ try:
2358
+ import json
2359
+
2360
+ args.out_json.parent.mkdir(parents=True, exist_ok=True)
2361
+ args.out_json.write_text(json.dumps(report_dict, indent=2), encoding="utf-8")
2362
+ logger.info(f"JSON report written to: {args.out_json}")
2363
+ except Exception as e:
2364
+ logger.error(f"Failed to write JSON report: {e}")
2365
+ sys.exit(1)
2366
+
2367
+ # Generate visualizations if requested
2368
+ viz_paths = {}
2369
+ if args.with_plots:
2370
+ if not is_viz_available():
2371
+ logger.warning(
2372
+ "matplotlib not installed. Skipping visualizations. Install with: pip install matplotlib"
2373
+ )
2374
+ else:
2375
+ progress.step("Generating visualizations")
2376
+ # Determine assets directory
2377
+ if args.assets_dir:
2378
+ assets_dir = args.assets_dir
2379
+ elif args.out_html:
2380
+ # HTML embeds images, but we still generate them for the file
2381
+ assets_dir = args.out_html.parent / "assets"
2382
+ elif args.out_md:
2383
+ assets_dir = args.out_md.parent / "assets"
2384
+ elif args.out_json:
2385
+ assets_dir = args.out_json.parent / "assets"
2386
+ else:
2387
+ assets_dir = pathlib.Path("assets")
2388
+
2389
+ try:
2390
+ viz_gen = VisualizationGenerator(logger=logger)
2391
+ viz_paths = viz_gen.generate_all(report, assets_dir)
2392
+ logger.info(f"Generated {len(viz_paths)} visualization assets in {assets_dir}")
2393
+ except Exception as e:
2394
+ logger.warning(f"Failed to generate some visualizations: {e}")
2395
+ if args.log_level == "debug":
2396
+ import traceback
2397
+
2398
+ traceback.print_exc()
2399
+
2400
+ if args.out_md:
2401
+ try:
2402
+ args.out_md.parent.mkdir(parents=True, exist_ok=True)
2403
+ # Generate markdown with visualizations and/or LLM summaries
2404
+ if viz_paths or llm_summary:
2405
+ md_content = _generate_markdown_with_extras(
2406
+ report, viz_paths, args.out_md.parent, llm_summary
2407
+ )
2408
+ else:
2409
+ md_content = report.to_markdown()
2410
+ args.out_md.write_text(md_content, encoding="utf-8")
2411
+ logger.info(f"Markdown model card written to: {args.out_md}")
2412
+ except Exception as e:
2413
+ logger.error(f"Failed to write Markdown report: {e}")
2414
+ sys.exit(1)
2415
+
2416
+ if args.out_html or args.out_pdf:
2417
+ # Add LLM summary to report if available
2418
+ if llm_summary and llm_summary.success:
2419
+ report.llm_summary = {
2420
+ "success": True,
2421
+ "short_summary": llm_summary.short_summary,
2422
+ "detailed_summary": llm_summary.detailed_summary,
2423
+ "model": args.llm_model,
2424
+ }
2425
+ # Generate layer table HTML if requested
2426
+ layer_table_html = None
2427
+ if args.include_layer_table or args.layer_csv:
2428
+ try:
2429
+ # Re-load graph info if needed
2430
+ if hasattr(inspector, "_graph_info") and inspector._graph_info:
2431
+ graph_info = inspector._graph_info
2432
+ else:
2433
+ from .analyzer import ONNXGraphLoader
2434
+
2435
+ loader = ONNXGraphLoader(logger=logger)
2436
+ _, graph_info = loader.load(model_path)
2437
+
2438
+ layer_builder = LayerSummaryBuilder(logger=logger)
2439
+ layer_summary = layer_builder.build(
2440
+ graph_info,
2441
+ report.param_counts,
2442
+ report.flop_counts,
2443
+ report.memory_estimates,
2444
+ )
2445
+
2446
+ if args.include_layer_table:
2447
+ layer_table_html = generate_html_table(layer_summary)
2448
+ logger.debug("Generated layer summary table for HTML report")
2449
+
2450
+ if args.layer_csv:
2451
+ args.layer_csv.parent.mkdir(parents=True, exist_ok=True)
2452
+ layer_summary.save_csv(args.layer_csv)
2453
+ logger.info(f"Layer summary CSV written to: {args.layer_csv}")
2454
+
2455
+ except Exception as e:
2456
+ logger.warning(f"Could not generate layer summary: {e}")
2457
+
2458
+ # Generate embedded graph HTML if requested
2459
+ graph_html = None
2460
+ if args.include_graph:
2461
+ try:
2462
+ # Re-load graph info if needed
2463
+ if hasattr(inspector, "_graph_info") and inspector._graph_info:
2464
+ graph_info = inspector._graph_info
2465
+ else:
2466
+ from .analyzer import ONNXGraphLoader
2467
+
2468
+ loader = ONNXGraphLoader(logger=logger)
2469
+ _, graph_info = loader.load(model_path)
2470
+
2471
+ pattern_analyzer = PatternAnalyzer(logger=logger)
2472
+ blocks = pattern_analyzer.group_into_blocks(graph_info)
2473
+
2474
+ edge_analyzer = EdgeAnalyzer(logger=logger)
2475
+ edge_result = edge_analyzer.analyze(graph_info)
2476
+
2477
+ builder = HierarchicalGraphBuilder(logger=logger)
2478
+ hier_graph = builder.build(graph_info, blocks, model_path.stem)
2479
+
2480
+ # Generate graph HTML (just the interactive part, not full document)
2481
+ # For embedding, we'll use an iframe approach
2482
+
2483
+ # Extract layer timing from profiling results if available
2484
+ layer_timing = None
2485
+ if (
2486
+ report.extra_data
2487
+ and "profiling" in report.extra_data
2488
+ and "slowest_layers" in report.extra_data["profiling"]
2489
+ ):
2490
+ layer_timing = {}
2491
+ for layer in report.extra_data["profiling"]["slowest_layers"]:
2492
+ layer_timing[layer["name"]] = layer["duration_ms"]
2493
+
2494
+ full_graph_html = generate_graph_html(
2495
+ hier_graph, edge_result, model_path.stem, layer_timing=layer_timing
2496
+ )
2497
+ # Wrap in iframe data URI for embedding
2498
+ import base64
2499
+
2500
+ graph_data = base64.b64encode(full_graph_html.encode()).decode()
2501
+ graph_html = f'<iframe src="data:text/html;base64,{graph_data}" style="width:100%;height:100%;border:none;"></iframe>'
2502
+ logger.debug("Generated interactive graph for HTML report")
2503
+
2504
+ except Exception as e:
2505
+ logger.warning(f"Could not generate embedded graph: {e}")
2506
+
2507
+ # Generate HTML with embedded images, graph, and layer table
2508
+ html_content = report.to_html(
2509
+ image_paths=viz_paths,
2510
+ graph_html=graph_html,
2511
+ layer_table_html=layer_table_html,
2512
+ )
2513
+
2514
+ if args.out_html:
2515
+ try:
2516
+ args.out_html.parent.mkdir(parents=True, exist_ok=True)
2517
+ args.out_html.write_text(html_content, encoding="utf-8")
2518
+ logger.info(f"HTML report written to: {args.out_html}")
2519
+ except Exception as e:
2520
+ logger.error(f"Failed to write HTML report: {e}")
2521
+ sys.exit(1)
2522
+
2523
+ if args.out_pdf:
2524
+ if not is_pdf_available():
2525
+ logger.error(
2526
+ "Playwright not installed. Install with: pip install playwright && playwright install chromium"
2527
+ )
2528
+ sys.exit(1)
2529
+ try:
2530
+ args.out_pdf.parent.mkdir(parents=True, exist_ok=True)
2531
+ pdf_gen = PDFGenerator(logger=logger)
2532
+ success = pdf_gen.generate_from_html(html_content, args.out_pdf)
2533
+ if success:
2534
+ logger.info(f"PDF report written to: {args.out_pdf}")
2535
+ else:
2536
+ logger.error("PDF generation failed")
2537
+ sys.exit(1)
2538
+ except Exception as e:
2539
+ logger.error(f"Failed to write PDF report: {e}")
2540
+ sys.exit(1)
2541
+
2542
+ # Interactive graph visualization
2543
+ if args.html_graph:
2544
+ try:
2545
+ args.html_graph.parent.mkdir(parents=True, exist_ok=True)
2546
+
2547
+ # Use graph_info from the inspector if available
2548
+ if hasattr(inspector, "_graph_info") and inspector._graph_info:
2549
+ graph_info = inspector._graph_info
2550
+ else:
2551
+ # Re-load the model to get graph_info
2552
+ from .analyzer import ONNXGraphLoader
2553
+
2554
+ loader = ONNXGraphLoader(logger=logger)
2555
+ _, graph_info = loader.load(model_path)
2556
+
2557
+ # Detect patterns
2558
+ pattern_analyzer = PatternAnalyzer(logger=logger)
2559
+ blocks = pattern_analyzer.group_into_blocks(graph_info)
2560
+
2561
+ # Analyze edges
2562
+ edge_analyzer = EdgeAnalyzer(logger=logger)
2563
+ edge_result = edge_analyzer.analyze(graph_info)
2564
+
2565
+ # Build hierarchy
2566
+ builder = HierarchicalGraphBuilder(logger=logger)
2567
+ hier_graph = builder.build(graph_info, blocks, model_path.stem)
2568
+
2569
+ # Export HTML with model size and layer timing
2570
+ model_size = model_path.stat().st_size if model_path.exists() else None
2571
+
2572
+ # Extract layer timing from profiling results if available
2573
+ layer_timing = None
2574
+ if (
2575
+ report.extra_data
2576
+ and "profiling" in report.extra_data
2577
+ and "slowest_layers" in report.extra_data["profiling"]
2578
+ ):
2579
+ # Build timing dict from profiling results
2580
+ layer_timing = {}
2581
+ for layer in report.extra_data["profiling"]["slowest_layers"]:
2582
+ layer_timing[layer["name"]] = layer["duration_ms"]
2583
+
2584
+ exporter = HTMLExporter(logger=logger)
2585
+ exporter.export(
2586
+ hier_graph,
2587
+ edge_result,
2588
+ args.html_graph,
2589
+ model_path.stem,
2590
+ model_size_bytes=model_size,
2591
+ layer_timing=layer_timing,
2592
+ )
2593
+
2594
+ logger.info(f"Interactive graph visualization written to: {args.html_graph}")
2595
+ except Exception as e:
2596
+ logger.error(f"Failed to generate graph visualization: {e}")
2597
+ if not args.quiet:
2598
+ import traceback
2599
+
2600
+ traceback.print_exc()
2601
+ sys.exit(1)
2602
+
2603
+ # Console output
2604
+ if (
2605
+ not args.quiet
2606
+ and not args.out_json
2607
+ and not args.out_md
2608
+ and not args.out_html
2609
+ and not args.out_pdf
2610
+ and not args.html_graph
2611
+ ):
2612
+ # No output files specified - print summary to console
2613
+ print("\n" + "=" * 60)
2614
+ print(f"Model: {model_path.name}")
2615
+ print("=" * 60)
2616
+
2617
+ if report.graph_summary:
2618
+ print(f"\nNodes: {report.graph_summary.num_nodes}")
2619
+ print(f"Inputs: {report.graph_summary.num_inputs}")
2620
+ print(f"Outputs: {report.graph_summary.num_outputs}")
2621
+ print(f"Initializers: {report.graph_summary.num_initializers}")
2622
+
2623
+ if report.param_counts:
2624
+ print(f"\nParameters: {report._format_number(report.param_counts.total)}")
2625
+
2626
+ if report.flop_counts:
2627
+ print(f"FLOPs: {report._format_number(report.flop_counts.total)}")
2628
+
2629
+ if report.memory_estimates:
2630
+ print(f"Model Size: {report._format_bytes(report.memory_estimates.model_size_bytes)}")
2631
+
2632
+ print(f"\nArchitecture: {report.architecture_type}")
2633
+ print(f"Detected Blocks: {len(report.detected_blocks)}")
2634
+
2635
+ # Hardware estimates
2636
+ if hasattr(report, "hardware_estimates") and report.hardware_estimates:
2637
+ hw = report.hardware_estimates
2638
+ print(f"\n--- Hardware Estimates ({hw.device}) ---")
2639
+ print(f"Precision: {hw.precision}, Batch Size: {hw.batch_size}")
2640
+ print(f"VRAM Required: {report._format_bytes(hw.vram_required_bytes)}")
2641
+ print(f"Fits in VRAM: {'Yes' if hw.fits_in_vram else 'NO'}")
2642
+ if hw.fits_in_vram:
2643
+ print(f"Theoretical Latency: {hw.theoretical_latency_ms:.2f} ms")
2644
+ print(f"Bottleneck: {hw.bottleneck}")
2645
+
2646
+ # System Requirements (Console)
2647
+ if hasattr(report, "system_requirements") and report.system_requirements:
2648
+ reqs = report.system_requirements
2649
+ print("\n--- System Requirements ---")
2650
+ print(f"Minimum: {reqs.minimum_gpu.name} ({reqs.minimum_vram_gb} GB VRAM)")
2651
+ print(f"Recommended: {reqs.recommended_gpu.name} ({reqs.recommended_vram_gb} GB VRAM)")
2652
+ print(f"Optimal: {reqs.optimal_gpu.name}")
2653
+
2654
+ # Batch Scaling (Console)
2655
+ if hasattr(report, "batch_size_sweep") and report.batch_size_sweep:
2656
+ sweep = report.batch_size_sweep
2657
+ print("\n--- Batch Size Scaling ---")
2658
+ print(f"Optimal Batch Size: {sweep.optimal_batch_size}")
2659
+ print(f"Max Throughput: {max(sweep.throughputs):.1f} inf/s")
2660
+
2661
+ if report.risk_signals:
2662
+ print(f"\nRisk Signals: {len(report.risk_signals)}")
2663
+ for risk in report.risk_signals:
2664
+ severity_icon = {
2665
+ "info": "[INFO]",
2666
+ "warning": "[WARN]",
2667
+ "high": "[HIGH]",
2668
+ }
2669
+ print(f" {severity_icon.get(risk.severity, '')} {risk.id}")
2670
+
2671
+ # LLM Summary
2672
+ if llm_summary and llm_summary.success:
2673
+ print(f"\n--- LLM Summary ({llm_summary.model_used}) ---")
2674
+ if llm_summary.short_summary:
2675
+ print(f"{llm_summary.short_summary}")
2676
+
2677
+ print("\n" + "=" * 60)
2678
+ print("Use --out-json or --out-md for detailed reports.")
2679
+ if not args.hardware:
2680
+ print("Use --hardware auto or --hardware <profile> for hardware estimates.")
2681
+ print("=" * 60 + "\n")
2682
+
2683
+ elif not args.quiet:
2684
+ # Files written - just confirm
2685
+ print(f"\nInspection complete for: {model_path.name}")
2686
+ if args.out_json:
2687
+ print(f" JSON report: {args.out_json}")
2688
+ if args.out_md:
2689
+ print(f" Markdown card: {args.out_md}")
2690
+ if args.out_html:
2691
+ print(f" HTML report: {args.out_html}")
2692
+ if args.out_pdf:
2693
+ print(f" PDF report: {args.out_pdf}")
2694
+ if args.html_graph:
2695
+ print(f" Graph visualization: {args.html_graph}")
2696
+ if args.layer_csv:
2697
+ print(f" Layer CSV: {args.layer_csv}")
2698
+
2699
+ # Finish progress indicator
2700
+ progress.finish("Analysis complete!")
2701
+
2702
+ # Cleanup temp ONNX file if we created one
2703
+ if temp_onnx_file is not None:
2704
+ try:
2705
+ pathlib.Path(temp_onnx_file.name).unlink()
2706
+ logger.debug(f"Cleaned up temp ONNX file: {temp_onnx_file.name}")
2707
+ except Exception:
2708
+ pass
2709
+
2710
+
2711
+ if __name__ == "__main__":
2712
+ run_inspect()