morphml 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of morphml might be problematic. Click here for more details.
- morphml/__init__.py +14 -0
- morphml/api/__init__.py +26 -0
- morphml/api/app.py +326 -0
- morphml/api/auth.py +193 -0
- morphml/api/client.py +338 -0
- morphml/api/models.py +132 -0
- morphml/api/rate_limit.py +192 -0
- morphml/benchmarking/__init__.py +36 -0
- morphml/benchmarking/comparison.py +430 -0
- morphml/benchmarks/__init__.py +56 -0
- morphml/benchmarks/comparator.py +409 -0
- morphml/benchmarks/datasets.py +280 -0
- morphml/benchmarks/metrics.py +199 -0
- morphml/benchmarks/openml_suite.py +201 -0
- morphml/benchmarks/problems.py +289 -0
- morphml/benchmarks/suite.py +318 -0
- morphml/cli/__init__.py +5 -0
- morphml/cli/commands/experiment.py +329 -0
- morphml/cli/main.py +457 -0
- morphml/cli/quickstart.py +312 -0
- morphml/config.py +278 -0
- morphml/constraints/__init__.py +19 -0
- morphml/constraints/handler.py +205 -0
- morphml/constraints/predicates.py +285 -0
- morphml/core/__init__.py +3 -0
- morphml/core/crossover.py +449 -0
- morphml/core/dsl/README.md +359 -0
- morphml/core/dsl/__init__.py +72 -0
- morphml/core/dsl/ast_nodes.py +364 -0
- morphml/core/dsl/compiler.py +318 -0
- morphml/core/dsl/layers.py +368 -0
- morphml/core/dsl/lexer.py +336 -0
- morphml/core/dsl/parser.py +455 -0
- morphml/core/dsl/search_space.py +386 -0
- morphml/core/dsl/syntax.py +199 -0
- morphml/core/dsl/type_system.py +361 -0
- morphml/core/dsl/validator.py +386 -0
- morphml/core/graph/__init__.py +40 -0
- morphml/core/graph/edge.py +124 -0
- morphml/core/graph/graph.py +507 -0
- morphml/core/graph/mutations.py +409 -0
- morphml/core/graph/node.py +196 -0
- morphml/core/graph/serialization.py +361 -0
- morphml/core/graph/visualization.py +431 -0
- morphml/core/objectives/__init__.py +20 -0
- morphml/core/search/__init__.py +33 -0
- morphml/core/search/individual.py +252 -0
- morphml/core/search/parameters.py +453 -0
- morphml/core/search/population.py +375 -0
- morphml/core/search/search_engine.py +340 -0
- morphml/distributed/__init__.py +76 -0
- morphml/distributed/fault_tolerance.py +497 -0
- morphml/distributed/health_monitor.py +348 -0
- morphml/distributed/master.py +709 -0
- morphml/distributed/proto/README.md +224 -0
- morphml/distributed/proto/__init__.py +74 -0
- morphml/distributed/proto/worker.proto +170 -0
- morphml/distributed/proto/worker_pb2.py +79 -0
- morphml/distributed/proto/worker_pb2_grpc.py +423 -0
- morphml/distributed/resource_manager.py +416 -0
- morphml/distributed/scheduler.py +567 -0
- morphml/distributed/storage/__init__.py +33 -0
- morphml/distributed/storage/artifacts.py +381 -0
- morphml/distributed/storage/cache.py +366 -0
- morphml/distributed/storage/checkpointing.py +329 -0
- morphml/distributed/storage/database.py +459 -0
- morphml/distributed/worker.py +549 -0
- morphml/evaluation/__init__.py +5 -0
- morphml/evaluation/heuristic.py +237 -0
- morphml/exceptions.py +55 -0
- morphml/execution/__init__.py +5 -0
- morphml/execution/local_executor.py +350 -0
- morphml/integrations/__init__.py +28 -0
- morphml/integrations/jax_adapter.py +206 -0
- morphml/integrations/pytorch_adapter.py +530 -0
- morphml/integrations/sklearn_adapter.py +206 -0
- morphml/integrations/tensorflow_adapter.py +230 -0
- morphml/logging_config.py +93 -0
- morphml/meta_learning/__init__.py +66 -0
- morphml/meta_learning/architecture_similarity.py +277 -0
- morphml/meta_learning/experiment_database.py +240 -0
- morphml/meta_learning/knowledge_base/__init__.py +19 -0
- morphml/meta_learning/knowledge_base/embedder.py +179 -0
- morphml/meta_learning/knowledge_base/knowledge_base.py +313 -0
- morphml/meta_learning/knowledge_base/meta_features.py +265 -0
- morphml/meta_learning/knowledge_base/vector_store.py +271 -0
- morphml/meta_learning/predictors/__init__.py +27 -0
- morphml/meta_learning/predictors/ensemble.py +221 -0
- morphml/meta_learning/predictors/gnn_predictor.py +552 -0
- morphml/meta_learning/predictors/learning_curve.py +231 -0
- morphml/meta_learning/predictors/proxy_metrics.py +261 -0
- morphml/meta_learning/strategy_evolution/__init__.py +27 -0
- morphml/meta_learning/strategy_evolution/adaptive_optimizer.py +226 -0
- morphml/meta_learning/strategy_evolution/bandit.py +276 -0
- morphml/meta_learning/strategy_evolution/portfolio.py +230 -0
- morphml/meta_learning/transfer.py +581 -0
- morphml/meta_learning/warm_start.py +286 -0
- morphml/optimizers/__init__.py +74 -0
- morphml/optimizers/adaptive_operators.py +399 -0
- morphml/optimizers/bayesian/__init__.py +52 -0
- morphml/optimizers/bayesian/acquisition.py +387 -0
- morphml/optimizers/bayesian/base.py +319 -0
- morphml/optimizers/bayesian/gaussian_process.py +635 -0
- morphml/optimizers/bayesian/smac.py +534 -0
- morphml/optimizers/bayesian/tpe.py +411 -0
- morphml/optimizers/differential_evolution.py +220 -0
- morphml/optimizers/evolutionary/__init__.py +61 -0
- morphml/optimizers/evolutionary/cma_es.py +416 -0
- morphml/optimizers/evolutionary/differential_evolution.py +556 -0
- morphml/optimizers/evolutionary/encoding.py +426 -0
- morphml/optimizers/evolutionary/particle_swarm.py +449 -0
- morphml/optimizers/genetic_algorithm.py +486 -0
- morphml/optimizers/gradient_based/__init__.py +22 -0
- morphml/optimizers/gradient_based/darts.py +550 -0
- morphml/optimizers/gradient_based/enas.py +585 -0
- morphml/optimizers/gradient_based/operations.py +474 -0
- morphml/optimizers/gradient_based/utils.py +601 -0
- morphml/optimizers/hill_climbing.py +169 -0
- morphml/optimizers/multi_objective/__init__.py +56 -0
- morphml/optimizers/multi_objective/indicators.py +504 -0
- morphml/optimizers/multi_objective/nsga2.py +647 -0
- morphml/optimizers/multi_objective/visualization.py +427 -0
- morphml/optimizers/nsga2.py +308 -0
- morphml/optimizers/random_search.py +172 -0
- morphml/optimizers/simulated_annealing.py +181 -0
- morphml/plugins/__init__.py +35 -0
- morphml/plugins/custom_evaluator_example.py +81 -0
- morphml/plugins/custom_optimizer_example.py +63 -0
- morphml/plugins/plugin_system.py +454 -0
- morphml/reports/__init__.py +30 -0
- morphml/reports/generator.py +362 -0
- morphml/tracking/__init__.py +7 -0
- morphml/tracking/experiment.py +309 -0
- morphml/tracking/logger.py +301 -0
- morphml/tracking/reporter.py +357 -0
- morphml/utils/__init__.py +6 -0
- morphml/utils/checkpoint.py +189 -0
- morphml/utils/comparison.py +390 -0
- morphml/utils/export.py +407 -0
- morphml/utils/progress.py +392 -0
- morphml/utils/validation.py +392 -0
- morphml/version.py +7 -0
- morphml/visualization/__init__.py +50 -0
- morphml/visualization/analytics.py +423 -0
- morphml/visualization/architecture_diagrams.py +353 -0
- morphml/visualization/architecture_plot.py +223 -0
- morphml/visualization/convergence_plot.py +174 -0
- morphml/visualization/crossover_viz.py +386 -0
- morphml/visualization/graph_viz.py +338 -0
- morphml/visualization/pareto_plot.py +149 -0
- morphml/visualization/plotly_dashboards.py +422 -0
- morphml/visualization/population.py +309 -0
- morphml/visualization/progress.py +260 -0
- morphml-1.0.0.dist-info/METADATA +434 -0
- morphml-1.0.0.dist-info/RECORD +158 -0
- morphml-1.0.0.dist-info/WHEEL +4 -0
- morphml-1.0.0.dist-info/entry_points.txt +3 -0
- morphml-1.0.0.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,431 @@
|
|
|
1
|
+
"""Visualization utilities for model graphs.
|
|
2
|
+
|
|
3
|
+
Provides functions to visualize ModelGraph structures using matplotlib
|
|
4
|
+
and networkx for graph plotting.
|
|
5
|
+
|
|
6
|
+
Author: Eshan Roy <eshanized@proton.me>
|
|
7
|
+
Organization: TONMOY INFRASTRUCTURE & VISION
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
from pathlib import Path
|
|
11
|
+
from typing import Dict, Optional, Tuple, Union
|
|
12
|
+
|
|
13
|
+
from morphml.core.graph.graph import ModelGraph
|
|
14
|
+
from morphml.exceptions import GraphError
|
|
15
|
+
from morphml.logging_config import get_logger
|
|
16
|
+
|
|
17
|
+
logger = get_logger(__name__)
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def plot_graph(
|
|
21
|
+
graph: ModelGraph,
|
|
22
|
+
output_path: Optional[Union[str, Path]] = None,
|
|
23
|
+
layout: str = "hierarchical",
|
|
24
|
+
figsize: Tuple[int, int] = (12, 8),
|
|
25
|
+
node_size: int = 3000,
|
|
26
|
+
font_size: int = 10,
|
|
27
|
+
with_labels: bool = True,
|
|
28
|
+
show_params: bool = False,
|
|
29
|
+
dpi: int = 150,
|
|
30
|
+
) -> None:
|
|
31
|
+
"""
|
|
32
|
+
Visualize model graph structure.
|
|
33
|
+
|
|
34
|
+
Args:
|
|
35
|
+
graph: ModelGraph to visualize
|
|
36
|
+
output_path: Path to save figure (if None, displays interactively)
|
|
37
|
+
layout: Layout algorithm ('hierarchical', 'spring', 'circular', 'kamada_kawai')
|
|
38
|
+
figsize: Figure size (width, height)
|
|
39
|
+
node_size: Size of nodes
|
|
40
|
+
font_size: Font size for labels
|
|
41
|
+
with_labels: Whether to show node labels
|
|
42
|
+
show_params: Whether to show parameter counts
|
|
43
|
+
dpi: Resolution for saved figure
|
|
44
|
+
|
|
45
|
+
Example:
|
|
46
|
+
>>> plot_graph(graph, 'model.png')
|
|
47
|
+
>>> plot_graph(graph, layout='spring', show_params=True)
|
|
48
|
+
"""
|
|
49
|
+
try:
|
|
50
|
+
import matplotlib.pyplot as plt
|
|
51
|
+
import networkx as nx
|
|
52
|
+
except ImportError:
|
|
53
|
+
raise GraphError(
|
|
54
|
+
"Visualization requires matplotlib and networkx. "
|
|
55
|
+
"Install with: pip install matplotlib networkx"
|
|
56
|
+
)
|
|
57
|
+
|
|
58
|
+
# Create NetworkX graph
|
|
59
|
+
G = nx.DiGraph()
|
|
60
|
+
|
|
61
|
+
# Add nodes with attributes
|
|
62
|
+
for node_id, node in graph.nodes.items():
|
|
63
|
+
label = node.operation
|
|
64
|
+
if show_params:
|
|
65
|
+
params = graph.estimate_parameters()
|
|
66
|
+
label = f"{node.operation}\n({params:,})"
|
|
67
|
+
G.add_node(node_id, label=label, operation=node.operation)
|
|
68
|
+
|
|
69
|
+
# Add edges
|
|
70
|
+
for _edge_id, edge in graph.edges.items():
|
|
71
|
+
if edge.source and edge.target:
|
|
72
|
+
G.add_edge(edge.source.id, edge.target.id)
|
|
73
|
+
|
|
74
|
+
# Create figure
|
|
75
|
+
fig, ax = plt.subplots(figsize=figsize)
|
|
76
|
+
|
|
77
|
+
# Choose layout
|
|
78
|
+
if layout == "hierarchical":
|
|
79
|
+
pos = _hierarchical_layout(graph, G)
|
|
80
|
+
elif layout == "spring":
|
|
81
|
+
pos = nx.spring_layout(G, k=2, iterations=50)
|
|
82
|
+
elif layout == "circular":
|
|
83
|
+
pos = nx.circular_layout(G)
|
|
84
|
+
elif layout == "kamada_kawai":
|
|
85
|
+
pos = nx.kamada_kawai_layout(G)
|
|
86
|
+
else:
|
|
87
|
+
logger.warning(f"Unknown layout '{layout}', using spring")
|
|
88
|
+
pos = nx.spring_layout(G)
|
|
89
|
+
|
|
90
|
+
# Color nodes by operation type
|
|
91
|
+
node_colors = _get_node_colors(graph)
|
|
92
|
+
|
|
93
|
+
# Draw graph
|
|
94
|
+
nx.draw_networkx_nodes(G, pos, node_color=node_colors, node_size=node_size, alpha=0.9, ax=ax)
|
|
95
|
+
|
|
96
|
+
nx.draw_networkx_edges(G, pos, edge_color="gray", arrows=True, arrowsize=20, alpha=0.6, ax=ax)
|
|
97
|
+
|
|
98
|
+
if with_labels:
|
|
99
|
+
labels = nx.get_node_attributes(G, "label")
|
|
100
|
+
nx.draw_networkx_labels(G, pos, labels, font_size=font_size, ax=ax)
|
|
101
|
+
|
|
102
|
+
# Add title
|
|
103
|
+
ax.set_title(
|
|
104
|
+
f"Model Graph: {len(graph.nodes)} nodes, "
|
|
105
|
+
f"{len(graph.edges)} edges, "
|
|
106
|
+
f"depth {graph.get_depth()}",
|
|
107
|
+
fontsize=14,
|
|
108
|
+
fontweight="bold",
|
|
109
|
+
)
|
|
110
|
+
|
|
111
|
+
ax.axis("off")
|
|
112
|
+
plt.tight_layout()
|
|
113
|
+
|
|
114
|
+
# Save or show
|
|
115
|
+
if output_path:
|
|
116
|
+
output_path = Path(output_path)
|
|
117
|
+
output_path.parent.mkdir(parents=True, exist_ok=True)
|
|
118
|
+
plt.savefig(output_path, dpi=dpi, bbox_inches="tight")
|
|
119
|
+
logger.info(f"Saved graph visualization to {output_path}")
|
|
120
|
+
plt.close()
|
|
121
|
+
else:
|
|
122
|
+
plt.show()
|
|
123
|
+
|
|
124
|
+
|
|
125
|
+
def _hierarchical_layout(graph: ModelGraph, G) -> Dict:
|
|
126
|
+
"""
|
|
127
|
+
Create hierarchical layout based on topological order.
|
|
128
|
+
|
|
129
|
+
Args:
|
|
130
|
+
graph: ModelGraph
|
|
131
|
+
G: NetworkX graph
|
|
132
|
+
|
|
133
|
+
Returns:
|
|
134
|
+
Position dictionary
|
|
135
|
+
"""
|
|
136
|
+
try:
|
|
137
|
+
topo_order = graph.topological_sort()
|
|
138
|
+
except Exception:
|
|
139
|
+
# Fallback to spring layout if topological sort fails
|
|
140
|
+
import networkx as nx
|
|
141
|
+
|
|
142
|
+
return nx.spring_layout(G)
|
|
143
|
+
|
|
144
|
+
# Group nodes by depth level
|
|
145
|
+
levels: Dict[int, list] = {}
|
|
146
|
+
for i, node in enumerate(topo_order):
|
|
147
|
+
level = _compute_node_depth(graph, node.id)
|
|
148
|
+
if level not in levels:
|
|
149
|
+
levels[level] = []
|
|
150
|
+
levels[level].append(node.id)
|
|
151
|
+
|
|
152
|
+
# Assign positions
|
|
153
|
+
pos = {}
|
|
154
|
+
max(len(nodes) for nodes in levels.values())
|
|
155
|
+
|
|
156
|
+
for level, node_ids in sorted(levels.items()):
|
|
157
|
+
y = -level # Negative so it goes downward
|
|
158
|
+
width = len(node_ids)
|
|
159
|
+
x_start = -(width - 1) / 2
|
|
160
|
+
|
|
161
|
+
for i, node_id in enumerate(node_ids):
|
|
162
|
+
x = x_start + i
|
|
163
|
+
pos[node_id] = (x * 2, y * 2) # Scale for spacing
|
|
164
|
+
|
|
165
|
+
return pos
|
|
166
|
+
|
|
167
|
+
|
|
168
|
+
def _compute_node_depth(graph: ModelGraph, node_id: str) -> int:
|
|
169
|
+
"""Compute depth of node from input."""
|
|
170
|
+
visited = set()
|
|
171
|
+
queue = [(graph.get_input_nodes()[0].id if graph.get_input_nodes() else node_id, 0)]
|
|
172
|
+
depths = {node_id: 0 for node_id in graph.nodes}
|
|
173
|
+
|
|
174
|
+
while queue:
|
|
175
|
+
current_id, depth = queue.pop(0)
|
|
176
|
+
if current_id in visited:
|
|
177
|
+
continue
|
|
178
|
+
visited.add(current_id)
|
|
179
|
+
depths[current_id] = depth
|
|
180
|
+
|
|
181
|
+
node = graph.nodes.get(current_id)
|
|
182
|
+
if node:
|
|
183
|
+
for succ in node.successors:
|
|
184
|
+
if succ.id not in visited:
|
|
185
|
+
queue.append((succ.id, depth + 1))
|
|
186
|
+
|
|
187
|
+
return depths.get(node_id, 0)
|
|
188
|
+
|
|
189
|
+
|
|
190
|
+
def _get_node_colors(graph: ModelGraph) -> list:
|
|
191
|
+
"""
|
|
192
|
+
Assign colors to nodes based on operation type.
|
|
193
|
+
|
|
194
|
+
Args:
|
|
195
|
+
graph: ModelGraph
|
|
196
|
+
|
|
197
|
+
Returns:
|
|
198
|
+
List of colors for each node
|
|
199
|
+
"""
|
|
200
|
+
color_map = {
|
|
201
|
+
"input": "#90EE90", # Light green
|
|
202
|
+
"output": "#FFB6C1", # Light pink
|
|
203
|
+
"conv2d": "#87CEEB", # Sky blue
|
|
204
|
+
"conv3d": "#87CEEB",
|
|
205
|
+
"dense": "#DDA0DD", # Plum
|
|
206
|
+
"linear": "#DDA0DD",
|
|
207
|
+
"maxpool": "#F0E68C", # Khaki
|
|
208
|
+
"avgpool": "#F0E68C",
|
|
209
|
+
"max_pool": "#F0E68C",
|
|
210
|
+
"avg_pool": "#F0E68C",
|
|
211
|
+
"dropout": "#FFE4B5", # Moccasin
|
|
212
|
+
"batch_norm": "#B0E0E6", # Powder blue
|
|
213
|
+
"layer_norm": "#B0E0E6",
|
|
214
|
+
"batchnorm": "#B0E0E6",
|
|
215
|
+
"relu": "#FFA07A", # Light salmon
|
|
216
|
+
"elu": "#FFA07A",
|
|
217
|
+
"gelu": "#FFA07A",
|
|
218
|
+
"sigmoid": "#FFA07A",
|
|
219
|
+
"tanh": "#FFA07A",
|
|
220
|
+
"flatten": "#D3D3D3", # Light gray
|
|
221
|
+
}
|
|
222
|
+
|
|
223
|
+
colors = []
|
|
224
|
+
for node_id in graph.nodes:
|
|
225
|
+
node = graph.nodes[node_id]
|
|
226
|
+
color = color_map.get(node.operation, "#CCCCCC") # Default gray
|
|
227
|
+
colors.append(color)
|
|
228
|
+
|
|
229
|
+
return colors
|
|
230
|
+
|
|
231
|
+
|
|
232
|
+
def plot_training_history(
|
|
233
|
+
history: Dict,
|
|
234
|
+
output_path: Optional[Union[str, Path]] = None,
|
|
235
|
+
figsize: Tuple[int, int] = (12, 6),
|
|
236
|
+
dpi: int = 150,
|
|
237
|
+
) -> None:
|
|
238
|
+
"""
|
|
239
|
+
Plot training history (fitness over generations).
|
|
240
|
+
|
|
241
|
+
Args:
|
|
242
|
+
history: Dictionary with 'generation', 'best_fitness', 'mean_fitness', etc.
|
|
243
|
+
output_path: Path to save figure
|
|
244
|
+
figsize: Figure size
|
|
245
|
+
dpi: Resolution
|
|
246
|
+
|
|
247
|
+
Example:
|
|
248
|
+
>>> history = optimizer.get_history()
|
|
249
|
+
>>> plot_training_history(history, 'training.png')
|
|
250
|
+
"""
|
|
251
|
+
try:
|
|
252
|
+
import matplotlib.pyplot as plt
|
|
253
|
+
except ImportError:
|
|
254
|
+
raise GraphError("Visualization requires matplotlib. Install with: pip install matplotlib")
|
|
255
|
+
|
|
256
|
+
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=figsize)
|
|
257
|
+
|
|
258
|
+
# Extract data
|
|
259
|
+
generations = [entry.get("generation", i) for i, entry in enumerate(history)]
|
|
260
|
+
best_fitness = [entry.get("best_fitness", 0) for entry in history]
|
|
261
|
+
mean_fitness = [entry.get("mean_fitness", 0) for entry in history]
|
|
262
|
+
|
|
263
|
+
# Plot fitness evolution
|
|
264
|
+
ax1.plot(generations, best_fitness, label="Best Fitness", marker="o", linewidth=2)
|
|
265
|
+
ax1.plot(generations, mean_fitness, label="Mean Fitness", marker="s", linewidth=2, alpha=0.7)
|
|
266
|
+
ax1.set_xlabel("Generation")
|
|
267
|
+
ax1.set_ylabel("Fitness")
|
|
268
|
+
ax1.set_title("Fitness Evolution")
|
|
269
|
+
ax1.legend()
|
|
270
|
+
ax1.grid(True, alpha=0.3)
|
|
271
|
+
|
|
272
|
+
# Plot diversity if available
|
|
273
|
+
if "diversity" in history[0]:
|
|
274
|
+
diversity = [entry.get("diversity", 0) for entry in history]
|
|
275
|
+
ax2.plot(generations, diversity, label="Diversity", marker="^", linewidth=2, color="green")
|
|
276
|
+
ax2.set_xlabel("Generation")
|
|
277
|
+
ax2.set_ylabel("Diversity")
|
|
278
|
+
ax2.set_title("Population Diversity")
|
|
279
|
+
ax2.legend()
|
|
280
|
+
ax2.grid(True, alpha=0.3)
|
|
281
|
+
else:
|
|
282
|
+
ax2.text(
|
|
283
|
+
0.5,
|
|
284
|
+
0.5,
|
|
285
|
+
"Diversity data not available",
|
|
286
|
+
ha="center",
|
|
287
|
+
va="center",
|
|
288
|
+
transform=ax2.transAxes,
|
|
289
|
+
)
|
|
290
|
+
ax2.axis("off")
|
|
291
|
+
|
|
292
|
+
plt.tight_layout()
|
|
293
|
+
|
|
294
|
+
if output_path:
|
|
295
|
+
output_path = Path(output_path)
|
|
296
|
+
output_path.parent.mkdir(parents=True, exist_ok=True)
|
|
297
|
+
plt.savefig(output_path, dpi=dpi, bbox_inches="tight")
|
|
298
|
+
logger.info(f"Saved training history to {output_path}")
|
|
299
|
+
plt.close()
|
|
300
|
+
else:
|
|
301
|
+
plt.show()
|
|
302
|
+
|
|
303
|
+
|
|
304
|
+
def plot_architecture_comparison(
|
|
305
|
+
graphs: Dict[str, ModelGraph],
|
|
306
|
+
output_path: Optional[Union[str, Path]] = None,
|
|
307
|
+
figsize: Tuple[int, int] = (14, 8),
|
|
308
|
+
dpi: int = 150,
|
|
309
|
+
) -> None:
|
|
310
|
+
"""
|
|
311
|
+
Compare multiple architectures side by side.
|
|
312
|
+
|
|
313
|
+
Args:
|
|
314
|
+
graphs: Dictionary mapping names to ModelGraphs
|
|
315
|
+
output_path: Path to save figure
|
|
316
|
+
figsize: Figure size
|
|
317
|
+
dpi: Resolution
|
|
318
|
+
|
|
319
|
+
Example:
|
|
320
|
+
>>> graphs = {'Model A': graph1, 'Model B': graph2}
|
|
321
|
+
>>> plot_architecture_comparison(graphs, 'comparison.png')
|
|
322
|
+
"""
|
|
323
|
+
try:
|
|
324
|
+
import matplotlib.pyplot as plt
|
|
325
|
+
except ImportError:
|
|
326
|
+
raise GraphError("Visualization requires matplotlib. Install with: pip install matplotlib")
|
|
327
|
+
|
|
328
|
+
n_models = len(graphs)
|
|
329
|
+
fig, axes = plt.subplots(1, n_models, figsize=figsize)
|
|
330
|
+
|
|
331
|
+
if n_models == 1:
|
|
332
|
+
axes = [axes]
|
|
333
|
+
|
|
334
|
+
for ax, (name, graph) in zip(axes, graphs.items()):
|
|
335
|
+
# Plot individual graph
|
|
336
|
+
_plot_graph_on_axis(graph, ax, title=name)
|
|
337
|
+
|
|
338
|
+
plt.tight_layout()
|
|
339
|
+
|
|
340
|
+
if output_path:
|
|
341
|
+
output_path = Path(output_path)
|
|
342
|
+
output_path.parent.mkdir(parents=True, exist_ok=True)
|
|
343
|
+
plt.savefig(output_path, dpi=dpi, bbox_inches="tight")
|
|
344
|
+
logger.info(f"Saved comparison to {output_path}")
|
|
345
|
+
plt.close()
|
|
346
|
+
else:
|
|
347
|
+
plt.show()
|
|
348
|
+
|
|
349
|
+
|
|
350
|
+
def _plot_graph_on_axis(graph: ModelGraph, ax, title: str) -> None:
|
|
351
|
+
"""Plot graph on specific matplotlib axis."""
|
|
352
|
+
try:
|
|
353
|
+
import networkx as nx
|
|
354
|
+
except ImportError:
|
|
355
|
+
return
|
|
356
|
+
|
|
357
|
+
G = nx.DiGraph()
|
|
358
|
+
|
|
359
|
+
for node_id, node in graph.nodes.items():
|
|
360
|
+
G.add_node(node_id, label=node.operation)
|
|
361
|
+
|
|
362
|
+
for _edge_id, edge in graph.edges.items():
|
|
363
|
+
if edge.source and edge.target:
|
|
364
|
+
G.add_edge(edge.source.id, edge.target.id)
|
|
365
|
+
|
|
366
|
+
pos = nx.spring_layout(G, k=1, iterations=30)
|
|
367
|
+
node_colors = _get_node_colors(graph)
|
|
368
|
+
|
|
369
|
+
nx.draw_networkx_nodes(G, pos, node_color=node_colors, node_size=1500, alpha=0.9, ax=ax)
|
|
370
|
+
nx.draw_networkx_edges(G, pos, edge_color="gray", arrows=True, arrowsize=15, alpha=0.6, ax=ax)
|
|
371
|
+
|
|
372
|
+
labels = nx.get_node_attributes(G, "label")
|
|
373
|
+
nx.draw_networkx_labels(G, pos, labels, font_size=8, ax=ax)
|
|
374
|
+
|
|
375
|
+
ax.set_title(f"{title}\n{len(graph.nodes)} nodes, depth {graph.get_depth()}", fontsize=10)
|
|
376
|
+
ax.axis("off")
|
|
377
|
+
|
|
378
|
+
|
|
379
|
+
def export_graphviz(graph: ModelGraph, output_path: Union[str, Path]) -> None:
|
|
380
|
+
"""
|
|
381
|
+
Export graph in Graphviz DOT format.
|
|
382
|
+
|
|
383
|
+
Args:
|
|
384
|
+
graph: ModelGraph to export
|
|
385
|
+
output_path: Output .dot file path
|
|
386
|
+
|
|
387
|
+
Example:
|
|
388
|
+
>>> export_graphviz(graph, 'model.dot')
|
|
389
|
+
>>> # Then: dot -Tpng model.dot -o model.png
|
|
390
|
+
"""
|
|
391
|
+
output_path = Path(output_path)
|
|
392
|
+
output_path.parent.mkdir(parents=True, exist_ok=True)
|
|
393
|
+
|
|
394
|
+
lines = ["digraph ModelGraph {", " rankdir=TB;", " node [shape=box, style=filled];", ""]
|
|
395
|
+
|
|
396
|
+
# Add nodes
|
|
397
|
+
for node_id, node in graph.nodes.items():
|
|
398
|
+
short_id = node_id[:8]
|
|
399
|
+
color = _get_graphviz_color(node.operation)
|
|
400
|
+
lines.append(f' "{short_id}" [label="{node.operation}", fillcolor="{color}"];')
|
|
401
|
+
|
|
402
|
+
lines.append("")
|
|
403
|
+
|
|
404
|
+
# Add edges
|
|
405
|
+
for _edge_id, edge in graph.edges.items():
|
|
406
|
+
if edge.source and edge.target:
|
|
407
|
+
src = edge.source.id[:8]
|
|
408
|
+
tgt = edge.target.id[:8]
|
|
409
|
+
lines.append(f' "{src}" -> "{tgt}";')
|
|
410
|
+
|
|
411
|
+
lines.append("}")
|
|
412
|
+
|
|
413
|
+
with open(output_path, "w") as f:
|
|
414
|
+
f.write("\n".join(lines))
|
|
415
|
+
|
|
416
|
+
logger.info(f"Exported Graphviz DOT to {output_path}")
|
|
417
|
+
|
|
418
|
+
|
|
419
|
+
def _get_graphviz_color(operation: str) -> str:
|
|
420
|
+
"""Get Graphviz color for operation."""
|
|
421
|
+
color_map = {
|
|
422
|
+
"input": "lightgreen",
|
|
423
|
+
"output": "lightpink",
|
|
424
|
+
"conv2d": "skyblue",
|
|
425
|
+
"dense": "plum",
|
|
426
|
+
"maxpool": "khaki",
|
|
427
|
+
"dropout": "moccasin",
|
|
428
|
+
"batch_norm": "powderblue",
|
|
429
|
+
"relu": "lightsalmon",
|
|
430
|
+
}
|
|
431
|
+
return color_map.get(operation, "lightgray")
|
|
@@ -0,0 +1,20 @@
|
|
|
1
|
+
"""Objective functions and multi-objective support.
|
|
2
|
+
|
|
3
|
+
This module provides utilities for defining and evaluating optimization objectives:
|
|
4
|
+
- Single objective evaluation
|
|
5
|
+
- Multi-objective evaluation
|
|
6
|
+
- Pareto dominance relationships
|
|
7
|
+
- Quality indicators (hypervolume, IGD)
|
|
8
|
+
|
|
9
|
+
Example:
|
|
10
|
+
>>> from morphml.core.objectives import MultiObjectiveEvaluator
|
|
11
|
+
>>> evaluator = MultiObjectiveEvaluator(
|
|
12
|
+
... objectives={
|
|
13
|
+
... 'accuracy': lambda g: train_and_evaluate(g),
|
|
14
|
+
... 'latency': lambda g: estimate_latency(g),
|
|
15
|
+
... 'parameters': lambda g: count_parameters(g)
|
|
16
|
+
... }
|
|
17
|
+
... )
|
|
18
|
+
"""
|
|
19
|
+
|
|
20
|
+
__all__ = [] # Will be populated as components are implemented
|
|
@@ -0,0 +1,33 @@
|
|
|
1
|
+
"""Search space and optimization components."""
|
|
2
|
+
|
|
3
|
+
from morphml.core.search.individual import Individual
|
|
4
|
+
from morphml.core.search.parameters import (
|
|
5
|
+
BooleanParameter,
|
|
6
|
+
CategoricalParameter,
|
|
7
|
+
ConstantParameter,
|
|
8
|
+
FloatParameter,
|
|
9
|
+
IntegerParameter,
|
|
10
|
+
Parameter,
|
|
11
|
+
create_parameter,
|
|
12
|
+
)
|
|
13
|
+
from morphml.core.search.population import Population
|
|
14
|
+
from morphml.core.search.search_engine import (
|
|
15
|
+
GridSearchEngine,
|
|
16
|
+
RandomSearchEngine,
|
|
17
|
+
SearchEngine,
|
|
18
|
+
)
|
|
19
|
+
|
|
20
|
+
__all__ = [
|
|
21
|
+
"Individual",
|
|
22
|
+
"Population",
|
|
23
|
+
"Parameter",
|
|
24
|
+
"CategoricalParameter",
|
|
25
|
+
"IntegerParameter",
|
|
26
|
+
"FloatParameter",
|
|
27
|
+
"BooleanParameter",
|
|
28
|
+
"ConstantParameter",
|
|
29
|
+
"create_parameter",
|
|
30
|
+
"SearchEngine",
|
|
31
|
+
"RandomSearchEngine",
|
|
32
|
+
"GridSearchEngine",
|
|
33
|
+
]
|