morphml 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of morphml might be problematic. Click here for more details.

Files changed (158) hide show
  1. morphml/__init__.py +14 -0
  2. morphml/api/__init__.py +26 -0
  3. morphml/api/app.py +326 -0
  4. morphml/api/auth.py +193 -0
  5. morphml/api/client.py +338 -0
  6. morphml/api/models.py +132 -0
  7. morphml/api/rate_limit.py +192 -0
  8. morphml/benchmarking/__init__.py +36 -0
  9. morphml/benchmarking/comparison.py +430 -0
  10. morphml/benchmarks/__init__.py +56 -0
  11. morphml/benchmarks/comparator.py +409 -0
  12. morphml/benchmarks/datasets.py +280 -0
  13. morphml/benchmarks/metrics.py +199 -0
  14. morphml/benchmarks/openml_suite.py +201 -0
  15. morphml/benchmarks/problems.py +289 -0
  16. morphml/benchmarks/suite.py +318 -0
  17. morphml/cli/__init__.py +5 -0
  18. morphml/cli/commands/experiment.py +329 -0
  19. morphml/cli/main.py +457 -0
  20. morphml/cli/quickstart.py +312 -0
  21. morphml/config.py +278 -0
  22. morphml/constraints/__init__.py +19 -0
  23. morphml/constraints/handler.py +205 -0
  24. morphml/constraints/predicates.py +285 -0
  25. morphml/core/__init__.py +3 -0
  26. morphml/core/crossover.py +449 -0
  27. morphml/core/dsl/README.md +359 -0
  28. morphml/core/dsl/__init__.py +72 -0
  29. morphml/core/dsl/ast_nodes.py +364 -0
  30. morphml/core/dsl/compiler.py +318 -0
  31. morphml/core/dsl/layers.py +368 -0
  32. morphml/core/dsl/lexer.py +336 -0
  33. morphml/core/dsl/parser.py +455 -0
  34. morphml/core/dsl/search_space.py +386 -0
  35. morphml/core/dsl/syntax.py +199 -0
  36. morphml/core/dsl/type_system.py +361 -0
  37. morphml/core/dsl/validator.py +386 -0
  38. morphml/core/graph/__init__.py +40 -0
  39. morphml/core/graph/edge.py +124 -0
  40. morphml/core/graph/graph.py +507 -0
  41. morphml/core/graph/mutations.py +409 -0
  42. morphml/core/graph/node.py +196 -0
  43. morphml/core/graph/serialization.py +361 -0
  44. morphml/core/graph/visualization.py +431 -0
  45. morphml/core/objectives/__init__.py +20 -0
  46. morphml/core/search/__init__.py +33 -0
  47. morphml/core/search/individual.py +252 -0
  48. morphml/core/search/parameters.py +453 -0
  49. morphml/core/search/population.py +375 -0
  50. morphml/core/search/search_engine.py +340 -0
  51. morphml/distributed/__init__.py +76 -0
  52. morphml/distributed/fault_tolerance.py +497 -0
  53. morphml/distributed/health_monitor.py +348 -0
  54. morphml/distributed/master.py +709 -0
  55. morphml/distributed/proto/README.md +224 -0
  56. morphml/distributed/proto/__init__.py +74 -0
  57. morphml/distributed/proto/worker.proto +170 -0
  58. morphml/distributed/proto/worker_pb2.py +79 -0
  59. morphml/distributed/proto/worker_pb2_grpc.py +423 -0
  60. morphml/distributed/resource_manager.py +416 -0
  61. morphml/distributed/scheduler.py +567 -0
  62. morphml/distributed/storage/__init__.py +33 -0
  63. morphml/distributed/storage/artifacts.py +381 -0
  64. morphml/distributed/storage/cache.py +366 -0
  65. morphml/distributed/storage/checkpointing.py +329 -0
  66. morphml/distributed/storage/database.py +459 -0
  67. morphml/distributed/worker.py +549 -0
  68. morphml/evaluation/__init__.py +5 -0
  69. morphml/evaluation/heuristic.py +237 -0
  70. morphml/exceptions.py +55 -0
  71. morphml/execution/__init__.py +5 -0
  72. morphml/execution/local_executor.py +350 -0
  73. morphml/integrations/__init__.py +28 -0
  74. morphml/integrations/jax_adapter.py +206 -0
  75. morphml/integrations/pytorch_adapter.py +530 -0
  76. morphml/integrations/sklearn_adapter.py +206 -0
  77. morphml/integrations/tensorflow_adapter.py +230 -0
  78. morphml/logging_config.py +93 -0
  79. morphml/meta_learning/__init__.py +66 -0
  80. morphml/meta_learning/architecture_similarity.py +277 -0
  81. morphml/meta_learning/experiment_database.py +240 -0
  82. morphml/meta_learning/knowledge_base/__init__.py +19 -0
  83. morphml/meta_learning/knowledge_base/embedder.py +179 -0
  84. morphml/meta_learning/knowledge_base/knowledge_base.py +313 -0
  85. morphml/meta_learning/knowledge_base/meta_features.py +265 -0
  86. morphml/meta_learning/knowledge_base/vector_store.py +271 -0
  87. morphml/meta_learning/predictors/__init__.py +27 -0
  88. morphml/meta_learning/predictors/ensemble.py +221 -0
  89. morphml/meta_learning/predictors/gnn_predictor.py +552 -0
  90. morphml/meta_learning/predictors/learning_curve.py +231 -0
  91. morphml/meta_learning/predictors/proxy_metrics.py +261 -0
  92. morphml/meta_learning/strategy_evolution/__init__.py +27 -0
  93. morphml/meta_learning/strategy_evolution/adaptive_optimizer.py +226 -0
  94. morphml/meta_learning/strategy_evolution/bandit.py +276 -0
  95. morphml/meta_learning/strategy_evolution/portfolio.py +230 -0
  96. morphml/meta_learning/transfer.py +581 -0
  97. morphml/meta_learning/warm_start.py +286 -0
  98. morphml/optimizers/__init__.py +74 -0
  99. morphml/optimizers/adaptive_operators.py +399 -0
  100. morphml/optimizers/bayesian/__init__.py +52 -0
  101. morphml/optimizers/bayesian/acquisition.py +387 -0
  102. morphml/optimizers/bayesian/base.py +319 -0
  103. morphml/optimizers/bayesian/gaussian_process.py +635 -0
  104. morphml/optimizers/bayesian/smac.py +534 -0
  105. morphml/optimizers/bayesian/tpe.py +411 -0
  106. morphml/optimizers/differential_evolution.py +220 -0
  107. morphml/optimizers/evolutionary/__init__.py +61 -0
  108. morphml/optimizers/evolutionary/cma_es.py +416 -0
  109. morphml/optimizers/evolutionary/differential_evolution.py +556 -0
  110. morphml/optimizers/evolutionary/encoding.py +426 -0
  111. morphml/optimizers/evolutionary/particle_swarm.py +449 -0
  112. morphml/optimizers/genetic_algorithm.py +486 -0
  113. morphml/optimizers/gradient_based/__init__.py +22 -0
  114. morphml/optimizers/gradient_based/darts.py +550 -0
  115. morphml/optimizers/gradient_based/enas.py +585 -0
  116. morphml/optimizers/gradient_based/operations.py +474 -0
  117. morphml/optimizers/gradient_based/utils.py +601 -0
  118. morphml/optimizers/hill_climbing.py +169 -0
  119. morphml/optimizers/multi_objective/__init__.py +56 -0
  120. morphml/optimizers/multi_objective/indicators.py +504 -0
  121. morphml/optimizers/multi_objective/nsga2.py +647 -0
  122. morphml/optimizers/multi_objective/visualization.py +427 -0
  123. morphml/optimizers/nsga2.py +308 -0
  124. morphml/optimizers/random_search.py +172 -0
  125. morphml/optimizers/simulated_annealing.py +181 -0
  126. morphml/plugins/__init__.py +35 -0
  127. morphml/plugins/custom_evaluator_example.py +81 -0
  128. morphml/plugins/custom_optimizer_example.py +63 -0
  129. morphml/plugins/plugin_system.py +454 -0
  130. morphml/reports/__init__.py +30 -0
  131. morphml/reports/generator.py +362 -0
  132. morphml/tracking/__init__.py +7 -0
  133. morphml/tracking/experiment.py +309 -0
  134. morphml/tracking/logger.py +301 -0
  135. morphml/tracking/reporter.py +357 -0
  136. morphml/utils/__init__.py +6 -0
  137. morphml/utils/checkpoint.py +189 -0
  138. morphml/utils/comparison.py +390 -0
  139. morphml/utils/export.py +407 -0
  140. morphml/utils/progress.py +392 -0
  141. morphml/utils/validation.py +392 -0
  142. morphml/version.py +7 -0
  143. morphml/visualization/__init__.py +50 -0
  144. morphml/visualization/analytics.py +423 -0
  145. morphml/visualization/architecture_diagrams.py +353 -0
  146. morphml/visualization/architecture_plot.py +223 -0
  147. morphml/visualization/convergence_plot.py +174 -0
  148. morphml/visualization/crossover_viz.py +386 -0
  149. morphml/visualization/graph_viz.py +338 -0
  150. morphml/visualization/pareto_plot.py +149 -0
  151. morphml/visualization/plotly_dashboards.py +422 -0
  152. morphml/visualization/population.py +309 -0
  153. morphml/visualization/progress.py +260 -0
  154. morphml-1.0.0.dist-info/METADATA +434 -0
  155. morphml-1.0.0.dist-info/RECORD +158 -0
  156. morphml-1.0.0.dist-info/WHEEL +4 -0
  157. morphml-1.0.0.dist-info/entry_points.txt +3 -0
  158. morphml-1.0.0.dist-info/licenses/LICENSE +21 -0
