haoline 0.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (70) hide show
  1. haoline/.streamlit/config.toml +10 -0
  2. haoline/__init__.py +248 -0
  3. haoline/analyzer.py +935 -0
  4. haoline/cli.py +2712 -0
  5. haoline/compare.py +811 -0
  6. haoline/compare_visualizations.py +1564 -0
  7. haoline/edge_analysis.py +525 -0
  8. haoline/eval/__init__.py +131 -0
  9. haoline/eval/adapters.py +844 -0
  10. haoline/eval/cli.py +390 -0
  11. haoline/eval/comparison.py +542 -0
  12. haoline/eval/deployment.py +633 -0
  13. haoline/eval/schemas.py +833 -0
  14. haoline/examples/__init__.py +15 -0
  15. haoline/examples/basic_inspection.py +74 -0
  16. haoline/examples/compare_models.py +117 -0
  17. haoline/examples/hardware_estimation.py +78 -0
  18. haoline/format_adapters.py +1001 -0
  19. haoline/formats/__init__.py +123 -0
  20. haoline/formats/coreml.py +250 -0
  21. haoline/formats/gguf.py +483 -0
  22. haoline/formats/openvino.py +255 -0
  23. haoline/formats/safetensors.py +273 -0
  24. haoline/formats/tflite.py +369 -0
  25. haoline/hardware.py +2307 -0
  26. haoline/hierarchical_graph.py +462 -0
  27. haoline/html_export.py +1573 -0
  28. haoline/layer_summary.py +769 -0
  29. haoline/llm_summarizer.py +465 -0
  30. haoline/op_icons.py +618 -0
  31. haoline/operational_profiling.py +1492 -0
  32. haoline/patterns.py +1116 -0
  33. haoline/pdf_generator.py +265 -0
  34. haoline/privacy.py +250 -0
  35. haoline/pydantic_models.py +241 -0
  36. haoline/report.py +1923 -0
  37. haoline/report_sections.py +539 -0
  38. haoline/risks.py +521 -0
  39. haoline/schema.py +523 -0
  40. haoline/streamlit_app.py +2024 -0
  41. haoline/tests/__init__.py +4 -0
  42. haoline/tests/conftest.py +123 -0
  43. haoline/tests/test_analyzer.py +868 -0
  44. haoline/tests/test_compare_visualizations.py +293 -0
  45. haoline/tests/test_edge_analysis.py +243 -0
  46. haoline/tests/test_eval.py +604 -0
  47. haoline/tests/test_format_adapters.py +460 -0
  48. haoline/tests/test_hardware.py +237 -0
  49. haoline/tests/test_hardware_recommender.py +90 -0
  50. haoline/tests/test_hierarchical_graph.py +326 -0
  51. haoline/tests/test_html_export.py +180 -0
  52. haoline/tests/test_layer_summary.py +428 -0
  53. haoline/tests/test_llm_patterns.py +540 -0
  54. haoline/tests/test_llm_summarizer.py +339 -0
  55. haoline/tests/test_patterns.py +774 -0
  56. haoline/tests/test_pytorch.py +327 -0
  57. haoline/tests/test_report.py +383 -0
  58. haoline/tests/test_risks.py +398 -0
  59. haoline/tests/test_schema.py +417 -0
  60. haoline/tests/test_tensorflow.py +380 -0
  61. haoline/tests/test_visualizations.py +316 -0
  62. haoline/universal_ir.py +856 -0
  63. haoline/visualizations.py +1086 -0
  64. haoline/visualize_yolo.py +44 -0
  65. haoline/web.py +110 -0
  66. haoline-0.3.0.dist-info/METADATA +471 -0
  67. haoline-0.3.0.dist-info/RECORD +70 -0
  68. haoline-0.3.0.dist-info/WHEEL +4 -0
  69. haoline-0.3.0.dist-info/entry_points.txt +5 -0
  70. haoline-0.3.0.dist-info/licenses/LICENSE +22 -0
