haoline 0.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (70) hide show
  1. haoline/.streamlit/config.toml +10 -0
  2. haoline/__init__.py +248 -0
  3. haoline/analyzer.py +935 -0
  4. haoline/cli.py +2712 -0
  5. haoline/compare.py +811 -0
  6. haoline/compare_visualizations.py +1564 -0
  7. haoline/edge_analysis.py +525 -0
  8. haoline/eval/__init__.py +131 -0
  9. haoline/eval/adapters.py +844 -0
  10. haoline/eval/cli.py +390 -0
  11. haoline/eval/comparison.py +542 -0
  12. haoline/eval/deployment.py +633 -0
  13. haoline/eval/schemas.py +833 -0
  14. haoline/examples/__init__.py +15 -0
  15. haoline/examples/basic_inspection.py +74 -0
  16. haoline/examples/compare_models.py +117 -0
  17. haoline/examples/hardware_estimation.py +78 -0
  18. haoline/format_adapters.py +1001 -0
  19. haoline/formats/__init__.py +123 -0
  20. haoline/formats/coreml.py +250 -0
  21. haoline/formats/gguf.py +483 -0
  22. haoline/formats/openvino.py +255 -0
  23. haoline/formats/safetensors.py +273 -0
  24. haoline/formats/tflite.py +369 -0
  25. haoline/hardware.py +2307 -0
  26. haoline/hierarchical_graph.py +462 -0
  27. haoline/html_export.py +1573 -0
  28. haoline/layer_summary.py +769 -0
  29. haoline/llm_summarizer.py +465 -0
  30. haoline/op_icons.py +618 -0
  31. haoline/operational_profiling.py +1492 -0
  32. haoline/patterns.py +1116 -0
  33. haoline/pdf_generator.py +265 -0
  34. haoline/privacy.py +250 -0
  35. haoline/pydantic_models.py +241 -0
  36. haoline/report.py +1923 -0
  37. haoline/report_sections.py +539 -0
  38. haoline/risks.py +521 -0
  39. haoline/schema.py +523 -0
  40. haoline/streamlit_app.py +2024 -0
  41. haoline/tests/__init__.py +4 -0
  42. haoline/tests/conftest.py +123 -0
  43. haoline/tests/test_analyzer.py +868 -0
  44. haoline/tests/test_compare_visualizations.py +293 -0
  45. haoline/tests/test_edge_analysis.py +243 -0
  46. haoline/tests/test_eval.py +604 -0
  47. haoline/tests/test_format_adapters.py +460 -0
  48. haoline/tests/test_hardware.py +237 -0
  49. haoline/tests/test_hardware_recommender.py +90 -0
  50. haoline/tests/test_hierarchical_graph.py +326 -0
  51. haoline/tests/test_html_export.py +180 -0
  52. haoline/tests/test_layer_summary.py +428 -0
  53. haoline/tests/test_llm_patterns.py +540 -0
  54. haoline/tests/test_llm_summarizer.py +339 -0
  55. haoline/tests/test_patterns.py +774 -0
  56. haoline/tests/test_pytorch.py +327 -0
  57. haoline/tests/test_report.py +383 -0
  58. haoline/tests/test_risks.py +398 -0
  59. haoline/tests/test_schema.py +417 -0
  60. haoline/tests/test_tensorflow.py +380 -0
  61. haoline/tests/test_visualizations.py +316 -0
  62. haoline/universal_ir.py +856 -0
  63. haoline/visualizations.py +1086 -0
  64. haoline/visualize_yolo.py +44 -0
  65. haoline/web.py +110 -0
  66. haoline-0.3.0.dist-info/METADATA +471 -0
  67. haoline-0.3.0.dist-info/RECORD +70 -0
  68. haoline-0.3.0.dist-info/WHEEL +4 -0
  69. haoline-0.3.0.dist-info/entry_points.txt +5 -0
  70. haoline-0.3.0.dist-info/licenses/LICENSE +22 -0
