morphml 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of morphml might be problematic. Click here for more details.

Files changed (158) hide show
  1. morphml/__init__.py +14 -0
  2. morphml/api/__init__.py +26 -0
  3. morphml/api/app.py +326 -0
  4. morphml/api/auth.py +193 -0
  5. morphml/api/client.py +338 -0
  6. morphml/api/models.py +132 -0
  7. morphml/api/rate_limit.py +192 -0
  8. morphml/benchmarking/__init__.py +36 -0
  9. morphml/benchmarking/comparison.py +430 -0
  10. morphml/benchmarks/__init__.py +56 -0
  11. morphml/benchmarks/comparator.py +409 -0
  12. morphml/benchmarks/datasets.py +280 -0
  13. morphml/benchmarks/metrics.py +199 -0
  14. morphml/benchmarks/openml_suite.py +201 -0
  15. morphml/benchmarks/problems.py +289 -0
  16. morphml/benchmarks/suite.py +318 -0
  17. morphml/cli/__init__.py +5 -0
  18. morphml/cli/commands/experiment.py +329 -0
  19. morphml/cli/main.py +457 -0
  20. morphml/cli/quickstart.py +312 -0
  21. morphml/config.py +278 -0
  22. morphml/constraints/__init__.py +19 -0
  23. morphml/constraints/handler.py +205 -0
  24. morphml/constraints/predicates.py +285 -0
  25. morphml/core/__init__.py +3 -0
  26. morphml/core/crossover.py +449 -0
  27. morphml/core/dsl/README.md +359 -0
  28. morphml/core/dsl/__init__.py +72 -0
  29. morphml/core/dsl/ast_nodes.py +364 -0
  30. morphml/core/dsl/compiler.py +318 -0
  31. morphml/core/dsl/layers.py +368 -0
  32. morphml/core/dsl/lexer.py +336 -0
  33. morphml/core/dsl/parser.py +455 -0
  34. morphml/core/dsl/search_space.py +386 -0
  35. morphml/core/dsl/syntax.py +199 -0
  36. morphml/core/dsl/type_system.py +361 -0
  37. morphml/core/dsl/validator.py +386 -0
  38. morphml/core/graph/__init__.py +40 -0
  39. morphml/core/graph/edge.py +124 -0
  40. morphml/core/graph/graph.py +507 -0
  41. morphml/core/graph/mutations.py +409 -0
  42. morphml/core/graph/node.py +196 -0
  43. morphml/core/graph/serialization.py +361 -0
  44. morphml/core/graph/visualization.py +431 -0
  45. morphml/core/objectives/__init__.py +20 -0
  46. morphml/core/search/__init__.py +33 -0
  47. morphml/core/search/individual.py +252 -0
  48. morphml/core/search/parameters.py +453 -0
  49. morphml/core/search/population.py +375 -0
  50. morphml/core/search/search_engine.py +340 -0
  51. morphml/distributed/__init__.py +76 -0
  52. morphml/distributed/fault_tolerance.py +497 -0
  53. morphml/distributed/health_monitor.py +348 -0
  54. morphml/distributed/master.py +709 -0
  55. morphml/distributed/proto/README.md +224 -0
  56. morphml/distributed/proto/__init__.py +74 -0
  57. morphml/distributed/proto/worker.proto +170 -0
  58. morphml/distributed/proto/worker_pb2.py +79 -0
  59. morphml/distributed/proto/worker_pb2_grpc.py +423 -0
  60. morphml/distributed/resource_manager.py +416 -0
  61. morphml/distributed/scheduler.py +567 -0
  62. morphml/distributed/storage/__init__.py +33 -0
  63. morphml/distributed/storage/artifacts.py +381 -0
  64. morphml/distributed/storage/cache.py +366 -0
  65. morphml/distributed/storage/checkpointing.py +329 -0
  66. morphml/distributed/storage/database.py +459 -0
  67. morphml/distributed/worker.py +549 -0
  68. morphml/evaluation/__init__.py +5 -0
  69. morphml/evaluation/heuristic.py +237 -0
  70. morphml/exceptions.py +55 -0
  71. morphml/execution/__init__.py +5 -0
  72. morphml/execution/local_executor.py +350 -0
  73. morphml/integrations/__init__.py +28 -0
  74. morphml/integrations/jax_adapter.py +206 -0
  75. morphml/integrations/pytorch_adapter.py +530 -0
  76. morphml/integrations/sklearn_adapter.py +206 -0
  77. morphml/integrations/tensorflow_adapter.py +230 -0
  78. morphml/logging_config.py +93 -0
  79. morphml/meta_learning/__init__.py +66 -0
  80. morphml/meta_learning/architecture_similarity.py +277 -0
  81. morphml/meta_learning/experiment_database.py +240 -0
  82. morphml/meta_learning/knowledge_base/__init__.py +19 -0
  83. morphml/meta_learning/knowledge_base/embedder.py +179 -0
  84. morphml/meta_learning/knowledge_base/knowledge_base.py +313 -0
  85. morphml/meta_learning/knowledge_base/meta_features.py +265 -0
  86. morphml/meta_learning/knowledge_base/vector_store.py +271 -0
  87. morphml/meta_learning/predictors/__init__.py +27 -0
  88. morphml/meta_learning/predictors/ensemble.py +221 -0
  89. morphml/meta_learning/predictors/gnn_predictor.py +552 -0
  90. morphml/meta_learning/predictors/learning_curve.py +231 -0
  91. morphml/meta_learning/predictors/proxy_metrics.py +261 -0
  92. morphml/meta_learning/strategy_evolution/__init__.py +27 -0
  93. morphml/meta_learning/strategy_evolution/adaptive_optimizer.py +226 -0
  94. morphml/meta_learning/strategy_evolution/bandit.py +276 -0
  95. morphml/meta_learning/strategy_evolution/portfolio.py +230 -0
  96. morphml/meta_learning/transfer.py +581 -0
  97. morphml/meta_learning/warm_start.py +286 -0
  98. morphml/optimizers/__init__.py +74 -0
  99. morphml/optimizers/adaptive_operators.py +399 -0
  100. morphml/optimizers/bayesian/__init__.py +52 -0
  101. morphml/optimizers/bayesian/acquisition.py +387 -0
  102. morphml/optimizers/bayesian/base.py +319 -0
  103. morphml/optimizers/bayesian/gaussian_process.py +635 -0
  104. morphml/optimizers/bayesian/smac.py +534 -0
  105. morphml/optimizers/bayesian/tpe.py +411 -0
  106. morphml/optimizers/differential_evolution.py +220 -0
  107. morphml/optimizers/evolutionary/__init__.py +61 -0
  108. morphml/optimizers/evolutionary/cma_es.py +416 -0
  109. morphml/optimizers/evolutionary/differential_evolution.py +556 -0
  110. morphml/optimizers/evolutionary/encoding.py +426 -0
  111. morphml/optimizers/evolutionary/particle_swarm.py +449 -0
  112. morphml/optimizers/genetic_algorithm.py +486 -0
  113. morphml/optimizers/gradient_based/__init__.py +22 -0
  114. morphml/optimizers/gradient_based/darts.py +550 -0
  115. morphml/optimizers/gradient_based/enas.py +585 -0
  116. morphml/optimizers/gradient_based/operations.py +474 -0
  117. morphml/optimizers/gradient_based/utils.py +601 -0
  118. morphml/optimizers/hill_climbing.py +169 -0
  119. morphml/optimizers/multi_objective/__init__.py +56 -0
  120. morphml/optimizers/multi_objective/indicators.py +504 -0
  121. morphml/optimizers/multi_objective/nsga2.py +647 -0
  122. morphml/optimizers/multi_objective/visualization.py +427 -0
  123. morphml/optimizers/nsga2.py +308 -0
  124. morphml/optimizers/random_search.py +172 -0
  125. morphml/optimizers/simulated_annealing.py +181 -0
  126. morphml/plugins/__init__.py +35 -0
  127. morphml/plugins/custom_evaluator_example.py +81 -0
  128. morphml/plugins/custom_optimizer_example.py +63 -0
  129. morphml/plugins/plugin_system.py +454 -0
  130. morphml/reports/__init__.py +30 -0
  131. morphml/reports/generator.py +362 -0
  132. morphml/tracking/__init__.py +7 -0
  133. morphml/tracking/experiment.py +309 -0
  134. morphml/tracking/logger.py +301 -0
  135. morphml/tracking/reporter.py +357 -0
  136. morphml/utils/__init__.py +6 -0
  137. morphml/utils/checkpoint.py +189 -0
  138. morphml/utils/comparison.py +390 -0
  139. morphml/utils/export.py +407 -0
  140. morphml/utils/progress.py +392 -0
  141. morphml/utils/validation.py +392 -0
  142. morphml/version.py +7 -0
  143. morphml/visualization/__init__.py +50 -0
  144. morphml/visualization/analytics.py +423 -0
  145. morphml/visualization/architecture_diagrams.py +353 -0
  146. morphml/visualization/architecture_plot.py +223 -0
  147. morphml/visualization/convergence_plot.py +174 -0
  148. morphml/visualization/crossover_viz.py +386 -0
  149. morphml/visualization/graph_viz.py +338 -0
  150. morphml/visualization/pareto_plot.py +149 -0
  151. morphml/visualization/plotly_dashboards.py +422 -0
  152. morphml/visualization/population.py +309 -0
  153. morphml/visualization/progress.py +260 -0
  154. morphml-1.0.0.dist-info/METADATA +434 -0
  155. morphml-1.0.0.dist-info/RECORD +158 -0
  156. morphml-1.0.0.dist-info/WHEEL +4 -0
  157. morphml-1.0.0.dist-info/entry_points.txt +3 -0
  158. morphml-1.0.0.dist-info/licenses/LICENSE +21 -0
