morphml 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of morphml might be problematic. Click here for more details.

Files changed (158) hide show
  1. morphml/__init__.py +14 -0
  2. morphml/api/__init__.py +26 -0
  3. morphml/api/app.py +326 -0
  4. morphml/api/auth.py +193 -0
  5. morphml/api/client.py +338 -0
  6. morphml/api/models.py +132 -0
  7. morphml/api/rate_limit.py +192 -0
  8. morphml/benchmarking/__init__.py +36 -0
  9. morphml/benchmarking/comparison.py +430 -0
  10. morphml/benchmarks/__init__.py +56 -0
  11. morphml/benchmarks/comparator.py +409 -0
  12. morphml/benchmarks/datasets.py +280 -0
  13. morphml/benchmarks/metrics.py +199 -0
  14. morphml/benchmarks/openml_suite.py +201 -0
  15. morphml/benchmarks/problems.py +289 -0
  16. morphml/benchmarks/suite.py +318 -0
  17. morphml/cli/__init__.py +5 -0
  18. morphml/cli/commands/experiment.py +329 -0
  19. morphml/cli/main.py +457 -0
  20. morphml/cli/quickstart.py +312 -0
  21. morphml/config.py +278 -0
  22. morphml/constraints/__init__.py +19 -0
  23. morphml/constraints/handler.py +205 -0
  24. morphml/constraints/predicates.py +285 -0
  25. morphml/core/__init__.py +3 -0
  26. morphml/core/crossover.py +449 -0
  27. morphml/core/dsl/README.md +359 -0
  28. morphml/core/dsl/__init__.py +72 -0
  29. morphml/core/dsl/ast_nodes.py +364 -0
  30. morphml/core/dsl/compiler.py +318 -0
  31. morphml/core/dsl/layers.py +368 -0
  32. morphml/core/dsl/lexer.py +336 -0
  33. morphml/core/dsl/parser.py +455 -0
  34. morphml/core/dsl/search_space.py +386 -0
  35. morphml/core/dsl/syntax.py +199 -0
  36. morphml/core/dsl/type_system.py +361 -0
  37. morphml/core/dsl/validator.py +386 -0
  38. morphml/core/graph/__init__.py +40 -0
  39. morphml/core/graph/edge.py +124 -0
  40. morphml/core/graph/graph.py +507 -0
  41. morphml/core/graph/mutations.py +409 -0
  42. morphml/core/graph/node.py +196 -0
  43. morphml/core/graph/serialization.py +361 -0
  44. morphml/core/graph/visualization.py +431 -0
  45. morphml/core/objectives/__init__.py +20 -0
  46. morphml/core/search/__init__.py +33 -0
  47. morphml/core/search/individual.py +252 -0
  48. morphml/core/search/parameters.py +453 -0
  49. morphml/core/search/population.py +375 -0
  50. morphml/core/search/search_engine.py +340 -0
  51. morphml/distributed/__init__.py +76 -0
  52. morphml/distributed/fault_tolerance.py +497 -0
  53. morphml/distributed/health_monitor.py +348 -0
  54. morphml/distributed/master.py +709 -0
  55. morphml/distributed/proto/README.md +224 -0
  56. morphml/distributed/proto/__init__.py +74 -0
  57. morphml/distributed/proto/worker.proto +170 -0
  58. morphml/distributed/proto/worker_pb2.py +79 -0
  59. morphml/distributed/proto/worker_pb2_grpc.py +423 -0
  60. morphml/distributed/resource_manager.py +416 -0
  61. morphml/distributed/scheduler.py +567 -0
  62. morphml/distributed/storage/__init__.py +33 -0
  63. morphml/distributed/storage/artifacts.py +381 -0
  64. morphml/distributed/storage/cache.py +366 -0
  65. morphml/distributed/storage/checkpointing.py +329 -0
  66. morphml/distributed/storage/database.py +459 -0
  67. morphml/distributed/worker.py +549 -0
  68. morphml/evaluation/__init__.py +5 -0
  69. morphml/evaluation/heuristic.py +237 -0
  70. morphml/exceptions.py +55 -0
  71. morphml/execution/__init__.py +5 -0
  72. morphml/execution/local_executor.py +350 -0
  73. morphml/integrations/__init__.py +28 -0
  74. morphml/integrations/jax_adapter.py +206 -0
  75. morphml/integrations/pytorch_adapter.py +530 -0
  76. morphml/integrations/sklearn_adapter.py +206 -0
  77. morphml/integrations/tensorflow_adapter.py +230 -0
  78. morphml/logging_config.py +93 -0
  79. morphml/meta_learning/__init__.py +66 -0
  80. morphml/meta_learning/architecture_similarity.py +277 -0
  81. morphml/meta_learning/experiment_database.py +240 -0
  82. morphml/meta_learning/knowledge_base/__init__.py +19 -0
  83. morphml/meta_learning/knowledge_base/embedder.py +179 -0
  84. morphml/meta_learning/knowledge_base/knowledge_base.py +313 -0
  85. morphml/meta_learning/knowledge_base/meta_features.py +265 -0
  86. morphml/meta_learning/knowledge_base/vector_store.py +271 -0
  87. morphml/meta_learning/predictors/__init__.py +27 -0
  88. morphml/meta_learning/predictors/ensemble.py +221 -0
  89. morphml/meta_learning/predictors/gnn_predictor.py +552 -0
  90. morphml/meta_learning/predictors/learning_curve.py +231 -0
  91. morphml/meta_learning/predictors/proxy_metrics.py +261 -0
  92. morphml/meta_learning/strategy_evolution/__init__.py +27 -0
  93. morphml/meta_learning/strategy_evolution/adaptive_optimizer.py +226 -0
  94. morphml/meta_learning/strategy_evolution/bandit.py +276 -0
  95. morphml/meta_learning/strategy_evolution/portfolio.py +230 -0
  96. morphml/meta_learning/transfer.py +581 -0
  97. morphml/meta_learning/warm_start.py +286 -0
  98. morphml/optimizers/__init__.py +74 -0
  99. morphml/optimizers/adaptive_operators.py +399 -0
  100. morphml/optimizers/bayesian/__init__.py +52 -0
  101. morphml/optimizers/bayesian/acquisition.py +387 -0
  102. morphml/optimizers/bayesian/base.py +319 -0
  103. morphml/optimizers/bayesian/gaussian_process.py +635 -0
  104. morphml/optimizers/bayesian/smac.py +534 -0
  105. morphml/optimizers/bayesian/tpe.py +411 -0
  106. morphml/optimizers/differential_evolution.py +220 -0
  107. morphml/optimizers/evolutionary/__init__.py +61 -0
  108. morphml/optimizers/evolutionary/cma_es.py +416 -0
  109. morphml/optimizers/evolutionary/differential_evolution.py +556 -0
  110. morphml/optimizers/evolutionary/encoding.py +426 -0
  111. morphml/optimizers/evolutionary/particle_swarm.py +449 -0
  112. morphml/optimizers/genetic_algorithm.py +486 -0
  113. morphml/optimizers/gradient_based/__init__.py +22 -0
  114. morphml/optimizers/gradient_based/darts.py +550 -0
  115. morphml/optimizers/gradient_based/enas.py +585 -0
  116. morphml/optimizers/gradient_based/operations.py +474 -0
  117. morphml/optimizers/gradient_based/utils.py +601 -0
  118. morphml/optimizers/hill_climbing.py +169 -0
  119. morphml/optimizers/multi_objective/__init__.py +56 -0
  120. morphml/optimizers/multi_objective/indicators.py +504 -0
  121. morphml/optimizers/multi_objective/nsga2.py +647 -0
  122. morphml/optimizers/multi_objective/visualization.py +427 -0
  123. morphml/optimizers/nsga2.py +308 -0
  124. morphml/optimizers/random_search.py +172 -0
  125. morphml/optimizers/simulated_annealing.py +181 -0
  126. morphml/plugins/__init__.py +35 -0
  127. morphml/plugins/custom_evaluator_example.py +81 -0
  128. morphml/plugins/custom_optimizer_example.py +63 -0
  129. morphml/plugins/plugin_system.py +454 -0
  130. morphml/reports/__init__.py +30 -0
  131. morphml/reports/generator.py +362 -0
  132. morphml/tracking/__init__.py +7 -0
  133. morphml/tracking/experiment.py +309 -0
  134. morphml/tracking/logger.py +301 -0
  135. morphml/tracking/reporter.py +357 -0
  136. morphml/utils/__init__.py +6 -0
  137. morphml/utils/checkpoint.py +189 -0
  138. morphml/utils/comparison.py +390 -0
  139. morphml/utils/export.py +407 -0
  140. morphml/utils/progress.py +392 -0
  141. morphml/utils/validation.py +392 -0
  142. morphml/version.py +7 -0
  143. morphml/visualization/__init__.py +50 -0
  144. morphml/visualization/analytics.py +423 -0
  145. morphml/visualization/architecture_diagrams.py +353 -0
  146. morphml/visualization/architecture_plot.py +223 -0
  147. morphml/visualization/convergence_plot.py +174 -0
  148. morphml/visualization/crossover_viz.py +386 -0
  149. morphml/visualization/graph_viz.py +338 -0
  150. morphml/visualization/pareto_plot.py +149 -0
  151. morphml/visualization/plotly_dashboards.py +422 -0
  152. morphml/visualization/population.py +309 -0
  153. morphml/visualization/progress.py +260 -0
  154. morphml-1.0.0.dist-info/METADATA +434 -0
  155. morphml-1.0.0.dist-info/RECORD +158 -0
  156. morphml-1.0.0.dist-info/WHEEL +4 -0
  157. morphml-1.0.0.dist-info/entry_points.txt +3 -0
  158. morphml-1.0.0.dist-info/licenses/LICENSE +21 -0