@@ -0,0 +1,1001 @@
1
+ # Copyright (c) 2025 HaoLine Contributors
2
+ # SPDX-License-Identifier: MIT
3
+
4
+ """
5
+ Format Adapter system for Universal IR.
6
+
7
+ This module provides the plugin interface for model format readers/writers.
8
+ Each adapter converts format-specific models to/from UniversalGraph.
9
+
10
+ Usage:
11
+ from haoline.format_adapters import get_adapter, list_adapters
12
+
13
+ # Auto-detect and load
14
+ adapter = get_adapter("model.onnx")
15
+ graph = adapter.read("model.onnx")
16
+
17
+ # Explicit adapter selection
18
+ from haoline.format_adapters import OnnxAdapter
19
+ graph = OnnxAdapter().read("model.onnx")
20
+ """
21
+
22
+ from __future__ import annotations
23
+
24
+ import logging
25
+ from abc import ABC, abstractmethod
26
+ from enum import Enum
27
+ from pathlib import Path
28
+ from typing import TYPE_CHECKING
29
+
30
+ if TYPE_CHECKING:
31
+ import onnx
32
+ import torch
33
+
34
+ from .universal_ir import (
35
+ DataType,
36
+ GraphMetadata,
37
+ SourceFormat,
38
+ TensorOrigin,
39
+ UniversalGraph,
40
+ UniversalNode,
41
+ UniversalTensor,
42
+ )
43
+
44
+ logger = logging.getLogger(__name__)
45
+
46
+
47
+ # =============================================================================
48
+ # Format Adapter Protocol
49
+ # =============================================================================
50
+
51
+
52
+ class FormatAdapter(ABC):
53
+ """Abstract base class for model format adapters.
54
+
55
+ Implement this interface to add support for a new model format.
56
+ Register the adapter using `register_adapter()`.
57
+
58
+ Example:
59
+ class MyFormatAdapter(FormatAdapter):
60
+ name = "myformat"
61
+ extensions = [".myf", ".myformat"]
62
+ source_format = SourceFormat.UNKNOWN
63
+
64
+ def can_read(self, path: Path) -> bool:
65
+ return path.suffix.lower() in self.extensions
66
+
67
+ def read(self, path: Path) -> UniversalGraph:
68
+ # Parse format-specific file
69
+ # Build and return UniversalGraph
70
+ ...
71
+ """
72
+
73
+ # Adapter metadata (override in subclasses)
74
+ name: str = "unknown"
75
+ extensions: list[str] = []
76
+ source_format: SourceFormat = SourceFormat.UNKNOWN
77
+
78
+ @abstractmethod
79
+ def can_read(self, path: Path) -> bool:
80
+ """Check if this adapter can read the given file.
81
+
82
+ Args:
83
+ path: Path to the model file
84
+
85
+ Returns:
86
+ True if this adapter can read the file
87
+ """
88
+ pass
89
+
90
+ @abstractmethod
91
+ def read(self, path: Path) -> UniversalGraph:
92
+ """Read a model file and convert to UniversalGraph.
93
+
94
+ Args:
95
+ path: Path to the model file
96
+
97
+ Returns:
98
+ UniversalGraph representation of the model
99
+
100
+ Raises:
101
+ FileNotFoundError: If file doesn't exist
102
+ ValueError: If file format is invalid
103
+ """
104
+ pass
105
+
106
+ def can_write(self) -> bool:
107
+ """Check if this adapter supports writing.
108
+
109
+ Override this method if your adapter supports exporting
110
+ UniversalGraph back to the format-specific file.
111
+
112
+ Returns:
113
+ True if write() is supported
114
+ """
115
+ return False
116
+
117
+ def write(self, graph: UniversalGraph, path: Path) -> None:
118
+ """Write UniversalGraph to a format-specific file.
119
+
120
+ Args:
121
+ graph: The graph to export
122
+ path: Output file path
123
+
124
+ Raises:
125
+ NotImplementedError: If writing is not supported
126
+ """
127
+ raise NotImplementedError(f"{self.name} adapter does not support writing")
128
+
129
+
130
+ # =============================================================================
131
+ # Adapter Registry
132
+ # =============================================================================
133
+
134
+ # Global registry mapping extensions to adapters
135
+ _ADAPTER_REGISTRY: dict[str, type[FormatAdapter]] = {}
136
+
137
+
138
+ def register_adapter(adapter_class: type[FormatAdapter]) -> type[FormatAdapter]:
139
+ """Register a format adapter.
140
+
141
+ Can be used as a decorator:
142
+
143
+ @register_adapter
144
+ class MyFormatAdapter(FormatAdapter):
145
+ ...
146
+
147
+ Args:
148
+ adapter_class: The adapter class to register
149
+
150
+ Returns:
151
+ The adapter class (for decorator use)
152
+ """
153
+ for ext in adapter_class.extensions:
154
+ ext_lower = ext.lower()
155
+ if ext_lower in _ADAPTER_REGISTRY:
156
+ logger.warning(
157
+ f"Overwriting adapter for {ext_lower}: "
158
+ f"{_ADAPTER_REGISTRY[ext_lower].name} -> {adapter_class.name}"
159
+ )
160
+ _ADAPTER_REGISTRY[ext_lower] = adapter_class
161
+ logger.debug(f"Registered adapter: {adapter_class.name} for {adapter_class.extensions}")
162
+ return adapter_class
163
+
164
+
165
+ def get_adapter(path: str | Path) -> FormatAdapter:
166
+ """Get an adapter for the given file.
167
+
168
+ Auto-detects the format based on file extension.
169
+
170
+ Args:
171
+ path: Path to the model file
172
+
173
+ Returns:
174
+ An instance of the appropriate FormatAdapter
175
+
176
+ Raises:
177
+ ValueError: If no adapter is registered for the file extension
178
+ """
179
+ path = Path(path)
180
+ ext = path.suffix.lower()
181
+
182
+ if ext not in _ADAPTER_REGISTRY:
183
+ available = ", ".join(sorted(_ADAPTER_REGISTRY.keys()))
184
+ raise ValueError(f"No adapter registered for extension '{ext}'. Available: {available}")
185
+
186
+ adapter_class = _ADAPTER_REGISTRY[ext]
187
+ return adapter_class()
188
+
189
+
190
+ def list_adapters() -> list[dict[str, str | list[str] | bool]]:
191
+ """List all registered adapters.
192
+
193
+ Returns:
194
+ List of adapter info dicts with keys: name, extensions, can_write
195
+ """
196
+ seen: set[str] = set()
197
+ result: list[dict[str, str | list[str] | bool]] = []
198
+
199
+ for adapter_class in _ADAPTER_REGISTRY.values():
200
+ if adapter_class.name not in seen:
201
+ seen.add(adapter_class.name)
202
+ instance = adapter_class()
203
+ result.append(
204
+ {
205
+ "name": adapter_class.name,
206
+ "extensions": adapter_class.extensions,
207
+ "source_format": adapter_class.source_format.value,
208
+ "can_write": instance.can_write(),
209
+ }
210
+ )
211
+
212
+ return sorted(result, key=lambda x: str(x["name"]))
213
+
214
+
215
+ # =============================================================================
216
+ # Op Type Mapping (ONNX -> Universal)
217
+ # =============================================================================
218
+
219
+ # Map ONNX op types to universal op types
220
+ ONNX_TO_UNIVERSAL_OP: dict[str, str] = {
221
+ # Convolution
222
+ "Conv": "Conv2D",
223
+ "ConvTranspose": "ConvTranspose2D",
224
+ # Linear/Dense
225
+ "Gemm": "MatMul",
226
+ "MatMul": "MatMul",
227
+ "MatMulInteger": "MatMul",
228
+ # Normalization
229
+ "BatchNormalization": "BatchNorm",
230
+ "LayerNormalization": "LayerNorm",
231
+ "InstanceNormalization": "InstanceNorm",
232
+ "GroupNormalization": "GroupNorm",
233
+ # Activations
234
+ "Relu": "Relu",
235
+ "LeakyRelu": "LeakyRelu",
236
+ "Sigmoid": "Sigmoid",
237
+ "Tanh": "Tanh",
238
+ "Softmax": "Softmax",
239
+ "Gelu": "Gelu",
240
+ "Silu": "Silu",
241
+ "Mish": "Mish",
242
+ # Pooling
243
+ "MaxPool": "MaxPool2D",
244
+ "AveragePool": "AvgPool2D",
245
+ "GlobalAveragePool": "GlobalAvgPool",
246
+ "GlobalMaxPool": "GlobalMaxPool",
247
+ # Element-wise
248
+ "Add": "Add",
249
+ "Sub": "Sub",
250
+ "Mul": "Mul",
251
+ "Div": "Div",
252
+ # Reshape/View
253
+ "Reshape": "Reshape",
254
+ "Flatten": "Flatten",
255
+ "Squeeze": "Squeeze",
256
+ "Unsqueeze": "Unsqueeze",
257
+ "Transpose": "Transpose",
258
+ # Attention (custom/subgraph)
259
+ "Attention": "Attention",
260
+ "MultiHeadAttention": "MultiHeadAttention",
261
+ # Misc
262
+ "Concat": "Concat",
263
+ "Split": "Split",
264
+ "Slice": "Slice",
265
+ "Gather": "Gather",
266
+ "Dropout": "Dropout",
267
+ "Constant": "Constant",
268
+ "Identity": "Identity",
269
+ "Cast": "Cast",
270
+ "ReduceMean": "ReduceMean",
271
+ "ReduceSum": "ReduceSum",
272
+ "Clip": "Clip",
273
+ "Pad": "Pad",
274
+ "Resize": "Resize",
275
+ "Upsample": "Upsample",
276
+ }
277
+
278
+
279
+ def map_onnx_op_to_universal(onnx_op: str) -> str:
280
+ """Map ONNX op type to universal op type.
281
+
282
+ Args:
283
+ onnx_op: ONNX operator name (e.g., "Conv", "Gemm")
284
+
285
+ Returns:
286
+ Universal op type (e.g., "Conv2D", "MatMul")
287
+ """
288
+ return ONNX_TO_UNIVERSAL_OP.get(onnx_op, onnx_op)
289
+
290
+
291
+ # =============================================================================
292
+ # ONNX Adapter
293
+ # =============================================================================
294
+
295
+
296
+ @register_adapter
297
+ class OnnxAdapter(FormatAdapter):
298
+ """Adapter for ONNX models (.onnx files).
299
+
300
+ This is the primary adapter since ONNX is HaoLine's native format.
301
+ Supports both reading and writing.
302
+ """
303
+
304
+ name = "onnx"
305
+ extensions = [".onnx"]
306
+ source_format = SourceFormat.ONNX
307
+
308
+ def can_read(self, path: Path) -> bool:
309
+ """Check if file is an ONNX model."""
310
+ return path.suffix.lower() == ".onnx"
311
+
312
+ def read(self, path: Path) -> UniversalGraph:
313
+ """Read ONNX model and convert to UniversalGraph."""
314
+ import onnx
315
+ from onnx import numpy_helper
316
+
317
+ path = Path(path)
318
+ if not path.exists():
319
+ raise FileNotFoundError(f"ONNX model not found: {path}")
320
+
321
+ # Load model
322
+ model = onnx.load(str(path))
323
+
324
+ # Run shape inference for better metadata
325
+ try:
326
+ model = onnx.shape_inference.infer_shapes(model)
327
+ except Exception as e:
328
+ logger.warning(f"Shape inference failed: {e}")
329
+
330
+ graph = model.graph
331
+
332
+ # Build metadata
333
+ metadata = GraphMetadata(
334
+ name=graph.name or path.stem,
335
+ source_format=SourceFormat.ONNX,
336
+ source_path=str(path),
337
+ ir_version=model.ir_version,
338
+ producer_name=model.producer_name or None,
339
+ producer_version=model.producer_version or None,
340
+ opset_version=model.opset_import[0].version if model.opset_import else None,
341
+ input_names=[inp.name for inp in graph.input],
342
+ output_names=[out.name for out in graph.output],
343
+ )
344
+
345
+ # Build tensors dict
346
+ tensors: dict[str, UniversalTensor] = {}
347
+
348
+ # Add initializers (weights)
349
+ for init in graph.initializer:
350
+ tensor_data = numpy_helper.to_array(init)
351
+ tensors[init.name] = UniversalTensor(
352
+ name=init.name,
353
+ shape=list(tensor_data.shape),
354
+ dtype=DataType.from_numpy_dtype(tensor_data.dtype),
355
+ origin=TensorOrigin.WEIGHT,
356
+ data=tensor_data,
357
+ source_name=init.name,
358
+ )
359
+
360
+ # Add inputs (non-initializer)
361
+ initializer_names = {init.name for init in graph.initializer}
362
+ for inp in graph.input:
363
+ if inp.name not in initializer_names:
364
+ shape = self._extract_shape(inp)
365
+ dtype = self._extract_dtype(inp)
366
+ tensors[inp.name] = UniversalTensor(
367
+ name=inp.name,
368
+ shape=shape,
369
+ dtype=dtype,
370
+ origin=TensorOrigin.INPUT,
371
+ source_name=inp.name,
372
+ )
373
+
374
+ # Add outputs
375
+ for out in graph.output:
376
+ shape = self._extract_shape(out)
377
+ dtype = self._extract_dtype(out)
378
+ tensors[out.name] = UniversalTensor(
379
+ name=out.name,
380
+ shape=shape,
381
+ dtype=dtype,
382
+ origin=TensorOrigin.OUTPUT,
383
+ source_name=out.name,
384
+ )
385
+
386
+ # Add value_info (intermediate tensors)
387
+ for vi in graph.value_info:
388
+ if vi.name not in tensors:
389
+ shape = self._extract_shape(vi)
390
+ dtype = self._extract_dtype(vi)
391
+ tensors[vi.name] = UniversalTensor(
392
+ name=vi.name,
393
+ shape=shape,
394
+ dtype=dtype,
395
+ origin=TensorOrigin.ACTIVATION,
396
+ source_name=vi.name,
397
+ )
398
+
399
+ # Build nodes
400
+ nodes: list[UniversalNode] = []
401
+ for node in graph.node:
402
+ # Extract output shapes from value_info or tensors
403
+ output_shapes: list[list[int]] = []
404
+ output_dtypes: list[DataType] = []
405
+ for out_name in node.output:
406
+ if out_name in tensors:
407
+ output_shapes.append(tensors[out_name].shape)
408
+ output_dtypes.append(tensors[out_name].dtype)
409
+ else:
410
+ output_shapes.append([])
411
+ output_dtypes.append(DataType.UNKNOWN)
412
+
413
+ # Extract attributes
414
+ attrs = self._extract_attributes(node)
415
+
416
+ nodes.append(
417
+ UniversalNode(
418
+ id=node.name or f"{node.op_type}_{len(nodes)}",
419
+ op_type=map_onnx_op_to_universal(node.op_type),
420
+ inputs=list(node.input),
421
+ outputs=list(node.output),
422
+ attributes=attrs,
423
+ output_shapes=output_shapes,
424
+ output_dtypes=output_dtypes,
425
+ source_op=node.op_type,
426
+ source_domain=node.domain or "ai.onnx",
427
+ )
428
+ )
429
+
430
+ return UniversalGraph(
431
+ nodes=nodes,
432
+ tensors=tensors,
433
+ metadata=metadata,
434
+ )
435
+
436
+ def can_write(self) -> bool:
437
+ """ONNX adapter supports writing."""
438
+ return True
439
+
440
+ def write(self, graph: UniversalGraph, path: Path) -> None:
441
+ """Write UniversalGraph to ONNX format."""
442
+ import onnx
443
+ from onnx import helper, numpy_helper
444
+
445
+ # Create initializers (weights)
446
+ initializers = []
447
+ for tensor in graph.tensors.values():
448
+ if tensor.origin == TensorOrigin.WEIGHT and tensor.data is not None:
449
+ init = numpy_helper.from_array(tensor.data, name=tensor.name)
450
+ initializers.append(init)
451
+
452
+ # Create inputs
453
+ inputs = []
454
+ for tensor in graph.tensors.values():
455
+ if tensor.origin == TensorOrigin.INPUT:
456
+ elem_type = self._dtype_to_onnx(tensor.dtype)
457
+ shape = tensor.shape if tensor.shape else None
458
+ inp = helper.make_tensor_value_info(tensor.name, elem_type, shape)
459
+ inputs.append(inp)
460
+
461
+ # Also add weight tensors as inputs (ONNX convention)
462
+ for tensor in graph.tensors.values():
463
+ if tensor.origin == TensorOrigin.WEIGHT:
464
+ elem_type = self._dtype_to_onnx(tensor.dtype)
465
+ inp = helper.make_tensor_value_info(tensor.name, elem_type, tensor.shape)
466
+ inputs.append(inp)
467
+
468
+ # Create outputs
469
+ outputs = []
470
+ for tensor in graph.tensors.values():
471
+ if tensor.origin == TensorOrigin.OUTPUT:
472
+ elem_type = self._dtype_to_onnx(tensor.dtype)
473
+ shape = tensor.shape if tensor.shape else None
474
+ out = helper.make_tensor_value_info(tensor.name, elem_type, shape)
475
+ outputs.append(out)
476
+
477
+ # Create nodes
478
+ onnx_nodes = []
479
+ for node in graph.nodes:
480
+ # Map universal op back to ONNX op
481
+ onnx_op = node.source_op or self._universal_to_onnx_op(node.op_type)
482
+
483
+ onnx_node = helper.make_node(
484
+ onnx_op,
485
+ inputs=node.inputs,
486
+ outputs=node.outputs,
487
+ name=node.id,
488
+ domain=node.source_domain or "",
489
+ **node.attributes,
490
+ )
491
+ onnx_nodes.append(onnx_node)
492
+
493
+ # Create graph
494
+ onnx_graph = helper.make_graph(
495
+ onnx_nodes,
496
+ name=graph.metadata.name or "haoline_export",
497
+ inputs=inputs,
498
+ outputs=outputs,
499
+ initializer=initializers,
500
+ )
501
+
502
+ # Create model
503
+ opset_version = graph.metadata.opset_version or 17
504
+ model = helper.make_model(
505
+ onnx_graph,
506
+ opset_imports=[helper.make_opsetid("", opset_version)],
507
+ producer_name="haoline",
508
+ )
509
+
510
+ # Save
511
+ onnx.save(model, str(path))
512
+
513
+ def _extract_shape(self, value_info: onnx.ValueInfoProto) -> list[int]:
514
+ """Extract shape from ONNX ValueInfoProto."""
515
+ shape: list[int] = []
516
+ try:
517
+ tensor_type = value_info.type.tensor_type
518
+ for dim in tensor_type.shape.dim:
519
+ if dim.dim_value > 0:
520
+ shape.append(dim.dim_value)
521
+ else:
522
+ # Dynamic dimension (dim_param)
523
+ shape.append(-1)
524
+ except Exception:
525
+ pass
526
+ return shape
527
+
528
+ def _extract_dtype(self, value_info: onnx.ValueInfoProto) -> DataType:
529
+ """Extract dtype from ONNX ValueInfoProto."""
530
+ try:
531
+ elem_type = value_info.type.tensor_type.elem_type
532
+ return DataType.from_onnx_dtype(elem_type)
533
+ except Exception:
534
+ return DataType.UNKNOWN
535
+
536
+ def _extract_attributes(self, node: onnx.NodeProto) -> dict[str, object]:
537
+ """Extract attributes from ONNX NodeProto."""
538
+ attrs: dict[str, object] = {}
539
+ for attr in node.attribute:
540
+ if attr.type == 1: # FLOAT
541
+ attrs[attr.name] = attr.f
542
+ elif attr.type == 2: # INT
543
+ attrs[attr.name] = attr.i
544
+ elif attr.type == 3: # STRING
545
+ attrs[attr.name] = attr.s.decode("utf-8") if attr.s else ""
546
+ elif attr.type == 6: # FLOATS
547
+ attrs[attr.name] = list(attr.floats)
548
+ elif attr.type == 7: # INTS
549
+ attrs[attr.name] = list(attr.ints)
550
+ elif attr.type == 8: # STRINGS
551
+ attrs[attr.name] = [s.decode("utf-8") for s in attr.strings]
552
+ # Skip TENSOR and GRAPH types for now
553
+ return attrs
554
+
555
+ def _dtype_to_onnx(self, dtype: DataType) -> int:
556
+ """Convert DataType to ONNX TensorProto dtype."""
557
+ from onnx import TensorProto
558
+
559
+ mapping = {
560
+ DataType.FLOAT32: TensorProto.FLOAT,
561
+ DataType.FLOAT64: TensorProto.DOUBLE,
562
+ DataType.FLOAT16: TensorProto.FLOAT16,
563
+ DataType.BFLOAT16: TensorProto.BFLOAT16,
564
+ DataType.INT64: TensorProto.INT64,
565
+ DataType.INT32: TensorProto.INT32,
566
+ DataType.INT16: TensorProto.INT16,
567
+ DataType.INT8: TensorProto.INT8,
568
+ DataType.UINT8: TensorProto.UINT8,
569
+ DataType.BOOL: TensorProto.BOOL,
570
+ DataType.STRING: TensorProto.STRING,
571
+ }
572
+ return mapping.get(dtype, TensorProto.FLOAT)
573
+
574
+ def _universal_to_onnx_op(self, universal_op: str) -> str:
575
+ """Map universal op type back to ONNX op."""
576
+ # Reverse mapping
577
+ reverse_map = {v: k for k, v in ONNX_TO_UNIVERSAL_OP.items()}
578
+ return reverse_map.get(universal_op, universal_op)
579
+
580
+
581
+ # =============================================================================
582
+ # PyTorch Adapter
583
+ # =============================================================================
584
+
585
+
586
+ @register_adapter
587
+ class PyTorchAdapter(FormatAdapter):
588
+ """Adapter for PyTorch models (.pt, .pth files).
589
+
590
+ Converts PyTorch models to UniversalGraph by first exporting to ONNX,
591
+ then using the OnnxAdapter. This ensures consistent representation.
592
+
593
+ For full models (nn.Module), uses torch.onnx.export.
594
+ For state_dicts, extracts weights without graph structure.
595
+ """
596
+
597
+ name = "pytorch"
598
+ extensions = [".pt", ".pth"]
599
+ source_format = SourceFormat.PYTORCH
600
+
601
+ def can_read(self, path: Path) -> bool:
602
+ """Check if file is a PyTorch model."""
603
+ return path.suffix.lower() in [".pt", ".pth"]
604
+
605
+ def read(self, path: Path) -> UniversalGraph:
606
+ """Read PyTorch model and convert to UniversalGraph.
607
+
608
+ Note: Requires sample input for tracing. Will attempt to
609
+ auto-detect input shape from the model structure.
610
+ """
611
+ import torch
612
+
613
+ path = Path(path)
614
+ if not path.exists():
615
+ raise FileNotFoundError(f"PyTorch model not found: {path}")
616
+
617
+ loaded = torch.load(str(path), map_location="cpu", weights_only=False)
618
+
619
+ # Check if it's a full model or state_dict
620
+ if isinstance(loaded, torch.nn.Module):
621
+ return self._convert_module(loaded, path)
622
+ elif isinstance(loaded, dict):
623
+ # Could be state_dict or Ultralytics model
624
+ if "model" in loaded:
625
+ # Ultralytics YOLO model
626
+ return self._convert_ultralytics(loaded, path)
627
+ else:
628
+ # Pure state_dict - weights only
629
+ return self._convert_state_dict(loaded, path)
630
+ else:
631
+ raise ValueError(f"Unknown PyTorch file format: {type(loaded)}")
632
+
633
+ def _convert_module(self, model: torch.nn.Module, path: Path) -> UniversalGraph:
634
+ """Convert torch.nn.Module to UniversalGraph via ONNX."""
635
+ import tempfile
636
+
637
+ import torch
638
+
639
+ model.eval()
640
+
641
+ # Try to detect input shape
642
+ dummy_input = self._create_dummy_input(model)
643
+
644
+ with tempfile.NamedTemporaryFile(suffix=".onnx", delete=False) as f:
645
+ onnx_path = Path(f.name)
646
+
647
+ try:
648
+ torch.onnx.export(
649
+ model,
650
+ (dummy_input,),
651
+ str(onnx_path),
652
+ opset_version=17,
653
+ do_constant_folding=True,
654
+ )
655
+
656
+ # Use ONNX adapter to read
657
+ graph = OnnxAdapter().read(onnx_path)
658
+
659
+ # Update metadata to reflect PyTorch origin
660
+ graph.metadata.source_format = SourceFormat.PYTORCH
661
+ graph.metadata.source_path = str(path)
662
+
663
+ return graph
664
+
665
+ finally:
666
+ if onnx_path.exists():
667
+ onnx_path.unlink()
668
+
669
+ def _convert_ultralytics(self, loaded: dict[str, object], path: Path) -> UniversalGraph:
670
+ """Convert Ultralytics YOLO model to UniversalGraph."""
671
+ import tempfile
672
+
673
+ try:
674
+ from ultralytics import YOLO
675
+ except ImportError as e:
676
+ raise ImportError(
677
+ "Ultralytics YOLO model detected. Install ultralytics: pip install ultralytics"
678
+ ) from e
679
+
680
+ # Use Ultralytics export
681
+ yolo = YOLO(str(path))
682
+
683
+ with tempfile.TemporaryDirectory() as _tmpdir:
684
+ onnx_path_str: str = yolo.export(format="onnx")
685
+ onnx_path = Path(onnx_path_str)
686
+
687
+ # Use ONNX adapter to read
688
+ graph = OnnxAdapter().read(onnx_path)
689
+
690
+ # Update metadata
691
+ graph.metadata.source_format = SourceFormat.PYTORCH
692
+ graph.metadata.source_path = str(path)
693
+ graph.metadata.extra["ultralytics"] = True
694
+
695
+ return graph
696
+
697
+ def _convert_state_dict(self, state_dict: dict[str, object], path: Path) -> UniversalGraph:
698
+ """Convert state_dict to UniversalGraph (weights only, no graph)."""
699
+ import torch
700
+
701
+ tensors: dict[str, UniversalTensor] = {}
702
+
703
+ for name, param in state_dict.items():
704
+ if isinstance(param, torch.Tensor):
705
+ np_data = param.detach().cpu().numpy()
706
+ tensors[name] = UniversalTensor(
707
+ name=name,
708
+ shape=list(np_data.shape),
709
+ dtype=DataType.from_numpy_dtype(np_data.dtype),
710
+ origin=TensorOrigin.WEIGHT,
711
+ data=np_data,
712
+ source_name=name,
713
+ )
714
+
715
+ return UniversalGraph(
716
+ nodes=[], # No graph structure for state_dict
717
+ tensors=tensors,
718
+ metadata=GraphMetadata(
719
+ name=path.stem,
720
+ source_format=SourceFormat.PYTORCH,
721
+ source_path=str(path),
722
+ extra={"type": "state_dict"},
723
+ ),
724
+ )
725
+
726
+ def _create_dummy_input(self, model: torch.nn.Module) -> torch.Tensor:
727
+ """Create dummy input for ONNX export.
728
+
729
+ Attempts to auto-detect input shape from the model's first layer.
730
+ """
731
+ import torch
732
+
733
+ # Try to find first conv or linear layer
734
+ for module in model.modules():
735
+ if isinstance(module, torch.nn.Conv2d):
736
+ # Assume image input
737
+ in_channels = module.in_channels
738
+ return torch.randn(1, in_channels, 224, 224)
739
+ elif isinstance(module, torch.nn.Linear):
740
+ in_features = module.in_features
741
+ return torch.randn(1, in_features)
742
+
743
+ # Default: batch of 224x224 RGB images
744
+ return torch.randn(1, 3, 224, 224)
745
+
746
+
747
+ # =============================================================================
748
+ # Utility Functions
749
+ # =============================================================================
750
+
751
+
752
+ def load_model(path: str | Path) -> UniversalGraph:
753
+ """Load a model file and convert to UniversalGraph.
754
+
755
+ Auto-detects format based on file extension.
756
+
757
+ Args:
758
+ path: Path to the model file
759
+
760
+ Returns:
761
+ UniversalGraph representation
762
+
763
+ Example:
764
+ graph = load_model("model.onnx")
765
+ graph = load_model("model.pt")
766
+ """
767
+ adapter = get_adapter(path)
768
+ return adapter.read(Path(path))
769
+
770
+
771
+ def save_model(graph: UniversalGraph, path: str | Path) -> None:
772
+ """Save UniversalGraph to a model file.
773
+
774
+ Format is determined by file extension.
775
+
776
+ Args:
777
+ graph: The graph to save
778
+ path: Output file path
779
+
780
+ Raises:
781
+ ValueError: If adapter doesn't support writing
782
+ """
783
+ path = Path(path)
784
+ adapter = get_adapter(path)
785
+ if not adapter.can_write():
786
+ raise ValueError(f"{adapter.name} adapter does not support writing")
787
+ adapter.write(graph, path)
788
+
789
+
790
+ # =============================================================================
791
+ # Conversion Matrix (Task 18.3)
792
+ # =============================================================================
793
+
794
+
795
+ class ConversionLevel(str, Enum):
796
+ """Conversion capability between formats.
797
+
798
+ Describes how well a conversion preserves information:
799
+ - FULL: Lossless conversion, all info preserved
800
+ - PARTIAL: Some limitations or requires multi-step
801
+ - LOSSY: Some information is lost
802
+ - NONE: No conversion path available
803
+ """
804
+
805
+ FULL = "full" # Lossless, complete conversion
806
+ PARTIAL = "partial" # Some limitations or multi-step required
807
+ LOSSY = "lossy" # Information loss during conversion
808
+ NONE = "none" # No conversion path
809
+
810
+
811
+ # Conversion matrix: (source, target) -> ConversionLevel
812
+ # Format: CONVERSION_MATRIX[(source_format, target_format)] = level
813
+ _CONVERSION_MATRIX: dict[tuple[SourceFormat, SourceFormat], ConversionLevel] = {
814
+ # ONNX conversions (primary interchange format)
815
+ (SourceFormat.ONNX, SourceFormat.TENSORRT): ConversionLevel.PARTIAL, # TensorRT-specific ops
816
+ (SourceFormat.ONNX, SourceFormat.TFLITE): ConversionLevel.PARTIAL, # Some ops unsupported
817
+ (SourceFormat.ONNX, SourceFormat.COREML): ConversionLevel.PARTIAL, # iOS-specific limits
818
+ (SourceFormat.ONNX, SourceFormat.OPENVINO): ConversionLevel.FULL, # Good ONNX support
819
+ # PyTorch conversions (via ONNX)
820
+ (SourceFormat.PYTORCH, SourceFormat.ONNX): ConversionLevel.FULL, # torch.onnx.export
821
+ (SourceFormat.PYTORCH, SourceFormat.TENSORRT): ConversionLevel.PARTIAL, # Via ONNX
822
+ (SourceFormat.PYTORCH, SourceFormat.TFLITE): ConversionLevel.PARTIAL, # Via ONNX
823
+ (SourceFormat.PYTORCH, SourceFormat.COREML): ConversionLevel.PARTIAL, # coremltools
824
+ # TensorFlow conversions
825
+ (SourceFormat.TENSORFLOW, SourceFormat.ONNX): ConversionLevel.PARTIAL, # tf2onnx
826
+ (SourceFormat.TENSORFLOW, SourceFormat.TFLITE): ConversionLevel.FULL, # TFLite converter
827
+ (SourceFormat.TENSORFLOW, SourceFormat.COREML): ConversionLevel.PARTIAL, # coremltools
828
+ # TensorRT (inference-only, limited export)
829
+ (SourceFormat.TENSORRT, SourceFormat.ONNX): ConversionLevel.NONE, # Cannot export
830
+ # CoreML (Apple ecosystem)
831
+ (SourceFormat.COREML, SourceFormat.ONNX): ConversionLevel.LOSSY, # Some info lost
832
+ # TFLite (mobile)
833
+ (SourceFormat.TFLITE, SourceFormat.ONNX): ConversionLevel.PARTIAL, # tflite2onnx
834
+ # Weights-only formats (no graph structure)
835
+ (SourceFormat.SAFETENSORS, SourceFormat.ONNX): ConversionLevel.NONE, # Weights only
836
+ (SourceFormat.GGUF, SourceFormat.ONNX): ConversionLevel.NONE, # Weights only
837
+ }
838
+
839
+
840
+ def get_conversion_level(source: SourceFormat | str, target: SourceFormat | str) -> ConversionLevel:
841
+ """Get the conversion capability between two formats.
842
+
843
+ Args:
844
+ source: Source model format
845
+ target: Target model format
846
+
847
+ Returns:
848
+ ConversionLevel indicating conversion capability
849
+ """
850
+ # Normalize to SourceFormat
851
+ if isinstance(source, str):
852
+ try:
853
+ source = SourceFormat(source.lower())
854
+ except ValueError:
855
+ return ConversionLevel.NONE
856
+ if isinstance(target, str):
857
+ try:
858
+ target = SourceFormat(target.lower())
859
+ except ValueError:
860
+ return ConversionLevel.NONE
861
+
862
+ # Identity conversion
863
+ if source == target:
864
+ return ConversionLevel.FULL
865
+
866
+ return _CONVERSION_MATRIX.get((source, target), ConversionLevel.NONE)
867
+
868
+
869
+ def list_conversion_paths(
870
+ source: SourceFormat | str | None = None,
871
+ target: SourceFormat | str | None = None,
872
+ ) -> list[dict[str, str]]:
873
+ """List available conversion paths.
874
+
875
+ Args:
876
+ source: Filter by source format (optional)
877
+ target: Filter by target format (optional)
878
+
879
+ Returns:
880
+ List of dicts with source, target, and level
881
+ """
882
+ result: list[dict[str, str]] = []
883
+
884
+ for (src, tgt), level in _CONVERSION_MATRIX.items():
885
+ # Apply filters
886
+ if source is not None:
887
+ source_fmt = (
888
+ source if isinstance(source, SourceFormat) else SourceFormat(source.lower())
889
+ )
890
+ if src != source_fmt:
891
+ continue
892
+ if target is not None:
893
+ target_fmt = (
894
+ target if isinstance(target, SourceFormat) else SourceFormat(target.lower())
895
+ )
896
+ if tgt != target_fmt:
897
+ continue
898
+
899
+ result.append(
900
+ {
901
+ "source": src.value,
902
+ "target": tgt.value,
903
+ "level": level.value,
904
+ }
905
+ )
906
+
907
+ return sorted(result, key=lambda x: (x["source"], x["target"]))
908
+
909
+
910
+ def can_convert(source: SourceFormat | str, target: SourceFormat | str) -> bool:
911
+ """Check if conversion is possible between two formats.
912
+
913
+ Returns True for FULL, PARTIAL, or LOSSY conversions.
914
+
915
+ Args:
916
+ source: Source model format
917
+ target: Target model format
918
+
919
+ Returns:
920
+ True if any conversion path exists
921
+ """
922
+ level = get_conversion_level(source, target)
923
+ return level != ConversionLevel.NONE
924
+
925
+
926
+ def convert_model(
927
+ graph: UniversalGraph,
928
+ target_format: SourceFormat | str,
929
+ output_path: Path | str,
930
+ ) -> Path:
931
+ """Convert a model to a different format.
932
+
933
+ Args:
934
+ graph: UniversalGraph to convert
935
+ target_format: Target format (e.g., "onnx", "tflite")
936
+ output_path: Output file path
937
+
938
+ Returns:
939
+ Path to the converted model
940
+
941
+ Raises:
942
+ ValueError: If conversion is not supported
943
+ """
944
+ output_path = Path(output_path)
945
+
946
+ # Get conversion level
947
+ source = graph.metadata.source_format
948
+ if isinstance(target_format, str):
949
+ target_format = SourceFormat(target_format.lower())
950
+
951
+ level = get_conversion_level(source, target_format)
952
+ if level == ConversionLevel.NONE:
953
+ raise ValueError(
954
+ f"Cannot convert from {source.value} to {target_format.value}. "
955
+ f"No conversion path available."
956
+ )
957
+
958
+ # Log warning for lossy conversions
959
+ if level == ConversionLevel.LOSSY:
960
+ logger.warning(f"Converting {source.value} to {target_format.value} may lose information")
961
+
962
+ # Get target adapter
963
+ # For now, only ONNX writing is supported
964
+ if target_format != SourceFormat.ONNX:
965
+ raise NotImplementedError(
966
+ f"Direct conversion to {target_format.value} not yet implemented. "
967
+ f"Export to ONNX first, then use format-specific tools."
968
+ )
969
+
970
+ adapter = OnnxAdapter()
971
+ adapter.write(graph, output_path)
972
+
973
+ return output_path
974
+
975
+
976
+ # =============================================================================
977
+ # Module Exports
978
+ # =============================================================================
979
+
980
+ __all__ = [
981
+ # Protocol
982
+ "FormatAdapter",
983
+ # Registry
984
+ "register_adapter",
985
+ "get_adapter",
986
+ "list_adapters",
987
+ # Adapters
988
+ "OnnxAdapter",
989
+ "PyTorchAdapter",
990
+ # Utilities
991
+ "load_model",
992
+ "save_model",
993
+ "map_onnx_op_to_universal",
994
+ "ONNX_TO_UNIVERSAL_OP",
995
+ # Conversion Matrix
996
+ "ConversionLevel",
997
+ "get_conversion_level",
998
+ "list_conversion_paths",
999
+ "can_convert",
1000
+ "convert_model",
1001
+ ]