morphml 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of morphml might be problematic. Click here for more details.
- morphml/__init__.py +14 -0
- morphml/api/__init__.py +26 -0
- morphml/api/app.py +326 -0
- morphml/api/auth.py +193 -0
- morphml/api/client.py +338 -0
- morphml/api/models.py +132 -0
- morphml/api/rate_limit.py +192 -0
- morphml/benchmarking/__init__.py +36 -0
- morphml/benchmarking/comparison.py +430 -0
- morphml/benchmarks/__init__.py +56 -0
- morphml/benchmarks/comparator.py +409 -0
- morphml/benchmarks/datasets.py +280 -0
- morphml/benchmarks/metrics.py +199 -0
- morphml/benchmarks/openml_suite.py +201 -0
- morphml/benchmarks/problems.py +289 -0
- morphml/benchmarks/suite.py +318 -0
- morphml/cli/__init__.py +5 -0
- morphml/cli/commands/experiment.py +329 -0
- morphml/cli/main.py +457 -0
- morphml/cli/quickstart.py +312 -0
- morphml/config.py +278 -0
- morphml/constraints/__init__.py +19 -0
- morphml/constraints/handler.py +205 -0
- morphml/constraints/predicates.py +285 -0
- morphml/core/__init__.py +3 -0
- morphml/core/crossover.py +449 -0
- morphml/core/dsl/README.md +359 -0
- morphml/core/dsl/__init__.py +72 -0
- morphml/core/dsl/ast_nodes.py +364 -0
- morphml/core/dsl/compiler.py +318 -0
- morphml/core/dsl/layers.py +368 -0
- morphml/core/dsl/lexer.py +336 -0
- morphml/core/dsl/parser.py +455 -0
- morphml/core/dsl/search_space.py +386 -0
- morphml/core/dsl/syntax.py +199 -0
- morphml/core/dsl/type_system.py +361 -0
- morphml/core/dsl/validator.py +386 -0
- morphml/core/graph/__init__.py +40 -0
- morphml/core/graph/edge.py +124 -0
- morphml/core/graph/graph.py +507 -0
- morphml/core/graph/mutations.py +409 -0
- morphml/core/graph/node.py +196 -0
- morphml/core/graph/serialization.py +361 -0
- morphml/core/graph/visualization.py +431 -0
- morphml/core/objectives/__init__.py +20 -0
- morphml/core/search/__init__.py +33 -0
- morphml/core/search/individual.py +252 -0
- morphml/core/search/parameters.py +453 -0
- morphml/core/search/population.py +375 -0
- morphml/core/search/search_engine.py +340 -0
- morphml/distributed/__init__.py +76 -0
- morphml/distributed/fault_tolerance.py +497 -0
- morphml/distributed/health_monitor.py +348 -0
- morphml/distributed/master.py +709 -0
- morphml/distributed/proto/README.md +224 -0
- morphml/distributed/proto/__init__.py +74 -0
- morphml/distributed/proto/worker.proto +170 -0
- morphml/distributed/proto/worker_pb2.py +79 -0
- morphml/distributed/proto/worker_pb2_grpc.py +423 -0
- morphml/distributed/resource_manager.py +416 -0
- morphml/distributed/scheduler.py +567 -0
- morphml/distributed/storage/__init__.py +33 -0
- morphml/distributed/storage/artifacts.py +381 -0
- morphml/distributed/storage/cache.py +366 -0
- morphml/distributed/storage/checkpointing.py +329 -0
- morphml/distributed/storage/database.py +459 -0
- morphml/distributed/worker.py +549 -0
- morphml/evaluation/__init__.py +5 -0
- morphml/evaluation/heuristic.py +237 -0
- morphml/exceptions.py +55 -0
- morphml/execution/__init__.py +5 -0
- morphml/execution/local_executor.py +350 -0
- morphml/integrations/__init__.py +28 -0
- morphml/integrations/jax_adapter.py +206 -0
- morphml/integrations/pytorch_adapter.py +530 -0
- morphml/integrations/sklearn_adapter.py +206 -0
- morphml/integrations/tensorflow_adapter.py +230 -0
- morphml/logging_config.py +93 -0
- morphml/meta_learning/__init__.py +66 -0
- morphml/meta_learning/architecture_similarity.py +277 -0
- morphml/meta_learning/experiment_database.py +240 -0
- morphml/meta_learning/knowledge_base/__init__.py +19 -0
- morphml/meta_learning/knowledge_base/embedder.py +179 -0
- morphml/meta_learning/knowledge_base/knowledge_base.py +313 -0
- morphml/meta_learning/knowledge_base/meta_features.py +265 -0
- morphml/meta_learning/knowledge_base/vector_store.py +271 -0
- morphml/meta_learning/predictors/__init__.py +27 -0
- morphml/meta_learning/predictors/ensemble.py +221 -0
- morphml/meta_learning/predictors/gnn_predictor.py +552 -0
- morphml/meta_learning/predictors/learning_curve.py +231 -0
- morphml/meta_learning/predictors/proxy_metrics.py +261 -0
- morphml/meta_learning/strategy_evolution/__init__.py +27 -0
- morphml/meta_learning/strategy_evolution/adaptive_optimizer.py +226 -0
- morphml/meta_learning/strategy_evolution/bandit.py +276 -0
- morphml/meta_learning/strategy_evolution/portfolio.py +230 -0
- morphml/meta_learning/transfer.py +581 -0
- morphml/meta_learning/warm_start.py +286 -0
- morphml/optimizers/__init__.py +74 -0
- morphml/optimizers/adaptive_operators.py +399 -0
- morphml/optimizers/bayesian/__init__.py +52 -0
- morphml/optimizers/bayesian/acquisition.py +387 -0
- morphml/optimizers/bayesian/base.py +319 -0
- morphml/optimizers/bayesian/gaussian_process.py +635 -0
- morphml/optimizers/bayesian/smac.py +534 -0
- morphml/optimizers/bayesian/tpe.py +411 -0
- morphml/optimizers/differential_evolution.py +220 -0
- morphml/optimizers/evolutionary/__init__.py +61 -0
- morphml/optimizers/evolutionary/cma_es.py +416 -0
- morphml/optimizers/evolutionary/differential_evolution.py +556 -0
- morphml/optimizers/evolutionary/encoding.py +426 -0
- morphml/optimizers/evolutionary/particle_swarm.py +449 -0
- morphml/optimizers/genetic_algorithm.py +486 -0
- morphml/optimizers/gradient_based/__init__.py +22 -0
- morphml/optimizers/gradient_based/darts.py +550 -0
- morphml/optimizers/gradient_based/enas.py +585 -0
- morphml/optimizers/gradient_based/operations.py +474 -0
- morphml/optimizers/gradient_based/utils.py +601 -0
- morphml/optimizers/hill_climbing.py +169 -0
- morphml/optimizers/multi_objective/__init__.py +56 -0
- morphml/optimizers/multi_objective/indicators.py +504 -0
- morphml/optimizers/multi_objective/nsga2.py +647 -0
- morphml/optimizers/multi_objective/visualization.py +427 -0
- morphml/optimizers/nsga2.py +308 -0
- morphml/optimizers/random_search.py +172 -0
- morphml/optimizers/simulated_annealing.py +181 -0
- morphml/plugins/__init__.py +35 -0
- morphml/plugins/custom_evaluator_example.py +81 -0
- morphml/plugins/custom_optimizer_example.py +63 -0
- morphml/plugins/plugin_system.py +454 -0
- morphml/reports/__init__.py +30 -0
- morphml/reports/generator.py +362 -0
- morphml/tracking/__init__.py +7 -0
- morphml/tracking/experiment.py +309 -0
- morphml/tracking/logger.py +301 -0
- morphml/tracking/reporter.py +357 -0
- morphml/utils/__init__.py +6 -0
- morphml/utils/checkpoint.py +189 -0
- morphml/utils/comparison.py +390 -0
- morphml/utils/export.py +407 -0
- morphml/utils/progress.py +392 -0
- morphml/utils/validation.py +392 -0
- morphml/version.py +7 -0
- morphml/visualization/__init__.py +50 -0
- morphml/visualization/analytics.py +423 -0
- morphml/visualization/architecture_diagrams.py +353 -0
- morphml/visualization/architecture_plot.py +223 -0
- morphml/visualization/convergence_plot.py +174 -0
- morphml/visualization/crossover_viz.py +386 -0
- morphml/visualization/graph_viz.py +338 -0
- morphml/visualization/pareto_plot.py +149 -0
- morphml/visualization/plotly_dashboards.py +422 -0
- morphml/visualization/population.py +309 -0
- morphml/visualization/progress.py +260 -0
- morphml-1.0.0.dist-info/METADATA +434 -0
- morphml-1.0.0.dist-info/RECORD +158 -0
- morphml-1.0.0.dist-info/WHEEL +4 -0
- morphml-1.0.0.dist-info/entry_points.txt +3 -0
- morphml-1.0.0.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,507 @@
|
|
|
1
|
+
"""Model graph representation for neural architectures."""
|
|
2
|
+
|
|
3
|
+
import hashlib
|
|
4
|
+
import json
|
|
5
|
+
from typing import Any, Dict, List, Optional, Set
|
|
6
|
+
|
|
7
|
+
import networkx as nx
|
|
8
|
+
|
|
9
|
+
from morphml.core.graph.edge import GraphEdge
|
|
10
|
+
from morphml.core.graph.node import GraphNode
|
|
11
|
+
from morphml.exceptions import GraphError
|
|
12
|
+
from morphml.logging_config import get_logger
|
|
13
|
+
|
|
14
|
+
logger = get_logger(__name__)
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class ModelGraph:
|
|
18
|
+
"""
|
|
19
|
+
Directed Acyclic Graph (DAG) representation of a neural architecture.
|
|
20
|
+
|
|
21
|
+
A ModelGraph consists of:
|
|
22
|
+
- Nodes: Operations/layers (conv2d, maxpool, dense, etc.)
|
|
23
|
+
- Edges: Connections between operations
|
|
24
|
+
- Metadata: Additional information about the architecture
|
|
25
|
+
|
|
26
|
+
The graph must be a valid DAG (no cycles) and have exactly one
|
|
27
|
+
input node and one output node.
|
|
28
|
+
|
|
29
|
+
Attributes:
|
|
30
|
+
nodes: Dictionary mapping node IDs to GraphNode instances
|
|
31
|
+
edges: Dictionary mapping edge IDs to GraphEdge instances
|
|
32
|
+
metadata: Additional metadata
|
|
33
|
+
|
|
34
|
+
Example:
|
|
35
|
+
>>> graph = ModelGraph()
|
|
36
|
+
>>> input_node = graph.add_node(GraphNode.create('input'))
|
|
37
|
+
>>> conv = graph.add_node(GraphNode.create('conv2d', {'filters': 64}))
|
|
38
|
+
>>> graph.add_edge(GraphEdge(input_node, conv))
|
|
39
|
+
"""
|
|
40
|
+
|
|
41
|
+
def __init__(self, metadata: Optional[Dict[str, Any]] = None):
|
|
42
|
+
"""
|
|
43
|
+
Initialize empty model graph.
|
|
44
|
+
|
|
45
|
+
Args:
|
|
46
|
+
metadata: Optional metadata
|
|
47
|
+
"""
|
|
48
|
+
self.nodes: Dict[str, GraphNode] = {}
|
|
49
|
+
self.edges: Dict[str, GraphEdge] = {}
|
|
50
|
+
self.metadata = metadata or {}
|
|
51
|
+
|
|
52
|
+
def add_node(self, node: GraphNode) -> GraphNode:
|
|
53
|
+
"""
|
|
54
|
+
Add a node to the graph.
|
|
55
|
+
|
|
56
|
+
Args:
|
|
57
|
+
node: GraphNode to add
|
|
58
|
+
|
|
59
|
+
Returns:
|
|
60
|
+
The added node
|
|
61
|
+
|
|
62
|
+
Raises:
|
|
63
|
+
GraphError: If node with same ID already exists
|
|
64
|
+
"""
|
|
65
|
+
if node.id in self.nodes:
|
|
66
|
+
raise GraphError(f"Node with ID {node.id} already exists")
|
|
67
|
+
|
|
68
|
+
self.nodes[node.id] = node
|
|
69
|
+
logger.debug(f"Added node: {node.operation} (id={node.id[:8]})")
|
|
70
|
+
return node
|
|
71
|
+
|
|
72
|
+
def add_edge(self, edge: GraphEdge) -> GraphEdge:
|
|
73
|
+
"""
|
|
74
|
+
Add an edge to the graph.
|
|
75
|
+
|
|
76
|
+
Automatically updates predecessor/successor relationships.
|
|
77
|
+
|
|
78
|
+
Args:
|
|
79
|
+
edge: GraphEdge to add
|
|
80
|
+
|
|
81
|
+
Returns:
|
|
82
|
+
The added edge
|
|
83
|
+
|
|
84
|
+
Raises:
|
|
85
|
+
GraphError: If edge creates a cycle or nodes not in graph
|
|
86
|
+
"""
|
|
87
|
+
# Validate nodes exist
|
|
88
|
+
if edge.source.id not in self.nodes:
|
|
89
|
+
raise GraphError(f"Source node {edge.source.id} not in graph")
|
|
90
|
+
if edge.target.id not in self.nodes:
|
|
91
|
+
raise GraphError(f"Target node {edge.target.id} not in graph")
|
|
92
|
+
|
|
93
|
+
# Check for cycles
|
|
94
|
+
if self._would_create_cycle(edge.source, edge.target):
|
|
95
|
+
raise GraphError(
|
|
96
|
+
f"Adding edge from {edge.source.operation} to {edge.target.operation} "
|
|
97
|
+
"would create a cycle"
|
|
98
|
+
)
|
|
99
|
+
|
|
100
|
+
# Add edge
|
|
101
|
+
self.edges[edge.id] = edge
|
|
102
|
+
|
|
103
|
+
# Update connections
|
|
104
|
+
edge.source.add_successor(edge.target)
|
|
105
|
+
edge.target.add_predecessor(edge.source)
|
|
106
|
+
|
|
107
|
+
logger.debug(
|
|
108
|
+
f"Added edge: {edge.source.operation} -> {edge.target.operation} " f"(id={edge.id[:8]})"
|
|
109
|
+
)
|
|
110
|
+
return edge
|
|
111
|
+
|
|
112
|
+
def remove_node(self, node_id: str) -> None:
|
|
113
|
+
"""
|
|
114
|
+
Remove a node and its connected edges.
|
|
115
|
+
|
|
116
|
+
Args:
|
|
117
|
+
node_id: ID of node to remove
|
|
118
|
+
|
|
119
|
+
Raises:
|
|
120
|
+
GraphError: If node not found
|
|
121
|
+
"""
|
|
122
|
+
if node_id not in self.nodes:
|
|
123
|
+
raise GraphError(f"Node {node_id} not found")
|
|
124
|
+
|
|
125
|
+
node = self.nodes[node_id]
|
|
126
|
+
|
|
127
|
+
# Remove connected edges
|
|
128
|
+
edges_to_remove = [
|
|
129
|
+
edge_id
|
|
130
|
+
for edge_id, edge in self.edges.items()
|
|
131
|
+
if edge.source.id == node_id or edge.target.id == node_id
|
|
132
|
+
]
|
|
133
|
+
|
|
134
|
+
for edge_id in edges_to_remove:
|
|
135
|
+
self.remove_edge(edge_id)
|
|
136
|
+
|
|
137
|
+
# Remove node
|
|
138
|
+
del self.nodes[node_id]
|
|
139
|
+
logger.debug(f"Removed node: {node.operation} (id={node_id[:8]})")
|
|
140
|
+
|
|
141
|
+
def remove_edge(self, edge_id: str) -> None:
|
|
142
|
+
"""
|
|
143
|
+
Remove an edge from the graph.
|
|
144
|
+
|
|
145
|
+
Args:
|
|
146
|
+
edge_id: ID of edge to remove
|
|
147
|
+
|
|
148
|
+
Raises:
|
|
149
|
+
GraphError: If edge not found
|
|
150
|
+
"""
|
|
151
|
+
if edge_id not in self.edges:
|
|
152
|
+
raise GraphError(f"Edge {edge_id} not found")
|
|
153
|
+
|
|
154
|
+
edge = self.edges[edge_id]
|
|
155
|
+
|
|
156
|
+
# Update connections
|
|
157
|
+
edge.source.remove_successor(edge.target)
|
|
158
|
+
edge.target.remove_predecessor(edge.source)
|
|
159
|
+
|
|
160
|
+
# Remove edge
|
|
161
|
+
del self.edges[edge_id]
|
|
162
|
+
logger.debug(f"Removed edge: {edge_id[:8]}")
|
|
163
|
+
|
|
164
|
+
def get_input_nodes(self) -> List[GraphNode]:
|
|
165
|
+
"""
|
|
166
|
+
Get all input nodes (nodes with no predecessors).
|
|
167
|
+
|
|
168
|
+
Returns:
|
|
169
|
+
List of input nodes
|
|
170
|
+
"""
|
|
171
|
+
return [node for node in self.nodes.values() if len(node.predecessors) == 0]
|
|
172
|
+
|
|
173
|
+
def get_output_nodes(self) -> List[GraphNode]:
|
|
174
|
+
"""
|
|
175
|
+
Get all output nodes (nodes with no successors).
|
|
176
|
+
|
|
177
|
+
Returns:
|
|
178
|
+
List of output nodes
|
|
179
|
+
"""
|
|
180
|
+
return [node for node in self.nodes.values() if len(node.successors) == 0]
|
|
181
|
+
|
|
182
|
+
def get_input_node(self) -> Optional[GraphNode]:
|
|
183
|
+
"""Get single input node (returns first if multiple)."""
|
|
184
|
+
inputs = self.get_input_nodes()
|
|
185
|
+
return inputs[0] if inputs else None
|
|
186
|
+
|
|
187
|
+
def get_output_node(self) -> Optional[GraphNode]:
|
|
188
|
+
"""Get single output node (returns first if multiple)."""
|
|
189
|
+
outputs = self.get_output_nodes()
|
|
190
|
+
return outputs[0] if outputs else None
|
|
191
|
+
|
|
192
|
+
def topological_sort(self) -> List[GraphNode]:
|
|
193
|
+
"""
|
|
194
|
+
Return nodes in topological order.
|
|
195
|
+
|
|
196
|
+
Returns:
|
|
197
|
+
List of nodes in topological order
|
|
198
|
+
|
|
199
|
+
Raises:
|
|
200
|
+
GraphError: If graph has cycles
|
|
201
|
+
"""
|
|
202
|
+
try:
|
|
203
|
+
nx_graph = self.to_networkx()
|
|
204
|
+
sorted_ids = list(nx.topological_sort(nx_graph))
|
|
205
|
+
return [self.nodes[node_id] for node_id in sorted_ids]
|
|
206
|
+
except nx.NetworkXError as e:
|
|
207
|
+
raise GraphError(f"Graph is not a DAG: {e}") from e
|
|
208
|
+
|
|
209
|
+
def is_valid(self) -> bool:
|
|
210
|
+
"""
|
|
211
|
+
Check if graph is valid.
|
|
212
|
+
|
|
213
|
+
A valid graph:
|
|
214
|
+
- Is a DAG (no cycles)
|
|
215
|
+
- Has at least one input node
|
|
216
|
+
- Has at least one output node
|
|
217
|
+
- All nodes are reachable from input(s)
|
|
218
|
+
|
|
219
|
+
Returns:
|
|
220
|
+
True if valid, False otherwise
|
|
221
|
+
"""
|
|
222
|
+
try:
|
|
223
|
+
# Check for cycles
|
|
224
|
+
self.topological_sort()
|
|
225
|
+
|
|
226
|
+
# Check for input/output nodes
|
|
227
|
+
if not self.get_input_nodes():
|
|
228
|
+
return False
|
|
229
|
+
if not self.get_output_nodes():
|
|
230
|
+
return False
|
|
231
|
+
|
|
232
|
+
# Check all nodes are reachable
|
|
233
|
+
if not self._all_nodes_reachable():
|
|
234
|
+
return False
|
|
235
|
+
|
|
236
|
+
return True
|
|
237
|
+
|
|
238
|
+
except (GraphError, nx.NetworkXError):
|
|
239
|
+
return False
|
|
240
|
+
|
|
241
|
+
def is_valid_dag(self) -> bool:
|
|
242
|
+
"""
|
|
243
|
+
Check if graph is a valid DAG.
|
|
244
|
+
|
|
245
|
+
Alias for is_valid() method.
|
|
246
|
+
|
|
247
|
+
Returns:
|
|
248
|
+
True if valid DAG, False otherwise
|
|
249
|
+
"""
|
|
250
|
+
return self.is_valid()
|
|
251
|
+
|
|
252
|
+
def clone(self) -> "ModelGraph":
|
|
253
|
+
"""
|
|
254
|
+
Create a deep copy of this graph.
|
|
255
|
+
|
|
256
|
+
Returns:
|
|
257
|
+
Cloned ModelGraph
|
|
258
|
+
"""
|
|
259
|
+
# Clone metadata
|
|
260
|
+
cloned = ModelGraph(metadata=self.metadata.copy())
|
|
261
|
+
|
|
262
|
+
# Clone nodes (create mapping)
|
|
263
|
+
node_mapping = {}
|
|
264
|
+
for node_id, node in self.nodes.items():
|
|
265
|
+
cloned_node = node.clone()
|
|
266
|
+
cloned.nodes[cloned_node.id] = cloned_node
|
|
267
|
+
node_mapping[node_id] = cloned_node
|
|
268
|
+
|
|
269
|
+
# Clone edges (using new nodes)
|
|
270
|
+
for edge in self.edges.values():
|
|
271
|
+
new_source = node_mapping[edge.source.id]
|
|
272
|
+
new_target = node_mapping[edge.target.id]
|
|
273
|
+
cloned_edge = GraphEdge(
|
|
274
|
+
source=new_source,
|
|
275
|
+
target=new_target,
|
|
276
|
+
operation=edge.operation,
|
|
277
|
+
metadata=edge.metadata.copy(),
|
|
278
|
+
)
|
|
279
|
+
cloned.add_edge(cloned_edge)
|
|
280
|
+
|
|
281
|
+
return cloned
|
|
282
|
+
|
|
283
|
+
def to_networkx(self) -> nx.DiGraph:
|
|
284
|
+
"""
|
|
285
|
+
Convert to NetworkX DiGraph.
|
|
286
|
+
|
|
287
|
+
Returns:
|
|
288
|
+
NetworkX directed graph
|
|
289
|
+
"""
|
|
290
|
+
G = nx.DiGraph()
|
|
291
|
+
|
|
292
|
+
# Add nodes
|
|
293
|
+
for node_id, node in self.nodes.items():
|
|
294
|
+
G.add_node(node_id, **node.to_dict())
|
|
295
|
+
|
|
296
|
+
# Add edges
|
|
297
|
+
for edge in self.edges.values():
|
|
298
|
+
G.add_edge(edge.source.id, edge.target.id, **edge.to_dict())
|
|
299
|
+
|
|
300
|
+
return G
|
|
301
|
+
|
|
302
|
+
def to_dict(self) -> Dict[str, Any]:
|
|
303
|
+
"""
|
|
304
|
+
Serialize graph to dictionary.
|
|
305
|
+
|
|
306
|
+
Returns:
|
|
307
|
+
Dictionary representation
|
|
308
|
+
"""
|
|
309
|
+
return {
|
|
310
|
+
"nodes": [node.to_dict() for node in self.nodes.values()],
|
|
311
|
+
"edges": [edge.to_dict() for edge in self.edges.values()],
|
|
312
|
+
"metadata": self.metadata,
|
|
313
|
+
}
|
|
314
|
+
|
|
315
|
+
def to_json(self) -> str:
|
|
316
|
+
"""
|
|
317
|
+
Serialize graph to JSON string.
|
|
318
|
+
|
|
319
|
+
Returns:
|
|
320
|
+
JSON string representation
|
|
321
|
+
"""
|
|
322
|
+
return json.dumps(self.to_dict(), indent=2)
|
|
323
|
+
|
|
324
|
+
@classmethod
|
|
325
|
+
def from_dict(cls, data: Dict[str, Any]) -> "ModelGraph":
|
|
326
|
+
"""
|
|
327
|
+
Deserialize graph from dictionary.
|
|
328
|
+
|
|
329
|
+
Args:
|
|
330
|
+
data: Dictionary representation
|
|
331
|
+
|
|
332
|
+
Returns:
|
|
333
|
+
ModelGraph instance
|
|
334
|
+
"""
|
|
335
|
+
graph = cls(metadata=data.get("metadata", {}))
|
|
336
|
+
|
|
337
|
+
# Restore nodes
|
|
338
|
+
for node_data in data["nodes"]:
|
|
339
|
+
node = GraphNode.from_dict(node_data)
|
|
340
|
+
graph.nodes[node.id] = node
|
|
341
|
+
|
|
342
|
+
# Restore edges
|
|
343
|
+
for edge_data in data["edges"]:
|
|
344
|
+
edge = GraphEdge.from_dict(edge_data, graph.nodes)
|
|
345
|
+
graph.edges[edge.id] = edge
|
|
346
|
+
|
|
347
|
+
# Restore connections
|
|
348
|
+
edge.source.add_successor(edge.target)
|
|
349
|
+
edge.target.add_predecessor(edge.source)
|
|
350
|
+
|
|
351
|
+
return graph
|
|
352
|
+
|
|
353
|
+
@classmethod
|
|
354
|
+
def from_json(cls, json_str: str) -> "ModelGraph":
|
|
355
|
+
"""
|
|
356
|
+
Deserialize graph from JSON string.
|
|
357
|
+
|
|
358
|
+
Args:
|
|
359
|
+
json_str: JSON string representation
|
|
360
|
+
|
|
361
|
+
Returns:
|
|
362
|
+
ModelGraph instance
|
|
363
|
+
"""
|
|
364
|
+
data = json.loads(json_str)
|
|
365
|
+
return cls.from_dict(data)
|
|
366
|
+
|
|
367
|
+
def hash(self) -> str:
|
|
368
|
+
"""
|
|
369
|
+
Compute hash of graph structure.
|
|
370
|
+
|
|
371
|
+
Used for deduplication and caching.
|
|
372
|
+
|
|
373
|
+
Returns:
|
|
374
|
+
SHA256 hash of graph
|
|
375
|
+
"""
|
|
376
|
+
# Create canonical representation
|
|
377
|
+
canonical = {
|
|
378
|
+
"nodes": sorted(
|
|
379
|
+
[
|
|
380
|
+
(node.operation, tuple(sorted(node.params.items())))
|
|
381
|
+
for node in self.nodes.values()
|
|
382
|
+
]
|
|
383
|
+
),
|
|
384
|
+
"edges": sorted(
|
|
385
|
+
[(edge.source.operation, edge.target.operation) for edge in self.edges.values()]
|
|
386
|
+
),
|
|
387
|
+
}
|
|
388
|
+
|
|
389
|
+
canonical_str = json.dumps(canonical, sort_keys=True)
|
|
390
|
+
return hashlib.sha256(canonical_str.encode()).hexdigest()
|
|
391
|
+
|
|
392
|
+
def get_depth(self) -> int:
|
|
393
|
+
"""
|
|
394
|
+
Get maximum depth of graph.
|
|
395
|
+
|
|
396
|
+
Returns:
|
|
397
|
+
Maximum depth (longest path from input to output)
|
|
398
|
+
"""
|
|
399
|
+
try:
|
|
400
|
+
nx_graph = self.to_networkx()
|
|
401
|
+
length: int = nx.dag_longest_path_length(nx_graph)
|
|
402
|
+
return length
|
|
403
|
+
except nx.NetworkXError:
|
|
404
|
+
return 0
|
|
405
|
+
|
|
406
|
+
def get_max_width(self) -> int:
|
|
407
|
+
"""
|
|
408
|
+
Get maximum width of graph.
|
|
409
|
+
|
|
410
|
+
Returns:
|
|
411
|
+
Maximum number of nodes at any depth level
|
|
412
|
+
"""
|
|
413
|
+
# Compute levels
|
|
414
|
+
levels: Dict[int, int] = {}
|
|
415
|
+
|
|
416
|
+
def compute_level(node: GraphNode, level: int = 0) -> None:
|
|
417
|
+
levels[level] = levels.get(level, 0) + 1
|
|
418
|
+
for successor in node.successors:
|
|
419
|
+
compute_level(successor, level + 1)
|
|
420
|
+
|
|
421
|
+
# Start from input nodes
|
|
422
|
+
for input_node in self.get_input_nodes():
|
|
423
|
+
compute_level(input_node)
|
|
424
|
+
|
|
425
|
+
return max(levels.values()) if levels else 0
|
|
426
|
+
|
|
427
|
+
def estimate_parameters(self) -> int:
|
|
428
|
+
"""
|
|
429
|
+
Estimate number of parameters (simplified).
|
|
430
|
+
|
|
431
|
+
Returns:
|
|
432
|
+
Estimated parameter count
|
|
433
|
+
"""
|
|
434
|
+
total_params = 0
|
|
435
|
+
|
|
436
|
+
for node in self.nodes.values():
|
|
437
|
+
if node.operation == "conv2d":
|
|
438
|
+
filters = node.get_param("filters", 64)
|
|
439
|
+
kernel_size = node.get_param("kernel_size", 3)
|
|
440
|
+
in_channels = node.get_param("in_channels", 3)
|
|
441
|
+
params = in_channels * filters * kernel_size * kernel_size
|
|
442
|
+
total_params += params
|
|
443
|
+
|
|
444
|
+
elif node.operation == "dense":
|
|
445
|
+
units = node.get_param("units", 128)
|
|
446
|
+
in_features = node.get_param("in_features", 512)
|
|
447
|
+
params = in_features * units
|
|
448
|
+
total_params += params
|
|
449
|
+
|
|
450
|
+
return total_params
|
|
451
|
+
|
|
452
|
+
def _would_create_cycle(self, source: GraphNode, target: GraphNode) -> bool:
|
|
453
|
+
"""
|
|
454
|
+
Check if adding edge would create a cycle.
|
|
455
|
+
|
|
456
|
+
Args:
|
|
457
|
+
source: Source node
|
|
458
|
+
target: Target node
|
|
459
|
+
|
|
460
|
+
Returns:
|
|
461
|
+
True if would create cycle, False otherwise
|
|
462
|
+
"""
|
|
463
|
+
# DFS from target to see if we can reach source
|
|
464
|
+
visited: Set[str] = set()
|
|
465
|
+
|
|
466
|
+
def dfs(node: GraphNode) -> bool:
|
|
467
|
+
if node.id in visited:
|
|
468
|
+
return False
|
|
469
|
+
if node.id == source.id:
|
|
470
|
+
return True
|
|
471
|
+
|
|
472
|
+
visited.add(node.id)
|
|
473
|
+
for successor in node.successors:
|
|
474
|
+
if dfs(successor):
|
|
475
|
+
return True
|
|
476
|
+
|
|
477
|
+
return False
|
|
478
|
+
|
|
479
|
+
return dfs(target)
|
|
480
|
+
|
|
481
|
+
def _all_nodes_reachable(self) -> bool:
|
|
482
|
+
"""Check if all nodes are reachable from input nodes."""
|
|
483
|
+
reachable: Set[str] = set()
|
|
484
|
+
|
|
485
|
+
def dfs(node: GraphNode) -> None:
|
|
486
|
+
if node.id in reachable:
|
|
487
|
+
return
|
|
488
|
+
reachable.add(node.id)
|
|
489
|
+
for successor in node.successors:
|
|
490
|
+
dfs(successor)
|
|
491
|
+
|
|
492
|
+
# DFS from all input nodes
|
|
493
|
+
for input_node in self.get_input_nodes():
|
|
494
|
+
dfs(input_node)
|
|
495
|
+
|
|
496
|
+
return len(reachable) == len(self.nodes)
|
|
497
|
+
|
|
498
|
+
def __repr__(self) -> str:
|
|
499
|
+
"""String representation of graph."""
|
|
500
|
+
return (
|
|
501
|
+
f"ModelGraph(nodes={len(self.nodes)}, edges={len(self.edges)}, "
|
|
502
|
+
f"depth={self.get_depth()}, hash={self.hash()[:8]})"
|
|
503
|
+
)
|
|
504
|
+
|
|
505
|
+
def __len__(self) -> int:
|
|
506
|
+
"""Return number of nodes."""
|
|
507
|
+
return len(self.nodes)
|