@@ -0,0 +1,386 @@
1
+ """Semantic validator for MorphML DSL.
2
+
3
+ Validates AST for semantic correctness beyond syntax.
4
+
5
+ Author: Eshan Roy <eshanized@proton.me>
6
+ Organization: TONMOY INFRASTRUCTURE & VISION
7
+ """
8
+
9
+ from typing import Dict, List, Set
10
+
11
+ from morphml.core.dsl.ast_nodes import (
12
+ ASTVisitor,
13
+ ConstraintNode,
14
+ EvolutionNode,
15
+ ExperimentNode,
16
+ LayerNode,
17
+ ParamNode,
18
+ SearchSpaceNode,
19
+ )
20
+ from morphml.core.dsl.syntax import EVOLUTION_STRATEGIES, LAYER_TYPES
21
+ from morphml.logging_config import get_logger
22
+
23
+ logger = get_logger(__name__)
24
+
25
+
26
+ class ValidationError:
27
+ """Represents a validation error with location information."""
28
+
29
+ def __init__(self, message: str, node: str = "Unknown"):
30
+ """
31
+ Initialize validation error.
32
+
33
+ Args:
34
+ message: Error message
35
+ node: Node type where error occurred
36
+ """
37
+ self.message = message
38
+ self.node = node
39
+
40
+ def __str__(self) -> str:
41
+ """String representation."""
42
+ return f"[{self.node}] {self.message}"
43
+
44
+
45
+ class Validator(ASTVisitor):
46
+ """
47
+ Validates AST for semantic correctness.
48
+
49
+ Checks:
50
+ - Required parameters are present
51
+ - Parameter values are valid
52
+ - Layer types are supported
53
+ - Evolution strategy is valid
54
+ - No circular dependencies
55
+ - Type consistency
56
+
57
+ Example:
58
+ >>> validator = Validator()
59
+ >>> errors = validator.validate(ast)
60
+ >>> if errors:
61
+ ... for error in errors:
62
+ ... print(error)
63
+ """
64
+
65
+ def __init__(self) -> None:
66
+ """Initialize validator."""
67
+ self.errors: List[ValidationError] = []
68
+ self.warnings: List[str] = []
69
+ self.layer_types_seen: Set[str] = set()
70
+
71
+ # Required parameters for each layer type
72
+ self.required_params: Dict[str, List[str]] = {
73
+ "conv2d": ["filters", "kernel_size"],
74
+ "conv3d": ["filters", "kernel_size"],
75
+ "dense": ["units"],
76
+ "linear": ["units"],
77
+ "dropout": ["rate"],
78
+ "max_pool": ["pool_size"],
79
+ "avg_pool": ["pool_size"],
80
+ }
81
+
82
+ def validate(self, ast: ExperimentNode) -> List[ValidationError]:
83
+ """
84
+ Validate complete experiment AST.
85
+
86
+ Args:
87
+ ast: ExperimentNode to validate
88
+
89
+ Returns:
90
+ List of validation errors (empty if valid)
91
+ """
92
+ logger.info("Starting validation")
93
+
94
+ # Reset state
95
+ self.errors.clear()
96
+ self.warnings.clear()
97
+ self.layer_types_seen.clear()
98
+
99
+ # Validate via visitor pattern
100
+ ast.accept(self)
101
+
102
+ # Additional high-level validations
103
+ self._validate_layer_diversity()
104
+
105
+ if self.errors:
106
+ logger.warning(f"Validation found {len(self.errors)} errors")
107
+ else:
108
+ logger.info("Validation successful")
109
+
110
+ return self.errors
111
+
112
+ def visit_experiment(self, node: ExperimentNode) -> None:
113
+ """Validate experiment node."""
114
+ # Validate search space is present
115
+ if not node.search_space:
116
+ self.errors.append(ValidationError("Experiment must have a search space", "Experiment"))
117
+
118
+ # Continue validation
119
+ super().visit_experiment(node)
120
+
121
+ def visit_search_space(self, node: SearchSpaceNode) -> None:
122
+ """Validate search space node."""
123
+ # Check at least one layer
124
+ if not node.layers:
125
+ self.errors.append(
126
+ ValidationError("SearchSpace must contain at least one layer", "SearchSpace")
127
+ )
128
+
129
+ # Check for input/output layers
130
+ has_input = any(layer.layer_type == "input" for layer in node.layers)
131
+ has_output = any(layer.layer_type == "output" for layer in node.layers)
132
+
133
+ if len(node.layers) > 2: # Only check if there are actual layers
134
+ if not has_input:
135
+ self.warnings.append("SearchSpace should have an input layer")
136
+ if not has_output:
137
+ self.warnings.append("SearchSpace should have an output layer")
138
+
139
+ # Validate each layer
140
+ super().visit_search_space(node)
141
+
142
+ def visit_layer(self, node: LayerNode) -> None:
143
+ """Validate layer node."""
144
+ # Track layer type
145
+ self.layer_types_seen.add(node.layer_type)
146
+
147
+ # Check layer type is supported
148
+ if node.layer_type not in LAYER_TYPES:
149
+ self.errors.append(
150
+ ValidationError(
151
+ f"Unsupported layer type: '{node.layer_type}'. "
152
+ f"Valid types: {', '.join(LAYER_TYPES)}",
153
+ "Layer",
154
+ )
155
+ )
156
+ return
157
+
158
+ # Check required parameters
159
+ required = self.required_params.get(node.layer_type, [])
160
+ for param_name in required:
161
+ if param_name not in node.params:
162
+ self.errors.append(
163
+ ValidationError(
164
+ f"Layer '{node.layer_type}' missing required parameter: '{param_name}'",
165
+ "Layer",
166
+ )
167
+ )
168
+
169
+ # Validate each parameter
170
+ for param_name, param_node in node.params.items():
171
+ self._validate_layer_param(node.layer_type, param_name, param_node)
172
+
173
+ super().visit_layer(node)
174
+
175
+ def _validate_layer_param(
176
+ self, layer_type: str, param_name: str, param_node: ParamNode
177
+ ) -> None:
178
+ """
179
+ Validate a layer parameter.
180
+
181
+ Args:
182
+ layer_type: Type of layer
183
+ param_name: Parameter name
184
+ param_node: Parameter node
185
+ """
186
+ # Check parameter has values
187
+ if not param_node.values:
188
+ self.errors.append(
189
+ ValidationError(f"Parameter '{param_name}' has no values", "Parameter")
190
+ )
191
+ return
192
+
193
+ # Type-specific validation
194
+ if param_name == "filters":
195
+ self._validate_filters(param_node)
196
+ elif param_name == "kernel_size":
197
+ self._validate_kernel_size(param_node)
198
+ elif param_name == "units":
199
+ self._validate_units(param_node)
200
+ elif param_name == "rate":
201
+ self._validate_dropout_rate(param_node)
202
+ elif param_name == "pool_size":
203
+ self._validate_pool_size(param_node)
204
+
205
+ def _validate_filters(self, param_node: ParamNode) -> None:
206
+ """Validate filter count parameter."""
207
+ for value in param_node.values:
208
+ if not isinstance(value, int):
209
+ self.errors.append(
210
+ ValidationError(
211
+ f"Filter count must be integer, got {type(value).__name__}", "Parameter"
212
+ )
213
+ )
214
+ elif value <= 0:
215
+ self.errors.append(
216
+ ValidationError(f"Filter count must be positive, got {value}", "Parameter")
217
+ )
218
+
219
+ def _validate_kernel_size(self, param_node: ParamNode) -> None:
220
+ """Validate kernel size parameter."""
221
+ for value in param_node.values:
222
+ if not isinstance(value, int):
223
+ self.errors.append(
224
+ ValidationError(
225
+ f"Kernel size must be integer, got {type(value).__name__}", "Parameter"
226
+ )
227
+ )
228
+ elif value <= 0:
229
+ self.errors.append(
230
+ ValidationError(f"Kernel size must be positive, got {value}", "Parameter")
231
+ )
232
+ elif value % 2 == 0:
233
+ self.warnings.append(
234
+ f"Kernel size {value} is even (odd sizes are typically preferred)"
235
+ )
236
+
237
+ def _validate_units(self, param_node: ParamNode) -> None:
238
+ """Validate units parameter for dense layers."""
239
+ for value in param_node.values:
240
+ if not isinstance(value, int):
241
+ self.errors.append(
242
+ ValidationError(
243
+ f"Units must be integer, got {type(value).__name__}", "Parameter"
244
+ )
245
+ )
246
+ elif value <= 0:
247
+ self.errors.append(
248
+ ValidationError(f"Units must be positive, got {value}", "Parameter")
249
+ )
250
+
251
+ def _validate_dropout_rate(self, param_node: ParamNode) -> None:
252
+ """Validate dropout rate parameter."""
253
+ for value in param_node.values:
254
+ if not isinstance(value, (int, float)):
255
+ self.errors.append(
256
+ ValidationError(
257
+ f"Dropout rate must be numeric, got {type(value).__name__}", "Parameter"
258
+ )
259
+ )
260
+ elif not (0 <= value < 1):
261
+ self.errors.append(
262
+ ValidationError(f"Dropout rate must be in [0, 1), got {value}", "Parameter")
263
+ )
264
+
265
+ def _validate_pool_size(self, param_node: ParamNode) -> None:
266
+ """Validate pooling size parameter."""
267
+ for value in param_node.values:
268
+ if not isinstance(value, int):
269
+ self.errors.append(
270
+ ValidationError(
271
+ f"Pool size must be integer, got {type(value).__name__}", "Parameter"
272
+ )
273
+ )
274
+ elif value <= 0:
275
+ self.errors.append(
276
+ ValidationError(f"Pool size must be positive, got {value}", "Parameter")
277
+ )
278
+
279
+ def visit_param(self, node: ParamNode) -> None:
280
+ """Validate parameter node."""
281
+ # Check values list is not empty
282
+ if not node.values:
283
+ self.errors.append(
284
+ ValidationError(f"Parameter '{node.name}' has no values", "Parameter")
285
+ )
286
+
287
+ # Check type consistency
288
+ if len(node.values) > 1:
289
+ types = {type(v) for v in node.values}
290
+ # Allow mixing int and float
291
+ if types - {int, float}:
292
+ if len(types) > 1 and not (types <= {int, float}):
293
+ self.warnings.append(
294
+ f"Parameter '{node.name}' has mixed types: {[t.__name__ for t in types]}"
295
+ )
296
+
297
+ def visit_evolution(self, node: EvolutionNode) -> None:
298
+ """Validate evolution node."""
299
+ # Check strategy is supported
300
+ if node.strategy not in EVOLUTION_STRATEGIES:
301
+ self.errors.append(
302
+ ValidationError(
303
+ f"Unknown evolution strategy: '{node.strategy}'. "
304
+ f"Valid strategies: {', '.join(EVOLUTION_STRATEGIES)}",
305
+ "Evolution",
306
+ )
307
+ )
308
+
309
+ # Validate strategy-specific parameters
310
+ if node.strategy == "genetic":
311
+ self._validate_genetic_params(node)
312
+ elif node.strategy == "bayesian":
313
+ self._validate_bayesian_params(node)
314
+
315
+ def _validate_genetic_params(self, node: EvolutionNode) -> None:
316
+ """Validate genetic algorithm parameters."""
317
+ # Check for recommended parameters
318
+ recommended = ["population_size", "num_generations", "mutation_rate"]
319
+ for param in recommended:
320
+ if param not in node.params:
321
+ self.warnings.append(f"Genetic algorithm should specify '{param}' parameter")
322
+
323
+ # Validate ranges
324
+ if "population_size" in node.params:
325
+ pop_size = node.params["population_size"]
326
+ if isinstance(pop_size, int) and pop_size < 2:
327
+ self.errors.append(
328
+ ValidationError("Population size must be at least 2", "Evolution")
329
+ )
330
+
331
+ if "mutation_rate" in node.params:
332
+ rate = node.params["mutation_rate"]
333
+ if isinstance(rate, (int, float)) and not (0 <= rate <= 1):
334
+ self.errors.append(
335
+ ValidationError(f"Mutation rate must be in [0, 1], got {rate}", "Evolution")
336
+ )
337
+
338
+ def _validate_bayesian_params(self, node: EvolutionNode) -> None:
339
+ """Validate Bayesian optimization parameters."""
340
+ # Check for recommended parameters
341
+ if "num_iterations" not in node.params and "max_evaluations" not in node.params:
342
+ self.warnings.append("Bayesian optimization should specify iteration budget")
343
+
344
+ def visit_constraint(self, node: ConstraintNode) -> None:
345
+ """Validate constraint node."""
346
+ # Check constraint type is known
347
+ known_constraints = ["max_depth", "max_nodes", "max_params", "required_layers"]
348
+
349
+ if node.constraint_type not in known_constraints:
350
+ self.warnings.append(
351
+ f"Unknown constraint type: '{node.constraint_type}' "
352
+ f"(known types: {', '.join(known_constraints)})"
353
+ )
354
+
355
+ def _validate_layer_diversity(self) -> None:
356
+ """Validate that search space has sufficient diversity."""
357
+ if len(self.layer_types_seen) < 2:
358
+ self.warnings.append(
359
+ f"Search space has low layer diversity: only {len(self.layer_types_seen)} "
360
+ "different layer types"
361
+ )
362
+
363
+ # Check for common architectural patterns
364
+ has_conv = any(lt.startswith("conv") for lt in self.layer_types_seen)
365
+ has_pool = any("pool" in lt for lt in self.layer_types_seen)
366
+ has_dense = "dense" in self.layer_types_seen
367
+
368
+ if has_conv and not has_pool:
369
+ self.warnings.append("CNN architecture detected but no pooling layers specified")
370
+
371
+ if has_conv and not has_dense:
372
+ self.warnings.append("CNN architecture detected but no dense layers for classification")
373
+
374
+
375
+ def validate_ast(ast: ExperimentNode) -> List[ValidationError]:
376
+ """
377
+ Convenience function to validate AST.
378
+
379
+ Args:
380
+ ast: ExperimentNode to validate
381
+
382
+ Returns:
383
+ List of validation errors
384
+ """
385
+ validator = Validator()
386
+ return validator.validate(ast)
@@ -0,0 +1,40 @@
1
+ """Graph-based model representation."""
2
+
3
+ from morphml.core.graph.edge import GraphEdge
4
+ from morphml.core.graph.graph import ModelGraph
5
+ from morphml.core.graph.mutations import GraphMutator, crossover
6
+ from morphml.core.graph.node import GraphNode
7
+ from morphml.core.graph.serialization import (
8
+ batch_load_graphs,
9
+ batch_save_graphs,
10
+ export_graph_summary,
11
+ graph_from_json_string,
12
+ graph_to_json_string,
13
+ load_graph,
14
+ save_graph,
15
+ )
16
+ from morphml.core.graph.visualization import (
17
+ export_graphviz,
18
+ plot_architecture_comparison,
19
+ plot_graph,
20
+ plot_training_history,
21
+ )
22
+
23
+ __all__ = [
24
+ "GraphNode",
25
+ "GraphEdge",
26
+ "ModelGraph",
27
+ "GraphMutator",
28
+ "crossover",
29
+ "save_graph",
30
+ "load_graph",
31
+ "graph_to_json_string",
32
+ "graph_from_json_string",
33
+ "export_graph_summary",
34
+ "batch_save_graphs",
35
+ "batch_load_graphs",
36
+ "plot_graph",
37
+ "plot_training_history",
38
+ "plot_architecture_comparison",
39
+ "export_graphviz",
40
+ ]
@@ -0,0 +1,124 @@
1
+ """Graph edge representation for neural architecture."""
2
+
3
+ import uuid
4
+ from typing import Any, Dict, Optional
5
+
6
+ from morphml.core.graph.node import GraphNode
7
+ from morphml.exceptions import GraphError
8
+
9
+
10
+ class GraphEdge:
11
+ """
12
+ Represents a connection between two nodes in a neural architecture graph.
13
+
14
+ Each edge contains:
15
+ - Unique identifier
16
+ - Source and target nodes
17
+ - Optional operation (for edge-level operations)
18
+ - Metadata
19
+
20
+ Attributes:
21
+ id: Unique edge identifier
22
+ source: Source node
23
+ target: Target node
24
+ operation: Optional edge operation
25
+ metadata: Additional metadata
26
+
27
+ Example:
28
+ >>> source = GraphNode.create('conv2d')
29
+ >>> target = GraphNode.create('relu')
30
+ >>> edge = GraphEdge(source, target)
31
+ """
32
+
33
+ def __init__(
34
+ self,
35
+ source: GraphNode,
36
+ target: GraphNode,
37
+ operation: Optional[str] = None,
38
+ metadata: Optional[Dict[str, Any]] = None,
39
+ edge_id: Optional[str] = None,
40
+ ):
41
+ """
42
+ Initialize graph edge.
43
+
44
+ Args:
45
+ source: Source node
46
+ target: Target node
47
+ operation: Optional edge operation
48
+ metadata: Additional metadata
49
+ edge_id: Optional edge ID (auto-generated if None)
50
+
51
+ Raises:
52
+ GraphError: If source or target is None
53
+ """
54
+ if source is None or target is None:
55
+ raise GraphError("Source and target nodes cannot be None")
56
+
57
+ self.id = edge_id or str(uuid.uuid4())
58
+ self.source = source
59
+ self.target = target
60
+ self.operation = operation
61
+ self.metadata = metadata or {}
62
+
63
+ def to_dict(self) -> Dict[str, Any]:
64
+ """
65
+ Serialize edge to dictionary.
66
+
67
+ Returns:
68
+ Dictionary representation
69
+ """
70
+ return {
71
+ "id": self.id,
72
+ "source_id": self.source.id,
73
+ "target_id": self.target.id,
74
+ "operation": self.operation,
75
+ "metadata": self.metadata,
76
+ }
77
+
78
+ @classmethod
79
+ def from_dict(cls, data: Dict[str, Any], node_map: Dict[str, GraphNode]) -> "GraphEdge":
80
+ """
81
+ Deserialize edge from dictionary.
82
+
83
+ Args:
84
+ data: Dictionary representation
85
+ node_map: Mapping of node IDs to GraphNode instances
86
+
87
+ Returns:
88
+ GraphEdge instance
89
+
90
+ Raises:
91
+ GraphError: If source or target node not found
92
+ """
93
+ source_id = data["source_id"]
94
+ target_id = data["target_id"]
95
+
96
+ if source_id not in node_map or target_id not in node_map:
97
+ raise GraphError(f"Source or target node not found: {source_id}, {target_id}")
98
+
99
+ return cls(
100
+ source=node_map[source_id],
101
+ target=node_map[target_id],
102
+ operation=data.get("operation"),
103
+ metadata=data.get("metadata", {}),
104
+ edge_id=data["id"],
105
+ )
106
+
107
+ def __repr__(self) -> str:
108
+ """String representation of edge."""
109
+ op_str = f", operation={self.operation}" if self.operation else ""
110
+ return (
111
+ f"GraphEdge(id={self.id[:8]}, "
112
+ f"source={self.source.operation}, "
113
+ f"target={self.target.operation}{op_str})"
114
+ )
115
+
116
+ def __eq__(self, other: object) -> bool:
117
+ """Check equality based on ID."""
118
+ if not isinstance(other, GraphEdge):
119
+ return False
120
+ return self.id == other.id
121
+
122
+ def __hash__(self) -> int:
123
+ """Hash based on ID."""
124
+ return hash(self.id)