morphml 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of morphml might be problematic. Click here for more details.

Files changed (158) hide show
  1. morphml/__init__.py +14 -0
  2. morphml/api/__init__.py +26 -0
  3. morphml/api/app.py +326 -0
  4. morphml/api/auth.py +193 -0
  5. morphml/api/client.py +338 -0
  6. morphml/api/models.py +132 -0
  7. morphml/api/rate_limit.py +192 -0
  8. morphml/benchmarking/__init__.py +36 -0
  9. morphml/benchmarking/comparison.py +430 -0
  10. morphml/benchmarks/__init__.py +56 -0
  11. morphml/benchmarks/comparator.py +409 -0
  12. morphml/benchmarks/datasets.py +280 -0
  13. morphml/benchmarks/metrics.py +199 -0
  14. morphml/benchmarks/openml_suite.py +201 -0
  15. morphml/benchmarks/problems.py +289 -0
  16. morphml/benchmarks/suite.py +318 -0
  17. morphml/cli/__init__.py +5 -0
  18. morphml/cli/commands/experiment.py +329 -0
  19. morphml/cli/main.py +457 -0
  20. morphml/cli/quickstart.py +312 -0
  21. morphml/config.py +278 -0
  22. morphml/constraints/__init__.py +19 -0
  23. morphml/constraints/handler.py +205 -0
  24. morphml/constraints/predicates.py +285 -0
  25. morphml/core/__init__.py +3 -0
  26. morphml/core/crossover.py +449 -0
  27. morphml/core/dsl/README.md +359 -0
  28. morphml/core/dsl/__init__.py +72 -0
  29. morphml/core/dsl/ast_nodes.py +364 -0
  30. morphml/core/dsl/compiler.py +318 -0
  31. morphml/core/dsl/layers.py +368 -0
  32. morphml/core/dsl/lexer.py +336 -0
  33. morphml/core/dsl/parser.py +455 -0
  34. morphml/core/dsl/search_space.py +386 -0
  35. morphml/core/dsl/syntax.py +199 -0
  36. morphml/core/dsl/type_system.py +361 -0
  37. morphml/core/dsl/validator.py +386 -0
  38. morphml/core/graph/__init__.py +40 -0
  39. morphml/core/graph/edge.py +124 -0
  40. morphml/core/graph/graph.py +507 -0
  41. morphml/core/graph/mutations.py +409 -0
  42. morphml/core/graph/node.py +196 -0
  43. morphml/core/graph/serialization.py +361 -0
  44. morphml/core/graph/visualization.py +431 -0
  45. morphml/core/objectives/__init__.py +20 -0
  46. morphml/core/search/__init__.py +33 -0
  47. morphml/core/search/individual.py +252 -0
  48. morphml/core/search/parameters.py +453 -0
  49. morphml/core/search/population.py +375 -0
  50. morphml/core/search/search_engine.py +340 -0
  51. morphml/distributed/__init__.py +76 -0
  52. morphml/distributed/fault_tolerance.py +497 -0
  53. morphml/distributed/health_monitor.py +348 -0
  54. morphml/distributed/master.py +709 -0
  55. morphml/distributed/proto/README.md +224 -0
  56. morphml/distributed/proto/__init__.py +74 -0
  57. morphml/distributed/proto/worker.proto +170 -0
  58. morphml/distributed/proto/worker_pb2.py +79 -0
  59. morphml/distributed/proto/worker_pb2_grpc.py +423 -0
  60. morphml/distributed/resource_manager.py +416 -0
  61. morphml/distributed/scheduler.py +567 -0
  62. morphml/distributed/storage/__init__.py +33 -0
  63. morphml/distributed/storage/artifacts.py +381 -0
  64. morphml/distributed/storage/cache.py +366 -0
  65. morphml/distributed/storage/checkpointing.py +329 -0
  66. morphml/distributed/storage/database.py +459 -0
  67. morphml/distributed/worker.py +549 -0
  68. morphml/evaluation/__init__.py +5 -0
  69. morphml/evaluation/heuristic.py +237 -0
  70. morphml/exceptions.py +55 -0
  71. morphml/execution/__init__.py +5 -0
  72. morphml/execution/local_executor.py +350 -0
  73. morphml/integrations/__init__.py +28 -0
  74. morphml/integrations/jax_adapter.py +206 -0
  75. morphml/integrations/pytorch_adapter.py +530 -0
  76. morphml/integrations/sklearn_adapter.py +206 -0
  77. morphml/integrations/tensorflow_adapter.py +230 -0
  78. morphml/logging_config.py +93 -0
  79. morphml/meta_learning/__init__.py +66 -0
  80. morphml/meta_learning/architecture_similarity.py +277 -0
  81. morphml/meta_learning/experiment_database.py +240 -0
  82. morphml/meta_learning/knowledge_base/__init__.py +19 -0
  83. morphml/meta_learning/knowledge_base/embedder.py +179 -0
  84. morphml/meta_learning/knowledge_base/knowledge_base.py +313 -0
  85. morphml/meta_learning/knowledge_base/meta_features.py +265 -0
  86. morphml/meta_learning/knowledge_base/vector_store.py +271 -0
  87. morphml/meta_learning/predictors/__init__.py +27 -0
  88. morphml/meta_learning/predictors/ensemble.py +221 -0
  89. morphml/meta_learning/predictors/gnn_predictor.py +552 -0
  90. morphml/meta_learning/predictors/learning_curve.py +231 -0
  91. morphml/meta_learning/predictors/proxy_metrics.py +261 -0
  92. morphml/meta_learning/strategy_evolution/__init__.py +27 -0
  93. morphml/meta_learning/strategy_evolution/adaptive_optimizer.py +226 -0
  94. morphml/meta_learning/strategy_evolution/bandit.py +276 -0
  95. morphml/meta_learning/strategy_evolution/portfolio.py +230 -0
  96. morphml/meta_learning/transfer.py +581 -0
  97. morphml/meta_learning/warm_start.py +286 -0
  98. morphml/optimizers/__init__.py +74 -0
  99. morphml/optimizers/adaptive_operators.py +399 -0
  100. morphml/optimizers/bayesian/__init__.py +52 -0
  101. morphml/optimizers/bayesian/acquisition.py +387 -0
  102. morphml/optimizers/bayesian/base.py +319 -0
  103. morphml/optimizers/bayesian/gaussian_process.py +635 -0
  104. morphml/optimizers/bayesian/smac.py +534 -0
  105. morphml/optimizers/bayesian/tpe.py +411 -0
  106. morphml/optimizers/differential_evolution.py +220 -0
  107. morphml/optimizers/evolutionary/__init__.py +61 -0
  108. morphml/optimizers/evolutionary/cma_es.py +416 -0
  109. morphml/optimizers/evolutionary/differential_evolution.py +556 -0
  110. morphml/optimizers/evolutionary/encoding.py +426 -0
  111. morphml/optimizers/evolutionary/particle_swarm.py +449 -0
  112. morphml/optimizers/genetic_algorithm.py +486 -0
  113. morphml/optimizers/gradient_based/__init__.py +22 -0
  114. morphml/optimizers/gradient_based/darts.py +550 -0
  115. morphml/optimizers/gradient_based/enas.py +585 -0
  116. morphml/optimizers/gradient_based/operations.py +474 -0
  117. morphml/optimizers/gradient_based/utils.py +601 -0
  118. morphml/optimizers/hill_climbing.py +169 -0
  119. morphml/optimizers/multi_objective/__init__.py +56 -0
  120. morphml/optimizers/multi_objective/indicators.py +504 -0
  121. morphml/optimizers/multi_objective/nsga2.py +647 -0
  122. morphml/optimizers/multi_objective/visualization.py +427 -0
  123. morphml/optimizers/nsga2.py +308 -0
  124. morphml/optimizers/random_search.py +172 -0
  125. morphml/optimizers/simulated_annealing.py +181 -0
  126. morphml/plugins/__init__.py +35 -0
  127. morphml/plugins/custom_evaluator_example.py +81 -0
  128. morphml/plugins/custom_optimizer_example.py +63 -0
  129. morphml/plugins/plugin_system.py +454 -0
  130. morphml/reports/__init__.py +30 -0
  131. morphml/reports/generator.py +362 -0
  132. morphml/tracking/__init__.py +7 -0
  133. morphml/tracking/experiment.py +309 -0
  134. morphml/tracking/logger.py +301 -0
  135. morphml/tracking/reporter.py +357 -0
  136. morphml/utils/__init__.py +6 -0
  137. morphml/utils/checkpoint.py +189 -0
  138. morphml/utils/comparison.py +390 -0
  139. morphml/utils/export.py +407 -0
  140. morphml/utils/progress.py +392 -0
  141. morphml/utils/validation.py +392 -0
  142. morphml/version.py +7 -0
  143. morphml/visualization/__init__.py +50 -0
  144. morphml/visualization/analytics.py +423 -0
  145. morphml/visualization/architecture_diagrams.py +353 -0
  146. morphml/visualization/architecture_plot.py +223 -0
  147. morphml/visualization/convergence_plot.py +174 -0
  148. morphml/visualization/crossover_viz.py +386 -0
  149. morphml/visualization/graph_viz.py +338 -0
  150. morphml/visualization/pareto_plot.py +149 -0
  151. morphml/visualization/plotly_dashboards.py +422 -0
  152. morphml/visualization/population.py +309 -0
  153. morphml/visualization/progress.py +260 -0
  154. morphml-1.0.0.dist-info/METADATA +434 -0
  155. morphml-1.0.0.dist-info/RECORD +158 -0
  156. morphml-1.0.0.dist-info/WHEEL +4 -0
  157. morphml-1.0.0.dist-info/entry_points.txt +3 -0
  158. morphml-1.0.0.dist-info/licenses/LICENSE +21 -0