@@ -0,0 +1,856 @@
1
+ # Copyright (c) 2025 HaoLine Contributors
2
+ # SPDX-License-Identifier: MIT
3
+
4
+ """
5
+ Universal Internal Representation (IR) for format-agnostic model analysis.
6
+
7
+ This module provides a backend-neutral representation of neural network models,
8
+ enabling analysis and comparison across different frameworks (ONNX, PyTorch,
9
+ TensorFlow, TensorRT, CoreML, etc.).
10
+
11
+ Design inspired by:
12
+ - OpenVINO IR (graph + weights separation)
13
+ - TVM Relay (typed graph representation)
14
+ - MLIR (extensible operation types)
15
+
16
+ Usage:
17
+ from haoline.universal_ir import UniversalGraph, UniversalNode, UniversalTensor
18
+
19
+ # Load from ONNX (via adapter)
20
+ graph = OnnxAdapter().read("model.onnx")
21
+
22
+ # Analyze
23
+ print(f"Nodes: {len(graph.nodes)}")
24
+ print(f"Parameters: {graph.total_parameters}")
25
+
26
+ # Compare two graphs
27
+ if graph1.is_structurally_equal(graph2):
28
+ print("Same architecture!")
29
+
30
+ # Export to JSON
31
+ graph.to_json("model_ir.json")
32
+ """
33
+
34
+ from __future__ import annotations
35
+
36
+ from enum import Enum
37
+ from pathlib import Path
38
+ from typing import Any
39
+
40
+ from pydantic import BaseModel, Field
41
+
42
+
43
+ class TensorOrigin(str, Enum):
44
+ """Origin type for tensors in the graph."""
45
+
46
+ WEIGHT = "weight" # Constant model parameter
47
+ INPUT = "input" # Model input
48
+ OUTPUT = "output" # Model output
49
+ ACTIVATION = "activation" # Intermediate activation (runtime)
50
+
51
+
52
+ class DataType(str, Enum):
53
+ """Supported data types for tensors."""
54
+
55
+ FLOAT64 = "float64"
56
+ FLOAT32 = "float32"
57
+ FLOAT16 = "float16"
58
+ BFLOAT16 = "bfloat16"
59
+ INT64 = "int64"
60
+ INT32 = "int32"
61
+ INT16 = "int16"
62
+ INT8 = "int8"
63
+ UINT8 = "uint8"
64
+ BOOL = "bool"
65
+ STRING = "string"
66
+ UNKNOWN = "unknown"
67
+
68
+ @classmethod
69
+ def from_numpy_dtype(cls, dtype: Any) -> DataType:
70
+ """Convert numpy dtype to DataType."""
71
+ import numpy as np
72
+
73
+ dtype_map = {
74
+ np.float64: cls.FLOAT64,
75
+ np.float32: cls.FLOAT32,
76
+ np.float16: cls.FLOAT16,
77
+ np.int64: cls.INT64,
78
+ np.int32: cls.INT32,
79
+ np.int16: cls.INT16,
80
+ np.int8: cls.INT8,
81
+ np.uint8: cls.UINT8,
82
+ np.bool_: cls.BOOL,
83
+ }
84
+ return dtype_map.get(dtype.type, cls.UNKNOWN)
85
+
86
+ @classmethod
87
+ def from_onnx_dtype(cls, onnx_dtype: int) -> DataType:
88
+ """Convert ONNX TensorProto dtype to DataType."""
89
+ # ONNX TensorProto.DataType values
90
+ onnx_map = {
91
+ 1: cls.FLOAT32, # FLOAT
92
+ 2: cls.UINT8, # UINT8
93
+ 3: cls.INT8, # INT8
94
+ 5: cls.INT16, # INT16
95
+ 6: cls.INT32, # INT32
96
+ 7: cls.INT64, # INT64
97
+ 8: cls.STRING, # STRING
98
+ 9: cls.BOOL, # BOOL
99
+ 10: cls.FLOAT16, # FLOAT16
100
+ 11: cls.FLOAT64, # DOUBLE
101
+ 16: cls.BFLOAT16, # BFLOAT16
102
+ }
103
+ return onnx_map.get(onnx_dtype, cls.UNKNOWN)
104
+
105
+ @property
106
+ def bytes_per_element(self) -> int:
107
+ """Return bytes per element for this dtype."""
108
+ size_map = {
109
+ DataType.FLOAT64: 8,
110
+ DataType.INT64: 8,
111
+ DataType.FLOAT32: 4,
112
+ DataType.INT32: 4,
113
+ DataType.FLOAT16: 2,
114
+ DataType.BFLOAT16: 2,
115
+ DataType.INT16: 2,
116
+ DataType.INT8: 1,
117
+ DataType.UINT8: 1,
118
+ DataType.BOOL: 1,
119
+ DataType.STRING: 0, # Variable
120
+ DataType.UNKNOWN: 0,
121
+ }
122
+ return size_map.get(self, 0)
123
+
124
+
125
+ class UniversalTensor(BaseModel):
126
+ """Represents a tensor (weight, input, output, or activation) in the IR.
127
+
128
+ Tensors are the edges of the computation graph - they connect nodes
129
+ and carry data (for weights) or metadata (for activations).
130
+
131
+ Attributes:
132
+ name: Unique identifier for this tensor
133
+ shape: Tensor dimensions (empty list for scalars)
134
+ dtype: Data type (float32, float16, int8, etc.)
135
+ origin: Whether this is a weight, input, output, or activation
136
+ data: Actual tensor data (for weights). None for activations.
137
+ Use lazy loading for large tensors.
138
+ source_name: Original name in source format (for round-trip)
139
+ """
140
+
141
+ name: str
142
+ shape: list[int] = Field(default_factory=list)
143
+ dtype: DataType = DataType.FLOAT32
144
+ origin: TensorOrigin = TensorOrigin.ACTIVATION
145
+ data: Any | None = None # numpy array or None for lazy loading
146
+ source_name: str | None = None # Original name in source format
147
+
148
+ model_config = {"arbitrary_types_allowed": True}
149
+
150
+ @property
151
+ def num_elements(self) -> int:
152
+ """Total number of elements in the tensor."""
153
+ if not self.shape:
154
+ return 1 # Scalar
155
+ result = 1
156
+ for dim in self.shape:
157
+ if dim > 0:
158
+ result *= dim
159
+ return result
160
+
161
+ @property
162
+ def size_bytes(self) -> int:
163
+ """Size in bytes (0 if shape has dynamic dimensions)."""
164
+ if any(d <= 0 for d in self.shape):
165
+ return 0 # Dynamic dimension
166
+ return self.num_elements * self.dtype.bytes_per_element
167
+
168
+ def __repr__(self) -> str:
169
+ shape_str = "x".join(str(d) for d in self.shape) if self.shape else "scalar"
170
+ return f"UniversalTensor({self.name}: {shape_str} {self.dtype.value})"
171
+
172
+
173
+ class UniversalNode(BaseModel):
174
+ """Represents a single operation in the computation graph.
175
+
176
+ Nodes are the vertices of the graph. Each node performs an operation
177
+ (like Conv2D, MatMul, Relu) on its input tensors to produce output tensors.
178
+
179
+ The op_type is a high-level category, NOT tied to any specific framework.
180
+ This enables cross-format comparison.
181
+
182
+ Attributes:
183
+ id: Unique identifier for this node
184
+ op_type: High-level operation type (Conv2D, MatMul, Relu, etc.)
185
+ inputs: List of input tensor names
186
+ outputs: List of output tensor names
187
+ attributes: Operation-specific parameters (kernel_size, strides, etc.)
188
+ output_shapes: Shapes of output tensors (if known)
189
+ output_dtypes: Data types of output tensors
190
+ source_op: Original op name in source format (for round-trip)
191
+ source_domain: Original domain (e.g., "ai.onnx" for ONNX)
192
+ """
193
+
194
+ id: str
195
+ op_type: str # High-level: Conv2D, MatMul, Relu, Add, etc.
196
+ inputs: list[str] = Field(default_factory=list)
197
+ outputs: list[str] = Field(default_factory=list)
198
+ attributes: dict[str, Any] = Field(default_factory=dict)
199
+ output_shapes: list[list[int]] = Field(default_factory=list)
200
+ output_dtypes: list[DataType] = Field(default_factory=list)
201
+
202
+ # Source format metadata (for round-trip conversion)
203
+ source_op: str | None = None # Original op name (e.g., "Conv" in ONNX)
204
+ source_domain: str | None = None # e.g., "ai.onnx", "com.microsoft"
205
+
206
+ def __repr__(self) -> str:
207
+ return f"UniversalNode({self.id}: {self.op_type})"
208
+
209
+ @property
210
+ def is_compute_op(self) -> bool:
211
+ """Check if this is a compute-heavy operation."""
212
+ compute_ops = {
213
+ "Conv2D",
214
+ "Conv3D",
215
+ "MatMul",
216
+ "Gemm",
217
+ "ConvTranspose",
218
+ "Attention",
219
+ "MultiHeadAttention",
220
+ }
221
+ return self.op_type in compute_ops
222
+
223
+ @property
224
+ def is_activation(self) -> bool:
225
+ """Check if this is an activation function."""
226
+ activation_ops = {
227
+ "Relu",
228
+ "LeakyRelu",
229
+ "Sigmoid",
230
+ "Tanh",
231
+ "Gelu",
232
+ "Silu",
233
+ "Swish",
234
+ "Softmax",
235
+ "Mish",
236
+ }
237
+ return self.op_type in activation_ops
238
+
239
+
240
+ class SourceFormat(str, Enum):
241
+ """Supported source model formats."""
242
+
243
+ ONNX = "onnx"
244
+ PYTORCH = "pytorch"
245
+ TENSORFLOW = "tensorflow"
246
+ TENSORRT = "tensorrt"
247
+ COREML = "coreml"
248
+ TFLITE = "tflite"
249
+ OPENVINO = "openvino"
250
+ SAFETENSORS = "safetensors"
251
+ GGUF = "gguf"
252
+ UNKNOWN = "unknown"
253
+
254
+
255
+ class GraphMetadata(BaseModel):
256
+ """Metadata about the model graph.
257
+
258
+ Stores information about the model's origin, version, and structure
259
+ that isn't captured in the nodes/tensors themselves.
260
+ """
261
+
262
+ name: str = ""
263
+ source_format: SourceFormat = SourceFormat.UNKNOWN
264
+ source_path: str | None = None
265
+
266
+ # Version info from source
267
+ ir_version: int | None = None # e.g., ONNX IR version
268
+ producer_name: str | None = None # e.g., "pytorch", "tf2onnx"
269
+ producer_version: str | None = None
270
+ opset_version: int | None = None # e.g., ONNX opset
271
+
272
+ # Model I/O
273
+ input_names: list[str] = Field(default_factory=list)
274
+ output_names: list[str] = Field(default_factory=list)
275
+
276
+ # Additional metadata from source (for round-trip)
277
+ extra: dict[str, Any] = Field(default_factory=dict)
278
+
279
+
280
+ class UniversalGraph(BaseModel):
281
+ """Universal representation of a neural network computation graph.
282
+
283
+ This is the top-level container for a model's IR. It holds all nodes
284
+ (operations), tensors (weights and I/O), and metadata.
285
+
286
+ The graph is designed to be:
287
+ - Format-agnostic: Works with ONNX, PyTorch, TensorFlow, etc.
288
+ - Serializable: Can be saved to JSON for debugging/interchange
289
+ - Comparable: Supports structural comparison between models
290
+ - Extensible: Easy to add new op types or metadata
291
+
292
+ Attributes:
293
+ nodes: List of computation nodes (operations)
294
+ tensors: Dict mapping tensor name to UniversalTensor
295
+ metadata: Graph-level metadata (name, source format, etc.)
296
+ """
297
+
298
+ nodes: list[UniversalNode] = Field(default_factory=list)
299
+ tensors: dict[str, UniversalTensor] = Field(default_factory=dict)
300
+ metadata: GraphMetadata = Field(default_factory=GraphMetadata)
301
+
302
+ model_config = {"arbitrary_types_allowed": True}
303
+
304
+ # -------------------------------------------------------------------------
305
+ # Properties
306
+ # -------------------------------------------------------------------------
307
+
308
+ @property
309
+ def num_nodes(self) -> int:
310
+ """Total number of nodes in the graph."""
311
+ return len(self.nodes)
312
+
313
+ @property
314
+ def num_tensors(self) -> int:
315
+ """Total number of tensors (weights + I/O + activations)."""
316
+ return len(self.tensors)
317
+
318
+ @property
319
+ def weight_tensors(self) -> list[UniversalTensor]:
320
+ """Get all weight tensors."""
321
+ return [t for t in self.tensors.values() if t.origin == TensorOrigin.WEIGHT]
322
+
323
+ @property
324
+ def input_tensors(self) -> list[UniversalTensor]:
325
+ """Get all input tensors."""
326
+ return [t for t in self.tensors.values() if t.origin == TensorOrigin.INPUT]
327
+
328
+ @property
329
+ def output_tensors(self) -> list[UniversalTensor]:
330
+ """Get all output tensors."""
331
+ return [t for t in self.tensors.values() if t.origin == TensorOrigin.OUTPUT]
332
+
333
+ @property
334
+ def total_parameters(self) -> int:
335
+ """Total number of parameters (weight elements)."""
336
+ return sum(t.num_elements for t in self.weight_tensors)
337
+
338
+ @property
339
+ def total_weight_bytes(self) -> int:
340
+ """Total size of weights in bytes."""
341
+ return sum(t.size_bytes for t in self.weight_tensors)
342
+
343
+ @property
344
+ def op_type_counts(self) -> dict[str, int]:
345
+ """Count of each operation type."""
346
+ counts: dict[str, int] = {}
347
+ for node in self.nodes:
348
+ counts[node.op_type] = counts.get(node.op_type, 0) + 1
349
+ return counts
350
+
351
+ # -------------------------------------------------------------------------
352
+ # Node/Tensor Access
353
+ # -------------------------------------------------------------------------
354
+
355
+ def get_node(self, node_id: str) -> UniversalNode | None:
356
+ """Get a node by its ID."""
357
+ for node in self.nodes:
358
+ if node.id == node_id:
359
+ return node
360
+ return None
361
+
362
+ def get_tensor(self, name: str) -> UniversalTensor | None:
363
+ """Get a tensor by its name."""
364
+ return self.tensors.get(name)
365
+
366
+ def get_node_inputs(self, node: UniversalNode) -> list[UniversalTensor]:
367
+ """Get the input tensors for a node."""
368
+ return [self.tensors[name] for name in node.inputs if name in self.tensors]
369
+
370
+ def get_node_outputs(self, node: UniversalNode) -> list[UniversalTensor]:
371
+ """Get the output tensors for a node."""
372
+ return [self.tensors[name] for name in node.outputs if name in self.tensors]
373
+
374
+ # -------------------------------------------------------------------------
375
+ # Comparison (Task 18.4.1, 18.4.2)
376
+ # -------------------------------------------------------------------------
377
+
378
+ def is_structurally_equal(self, other: UniversalGraph) -> bool:
379
+ """Check if two graphs have the same structure.
380
+
381
+ Two graphs are structurally equal if they have:
382
+ - Same number of nodes
383
+ - Same op_type sequence
384
+ - Same connectivity (which node feeds which)
385
+
386
+ Ignores:
387
+ - Weight values
388
+ - Precision differences (FP32 vs FP16)
389
+ - Node/tensor names
390
+ """
391
+ if len(self.nodes) != len(other.nodes):
392
+ return False
393
+
394
+ # Compare op types in order
395
+ self_ops = [n.op_type for n in self.nodes]
396
+ other_ops = [n.op_type for n in other.nodes]
397
+ if self_ops != other_ops:
398
+ return False
399
+
400
+ # Compare connectivity (input/output counts per node)
401
+ for n1, n2 in zip(self.nodes, other.nodes, strict=True):
402
+ if len(n1.inputs) != len(n2.inputs):
403
+ return False
404
+ if len(n1.outputs) != len(n2.outputs):
405
+ return False
406
+
407
+ return True
408
+
409
+ def diff(self, other: UniversalGraph) -> dict[str, Any]:
410
+ """Generate a detailed diff between two graphs.
411
+
412
+ Returns a dict with:
413
+ - 'structurally_equal': bool
414
+ - 'node_count_diff': (self_count, other_count)
415
+ - 'op_type_diff': {op_type: (self_count, other_count)}
416
+ - 'dtype_changes': [{node_id, self_dtype, other_dtype}]
417
+ - 'missing_in_self': [node_ids in other but not self]
418
+ - 'missing_in_other': [node_ids in self but not other]
419
+ """
420
+ result: dict[str, Any] = {
421
+ "structurally_equal": self.is_structurally_equal(other),
422
+ "node_count_diff": (len(self.nodes), len(other.nodes)),
423
+ "param_count_diff": (self.total_parameters, other.total_parameters),
424
+ "weight_bytes_diff": (self.total_weight_bytes, other.total_weight_bytes),
425
+ "op_type_diff": {},
426
+ "dtype_changes": [],
427
+ }
428
+
429
+ # Compare op type counts
430
+ self_ops = self.op_type_counts
431
+ other_ops = other.op_type_counts
432
+ all_ops = set(self_ops.keys()) | set(other_ops.keys())
433
+ for op in all_ops:
434
+ self_count = self_ops.get(op, 0)
435
+ other_count = other_ops.get(op, 0)
436
+ if self_count != other_count:
437
+ result["op_type_diff"][op] = (self_count, other_count)
438
+
439
+ # Compare weight dtypes
440
+ self_weights = {t.name: t for t in self.weight_tensors}
441
+ other_weights = {t.name: t for t in other.weight_tensors}
442
+
443
+ for name in set(self_weights.keys()) & set(other_weights.keys()):
444
+ if self_weights[name].dtype != other_weights[name].dtype:
445
+ result["dtype_changes"].append(
446
+ {
447
+ "tensor": name,
448
+ "self_dtype": self_weights[name].dtype.value,
449
+ "other_dtype": other_weights[name].dtype.value,
450
+ }
451
+ )
452
+
453
+ return result
454
+
455
+ # -------------------------------------------------------------------------
456
+ # Serialization (Task 18.5.1)
457
+ # -------------------------------------------------------------------------
458
+
459
+ def to_dict(self, include_weights: bool = False) -> dict[str, Any]:
460
+ """Convert graph to a dictionary for JSON serialization.
461
+
462
+ Args:
463
+ include_weights: If True, include actual weight data (large!)
464
+
465
+ Returns:
466
+ Dict representation of the graph
467
+ """
468
+ # Build tensors dict
469
+ tensors_dict: dict[str, Any] = {}
470
+ for name, tensor in self.tensors.items():
471
+ t_dict = tensor.model_dump()
472
+ if not include_weights:
473
+ t_dict["data"] = None # Don't serialize actual weight data
474
+ tensors_dict[name] = t_dict
475
+
476
+ result: dict[str, Any] = {
477
+ "metadata": self.metadata.model_dump(),
478
+ "nodes": [node.model_dump() for node in self.nodes],
479
+ "tensors": tensors_dict,
480
+ "summary": {
481
+ "num_nodes": self.num_nodes,
482
+ "num_tensors": self.num_tensors,
483
+ "total_parameters": self.total_parameters,
484
+ "total_weight_bytes": self.total_weight_bytes,
485
+ "op_type_counts": self.op_type_counts,
486
+ },
487
+ }
488
+
489
+ return result
490
+
491
+ def to_json(self, path: str | Path, include_weights: bool = False) -> None:
492
+ """Save graph to JSON file.
493
+
494
+ Args:
495
+ path: Output file path
496
+ include_weights: If True, include actual weight data (large!)
497
+ """
498
+ import json
499
+
500
+ data = self.to_dict(include_weights=include_weights)
501
+ with open(path, "w") as f:
502
+ json.dump(data, f, indent=2, default=str)
503
+
504
+ @classmethod
505
+ def from_json(cls, path: str | Path) -> UniversalGraph:
506
+ """Load graph from JSON file.
507
+
508
+ Note: Weight data is not restored (would need separate binary file).
509
+ """
510
+ import json
511
+
512
+ with open(path) as f:
513
+ data = json.load(f)
514
+
515
+ metadata = GraphMetadata(**data.get("metadata", {}))
516
+ nodes = [UniversalNode(**n) for n in data.get("nodes", [])]
517
+ tensors = {name: UniversalTensor(**t) for name, t in data.get("tensors", {}).items()}
518
+
519
+ return cls(nodes=nodes, tensors=tensors, metadata=metadata)
520
+
521
+ # -------------------------------------------------------------------------
522
+ # String representation
523
+ # -------------------------------------------------------------------------
524
+
525
+ def __repr__(self) -> str:
526
+ return (
527
+ f"UniversalGraph("
528
+ f"nodes={self.num_nodes}, "
529
+ f"params={self.total_parameters:,}, "
530
+ f"source={self.metadata.source_format.value})"
531
+ )
532
+
533
+ def summary(self) -> str:
534
+ """Return a human-readable summary of the graph."""
535
+ lines = [
536
+ f"Universal IR Graph: {self.metadata.name or 'unnamed'}",
537
+ f" Source: {self.metadata.source_format.value}",
538
+ f" Nodes: {self.num_nodes}",
539
+ f" Parameters: {self.total_parameters:,}",
540
+ f" Weight Size: {self.total_weight_bytes / 1024 / 1024:.2f} MB",
541
+ f" Inputs: {len(self.input_tensors)}",
542
+ f" Outputs: {len(self.output_tensors)}",
543
+ "",
544
+ " Top Operations:",
545
+ ]
546
+
547
+ for op, count in sorted(self.op_type_counts.items(), key=lambda x: x[1], reverse=True)[:10]:
548
+ lines.append(f" {op}: {count}")
549
+
550
+ return "\n".join(lines)
551
+
552
+ # -------------------------------------------------------------------------
553
+ # Visualization (Task 18.5.2)
554
+ # -------------------------------------------------------------------------
555
+
556
+ def to_dot(self, max_nodes: int = 500, cluster_by_op: bool = False) -> str:
557
+ """Export graph to Graphviz DOT format.
558
+
559
+ Args:
560
+ max_nodes: Maximum nodes to include (for large graphs)
561
+ cluster_by_op: Group nodes by operation type in subgraphs
562
+
563
+ Returns:
564
+ DOT format string
565
+ """
566
+ lines = [
567
+ "digraph UniversalGraph {",
568
+ " rankdir=TB;",
569
+ ' node [shape=box, style=filled, fontname="Arial"];',
570
+ ' edge [fontname="Arial", fontsize=10];',
571
+ "",
572
+ ]
573
+
574
+ # Add title
575
+ name = self.metadata.name or "model"
576
+ lines.append(
577
+ f' label="{name} ({self.num_nodes} nodes, {self.total_parameters:,} params)";'
578
+ )
579
+ lines.append(' labelloc="t";')
580
+ lines.append("")
581
+
582
+ # Limit nodes for large graphs
583
+ nodes_to_render = self.nodes[:max_nodes]
584
+ if len(self.nodes) > max_nodes:
585
+ lines.append(f" // Showing {max_nodes} of {len(self.nodes)} nodes")
586
+ lines.append("")
587
+
588
+ # Color mapping for op types
589
+ op_colors = {
590
+ "Conv2D": "#a8d5ba", # Green for convolutions
591
+ "MatMul": "#f4a261", # Orange for matrix ops
592
+ "Relu": "#e9c46a", # Yellow for activations
593
+ "LeakyRelu": "#e9c46a",
594
+ "Sigmoid": "#e9c46a",
595
+ "Softmax": "#e9c46a",
596
+ "BatchNorm": "#b8c5d6", # Blue-gray for normalization
597
+ "LayerNorm": "#b8c5d6",
598
+ "Add": "#d4a5a5", # Pink for element-wise
599
+ "Concat": "#c8b6ff", # Purple for structural
600
+ "Reshape": "#c8b6ff",
601
+ "MaxPool2D": "#95d5b2", # Light green for pooling
602
+ "AvgPool2D": "#95d5b2",
603
+ }
604
+ default_color = "#ffffff"
605
+
606
+ if cluster_by_op:
607
+ # Group by op type
608
+ ops_to_nodes: dict[str, list[UniversalNode]] = {}
609
+ for node in nodes_to_render:
610
+ ops_to_nodes.setdefault(node.op_type, []).append(node)
611
+
612
+ for op_type, op_nodes in ops_to_nodes.items():
613
+ color = op_colors.get(op_type, default_color)
614
+ lines.append(f" subgraph cluster_{op_type} {{")
615
+ lines.append(f' label="{op_type}";')
616
+ lines.append(" style=filled;")
617
+ lines.append(f' color="{color}";')
618
+ for node in op_nodes:
619
+ label = f"{node.id}\\n{node.op_type}"
620
+ lines.append(f' "{node.id}" [label="{label}"];')
621
+ lines.append(" }")
622
+ lines.append("")
623
+ else:
624
+ # Flat node list
625
+ for node in nodes_to_render:
626
+ color = op_colors.get(node.op_type, default_color)
627
+ label = f"{node.id}\\n{node.op_type}"
628
+ lines.append(f' "{node.id}" [label="{label}", fillcolor="{color}"];')
629
+
630
+ lines.append("")
631
+
632
+ # Add edges based on tensor connections
633
+ node_ids = {n.id for n in nodes_to_render}
634
+ tensor_to_producer: dict[str, str] = {}
635
+
636
+ # Map tensors to their producing nodes
637
+ for node in nodes_to_render:
638
+ for output in node.outputs:
639
+ tensor_to_producer[output] = node.id
640
+
641
+ # Create edges
642
+ for node in nodes_to_render:
643
+ for inp in node.inputs:
644
+ if inp in tensor_to_producer:
645
+ producer = tensor_to_producer[inp]
646
+ if producer in node_ids:
647
+ lines.append(f' "{producer}" -> "{node.id}";')
648
+
649
+ lines.append("}")
650
+ return "\n".join(lines)
651
+
652
+ def to_networkx(self) -> Any:
653
+ """Export graph to NetworkX DiGraph.
654
+
655
+ Returns:
656
+ networkx.DiGraph with nodes and edges
657
+
658
+ Raises:
659
+ ImportError: If networkx is not installed
660
+ """
661
+ try:
662
+ import networkx as nx
663
+ except ImportError as e:
664
+ raise ImportError(
665
+ "NetworkX is required for graph export. Install with: pip install networkx"
666
+ ) from e
667
+
668
+ G = nx.DiGraph()
669
+
670
+ # Add nodes with attributes
671
+ for node in self.nodes:
672
+ G.add_node(
673
+ node.id,
674
+ op_type=node.op_type,
675
+ inputs=node.inputs,
676
+ outputs=node.outputs,
677
+ attributes=node.attributes,
678
+ )
679
+
680
+ # Add edges based on tensor connections
681
+ tensor_to_producer: dict[str, str] = {}
682
+ for node in self.nodes:
683
+ for output in node.outputs:
684
+ tensor_to_producer[output] = node.id
685
+
686
+ for node in self.nodes:
687
+ for inp in node.inputs:
688
+ if inp in tensor_to_producer:
689
+ G.add_edge(tensor_to_producer[inp], node.id, tensor=inp)
690
+
691
+ return G
692
+
693
+ def save_dot(self, path: str | Path) -> None:
694
+ """Save graph to DOT file.
695
+
696
+ Args:
697
+ path: Output file path (.dot)
698
+ """
699
+ dot_content = self.to_dot()
700
+ with open(path, "w") as f:
701
+ f.write(dot_content)
702
+
703
+ def save_png(self, path: str | Path, max_nodes: int = 500) -> None:
704
+ """Render graph to PNG using Graphviz.
705
+
706
+ Args:
707
+ path: Output file path (.png)
708
+ max_nodes: Maximum nodes to render
709
+
710
+ Raises:
711
+ ImportError: If graphviz is not installed
712
+ """
713
+ try:
714
+ import graphviz
715
+ except ImportError as e:
716
+ raise ImportError(
717
+ "Graphviz Python package is required. Install with: pip install graphviz\n"
718
+ "Also ensure Graphviz system package is installed."
719
+ ) from e
720
+
721
+ dot_content = self.to_dot(max_nodes=max_nodes)
722
+ source = graphviz.Source(dot_content)
723
+
724
+ # graphviz renders to path without extension, then adds it
725
+ path = Path(path)
726
+ output_path = path.with_suffix("") # Remove .png
727
+ source.render(str(output_path), format="png", cleanup=True)
728
+
729
+ def to_hierarchical(self) -> dict[str, Any]:
730
+ """Convert UniversalGraph to HierarchicalGraph-compatible dict for D3.js.
731
+
732
+ Returns a nested dict structure that matches the format expected by the
733
+ interactive D3.js graph visualization. This enables using UniversalGraph
734
+ as a drop-in replacement for HierarchicalGraph.
735
+
736
+ Returns:
737
+ Dict structure matching HierarchicalGraph.to_dict() format
738
+ """
739
+ # Build tensor -> producing node mapping for edges
740
+ tensor_to_producer: dict[str, str] = {}
741
+ for node in self.nodes:
742
+ for output in node.outputs:
743
+ tensor_to_producer[output] = node.id
744
+
745
+ # Calculate approximate stats per node
746
+ def estimate_node_flops(node: UniversalNode) -> int:
747
+ """Rough FLOP estimate based on op type and output shapes."""
748
+ if not node.output_shapes:
749
+ return 0
750
+
751
+ output_elements = 1
752
+ for shape in node.output_shapes:
753
+ for dim in shape:
754
+ if dim > 0:
755
+ output_elements *= dim
756
+
757
+ # Simple heuristics
758
+ flop_multipliers = {
759
+ "Conv2D": 9, # Approximate for 3x3 kernel
760
+ "MatMul": 2, # multiply-add
761
+ "Gemm": 2,
762
+ "Attention": 4,
763
+ "MultiHeadAttention": 4,
764
+ }
765
+ return output_elements * flop_multipliers.get(node.op_type, 1)
766
+
767
+ def estimate_node_memory(node: UniversalNode) -> int:
768
+ """Rough memory estimate in bytes."""
769
+ if not node.output_shapes:
770
+ return 0
771
+
772
+ total_elements = 0
773
+ for shape in node.output_shapes:
774
+ elements = 1
775
+ for dim in shape:
776
+ if dim > 0:
777
+ elements *= dim
778
+ total_elements += elements
779
+
780
+ # Assume float32 (4 bytes) by default
781
+ bytes_per_elem = 4
782
+ if node.output_dtypes:
783
+ dtype = node.output_dtypes[0]
784
+ if isinstance(dtype, DataType):
785
+ bytes_per_elem = dtype.bytes_per_element # Property, not method
786
+ elif isinstance(dtype, int):
787
+ # Handle case where dtype is stored as int
788
+ bytes_per_elem = dtype if dtype > 0 else 4
789
+
790
+ return total_elements * bytes_per_elem
791
+
792
+ # Convert nodes to hierarchical format
793
+ children = []
794
+ total_flops = 0
795
+ total_params = 0
796
+ total_memory = 0
797
+
798
+ for node in self.nodes:
799
+ node_flops = estimate_node_flops(node)
800
+ node_memory = estimate_node_memory(node)
801
+ total_flops += node_flops
802
+ total_memory += node_memory
803
+
804
+ # Get param count from associated weight tensors
805
+ node_params = 0
806
+ for inp in node.inputs:
807
+ if inp in self.tensors:
808
+ tensor = self.tensors[inp]
809
+ if tensor.origin == TensorOrigin.WEIGHT:
810
+ node_params += tensor.num_elements
811
+
812
+ total_params += node_params
813
+
814
+ child = {
815
+ "id": node.id,
816
+ "name": node.id,
817
+ "display_name": node.op_type,
818
+ "node_type": "op",
819
+ "op_type": node.op_type,
820
+ "depth": 1,
821
+ "is_collapsed": False,
822
+ "is_repeated": False,
823
+ "repeat_count": 1,
824
+ "total_flops": node_flops,
825
+ "total_params": node_params,
826
+ "total_memory_bytes": node_memory,
827
+ "node_count": 1,
828
+ "inputs": node.inputs,
829
+ "outputs": node.outputs,
830
+ "attributes": node.attributes,
831
+ }
832
+ children.append(child)
833
+
834
+ # Create root node
835
+ model_name = self.metadata.name or "Model"
836
+ root = {
837
+ "id": "root",
838
+ "name": model_name,
839
+ "display_name": model_name,
840
+ "node_type": "model",
841
+ "op_type": None,
842
+ "depth": 0,
843
+ "is_collapsed": False,
844
+ "is_repeated": False,
845
+ "repeat_count": 1,
846
+ "total_flops": total_flops,
847
+ "total_params": total_params or self.total_parameters,
848
+ "total_memory_bytes": total_memory,
849
+ "node_count": len(self.nodes),
850
+ "inputs": self.metadata.input_names,
851
+ "outputs": self.metadata.output_names,
852
+ "attributes": {},
853
+ "children": children,
854
+ }
855
+
856
+ return root