@@ -0,0 +1,189 @@
1
+ """Checkpointing for saving and resuming optimization.
2
+
3
+ Example:
4
+ >>> from morphml.utils import Checkpoint
5
+ >>>
6
+ >>> # Save
7
+ >>> Checkpoint.save(optimizer, 'checkpoint.json')
8
+ >>>
9
+ >>> # Load
10
+ >>> optimizer = Checkpoint.load('checkpoint.json', space)
11
+ """
12
+
13
+ import json
14
+ from pathlib import Path
15
+ from typing import Any
16
+
17
+ from morphml.core.dsl.search_space import SearchSpace
18
+ from morphml.core.search import Individual, Population
19
+ from morphml.logging_config import get_logger
20
+
21
+ logger = get_logger(__name__)
22
+
23
+
24
+ class Checkpoint:
25
+ """
26
+ Checkpoint management for optimization.
27
+
28
+ Saves and loads optimizer state to enable resuming
29
+ long-running searches.
30
+
31
+ Example:
32
+ >>> # During optimization
33
+ >>> if generation % 10 == 0:
34
+ ... Checkpoint.save(ga, f'checkpoint_{generation}.json')
35
+ >>>
36
+ >>> # Resume later
37
+ >>> ga = Checkpoint.load('checkpoint_50.json', search_space)
38
+ >>> ga.optimize(evaluator) # Continue from generation 50
39
+ """
40
+
41
+ @staticmethod
42
+ def save(optimizer: Any, filepath: str) -> None:
43
+ """
44
+ Save optimizer state to file.
45
+
46
+ Args:
47
+ optimizer: Optimizer instance (GA, RandomSearch, etc.)
48
+ filepath: Path to save checkpoint
49
+
50
+ Example:
51
+ >>> Checkpoint.save(ga, 'my_checkpoint.json')
52
+ """
53
+ try:
54
+ filepath = Path(filepath)
55
+ filepath.parent.mkdir(parents=True, exist_ok=True)
56
+
57
+ # Build checkpoint data
58
+ checkpoint = {
59
+ "optimizer_type": type(optimizer).__name__,
60
+ "config": getattr(optimizer, "config", {}),
61
+ }
62
+
63
+ # Save optimizer-specific state
64
+ if hasattr(optimizer, "population"):
65
+ # For population-based optimizers (GA, etc.)
66
+ pop = optimizer.population
67
+ checkpoint["population"] = {
68
+ "generation": pop.generation,
69
+ "max_size": pop.max_size,
70
+ "elitism": pop.elitism,
71
+ "individuals": [ind.to_dict() for ind in pop.individuals],
72
+ }
73
+
74
+ if hasattr(optimizer, "history"):
75
+ checkpoint["history"] = optimizer.history
76
+
77
+ if hasattr(optimizer, "best_individual") and optimizer.best_individual:
78
+ checkpoint["best_individual"] = optimizer.best_individual.to_dict()
79
+
80
+ if hasattr(optimizer, "evaluated"):
81
+ checkpoint["evaluated"] = [ind.to_dict() for ind in optimizer.evaluated]
82
+
83
+ # Write to file
84
+ with open(filepath, "w") as f:
85
+ json.dump(checkpoint, f, indent=2)
86
+
87
+ logger.info(f"Checkpoint saved to {filepath}")
88
+
89
+ except Exception as e:
90
+ logger.error(f"Failed to save checkpoint: {e}")
91
+ raise
92
+
93
+ @staticmethod
94
+ def load(
95
+ filepath: str,
96
+ search_space: SearchSpace,
97
+ optimizer_class: Any = None,
98
+ ) -> Any:
99
+ """
100
+ Load optimizer from checkpoint.
101
+
102
+ Args:
103
+ filepath: Path to checkpoint file
104
+ search_space: SearchSpace instance
105
+ optimizer_class: Optimizer class (auto-detected if None)
106
+
107
+ Returns:
108
+ Restored optimizer instance
109
+
110
+ Example:
111
+ >>> from morphml.optimizers import GeneticAlgorithm
112
+ >>> ga = Checkpoint.load('checkpoint.json', space, GeneticAlgorithm)
113
+ """
114
+ try:
115
+ with open(filepath, "r") as f:
116
+ checkpoint = json.load(f)
117
+
118
+ # Determine optimizer class
119
+ if optimizer_class is None:
120
+ optimizer_type = checkpoint.get("optimizer_type", "GeneticAlgorithm")
121
+
122
+ # Import appropriate class
123
+ if optimizer_type == "GeneticAlgorithm":
124
+ from morphml.optimizers import GeneticAlgorithm
125
+
126
+ optimizer_class = GeneticAlgorithm
127
+ elif optimizer_type == "RandomSearch":
128
+ from morphml.optimizers import RandomSearch
129
+
130
+ optimizer_class = RandomSearch
131
+ elif optimizer_type == "HillClimbing":
132
+ from morphml.optimizers import HillClimbing
133
+
134
+ optimizer_class = HillClimbing
135
+ else:
136
+ raise ValueError(f"Unknown optimizer type: {optimizer_type}")
137
+
138
+ # Create optimizer with saved config
139
+ config = checkpoint.get("config", {})
140
+ optimizer = optimizer_class(search_space=search_space, **config)
141
+
142
+ # Restore population if exists
143
+ if "population" in checkpoint:
144
+ pop_data = checkpoint["population"]
145
+ population = Population(max_size=pop_data["max_size"], elitism=pop_data["elitism"])
146
+
147
+ # Restore individuals
148
+ for ind_data in pop_data["individuals"]:
149
+ individual = Individual.from_dict(ind_data)
150
+ population.add(individual)
151
+
152
+ population.generation = pop_data["generation"]
153
+ optimizer.population = population
154
+
155
+ # Restore history
156
+ if "history" in checkpoint:
157
+ optimizer.history = checkpoint["history"]
158
+
159
+ # Restore best individual
160
+ if "best_individual" in checkpoint:
161
+ optimizer.best_individual = Individual.from_dict(checkpoint["best_individual"])
162
+
163
+ # Restore evaluated list
164
+ if "evaluated" in checkpoint:
165
+ optimizer.evaluated = [
166
+ Individual.from_dict(ind_data) for ind_data in checkpoint["evaluated"]
167
+ ]
168
+
169
+ logger.info(f"Checkpoint loaded from {filepath}")
170
+ return optimizer
171
+
172
+ except Exception as e:
173
+ logger.error(f"Failed to load checkpoint: {e}")
174
+ raise
175
+
176
+ @staticmethod
177
+ def list_checkpoints(directory: str = ".") -> list:
178
+ """
179
+ List all checkpoint files in directory.
180
+
181
+ Args:
182
+ directory: Directory to search
183
+
184
+ Returns:
185
+ List of checkpoint file paths
186
+ """
187
+ dir_path = Path(directory)
188
+ checkpoints = list(dir_path.glob("*.json"))
189
+ return [str(cp) for cp in checkpoints]
@@ -0,0 +1,390 @@
1
+ """Architecture comparison utilities.
2
+
3
+ Compare multiple architectures across various metrics.
4
+
5
+ Example:
6
+ >>> from morphml.utils.comparison import compare_architectures
7
+ >>>
8
+ >>> comparison = compare_architectures([arch1, arch2, arch3])
9
+ >>> comparison.print_table()
10
+ >>> comparison.plot()
11
+ """
12
+
13
+ from typing import Any, Dict, List, Optional
14
+
15
+ import numpy as np
16
+
17
+ from morphml.core.graph import ModelGraph
18
+ from morphml.logging_config import get_logger
19
+
20
+ logger = get_logger(__name__)
21
+
22
+
23
+ class ArchitectureComparison:
24
+ """
25
+ Compare multiple architectures across metrics.
26
+
27
+ Attributes:
28
+ architectures: List of ModelGraph instances
29
+ names: Optional names for architectures
30
+
31
+ Example:
32
+ >>> comparison = ArchitectureComparison([arch1, arch2, arch3])
33
+ >>> comparison.add_metric("custom", lambda g: len(g.nodes) * 2)
34
+ >>> comparison.print_table()
35
+ """
36
+
37
+ def __init__(
38
+ self,
39
+ architectures: List[ModelGraph],
40
+ names: Optional[List[str]] = None,
41
+ ):
42
+ """
43
+ Initialize comparison.
44
+
45
+ Args:
46
+ architectures: List of architectures to compare
47
+ names: Optional names for each architecture
48
+ """
49
+ self.architectures = architectures
50
+ self.names = names or [f"Arch_{i+1}" for i in range(len(architectures))]
51
+ self.custom_metrics = {}
52
+
53
+ if len(self.architectures) != len(self.names):
54
+ raise ValueError("Number of names must match number of architectures")
55
+
56
+ def add_metric(self, name: str, func):
57
+ """
58
+ Add custom metric.
59
+
60
+ Args:
61
+ name: Metric name
62
+ func: Function that takes ModelGraph and returns numeric value
63
+
64
+ Example:
65
+ >>> comparison.add_metric("complexity", lambda g: g.depth() * len(g.nodes))
66
+ """
67
+ self.custom_metrics[name] = func
68
+
69
+ def compute_metrics(self) -> Dict[str, List[Any]]:
70
+ """
71
+ Compute all metrics for all architectures.
72
+
73
+ Returns:
74
+ Dictionary mapping metric names to lists of values
75
+ """
76
+ metrics = {
77
+ "nodes": [],
78
+ "edges": [],
79
+ "parameters": [],
80
+ "depth": [],
81
+ "width": [],
82
+ }
83
+
84
+ # Add custom metrics
85
+ for name in self.custom_metrics:
86
+ metrics[name] = []
87
+
88
+ # Compute for each architecture
89
+ for arch in self.architectures:
90
+ metrics["nodes"].append(len(arch.nodes))
91
+ metrics["edges"].append(len(arch.edges))
92
+ metrics["parameters"].append(arch.estimate_parameters())
93
+ metrics["depth"].append(arch.depth())
94
+ metrics["width"].append(arch.width())
95
+
96
+ # Custom metrics
97
+ for name, func in self.custom_metrics.items():
98
+ try:
99
+ value = func(arch)
100
+ metrics[name].append(value)
101
+ except Exception as e:
102
+ logger.warning(f"Failed to compute {name}: {e}")
103
+ metrics[name].append(None)
104
+
105
+ return metrics
106
+
107
+ def print_table(self):
108
+ """Print comparison table."""
109
+ metrics = self.compute_metrics()
110
+
111
+ # Try to use rich for better formatting
112
+ try:
113
+ from rich.console import Console
114
+ from rich.table import Table
115
+
116
+ console = Console()
117
+ table = Table(title="Architecture Comparison", show_header=True)
118
+
119
+ table.add_column("Architecture", style="cyan")
120
+ for metric_name in metrics.keys():
121
+ table.add_column(metric_name.title(), style="green")
122
+
123
+ for i, name in enumerate(self.names):
124
+ row = [name]
125
+ for metric_values in metrics.values():
126
+ value = metric_values[i]
127
+ if value is None:
128
+ row.append("N/A")
129
+ elif isinstance(value, float):
130
+ row.append(f"{value:.2f}")
131
+ elif isinstance(value, int) and value > 1000:
132
+ row.append(f"{value:,}")
133
+ else:
134
+ row.append(str(value))
135
+ table.add_row(*row)
136
+
137
+ console.print(table)
138
+
139
+ except ImportError:
140
+ # Fallback to simple print
141
+ print("\n" + "=" * 80)
142
+ print("Architecture Comparison")
143
+ print("=" * 80)
144
+
145
+ # Header
146
+ header = f"{'Architecture':<20}"
147
+ for metric_name in metrics.keys():
148
+ header += f"{metric_name.title():<15}"
149
+ print(header)
150
+ print("-" * 80)
151
+
152
+ # Rows
153
+ for i, name in enumerate(self.names):
154
+ row = f"{name:<20}"
155
+ for metric_values in metrics.values():
156
+ value = metric_values[i]
157
+ if value is None:
158
+ row += f"{'N/A':<15}"
159
+ elif isinstance(value, float):
160
+ row += f"{value:<15.2f}"
161
+ elif isinstance(value, int) and value > 1000:
162
+ row += f"{value:<15,}"
163
+ else:
164
+ row += f"{str(value):<15}"
165
+ print(row)
166
+ print("=" * 80 + "\n")
167
+
168
+ def plot(self, output_file: Optional[str] = None):
169
+ """
170
+ Plot comparison charts.
171
+
172
+ Args:
173
+ output_file: Optional path to save figure
174
+ """
175
+ try:
176
+ import matplotlib.pyplot as plt
177
+ except ImportError:
178
+ logger.error("Matplotlib required for plotting. Install with: pip install matplotlib")
179
+ return
180
+
181
+ metrics = self.compute_metrics()
182
+
183
+ # Create subplots
184
+ fig, axes = plt.subplots(2, 3, figsize=(15, 10))
185
+ fig.suptitle("Architecture Comparison", fontsize=16, fontweight="bold")
186
+
187
+ plot_metrics = ["nodes", "edges", "parameters", "depth", "width"]
188
+
189
+ for idx, metric_name in enumerate(plot_metrics):
190
+ if idx >= 6:
191
+ break
192
+
193
+ ax = axes[idx // 3, idx % 3]
194
+ values = metrics[metric_name]
195
+
196
+ # Bar plot
197
+ x = np.arange(len(self.names))
198
+ ax.bar(x, values, color="skyblue", edgecolor="black")
199
+ ax.set_xticks(x)
200
+ ax.set_xticklabels(self.names, rotation=45, ha="right")
201
+ ax.set_title(metric_name.title())
202
+ ax.set_ylabel("Value")
203
+ ax.grid(True, alpha=0.3)
204
+
205
+ # Hide unused subplot
206
+ if len(plot_metrics) < 6:
207
+ axes[1, 2].axis("off")
208
+
209
+ plt.tight_layout()
210
+
211
+ if output_file:
212
+ plt.savefig(output_file, dpi=300, bbox_inches="tight")
213
+ logger.info(f"Saved comparison plot to {output_file}")
214
+ else:
215
+ plt.show()
216
+
217
+ def get_best(self, metric: str = "parameters", minimize: bool = True) -> tuple:
218
+ """
219
+ Get best architecture by metric.
220
+
221
+ Args:
222
+ metric: Metric name
223
+ minimize: Whether to minimize (True) or maximize (False)
224
+
225
+ Returns:
226
+ Tuple of (architecture, name, value)
227
+ """
228
+ metrics = self.compute_metrics()
229
+
230
+ if metric not in metrics:
231
+ raise ValueError(f"Unknown metric: {metric}")
232
+
233
+ values = metrics[metric]
234
+
235
+ if minimize:
236
+ best_idx = np.argmin(values)
237
+ else:
238
+ best_idx = np.argmax(values)
239
+
240
+ return (self.architectures[best_idx], self.names[best_idx], values[best_idx])
241
+
242
+ def get_summary(self) -> Dict[str, Dict[str, float]]:
243
+ """
244
+ Get statistical summary of metrics.
245
+
246
+ Returns:
247
+ Dictionary with mean, std, min, max for each metric
248
+ """
249
+ metrics = self.compute_metrics()
250
+ summary = {}
251
+
252
+ for metric_name, values in metrics.items():
253
+ # Filter out None values
254
+ valid_values = [v for v in values if v is not None]
255
+
256
+ if not valid_values:
257
+ continue
258
+
259
+ summary[metric_name] = {
260
+ "mean": np.mean(valid_values),
261
+ "std": np.std(valid_values),
262
+ "min": np.min(valid_values),
263
+ "max": np.max(valid_values),
264
+ }
265
+
266
+ return summary
267
+
268
+
269
+ def compare_architectures(
270
+ architectures: List[ModelGraph],
271
+ names: Optional[List[str]] = None,
272
+ print_table: bool = True,
273
+ plot: bool = False,
274
+ output_file: Optional[str] = None,
275
+ ) -> ArchitectureComparison:
276
+ """
277
+ Quick comparison of multiple architectures.
278
+
279
+ Args:
280
+ architectures: List of architectures
281
+ names: Optional names
282
+ print_table: Whether to print comparison table
283
+ plot: Whether to plot comparison
284
+ output_file: Optional file to save plot
285
+
286
+ Returns:
287
+ ArchitectureComparison instance
288
+
289
+ Example:
290
+ >>> comparison = compare_architectures([arch1, arch2, arch3])
291
+ >>> best_arch, name, params = comparison.get_best("parameters")
292
+ >>> print(f"Best: {name} with {params:,} parameters")
293
+ """
294
+ comparison = ArchitectureComparison(architectures, names)
295
+
296
+ if print_table:
297
+ comparison.print_table()
298
+
299
+ if plot:
300
+ comparison.plot(output_file)
301
+
302
+ return comparison
303
+
304
+
305
+ def find_similar_architectures(
306
+ target: ModelGraph,
307
+ candidates: List[ModelGraph],
308
+ top_k: int = 5,
309
+ metric: str = "structure",
310
+ ) -> List[tuple]:
311
+ """
312
+ Find architectures similar to target.
313
+
314
+ Args:
315
+ target: Target architecture
316
+ candidates: List of candidate architectures
317
+ top_k: Number of similar architectures to return
318
+ metric: Similarity metric ("structure", "parameters", "depth")
319
+
320
+ Returns:
321
+ List of (architecture, similarity_score) tuples
322
+
323
+ Example:
324
+ >>> similar = find_similar_architectures(my_arch, all_archs, top_k=3)
325
+ >>> for arch, score in similar:
326
+ ... print(f"Similarity: {score:.3f}")
327
+ """
328
+ similarities = []
329
+
330
+ target_nodes = len(target.nodes)
331
+ target_edges = len(target.edges)
332
+ target_params = target.estimate_parameters()
333
+ target_depth = target.depth()
334
+
335
+ for candidate in candidates:
336
+ if metric == "structure":
337
+ # Structure similarity based on nodes and edges
338
+ node_diff = abs(len(candidate.nodes) - target_nodes) / max(target_nodes, 1)
339
+ edge_diff = abs(len(candidate.edges) - target_edges) / max(target_edges, 1)
340
+ similarity = 1.0 / (1.0 + node_diff + edge_diff)
341
+
342
+ elif metric == "parameters":
343
+ # Parameter similarity
344
+ param_diff = abs(candidate.estimate_parameters() - target_params) / max(
345
+ target_params, 1
346
+ )
347
+ similarity = 1.0 / (1.0 + param_diff)
348
+
349
+ elif metric == "depth":
350
+ # Depth similarity
351
+ depth_diff = abs(candidate.depth() - target_depth) / max(target_depth, 1)
352
+ similarity = 1.0 / (1.0 + depth_diff)
353
+
354
+ else:
355
+ raise ValueError(f"Unknown metric: {metric}")
356
+
357
+ similarities.append((candidate, similarity))
358
+
359
+ # Sort by similarity (descending)
360
+ similarities.sort(key=lambda x: x[1], reverse=True)
361
+
362
+ return similarities[:top_k]
363
+
364
+
365
+ def diff_architectures(arch1: ModelGraph, arch2: ModelGraph) -> Dict[str, Any]:
366
+ """
367
+ Compute differences between two architectures.
368
+
369
+ Args:
370
+ arch1: First architecture
371
+ arch2: Second architecture
372
+
373
+ Returns:
374
+ Dictionary of differences
375
+
376
+ Example:
377
+ >>> diff = diff_architectures(arch1, arch2)
378
+ >>> print(f"Node difference: {diff['nodes_diff']}")
379
+ """
380
+ return {
381
+ "nodes_diff": len(arch2.nodes) - len(arch1.nodes),
382
+ "edges_diff": len(arch2.edges) - len(arch1.edges),
383
+ "parameters_diff": arch2.estimate_parameters() - arch1.estimate_parameters(),
384
+ "depth_diff": arch2.depth() - arch1.depth(),
385
+ "width_diff": arch2.width() - arch1.width(),
386
+ "nodes_1": len(arch1.nodes),
387
+ "nodes_2": len(arch2.nodes),
388
+ "parameters_1": arch1.estimate_parameters(),
389
+ "parameters_2": arch2.estimate_parameters(),
390
+ }