morphml 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of morphml might be problematic. Click here for more details.

Files changed (158) hide show
  1. morphml/__init__.py +14 -0
  2. morphml/api/__init__.py +26 -0
  3. morphml/api/app.py +326 -0
  4. morphml/api/auth.py +193 -0
  5. morphml/api/client.py +338 -0
  6. morphml/api/models.py +132 -0
  7. morphml/api/rate_limit.py +192 -0
  8. morphml/benchmarking/__init__.py +36 -0
  9. morphml/benchmarking/comparison.py +430 -0
  10. morphml/benchmarks/__init__.py +56 -0
  11. morphml/benchmarks/comparator.py +409 -0
  12. morphml/benchmarks/datasets.py +280 -0
  13. morphml/benchmarks/metrics.py +199 -0
  14. morphml/benchmarks/openml_suite.py +201 -0
  15. morphml/benchmarks/problems.py +289 -0
  16. morphml/benchmarks/suite.py +318 -0
  17. morphml/cli/__init__.py +5 -0
  18. morphml/cli/commands/experiment.py +329 -0
  19. morphml/cli/main.py +457 -0
  20. morphml/cli/quickstart.py +312 -0
  21. morphml/config.py +278 -0
  22. morphml/constraints/__init__.py +19 -0
  23. morphml/constraints/handler.py +205 -0
  24. morphml/constraints/predicates.py +285 -0
  25. morphml/core/__init__.py +3 -0
  26. morphml/core/crossover.py +449 -0
  27. morphml/core/dsl/README.md +359 -0
  28. morphml/core/dsl/__init__.py +72 -0
  29. morphml/core/dsl/ast_nodes.py +364 -0
  30. morphml/core/dsl/compiler.py +318 -0
  31. morphml/core/dsl/layers.py +368 -0
  32. morphml/core/dsl/lexer.py +336 -0
  33. morphml/core/dsl/parser.py +455 -0
  34. morphml/core/dsl/search_space.py +386 -0
  35. morphml/core/dsl/syntax.py +199 -0
  36. morphml/core/dsl/type_system.py +361 -0
  37. morphml/core/dsl/validator.py +386 -0
  38. morphml/core/graph/__init__.py +40 -0
  39. morphml/core/graph/edge.py +124 -0
  40. morphml/core/graph/graph.py +507 -0
  41. morphml/core/graph/mutations.py +409 -0
  42. morphml/core/graph/node.py +196 -0
  43. morphml/core/graph/serialization.py +361 -0
  44. morphml/core/graph/visualization.py +431 -0
  45. morphml/core/objectives/__init__.py +20 -0
  46. morphml/core/search/__init__.py +33 -0
  47. morphml/core/search/individual.py +252 -0
  48. morphml/core/search/parameters.py +453 -0
  49. morphml/core/search/population.py +375 -0
  50. morphml/core/search/search_engine.py +340 -0
  51. morphml/distributed/__init__.py +76 -0
  52. morphml/distributed/fault_tolerance.py +497 -0
  53. morphml/distributed/health_monitor.py +348 -0
  54. morphml/distributed/master.py +709 -0
  55. morphml/distributed/proto/README.md +224 -0
  56. morphml/distributed/proto/__init__.py +74 -0
  57. morphml/distributed/proto/worker.proto +170 -0
  58. morphml/distributed/proto/worker_pb2.py +79 -0
  59. morphml/distributed/proto/worker_pb2_grpc.py +423 -0
  60. morphml/distributed/resource_manager.py +416 -0
  61. morphml/distributed/scheduler.py +567 -0
  62. morphml/distributed/storage/__init__.py +33 -0
  63. morphml/distributed/storage/artifacts.py +381 -0
  64. morphml/distributed/storage/cache.py +366 -0
  65. morphml/distributed/storage/checkpointing.py +329 -0
  66. morphml/distributed/storage/database.py +459 -0
  67. morphml/distributed/worker.py +549 -0
  68. morphml/evaluation/__init__.py +5 -0
  69. morphml/evaluation/heuristic.py +237 -0
  70. morphml/exceptions.py +55 -0
  71. morphml/execution/__init__.py +5 -0
  72. morphml/execution/local_executor.py +350 -0
  73. morphml/integrations/__init__.py +28 -0
  74. morphml/integrations/jax_adapter.py +206 -0
  75. morphml/integrations/pytorch_adapter.py +530 -0
  76. morphml/integrations/sklearn_adapter.py +206 -0
  77. morphml/integrations/tensorflow_adapter.py +230 -0
  78. morphml/logging_config.py +93 -0
  79. morphml/meta_learning/__init__.py +66 -0
  80. morphml/meta_learning/architecture_similarity.py +277 -0
  81. morphml/meta_learning/experiment_database.py +240 -0
  82. morphml/meta_learning/knowledge_base/__init__.py +19 -0
  83. morphml/meta_learning/knowledge_base/embedder.py +179 -0
  84. morphml/meta_learning/knowledge_base/knowledge_base.py +313 -0
  85. morphml/meta_learning/knowledge_base/meta_features.py +265 -0
  86. morphml/meta_learning/knowledge_base/vector_store.py +271 -0
  87. morphml/meta_learning/predictors/__init__.py +27 -0
  88. morphml/meta_learning/predictors/ensemble.py +221 -0
  89. morphml/meta_learning/predictors/gnn_predictor.py +552 -0
  90. morphml/meta_learning/predictors/learning_curve.py +231 -0
  91. morphml/meta_learning/predictors/proxy_metrics.py +261 -0
  92. morphml/meta_learning/strategy_evolution/__init__.py +27 -0
  93. morphml/meta_learning/strategy_evolution/adaptive_optimizer.py +226 -0
  94. morphml/meta_learning/strategy_evolution/bandit.py +276 -0
  95. morphml/meta_learning/strategy_evolution/portfolio.py +230 -0
  96. morphml/meta_learning/transfer.py +581 -0
  97. morphml/meta_learning/warm_start.py +286 -0
  98. morphml/optimizers/__init__.py +74 -0
  99. morphml/optimizers/adaptive_operators.py +399 -0
  100. morphml/optimizers/bayesian/__init__.py +52 -0
  101. morphml/optimizers/bayesian/acquisition.py +387 -0
  102. morphml/optimizers/bayesian/base.py +319 -0
  103. morphml/optimizers/bayesian/gaussian_process.py +635 -0
  104. morphml/optimizers/bayesian/smac.py +534 -0
  105. morphml/optimizers/bayesian/tpe.py +411 -0
  106. morphml/optimizers/differential_evolution.py +220 -0
  107. morphml/optimizers/evolutionary/__init__.py +61 -0
  108. morphml/optimizers/evolutionary/cma_es.py +416 -0
  109. morphml/optimizers/evolutionary/differential_evolution.py +556 -0
  110. morphml/optimizers/evolutionary/encoding.py +426 -0
  111. morphml/optimizers/evolutionary/particle_swarm.py +449 -0
  112. morphml/optimizers/genetic_algorithm.py +486 -0
  113. morphml/optimizers/gradient_based/__init__.py +22 -0
  114. morphml/optimizers/gradient_based/darts.py +550 -0
  115. morphml/optimizers/gradient_based/enas.py +585 -0
  116. morphml/optimizers/gradient_based/operations.py +474 -0
  117. morphml/optimizers/gradient_based/utils.py +601 -0
  118. morphml/optimizers/hill_climbing.py +169 -0
  119. morphml/optimizers/multi_objective/__init__.py +56 -0
  120. morphml/optimizers/multi_objective/indicators.py +504 -0
  121. morphml/optimizers/multi_objective/nsga2.py +647 -0
  122. morphml/optimizers/multi_objective/visualization.py +427 -0
  123. morphml/optimizers/nsga2.py +308 -0
  124. morphml/optimizers/random_search.py +172 -0
  125. morphml/optimizers/simulated_annealing.py +181 -0
  126. morphml/plugins/__init__.py +35 -0
  127. morphml/plugins/custom_evaluator_example.py +81 -0
  128. morphml/plugins/custom_optimizer_example.py +63 -0
  129. morphml/plugins/plugin_system.py +454 -0
  130. morphml/reports/__init__.py +30 -0
  131. morphml/reports/generator.py +362 -0
  132. morphml/tracking/__init__.py +7 -0
  133. morphml/tracking/experiment.py +309 -0
  134. morphml/tracking/logger.py +301 -0
  135. morphml/tracking/reporter.py +357 -0
  136. morphml/utils/__init__.py +6 -0
  137. morphml/utils/checkpoint.py +189 -0
  138. morphml/utils/comparison.py +390 -0
  139. morphml/utils/export.py +407 -0
  140. morphml/utils/progress.py +392 -0
  141. morphml/utils/validation.py +392 -0
  142. morphml/version.py +7 -0
  143. morphml/visualization/__init__.py +50 -0
  144. morphml/visualization/analytics.py +423 -0
  145. morphml/visualization/architecture_diagrams.py +353 -0
  146. morphml/visualization/architecture_plot.py +223 -0
  147. morphml/visualization/convergence_plot.py +174 -0
  148. morphml/visualization/crossover_viz.py +386 -0
  149. morphml/visualization/graph_viz.py +338 -0
  150. morphml/visualization/pareto_plot.py +149 -0
  151. morphml/visualization/plotly_dashboards.py +422 -0
  152. morphml/visualization/population.py +309 -0
  153. morphml/visualization/progress.py +260 -0
  154. morphml-1.0.0.dist-info/METADATA +434 -0
  155. morphml-1.0.0.dist-info/RECORD +158 -0
  156. morphml-1.0.0.dist-info/WHEEL +4 -0
  157. morphml-1.0.0.dist-info/entry_points.txt +3 -0
  158. morphml-1.0.0.dist-info/licenses/LICENSE +21 -0
