morphml 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of morphml might be problematic. Click here for more details.
- morphml/__init__.py +14 -0
- morphml/api/__init__.py +26 -0
- morphml/api/app.py +326 -0
- morphml/api/auth.py +193 -0
- morphml/api/client.py +338 -0
- morphml/api/models.py +132 -0
- morphml/api/rate_limit.py +192 -0
- morphml/benchmarking/__init__.py +36 -0
- morphml/benchmarking/comparison.py +430 -0
- morphml/benchmarks/__init__.py +56 -0
- morphml/benchmarks/comparator.py +409 -0
- morphml/benchmarks/datasets.py +280 -0
- morphml/benchmarks/metrics.py +199 -0
- morphml/benchmarks/openml_suite.py +201 -0
- morphml/benchmarks/problems.py +289 -0
- morphml/benchmarks/suite.py +318 -0
- morphml/cli/__init__.py +5 -0
- morphml/cli/commands/experiment.py +329 -0
- morphml/cli/main.py +457 -0
- morphml/cli/quickstart.py +312 -0
- morphml/config.py +278 -0
- morphml/constraints/__init__.py +19 -0
- morphml/constraints/handler.py +205 -0
- morphml/constraints/predicates.py +285 -0
- morphml/core/__init__.py +3 -0
- morphml/core/crossover.py +449 -0
- morphml/core/dsl/README.md +359 -0
- morphml/core/dsl/__init__.py +72 -0
- morphml/core/dsl/ast_nodes.py +364 -0
- morphml/core/dsl/compiler.py +318 -0
- morphml/core/dsl/layers.py +368 -0
- morphml/core/dsl/lexer.py +336 -0
- morphml/core/dsl/parser.py +455 -0
- morphml/core/dsl/search_space.py +386 -0
- morphml/core/dsl/syntax.py +199 -0
- morphml/core/dsl/type_system.py +361 -0
- morphml/core/dsl/validator.py +386 -0
- morphml/core/graph/__init__.py +40 -0
- morphml/core/graph/edge.py +124 -0
- morphml/core/graph/graph.py +507 -0
- morphml/core/graph/mutations.py +409 -0
- morphml/core/graph/node.py +196 -0
- morphml/core/graph/serialization.py +361 -0
- morphml/core/graph/visualization.py +431 -0
- morphml/core/objectives/__init__.py +20 -0
- morphml/core/search/__init__.py +33 -0
- morphml/core/search/individual.py +252 -0
- morphml/core/search/parameters.py +453 -0
- morphml/core/search/population.py +375 -0
- morphml/core/search/search_engine.py +340 -0
- morphml/distributed/__init__.py +76 -0
- morphml/distributed/fault_tolerance.py +497 -0
- morphml/distributed/health_monitor.py +348 -0
- morphml/distributed/master.py +709 -0
- morphml/distributed/proto/README.md +224 -0
- morphml/distributed/proto/__init__.py +74 -0
- morphml/distributed/proto/worker.proto +170 -0
- morphml/distributed/proto/worker_pb2.py +79 -0
- morphml/distributed/proto/worker_pb2_grpc.py +423 -0
- morphml/distributed/resource_manager.py +416 -0
- morphml/distributed/scheduler.py +567 -0
- morphml/distributed/storage/__init__.py +33 -0
- morphml/distributed/storage/artifacts.py +381 -0
- morphml/distributed/storage/cache.py +366 -0
- morphml/distributed/storage/checkpointing.py +329 -0
- morphml/distributed/storage/database.py +459 -0
- morphml/distributed/worker.py +549 -0
- morphml/evaluation/__init__.py +5 -0
- morphml/evaluation/heuristic.py +237 -0
- morphml/exceptions.py +55 -0
- morphml/execution/__init__.py +5 -0
- morphml/execution/local_executor.py +350 -0
- morphml/integrations/__init__.py +28 -0
- morphml/integrations/jax_adapter.py +206 -0
- morphml/integrations/pytorch_adapter.py +530 -0
- morphml/integrations/sklearn_adapter.py +206 -0
- morphml/integrations/tensorflow_adapter.py +230 -0
- morphml/logging_config.py +93 -0
- morphml/meta_learning/__init__.py +66 -0
- morphml/meta_learning/architecture_similarity.py +277 -0
- morphml/meta_learning/experiment_database.py +240 -0
- morphml/meta_learning/knowledge_base/__init__.py +19 -0
- morphml/meta_learning/knowledge_base/embedder.py +179 -0
- morphml/meta_learning/knowledge_base/knowledge_base.py +313 -0
- morphml/meta_learning/knowledge_base/meta_features.py +265 -0
- morphml/meta_learning/knowledge_base/vector_store.py +271 -0
- morphml/meta_learning/predictors/__init__.py +27 -0
- morphml/meta_learning/predictors/ensemble.py +221 -0
- morphml/meta_learning/predictors/gnn_predictor.py +552 -0
- morphml/meta_learning/predictors/learning_curve.py +231 -0
- morphml/meta_learning/predictors/proxy_metrics.py +261 -0
- morphml/meta_learning/strategy_evolution/__init__.py +27 -0
- morphml/meta_learning/strategy_evolution/adaptive_optimizer.py +226 -0
- morphml/meta_learning/strategy_evolution/bandit.py +276 -0
- morphml/meta_learning/strategy_evolution/portfolio.py +230 -0
- morphml/meta_learning/transfer.py +581 -0
- morphml/meta_learning/warm_start.py +286 -0
- morphml/optimizers/__init__.py +74 -0
- morphml/optimizers/adaptive_operators.py +399 -0
- morphml/optimizers/bayesian/__init__.py +52 -0
- morphml/optimizers/bayesian/acquisition.py +387 -0
- morphml/optimizers/bayesian/base.py +319 -0
- morphml/optimizers/bayesian/gaussian_process.py +635 -0
- morphml/optimizers/bayesian/smac.py +534 -0
- morphml/optimizers/bayesian/tpe.py +411 -0
- morphml/optimizers/differential_evolution.py +220 -0
- morphml/optimizers/evolutionary/__init__.py +61 -0
- morphml/optimizers/evolutionary/cma_es.py +416 -0
- morphml/optimizers/evolutionary/differential_evolution.py +556 -0
- morphml/optimizers/evolutionary/encoding.py +426 -0
- morphml/optimizers/evolutionary/particle_swarm.py +449 -0
- morphml/optimizers/genetic_algorithm.py +486 -0
- morphml/optimizers/gradient_based/__init__.py +22 -0
- morphml/optimizers/gradient_based/darts.py +550 -0
- morphml/optimizers/gradient_based/enas.py +585 -0
- morphml/optimizers/gradient_based/operations.py +474 -0
- morphml/optimizers/gradient_based/utils.py +601 -0
- morphml/optimizers/hill_climbing.py +169 -0
- morphml/optimizers/multi_objective/__init__.py +56 -0
- morphml/optimizers/multi_objective/indicators.py +504 -0
- morphml/optimizers/multi_objective/nsga2.py +647 -0
- morphml/optimizers/multi_objective/visualization.py +427 -0
- morphml/optimizers/nsga2.py +308 -0
- morphml/optimizers/random_search.py +172 -0
- morphml/optimizers/simulated_annealing.py +181 -0
- morphml/plugins/__init__.py +35 -0
- morphml/plugins/custom_evaluator_example.py +81 -0
- morphml/plugins/custom_optimizer_example.py +63 -0
- morphml/plugins/plugin_system.py +454 -0
- morphml/reports/__init__.py +30 -0
- morphml/reports/generator.py +362 -0
- morphml/tracking/__init__.py +7 -0
- morphml/tracking/experiment.py +309 -0
- morphml/tracking/logger.py +301 -0
- morphml/tracking/reporter.py +357 -0
- morphml/utils/__init__.py +6 -0
- morphml/utils/checkpoint.py +189 -0
- morphml/utils/comparison.py +390 -0
- morphml/utils/export.py +407 -0
- morphml/utils/progress.py +392 -0
- morphml/utils/validation.py +392 -0
- morphml/version.py +7 -0
- morphml/visualization/__init__.py +50 -0
- morphml/visualization/analytics.py +423 -0
- morphml/visualization/architecture_diagrams.py +353 -0
- morphml/visualization/architecture_plot.py +223 -0
- morphml/visualization/convergence_plot.py +174 -0
- morphml/visualization/crossover_viz.py +386 -0
- morphml/visualization/graph_viz.py +338 -0
- morphml/visualization/pareto_plot.py +149 -0
- morphml/visualization/plotly_dashboards.py +422 -0
- morphml/visualization/population.py +309 -0
- morphml/visualization/progress.py +260 -0
- morphml-1.0.0.dist-info/METADATA +434 -0
- morphml-1.0.0.dist-info/RECORD +158 -0
- morphml-1.0.0.dist-info/WHEEL +4 -0
- morphml-1.0.0.dist-info/entry_points.txt +3 -0
- morphml-1.0.0.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,353 @@
|
|
|
1
|
+
"""Professional architecture diagram generation for MorphML.
|
|
2
|
+
|
|
3
|
+
Generate publication-quality diagrams of neural architectures.
|
|
4
|
+
|
|
5
|
+
Example:
|
|
6
|
+
>>> from morphml.visualization.architecture_diagrams import ArchitectureDiagramGenerator
|
|
7
|
+
>>> generator = ArchitectureDiagramGenerator()
|
|
8
|
+
>>> generator.generate_svg(graph, "architecture.svg")
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
from typing import Dict, Optional
|
|
12
|
+
|
|
13
|
+
try:
|
|
14
|
+
import graphviz
|
|
15
|
+
|
|
16
|
+
GRAPHVIZ_AVAILABLE = True
|
|
17
|
+
except ImportError:
|
|
18
|
+
GRAPHVIZ_AVAILABLE = False
|
|
19
|
+
graphviz = None
|
|
20
|
+
|
|
21
|
+
from morphml.core.graph import GraphNode, ModelGraph
|
|
22
|
+
from morphml.logging_config import get_logger
|
|
23
|
+
|
|
24
|
+
logger = get_logger(__name__)
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class ArchitectureDiagramGenerator:
|
|
28
|
+
"""
|
|
29
|
+
Generate professional architecture diagrams.
|
|
30
|
+
|
|
31
|
+
Creates publication-quality visualizations of neural architectures
|
|
32
|
+
using Graphviz with custom styling.
|
|
33
|
+
|
|
34
|
+
Example:
|
|
35
|
+
>>> generator = ArchitectureDiagramGenerator()
|
|
36
|
+
>>> generator.generate_svg(graph, "output.svg")
|
|
37
|
+
>>> generator.generate_png(graph, "output.png", dpi=300)
|
|
38
|
+
"""
|
|
39
|
+
|
|
40
|
+
# Color scheme for different operations
|
|
41
|
+
OPERATION_COLORS = {
|
|
42
|
+
"input": "#90EE90", # Light green
|
|
43
|
+
"output": "#FFB6C1", # Light pink
|
|
44
|
+
"conv2d": "#FF6B6B", # Red
|
|
45
|
+
"dense": "#4ECDC4", # Teal
|
|
46
|
+
"maxpool": "#45B7D1", # Blue
|
|
47
|
+
"avgpool": "#5DADE2", # Light blue
|
|
48
|
+
"relu": "#96CEB4", # Green
|
|
49
|
+
"sigmoid": "#FFEAA7", # Yellow
|
|
50
|
+
"tanh": "#DFE6E9", # Gray
|
|
51
|
+
"softmax": "#FD79A8", # Pink
|
|
52
|
+
"batchnorm": "#A29BFE", # Purple
|
|
53
|
+
"dropout": "#FDCB6E", # Orange
|
|
54
|
+
"flatten": "#E17055", # Dark orange
|
|
55
|
+
}
|
|
56
|
+
|
|
57
|
+
def __init__(self, style: str = "modern"):
|
|
58
|
+
"""
|
|
59
|
+
Initialize diagram generator.
|
|
60
|
+
|
|
61
|
+
Args:
|
|
62
|
+
style: Visual style ('modern', 'classic', 'minimal')
|
|
63
|
+
"""
|
|
64
|
+
if not GRAPHVIZ_AVAILABLE:
|
|
65
|
+
raise ImportError(
|
|
66
|
+
"Graphviz is required for architecture diagrams. "
|
|
67
|
+
"Install with: pip install graphviz"
|
|
68
|
+
)
|
|
69
|
+
|
|
70
|
+
self.style = style
|
|
71
|
+
logger.info(f"Initialized ArchitectureDiagramGenerator with style: {style}")
|
|
72
|
+
|
|
73
|
+
def generate_svg(self, graph: ModelGraph, output_path: str, title: Optional[str] = None):
|
|
74
|
+
"""
|
|
75
|
+
Generate SVG diagram.
|
|
76
|
+
|
|
77
|
+
Args:
|
|
78
|
+
graph: ModelGraph to visualize
|
|
79
|
+
output_path: Output file path (without extension)
|
|
80
|
+
title: Optional diagram title
|
|
81
|
+
|
|
82
|
+
Example:
|
|
83
|
+
>>> generator.generate_svg(graph, "architecture")
|
|
84
|
+
# Creates architecture.svg
|
|
85
|
+
"""
|
|
86
|
+
dot = self._create_graphviz_graph(graph, title)
|
|
87
|
+
dot.render(output_path, format="svg", cleanup=True)
|
|
88
|
+
logger.info(f"Generated SVG diagram: {output_path}.svg")
|
|
89
|
+
|
|
90
|
+
def generate_png(
|
|
91
|
+
self, graph: ModelGraph, output_path: str, dpi: int = 300, title: Optional[str] = None
|
|
92
|
+
):
|
|
93
|
+
"""
|
|
94
|
+
Generate PNG diagram.
|
|
95
|
+
|
|
96
|
+
Args:
|
|
97
|
+
graph: ModelGraph to visualize
|
|
98
|
+
output_path: Output file path (without extension)
|
|
99
|
+
dpi: Resolution in DPI
|
|
100
|
+
title: Optional diagram title
|
|
101
|
+
|
|
102
|
+
Example:
|
|
103
|
+
>>> generator.generate_png(graph, "architecture", dpi=300)
|
|
104
|
+
# Creates architecture.png
|
|
105
|
+
"""
|
|
106
|
+
dot = self._create_graphviz_graph(graph, title)
|
|
107
|
+
dot.graph_attr["dpi"] = str(dpi)
|
|
108
|
+
dot.render(output_path, format="png", cleanup=True)
|
|
109
|
+
logger.info(f"Generated PNG diagram: {output_path}.png")
|
|
110
|
+
|
|
111
|
+
def generate_pdf(self, graph: ModelGraph, output_path: str, title: Optional[str] = None):
|
|
112
|
+
"""
|
|
113
|
+
Generate PDF diagram.
|
|
114
|
+
|
|
115
|
+
Args:
|
|
116
|
+
graph: ModelGraph to visualize
|
|
117
|
+
output_path: Output file path (without extension)
|
|
118
|
+
title: Optional diagram title
|
|
119
|
+
|
|
120
|
+
Example:
|
|
121
|
+
>>> generator.generate_pdf(graph, "architecture")
|
|
122
|
+
# Creates architecture.pdf
|
|
123
|
+
"""
|
|
124
|
+
dot = self._create_graphviz_graph(graph, title)
|
|
125
|
+
dot.render(output_path, format="pdf", cleanup=True)
|
|
126
|
+
logger.info(f"Generated PDF diagram: {output_path}.pdf")
|
|
127
|
+
|
|
128
|
+
def _create_graphviz_graph(
|
|
129
|
+
self, graph: ModelGraph, title: Optional[str] = None
|
|
130
|
+
) -> "graphviz.Digraph":
|
|
131
|
+
"""
|
|
132
|
+
Create Graphviz graph from ModelGraph.
|
|
133
|
+
|
|
134
|
+
Args:
|
|
135
|
+
graph: ModelGraph to convert
|
|
136
|
+
title: Optional title
|
|
137
|
+
|
|
138
|
+
Returns:
|
|
139
|
+
Graphviz Digraph
|
|
140
|
+
"""
|
|
141
|
+
dot = graphviz.Digraph(comment="Neural Architecture")
|
|
142
|
+
|
|
143
|
+
# Set graph attributes based on style
|
|
144
|
+
if self.style == "modern":
|
|
145
|
+
dot.attr(
|
|
146
|
+
rankdir="TB",
|
|
147
|
+
bgcolor="white",
|
|
148
|
+
fontname="Arial",
|
|
149
|
+
fontsize="12",
|
|
150
|
+
splines="ortho",
|
|
151
|
+
nodesep="0.5",
|
|
152
|
+
ranksep="0.8",
|
|
153
|
+
)
|
|
154
|
+
elif self.style == "classic":
|
|
155
|
+
dot.attr(rankdir="TB", bgcolor="white", fontname="Times", fontsize="11")
|
|
156
|
+
else: # minimal
|
|
157
|
+
dot.attr(
|
|
158
|
+
rankdir="TB", bgcolor="white", fontname="Helvetica", fontsize="10", splines="line"
|
|
159
|
+
)
|
|
160
|
+
|
|
161
|
+
# Add title if provided
|
|
162
|
+
if title:
|
|
163
|
+
dot.attr(label=title, labelloc="t", fontsize="16")
|
|
164
|
+
|
|
165
|
+
# Add nodes
|
|
166
|
+
for _node_id, node in graph.nodes.items():
|
|
167
|
+
self._add_node(dot, node)
|
|
168
|
+
|
|
169
|
+
# Add edges
|
|
170
|
+
for _edge_id, edge in graph.edges.items():
|
|
171
|
+
self._add_edge(dot, edge)
|
|
172
|
+
|
|
173
|
+
return dot
|
|
174
|
+
|
|
175
|
+
def _add_node(self, dot: "graphviz.Digraph", node: GraphNode):
|
|
176
|
+
"""
|
|
177
|
+
Add node to Graphviz graph.
|
|
178
|
+
|
|
179
|
+
Args:
|
|
180
|
+
dot: Graphviz graph
|
|
181
|
+
node: GraphNode to add
|
|
182
|
+
"""
|
|
183
|
+
operation = node.operation
|
|
184
|
+
|
|
185
|
+
# Get color
|
|
186
|
+
color = self.OPERATION_COLORS.get(operation, "#CCCCCC")
|
|
187
|
+
|
|
188
|
+
# Create label
|
|
189
|
+
label = self._create_node_label(node)
|
|
190
|
+
|
|
191
|
+
# Node attributes
|
|
192
|
+
node_attrs = {
|
|
193
|
+
"label": label,
|
|
194
|
+
"shape": "box",
|
|
195
|
+
"style": "filled,rounded",
|
|
196
|
+
"fillcolor": color,
|
|
197
|
+
"fontname": "Arial",
|
|
198
|
+
"fontsize": "10",
|
|
199
|
+
"margin": "0.2,0.1",
|
|
200
|
+
}
|
|
201
|
+
|
|
202
|
+
# Special styling for input/output
|
|
203
|
+
if operation == "input":
|
|
204
|
+
node_attrs["shape"] = "ellipse"
|
|
205
|
+
node_attrs["style"] = "filled"
|
|
206
|
+
elif operation == "output":
|
|
207
|
+
node_attrs["shape"] = "ellipse"
|
|
208
|
+
node_attrs["style"] = "filled"
|
|
209
|
+
|
|
210
|
+
dot.node(str(node.id), **node_attrs)
|
|
211
|
+
|
|
212
|
+
def _create_node_label(self, node: GraphNode) -> str:
|
|
213
|
+
"""
|
|
214
|
+
Create formatted label for node.
|
|
215
|
+
|
|
216
|
+
Args:
|
|
217
|
+
node: GraphNode
|
|
218
|
+
|
|
219
|
+
Returns:
|
|
220
|
+
Formatted label string
|
|
221
|
+
"""
|
|
222
|
+
operation = node.operation
|
|
223
|
+
params = node.params
|
|
224
|
+
|
|
225
|
+
# Base label
|
|
226
|
+
label = f"{operation}"
|
|
227
|
+
|
|
228
|
+
# Add key parameters
|
|
229
|
+
if operation == "conv2d":
|
|
230
|
+
filters = params.get("filters", "?")
|
|
231
|
+
kernel = params.get("kernel_size", "?")
|
|
232
|
+
label += f"\n{filters} filters\n{kernel}x{kernel} kernel"
|
|
233
|
+
|
|
234
|
+
elif operation == "dense":
|
|
235
|
+
units = params.get("units", "?")
|
|
236
|
+
label += f"\n{units} units"
|
|
237
|
+
|
|
238
|
+
elif operation in ["maxpool", "avgpool"]:
|
|
239
|
+
pool_size = params.get("pool_size", "?")
|
|
240
|
+
label += f"\n{pool_size}x{pool_size} pool"
|
|
241
|
+
|
|
242
|
+
elif operation == "dropout":
|
|
243
|
+
rate = params.get("rate", "?")
|
|
244
|
+
label += f"\nrate={rate}"
|
|
245
|
+
|
|
246
|
+
elif operation == "input":
|
|
247
|
+
shape = params.get("shape", "?")
|
|
248
|
+
label += f"\nshape={shape}"
|
|
249
|
+
|
|
250
|
+
return label
|
|
251
|
+
|
|
252
|
+
def _add_edge(self, dot: "graphviz.Digraph", edge):
|
|
253
|
+
"""
|
|
254
|
+
Add edge to Graphviz graph.
|
|
255
|
+
|
|
256
|
+
Args:
|
|
257
|
+
dot: Graphviz graph
|
|
258
|
+
edge: GraphEdge to add
|
|
259
|
+
"""
|
|
260
|
+
dot.edge(
|
|
261
|
+
str(edge.source.id),
|
|
262
|
+
str(edge.target.id),
|
|
263
|
+
color="#2C3E50",
|
|
264
|
+
penwidth="1.5",
|
|
265
|
+
arrowsize="0.8",
|
|
266
|
+
)
|
|
267
|
+
|
|
268
|
+
def generate_comparison_diagram(
|
|
269
|
+
self, graphs: Dict[str, ModelGraph], output_path: str, format: str = "svg"
|
|
270
|
+
):
|
|
271
|
+
"""
|
|
272
|
+
Generate side-by-side comparison of multiple architectures.
|
|
273
|
+
|
|
274
|
+
Args:
|
|
275
|
+
graphs: Dict mapping names to ModelGraphs
|
|
276
|
+
output_path: Output file path (without extension)
|
|
277
|
+
format: Output format ('svg', 'png', 'pdf')
|
|
278
|
+
|
|
279
|
+
Example:
|
|
280
|
+
>>> generator.generate_comparison_diagram({
|
|
281
|
+
... 'Architecture A': graph_a,
|
|
282
|
+
... 'Architecture B': graph_b
|
|
283
|
+
... }, "comparison")
|
|
284
|
+
"""
|
|
285
|
+
# Create compound graph
|
|
286
|
+
dot = graphviz.Digraph(comment="Architecture Comparison")
|
|
287
|
+
dot.attr(rankdir="LR", compound="true")
|
|
288
|
+
|
|
289
|
+
# Add subgraphs for each architecture
|
|
290
|
+
for i, (name, graph) in enumerate(graphs.items()):
|
|
291
|
+
with dot.subgraph(name=f"cluster_{i}") as sub:
|
|
292
|
+
sub.attr(label=name, style="rounded", color="gray")
|
|
293
|
+
|
|
294
|
+
# Add nodes
|
|
295
|
+
for _node_id, node in graph.nodes.items():
|
|
296
|
+
self._add_node(sub, node)
|
|
297
|
+
|
|
298
|
+
# Add edges
|
|
299
|
+
for _edge_id, edge in graph.edges.items():
|
|
300
|
+
self._add_edge(sub, edge)
|
|
301
|
+
|
|
302
|
+
# Render
|
|
303
|
+
dot.render(output_path, format=format, cleanup=True)
|
|
304
|
+
logger.info(f"Generated comparison diagram: {output_path}.{format}")
|
|
305
|
+
|
|
306
|
+
def generate_layer_statistics_diagram(self, graph: ModelGraph, output_path: str):
|
|
307
|
+
"""
|
|
308
|
+
Generate diagram with layer statistics.
|
|
309
|
+
|
|
310
|
+
Args:
|
|
311
|
+
graph: ModelGraph to visualize
|
|
312
|
+
output_path: Output file path (without extension)
|
|
313
|
+
|
|
314
|
+
Example:
|
|
315
|
+
>>> generator.generate_layer_statistics_diagram(graph, "stats")
|
|
316
|
+
"""
|
|
317
|
+
dot = graphviz.Digraph(comment="Architecture with Statistics")
|
|
318
|
+
dot.attr(rankdir="TB", bgcolor="white")
|
|
319
|
+
|
|
320
|
+
# Calculate statistics
|
|
321
|
+
total_params = graph.estimate_parameters()
|
|
322
|
+
depth = graph.estimate_depth()
|
|
323
|
+
width = graph.estimate_width()
|
|
324
|
+
|
|
325
|
+
# Add info box
|
|
326
|
+
info_label = (
|
|
327
|
+
f"Total Parameters: {total_params:,}\\n"
|
|
328
|
+
f"Depth: {depth}\\n"
|
|
329
|
+
f"Width: {width}\\n"
|
|
330
|
+
f"Nodes: {len(graph.nodes)}"
|
|
331
|
+
)
|
|
332
|
+
|
|
333
|
+
dot.node(
|
|
334
|
+
"info",
|
|
335
|
+
label=info_label,
|
|
336
|
+
shape="box",
|
|
337
|
+
style="filled",
|
|
338
|
+
fillcolor="#E8F4F8",
|
|
339
|
+
fontname="Courier",
|
|
340
|
+
fontsize="10",
|
|
341
|
+
)
|
|
342
|
+
|
|
343
|
+
# Add architecture nodes
|
|
344
|
+
for _node_id, node in graph.nodes.items():
|
|
345
|
+
self._add_node(dot, node)
|
|
346
|
+
|
|
347
|
+
# Add edges
|
|
348
|
+
for _edge_id, edge in graph.edges.items():
|
|
349
|
+
self._add_edge(dot, edge)
|
|
350
|
+
|
|
351
|
+
# Render
|
|
352
|
+
dot.render(output_path, format="svg", cleanup=True)
|
|
353
|
+
logger.info(f"Generated statistics diagram: {output_path}.svg")
|
|
@@ -0,0 +1,223 @@
|
|
|
1
|
+
"""Architecture visualization utilities.
|
|
2
|
+
|
|
3
|
+
Author: Eshan Roy <eshanized@proton.me>
|
|
4
|
+
Organization: TONMOY INFRASTRUCTURE & VISION
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from typing import Optional
|
|
8
|
+
|
|
9
|
+
from morphml.core.graph import ModelGraph
|
|
10
|
+
from morphml.logging_config import get_logger
|
|
11
|
+
|
|
12
|
+
logger = get_logger(__name__)
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def plot_architecture(
|
|
16
|
+
graph: ModelGraph, save_path: Optional[str] = None, title: str = "Neural Architecture"
|
|
17
|
+
) -> None:
|
|
18
|
+
"""
|
|
19
|
+
Visualize neural architecture as a graph.
|
|
20
|
+
|
|
21
|
+
Args:
|
|
22
|
+
graph: ModelGraph to visualize
|
|
23
|
+
save_path: Path to save plot (displays if None)
|
|
24
|
+
title: Plot title
|
|
25
|
+
|
|
26
|
+
Example:
|
|
27
|
+
>>> from morphml.visualization.architecture_plot import plot_architecture
|
|
28
|
+
>>> plot_architecture(best_individual.graph)
|
|
29
|
+
"""
|
|
30
|
+
try:
|
|
31
|
+
import matplotlib.pyplot as plt
|
|
32
|
+
import networkx as nx
|
|
33
|
+
except ImportError:
|
|
34
|
+
logger.error("matplotlib and networkx required for architecture plotting")
|
|
35
|
+
return
|
|
36
|
+
|
|
37
|
+
# Create NetworkX graph
|
|
38
|
+
G = nx.DiGraph()
|
|
39
|
+
|
|
40
|
+
# Add nodes
|
|
41
|
+
for node_id, node in graph.nodes.items():
|
|
42
|
+
G.add_node(node_id, label=f"{node.operation}\n{node_id[:8]}")
|
|
43
|
+
|
|
44
|
+
# Add edges
|
|
45
|
+
for edge in graph.edges:
|
|
46
|
+
G.add_edge(edge.source.id, edge.target.id)
|
|
47
|
+
|
|
48
|
+
# Create layout
|
|
49
|
+
pos = nx.spring_layout(G, k=1, iterations=50)
|
|
50
|
+
|
|
51
|
+
# Create plot
|
|
52
|
+
plt.figure(figsize=(14, 10))
|
|
53
|
+
|
|
54
|
+
# Draw nodes
|
|
55
|
+
nx.draw_networkx_nodes(
|
|
56
|
+
G, pos, node_size=3000, node_color="lightblue", alpha=0.9, edgecolors="black", linewidths=2
|
|
57
|
+
)
|
|
58
|
+
|
|
59
|
+
# Draw edges
|
|
60
|
+
nx.draw_networkx_edges(
|
|
61
|
+
G, pos, edge_color="gray", arrows=True, arrowsize=20, arrowstyle="->", width=2
|
|
62
|
+
)
|
|
63
|
+
|
|
64
|
+
# Draw labels
|
|
65
|
+
labels = nx.get_node_attributes(G, "label")
|
|
66
|
+
nx.draw_networkx_labels(G, pos, labels, font_size=8)
|
|
67
|
+
|
|
68
|
+
plt.title(title, fontsize=16, fontweight="bold")
|
|
69
|
+
plt.axis("off")
|
|
70
|
+
plt.tight_layout()
|
|
71
|
+
|
|
72
|
+
if save_path:
|
|
73
|
+
plt.savefig(save_path, dpi=300, bbox_inches="tight")
|
|
74
|
+
logger.info(f"Architecture plot saved to {save_path}")
|
|
75
|
+
else:
|
|
76
|
+
plt.show()
|
|
77
|
+
|
|
78
|
+
plt.close()
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
def plot_architecture_hierarchy(
|
|
82
|
+
graph: ModelGraph, save_path: Optional[str] = None, title: str = "Architecture Hierarchy"
|
|
83
|
+
) -> None:
|
|
84
|
+
"""
|
|
85
|
+
Visualize architecture with hierarchical layout.
|
|
86
|
+
|
|
87
|
+
Args:
|
|
88
|
+
graph: ModelGraph to visualize
|
|
89
|
+
save_path: Path to save plot
|
|
90
|
+
title: Plot title
|
|
91
|
+
"""
|
|
92
|
+
try:
|
|
93
|
+
import matplotlib.pyplot as plt
|
|
94
|
+
import networkx as nx
|
|
95
|
+
except ImportError:
|
|
96
|
+
logger.error("matplotlib and networkx required")
|
|
97
|
+
return
|
|
98
|
+
|
|
99
|
+
# Create NetworkX graph
|
|
100
|
+
G = nx.DiGraph()
|
|
101
|
+
|
|
102
|
+
for node_id, node in graph.nodes.items():
|
|
103
|
+
G.add_node(node_id, label=node.operation)
|
|
104
|
+
|
|
105
|
+
for edge in graph.edges:
|
|
106
|
+
G.add_edge(edge.source.id, edge.target.id)
|
|
107
|
+
|
|
108
|
+
# Use hierarchical layout
|
|
109
|
+
try:
|
|
110
|
+
pos = nx.nx_agraph.graphviz_layout(G, prog="dot")
|
|
111
|
+
except Exception:
|
|
112
|
+
# Fallback to spring layout if graphviz not available
|
|
113
|
+
pos = nx.spring_layout(G)
|
|
114
|
+
|
|
115
|
+
plt.figure(figsize=(12, 10))
|
|
116
|
+
|
|
117
|
+
# Color nodes by operation type
|
|
118
|
+
node_colors = []
|
|
119
|
+
for node_id in G.nodes():
|
|
120
|
+
node = graph.nodes[node_id]
|
|
121
|
+
if node.operation in ["input", "output"]:
|
|
122
|
+
node_colors.append("lightgreen")
|
|
123
|
+
elif "conv" in node.operation:
|
|
124
|
+
node_colors.append("lightblue")
|
|
125
|
+
elif "dense" in node.operation:
|
|
126
|
+
node_colors.append("lightyellow")
|
|
127
|
+
else:
|
|
128
|
+
node_colors.append("lightgray")
|
|
129
|
+
|
|
130
|
+
nx.draw(
|
|
131
|
+
G,
|
|
132
|
+
pos,
|
|
133
|
+
node_color=node_colors,
|
|
134
|
+
with_labels=False,
|
|
135
|
+
node_size=2000,
|
|
136
|
+
alpha=0.9,
|
|
137
|
+
edgecolors="black",
|
|
138
|
+
linewidths=2,
|
|
139
|
+
edge_color="gray",
|
|
140
|
+
arrows=True,
|
|
141
|
+
arrowsize=15,
|
|
142
|
+
)
|
|
143
|
+
|
|
144
|
+
# Add labels
|
|
145
|
+
labels = {node_id: graph.nodes[node_id].operation for node_id in G.nodes()}
|
|
146
|
+
nx.draw_networkx_labels(G, pos, labels, font_size=8)
|
|
147
|
+
|
|
148
|
+
plt.title(title, fontsize=16, fontweight="bold")
|
|
149
|
+
plt.axis("off")
|
|
150
|
+
plt.tight_layout()
|
|
151
|
+
|
|
152
|
+
if save_path:
|
|
153
|
+
plt.savefig(save_path, dpi=300, bbox_inches="tight")
|
|
154
|
+
logger.info(f"Architecture hierarchy plot saved to {save_path}")
|
|
155
|
+
else:
|
|
156
|
+
plt.show()
|
|
157
|
+
|
|
158
|
+
plt.close()
|
|
159
|
+
|
|
160
|
+
|
|
161
|
+
def plot_architecture_stats(
|
|
162
|
+
graphs: list, save_path: Optional[str] = None, title: str = "Architecture Statistics"
|
|
163
|
+
) -> None:
|
|
164
|
+
"""
|
|
165
|
+
Plot statistics of multiple architectures.
|
|
166
|
+
|
|
167
|
+
Args:
|
|
168
|
+
graphs: List of ModelGraph objects
|
|
169
|
+
save_path: Path to save plot
|
|
170
|
+
title: Plot title
|
|
171
|
+
"""
|
|
172
|
+
try:
|
|
173
|
+
import matplotlib.pyplot as plt
|
|
174
|
+
except ImportError:
|
|
175
|
+
logger.error("matplotlib required")
|
|
176
|
+
return
|
|
177
|
+
|
|
178
|
+
# Collect statistics
|
|
179
|
+
node_counts = [len(g.nodes) for g in graphs]
|
|
180
|
+
depths = [g.get_depth() for g in graphs]
|
|
181
|
+
params = [g.estimate_parameters() / 1e6 for g in graphs] # In millions
|
|
182
|
+
|
|
183
|
+
# Create subplots
|
|
184
|
+
fig, axes = plt.subplots(1, 3, figsize=(15, 5))
|
|
185
|
+
|
|
186
|
+
# Node count distribution
|
|
187
|
+
axes[0].hist(node_counts, bins=20, edgecolor="black", alpha=0.7)
|
|
188
|
+
axes[0].set_xlabel("Number of Nodes", fontsize=12)
|
|
189
|
+
axes[0].set_ylabel("Frequency", fontsize=12)
|
|
190
|
+
axes[0].set_title("Node Count Distribution", fontsize=12, fontweight="bold")
|
|
191
|
+
axes[0].grid(True, alpha=0.3)
|
|
192
|
+
|
|
193
|
+
# Depth distribution
|
|
194
|
+
axes[1].hist(depths, bins=20, edgecolor="black", alpha=0.7, color="green")
|
|
195
|
+
axes[1].set_xlabel("Depth", fontsize=12)
|
|
196
|
+
axes[1].set_ylabel("Frequency", fontsize=12)
|
|
197
|
+
axes[1].set_title("Depth Distribution", fontsize=12, fontweight="bold")
|
|
198
|
+
axes[1].grid(True, alpha=0.3)
|
|
199
|
+
|
|
200
|
+
# Parameter count distribution
|
|
201
|
+
axes[2].hist(params, bins=20, edgecolor="black", alpha=0.7, color="orange")
|
|
202
|
+
axes[2].set_xlabel("Parameters (Millions)", fontsize=12)
|
|
203
|
+
axes[2].set_ylabel("Frequency", fontsize=12)
|
|
204
|
+
axes[2].set_title("Parameter Count Distribution", fontsize=12, fontweight="bold")
|
|
205
|
+
axes[2].grid(True, alpha=0.3)
|
|
206
|
+
|
|
207
|
+
plt.suptitle(title, fontsize=16, fontweight="bold")
|
|
208
|
+
plt.tight_layout()
|
|
209
|
+
|
|
210
|
+
if save_path:
|
|
211
|
+
plt.savefig(save_path, dpi=300, bbox_inches="tight")
|
|
212
|
+
logger.info(f"Architecture statistics plot saved to {save_path}")
|
|
213
|
+
else:
|
|
214
|
+
plt.show()
|
|
215
|
+
|
|
216
|
+
plt.close()
|
|
217
|
+
|
|
218
|
+
|
|
219
|
+
# Re-export from existing visualization module for convenience
|
|
220
|
+
try:
|
|
221
|
+
pass
|
|
222
|
+
except ImportError:
|
|
223
|
+
pass
|