@@ -0,0 +1,409 @@
1
+ """Optimizer comparison utilities."""
2
+
3
+ import statistics
4
+ from typing import Any, Dict, List, Optional
5
+
6
+ from morphml.logging_config import get_logger
7
+
8
+ logger = get_logger(__name__)
9
+
10
+
11
+ class OptimizerComparator:
12
+ """
13
+ Compare multiple optimizers on benchmark problems.
14
+
15
+ Provides statistical analysis and visualization support.
16
+
17
+ Example:
18
+ >>> comparator = OptimizerComparator()
19
+ >>> comparator.add_result("GA", "Problem1", [0.9, 0.92, 0.91])
20
+ >>> comparator.add_result("RS", "Problem1", [0.85, 0.87, 0.86])
21
+ >>> comparator.print_comparison()
22
+ """
23
+
24
+ def __init__(self):
25
+ """Initialize comparator."""
26
+ self.results: Dict[tuple, List[float]] = {}
27
+ self.metadata: Dict[tuple, Dict[str, Any]] = {}
28
+
29
+ def add_result(
30
+ self, optimizer_name: str, problem_name: str, fitnesses: List[float], **metadata
31
+ ) -> None:
32
+ """
33
+ Add results for an optimizer on a problem.
34
+
35
+ Args:
36
+ optimizer_name: Name of optimizer
37
+ problem_name: Name of problem
38
+ fitnesses: List of fitness values from multiple runs
39
+ **metadata: Additional metadata (time, evaluations, etc.)
40
+ """
41
+ key = (optimizer_name, problem_name)
42
+ self.results[key] = fitnesses
43
+ self.metadata[key] = metadata
44
+
45
+ logger.debug(f"Added results: {optimizer_name} on {problem_name} " f"(n={len(fitnesses)})")
46
+
47
+ def get_statistics(self, optimizer_name: str, problem_name: str) -> Dict[str, float]:
48
+ """
49
+ Get statistics for an optimizer-problem pair.
50
+
51
+ Args:
52
+ optimizer_name: Optimizer name
53
+ problem_name: Problem name
54
+
55
+ Returns:
56
+ Dictionary of statistics
57
+ """
58
+ key = (optimizer_name, problem_name)
59
+
60
+ if key not in self.results:
61
+ return {}
62
+
63
+ fitnesses = self.results[key]
64
+
65
+ if not fitnesses:
66
+ return {}
67
+
68
+ stats = {
69
+ "mean": statistics.mean(fitnesses),
70
+ "median": statistics.median(fitnesses),
71
+ "std": statistics.stdev(fitnesses) if len(fitnesses) > 1 else 0.0,
72
+ "min": min(fitnesses),
73
+ "max": max(fitnesses),
74
+ "count": len(fitnesses),
75
+ }
76
+
77
+ return stats
78
+
79
+ def rank_optimizers(self, problem_name: str) -> List[tuple]:
80
+ """
81
+ Rank optimizers for a specific problem.
82
+
83
+ Args:
84
+ problem_name: Problem name
85
+
86
+ Returns:
87
+ List of (optimizer_name, mean_fitness) tuples, sorted
88
+ """
89
+ rankings = []
90
+
91
+ for (opt_name, prob_name), fitnesses in self.results.items():
92
+ if prob_name == problem_name and fitnesses:
93
+ mean_fitness = statistics.mean(fitnesses)
94
+ rankings.append((opt_name, mean_fitness))
95
+
96
+ rankings.sort(key=lambda x: x[1], reverse=True)
97
+ return rankings
98
+
99
+ def compare_pair(self, optimizer1: str, optimizer2: str, problem_name: str) -> Dict[str, Any]:
100
+ """
101
+ Statistical comparison of two optimizers.
102
+
103
+ Args:
104
+ optimizer1: First optimizer name
105
+ optimizer2: Second optimizer name
106
+ problem_name: Problem name
107
+
108
+ Returns:
109
+ Comparison results
110
+ """
111
+ key1 = (optimizer1, problem_name)
112
+ key2 = (optimizer2, problem_name)
113
+
114
+ if key1 not in self.results or key2 not in self.results:
115
+ return {}
116
+
117
+ fitnesses1 = self.results[key1]
118
+ fitnesses2 = self.results[key2]
119
+
120
+ stats1 = self.get_statistics(optimizer1, problem_name)
121
+ stats2 = self.get_statistics(optimizer2, problem_name)
122
+
123
+ # Perform t-test if enough samples
124
+ p_value = None
125
+ if len(fitnesses1) > 1 and len(fitnesses2) > 1:
126
+ try:
127
+ from scipy import stats as scipy_stats
128
+
129
+ t_stat, p_value = scipy_stats.ttest_ind(fitnesses1, fitnesses2)
130
+ except ImportError:
131
+ logger.warning("scipy not available for statistical tests")
132
+
133
+ comparison = {
134
+ "optimizer1": optimizer1,
135
+ "optimizer2": optimizer2,
136
+ "problem": problem_name,
137
+ "mean_diff": stats1["mean"] - stats2["mean"],
138
+ "median_diff": stats1["median"] - stats2["median"],
139
+ "winner": optimizer1 if stats1["mean"] > stats2["mean"] else optimizer2,
140
+ "p_value": p_value,
141
+ "significant": p_value < 0.05 if p_value else None,
142
+ }
143
+
144
+ return comparison
145
+
146
+ def get_dominance_matrix(self) -> Dict[tuple, int]:
147
+ """
148
+ Get dominance matrix showing which optimizer beats which.
149
+
150
+ Returns:
151
+ Dictionary mapping (opt1, opt2) to number of problems opt1 beats opt2
152
+ """
153
+ optimizers = {opt for opt, _ in self.results.keys()}
154
+ problems = {prob for _, prob in self.results.keys()}
155
+
156
+ dominance = {}
157
+
158
+ for opt1 in optimizers:
159
+ for opt2 in optimizers:
160
+ if opt1 == opt2:
161
+ continue
162
+
163
+ wins = 0
164
+ for problem in problems:
165
+ stats1 = self.get_statistics(opt1, problem)
166
+ stats2 = self.get_statistics(opt2, problem)
167
+
168
+ if stats1 and stats2:
169
+ if stats1["mean"] > stats2["mean"]:
170
+ wins += 1
171
+
172
+ dominance[(opt1, opt2)] = wins
173
+
174
+ return dominance
175
+
176
+ def print_comparison(self) -> None:
177
+ """Print detailed comparison of all optimizers."""
178
+ problems = {prob for _, prob in self.results.keys()}
179
+ optimizers = {opt for opt, _ in self.results.keys()}
180
+
181
+ print("\n" + "=" * 100)
182
+ print("OPTIMIZER COMPARISON")
183
+ print("=" * 100)
184
+
185
+ for problem in sorted(problems):
186
+ print(f"\n{problem}")
187
+ print("-" * 100)
188
+ print(
189
+ f"{'Optimizer':<20} {'Mean':<12} {'Median':<12} {'Std':<12} {'Min':<12} {'Max':<12}"
190
+ )
191
+ print("-" * 100)
192
+
193
+ rankings = self.rank_optimizers(problem)
194
+
195
+ for opt_name, _mean_fitness in rankings:
196
+ stats = self.get_statistics(opt_name, problem)
197
+
198
+ print(
199
+ f"{opt_name:<20} "
200
+ f"{stats['mean']:<12.4f} "
201
+ f"{stats['median']:<12.4f} "
202
+ f"{stats['std']:<12.4f} "
203
+ f"{stats['min']:<12.4f} "
204
+ f"{stats['max']:<12.4f}"
205
+ )
206
+
207
+ # Overall rankings
208
+ print("\n" + "=" * 100)
209
+ print("OVERALL RANKINGS (by mean fitness across all problems)")
210
+ print("=" * 100)
211
+
212
+ overall_scores = {}
213
+ for opt in optimizers:
214
+ scores = []
215
+ for prob in problems:
216
+ stats = self.get_statistics(opt, prob)
217
+ if stats:
218
+ scores.append(stats["mean"])
219
+
220
+ if scores:
221
+ overall_scores[opt] = statistics.mean(scores)
222
+
223
+ overall_rankings = sorted(overall_scores.items(), key=lambda x: x[1], reverse=True)
224
+
225
+ for rank, (opt, score) in enumerate(overall_rankings, 1):
226
+ print(f"{rank}. {opt:<20} Average: {score:.4f}")
227
+
228
+ print("\n" + "=" * 100)
229
+
230
+ def get_best_optimizer(self) -> Optional[str]:
231
+ """Get overall best optimizer."""
232
+ problems = {prob for _, prob in self.results.keys()}
233
+ optimizers = {opt for opt, _ in self.results.keys()}
234
+
235
+ overall_scores = {}
236
+ for opt in optimizers:
237
+ scores = []
238
+ for prob in problems:
239
+ stats = self.get_statistics(opt, prob)
240
+ if stats:
241
+ scores.append(stats["mean"])
242
+
243
+ if scores:
244
+ overall_scores[opt] = statistics.mean(scores)
245
+
246
+ if not overall_scores:
247
+ return None
248
+
249
+ return max(overall_scores.items(), key=lambda x: x[1])[0]
250
+
251
+ def export_latex_table(self, filename: str) -> None:
252
+ """Export comparison as LaTeX table."""
253
+ problems = sorted({prob for _, prob in self.results.keys()})
254
+ optimizers = sorted({opt for opt, _ in self.results.keys()})
255
+
256
+ with open(filename, "w") as f:
257
+ # Header
258
+ f.write("\\begin{table}[htbp]\n")
259
+ f.write("\\centering\n")
260
+ f.write("\\caption{Optimizer Comparison}\n")
261
+ f.write("\\begin{tabular}{l" + "c" * len(optimizers) + "}\n")
262
+ f.write("\\hline\n")
263
+
264
+ # Column headers
265
+ f.write("Problem & " + " & ".join(optimizers) + " \\\\\n")
266
+ f.write("\\hline\n")
267
+
268
+ # Data rows
269
+ for problem in problems:
270
+ row = [problem]
271
+ for opt in optimizers:
272
+ stats = self.get_statistics(opt, problem)
273
+ if stats:
274
+ row.append(f"{stats['mean']:.3f}$\\pm${stats['std']:.3f}")
275
+ else:
276
+ row.append("--")
277
+
278
+ f.write(" & ".join(row) + " \\\\\n")
279
+
280
+ # Footer
281
+ f.write("\\hline\n")
282
+ f.write("\\end{tabular}\n")
283
+ f.write("\\end{table}\n")
284
+
285
+ logger.info(f"LaTeX table exported to {filename}")
286
+
287
+ def plot_comparison(self, save_path: Optional[str] = None) -> None:
288
+ """
289
+ Plot comparison of optimizers.
290
+
291
+ Args:
292
+ save_path: Path to save plot (displays if None)
293
+ """
294
+ try:
295
+ import matplotlib.pyplot as plt
296
+ import numpy as np
297
+ except ImportError:
298
+ logger.warning("matplotlib not available for plotting")
299
+ return
300
+
301
+ problems = sorted({prob for _, prob in self.results.keys()})
302
+ optimizers = sorted({opt for opt, _ in self.results.keys()})
303
+
304
+ fig, axes = plt.subplots(1, len(problems), figsize=(5 * len(problems), 5))
305
+
306
+ if len(problems) == 1:
307
+ axes = [axes]
308
+
309
+ for ax, problem in zip(axes, problems):
310
+ means = []
311
+ stds = []
312
+ labels = []
313
+
314
+ for opt in optimizers:
315
+ stats = self.get_statistics(opt, problem)
316
+ if stats:
317
+ means.append(stats["mean"])
318
+ stds.append(stats["std"])
319
+ labels.append(opt)
320
+
321
+ x = np.arange(len(labels))
322
+ ax.bar(x, means, yerr=stds, capsize=5)
323
+ ax.set_xticks(x)
324
+ ax.set_xticklabels(labels, rotation=45, ha="right")
325
+ ax.set_ylabel("Fitness")
326
+ ax.set_title(problem)
327
+ ax.grid(True, alpha=0.3)
328
+
329
+ plt.tight_layout()
330
+
331
+ if save_path:
332
+ plt.savefig(save_path, dpi=300, bbox_inches="tight")
333
+ logger.info(f"Plot saved to {save_path}")
334
+ else:
335
+ plt.show()
336
+
337
+
338
+ class ConvergenceAnalyzer:
339
+ """Analyze convergence behavior of optimizers."""
340
+
341
+ def __init__(self):
342
+ """Initialize analyzer."""
343
+ self.histories: Dict[str, List[List[float]]] = {}
344
+
345
+ def add_history(self, optimizer_name: str, history: List[float]) -> None:
346
+ """Add optimization history."""
347
+ if optimizer_name not in self.histories:
348
+ self.histories[optimizer_name] = []
349
+
350
+ self.histories[optimizer_name].append(history)
351
+
352
+ def get_mean_convergence(self, optimizer_name: str) -> List[float]:
353
+ """Get mean convergence curve."""
354
+ if optimizer_name not in self.histories:
355
+ return []
356
+
357
+ histories = self.histories[optimizer_name]
358
+
359
+ if not histories:
360
+ return []
361
+
362
+ # Find minimum length
363
+ min_len = min(len(h) for h in histories)
364
+
365
+ # Calculate mean at each generation
366
+ mean_curve = []
367
+ for i in range(min_len):
368
+ values = [h[i] for h in histories]
369
+ mean_curve.append(statistics.mean(values))
370
+
371
+ return mean_curve
372
+
373
+ def calculate_auc(self, optimizer_name: str) -> float:
374
+ """Calculate area under convergence curve."""
375
+ curve = self.get_mean_convergence(optimizer_name)
376
+
377
+ if not curve:
378
+ return 0.0
379
+
380
+ # Simple trapezoidal integration
381
+ auc = sum(curve) / len(curve)
382
+ return auc
383
+
384
+ def plot_convergence(self, save_path: Optional[str] = None) -> None:
385
+ """Plot convergence curves."""
386
+ try:
387
+ import matplotlib.pyplot as plt
388
+ except ImportError:
389
+ logger.warning("matplotlib not available")
390
+ return
391
+
392
+ plt.figure(figsize=(10, 6))
393
+
394
+ for opt_name in sorted(self.histories.keys()):
395
+ curve = self.get_mean_convergence(opt_name)
396
+ if curve:
397
+ plt.plot(curve, label=opt_name, linewidth=2)
398
+
399
+ plt.xlabel("Generation")
400
+ plt.ylabel("Best Fitness")
401
+ plt.title("Convergence Comparison")
402
+ plt.legend()
403
+ plt.grid(True, alpha=0.3)
404
+
405
+ if save_path:
406
+ plt.savefig(save_path, dpi=300, bbox_inches="tight")
407
+ logger.info(f"Convergence plot saved to {save_path}")
408
+ else:
409
+ plt.show()
@@ -0,0 +1,280 @@
1
+ """Dataset loaders and utilities for benchmarking.
2
+
3
+ Author: Eshan Roy <eshanized@proton.me>
4
+ Organization: TONMOY INFRASTRUCTURE & VISION
5
+ """
6
+
7
+ from typing import Dict, Optional, Tuple
8
+
9
+ import numpy as np
10
+
11
+ from morphml.logging_config import get_logger
12
+
13
+ logger = get_logger(__name__)
14
+
15
+
16
+ def load_cifar10(
17
+ data_dir: Optional[str] = None, normalize: bool = True
18
+ ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
19
+ """
20
+ Load CIFAR-10 dataset.
21
+
22
+ Tries multiple backends in order:
23
+ 1. torchvision (if available)
24
+ 2. tensorflow/keras (if available)
25
+ 3. OpenML (fallback, may be slow)
26
+
27
+ Args:
28
+ data_dir: Directory to download/cache data
29
+ normalize: Whether to normalize to [0, 1]
30
+
31
+ Returns:
32
+ Tuple of (X_train, y_train, X_test, y_test)
33
+ """
34
+ # Try torchvision first
35
+ try:
36
+ import torchvision
37
+ import torchvision.transforms as transforms
38
+
39
+ logger.info("Loading CIFAR-10 via torchvision...")
40
+
41
+ transform = transforms.Compose(
42
+ [
43
+ transforms.ToTensor(),
44
+ ]
45
+ )
46
+
47
+ trainset = torchvision.datasets.CIFAR10(
48
+ root=data_dir or "./data", train=True, download=True, transform=transform
49
+ )
50
+
51
+ testset = torchvision.datasets.CIFAR10(
52
+ root=data_dir or "./data", train=False, download=True, transform=transform
53
+ )
54
+
55
+ # Convert to numpy
56
+ X_train = np.array([trainset[i][0].numpy() for i in range(len(trainset))])
57
+ y_train = np.array([trainset[i][1] for i in range(len(trainset))])
58
+ X_test = np.array([testset[i][0].numpy() for i in range(len(testset))])
59
+ y_test = np.array([testset[i][1] for i in range(len(testset))])
60
+
61
+ if not normalize:
62
+ X_train *= 255.0
63
+ X_test *= 255.0
64
+
65
+ logger.info(f"Loaded CIFAR-10 via torchvision: train={X_train.shape}, test={X_test.shape}")
66
+ return X_train, y_train, X_test, y_test
67
+
68
+ except ImportError:
69
+ pass
70
+
71
+ # Try Keras/TensorFlow
72
+ try:
73
+ from tensorflow import keras
74
+
75
+ logger.info("Loading CIFAR-10 via Keras...")
76
+
77
+ (X_train, y_train), (X_test, y_test) = keras.datasets.cifar10.load_data()
78
+
79
+ # Convert to channel-first format (3, 32, 32)
80
+ X_train = np.transpose(X_train, (0, 3, 1, 2))
81
+ X_test = np.transpose(X_test, (0, 3, 1, 2))
82
+
83
+ # Flatten labels
84
+ y_train = y_train.flatten()
85
+ y_test = y_test.flatten()
86
+
87
+ # Normalize
88
+ X_train = X_train.astype("float32")
89
+ X_test = X_test.astype("float32")
90
+
91
+ if normalize:
92
+ X_train /= 255.0
93
+ X_test /= 255.0
94
+
95
+ logger.info(f"Loaded CIFAR-10 via Keras: train={X_train.shape}, test={X_test.shape}")
96
+ return X_train, y_train, X_test, y_test
97
+
98
+ except ImportError:
99
+ pass
100
+
101
+ # Fallback to OpenML (slower)
102
+ try:
103
+ from sklearn.datasets import fetch_openml
104
+
105
+ logger.info("Loading CIFAR-10 via OpenML (may be slow)...")
106
+ logger.warning("For faster loading, install torchvision or tensorflow")
107
+
108
+ cifar = fetch_openml("CIFAR_10", version=1, cache=True, data_home=data_dir)
109
+ X = cifar.data.astype("float32")
110
+ y = cifar.target.astype("int64")
111
+
112
+ # Reshape to (N, 3, 32, 32)
113
+ X = X.reshape(-1, 3, 32, 32)
114
+
115
+ if normalize:
116
+ X /= 255.0
117
+
118
+ # Split train/test (first 50000 train, rest test)
119
+ X_train, X_test = X[:50000], X[50000:]
120
+ y_train, y_test = y[:50000], y[50000:]
121
+
122
+ logger.info(f"Loaded CIFAR-10 via OpenML: train={X_train.shape}, test={X_test.shape}")
123
+ return X_train, y_train, X_test, y_test
124
+
125
+ except Exception as e:
126
+ logger.error(f"Failed to load CIFAR-10: {e}")
127
+ logger.error("Please install torchvision or tensorflow for CIFAR-10 support")
128
+ raise RuntimeError(
129
+ "CIFAR-10 dataset loading failed. Install one of:\n"
130
+ " pip install torchvision\n"
131
+ " pip install tensorflow\n"
132
+ " pip install scikit-learn (slower)"
133
+ )
134
+
135
+
136
+ def load_mnist(
137
+ data_dir: Optional[str] = None, normalize: bool = True
138
+ ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
139
+ """
140
+ Load MNIST dataset.
141
+
142
+ Args:
143
+ data_dir: Directory to download/cache data
144
+ normalize: Whether to normalize to [0, 1]
145
+
146
+ Returns:
147
+ Tuple of (X_train, y_train, X_test, y_test)
148
+ """
149
+ try:
150
+ from sklearn.datasets import fetch_openml
151
+ except ImportError:
152
+ logger.error("scikit-learn required")
153
+ raise
154
+
155
+ logger.info("Loading MNIST dataset from OpenML...")
156
+
157
+ mnist = fetch_openml("mnist_784", version=1, cache=True, data_home=data_dir)
158
+ X = mnist.data.astype("float32")
159
+ y = mnist.target.astype("int64")
160
+
161
+ if normalize:
162
+ X /= 255.0
163
+
164
+ # Split train/test
165
+ X_train, X_test = X[:60000], X[60000:]
166
+ y_train, y_test = y[:60000], y[60000:]
167
+
168
+ logger.info(f"Loaded MNIST: train={X_train.shape}, test={X_test.shape}")
169
+
170
+ return X_train, y_train, X_test, y_test
171
+
172
+
173
+ def load_fashion_mnist(
174
+ data_dir: Optional[str] = None, normalize: bool = True
175
+ ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
176
+ """
177
+ Load Fashion-MNIST dataset.
178
+
179
+ Args:
180
+ data_dir: Directory to download/cache data
181
+ normalize: Whether to normalize to [0, 1]
182
+
183
+ Returns:
184
+ Tuple of (X_train, y_train, X_test, y_test)
185
+ """
186
+ try:
187
+ from sklearn.datasets import fetch_openml
188
+ except ImportError:
189
+ logger.error("scikit-learn required")
190
+ raise
191
+
192
+ logger.info("Loading Fashion-MNIST dataset from OpenML...")
193
+
194
+ fashion = fetch_openml("Fashion-MNIST", version=1, cache=True, data_home=data_dir)
195
+ X = fashion.data.astype("float32")
196
+ y = fashion.target.astype("int64")
197
+
198
+ if normalize:
199
+ X /= 255.0
200
+
201
+ # Split train/test
202
+ X_train, X_test = X[:60000], X[60000:]
203
+ y_train, y_test = y[:60000], y[60000:]
204
+
205
+ logger.info(f"Loaded Fashion-MNIST: train={X_train.shape}, test={X_test.shape}")
206
+
207
+ return X_train, y_train, X_test, y_test
208
+
209
+
210
+ def get_dataset_info(dataset_name: str) -> Dict:
211
+ """
212
+ Get information about a dataset.
213
+
214
+ Args:
215
+ dataset_name: Name of dataset
216
+
217
+ Returns:
218
+ Dictionary with dataset metadata
219
+ """
220
+ dataset_info = {
221
+ "cifar10": {
222
+ "name": "CIFAR-10",
223
+ "num_classes": 10,
224
+ "input_shape": (3, 32, 32),
225
+ "train_size": 50000,
226
+ "test_size": 10000,
227
+ "type": "image_classification",
228
+ },
229
+ "mnist": {
230
+ "name": "MNIST",
231
+ "num_classes": 10,
232
+ "input_shape": (784,),
233
+ "train_size": 60000,
234
+ "test_size": 10000,
235
+ "type": "image_classification",
236
+ },
237
+ "fashion_mnist": {
238
+ "name": "Fashion-MNIST",
239
+ "num_classes": 10,
240
+ "input_shape": (784,),
241
+ "train_size": 60000,
242
+ "test_size": 10000,
243
+ "type": "image_classification",
244
+ },
245
+ }
246
+
247
+ return dataset_info.get(dataset_name, {})
248
+
249
+
250
+ class DatasetLoader:
251
+ """Unified dataset loader interface."""
252
+
253
+ LOADERS = {
254
+ "cifar10": load_cifar10,
255
+ "mnist": load_mnist,
256
+ "fashion_mnist": load_fashion_mnist,
257
+ }
258
+
259
+ @classmethod
260
+ def load(cls, dataset_name: str, **kwargs):
261
+ """
262
+ Load a dataset by name.
263
+
264
+ Args:
265
+ dataset_name: Name of dataset
266
+ **kwargs: Additional arguments for loader
267
+
268
+ Returns:
269
+ Tuple of (X_train, y_train, X_test, y_test)
270
+ """
271
+ loader = cls.LOADERS.get(dataset_name)
272
+ if loader is None:
273
+ raise ValueError(f"Unknown dataset: {dataset_name}")
274
+
275
+ return loader(**kwargs)
276
+
277
+ @classmethod
278
+ def list_available(cls):
279
+ """List available datasets."""
280
+ return list(cls.LOADERS.keys())