@@ -0,0 +1,409 @@
1
+ """Graph mutation operations for architecture search."""
2
+
3
+ import random
4
+ from typing import List, Optional
5
+
6
+ from morphml.core.graph.edge import GraphEdge
7
+ from morphml.core.graph.graph import ModelGraph
8
+ from morphml.core.graph.node import GraphNode
9
+ from morphml.logging_config import get_logger
10
+
11
+ logger = get_logger(__name__)
12
+
13
+
14
+ class GraphMutator:
15
+ """
16
+ Applies mutations to ModelGraph for architecture search.
17
+
18
+ Mutation operations:
19
+ - add_node: Insert a new node
20
+ - remove_node: Remove an existing node
21
+ - modify_node: Change node parameters
22
+ - add_edge: Add a new connection
23
+ - remove_edge: Remove a connection
24
+
25
+ All mutations preserve DAG property.
26
+
27
+ Example:
28
+ >>> mutator = GraphMutator()
29
+ >>> mutated_graph = mutator.mutate(original_graph, mutation_rate=0.1)
30
+ """
31
+
32
+ def __init__(self, operation_types: Optional[List[str]] = None):
33
+ """
34
+ Initialize graph mutator.
35
+
36
+ Args:
37
+ operation_types: List of available operation types for mutations
38
+ """
39
+ self.operation_types = operation_types or [
40
+ "conv2d",
41
+ "maxpool",
42
+ "avgpool",
43
+ "dense",
44
+ "relu",
45
+ "batchnorm",
46
+ "dropout",
47
+ ]
48
+
49
+ def mutate(
50
+ self,
51
+ graph: ModelGraph,
52
+ mutation_rate: float = 0.1,
53
+ max_mutations: Optional[int] = None,
54
+ ) -> ModelGraph:
55
+ """
56
+ Apply random mutations to graph.
57
+
58
+ Args:
59
+ graph: Original graph
60
+ mutation_rate: Probability of mutating each component
61
+ max_mutations: Maximum number of mutations (None = no limit)
62
+
63
+ Returns:
64
+ Mutated graph (new instance)
65
+ """
66
+ mutated = graph.clone()
67
+ mutations_applied = 0
68
+
69
+ # Available mutation operations
70
+ mutation_ops = [
71
+ self.add_node_mutation,
72
+ self.remove_node_mutation,
73
+ self.modify_node_mutation,
74
+ self.add_edge_mutation,
75
+ self.remove_edge_mutation,
76
+ ]
77
+
78
+ # Apply mutations
79
+ while random.random() < mutation_rate:
80
+ if max_mutations and mutations_applied >= max_mutations:
81
+ break
82
+
83
+ # Select random mutation
84
+ mutation_op = random.choice(mutation_ops)
85
+
86
+ try:
87
+ mutation_op(mutated)
88
+ mutations_applied += 1
89
+ except Exception as e:
90
+ logger.debug(f"Mutation failed: {e}")
91
+ continue
92
+
93
+ logger.debug(f"Applied {mutations_applied} mutations")
94
+ return mutated
95
+
96
+ def add_node_mutation(self, graph: ModelGraph) -> None:
97
+ """
98
+ Add a new node to the graph.
99
+
100
+ Strategy:
101
+ 1. Select random operation type
102
+ 2. Insert between two connected nodes
103
+ 3. Update edges accordingly
104
+
105
+ Args:
106
+ graph: Graph to mutate (modified in-place)
107
+ """
108
+ if len(graph.edges) == 0:
109
+ logger.debug("No edges to insert node between")
110
+ return
111
+
112
+ # Select random edge to split
113
+ edge = random.choice(list(graph.edges.values()))
114
+
115
+ # Create new node
116
+ operation = random.choice(self.operation_types)
117
+ new_node = GraphNode.create(operation, params=self._random_params(operation))
118
+
119
+ # Add node to graph
120
+ graph.add_node(new_node)
121
+
122
+ # Remove old edge
123
+ graph.remove_edge(edge.id)
124
+
125
+ # Add new edges: source -> new_node -> target
126
+ edge1 = GraphEdge(edge.source, new_node)
127
+ edge2 = GraphEdge(new_node, edge.target)
128
+
129
+ graph.add_edge(edge1)
130
+ graph.add_edge(edge2)
131
+
132
+ logger.debug(
133
+ f"Added node: {operation} between {edge.source.operation} and {edge.target.operation}"
134
+ )
135
+
136
+ def remove_node_mutation(self, graph: ModelGraph) -> None:
137
+ """
138
+ Remove a node from the graph.
139
+
140
+ Strategy:
141
+ 1. Select random non-input/non-output node
142
+ 2. Connect its predecessors directly to its successors
143
+ 3. Remove the node and its edges
144
+
145
+ Args:
146
+ graph: Graph to mutate (modified in-place)
147
+ """
148
+ # Get candidates (exclude input/output nodes)
149
+ input_ids = {n.id for n in graph.get_input_nodes()}
150
+ output_ids = {n.id for n in graph.get_output_nodes()}
151
+
152
+ candidates = [
153
+ node
154
+ for node in graph.nodes.values()
155
+ if node.id not in input_ids and node.id not in output_ids
156
+ ]
157
+
158
+ if not candidates:
159
+ logger.debug("No nodes available for removal")
160
+ return
161
+
162
+ # Select random node
163
+ node_to_remove = random.choice(candidates)
164
+
165
+ # Connect predecessors to successors
166
+ for pred in node_to_remove.predecessors:
167
+ for succ in node_to_remove.successors:
168
+ # Check if edge doesn't already exist
169
+ existing = any(
170
+ e.source.id == pred.id and e.target.id == succ.id for e in graph.edges.values()
171
+ )
172
+
173
+ if not existing:
174
+ new_edge = GraphEdge(pred, succ)
175
+ try:
176
+ graph.add_edge(new_edge)
177
+ except Exception as e:
178
+ logger.debug(f"Failed to add bypass edge: {e}")
179
+
180
+ # Remove node
181
+ graph.remove_node(node_to_remove.id)
182
+
183
+ logger.debug(f"Removed node: {node_to_remove.operation}")
184
+
185
+ def modify_node_mutation(self, graph: ModelGraph) -> None:
186
+ """
187
+ Modify parameters of an existing node.
188
+
189
+ Args:
190
+ graph: Graph to mutate (modified in-place)
191
+ """
192
+ if not graph.nodes:
193
+ return
194
+
195
+ # Select random node
196
+ node = random.choice(list(graph.nodes.values()))
197
+
198
+ # Modify a random parameter
199
+ if node.params:
200
+ param_key = random.choice(list(node.params.keys()))
201
+ old_value = node.params[param_key]
202
+
203
+ # Generate new value based on type
204
+ if isinstance(old_value, int):
205
+ # Multiply by random factor or add random offset
206
+ if random.random() < 0.5:
207
+ node.params[param_key] = max(1, old_value * random.choice([2, 4, 8]))
208
+ else:
209
+ node.params[param_key] = max(1, old_value // random.choice([2, 4]))
210
+
211
+ elif isinstance(old_value, float):
212
+ node.params[param_key] = old_value * random.uniform(0.5, 2.0)
213
+
214
+ logger.debug(
215
+ f"Modified {node.operation}.{param_key}: {old_value} -> {node.params[param_key]}"
216
+ )
217
+ else:
218
+ # Add random parameters if node has none
219
+ node.params = self._random_params(node.operation)
220
+ logger.debug(f"Added params to {node.operation}: {node.params}")
221
+
222
+ def add_edge_mutation(self, graph: ModelGraph) -> None:
223
+ """
224
+ Add a new edge (skip connection).
225
+
226
+ Args:
227
+ graph: Graph to mutate (modified in-place)
228
+ """
229
+ if len(graph.nodes) < 2:
230
+ return
231
+
232
+ nodes = list(graph.nodes.values())
233
+
234
+ # Try multiple times to find valid edge
235
+ for _ in range(10):
236
+ source = random.choice(nodes)
237
+ target = random.choice(nodes)
238
+
239
+ # Skip if same node or edge exists or would create cycle
240
+ if source.id == target.id:
241
+ continue
242
+
243
+ if any(
244
+ e.source.id == source.id and e.target.id == target.id for e in graph.edges.values()
245
+ ):
246
+ continue
247
+
248
+ try:
249
+ new_edge = GraphEdge(source, target)
250
+ graph.add_edge(new_edge)
251
+ logger.debug(f"Added edge: {source.operation} -> {target.operation}")
252
+ return
253
+ except Exception:
254
+ continue
255
+
256
+ logger.debug("Failed to add edge after multiple attempts")
257
+
258
+ def remove_edge_mutation(self, graph: ModelGraph) -> None:
259
+ """
260
+ Remove an edge.
261
+
262
+ Ensures graph remains connected.
263
+
264
+ Args:
265
+ graph: Graph to mutate (modified in-place)
266
+ """
267
+ if len(graph.edges) <= len(graph.nodes) - 1:
268
+ # Need at least n-1 edges to stay connected
269
+ logger.debug("Too few edges to remove")
270
+ return
271
+
272
+ # Get non-critical edges (removing won't disconnect graph)
273
+ candidates = []
274
+ for edge in graph.edges.values():
275
+ # Check if target has other predecessors
276
+ if len(edge.target.predecessors) > 1:
277
+ candidates.append(edge)
278
+
279
+ if not candidates:
280
+ logger.debug("No removable edges found")
281
+ return
282
+
283
+ edge_to_remove = random.choice(candidates)
284
+ graph.remove_edge(edge_to_remove.id)
285
+
286
+ logger.debug(
287
+ f"Removed edge: {edge_to_remove.source.operation} -> "
288
+ f"{edge_to_remove.target.operation}"
289
+ )
290
+
291
+ def _random_params(self, operation: str) -> dict:
292
+ """
293
+ Generate random parameters for an operation.
294
+
295
+ Args:
296
+ operation: Operation type
297
+
298
+ Returns:
299
+ Dictionary of parameters
300
+ """
301
+ if operation == "conv2d":
302
+ return {
303
+ "filters": random.choice([32, 64, 128, 256]),
304
+ "kernel_size": random.choice([3, 5, 7]),
305
+ "padding": "same",
306
+ }
307
+
308
+ elif operation == "dense":
309
+ return {"units": random.choice([64, 128, 256, 512])}
310
+
311
+ elif operation in ["maxpool", "avgpool"]:
312
+ return {"pool_size": random.choice([2, 3]), "stride": 2}
313
+
314
+ elif operation == "dropout":
315
+ return {"rate": random.uniform(0.1, 0.5)}
316
+
317
+ elif operation == "batchnorm":
318
+ return {}
319
+
320
+ elif operation == "relu":
321
+ return {}
322
+
323
+ else:
324
+ return {}
325
+
326
+
327
+ def crossover(parent1: ModelGraph, parent2: ModelGraph) -> tuple[ModelGraph, ModelGraph]:
328
+ """
329
+ Perform crossover between two graphs.
330
+
331
+ Creates two offspring by exchanging subgraphs.
332
+
333
+ Args:
334
+ parent1: First parent graph
335
+ parent2: Second parent graph
336
+
337
+ Returns:
338
+ Tuple of two offspring graphs
339
+
340
+ Note:
341
+ Implements single-point crossover by splitting parents at a random point
342
+ and combining their subgraphs.
343
+ """
344
+ import random
345
+
346
+ # Get topologically sorted nodes from both parents
347
+ try:
348
+ nodes1 = parent1.topological_sort()
349
+ nodes2 = parent2.topological_sort()
350
+ except Exception as e:
351
+ logger.warning(f"Crossover failed during topological sort: {e}, returning clones")
352
+ return parent1.clone(), parent2.clone()
353
+
354
+ # If either parent is too small, just return clones
355
+ if len(nodes1) < 3 or len(nodes2) < 3:
356
+ logger.debug("Parents too small for crossover, returning clones")
357
+ return parent1.clone(), parent2.clone()
358
+
359
+ # Choose crossover points (excluding input/output nodes)
360
+ point1 = random.randint(1, len(nodes1) - 2)
361
+ point2 = random.randint(1, len(nodes2) - 2)
362
+
363
+ # Create offspring by combining subgraphs
364
+ offspring1 = ModelGraph(metadata={"crossover": "parent1_start + parent2_end"})
365
+ offspring2 = ModelGraph(metadata={"crossover": "parent2_start + parent1_end"})
366
+
367
+ try:
368
+ # Offspring 1: first part of parent1 + second part of parent2
369
+ for i, node in enumerate(nodes1[:point1]):
370
+ new_node = node.clone()
371
+ offspring1.add_node(new_node)
372
+
373
+ for i, node in enumerate(nodes2[point2:]):
374
+ new_node = node.clone()
375
+ offspring1.add_node(new_node)
376
+
377
+ # Connect the nodes sequentially
378
+ all_nodes1 = list(offspring1.nodes.values())
379
+ for i in range(len(all_nodes1) - 1):
380
+ edge = GraphEdge(all_nodes1[i], all_nodes1[i + 1])
381
+ offspring1.add_edge(edge)
382
+
383
+ # Offspring 2: first part of parent2 + second part of parent1
384
+ for i, node in enumerate(nodes2[:point2]):
385
+ new_node = node.clone()
386
+ offspring2.add_node(new_node)
387
+
388
+ for i, node in enumerate(nodes1[point1:]):
389
+ new_node = node.clone()
390
+ offspring2.add_node(new_node)
391
+
392
+ # Connect the nodes sequentially
393
+ all_nodes2 = list(offspring2.nodes.values())
394
+ for i in range(len(all_nodes2) - 1):
395
+ edge = GraphEdge(all_nodes2[i], all_nodes2[i + 1])
396
+ offspring2.add_edge(edge)
397
+
398
+ logger.debug(f"Performed single-point crossover at points {point1}/{point2}")
399
+
400
+ # Validate offspring
401
+ if not offspring1.is_valid() or not offspring2.is_valid():
402
+ logger.warning("Crossover produced invalid offspring, returning clones")
403
+ return parent1.clone(), parent2.clone()
404
+
405
+ return offspring1, offspring2
406
+
407
+ except Exception as e:
408
+ logger.warning(f"Crossover failed: {e}, returning clones")
409
+ return parent1.clone(), parent2.clone()
@@ -0,0 +1,196 @@
1
+ """Graph node representation for neural architecture."""
2
+
3
+ import uuid
4
+ from typing import Any, Dict, List, Optional
5
+
6
+
7
+ class GraphNode:
8
+ """
9
+ Represents a single operation/layer in a neural architecture graph.
10
+
11
+ Each node contains:
12
+ - Unique identifier
13
+ - Operation type (conv2d, maxpool, dense, etc.)
14
+ - Operation parameters (filters, kernel_size, etc.)
15
+ - Connections to other nodes (predecessors/successors)
16
+ - Metadata for tracking
17
+
18
+ Attributes:
19
+ id: Unique node identifier
20
+ operation: Operation type
21
+ params: Operation parameters
22
+ predecessors: List of predecessor nodes
23
+ successors: List of successor nodes
24
+ metadata: Additional metadata
25
+
26
+ Example:
27
+ >>> node = GraphNode.create(
28
+ ... operation='conv2d',
29
+ ... params={'filters': 64, 'kernel_size': 3}
30
+ ... )
31
+ >>> node.get_param('filters')
32
+ 64
33
+ """
34
+
35
+ def __init__(
36
+ self,
37
+ node_id: str,
38
+ operation: str,
39
+ params: Optional[Dict[str, Any]] = None,
40
+ metadata: Optional[Dict[str, Any]] = None,
41
+ ):
42
+ """
43
+ Initialize graph node.
44
+
45
+ Args:
46
+ node_id: Unique identifier
47
+ operation: Operation type
48
+ params: Operation parameters
49
+ metadata: Additional metadata
50
+ """
51
+ self.id = node_id
52
+ self.operation = operation
53
+ self.params = params or {}
54
+ self.metadata = metadata or {}
55
+
56
+ # Connections
57
+ self.predecessors: List[GraphNode] = []
58
+ self.successors: List[GraphNode] = []
59
+
60
+ @classmethod
61
+ def create(
62
+ cls,
63
+ operation: str,
64
+ params: Optional[Dict[str, Any]] = None,
65
+ metadata: Optional[Dict[str, Any]] = None,
66
+ ) -> "GraphNode":
67
+ """
68
+ Factory method to create a new node with auto-generated ID.
69
+
70
+ Args:
71
+ operation: Operation type
72
+ params: Operation parameters
73
+ metadata: Additional metadata
74
+
75
+ Returns:
76
+ New GraphNode instance
77
+ """
78
+ node_id = str(uuid.uuid4())
79
+ return cls(node_id, operation, params, metadata)
80
+
81
+ def add_predecessor(self, node: "GraphNode") -> None:
82
+ """
83
+ Add a predecessor node.
84
+
85
+ Args:
86
+ node: Predecessor node to add
87
+ """
88
+ if node not in self.predecessors:
89
+ self.predecessors.append(node)
90
+
91
+ def add_successor(self, node: "GraphNode") -> None:
92
+ """
93
+ Add a successor node.
94
+
95
+ Args:
96
+ node: Successor node to add
97
+ """
98
+ if node not in self.successors:
99
+ self.successors.append(node)
100
+
101
+ def remove_predecessor(self, node: "GraphNode") -> None:
102
+ """Remove a predecessor node."""
103
+ if node in self.predecessors:
104
+ self.predecessors.remove(node)
105
+
106
+ def remove_successor(self, node: "GraphNode") -> None:
107
+ """Remove a successor node."""
108
+ if node in self.successors:
109
+ self.successors.remove(node)
110
+
111
+ def get_param(self, key: str, default: Any = None) -> Any:
112
+ """
113
+ Get operation parameter.
114
+
115
+ Args:
116
+ key: Parameter key
117
+ default: Default value if key not found
118
+
119
+ Returns:
120
+ Parameter value
121
+ """
122
+ return self.params.get(key, default)
123
+
124
+ def set_param(self, key: str, value: Any) -> None:
125
+ """
126
+ Set operation parameter.
127
+
128
+ Args:
129
+ key: Parameter key
130
+ value: Parameter value
131
+ """
132
+ self.params[key] = value
133
+
134
+ def clone(self) -> "GraphNode":
135
+ """
136
+ Create a deep copy of this node.
137
+
138
+ Returns:
139
+ Cloned GraphNode (new ID, same operation and params)
140
+ """
141
+ return GraphNode.create(
142
+ operation=self.operation,
143
+ params=self.params.copy(),
144
+ metadata=self.metadata.copy(),
145
+ )
146
+
147
+ def to_dict(self) -> Dict[str, Any]:
148
+ """
149
+ Serialize node to dictionary.
150
+
151
+ Returns:
152
+ Dictionary representation
153
+ """
154
+ return {
155
+ "id": self.id,
156
+ "operation": self.operation,
157
+ "params": self.params,
158
+ "metadata": self.metadata,
159
+ "predecessor_ids": [p.id for p in self.predecessors],
160
+ "successor_ids": [s.id for s in self.successors],
161
+ }
162
+
163
+ @classmethod
164
+ def from_dict(cls, data: Dict[str, Any]) -> "GraphNode":
165
+ """
166
+ Deserialize node from dictionary.
167
+
168
+ Args:
169
+ data: Dictionary representation
170
+
171
+ Returns:
172
+ GraphNode instance
173
+
174
+ Note:
175
+ Predecessor/successor connections must be restored separately
176
+ """
177
+ return cls(
178
+ node_id=data["id"],
179
+ operation=data["operation"],
180
+ params=data.get("params", {}),
181
+ metadata=data.get("metadata", {}),
182
+ )
183
+
184
+ def __repr__(self) -> str:
185
+ """String representation of node."""
186
+ return f"GraphNode(id={self.id[:8]}, operation={self.operation}, params={self.params})"
187
+
188
+ def __eq__(self, other: object) -> bool:
189
+ """Check equality based on ID."""
190
+ if not isinstance(other, GraphNode):
191
+ return False
192
+ return self.id == other.id
193
+
194
+ def __hash__(self) -> int:
195
+ """Hash based on ID."""
196
+ return hash(self.id)