morphml 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of morphml might be problematic. Click here for more details.
- morphml/__init__.py +14 -0
- morphml/api/__init__.py +26 -0
- morphml/api/app.py +326 -0
- morphml/api/auth.py +193 -0
- morphml/api/client.py +338 -0
- morphml/api/models.py +132 -0
- morphml/api/rate_limit.py +192 -0
- morphml/benchmarking/__init__.py +36 -0
- morphml/benchmarking/comparison.py +430 -0
- morphml/benchmarks/__init__.py +56 -0
- morphml/benchmarks/comparator.py +409 -0
- morphml/benchmarks/datasets.py +280 -0
- morphml/benchmarks/metrics.py +199 -0
- morphml/benchmarks/openml_suite.py +201 -0
- morphml/benchmarks/problems.py +289 -0
- morphml/benchmarks/suite.py +318 -0
- morphml/cli/__init__.py +5 -0
- morphml/cli/commands/experiment.py +329 -0
- morphml/cli/main.py +457 -0
- morphml/cli/quickstart.py +312 -0
- morphml/config.py +278 -0
- morphml/constraints/__init__.py +19 -0
- morphml/constraints/handler.py +205 -0
- morphml/constraints/predicates.py +285 -0
- morphml/core/__init__.py +3 -0
- morphml/core/crossover.py +449 -0
- morphml/core/dsl/README.md +359 -0
- morphml/core/dsl/__init__.py +72 -0
- morphml/core/dsl/ast_nodes.py +364 -0
- morphml/core/dsl/compiler.py +318 -0
- morphml/core/dsl/layers.py +368 -0
- morphml/core/dsl/lexer.py +336 -0
- morphml/core/dsl/parser.py +455 -0
- morphml/core/dsl/search_space.py +386 -0
- morphml/core/dsl/syntax.py +199 -0
- morphml/core/dsl/type_system.py +361 -0
- morphml/core/dsl/validator.py +386 -0
- morphml/core/graph/__init__.py +40 -0
- morphml/core/graph/edge.py +124 -0
- morphml/core/graph/graph.py +507 -0
- morphml/core/graph/mutations.py +409 -0
- morphml/core/graph/node.py +196 -0
- morphml/core/graph/serialization.py +361 -0
- morphml/core/graph/visualization.py +431 -0
- morphml/core/objectives/__init__.py +20 -0
- morphml/core/search/__init__.py +33 -0
- morphml/core/search/individual.py +252 -0
- morphml/core/search/parameters.py +453 -0
- morphml/core/search/population.py +375 -0
- morphml/core/search/search_engine.py +340 -0
- morphml/distributed/__init__.py +76 -0
- morphml/distributed/fault_tolerance.py +497 -0
- morphml/distributed/health_monitor.py +348 -0
- morphml/distributed/master.py +709 -0
- morphml/distributed/proto/README.md +224 -0
- morphml/distributed/proto/__init__.py +74 -0
- morphml/distributed/proto/worker.proto +170 -0
- morphml/distributed/proto/worker_pb2.py +79 -0
- morphml/distributed/proto/worker_pb2_grpc.py +423 -0
- morphml/distributed/resource_manager.py +416 -0
- morphml/distributed/scheduler.py +567 -0
- morphml/distributed/storage/__init__.py +33 -0
- morphml/distributed/storage/artifacts.py +381 -0
- morphml/distributed/storage/cache.py +366 -0
- morphml/distributed/storage/checkpointing.py +329 -0
- morphml/distributed/storage/database.py +459 -0
- morphml/distributed/worker.py +549 -0
- morphml/evaluation/__init__.py +5 -0
- morphml/evaluation/heuristic.py +237 -0
- morphml/exceptions.py +55 -0
- morphml/execution/__init__.py +5 -0
- morphml/execution/local_executor.py +350 -0
- morphml/integrations/__init__.py +28 -0
- morphml/integrations/jax_adapter.py +206 -0
- morphml/integrations/pytorch_adapter.py +530 -0
- morphml/integrations/sklearn_adapter.py +206 -0
- morphml/integrations/tensorflow_adapter.py +230 -0
- morphml/logging_config.py +93 -0
- morphml/meta_learning/__init__.py +66 -0
- morphml/meta_learning/architecture_similarity.py +277 -0
- morphml/meta_learning/experiment_database.py +240 -0
- morphml/meta_learning/knowledge_base/__init__.py +19 -0
- morphml/meta_learning/knowledge_base/embedder.py +179 -0
- morphml/meta_learning/knowledge_base/knowledge_base.py +313 -0
- morphml/meta_learning/knowledge_base/meta_features.py +265 -0
- morphml/meta_learning/knowledge_base/vector_store.py +271 -0
- morphml/meta_learning/predictors/__init__.py +27 -0
- morphml/meta_learning/predictors/ensemble.py +221 -0
- morphml/meta_learning/predictors/gnn_predictor.py +552 -0
- morphml/meta_learning/predictors/learning_curve.py +231 -0
- morphml/meta_learning/predictors/proxy_metrics.py +261 -0
- morphml/meta_learning/strategy_evolution/__init__.py +27 -0
- morphml/meta_learning/strategy_evolution/adaptive_optimizer.py +226 -0
- morphml/meta_learning/strategy_evolution/bandit.py +276 -0
- morphml/meta_learning/strategy_evolution/portfolio.py +230 -0
- morphml/meta_learning/transfer.py +581 -0
- morphml/meta_learning/warm_start.py +286 -0
- morphml/optimizers/__init__.py +74 -0
- morphml/optimizers/adaptive_operators.py +399 -0
- morphml/optimizers/bayesian/__init__.py +52 -0
- morphml/optimizers/bayesian/acquisition.py +387 -0
- morphml/optimizers/bayesian/base.py +319 -0
- morphml/optimizers/bayesian/gaussian_process.py +635 -0
- morphml/optimizers/bayesian/smac.py +534 -0
- morphml/optimizers/bayesian/tpe.py +411 -0
- morphml/optimizers/differential_evolution.py +220 -0
- morphml/optimizers/evolutionary/__init__.py +61 -0
- morphml/optimizers/evolutionary/cma_es.py +416 -0
- morphml/optimizers/evolutionary/differential_evolution.py +556 -0
- morphml/optimizers/evolutionary/encoding.py +426 -0
- morphml/optimizers/evolutionary/particle_swarm.py +449 -0
- morphml/optimizers/genetic_algorithm.py +486 -0
- morphml/optimizers/gradient_based/__init__.py +22 -0
- morphml/optimizers/gradient_based/darts.py +550 -0
- morphml/optimizers/gradient_based/enas.py +585 -0
- morphml/optimizers/gradient_based/operations.py +474 -0
- morphml/optimizers/gradient_based/utils.py +601 -0
- morphml/optimizers/hill_climbing.py +169 -0
- morphml/optimizers/multi_objective/__init__.py +56 -0
- morphml/optimizers/multi_objective/indicators.py +504 -0
- morphml/optimizers/multi_objective/nsga2.py +647 -0
- morphml/optimizers/multi_objective/visualization.py +427 -0
- morphml/optimizers/nsga2.py +308 -0
- morphml/optimizers/random_search.py +172 -0
- morphml/optimizers/simulated_annealing.py +181 -0
- morphml/plugins/__init__.py +35 -0
- morphml/plugins/custom_evaluator_example.py +81 -0
- morphml/plugins/custom_optimizer_example.py +63 -0
- morphml/plugins/plugin_system.py +454 -0
- morphml/reports/__init__.py +30 -0
- morphml/reports/generator.py +362 -0
- morphml/tracking/__init__.py +7 -0
- morphml/tracking/experiment.py +309 -0
- morphml/tracking/logger.py +301 -0
- morphml/tracking/reporter.py +357 -0
- morphml/utils/__init__.py +6 -0
- morphml/utils/checkpoint.py +189 -0
- morphml/utils/comparison.py +390 -0
- morphml/utils/export.py +407 -0
- morphml/utils/progress.py +392 -0
- morphml/utils/validation.py +392 -0
- morphml/version.py +7 -0
- morphml/visualization/__init__.py +50 -0
- morphml/visualization/analytics.py +423 -0
- morphml/visualization/architecture_diagrams.py +353 -0
- morphml/visualization/architecture_plot.py +223 -0
- morphml/visualization/convergence_plot.py +174 -0
- morphml/visualization/crossover_viz.py +386 -0
- morphml/visualization/graph_viz.py +338 -0
- morphml/visualization/pareto_plot.py +149 -0
- morphml/visualization/plotly_dashboards.py +422 -0
- morphml/visualization/population.py +309 -0
- morphml/visualization/progress.py +260 -0
- morphml-1.0.0.dist-info/METADATA +434 -0
- morphml-1.0.0.dist-info/RECORD +158 -0
- morphml-1.0.0.dist-info/WHEEL +4 -0
- morphml-1.0.0.dist-info/entry_points.txt +3 -0
- morphml-1.0.0.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,338 @@
|
|
|
1
|
+
"""Graph visualization for neural architectures."""
|
|
2
|
+
|
|
3
|
+
from typing import Optional, Tuple
|
|
4
|
+
|
|
5
|
+
from morphml.core.graph import ModelGraph
|
|
6
|
+
from morphml.logging_config import get_logger
|
|
7
|
+
|
|
8
|
+
logger = get_logger(__name__)
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class GraphVisualizer:
|
|
12
|
+
"""
|
|
13
|
+
Visualize neural architecture graphs.
|
|
14
|
+
|
|
15
|
+
Supports multiple output formats:
|
|
16
|
+
- Graphviz DOT
|
|
17
|
+
- NetworkX plots
|
|
18
|
+
- ASCII art
|
|
19
|
+
- JSON
|
|
20
|
+
|
|
21
|
+
Example:
|
|
22
|
+
>>> viz = GraphVisualizer()
|
|
23
|
+
>>> viz.to_graphviz(graph, 'architecture.dot')
|
|
24
|
+
>>> viz.plot_graph(graph, 'architecture.png')
|
|
25
|
+
"""
|
|
26
|
+
|
|
27
|
+
def __init__(self, style: str = "default"):
|
|
28
|
+
"""
|
|
29
|
+
Initialize visualizer.
|
|
30
|
+
|
|
31
|
+
Args:
|
|
32
|
+
style: Visualization style ('default', 'modern', 'simple')
|
|
33
|
+
"""
|
|
34
|
+
self.style = style
|
|
35
|
+
self.node_colors = {
|
|
36
|
+
"input": "#4CAF50",
|
|
37
|
+
"conv2d": "#2196F3",
|
|
38
|
+
"dense": "#FF9800",
|
|
39
|
+
"maxpool": "#9C27B0",
|
|
40
|
+
"avgpool": "#9C27B0",
|
|
41
|
+
"dropout": "#F44336",
|
|
42
|
+
"batchnorm": "#00BCD4",
|
|
43
|
+
"relu": "#FFEB3B",
|
|
44
|
+
"sigmoid": "#FFEB3B",
|
|
45
|
+
"tanh": "#FFEB3B",
|
|
46
|
+
"softmax": "#FFEB3B",
|
|
47
|
+
"output": "#4CAF50",
|
|
48
|
+
}
|
|
49
|
+
|
|
50
|
+
def to_graphviz(self, graph: ModelGraph, output_path: str) -> None:
|
|
51
|
+
"""
|
|
52
|
+
Export graph to Graphviz DOT format.
|
|
53
|
+
|
|
54
|
+
Args:
|
|
55
|
+
graph: Graph to visualize
|
|
56
|
+
output_path: Output file path
|
|
57
|
+
"""
|
|
58
|
+
try:
|
|
59
|
+
import graphviz
|
|
60
|
+
except ImportError:
|
|
61
|
+
logger.warning("graphviz not installed")
|
|
62
|
+
self._write_dot_file(graph, output_path)
|
|
63
|
+
return
|
|
64
|
+
|
|
65
|
+
dot = graphviz.Digraph(comment="Neural Architecture")
|
|
66
|
+
dot.attr(rankdir="TB")
|
|
67
|
+
|
|
68
|
+
# Add nodes
|
|
69
|
+
for node_id, node in graph.nodes.items():
|
|
70
|
+
color = self.node_colors.get(node.operation, "#9E9E9E")
|
|
71
|
+
label = self._format_node_label(node)
|
|
72
|
+
|
|
73
|
+
dot.node(node_id, label, style="filled", fillcolor=color, shape="box", fontname="Arial")
|
|
74
|
+
|
|
75
|
+
# Add edges
|
|
76
|
+
for edge in graph.edges:
|
|
77
|
+
dot.edge(edge.source.id, edge.target.id)
|
|
78
|
+
|
|
79
|
+
# Render
|
|
80
|
+
dot.render(output_path, format="png", cleanup=True)
|
|
81
|
+
logger.info(f"Graph visualization saved to {output_path}.png")
|
|
82
|
+
|
|
83
|
+
def _write_dot_file(self, graph: ModelGraph, output_path: str) -> None:
|
|
84
|
+
"""Write DOT file without graphviz library."""
|
|
85
|
+
with open(output_path, "w") as f:
|
|
86
|
+
f.write("digraph G {\n")
|
|
87
|
+
f.write(" rankdir=TB;\n")
|
|
88
|
+
f.write(" node [shape=box, style=filled];\n\n")
|
|
89
|
+
|
|
90
|
+
# Nodes
|
|
91
|
+
for node_id, node in graph.nodes.items():
|
|
92
|
+
color = self.node_colors.get(node.operation, "#9E9E9E")
|
|
93
|
+
label = self._format_node_label(node)
|
|
94
|
+
f.write(f' "{node_id}" [label="{label}", fillcolor="{color}"];\n')
|
|
95
|
+
|
|
96
|
+
f.write("\n")
|
|
97
|
+
|
|
98
|
+
# Edges
|
|
99
|
+
for edge in graph.edges:
|
|
100
|
+
f.write(f' "{edge.source.id}" -> "{edge.target.id}";\n')
|
|
101
|
+
|
|
102
|
+
f.write("}\n")
|
|
103
|
+
|
|
104
|
+
logger.info(f"DOT file saved to {output_path}")
|
|
105
|
+
|
|
106
|
+
def _format_node_label(self, node) -> str:
|
|
107
|
+
"""Format node label with operation and parameters."""
|
|
108
|
+
label = node.operation
|
|
109
|
+
|
|
110
|
+
if node.params:
|
|
111
|
+
params_str = []
|
|
112
|
+
for key, value in list(node.params.items())[:3]: # Limit to 3 params
|
|
113
|
+
params_str.append(f"{key}={value}")
|
|
114
|
+
|
|
115
|
+
if params_str:
|
|
116
|
+
label += "\\n" + "\\n".join(params_str)
|
|
117
|
+
|
|
118
|
+
return label
|
|
119
|
+
|
|
120
|
+
def to_ascii(self, graph: ModelGraph) -> str:
|
|
121
|
+
"""
|
|
122
|
+
Generate ASCII art representation.
|
|
123
|
+
|
|
124
|
+
Args:
|
|
125
|
+
graph: Graph to visualize
|
|
126
|
+
|
|
127
|
+
Returns:
|
|
128
|
+
ASCII art string
|
|
129
|
+
"""
|
|
130
|
+
try:
|
|
131
|
+
sorted_nodes = graph.topological_sort()
|
|
132
|
+
except Exception:
|
|
133
|
+
sorted_nodes = list(graph.nodes.values())
|
|
134
|
+
|
|
135
|
+
lines = []
|
|
136
|
+
lines.append("=" * 60)
|
|
137
|
+
lines.append("NEURAL ARCHITECTURE")
|
|
138
|
+
lines.append("=" * 60)
|
|
139
|
+
lines.append("")
|
|
140
|
+
|
|
141
|
+
for i, node in enumerate(sorted_nodes):
|
|
142
|
+
# Node info
|
|
143
|
+
op = node.operation
|
|
144
|
+
params = ", ".join(f"{k}={v}" for k, v in list(node.params.items())[:2])
|
|
145
|
+
|
|
146
|
+
lines.append(f"[{i+1}] {op}")
|
|
147
|
+
if params:
|
|
148
|
+
lines.append(f" {params}")
|
|
149
|
+
|
|
150
|
+
# Show connections
|
|
151
|
+
if i < len(sorted_nodes) - 1:
|
|
152
|
+
lines.append(" |")
|
|
153
|
+
lines.append(" v")
|
|
154
|
+
|
|
155
|
+
lines.append("")
|
|
156
|
+
|
|
157
|
+
lines.append("=" * 60)
|
|
158
|
+
lines.append(f"Total Nodes: {len(graph.nodes)}")
|
|
159
|
+
lines.append(f"Total Edges: {len(graph.edges)}")
|
|
160
|
+
lines.append(f"Depth: {graph.get_depth()}")
|
|
161
|
+
lines.append(f"Est. Parameters: {graph.estimate_parameters():,}")
|
|
162
|
+
lines.append("=" * 60)
|
|
163
|
+
|
|
164
|
+
return "\n".join(lines)
|
|
165
|
+
|
|
166
|
+
def plot_networkx(
|
|
167
|
+
self,
|
|
168
|
+
graph: ModelGraph,
|
|
169
|
+
output_path: Optional[str] = None,
|
|
170
|
+
figsize: Tuple[int, int] = (12, 8),
|
|
171
|
+
) -> None:
|
|
172
|
+
"""
|
|
173
|
+
Plot graph using NetworkX and matplotlib.
|
|
174
|
+
|
|
175
|
+
Args:
|
|
176
|
+
graph: Graph to visualize
|
|
177
|
+
output_path: Path to save plot (displays if None)
|
|
178
|
+
figsize: Figure size
|
|
179
|
+
"""
|
|
180
|
+
try:
|
|
181
|
+
import matplotlib.pyplot as plt
|
|
182
|
+
import networkx as nx
|
|
183
|
+
except ImportError:
|
|
184
|
+
logger.warning("matplotlib or networkx not installed")
|
|
185
|
+
return
|
|
186
|
+
|
|
187
|
+
# Convert to NetworkX
|
|
188
|
+
G = graph.to_networkx()
|
|
189
|
+
|
|
190
|
+
# Create layout
|
|
191
|
+
pos = nx.spring_layout(G, k=2, iterations=50)
|
|
192
|
+
|
|
193
|
+
# Plot
|
|
194
|
+
plt.figure(figsize=figsize)
|
|
195
|
+
|
|
196
|
+
# Draw nodes with colors
|
|
197
|
+
node_colors_list = []
|
|
198
|
+
for node_id in G.nodes():
|
|
199
|
+
node = graph.nodes[node_id]
|
|
200
|
+
color = self.node_colors.get(node.operation, "#9E9E9E")
|
|
201
|
+
node_colors_list.append(color)
|
|
202
|
+
|
|
203
|
+
nx.draw_networkx_nodes(G, pos, node_color=node_colors_list, node_size=2000, alpha=0.9)
|
|
204
|
+
|
|
205
|
+
# Draw edges
|
|
206
|
+
nx.draw_networkx_edges(G, pos, edge_color="gray", arrows=True, arrowsize=20, width=2)
|
|
207
|
+
|
|
208
|
+
# Draw labels
|
|
209
|
+
labels = {}
|
|
210
|
+
for node_id in G.nodes():
|
|
211
|
+
node = graph.nodes[node_id]
|
|
212
|
+
labels[node_id] = node.operation
|
|
213
|
+
|
|
214
|
+
nx.draw_networkx_labels(G, pos, labels, font_size=10, font_weight="bold")
|
|
215
|
+
|
|
216
|
+
plt.title(f"Neural Architecture (Nodes: {len(graph.nodes)}, Depth: {graph.get_depth()})")
|
|
217
|
+
plt.axis("off")
|
|
218
|
+
plt.tight_layout()
|
|
219
|
+
|
|
220
|
+
if output_path:
|
|
221
|
+
plt.savefig(output_path, dpi=300, bbox_inches="tight")
|
|
222
|
+
logger.info(f"NetworkX plot saved to {output_path}")
|
|
223
|
+
else:
|
|
224
|
+
plt.show()
|
|
225
|
+
|
|
226
|
+
def to_html(self, graph: ModelGraph, output_path: str) -> None:
|
|
227
|
+
"""
|
|
228
|
+
Generate interactive HTML visualization.
|
|
229
|
+
|
|
230
|
+
Args:
|
|
231
|
+
graph: Graph to visualize
|
|
232
|
+
output_path: Output HTML file path
|
|
233
|
+
"""
|
|
234
|
+
try:
|
|
235
|
+
from pyvis.network import Network
|
|
236
|
+
except ImportError:
|
|
237
|
+
logger.warning("pyvis not installed, falling back to simple HTML")
|
|
238
|
+
self._simple_html(graph, output_path)
|
|
239
|
+
return
|
|
240
|
+
|
|
241
|
+
net = Network(height="750px", width="100%", directed=True)
|
|
242
|
+
net.barnes_hut()
|
|
243
|
+
|
|
244
|
+
# Add nodes
|
|
245
|
+
for node_id, node in graph.nodes.items():
|
|
246
|
+
color = self.node_colors.get(node.operation, "#9E9E9E")
|
|
247
|
+
title = f"{node.operation}\n" + "\n".join(f"{k}: {v}" for k, v in node.params.items())
|
|
248
|
+
|
|
249
|
+
net.add_node(node_id, label=node.operation, title=title, color=color)
|
|
250
|
+
|
|
251
|
+
# Add edges
|
|
252
|
+
for edge in graph.edges:
|
|
253
|
+
net.add_edge(edge.source.id, edge.target.id)
|
|
254
|
+
|
|
255
|
+
net.save_graph(output_path)
|
|
256
|
+
logger.info(f"Interactive HTML saved to {output_path}")
|
|
257
|
+
|
|
258
|
+
def _simple_html(self, graph: ModelGraph, output_path: str) -> None:
|
|
259
|
+
"""Generate simple HTML table."""
|
|
260
|
+
html = ["<!DOCTYPE html>"]
|
|
261
|
+
html.append("<html><head><title>Neural Architecture</title>")
|
|
262
|
+
html.append("<style>")
|
|
263
|
+
html.append("body { font-family: Arial; margin: 20px; }")
|
|
264
|
+
html.append("table { border-collapse: collapse; width: 100%; }")
|
|
265
|
+
html.append("th, td { border: 1px solid #ddd; padding: 8px; text-align: left; }")
|
|
266
|
+
html.append("th { background-color: #4CAF50; color: white; }")
|
|
267
|
+
html.append("</style></head><body>")
|
|
268
|
+
html.append("<h1>Neural Architecture</h1>")
|
|
269
|
+
html.append(f"<p><strong>Nodes:</strong> {len(graph.nodes)}</p>")
|
|
270
|
+
html.append(f"<p><strong>Depth:</strong> {graph.get_depth()}</p>")
|
|
271
|
+
html.append(f"<p><strong>Parameters:</strong> {graph.estimate_parameters():,}</p>")
|
|
272
|
+
html.append("<h2>Layers</h2>")
|
|
273
|
+
html.append("<table><tr><th>ID</th><th>Operation</th><th>Parameters</th></tr>")
|
|
274
|
+
|
|
275
|
+
for node_id, node in graph.nodes.items():
|
|
276
|
+
params = ", ".join(f"{k}={v}" for k, v in node.params.items())
|
|
277
|
+
html.append(
|
|
278
|
+
f"<tr><td>{node_id[:8]}</td><td>{node.operation}</td><td>{params}</td></tr>"
|
|
279
|
+
)
|
|
280
|
+
|
|
281
|
+
html.append("</table></body></html>")
|
|
282
|
+
|
|
283
|
+
with open(output_path, "w") as f:
|
|
284
|
+
f.write("\n".join(html))
|
|
285
|
+
|
|
286
|
+
logger.info(f"Simple HTML saved to {output_path}")
|
|
287
|
+
|
|
288
|
+
def compare_graphs(self, graphs: list, labels: list, output_path: Optional[str] = None) -> None:
|
|
289
|
+
"""
|
|
290
|
+
Compare multiple graphs side by side.
|
|
291
|
+
|
|
292
|
+
Args:
|
|
293
|
+
graphs: List of ModelGraph instances
|
|
294
|
+
labels: List of labels for each graph
|
|
295
|
+
output_path: Output path for comparison plot
|
|
296
|
+
"""
|
|
297
|
+
try:
|
|
298
|
+
import matplotlib.pyplot as plt
|
|
299
|
+
except ImportError:
|
|
300
|
+
logger.warning("matplotlib not installed")
|
|
301
|
+
return
|
|
302
|
+
|
|
303
|
+
n = len(graphs)
|
|
304
|
+
fig, axes = plt.subplots(1, n, figsize=(6 * n, 6))
|
|
305
|
+
|
|
306
|
+
if n == 1:
|
|
307
|
+
axes = [axes]
|
|
308
|
+
|
|
309
|
+
for ax, graph, label in zip(axes, graphs, labels):
|
|
310
|
+
# Plot each graph
|
|
311
|
+
try:
|
|
312
|
+
import networkx as nx
|
|
313
|
+
|
|
314
|
+
G = graph.to_networkx()
|
|
315
|
+
pos = nx.spring_layout(G)
|
|
316
|
+
|
|
317
|
+
nx.draw(
|
|
318
|
+
G,
|
|
319
|
+
pos,
|
|
320
|
+
ax=ax,
|
|
321
|
+
with_labels=False,
|
|
322
|
+
node_color="lightblue",
|
|
323
|
+
node_size=500,
|
|
324
|
+
arrows=True,
|
|
325
|
+
)
|
|
326
|
+
|
|
327
|
+
ax.set_title(f"{label}\n(Nodes: {len(graph.nodes)}, Depth: {graph.get_depth()})")
|
|
328
|
+
except Exception as e:
|
|
329
|
+
ax.text(0.5, 0.5, f"Error: {str(e)}", ha="center", va="center")
|
|
330
|
+
ax.set_title(label)
|
|
331
|
+
|
|
332
|
+
plt.tight_layout()
|
|
333
|
+
|
|
334
|
+
if output_path:
|
|
335
|
+
plt.savefig(output_path, dpi=300, bbox_inches="tight")
|
|
336
|
+
logger.info(f"Comparison plot saved to {output_path}")
|
|
337
|
+
else:
|
|
338
|
+
plt.show()
|
|
@@ -0,0 +1,149 @@
|
|
|
1
|
+
"""Pareto front visualization utilities.
|
|
2
|
+
|
|
3
|
+
Author: Eshan Roy <eshanized@proton.me>
|
|
4
|
+
Organization: TONMOY INFRASTRUCTURE & VISION
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from typing import List, Optional
|
|
8
|
+
|
|
9
|
+
from morphml.logging_config import get_logger
|
|
10
|
+
|
|
11
|
+
logger = get_logger(__name__)
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def plot_pareto_front_2d(
|
|
15
|
+
pareto_front: List,
|
|
16
|
+
objective_names: List[str],
|
|
17
|
+
save_path: Optional[str] = None,
|
|
18
|
+
title: str = "Pareto Front",
|
|
19
|
+
) -> None:
|
|
20
|
+
"""
|
|
21
|
+
Plot 2D Pareto front.
|
|
22
|
+
|
|
23
|
+
Args:
|
|
24
|
+
pareto_front: List of MultiObjectiveIndividual objects
|
|
25
|
+
objective_names: Names of the two objectives to plot
|
|
26
|
+
save_path: Path to save plot (displays if None)
|
|
27
|
+
title: Plot title
|
|
28
|
+
|
|
29
|
+
Example:
|
|
30
|
+
>>> from morphml.visualization.pareto_plot import plot_pareto_front_2d
|
|
31
|
+
>>> plot_pareto_front_2d(pareto_front, ['accuracy', 'latency'])
|
|
32
|
+
"""
|
|
33
|
+
try:
|
|
34
|
+
import matplotlib.pyplot as plt
|
|
35
|
+
except ImportError:
|
|
36
|
+
logger.error("matplotlib required for plotting. Install with: pip install matplotlib")
|
|
37
|
+
return
|
|
38
|
+
|
|
39
|
+
if len(objective_names) != 2:
|
|
40
|
+
logger.error("2D plot requires exactly 2 objectives")
|
|
41
|
+
return
|
|
42
|
+
|
|
43
|
+
# Extract objectives
|
|
44
|
+
obj1_name, obj2_name = objective_names
|
|
45
|
+
obj1_values = [ind.objectives[obj1_name] for ind in pareto_front]
|
|
46
|
+
obj2_values = [ind.objectives[obj2_name] for ind in pareto_front]
|
|
47
|
+
|
|
48
|
+
# Create plot
|
|
49
|
+
plt.figure(figsize=(10, 6))
|
|
50
|
+
plt.scatter(obj1_values, obj2_values, c="blue", s=100, alpha=0.6, edgecolors="black")
|
|
51
|
+
plt.xlabel(obj1_name.capitalize(), fontsize=12)
|
|
52
|
+
plt.ylabel(obj2_name.capitalize(), fontsize=12)
|
|
53
|
+
plt.title(title, fontsize=14, fontweight="bold")
|
|
54
|
+
plt.grid(True, alpha=0.3)
|
|
55
|
+
|
|
56
|
+
if save_path:
|
|
57
|
+
plt.savefig(save_path, dpi=300, bbox_inches="tight")
|
|
58
|
+
logger.info(f"Pareto front plot saved to {save_path}")
|
|
59
|
+
else:
|
|
60
|
+
plt.show()
|
|
61
|
+
|
|
62
|
+
plt.close()
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
def plot_pareto_front_3d(
|
|
66
|
+
pareto_front: List,
|
|
67
|
+
objective_names: List[str],
|
|
68
|
+
save_path: Optional[str] = None,
|
|
69
|
+
title: str = "3D Pareto Front",
|
|
70
|
+
) -> None:
|
|
71
|
+
"""
|
|
72
|
+
Plot 3D Pareto front.
|
|
73
|
+
|
|
74
|
+
Args:
|
|
75
|
+
pareto_front: List of MultiObjectiveIndividual objects
|
|
76
|
+
objective_names: Names of the three objectives to plot
|
|
77
|
+
save_path: Path to save plot (displays if None)
|
|
78
|
+
title: Plot title
|
|
79
|
+
"""
|
|
80
|
+
try:
|
|
81
|
+
import matplotlib.pyplot as plt
|
|
82
|
+
except ImportError:
|
|
83
|
+
logger.error("matplotlib required for plotting")
|
|
84
|
+
return
|
|
85
|
+
|
|
86
|
+
if len(objective_names) != 3:
|
|
87
|
+
logger.error("3D plot requires exactly 3 objectives")
|
|
88
|
+
return
|
|
89
|
+
|
|
90
|
+
# Extract objectives
|
|
91
|
+
obj1_name, obj2_name, obj3_name = objective_names
|
|
92
|
+
obj1_values = [ind.objectives[obj1_name] for ind in pareto_front]
|
|
93
|
+
obj2_values = [ind.objectives[obj2_name] for ind in pareto_front]
|
|
94
|
+
obj3_values = [ind.objectives[obj3_name] for ind in pareto_front]
|
|
95
|
+
|
|
96
|
+
# Create 3D plot
|
|
97
|
+
fig = plt.figure(figsize=(12, 8))
|
|
98
|
+
ax = fig.add_subplot(111, projection="3d")
|
|
99
|
+
|
|
100
|
+
ax.scatter(
|
|
101
|
+
obj1_values, obj2_values, obj3_values, c="blue", s=100, alpha=0.6, edgecolors="black"
|
|
102
|
+
)
|
|
103
|
+
|
|
104
|
+
ax.set_xlabel(obj1_name.capitalize(), fontsize=12)
|
|
105
|
+
ax.set_ylabel(obj2_name.capitalize(), fontsize=12)
|
|
106
|
+
ax.set_zlabel(obj3_name.capitalize(), fontsize=12)
|
|
107
|
+
ax.set_title(title, fontsize=14, fontweight="bold")
|
|
108
|
+
|
|
109
|
+
if save_path:
|
|
110
|
+
plt.savefig(save_path, dpi=300, bbox_inches="tight")
|
|
111
|
+
logger.info(f"3D Pareto front plot saved to {save_path}")
|
|
112
|
+
else:
|
|
113
|
+
plt.show()
|
|
114
|
+
|
|
115
|
+
plt.close()
|
|
116
|
+
|
|
117
|
+
|
|
118
|
+
def plot_parallel_coordinates(
|
|
119
|
+
pareto_front: List,
|
|
120
|
+
objective_names: List[str],
|
|
121
|
+
save_path: Optional[str] = None,
|
|
122
|
+
title: str = "Parallel Coordinates Plot",
|
|
123
|
+
) -> None:
|
|
124
|
+
"""
|
|
125
|
+
Plot Pareto front as parallel coordinates.
|
|
126
|
+
|
|
127
|
+
Useful for visualizing many objectives simultaneously.
|
|
128
|
+
|
|
129
|
+
Args:
|
|
130
|
+
pareto_front: List of MultiObjectiveIndividual objects
|
|
131
|
+
objective_names: Names of objectives to plot
|
|
132
|
+
save_path: Path to save plot (displays if None)
|
|
133
|
+
title: Plot title
|
|
134
|
+
"""
|
|
135
|
+
try:
|
|
136
|
+
from morphml.optimizers.multi_objective.visualization import (
|
|
137
|
+
plot_parallel_coordinates as pc_plot,
|
|
138
|
+
)
|
|
139
|
+
|
|
140
|
+
pc_plot(pareto_front, objective_names, save_path, title)
|
|
141
|
+
except ImportError:
|
|
142
|
+
logger.error("Multi-objective visualization module required")
|
|
143
|
+
|
|
144
|
+
|
|
145
|
+
# Re-export from multi_objective module for convenience
|
|
146
|
+
try:
|
|
147
|
+
pass
|
|
148
|
+
except ImportError:
|
|
149
|
+
pass
|