@@ -0,0 +1,353 @@
1
+ """Professional architecture diagram generation for MorphML.
2
+
3
+ Generate publication-quality diagrams of neural architectures.
4
+
5
+ Example:
6
+ >>> from morphml.visualization.architecture_diagrams import ArchitectureDiagramGenerator
7
+ >>> generator = ArchitectureDiagramGenerator()
8
+ >>> generator.generate_svg(graph, "architecture.svg")
9
+ """
10
+
11
+ from typing import Dict, Optional
12
+
13
+ try:
14
+ import graphviz
15
+
16
+ GRAPHVIZ_AVAILABLE = True
17
+ except ImportError:
18
+ GRAPHVIZ_AVAILABLE = False
19
+ graphviz = None
20
+
21
+ from morphml.core.graph import GraphNode, ModelGraph
22
+ from morphml.logging_config import get_logger
23
+
24
+ logger = get_logger(__name__)
25
+
26
+
27
+ class ArchitectureDiagramGenerator:
28
+ """
29
+ Generate professional architecture diagrams.
30
+
31
+ Creates publication-quality visualizations of neural architectures
32
+ using Graphviz with custom styling.
33
+
34
+ Example:
35
+ >>> generator = ArchitectureDiagramGenerator()
36
+ >>> generator.generate_svg(graph, "output.svg")
37
+ >>> generator.generate_png(graph, "output.png", dpi=300)
38
+ """
39
+
40
+ # Color scheme for different operations
41
+ OPERATION_COLORS = {
42
+ "input": "#90EE90", # Light green
43
+ "output": "#FFB6C1", # Light pink
44
+ "conv2d": "#FF6B6B", # Red
45
+ "dense": "#4ECDC4", # Teal
46
+ "maxpool": "#45B7D1", # Blue
47
+ "avgpool": "#5DADE2", # Light blue
48
+ "relu": "#96CEB4", # Green
49
+ "sigmoid": "#FFEAA7", # Yellow
50
+ "tanh": "#DFE6E9", # Gray
51
+ "softmax": "#FD79A8", # Pink
52
+ "batchnorm": "#A29BFE", # Purple
53
+ "dropout": "#FDCB6E", # Orange
54
+ "flatten": "#E17055", # Dark orange
55
+ }
56
+
57
+ def __init__(self, style: str = "modern"):
58
+ """
59
+ Initialize diagram generator.
60
+
61
+ Args:
62
+ style: Visual style ('modern', 'classic', 'minimal')
63
+ """
64
+ if not GRAPHVIZ_AVAILABLE:
65
+ raise ImportError(
66
+ "Graphviz is required for architecture diagrams. "
67
+ "Install with: pip install graphviz"
68
+ )
69
+
70
+ self.style = style
71
+ logger.info(f"Initialized ArchitectureDiagramGenerator with style: {style}")
72
+
73
+ def generate_svg(self, graph: ModelGraph, output_path: str, title: Optional[str] = None):
74
+ """
75
+ Generate SVG diagram.
76
+
77
+ Args:
78
+ graph: ModelGraph to visualize
79
+ output_path: Output file path (without extension)
80
+ title: Optional diagram title
81
+
82
+ Example:
83
+ >>> generator.generate_svg(graph, "architecture")
84
+ # Creates architecture.svg
85
+ """
86
+ dot = self._create_graphviz_graph(graph, title)
87
+ dot.render(output_path, format="svg", cleanup=True)
88
+ logger.info(f"Generated SVG diagram: {output_path}.svg")
89
+
90
+ def generate_png(
91
+ self, graph: ModelGraph, output_path: str, dpi: int = 300, title: Optional[str] = None
92
+ ):
93
+ """
94
+ Generate PNG diagram.
95
+
96
+ Args:
97
+ graph: ModelGraph to visualize
98
+ output_path: Output file path (without extension)
99
+ dpi: Resolution in DPI
100
+ title: Optional diagram title
101
+
102
+ Example:
103
+ >>> generator.generate_png(graph, "architecture", dpi=300)
104
+ # Creates architecture.png
105
+ """
106
+ dot = self._create_graphviz_graph(graph, title)
107
+ dot.graph_attr["dpi"] = str(dpi)
108
+ dot.render(output_path, format="png", cleanup=True)
109
+ logger.info(f"Generated PNG diagram: {output_path}.png")
110
+
111
+ def generate_pdf(self, graph: ModelGraph, output_path: str, title: Optional[str] = None):
112
+ """
113
+ Generate PDF diagram.
114
+
115
+ Args:
116
+ graph: ModelGraph to visualize
117
+ output_path: Output file path (without extension)
118
+ title: Optional diagram title
119
+
120
+ Example:
121
+ >>> generator.generate_pdf(graph, "architecture")
122
+ # Creates architecture.pdf
123
+ """
124
+ dot = self._create_graphviz_graph(graph, title)
125
+ dot.render(output_path, format="pdf", cleanup=True)
126
+ logger.info(f"Generated PDF diagram: {output_path}.pdf")
127
+
128
+ def _create_graphviz_graph(
129
+ self, graph: ModelGraph, title: Optional[str] = None
130
+ ) -> "graphviz.Digraph":
131
+ """
132
+ Create Graphviz graph from ModelGraph.
133
+
134
+ Args:
135
+ graph: ModelGraph to convert
136
+ title: Optional title
137
+
138
+ Returns:
139
+ Graphviz Digraph
140
+ """
141
+ dot = graphviz.Digraph(comment="Neural Architecture")
142
+
143
+ # Set graph attributes based on style
144
+ if self.style == "modern":
145
+ dot.attr(
146
+ rankdir="TB",
147
+ bgcolor="white",
148
+ fontname="Arial",
149
+ fontsize="12",
150
+ splines="ortho",
151
+ nodesep="0.5",
152
+ ranksep="0.8",
153
+ )
154
+ elif self.style == "classic":
155
+ dot.attr(rankdir="TB", bgcolor="white", fontname="Times", fontsize="11")
156
+ else: # minimal
157
+ dot.attr(
158
+ rankdir="TB", bgcolor="white", fontname="Helvetica", fontsize="10", splines="line"
159
+ )
160
+
161
+ # Add title if provided
162
+ if title:
163
+ dot.attr(label=title, labelloc="t", fontsize="16")
164
+
165
+ # Add nodes
166
+ for _node_id, node in graph.nodes.items():
167
+ self._add_node(dot, node)
168
+
169
+ # Add edges
170
+ for _edge_id, edge in graph.edges.items():
171
+ self._add_edge(dot, edge)
172
+
173
+ return dot
174
+
175
+ def _add_node(self, dot: "graphviz.Digraph", node: GraphNode):
176
+ """
177
+ Add node to Graphviz graph.
178
+
179
+ Args:
180
+ dot: Graphviz graph
181
+ node: GraphNode to add
182
+ """
183
+ operation = node.operation
184
+
185
+ # Get color
186
+ color = self.OPERATION_COLORS.get(operation, "#CCCCCC")
187
+
188
+ # Create label
189
+ label = self._create_node_label(node)
190
+
191
+ # Node attributes
192
+ node_attrs = {
193
+ "label": label,
194
+ "shape": "box",
195
+ "style": "filled,rounded",
196
+ "fillcolor": color,
197
+ "fontname": "Arial",
198
+ "fontsize": "10",
199
+ "margin": "0.2,0.1",
200
+ }
201
+
202
+ # Special styling for input/output
203
+ if operation == "input":
204
+ node_attrs["shape"] = "ellipse"
205
+ node_attrs["style"] = "filled"
206
+ elif operation == "output":
207
+ node_attrs["shape"] = "ellipse"
208
+ node_attrs["style"] = "filled"
209
+
210
+ dot.node(str(node.id), **node_attrs)
211
+
212
+ def _create_node_label(self, node: GraphNode) -> str:
213
+ """
214
+ Create formatted label for node.
215
+
216
+ Args:
217
+ node: GraphNode
218
+
219
+ Returns:
220
+ Formatted label string
221
+ """
222
+ operation = node.operation
223
+ params = node.params
224
+
225
+ # Base label
226
+ label = f"{operation}"
227
+
228
+ # Add key parameters
229
+ if operation == "conv2d":
230
+ filters = params.get("filters", "?")
231
+ kernel = params.get("kernel_size", "?")
232
+ label += f"\n{filters} filters\n{kernel}x{kernel} kernel"
233
+
234
+ elif operation == "dense":
235
+ units = params.get("units", "?")
236
+ label += f"\n{units} units"
237
+
238
+ elif operation in ["maxpool", "avgpool"]:
239
+ pool_size = params.get("pool_size", "?")
240
+ label += f"\n{pool_size}x{pool_size} pool"
241
+
242
+ elif operation == "dropout":
243
+ rate = params.get("rate", "?")
244
+ label += f"\nrate={rate}"
245
+
246
+ elif operation == "input":
247
+ shape = params.get("shape", "?")
248
+ label += f"\nshape={shape}"
249
+
250
+ return label
251
+
252
+ def _add_edge(self, dot: "graphviz.Digraph", edge):
253
+ """
254
+ Add edge to Graphviz graph.
255
+
256
+ Args:
257
+ dot: Graphviz graph
258
+ edge: GraphEdge to add
259
+ """
260
+ dot.edge(
261
+ str(edge.source.id),
262
+ str(edge.target.id),
263
+ color="#2C3E50",
264
+ penwidth="1.5",
265
+ arrowsize="0.8",
266
+ )
267
+
268
+ def generate_comparison_diagram(
269
+ self, graphs: Dict[str, ModelGraph], output_path: str, format: str = "svg"
270
+ ):
271
+ """
272
+ Generate side-by-side comparison of multiple architectures.
273
+
274
+ Args:
275
+ graphs: Dict mapping names to ModelGraphs
276
+ output_path: Output file path (without extension)
277
+ format: Output format ('svg', 'png', 'pdf')
278
+
279
+ Example:
280
+ >>> generator.generate_comparison_diagram({
281
+ ... 'Architecture A': graph_a,
282
+ ... 'Architecture B': graph_b
283
+ ... }, "comparison")
284
+ """
285
+ # Create compound graph
286
+ dot = graphviz.Digraph(comment="Architecture Comparison")
287
+ dot.attr(rankdir="LR", compound="true")
288
+
289
+ # Add subgraphs for each architecture
290
+ for i, (name, graph) in enumerate(graphs.items()):
291
+ with dot.subgraph(name=f"cluster_{i}") as sub:
292
+ sub.attr(label=name, style="rounded", color="gray")
293
+
294
+ # Add nodes
295
+ for _node_id, node in graph.nodes.items():
296
+ self._add_node(sub, node)
297
+
298
+ # Add edges
299
+ for _edge_id, edge in graph.edges.items():
300
+ self._add_edge(sub, edge)
301
+
302
+ # Render
303
+ dot.render(output_path, format=format, cleanup=True)
304
+ logger.info(f"Generated comparison diagram: {output_path}.{format}")
305
+
306
+ def generate_layer_statistics_diagram(self, graph: ModelGraph, output_path: str):
307
+ """
308
+ Generate diagram with layer statistics.
309
+
310
+ Args:
311
+ graph: ModelGraph to visualize
312
+ output_path: Output file path (without extension)
313
+
314
+ Example:
315
+ >>> generator.generate_layer_statistics_diagram(graph, "stats")
316
+ """
317
+ dot = graphviz.Digraph(comment="Architecture with Statistics")
318
+ dot.attr(rankdir="TB", bgcolor="white")
319
+
320
+ # Calculate statistics
321
+ total_params = graph.estimate_parameters()
322
+ depth = graph.estimate_depth()
323
+ width = graph.estimate_width()
324
+
325
+ # Add info box
326
+ info_label = (
327
+ f"Total Parameters: {total_params:,}\\n"
328
+ f"Depth: {depth}\\n"
329
+ f"Width: {width}\\n"
330
+ f"Nodes: {len(graph.nodes)}"
331
+ )
332
+
333
+ dot.node(
334
+ "info",
335
+ label=info_label,
336
+ shape="box",
337
+ style="filled",
338
+ fillcolor="#E8F4F8",
339
+ fontname="Courier",
340
+ fontsize="10",
341
+ )
342
+
343
+ # Add architecture nodes
344
+ for _node_id, node in graph.nodes.items():
345
+ self._add_node(dot, node)
346
+
347
+ # Add edges
348
+ for _edge_id, edge in graph.edges.items():
349
+ self._add_edge(dot, edge)
350
+
351
+ # Render
352
+ dot.render(output_path, format="svg", cleanup=True)
353
+ logger.info(f"Generated statistics diagram: {output_path}.svg")
@@ -0,0 +1,223 @@
1
+ """Architecture visualization utilities.
2
+
3
+ Author: Eshan Roy <eshanized@proton.me>
4
+ Organization: TONMOY INFRASTRUCTURE & VISION
5
+ """
6
+
7
+ from typing import Optional
8
+
9
+ from morphml.core.graph import ModelGraph
10
+ from morphml.logging_config import get_logger
11
+
12
+ logger = get_logger(__name__)
13
+
14
+
15
+ def plot_architecture(
16
+ graph: ModelGraph, save_path: Optional[str] = None, title: str = "Neural Architecture"
17
+ ) -> None:
18
+ """
19
+ Visualize neural architecture as a graph.
20
+
21
+ Args:
22
+ graph: ModelGraph to visualize
23
+ save_path: Path to save plot (displays if None)
24
+ title: Plot title
25
+
26
+ Example:
27
+ >>> from morphml.visualization.architecture_plot import plot_architecture
28
+ >>> plot_architecture(best_individual.graph)
29
+ """
30
+ try:
31
+ import matplotlib.pyplot as plt
32
+ import networkx as nx
33
+ except ImportError:
34
+ logger.error("matplotlib and networkx required for architecture plotting")
35
+ return
36
+
37
+ # Create NetworkX graph
38
+ G = nx.DiGraph()
39
+
40
+ # Add nodes
41
+ for node_id, node in graph.nodes.items():
42
+ G.add_node(node_id, label=f"{node.operation}\n{node_id[:8]}")
43
+
44
+ # Add edges
45
+ for edge in graph.edges:
46
+ G.add_edge(edge.source.id, edge.target.id)
47
+
48
+ # Create layout
49
+ pos = nx.spring_layout(G, k=1, iterations=50)
50
+
51
+ # Create plot
52
+ plt.figure(figsize=(14, 10))
53
+
54
+ # Draw nodes
55
+ nx.draw_networkx_nodes(
56
+ G, pos, node_size=3000, node_color="lightblue", alpha=0.9, edgecolors="black", linewidths=2
57
+ )
58
+
59
+ # Draw edges
60
+ nx.draw_networkx_edges(
61
+ G, pos, edge_color="gray", arrows=True, arrowsize=20, arrowstyle="->", width=2
62
+ )
63
+
64
+ # Draw labels
65
+ labels = nx.get_node_attributes(G, "label")
66
+ nx.draw_networkx_labels(G, pos, labels, font_size=8)
67
+
68
+ plt.title(title, fontsize=16, fontweight="bold")
69
+ plt.axis("off")
70
+ plt.tight_layout()
71
+
72
+ if save_path:
73
+ plt.savefig(save_path, dpi=300, bbox_inches="tight")
74
+ logger.info(f"Architecture plot saved to {save_path}")
75
+ else:
76
+ plt.show()
77
+
78
+ plt.close()
79
+
80
+
81
+ def plot_architecture_hierarchy(
82
+ graph: ModelGraph, save_path: Optional[str] = None, title: str = "Architecture Hierarchy"
83
+ ) -> None:
84
+ """
85
+ Visualize architecture with hierarchical layout.
86
+
87
+ Args:
88
+ graph: ModelGraph to visualize
89
+ save_path: Path to save plot
90
+ title: Plot title
91
+ """
92
+ try:
93
+ import matplotlib.pyplot as plt
94
+ import networkx as nx
95
+ except ImportError:
96
+ logger.error("matplotlib and networkx required")
97
+ return
98
+
99
+ # Create NetworkX graph
100
+ G = nx.DiGraph()
101
+
102
+ for node_id, node in graph.nodes.items():
103
+ G.add_node(node_id, label=node.operation)
104
+
105
+ for edge in graph.edges:
106
+ G.add_edge(edge.source.id, edge.target.id)
107
+
108
+ # Use hierarchical layout
109
+ try:
110
+ pos = nx.nx_agraph.graphviz_layout(G, prog="dot")
111
+ except Exception:
112
+ # Fallback to spring layout if graphviz not available
113
+ pos = nx.spring_layout(G)
114
+
115
+ plt.figure(figsize=(12, 10))
116
+
117
+ # Color nodes by operation type
118
+ node_colors = []
119
+ for node_id in G.nodes():
120
+ node = graph.nodes[node_id]
121
+ if node.operation in ["input", "output"]:
122
+ node_colors.append("lightgreen")
123
+ elif "conv" in node.operation:
124
+ node_colors.append("lightblue")
125
+ elif "dense" in node.operation:
126
+ node_colors.append("lightyellow")
127
+ else:
128
+ node_colors.append("lightgray")
129
+
130
+ nx.draw(
131
+ G,
132
+ pos,
133
+ node_color=node_colors,
134
+ with_labels=False,
135
+ node_size=2000,
136
+ alpha=0.9,
137
+ edgecolors="black",
138
+ linewidths=2,
139
+ edge_color="gray",
140
+ arrows=True,
141
+ arrowsize=15,
142
+ )
143
+
144
+ # Add labels
145
+ labels = {node_id: graph.nodes[node_id].operation for node_id in G.nodes()}
146
+ nx.draw_networkx_labels(G, pos, labels, font_size=8)
147
+
148
+ plt.title(title, fontsize=16, fontweight="bold")
149
+ plt.axis("off")
150
+ plt.tight_layout()
151
+
152
+ if save_path:
153
+ plt.savefig(save_path, dpi=300, bbox_inches="tight")
154
+ logger.info(f"Architecture hierarchy plot saved to {save_path}")
155
+ else:
156
+ plt.show()
157
+
158
+ plt.close()
159
+
160
+
161
+ def plot_architecture_stats(
162
+ graphs: list, save_path: Optional[str] = None, title: str = "Architecture Statistics"
163
+ ) -> None:
164
+ """
165
+ Plot statistics of multiple architectures.
166
+
167
+ Args:
168
+ graphs: List of ModelGraph objects
169
+ save_path: Path to save plot
170
+ title: Plot title
171
+ """
172
+ try:
173
+ import matplotlib.pyplot as plt
174
+ except ImportError:
175
+ logger.error("matplotlib required")
176
+ return
177
+
178
+ # Collect statistics
179
+ node_counts = [len(g.nodes) for g in graphs]
180
+ depths = [g.get_depth() for g in graphs]
181
+ params = [g.estimate_parameters() / 1e6 for g in graphs] # In millions
182
+
183
+ # Create subplots
184
+ fig, axes = plt.subplots(1, 3, figsize=(15, 5))
185
+
186
+ # Node count distribution
187
+ axes[0].hist(node_counts, bins=20, edgecolor="black", alpha=0.7)
188
+ axes[0].set_xlabel("Number of Nodes", fontsize=12)
189
+ axes[0].set_ylabel("Frequency", fontsize=12)
190
+ axes[0].set_title("Node Count Distribution", fontsize=12, fontweight="bold")
191
+ axes[0].grid(True, alpha=0.3)
192
+
193
+ # Depth distribution
194
+ axes[1].hist(depths, bins=20, edgecolor="black", alpha=0.7, color="green")
195
+ axes[1].set_xlabel("Depth", fontsize=12)
196
+ axes[1].set_ylabel("Frequency", fontsize=12)
197
+ axes[1].set_title("Depth Distribution", fontsize=12, fontweight="bold")
198
+ axes[1].grid(True, alpha=0.3)
199
+
200
+ # Parameter count distribution
201
+ axes[2].hist(params, bins=20, edgecolor="black", alpha=0.7, color="orange")
202
+ axes[2].set_xlabel("Parameters (Millions)", fontsize=12)
203
+ axes[2].set_ylabel("Frequency", fontsize=12)
204
+ axes[2].set_title("Parameter Count Distribution", fontsize=12, fontweight="bold")
205
+ axes[2].grid(True, alpha=0.3)
206
+
207
+ plt.suptitle(title, fontsize=16, fontweight="bold")
208
+ plt.tight_layout()
209
+
210
+ if save_path:
211
+ plt.savefig(save_path, dpi=300, bbox_inches="tight")
212
+ logger.info(f"Architecture statistics plot saved to {save_path}")
213
+ else:
214
+ plt.show()
215
+
216
+ plt.close()
217
+
218
+
219
+ # Re-export from existing visualization module for convenience
220
+ try:
221
+ pass
222
+ except ImportError:
223
+ pass