haoline 0.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- haoline/.streamlit/config.toml +10 -0
- haoline/__init__.py +248 -0
- haoline/analyzer.py +935 -0
- haoline/cli.py +2712 -0
- haoline/compare.py +811 -0
- haoline/compare_visualizations.py +1564 -0
- haoline/edge_analysis.py +525 -0
- haoline/eval/__init__.py +131 -0
- haoline/eval/adapters.py +844 -0
- haoline/eval/cli.py +390 -0
- haoline/eval/comparison.py +542 -0
- haoline/eval/deployment.py +633 -0
- haoline/eval/schemas.py +833 -0
- haoline/examples/__init__.py +15 -0
- haoline/examples/basic_inspection.py +74 -0
- haoline/examples/compare_models.py +117 -0
- haoline/examples/hardware_estimation.py +78 -0
- haoline/format_adapters.py +1001 -0
- haoline/formats/__init__.py +123 -0
- haoline/formats/coreml.py +250 -0
- haoline/formats/gguf.py +483 -0
- haoline/formats/openvino.py +255 -0
- haoline/formats/safetensors.py +273 -0
- haoline/formats/tflite.py +369 -0
- haoline/hardware.py +2307 -0
- haoline/hierarchical_graph.py +462 -0
- haoline/html_export.py +1573 -0
- haoline/layer_summary.py +769 -0
- haoline/llm_summarizer.py +465 -0
- haoline/op_icons.py +618 -0
- haoline/operational_profiling.py +1492 -0
- haoline/patterns.py +1116 -0
- haoline/pdf_generator.py +265 -0
- haoline/privacy.py +250 -0
- haoline/pydantic_models.py +241 -0
- haoline/report.py +1923 -0
- haoline/report_sections.py +539 -0
- haoline/risks.py +521 -0
- haoline/schema.py +523 -0
- haoline/streamlit_app.py +2024 -0
- haoline/tests/__init__.py +4 -0
- haoline/tests/conftest.py +123 -0
- haoline/tests/test_analyzer.py +868 -0
- haoline/tests/test_compare_visualizations.py +293 -0
- haoline/tests/test_edge_analysis.py +243 -0
- haoline/tests/test_eval.py +604 -0
- haoline/tests/test_format_adapters.py +460 -0
- haoline/tests/test_hardware.py +237 -0
- haoline/tests/test_hardware_recommender.py +90 -0
- haoline/tests/test_hierarchical_graph.py +326 -0
- haoline/tests/test_html_export.py +180 -0
- haoline/tests/test_layer_summary.py +428 -0
- haoline/tests/test_llm_patterns.py +540 -0
- haoline/tests/test_llm_summarizer.py +339 -0
- haoline/tests/test_patterns.py +774 -0
- haoline/tests/test_pytorch.py +327 -0
- haoline/tests/test_report.py +383 -0
- haoline/tests/test_risks.py +398 -0
- haoline/tests/test_schema.py +417 -0
- haoline/tests/test_tensorflow.py +380 -0
- haoline/tests/test_visualizations.py +316 -0
- haoline/universal_ir.py +856 -0
- haoline/visualizations.py +1086 -0
- haoline/visualize_yolo.py +44 -0
- haoline/web.py +110 -0
- haoline-0.3.0.dist-info/METADATA +471 -0
- haoline-0.3.0.dist-info/RECORD +70 -0
- haoline-0.3.0.dist-info/WHEEL +4 -0
- haoline-0.3.0.dist-info/entry_points.txt +5 -0
- haoline-0.3.0.dist-info/licenses/LICENSE +22 -0
|
@@ -0,0 +1,1001 @@
|
|
|
1
|
+
# Copyright (c) 2025 HaoLine Contributors
|
|
2
|
+
# SPDX-License-Identifier: MIT
|
|
3
|
+
|
|
4
|
+
"""
|
|
5
|
+
Format Adapter system for Universal IR.
|
|
6
|
+
|
|
7
|
+
This module provides the plugin interface for model format readers/writers.
|
|
8
|
+
Each adapter converts format-specific models to/from UniversalGraph.
|
|
9
|
+
|
|
10
|
+
Usage:
|
|
11
|
+
from haoline.format_adapters import get_adapter, list_adapters
|
|
12
|
+
|
|
13
|
+
# Auto-detect and load
|
|
14
|
+
adapter = get_adapter("model.onnx")
|
|
15
|
+
graph = adapter.read("model.onnx")
|
|
16
|
+
|
|
17
|
+
# Explicit adapter selection
|
|
18
|
+
from haoline.format_adapters import OnnxAdapter
|
|
19
|
+
graph = OnnxAdapter().read("model.onnx")
|
|
20
|
+
"""
|
|
21
|
+
|
|
22
|
+
from __future__ import annotations
|
|
23
|
+
|
|
24
|
+
import logging
|
|
25
|
+
from abc import ABC, abstractmethod
|
|
26
|
+
from enum import Enum
|
|
27
|
+
from pathlib import Path
|
|
28
|
+
from typing import TYPE_CHECKING
|
|
29
|
+
|
|
30
|
+
if TYPE_CHECKING:
|
|
31
|
+
import onnx
|
|
32
|
+
import torch
|
|
33
|
+
|
|
34
|
+
from .universal_ir import (
|
|
35
|
+
DataType,
|
|
36
|
+
GraphMetadata,
|
|
37
|
+
SourceFormat,
|
|
38
|
+
TensorOrigin,
|
|
39
|
+
UniversalGraph,
|
|
40
|
+
UniversalNode,
|
|
41
|
+
UniversalTensor,
|
|
42
|
+
)
|
|
43
|
+
|
|
44
|
+
logger = logging.getLogger(__name__)
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
# =============================================================================
|
|
48
|
+
# Format Adapter Protocol
|
|
49
|
+
# =============================================================================
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
class FormatAdapter(ABC):
|
|
53
|
+
"""Abstract base class for model format adapters.
|
|
54
|
+
|
|
55
|
+
Implement this interface to add support for a new model format.
|
|
56
|
+
Register the adapter using `register_adapter()`.
|
|
57
|
+
|
|
58
|
+
Example:
|
|
59
|
+
class MyFormatAdapter(FormatAdapter):
|
|
60
|
+
name = "myformat"
|
|
61
|
+
extensions = [".myf", ".myformat"]
|
|
62
|
+
source_format = SourceFormat.UNKNOWN
|
|
63
|
+
|
|
64
|
+
def can_read(self, path: Path) -> bool:
|
|
65
|
+
return path.suffix.lower() in self.extensions
|
|
66
|
+
|
|
67
|
+
def read(self, path: Path) -> UniversalGraph:
|
|
68
|
+
# Parse format-specific file
|
|
69
|
+
# Build and return UniversalGraph
|
|
70
|
+
...
|
|
71
|
+
"""
|
|
72
|
+
|
|
73
|
+
# Adapter metadata (override in subclasses)
|
|
74
|
+
name: str = "unknown"
|
|
75
|
+
extensions: list[str] = []
|
|
76
|
+
source_format: SourceFormat = SourceFormat.UNKNOWN
|
|
77
|
+
|
|
78
|
+
@abstractmethod
|
|
79
|
+
def can_read(self, path: Path) -> bool:
|
|
80
|
+
"""Check if this adapter can read the given file.
|
|
81
|
+
|
|
82
|
+
Args:
|
|
83
|
+
path: Path to the model file
|
|
84
|
+
|
|
85
|
+
Returns:
|
|
86
|
+
True if this adapter can read the file
|
|
87
|
+
"""
|
|
88
|
+
pass
|
|
89
|
+
|
|
90
|
+
@abstractmethod
|
|
91
|
+
def read(self, path: Path) -> UniversalGraph:
|
|
92
|
+
"""Read a model file and convert to UniversalGraph.
|
|
93
|
+
|
|
94
|
+
Args:
|
|
95
|
+
path: Path to the model file
|
|
96
|
+
|
|
97
|
+
Returns:
|
|
98
|
+
UniversalGraph representation of the model
|
|
99
|
+
|
|
100
|
+
Raises:
|
|
101
|
+
FileNotFoundError: If file doesn't exist
|
|
102
|
+
ValueError: If file format is invalid
|
|
103
|
+
"""
|
|
104
|
+
pass
|
|
105
|
+
|
|
106
|
+
def can_write(self) -> bool:
|
|
107
|
+
"""Check if this adapter supports writing.
|
|
108
|
+
|
|
109
|
+
Override this method if your adapter supports exporting
|
|
110
|
+
UniversalGraph back to the format-specific file.
|
|
111
|
+
|
|
112
|
+
Returns:
|
|
113
|
+
True if write() is supported
|
|
114
|
+
"""
|
|
115
|
+
return False
|
|
116
|
+
|
|
117
|
+
def write(self, graph: UniversalGraph, path: Path) -> None:
|
|
118
|
+
"""Write UniversalGraph to a format-specific file.
|
|
119
|
+
|
|
120
|
+
Args:
|
|
121
|
+
graph: The graph to export
|
|
122
|
+
path: Output file path
|
|
123
|
+
|
|
124
|
+
Raises:
|
|
125
|
+
NotImplementedError: If writing is not supported
|
|
126
|
+
"""
|
|
127
|
+
raise NotImplementedError(f"{self.name} adapter does not support writing")
|
|
128
|
+
|
|
129
|
+
|
|
130
|
+
# =============================================================================
|
|
131
|
+
# Adapter Registry
|
|
132
|
+
# =============================================================================
|
|
133
|
+
|
|
134
|
+
# Global registry mapping extensions to adapters
|
|
135
|
+
_ADAPTER_REGISTRY: dict[str, type[FormatAdapter]] = {}
|
|
136
|
+
|
|
137
|
+
|
|
138
|
+
def register_adapter(adapter_class: type[FormatAdapter]) -> type[FormatAdapter]:
|
|
139
|
+
"""Register a format adapter.
|
|
140
|
+
|
|
141
|
+
Can be used as a decorator:
|
|
142
|
+
|
|
143
|
+
@register_adapter
|
|
144
|
+
class MyFormatAdapter(FormatAdapter):
|
|
145
|
+
...
|
|
146
|
+
|
|
147
|
+
Args:
|
|
148
|
+
adapter_class: The adapter class to register
|
|
149
|
+
|
|
150
|
+
Returns:
|
|
151
|
+
The adapter class (for decorator use)
|
|
152
|
+
"""
|
|
153
|
+
for ext in adapter_class.extensions:
|
|
154
|
+
ext_lower = ext.lower()
|
|
155
|
+
if ext_lower in _ADAPTER_REGISTRY:
|
|
156
|
+
logger.warning(
|
|
157
|
+
f"Overwriting adapter for {ext_lower}: "
|
|
158
|
+
f"{_ADAPTER_REGISTRY[ext_lower].name} -> {adapter_class.name}"
|
|
159
|
+
)
|
|
160
|
+
_ADAPTER_REGISTRY[ext_lower] = adapter_class
|
|
161
|
+
logger.debug(f"Registered adapter: {adapter_class.name} for {adapter_class.extensions}")
|
|
162
|
+
return adapter_class
|
|
163
|
+
|
|
164
|
+
|
|
165
|
+
def get_adapter(path: str | Path) -> FormatAdapter:
|
|
166
|
+
"""Get an adapter for the given file.
|
|
167
|
+
|
|
168
|
+
Auto-detects the format based on file extension.
|
|
169
|
+
|
|
170
|
+
Args:
|
|
171
|
+
path: Path to the model file
|
|
172
|
+
|
|
173
|
+
Returns:
|
|
174
|
+
An instance of the appropriate FormatAdapter
|
|
175
|
+
|
|
176
|
+
Raises:
|
|
177
|
+
ValueError: If no adapter is registered for the file extension
|
|
178
|
+
"""
|
|
179
|
+
path = Path(path)
|
|
180
|
+
ext = path.suffix.lower()
|
|
181
|
+
|
|
182
|
+
if ext not in _ADAPTER_REGISTRY:
|
|
183
|
+
available = ", ".join(sorted(_ADAPTER_REGISTRY.keys()))
|
|
184
|
+
raise ValueError(f"No adapter registered for extension '{ext}'. Available: {available}")
|
|
185
|
+
|
|
186
|
+
adapter_class = _ADAPTER_REGISTRY[ext]
|
|
187
|
+
return adapter_class()
|
|
188
|
+
|
|
189
|
+
|
|
190
|
+
def list_adapters() -> list[dict[str, str | list[str] | bool]]:
|
|
191
|
+
"""List all registered adapters.
|
|
192
|
+
|
|
193
|
+
Returns:
|
|
194
|
+
List of adapter info dicts with keys: name, extensions, can_write
|
|
195
|
+
"""
|
|
196
|
+
seen: set[str] = set()
|
|
197
|
+
result: list[dict[str, str | list[str] | bool]] = []
|
|
198
|
+
|
|
199
|
+
for adapter_class in _ADAPTER_REGISTRY.values():
|
|
200
|
+
if adapter_class.name not in seen:
|
|
201
|
+
seen.add(adapter_class.name)
|
|
202
|
+
instance = adapter_class()
|
|
203
|
+
result.append(
|
|
204
|
+
{
|
|
205
|
+
"name": adapter_class.name,
|
|
206
|
+
"extensions": adapter_class.extensions,
|
|
207
|
+
"source_format": adapter_class.source_format.value,
|
|
208
|
+
"can_write": instance.can_write(),
|
|
209
|
+
}
|
|
210
|
+
)
|
|
211
|
+
|
|
212
|
+
return sorted(result, key=lambda x: str(x["name"]))
|
|
213
|
+
|
|
214
|
+
|
|
215
|
+
# =============================================================================
|
|
216
|
+
# Op Type Mapping (ONNX -> Universal)
|
|
217
|
+
# =============================================================================
|
|
218
|
+
|
|
219
|
+
# Map ONNX op types to universal op types
|
|
220
|
+
ONNX_TO_UNIVERSAL_OP: dict[str, str] = {
|
|
221
|
+
# Convolution
|
|
222
|
+
"Conv": "Conv2D",
|
|
223
|
+
"ConvTranspose": "ConvTranspose2D",
|
|
224
|
+
# Linear/Dense
|
|
225
|
+
"Gemm": "MatMul",
|
|
226
|
+
"MatMul": "MatMul",
|
|
227
|
+
"MatMulInteger": "MatMul",
|
|
228
|
+
# Normalization
|
|
229
|
+
"BatchNormalization": "BatchNorm",
|
|
230
|
+
"LayerNormalization": "LayerNorm",
|
|
231
|
+
"InstanceNormalization": "InstanceNorm",
|
|
232
|
+
"GroupNormalization": "GroupNorm",
|
|
233
|
+
# Activations
|
|
234
|
+
"Relu": "Relu",
|
|
235
|
+
"LeakyRelu": "LeakyRelu",
|
|
236
|
+
"Sigmoid": "Sigmoid",
|
|
237
|
+
"Tanh": "Tanh",
|
|
238
|
+
"Softmax": "Softmax",
|
|
239
|
+
"Gelu": "Gelu",
|
|
240
|
+
"Silu": "Silu",
|
|
241
|
+
"Mish": "Mish",
|
|
242
|
+
# Pooling
|
|
243
|
+
"MaxPool": "MaxPool2D",
|
|
244
|
+
"AveragePool": "AvgPool2D",
|
|
245
|
+
"GlobalAveragePool": "GlobalAvgPool",
|
|
246
|
+
"GlobalMaxPool": "GlobalMaxPool",
|
|
247
|
+
# Element-wise
|
|
248
|
+
"Add": "Add",
|
|
249
|
+
"Sub": "Sub",
|
|
250
|
+
"Mul": "Mul",
|
|
251
|
+
"Div": "Div",
|
|
252
|
+
# Reshape/View
|
|
253
|
+
"Reshape": "Reshape",
|
|
254
|
+
"Flatten": "Flatten",
|
|
255
|
+
"Squeeze": "Squeeze",
|
|
256
|
+
"Unsqueeze": "Unsqueeze",
|
|
257
|
+
"Transpose": "Transpose",
|
|
258
|
+
# Attention (custom/subgraph)
|
|
259
|
+
"Attention": "Attention",
|
|
260
|
+
"MultiHeadAttention": "MultiHeadAttention",
|
|
261
|
+
# Misc
|
|
262
|
+
"Concat": "Concat",
|
|
263
|
+
"Split": "Split",
|
|
264
|
+
"Slice": "Slice",
|
|
265
|
+
"Gather": "Gather",
|
|
266
|
+
"Dropout": "Dropout",
|
|
267
|
+
"Constant": "Constant",
|
|
268
|
+
"Identity": "Identity",
|
|
269
|
+
"Cast": "Cast",
|
|
270
|
+
"ReduceMean": "ReduceMean",
|
|
271
|
+
"ReduceSum": "ReduceSum",
|
|
272
|
+
"Clip": "Clip",
|
|
273
|
+
"Pad": "Pad",
|
|
274
|
+
"Resize": "Resize",
|
|
275
|
+
"Upsample": "Upsample",
|
|
276
|
+
}
|
|
277
|
+
|
|
278
|
+
|
|
279
|
+
def map_onnx_op_to_universal(onnx_op: str) -> str:
|
|
280
|
+
"""Map ONNX op type to universal op type.
|
|
281
|
+
|
|
282
|
+
Args:
|
|
283
|
+
onnx_op: ONNX operator name (e.g., "Conv", "Gemm")
|
|
284
|
+
|
|
285
|
+
Returns:
|
|
286
|
+
Universal op type (e.g., "Conv2D", "MatMul")
|
|
287
|
+
"""
|
|
288
|
+
return ONNX_TO_UNIVERSAL_OP.get(onnx_op, onnx_op)
|
|
289
|
+
|
|
290
|
+
|
|
291
|
+
# =============================================================================
|
|
292
|
+
# ONNX Adapter
|
|
293
|
+
# =============================================================================
|
|
294
|
+
|
|
295
|
+
|
|
296
|
+
@register_adapter
|
|
297
|
+
class OnnxAdapter(FormatAdapter):
|
|
298
|
+
"""Adapter for ONNX models (.onnx files).
|
|
299
|
+
|
|
300
|
+
This is the primary adapter since ONNX is HaoLine's native format.
|
|
301
|
+
Supports both reading and writing.
|
|
302
|
+
"""
|
|
303
|
+
|
|
304
|
+
name = "onnx"
|
|
305
|
+
extensions = [".onnx"]
|
|
306
|
+
source_format = SourceFormat.ONNX
|
|
307
|
+
|
|
308
|
+
def can_read(self, path: Path) -> bool:
|
|
309
|
+
"""Check if file is an ONNX model."""
|
|
310
|
+
return path.suffix.lower() == ".onnx"
|
|
311
|
+
|
|
312
|
+
def read(self, path: Path) -> UniversalGraph:
|
|
313
|
+
"""Read ONNX model and convert to UniversalGraph."""
|
|
314
|
+
import onnx
|
|
315
|
+
from onnx import numpy_helper
|
|
316
|
+
|
|
317
|
+
path = Path(path)
|
|
318
|
+
if not path.exists():
|
|
319
|
+
raise FileNotFoundError(f"ONNX model not found: {path}")
|
|
320
|
+
|
|
321
|
+
# Load model
|
|
322
|
+
model = onnx.load(str(path))
|
|
323
|
+
|
|
324
|
+
# Run shape inference for better metadata
|
|
325
|
+
try:
|
|
326
|
+
model = onnx.shape_inference.infer_shapes(model)
|
|
327
|
+
except Exception as e:
|
|
328
|
+
logger.warning(f"Shape inference failed: {e}")
|
|
329
|
+
|
|
330
|
+
graph = model.graph
|
|
331
|
+
|
|
332
|
+
# Build metadata
|
|
333
|
+
metadata = GraphMetadata(
|
|
334
|
+
name=graph.name or path.stem,
|
|
335
|
+
source_format=SourceFormat.ONNX,
|
|
336
|
+
source_path=str(path),
|
|
337
|
+
ir_version=model.ir_version,
|
|
338
|
+
producer_name=model.producer_name or None,
|
|
339
|
+
producer_version=model.producer_version or None,
|
|
340
|
+
opset_version=model.opset_import[0].version if model.opset_import else None,
|
|
341
|
+
input_names=[inp.name for inp in graph.input],
|
|
342
|
+
output_names=[out.name for out in graph.output],
|
|
343
|
+
)
|
|
344
|
+
|
|
345
|
+
# Build tensors dict
|
|
346
|
+
tensors: dict[str, UniversalTensor] = {}
|
|
347
|
+
|
|
348
|
+
# Add initializers (weights)
|
|
349
|
+
for init in graph.initializer:
|
|
350
|
+
tensor_data = numpy_helper.to_array(init)
|
|
351
|
+
tensors[init.name] = UniversalTensor(
|
|
352
|
+
name=init.name,
|
|
353
|
+
shape=list(tensor_data.shape),
|
|
354
|
+
dtype=DataType.from_numpy_dtype(tensor_data.dtype),
|
|
355
|
+
origin=TensorOrigin.WEIGHT,
|
|
356
|
+
data=tensor_data,
|
|
357
|
+
source_name=init.name,
|
|
358
|
+
)
|
|
359
|
+
|
|
360
|
+
# Add inputs (non-initializer)
|
|
361
|
+
initializer_names = {init.name for init in graph.initializer}
|
|
362
|
+
for inp in graph.input:
|
|
363
|
+
if inp.name not in initializer_names:
|
|
364
|
+
shape = self._extract_shape(inp)
|
|
365
|
+
dtype = self._extract_dtype(inp)
|
|
366
|
+
tensors[inp.name] = UniversalTensor(
|
|
367
|
+
name=inp.name,
|
|
368
|
+
shape=shape,
|
|
369
|
+
dtype=dtype,
|
|
370
|
+
origin=TensorOrigin.INPUT,
|
|
371
|
+
source_name=inp.name,
|
|
372
|
+
)
|
|
373
|
+
|
|
374
|
+
# Add outputs
|
|
375
|
+
for out in graph.output:
|
|
376
|
+
shape = self._extract_shape(out)
|
|
377
|
+
dtype = self._extract_dtype(out)
|
|
378
|
+
tensors[out.name] = UniversalTensor(
|
|
379
|
+
name=out.name,
|
|
380
|
+
shape=shape,
|
|
381
|
+
dtype=dtype,
|
|
382
|
+
origin=TensorOrigin.OUTPUT,
|
|
383
|
+
source_name=out.name,
|
|
384
|
+
)
|
|
385
|
+
|
|
386
|
+
# Add value_info (intermediate tensors)
|
|
387
|
+
for vi in graph.value_info:
|
|
388
|
+
if vi.name not in tensors:
|
|
389
|
+
shape = self._extract_shape(vi)
|
|
390
|
+
dtype = self._extract_dtype(vi)
|
|
391
|
+
tensors[vi.name] = UniversalTensor(
|
|
392
|
+
name=vi.name,
|
|
393
|
+
shape=shape,
|
|
394
|
+
dtype=dtype,
|
|
395
|
+
origin=TensorOrigin.ACTIVATION,
|
|
396
|
+
source_name=vi.name,
|
|
397
|
+
)
|
|
398
|
+
|
|
399
|
+
# Build nodes
|
|
400
|
+
nodes: list[UniversalNode] = []
|
|
401
|
+
for node in graph.node:
|
|
402
|
+
# Extract output shapes from value_info or tensors
|
|
403
|
+
output_shapes: list[list[int]] = []
|
|
404
|
+
output_dtypes: list[DataType] = []
|
|
405
|
+
for out_name in node.output:
|
|
406
|
+
if out_name in tensors:
|
|
407
|
+
output_shapes.append(tensors[out_name].shape)
|
|
408
|
+
output_dtypes.append(tensors[out_name].dtype)
|
|
409
|
+
else:
|
|
410
|
+
output_shapes.append([])
|
|
411
|
+
output_dtypes.append(DataType.UNKNOWN)
|
|
412
|
+
|
|
413
|
+
# Extract attributes
|
|
414
|
+
attrs = self._extract_attributes(node)
|
|
415
|
+
|
|
416
|
+
nodes.append(
|
|
417
|
+
UniversalNode(
|
|
418
|
+
id=node.name or f"{node.op_type}_{len(nodes)}",
|
|
419
|
+
op_type=map_onnx_op_to_universal(node.op_type),
|
|
420
|
+
inputs=list(node.input),
|
|
421
|
+
outputs=list(node.output),
|
|
422
|
+
attributes=attrs,
|
|
423
|
+
output_shapes=output_shapes,
|
|
424
|
+
output_dtypes=output_dtypes,
|
|
425
|
+
source_op=node.op_type,
|
|
426
|
+
source_domain=node.domain or "ai.onnx",
|
|
427
|
+
)
|
|
428
|
+
)
|
|
429
|
+
|
|
430
|
+
return UniversalGraph(
|
|
431
|
+
nodes=nodes,
|
|
432
|
+
tensors=tensors,
|
|
433
|
+
metadata=metadata,
|
|
434
|
+
)
|
|
435
|
+
|
|
436
|
+
def can_write(self) -> bool:
|
|
437
|
+
"""ONNX adapter supports writing."""
|
|
438
|
+
return True
|
|
439
|
+
|
|
440
|
+
def write(self, graph: UniversalGraph, path: Path) -> None:
|
|
441
|
+
"""Write UniversalGraph to ONNX format."""
|
|
442
|
+
import onnx
|
|
443
|
+
from onnx import helper, numpy_helper
|
|
444
|
+
|
|
445
|
+
# Create initializers (weights)
|
|
446
|
+
initializers = []
|
|
447
|
+
for tensor in graph.tensors.values():
|
|
448
|
+
if tensor.origin == TensorOrigin.WEIGHT and tensor.data is not None:
|
|
449
|
+
init = numpy_helper.from_array(tensor.data, name=tensor.name)
|
|
450
|
+
initializers.append(init)
|
|
451
|
+
|
|
452
|
+
# Create inputs
|
|
453
|
+
inputs = []
|
|
454
|
+
for tensor in graph.tensors.values():
|
|
455
|
+
if tensor.origin == TensorOrigin.INPUT:
|
|
456
|
+
elem_type = self._dtype_to_onnx(tensor.dtype)
|
|
457
|
+
shape = tensor.shape if tensor.shape else None
|
|
458
|
+
inp = helper.make_tensor_value_info(tensor.name, elem_type, shape)
|
|
459
|
+
inputs.append(inp)
|
|
460
|
+
|
|
461
|
+
# Also add weight tensors as inputs (ONNX convention)
|
|
462
|
+
for tensor in graph.tensors.values():
|
|
463
|
+
if tensor.origin == TensorOrigin.WEIGHT:
|
|
464
|
+
elem_type = self._dtype_to_onnx(tensor.dtype)
|
|
465
|
+
inp = helper.make_tensor_value_info(tensor.name, elem_type, tensor.shape)
|
|
466
|
+
inputs.append(inp)
|
|
467
|
+
|
|
468
|
+
# Create outputs
|
|
469
|
+
outputs = []
|
|
470
|
+
for tensor in graph.tensors.values():
|
|
471
|
+
if tensor.origin == TensorOrigin.OUTPUT:
|
|
472
|
+
elem_type = self._dtype_to_onnx(tensor.dtype)
|
|
473
|
+
shape = tensor.shape if tensor.shape else None
|
|
474
|
+
out = helper.make_tensor_value_info(tensor.name, elem_type, shape)
|
|
475
|
+
outputs.append(out)
|
|
476
|
+
|
|
477
|
+
# Create nodes
|
|
478
|
+
onnx_nodes = []
|
|
479
|
+
for node in graph.nodes:
|
|
480
|
+
# Map universal op back to ONNX op
|
|
481
|
+
onnx_op = node.source_op or self._universal_to_onnx_op(node.op_type)
|
|
482
|
+
|
|
483
|
+
onnx_node = helper.make_node(
|
|
484
|
+
onnx_op,
|
|
485
|
+
inputs=node.inputs,
|
|
486
|
+
outputs=node.outputs,
|
|
487
|
+
name=node.id,
|
|
488
|
+
domain=node.source_domain or "",
|
|
489
|
+
**node.attributes,
|
|
490
|
+
)
|
|
491
|
+
onnx_nodes.append(onnx_node)
|
|
492
|
+
|
|
493
|
+
# Create graph
|
|
494
|
+
onnx_graph = helper.make_graph(
|
|
495
|
+
onnx_nodes,
|
|
496
|
+
name=graph.metadata.name or "haoline_export",
|
|
497
|
+
inputs=inputs,
|
|
498
|
+
outputs=outputs,
|
|
499
|
+
initializer=initializers,
|
|
500
|
+
)
|
|
501
|
+
|
|
502
|
+
# Create model
|
|
503
|
+
opset_version = graph.metadata.opset_version or 17
|
|
504
|
+
model = helper.make_model(
|
|
505
|
+
onnx_graph,
|
|
506
|
+
opset_imports=[helper.make_opsetid("", opset_version)],
|
|
507
|
+
producer_name="haoline",
|
|
508
|
+
)
|
|
509
|
+
|
|
510
|
+
# Save
|
|
511
|
+
onnx.save(model, str(path))
|
|
512
|
+
|
|
513
|
+
def _extract_shape(self, value_info: onnx.ValueInfoProto) -> list[int]:
|
|
514
|
+
"""Extract shape from ONNX ValueInfoProto."""
|
|
515
|
+
shape: list[int] = []
|
|
516
|
+
try:
|
|
517
|
+
tensor_type = value_info.type.tensor_type
|
|
518
|
+
for dim in tensor_type.shape.dim:
|
|
519
|
+
if dim.dim_value > 0:
|
|
520
|
+
shape.append(dim.dim_value)
|
|
521
|
+
else:
|
|
522
|
+
# Dynamic dimension (dim_param)
|
|
523
|
+
shape.append(-1)
|
|
524
|
+
except Exception:
|
|
525
|
+
pass
|
|
526
|
+
return shape
|
|
527
|
+
|
|
528
|
+
def _extract_dtype(self, value_info: onnx.ValueInfoProto) -> DataType:
|
|
529
|
+
"""Extract dtype from ONNX ValueInfoProto."""
|
|
530
|
+
try:
|
|
531
|
+
elem_type = value_info.type.tensor_type.elem_type
|
|
532
|
+
return DataType.from_onnx_dtype(elem_type)
|
|
533
|
+
except Exception:
|
|
534
|
+
return DataType.UNKNOWN
|
|
535
|
+
|
|
536
|
+
def _extract_attributes(self, node: onnx.NodeProto) -> dict[str, object]:
|
|
537
|
+
"""Extract attributes from ONNX NodeProto."""
|
|
538
|
+
attrs: dict[str, object] = {}
|
|
539
|
+
for attr in node.attribute:
|
|
540
|
+
if attr.type == 1: # FLOAT
|
|
541
|
+
attrs[attr.name] = attr.f
|
|
542
|
+
elif attr.type == 2: # INT
|
|
543
|
+
attrs[attr.name] = attr.i
|
|
544
|
+
elif attr.type == 3: # STRING
|
|
545
|
+
attrs[attr.name] = attr.s.decode("utf-8") if attr.s else ""
|
|
546
|
+
elif attr.type == 6: # FLOATS
|
|
547
|
+
attrs[attr.name] = list(attr.floats)
|
|
548
|
+
elif attr.type == 7: # INTS
|
|
549
|
+
attrs[attr.name] = list(attr.ints)
|
|
550
|
+
elif attr.type == 8: # STRINGS
|
|
551
|
+
attrs[attr.name] = [s.decode("utf-8") for s in attr.strings]
|
|
552
|
+
# Skip TENSOR and GRAPH types for now
|
|
553
|
+
return attrs
|
|
554
|
+
|
|
555
|
+
def _dtype_to_onnx(self, dtype: DataType) -> int:
|
|
556
|
+
"""Convert DataType to ONNX TensorProto dtype."""
|
|
557
|
+
from onnx import TensorProto
|
|
558
|
+
|
|
559
|
+
mapping = {
|
|
560
|
+
DataType.FLOAT32: TensorProto.FLOAT,
|
|
561
|
+
DataType.FLOAT64: TensorProto.DOUBLE,
|
|
562
|
+
DataType.FLOAT16: TensorProto.FLOAT16,
|
|
563
|
+
DataType.BFLOAT16: TensorProto.BFLOAT16,
|
|
564
|
+
DataType.INT64: TensorProto.INT64,
|
|
565
|
+
DataType.INT32: TensorProto.INT32,
|
|
566
|
+
DataType.INT16: TensorProto.INT16,
|
|
567
|
+
DataType.INT8: TensorProto.INT8,
|
|
568
|
+
DataType.UINT8: TensorProto.UINT8,
|
|
569
|
+
DataType.BOOL: TensorProto.BOOL,
|
|
570
|
+
DataType.STRING: TensorProto.STRING,
|
|
571
|
+
}
|
|
572
|
+
return mapping.get(dtype, TensorProto.FLOAT)
|
|
573
|
+
|
|
574
|
+
def _universal_to_onnx_op(self, universal_op: str) -> str:
|
|
575
|
+
"""Map universal op type back to ONNX op."""
|
|
576
|
+
# Reverse mapping
|
|
577
|
+
reverse_map = {v: k for k, v in ONNX_TO_UNIVERSAL_OP.items()}
|
|
578
|
+
return reverse_map.get(universal_op, universal_op)
|
|
579
|
+
|
|
580
|
+
|
|
581
|
+
# =============================================================================
|
|
582
|
+
# PyTorch Adapter
|
|
583
|
+
# =============================================================================
|
|
584
|
+
|
|
585
|
+
|
|
586
|
+
@register_adapter
|
|
587
|
+
class PyTorchAdapter(FormatAdapter):
|
|
588
|
+
"""Adapter for PyTorch models (.pt, .pth files).
|
|
589
|
+
|
|
590
|
+
Converts PyTorch models to UniversalGraph by first exporting to ONNX,
|
|
591
|
+
then using the OnnxAdapter. This ensures consistent representation.
|
|
592
|
+
|
|
593
|
+
For full models (nn.Module), uses torch.onnx.export.
|
|
594
|
+
For state_dicts, extracts weights without graph structure.
|
|
595
|
+
"""
|
|
596
|
+
|
|
597
|
+
name = "pytorch"
|
|
598
|
+
extensions = [".pt", ".pth"]
|
|
599
|
+
source_format = SourceFormat.PYTORCH
|
|
600
|
+
|
|
601
|
+
def can_read(self, path: Path) -> bool:
|
|
602
|
+
"""Check if file is a PyTorch model."""
|
|
603
|
+
return path.suffix.lower() in [".pt", ".pth"]
|
|
604
|
+
|
|
605
|
+
def read(self, path: Path) -> UniversalGraph:
|
|
606
|
+
"""Read PyTorch model and convert to UniversalGraph.
|
|
607
|
+
|
|
608
|
+
Note: Requires sample input for tracing. Will attempt to
|
|
609
|
+
auto-detect input shape from the model structure.
|
|
610
|
+
"""
|
|
611
|
+
import torch
|
|
612
|
+
|
|
613
|
+
path = Path(path)
|
|
614
|
+
if not path.exists():
|
|
615
|
+
raise FileNotFoundError(f"PyTorch model not found: {path}")
|
|
616
|
+
|
|
617
|
+
loaded = torch.load(str(path), map_location="cpu", weights_only=False)
|
|
618
|
+
|
|
619
|
+
# Check if it's a full model or state_dict
|
|
620
|
+
if isinstance(loaded, torch.nn.Module):
|
|
621
|
+
return self._convert_module(loaded, path)
|
|
622
|
+
elif isinstance(loaded, dict):
|
|
623
|
+
# Could be state_dict or Ultralytics model
|
|
624
|
+
if "model" in loaded:
|
|
625
|
+
# Ultralytics YOLO model
|
|
626
|
+
return self._convert_ultralytics(loaded, path)
|
|
627
|
+
else:
|
|
628
|
+
# Pure state_dict - weights only
|
|
629
|
+
return self._convert_state_dict(loaded, path)
|
|
630
|
+
else:
|
|
631
|
+
raise ValueError(f"Unknown PyTorch file format: {type(loaded)}")
|
|
632
|
+
|
|
633
|
+
def _convert_module(self, model: torch.nn.Module, path: Path) -> UniversalGraph:
|
|
634
|
+
"""Convert torch.nn.Module to UniversalGraph via ONNX."""
|
|
635
|
+
import tempfile
|
|
636
|
+
|
|
637
|
+
import torch
|
|
638
|
+
|
|
639
|
+
model.eval()
|
|
640
|
+
|
|
641
|
+
# Try to detect input shape
|
|
642
|
+
dummy_input = self._create_dummy_input(model)
|
|
643
|
+
|
|
644
|
+
with tempfile.NamedTemporaryFile(suffix=".onnx", delete=False) as f:
|
|
645
|
+
onnx_path = Path(f.name)
|
|
646
|
+
|
|
647
|
+
try:
|
|
648
|
+
torch.onnx.export(
|
|
649
|
+
model,
|
|
650
|
+
(dummy_input,),
|
|
651
|
+
str(onnx_path),
|
|
652
|
+
opset_version=17,
|
|
653
|
+
do_constant_folding=True,
|
|
654
|
+
)
|
|
655
|
+
|
|
656
|
+
# Use ONNX adapter to read
|
|
657
|
+
graph = OnnxAdapter().read(onnx_path)
|
|
658
|
+
|
|
659
|
+
# Update metadata to reflect PyTorch origin
|
|
660
|
+
graph.metadata.source_format = SourceFormat.PYTORCH
|
|
661
|
+
graph.metadata.source_path = str(path)
|
|
662
|
+
|
|
663
|
+
return graph
|
|
664
|
+
|
|
665
|
+
finally:
|
|
666
|
+
if onnx_path.exists():
|
|
667
|
+
onnx_path.unlink()
|
|
668
|
+
|
|
669
|
+
def _convert_ultralytics(self, loaded: dict[str, object], path: Path) -> UniversalGraph:
|
|
670
|
+
"""Convert Ultralytics YOLO model to UniversalGraph."""
|
|
671
|
+
import tempfile
|
|
672
|
+
|
|
673
|
+
try:
|
|
674
|
+
from ultralytics import YOLO
|
|
675
|
+
except ImportError as e:
|
|
676
|
+
raise ImportError(
|
|
677
|
+
"Ultralytics YOLO model detected. Install ultralytics: pip install ultralytics"
|
|
678
|
+
) from e
|
|
679
|
+
|
|
680
|
+
# Use Ultralytics export
|
|
681
|
+
yolo = YOLO(str(path))
|
|
682
|
+
|
|
683
|
+
with tempfile.TemporaryDirectory() as _tmpdir:
|
|
684
|
+
onnx_path_str: str = yolo.export(format="onnx")
|
|
685
|
+
onnx_path = Path(onnx_path_str)
|
|
686
|
+
|
|
687
|
+
# Use ONNX adapter to read
|
|
688
|
+
graph = OnnxAdapter().read(onnx_path)
|
|
689
|
+
|
|
690
|
+
# Update metadata
|
|
691
|
+
graph.metadata.source_format = SourceFormat.PYTORCH
|
|
692
|
+
graph.metadata.source_path = str(path)
|
|
693
|
+
graph.metadata.extra["ultralytics"] = True
|
|
694
|
+
|
|
695
|
+
return graph
|
|
696
|
+
|
|
697
|
+
def _convert_state_dict(self, state_dict: dict[str, object], path: Path) -> UniversalGraph:
|
|
698
|
+
"""Convert state_dict to UniversalGraph (weights only, no graph)."""
|
|
699
|
+
import torch
|
|
700
|
+
|
|
701
|
+
tensors: dict[str, UniversalTensor] = {}
|
|
702
|
+
|
|
703
|
+
for name, param in state_dict.items():
|
|
704
|
+
if isinstance(param, torch.Tensor):
|
|
705
|
+
np_data = param.detach().cpu().numpy()
|
|
706
|
+
tensors[name] = UniversalTensor(
|
|
707
|
+
name=name,
|
|
708
|
+
shape=list(np_data.shape),
|
|
709
|
+
dtype=DataType.from_numpy_dtype(np_data.dtype),
|
|
710
|
+
origin=TensorOrigin.WEIGHT,
|
|
711
|
+
data=np_data,
|
|
712
|
+
source_name=name,
|
|
713
|
+
)
|
|
714
|
+
|
|
715
|
+
return UniversalGraph(
|
|
716
|
+
nodes=[], # No graph structure for state_dict
|
|
717
|
+
tensors=tensors,
|
|
718
|
+
metadata=GraphMetadata(
|
|
719
|
+
name=path.stem,
|
|
720
|
+
source_format=SourceFormat.PYTORCH,
|
|
721
|
+
source_path=str(path),
|
|
722
|
+
extra={"type": "state_dict"},
|
|
723
|
+
),
|
|
724
|
+
)
|
|
725
|
+
|
|
726
|
+
def _create_dummy_input(self, model: torch.nn.Module) -> torch.Tensor:
|
|
727
|
+
"""Create dummy input for ONNX export.
|
|
728
|
+
|
|
729
|
+
Attempts to auto-detect input shape from the model's first layer.
|
|
730
|
+
"""
|
|
731
|
+
import torch
|
|
732
|
+
|
|
733
|
+
# Try to find first conv or linear layer
|
|
734
|
+
for module in model.modules():
|
|
735
|
+
if isinstance(module, torch.nn.Conv2d):
|
|
736
|
+
# Assume image input
|
|
737
|
+
in_channels = module.in_channels
|
|
738
|
+
return torch.randn(1, in_channels, 224, 224)
|
|
739
|
+
elif isinstance(module, torch.nn.Linear):
|
|
740
|
+
in_features = module.in_features
|
|
741
|
+
return torch.randn(1, in_features)
|
|
742
|
+
|
|
743
|
+
# Default: batch of 224x224 RGB images
|
|
744
|
+
return torch.randn(1, 3, 224, 224)
|
|
745
|
+
|
|
746
|
+
|
|
747
|
+
# =============================================================================
|
|
748
|
+
# Utility Functions
|
|
749
|
+
# =============================================================================
|
|
750
|
+
|
|
751
|
+
|
|
752
|
+
def load_model(path: str | Path) -> UniversalGraph:
|
|
753
|
+
"""Load a model file and convert to UniversalGraph.
|
|
754
|
+
|
|
755
|
+
Auto-detects format based on file extension.
|
|
756
|
+
|
|
757
|
+
Args:
|
|
758
|
+
path: Path to the model file
|
|
759
|
+
|
|
760
|
+
Returns:
|
|
761
|
+
UniversalGraph representation
|
|
762
|
+
|
|
763
|
+
Example:
|
|
764
|
+
graph = load_model("model.onnx")
|
|
765
|
+
graph = load_model("model.pt")
|
|
766
|
+
"""
|
|
767
|
+
adapter = get_adapter(path)
|
|
768
|
+
return adapter.read(Path(path))
|
|
769
|
+
|
|
770
|
+
|
|
771
|
+
def save_model(graph: UniversalGraph, path: str | Path) -> None:
|
|
772
|
+
"""Save UniversalGraph to a model file.
|
|
773
|
+
|
|
774
|
+
Format is determined by file extension.
|
|
775
|
+
|
|
776
|
+
Args:
|
|
777
|
+
graph: The graph to save
|
|
778
|
+
path: Output file path
|
|
779
|
+
|
|
780
|
+
Raises:
|
|
781
|
+
ValueError: If adapter doesn't support writing
|
|
782
|
+
"""
|
|
783
|
+
path = Path(path)
|
|
784
|
+
adapter = get_adapter(path)
|
|
785
|
+
if not adapter.can_write():
|
|
786
|
+
raise ValueError(f"{adapter.name} adapter does not support writing")
|
|
787
|
+
adapter.write(graph, path)
|
|
788
|
+
|
|
789
|
+
|
|
790
|
+
# =============================================================================
|
|
791
|
+
# Conversion Matrix (Task 18.3)
|
|
792
|
+
# =============================================================================
|
|
793
|
+
|
|
794
|
+
|
|
795
|
+
class ConversionLevel(str, Enum):
|
|
796
|
+
"""Conversion capability between formats.
|
|
797
|
+
|
|
798
|
+
Describes how well a conversion preserves information:
|
|
799
|
+
- FULL: Lossless conversion, all info preserved
|
|
800
|
+
- PARTIAL: Some limitations or requires multi-step
|
|
801
|
+
- LOSSY: Some information is lost
|
|
802
|
+
- NONE: No conversion path available
|
|
803
|
+
"""
|
|
804
|
+
|
|
805
|
+
FULL = "full" # Lossless, complete conversion
|
|
806
|
+
PARTIAL = "partial" # Some limitations or multi-step required
|
|
807
|
+
LOSSY = "lossy" # Information loss during conversion
|
|
808
|
+
NONE = "none" # No conversion path
|
|
809
|
+
|
|
810
|
+
|
|
811
|
+
# Conversion matrix: (source, target) -> ConversionLevel
|
|
812
|
+
# Format: CONVERSION_MATRIX[(source_format, target_format)] = level
|
|
813
|
+
_CONVERSION_MATRIX: dict[tuple[SourceFormat, SourceFormat], ConversionLevel] = {
|
|
814
|
+
# ONNX conversions (primary interchange format)
|
|
815
|
+
(SourceFormat.ONNX, SourceFormat.TENSORRT): ConversionLevel.PARTIAL, # TensorRT-specific ops
|
|
816
|
+
(SourceFormat.ONNX, SourceFormat.TFLITE): ConversionLevel.PARTIAL, # Some ops unsupported
|
|
817
|
+
(SourceFormat.ONNX, SourceFormat.COREML): ConversionLevel.PARTIAL, # iOS-specific limits
|
|
818
|
+
(SourceFormat.ONNX, SourceFormat.OPENVINO): ConversionLevel.FULL, # Good ONNX support
|
|
819
|
+
# PyTorch conversions (via ONNX)
|
|
820
|
+
(SourceFormat.PYTORCH, SourceFormat.ONNX): ConversionLevel.FULL, # torch.onnx.export
|
|
821
|
+
(SourceFormat.PYTORCH, SourceFormat.TENSORRT): ConversionLevel.PARTIAL, # Via ONNX
|
|
822
|
+
(SourceFormat.PYTORCH, SourceFormat.TFLITE): ConversionLevel.PARTIAL, # Via ONNX
|
|
823
|
+
(SourceFormat.PYTORCH, SourceFormat.COREML): ConversionLevel.PARTIAL, # coremltools
|
|
824
|
+
# TensorFlow conversions
|
|
825
|
+
(SourceFormat.TENSORFLOW, SourceFormat.ONNX): ConversionLevel.PARTIAL, # tf2onnx
|
|
826
|
+
(SourceFormat.TENSORFLOW, SourceFormat.TFLITE): ConversionLevel.FULL, # TFLite converter
|
|
827
|
+
(SourceFormat.TENSORFLOW, SourceFormat.COREML): ConversionLevel.PARTIAL, # coremltools
|
|
828
|
+
# TensorRT (inference-only, limited export)
|
|
829
|
+
(SourceFormat.TENSORRT, SourceFormat.ONNX): ConversionLevel.NONE, # Cannot export
|
|
830
|
+
# CoreML (Apple ecosystem)
|
|
831
|
+
(SourceFormat.COREML, SourceFormat.ONNX): ConversionLevel.LOSSY, # Some info lost
|
|
832
|
+
# TFLite (mobile)
|
|
833
|
+
(SourceFormat.TFLITE, SourceFormat.ONNX): ConversionLevel.PARTIAL, # tflite2onnx
|
|
834
|
+
# Weights-only formats (no graph structure)
|
|
835
|
+
(SourceFormat.SAFETENSORS, SourceFormat.ONNX): ConversionLevel.NONE, # Weights only
|
|
836
|
+
(SourceFormat.GGUF, SourceFormat.ONNX): ConversionLevel.NONE, # Weights only
|
|
837
|
+
}
|
|
838
|
+
|
|
839
|
+
|
|
840
|
+
def get_conversion_level(source: SourceFormat | str, target: SourceFormat | str) -> ConversionLevel:
|
|
841
|
+
"""Get the conversion capability between two formats.
|
|
842
|
+
|
|
843
|
+
Args:
|
|
844
|
+
source: Source model format
|
|
845
|
+
target: Target model format
|
|
846
|
+
|
|
847
|
+
Returns:
|
|
848
|
+
ConversionLevel indicating conversion capability
|
|
849
|
+
"""
|
|
850
|
+
# Normalize to SourceFormat
|
|
851
|
+
if isinstance(source, str):
|
|
852
|
+
try:
|
|
853
|
+
source = SourceFormat(source.lower())
|
|
854
|
+
except ValueError:
|
|
855
|
+
return ConversionLevel.NONE
|
|
856
|
+
if isinstance(target, str):
|
|
857
|
+
try:
|
|
858
|
+
target = SourceFormat(target.lower())
|
|
859
|
+
except ValueError:
|
|
860
|
+
return ConversionLevel.NONE
|
|
861
|
+
|
|
862
|
+
# Identity conversion
|
|
863
|
+
if source == target:
|
|
864
|
+
return ConversionLevel.FULL
|
|
865
|
+
|
|
866
|
+
return _CONVERSION_MATRIX.get((source, target), ConversionLevel.NONE)
|
|
867
|
+
|
|
868
|
+
|
|
869
|
+
def list_conversion_paths(
|
|
870
|
+
source: SourceFormat | str | None = None,
|
|
871
|
+
target: SourceFormat | str | None = None,
|
|
872
|
+
) -> list[dict[str, str]]:
|
|
873
|
+
"""List available conversion paths.
|
|
874
|
+
|
|
875
|
+
Args:
|
|
876
|
+
source: Filter by source format (optional)
|
|
877
|
+
target: Filter by target format (optional)
|
|
878
|
+
|
|
879
|
+
Returns:
|
|
880
|
+
List of dicts with source, target, and level
|
|
881
|
+
"""
|
|
882
|
+
result: list[dict[str, str]] = []
|
|
883
|
+
|
|
884
|
+
for (src, tgt), level in _CONVERSION_MATRIX.items():
|
|
885
|
+
# Apply filters
|
|
886
|
+
if source is not None:
|
|
887
|
+
source_fmt = (
|
|
888
|
+
source if isinstance(source, SourceFormat) else SourceFormat(source.lower())
|
|
889
|
+
)
|
|
890
|
+
if src != source_fmt:
|
|
891
|
+
continue
|
|
892
|
+
if target is not None:
|
|
893
|
+
target_fmt = (
|
|
894
|
+
target if isinstance(target, SourceFormat) else SourceFormat(target.lower())
|
|
895
|
+
)
|
|
896
|
+
if tgt != target_fmt:
|
|
897
|
+
continue
|
|
898
|
+
|
|
899
|
+
result.append(
|
|
900
|
+
{
|
|
901
|
+
"source": src.value,
|
|
902
|
+
"target": tgt.value,
|
|
903
|
+
"level": level.value,
|
|
904
|
+
}
|
|
905
|
+
)
|
|
906
|
+
|
|
907
|
+
return sorted(result, key=lambda x: (x["source"], x["target"]))
|
|
908
|
+
|
|
909
|
+
|
|
910
|
+
def can_convert(source: SourceFormat | str, target: SourceFormat | str) -> bool:
|
|
911
|
+
"""Check if conversion is possible between two formats.
|
|
912
|
+
|
|
913
|
+
Returns True for FULL, PARTIAL, or LOSSY conversions.
|
|
914
|
+
|
|
915
|
+
Args:
|
|
916
|
+
source: Source model format
|
|
917
|
+
target: Target model format
|
|
918
|
+
|
|
919
|
+
Returns:
|
|
920
|
+
True if any conversion path exists
|
|
921
|
+
"""
|
|
922
|
+
level = get_conversion_level(source, target)
|
|
923
|
+
return level != ConversionLevel.NONE
|
|
924
|
+
|
|
925
|
+
|
|
926
|
+
def convert_model(
|
|
927
|
+
graph: UniversalGraph,
|
|
928
|
+
target_format: SourceFormat | str,
|
|
929
|
+
output_path: Path | str,
|
|
930
|
+
) -> Path:
|
|
931
|
+
"""Convert a model to a different format.
|
|
932
|
+
|
|
933
|
+
Args:
|
|
934
|
+
graph: UniversalGraph to convert
|
|
935
|
+
target_format: Target format (e.g., "onnx", "tflite")
|
|
936
|
+
output_path: Output file path
|
|
937
|
+
|
|
938
|
+
Returns:
|
|
939
|
+
Path to the converted model
|
|
940
|
+
|
|
941
|
+
Raises:
|
|
942
|
+
ValueError: If conversion is not supported
|
|
943
|
+
"""
|
|
944
|
+
output_path = Path(output_path)
|
|
945
|
+
|
|
946
|
+
# Get conversion level
|
|
947
|
+
source = graph.metadata.source_format
|
|
948
|
+
if isinstance(target_format, str):
|
|
949
|
+
target_format = SourceFormat(target_format.lower())
|
|
950
|
+
|
|
951
|
+
level = get_conversion_level(source, target_format)
|
|
952
|
+
if level == ConversionLevel.NONE:
|
|
953
|
+
raise ValueError(
|
|
954
|
+
f"Cannot convert from {source.value} to {target_format.value}. "
|
|
955
|
+
f"No conversion path available."
|
|
956
|
+
)
|
|
957
|
+
|
|
958
|
+
# Log warning for lossy conversions
|
|
959
|
+
if level == ConversionLevel.LOSSY:
|
|
960
|
+
logger.warning(f"Converting {source.value} to {target_format.value} may lose information")
|
|
961
|
+
|
|
962
|
+
# Get target adapter
|
|
963
|
+
# For now, only ONNX writing is supported
|
|
964
|
+
if target_format != SourceFormat.ONNX:
|
|
965
|
+
raise NotImplementedError(
|
|
966
|
+
f"Direct conversion to {target_format.value} not yet implemented. "
|
|
967
|
+
f"Export to ONNX first, then use format-specific tools."
|
|
968
|
+
)
|
|
969
|
+
|
|
970
|
+
adapter = OnnxAdapter()
|
|
971
|
+
adapter.write(graph, output_path)
|
|
972
|
+
|
|
973
|
+
return output_path
|
|
974
|
+
|
|
975
|
+
|
|
976
|
+
# =============================================================================
|
|
977
|
+
# Module Exports
|
|
978
|
+
# =============================================================================
|
|
979
|
+
|
|
980
|
+
__all__ = [
|
|
981
|
+
# Protocol
|
|
982
|
+
"FormatAdapter",
|
|
983
|
+
# Registry
|
|
984
|
+
"register_adapter",
|
|
985
|
+
"get_adapter",
|
|
986
|
+
"list_adapters",
|
|
987
|
+
# Adapters
|
|
988
|
+
"OnnxAdapter",
|
|
989
|
+
"PyTorchAdapter",
|
|
990
|
+
# Utilities
|
|
991
|
+
"load_model",
|
|
992
|
+
"save_model",
|
|
993
|
+
"map_onnx_op_to_universal",
|
|
994
|
+
"ONNX_TO_UNIVERSAL_OP",
|
|
995
|
+
# Conversion Matrix
|
|
996
|
+
"ConversionLevel",
|
|
997
|
+
"get_conversion_level",
|
|
998
|
+
"list_conversion_paths",
|
|
999
|
+
"can_convert",
|
|
1000
|
+
"convert_model",
|
|
1001
|
+
]
|