morphml 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of morphml might be problematic. Click here for more details.

Files changed (158) hide show
  1. morphml/__init__.py +14 -0
  2. morphml/api/__init__.py +26 -0
  3. morphml/api/app.py +326 -0
  4. morphml/api/auth.py +193 -0
  5. morphml/api/client.py +338 -0
  6. morphml/api/models.py +132 -0
  7. morphml/api/rate_limit.py +192 -0
  8. morphml/benchmarking/__init__.py +36 -0
  9. morphml/benchmarking/comparison.py +430 -0
  10. morphml/benchmarks/__init__.py +56 -0
  11. morphml/benchmarks/comparator.py +409 -0
  12. morphml/benchmarks/datasets.py +280 -0
  13. morphml/benchmarks/metrics.py +199 -0
  14. morphml/benchmarks/openml_suite.py +201 -0
  15. morphml/benchmarks/problems.py +289 -0
  16. morphml/benchmarks/suite.py +318 -0
  17. morphml/cli/__init__.py +5 -0
  18. morphml/cli/commands/experiment.py +329 -0
  19. morphml/cli/main.py +457 -0
  20. morphml/cli/quickstart.py +312 -0
  21. morphml/config.py +278 -0
  22. morphml/constraints/__init__.py +19 -0
  23. morphml/constraints/handler.py +205 -0
  24. morphml/constraints/predicates.py +285 -0
  25. morphml/core/__init__.py +3 -0
  26. morphml/core/crossover.py +449 -0
  27. morphml/core/dsl/README.md +359 -0
  28. morphml/core/dsl/__init__.py +72 -0
  29. morphml/core/dsl/ast_nodes.py +364 -0
  30. morphml/core/dsl/compiler.py +318 -0
  31. morphml/core/dsl/layers.py +368 -0
  32. morphml/core/dsl/lexer.py +336 -0
  33. morphml/core/dsl/parser.py +455 -0
  34. morphml/core/dsl/search_space.py +386 -0
  35. morphml/core/dsl/syntax.py +199 -0
  36. morphml/core/dsl/type_system.py +361 -0
  37. morphml/core/dsl/validator.py +386 -0
  38. morphml/core/graph/__init__.py +40 -0
  39. morphml/core/graph/edge.py +124 -0
  40. morphml/core/graph/graph.py +507 -0
  41. morphml/core/graph/mutations.py +409 -0
  42. morphml/core/graph/node.py +196 -0
  43. morphml/core/graph/serialization.py +361 -0
  44. morphml/core/graph/visualization.py +431 -0
  45. morphml/core/objectives/__init__.py +20 -0
  46. morphml/core/search/__init__.py +33 -0
  47. morphml/core/search/individual.py +252 -0
  48. morphml/core/search/parameters.py +453 -0
  49. morphml/core/search/population.py +375 -0
  50. morphml/core/search/search_engine.py +340 -0
  51. morphml/distributed/__init__.py +76 -0
  52. morphml/distributed/fault_tolerance.py +497 -0
  53. morphml/distributed/health_monitor.py +348 -0
  54. morphml/distributed/master.py +709 -0
  55. morphml/distributed/proto/README.md +224 -0
  56. morphml/distributed/proto/__init__.py +74 -0
  57. morphml/distributed/proto/worker.proto +170 -0
  58. morphml/distributed/proto/worker_pb2.py +79 -0
  59. morphml/distributed/proto/worker_pb2_grpc.py +423 -0
  60. morphml/distributed/resource_manager.py +416 -0
  61. morphml/distributed/scheduler.py +567 -0
  62. morphml/distributed/storage/__init__.py +33 -0
  63. morphml/distributed/storage/artifacts.py +381 -0
  64. morphml/distributed/storage/cache.py +366 -0
  65. morphml/distributed/storage/checkpointing.py +329 -0
  66. morphml/distributed/storage/database.py +459 -0
  67. morphml/distributed/worker.py +549 -0
  68. morphml/evaluation/__init__.py +5 -0
  69. morphml/evaluation/heuristic.py +237 -0
  70. morphml/exceptions.py +55 -0
  71. morphml/execution/__init__.py +5 -0
  72. morphml/execution/local_executor.py +350 -0
  73. morphml/integrations/__init__.py +28 -0
  74. morphml/integrations/jax_adapter.py +206 -0
  75. morphml/integrations/pytorch_adapter.py +530 -0
  76. morphml/integrations/sklearn_adapter.py +206 -0
  77. morphml/integrations/tensorflow_adapter.py +230 -0
  78. morphml/logging_config.py +93 -0
  79. morphml/meta_learning/__init__.py +66 -0
  80. morphml/meta_learning/architecture_similarity.py +277 -0
  81. morphml/meta_learning/experiment_database.py +240 -0
  82. morphml/meta_learning/knowledge_base/__init__.py +19 -0
  83. morphml/meta_learning/knowledge_base/embedder.py +179 -0
  84. morphml/meta_learning/knowledge_base/knowledge_base.py +313 -0
  85. morphml/meta_learning/knowledge_base/meta_features.py +265 -0
  86. morphml/meta_learning/knowledge_base/vector_store.py +271 -0
  87. morphml/meta_learning/predictors/__init__.py +27 -0
  88. morphml/meta_learning/predictors/ensemble.py +221 -0
  89. morphml/meta_learning/predictors/gnn_predictor.py +552 -0
  90. morphml/meta_learning/predictors/learning_curve.py +231 -0
  91. morphml/meta_learning/predictors/proxy_metrics.py +261 -0
  92. morphml/meta_learning/strategy_evolution/__init__.py +27 -0
  93. morphml/meta_learning/strategy_evolution/adaptive_optimizer.py +226 -0
  94. morphml/meta_learning/strategy_evolution/bandit.py +276 -0
  95. morphml/meta_learning/strategy_evolution/portfolio.py +230 -0
  96. morphml/meta_learning/transfer.py +581 -0
  97. morphml/meta_learning/warm_start.py +286 -0
  98. morphml/optimizers/__init__.py +74 -0
  99. morphml/optimizers/adaptive_operators.py +399 -0
  100. morphml/optimizers/bayesian/__init__.py +52 -0
  101. morphml/optimizers/bayesian/acquisition.py +387 -0
  102. morphml/optimizers/bayesian/base.py +319 -0
  103. morphml/optimizers/bayesian/gaussian_process.py +635 -0
  104. morphml/optimizers/bayesian/smac.py +534 -0
  105. morphml/optimizers/bayesian/tpe.py +411 -0
  106. morphml/optimizers/differential_evolution.py +220 -0
  107. morphml/optimizers/evolutionary/__init__.py +61 -0
  108. morphml/optimizers/evolutionary/cma_es.py +416 -0
  109. morphml/optimizers/evolutionary/differential_evolution.py +556 -0
  110. morphml/optimizers/evolutionary/encoding.py +426 -0
  111. morphml/optimizers/evolutionary/particle_swarm.py +449 -0
  112. morphml/optimizers/genetic_algorithm.py +486 -0
  113. morphml/optimizers/gradient_based/__init__.py +22 -0
  114. morphml/optimizers/gradient_based/darts.py +550 -0
  115. morphml/optimizers/gradient_based/enas.py +585 -0
  116. morphml/optimizers/gradient_based/operations.py +474 -0
  117. morphml/optimizers/gradient_based/utils.py +601 -0
  118. morphml/optimizers/hill_climbing.py +169 -0
  119. morphml/optimizers/multi_objective/__init__.py +56 -0
  120. morphml/optimizers/multi_objective/indicators.py +504 -0
  121. morphml/optimizers/multi_objective/nsga2.py +647 -0
  122. morphml/optimizers/multi_objective/visualization.py +427 -0
  123. morphml/optimizers/nsga2.py +308 -0
  124. morphml/optimizers/random_search.py +172 -0
  125. morphml/optimizers/simulated_annealing.py +181 -0
  126. morphml/plugins/__init__.py +35 -0
  127. morphml/plugins/custom_evaluator_example.py +81 -0
  128. morphml/plugins/custom_optimizer_example.py +63 -0
  129. morphml/plugins/plugin_system.py +454 -0
  130. morphml/reports/__init__.py +30 -0
  131. morphml/reports/generator.py +362 -0
  132. morphml/tracking/__init__.py +7 -0
  133. morphml/tracking/experiment.py +309 -0
  134. morphml/tracking/logger.py +301 -0
  135. morphml/tracking/reporter.py +357 -0
  136. morphml/utils/__init__.py +6 -0
  137. morphml/utils/checkpoint.py +189 -0
  138. morphml/utils/comparison.py +390 -0
  139. morphml/utils/export.py +407 -0
  140. morphml/utils/progress.py +392 -0
  141. morphml/utils/validation.py +392 -0
  142. morphml/version.py +7 -0
  143. morphml/visualization/__init__.py +50 -0
  144. morphml/visualization/analytics.py +423 -0
  145. morphml/visualization/architecture_diagrams.py +353 -0
  146. morphml/visualization/architecture_plot.py +223 -0
  147. morphml/visualization/convergence_plot.py +174 -0
  148. morphml/visualization/crossover_viz.py +386 -0
  149. morphml/visualization/graph_viz.py +338 -0
  150. morphml/visualization/pareto_plot.py +149 -0
  151. morphml/visualization/plotly_dashboards.py +422 -0
  152. morphml/visualization/population.py +309 -0
  153. morphml/visualization/progress.py +260 -0
  154. morphml-1.0.0.dist-info/METADATA +434 -0
  155. morphml-1.0.0.dist-info/RECORD +158 -0
  156. morphml-1.0.0.dist-info/WHEEL +4 -0
  157. morphml-1.0.0.dist-info/entry_points.txt +3 -0
  158. morphml-1.0.0.dist-info/licenses/LICENSE +21 -0
@@ -0,0 +1,601 @@
1
+ """Utility functions for gradient-based NAS.
2
+
3
+ This module provides helper functions for DARTS and ENAS implementations including:
4
+ - GPU management utilities
5
+ - Parameter counting
6
+ - Learning rate scheduling
7
+ - Drop path regularization
8
+ - Architecture visualization
9
+
10
+ Author: Eshan Roy <eshanized@proton.me>
11
+ Organization: TONMOY INFRASTRUCTURE & VISION
12
+ """
13
+
14
+ from typing import Any, Dict, List, Optional, Tuple
15
+
16
+ try:
17
+ import torch
18
+ import torch.nn as nn
19
+
20
+ TORCH_AVAILABLE = True
21
+ except ImportError:
22
+ TORCH_AVAILABLE = False
23
+
24
+ from morphml.logging_config import get_logger
25
+
26
+ logger = get_logger(__name__)
27
+
28
+
29
+ def check_cuda_available() -> bool:
30
+ """
31
+ Check if CUDA is available.
32
+
33
+ Returns:
34
+ True if CUDA is available, False otherwise
35
+ """
36
+ if not TORCH_AVAILABLE:
37
+ return False
38
+ return torch.cuda.is_available()
39
+
40
+
41
+ def get_device(use_cuda: bool = True) -> "torch.device":
42
+ """
43
+ Get PyTorch device (CUDA or CPU).
44
+
45
+ Args:
46
+ use_cuda: Whether to use CUDA if available
47
+
48
+ Returns:
49
+ torch.device object
50
+
51
+ Example:
52
+ >>> device = get_device()
53
+ >>> model = model.to(device)
54
+ """
55
+ if not TORCH_AVAILABLE:
56
+ raise ImportError("PyTorch not available")
57
+
58
+ if use_cuda and torch.cuda.is_available():
59
+ device = torch.device("cuda")
60
+ logger.info(f"Using GPU: {torch.cuda.get_device_name(0)}")
61
+ else:
62
+ device = torch.device("cpu")
63
+ logger.info("Using CPU")
64
+
65
+ return device
66
+
67
+
68
+ def count_parameters(model: "nn.Module") -> int:
69
+ """
70
+ Count total number of trainable parameters in a model.
71
+
72
+ Args:
73
+ model: PyTorch model
74
+
75
+ Returns:
76
+ Number of trainable parameters
77
+
78
+ Example:
79
+ >>> from torchvision import models
80
+ >>> model = models.resnet18()
81
+ >>> params = count_parameters(model)
82
+ >>> print(f"Parameters: {params:,}")
83
+ """
84
+ if not TORCH_AVAILABLE:
85
+ return 0
86
+ return sum(p.numel() for p in model.parameters() if p.requires_grad)
87
+
88
+
89
+ def count_parameters_by_layer(model: "nn.Module") -> Dict[str, int]:
90
+ """
91
+ Count parameters for each layer in the model.
92
+
93
+ Args:
94
+ model: PyTorch model
95
+
96
+ Returns:
97
+ Dictionary mapping layer names to parameter counts
98
+ """
99
+ if not TORCH_AVAILABLE:
100
+ return {}
101
+
102
+ param_dict = {}
103
+ for name, param in model.named_parameters():
104
+ if param.requires_grad:
105
+ param_dict[name] = param.numel()
106
+
107
+ return param_dict
108
+
109
+
110
+ def get_model_size_mb(model: "nn.Module") -> float:
111
+ """
112
+ Estimate model size in megabytes.
113
+
114
+ Args:
115
+ model: PyTorch model
116
+
117
+ Returns:
118
+ Model size in MB
119
+ """
120
+ if not TORCH_AVAILABLE:
121
+ return 0.0
122
+
123
+ param_size = 0
124
+ for param in model.parameters():
125
+ param_size += param.numel() * param.element_size()
126
+
127
+ buffer_size = 0
128
+ for buffer in model.buffers():
129
+ buffer_size += buffer.numel() * buffer.element_size()
130
+
131
+ size_mb = (param_size + buffer_size) / (1024**2)
132
+ return size_mb
133
+
134
+
135
+ def drop_path(x: "torch.Tensor", drop_prob: float, training: bool = True) -> "torch.Tensor":
136
+ """
137
+ Drop path (Stochastic Depth) regularization.
138
+
139
+ Randomly drops entire paths during training with probability drop_prob.
140
+
141
+ Args:
142
+ x: Input tensor
143
+ drop_prob: Probability of dropping the path
144
+ training: Whether in training mode
145
+
146
+ Returns:
147
+ Tensor with paths potentially dropped
148
+
149
+ Reference:
150
+ Huang et al. "Deep Networks with Stochastic Depth." ECCV 2016.
151
+
152
+ Example:
153
+ >>> x = torch.randn(32, 128, 16, 16)
154
+ >>> out = drop_path(x, drop_prob=0.2, training=True)
155
+ """
156
+ if not TORCH_AVAILABLE:
157
+ raise ImportError("PyTorch not available")
158
+
159
+ if not training or drop_prob == 0.0:
160
+ return x
161
+
162
+ keep_prob = 1.0 - drop_prob
163
+ shape = (x.shape[0],) + (1,) * (x.ndim - 1)
164
+ random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
165
+ random_tensor.floor_() # Binarize
166
+
167
+ output = x.div(keep_prob) * random_tensor
168
+ return output
169
+
170
+
171
+ class AverageMeter:
172
+ """
173
+ Computes and stores the average and current value.
174
+
175
+ Useful for tracking metrics during training.
176
+
177
+ Example:
178
+ >>> losses = AverageMeter()
179
+ >>> for batch in data_loader:
180
+ ... loss = compute_loss(batch)
181
+ ... losses.update(loss.item(), batch_size)
182
+ >>> print(f"Average loss: {losses.avg:.4f}")
183
+ """
184
+
185
+ def __init__(self):
186
+ self.reset()
187
+
188
+ def reset(self):
189
+ """Reset all statistics."""
190
+ self.val = 0
191
+ self.avg = 0
192
+ self.sum = 0
193
+ self.count = 0
194
+
195
+ def update(self, val: float, n: int = 1):
196
+ """
197
+ Update statistics with new value.
198
+
199
+ Args:
200
+ val: New value
201
+ n: Number of samples this value represents
202
+ """
203
+ self.val = val
204
+ self.sum += val * n
205
+ self.count += n
206
+ self.avg = self.sum / self.count if self.count > 0 else 0
207
+
208
+
209
+ def accuracy(
210
+ output: "torch.Tensor", target: "torch.Tensor", topk: Tuple[int, ...] = (1,)
211
+ ) -> List[float]:
212
+ """
213
+ Compute top-k accuracy.
214
+
215
+ Args:
216
+ output: Model predictions (logits), shape (N, num_classes)
217
+ target: Ground truth labels, shape (N,)
218
+ topk: Tuple of k values for top-k accuracy
219
+
220
+ Returns:
221
+ List of top-k accuracies
222
+
223
+ Example:
224
+ >>> logits = torch.randn(32, 10)
225
+ >>> targets = torch.randint(0, 10, (32,))
226
+ >>> top1, top5 = accuracy(logits, targets, topk=(1, 5))
227
+ """
228
+ if not TORCH_AVAILABLE:
229
+ return [0.0] * len(topk)
230
+
231
+ maxk = max(topk)
232
+ batch_size = target.size(0)
233
+
234
+ _, pred = output.topk(maxk, 1, True, True)
235
+ pred = pred.t()
236
+ correct = pred.eq(target.view(1, -1).expand_as(pred))
237
+
238
+ res = []
239
+ for k in topk:
240
+ correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
241
+ res.append(correct_k.mul_(100.0 / batch_size).item())
242
+
243
+ return res
244
+
245
+
246
+ class CosineAnnealingLR:
247
+ """
248
+ Cosine annealing learning rate scheduler.
249
+
250
+ Gradually decreases learning rate following a cosine curve.
251
+
252
+ Args:
253
+ optimizer: PyTorch optimizer
254
+ T_max: Maximum number of iterations
255
+ eta_min: Minimum learning rate
256
+
257
+ Example:
258
+ >>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
259
+ >>> scheduler = CosineAnnealingLR(optimizer, T_max=100)
260
+ >>> for epoch in range(100):
261
+ ... train(...)
262
+ ... scheduler.step()
263
+ """
264
+
265
+ def __init__(self, optimizer: "torch.optim.Optimizer", T_max: int, eta_min: float = 0):
266
+ self.optimizer = optimizer
267
+ self.T_max = T_max
268
+ self.eta_min = eta_min
269
+ self.base_lrs = [group["lr"] for group in optimizer.param_groups]
270
+ self.last_epoch = 0
271
+
272
+ def step(self):
273
+ """Update learning rate."""
274
+ import math
275
+
276
+ for i, param_group in enumerate(self.optimizer.param_groups):
277
+ lr = (
278
+ self.eta_min
279
+ + (self.base_lrs[i] - self.eta_min)
280
+ * (1 + math.cos(math.pi * self.last_epoch / self.T_max))
281
+ / 2
282
+ )
283
+ param_group["lr"] = lr
284
+
285
+ self.last_epoch += 1
286
+
287
+ def get_lr(self) -> List[float]:
288
+ """Get current learning rates."""
289
+ return [group["lr"] for group in self.optimizer.param_groups]
290
+
291
+
292
+ def save_checkpoint(state: Dict[str, Any], filepath: str, is_best: bool = False):
293
+ """
294
+ Save training checkpoint.
295
+
296
+ Args:
297
+ state: State dictionary containing model, optimizer, etc.
298
+ filepath: Path to save checkpoint
299
+ is_best: Whether this is the best model so far
300
+
301
+ Example:
302
+ >>> state = {
303
+ ... 'epoch': epoch,
304
+ ... 'model_state': model.state_dict(),
305
+ ... 'optimizer_state': optimizer.state_dict(),
306
+ ... 'best_acc': best_acc
307
+ ... }
308
+ >>> save_checkpoint(state, 'checkpoint.pth.tar', is_best=True)
309
+ """
310
+ if not TORCH_AVAILABLE:
311
+ logger.warning("PyTorch not available, cannot save checkpoint")
312
+ return
313
+
314
+ torch.save(state, filepath)
315
+ logger.info(f"Checkpoint saved to {filepath}")
316
+
317
+ if is_best:
318
+ import shutil
319
+
320
+ best_path = filepath.replace(".pth.tar", "_best.pth.tar")
321
+ shutil.copyfile(filepath, best_path)
322
+ logger.info(f"Best model saved to {best_path}")
323
+
324
+
325
+ def load_checkpoint(
326
+ filepath: str, model: "nn.Module", optimizer: Optional["torch.optim.Optimizer"] = None
327
+ ) -> Dict[str, Any]:
328
+ """
329
+ Load training checkpoint.
330
+
331
+ Args:
332
+ filepath: Path to checkpoint file
333
+ model: Model to load weights into
334
+ optimizer: Optional optimizer to load state
335
+
336
+ Returns:
337
+ Checkpoint dictionary
338
+
339
+ Example:
340
+ >>> checkpoint = load_checkpoint('checkpoint.pth.tar', model, optimizer)
341
+ >>> start_epoch = checkpoint['epoch']
342
+ """
343
+ if not TORCH_AVAILABLE:
344
+ raise ImportError("PyTorch not available")
345
+
346
+ checkpoint = torch.load(filepath)
347
+ model.load_state_dict(checkpoint["model_state"])
348
+
349
+ if optimizer is not None and "optimizer_state" in checkpoint:
350
+ optimizer.load_state_dict(checkpoint["optimizer_state"])
351
+
352
+ logger.info(f"Checkpoint loaded from {filepath}")
353
+ return checkpoint
354
+
355
+
356
+ def set_seed(seed: int):
357
+ """
358
+ Set random seed for reproducibility.
359
+
360
+ Args:
361
+ seed: Random seed
362
+
363
+ Example:
364
+ >>> set_seed(42)
365
+ """
366
+ import random
367
+
368
+ import numpy as np
369
+
370
+ random.seed(seed)
371
+ np.random.seed(seed)
372
+
373
+ if TORCH_AVAILABLE:
374
+ torch.manual_seed(seed)
375
+ torch.cuda.manual_seed_all(seed)
376
+ torch.backends.cudnn.deterministic = True
377
+ torch.backends.cudnn.benchmark = False
378
+
379
+ logger.info(f"Random seed set to {seed}")
380
+
381
+
382
+ def print_model_summary(model: "nn.Module", input_size: Tuple[int, ...]):
383
+ """
384
+ Print model summary including layer shapes and parameters.
385
+
386
+ Args:
387
+ model: PyTorch model
388
+ input_size: Input tensor size (C, H, W)
389
+
390
+ Example:
391
+ >>> model = MyModel()
392
+ >>> print_model_summary(model, (3, 32, 32))
393
+ """
394
+ if not TORCH_AVAILABLE:
395
+ logger.warning("PyTorch not available")
396
+ return
397
+
398
+ try:
399
+ from torchsummary import summary
400
+
401
+ summary(model, input_size)
402
+ except ImportError:
403
+ logger.warning("torchsummary not installed. Install with: pip install torchsummary")
404
+ # Fallback: simple parameter count
405
+ total_params = count_parameters(model)
406
+ print(f"Total parameters: {total_params:,}")
407
+
408
+
409
+ def freeze_model(model: "nn.Module"):
410
+ """
411
+ Freeze all model parameters (set requires_grad=False).
412
+
413
+ Args:
414
+ model: PyTorch model
415
+ """
416
+ if not TORCH_AVAILABLE:
417
+ return
418
+
419
+ for param in model.parameters():
420
+ param.requires_grad = False
421
+
422
+ logger.info("Model parameters frozen")
423
+
424
+
425
+ def unfreeze_model(model: "nn.Module"):
426
+ """
427
+ Unfreeze all model parameters (set requires_grad=True).
428
+
429
+ Args:
430
+ model: PyTorch model
431
+ """
432
+ if not TORCH_AVAILABLE:
433
+ return
434
+
435
+ for param in model.parameters():
436
+ param.requires_grad = True
437
+
438
+ logger.info("Model parameters unfrozen")
439
+
440
+
441
+ def get_memory_usage() -> Dict[str, float]:
442
+ """
443
+ Get current GPU memory usage.
444
+
445
+ Returns:
446
+ Dictionary with memory statistics in MB
447
+ """
448
+ if not TORCH_AVAILABLE or not torch.cuda.is_available():
449
+ return {}
450
+
451
+ return {
452
+ "allocated": torch.cuda.memory_allocated() / (1024**2),
453
+ "cached": torch.cuda.memory_reserved() / (1024**2),
454
+ "max_allocated": torch.cuda.max_memory_allocated() / (1024**2),
455
+ }
456
+
457
+
458
+ def print_memory_usage():
459
+ """Print current GPU memory usage."""
460
+ mem = get_memory_usage()
461
+ if mem:
462
+ print(
463
+ f"GPU Memory: Allocated={mem['allocated']:.2f}MB, "
464
+ f"Cached={mem['cached']:.2f}MB, "
465
+ f"Max={mem['max_allocated']:.2f}MB"
466
+ )
467
+ else:
468
+ print("GPU not available or PyTorch not installed")
469
+
470
+
471
+ def get_lr(optimizer: "torch.optim.Optimizer") -> List[float]:
472
+ """
473
+ Get current learning rates from optimizer.
474
+
475
+ Args:
476
+ optimizer: PyTorch optimizer
477
+
478
+ Returns:
479
+ List of current learning rates
480
+ """
481
+ if not TORCH_AVAILABLE:
482
+ return []
483
+ return [group["lr"] for group in optimizer.param_groups]
484
+
485
+
486
+ def clip_grad_norm(parameters, max_norm: float, norm_type: float = 2.0) -> float:
487
+ """
488
+ Clip gradient norm of parameters.
489
+
490
+ Args:
491
+ parameters: Model parameters
492
+ max_norm: Maximum norm
493
+ norm_type: Type of norm (2 for L2)
494
+
495
+ Returns:
496
+ Total norm before clipping
497
+ """
498
+ if not TORCH_AVAILABLE:
499
+ return 0.0
500
+ return torch.nn.utils.clip_grad_norm_(parameters, max_norm, norm_type=norm_type)
501
+
502
+
503
+ class EarlyStopping:
504
+ """
505
+ Early stopping to stop training when validation metric stops improving.
506
+
507
+ Args:
508
+ patience: Number of epochs to wait for improvement
509
+ mode: 'min' to minimize metric, 'max' to maximize
510
+ delta: Minimum change to qualify as improvement
511
+
512
+ Example:
513
+ >>> early_stopping = EarlyStopping(patience=10, mode='max')
514
+ >>> for epoch in range(epochs):
515
+ ... val_acc = validate(...)
516
+ ... if early_stopping(val_acc):
517
+ ... print("Early stopping triggered")
518
+ ... break
519
+ """
520
+
521
+ def __init__(self, patience: int = 10, mode: str = "max", delta: float = 0.0):
522
+ self.patience = patience
523
+ self.mode = mode
524
+ self.delta = delta
525
+ self.counter = 0
526
+ self.best_score = None
527
+ self.early_stop = False
528
+
529
+ def __call__(self, score: float) -> bool:
530
+ """
531
+ Check if should stop training.
532
+
533
+ Args:
534
+ score: Current validation metric
535
+
536
+ Returns:
537
+ True if should stop, False otherwise
538
+ """
539
+ if self.best_score is None:
540
+ self.best_score = score
541
+ elif self._is_better(score):
542
+ self.best_score = score
543
+ self.counter = 0
544
+ else:
545
+ self.counter += 1
546
+ if self.counter >= self.patience:
547
+ self.early_stop = True
548
+
549
+ return self.early_stop
550
+
551
+ def _is_better(self, score: float) -> bool:
552
+ """Check if score is better than best."""
553
+ if self.mode == "min":
554
+ return score < self.best_score - self.delta
555
+ else:
556
+ return score > self.best_score + self.delta
557
+
558
+
559
+ # Version check
560
+ def check_pytorch_version(min_version: str = "1.7.0") -> bool:
561
+ """
562
+ Check if PyTorch version meets minimum requirement.
563
+
564
+ Args:
565
+ min_version: Minimum required version
566
+
567
+ Returns:
568
+ True if version is sufficient
569
+ """
570
+ if not TORCH_AVAILABLE:
571
+ return False
572
+
573
+ from packaging import version
574
+
575
+ current_version = torch.__version__.split("+")[0] # Remove +cu111 suffix
576
+
577
+ return version.parse(current_version) >= version.parse(min_version)
578
+
579
+
580
+ if __name__ == "__main__":
581
+ print("Gradient-Based NAS Utilities")
582
+ print("=" * 50)
583
+
584
+ if TORCH_AVAILABLE:
585
+ print(f"PyTorch version: {torch.__version__}")
586
+ print(f"CUDA available: {check_cuda_available()}")
587
+
588
+ if check_cuda_available():
589
+ print(f"CUDA device: {torch.cuda.get_device_name(0)}")
590
+ print_memory_usage()
591
+
592
+ # Test utilities
593
+ print("\nTesting utilities...")
594
+ meter = AverageMeter()
595
+ for i in range(10):
596
+ meter.update(i, 1)
597
+ print(f"Average meter test: avg={meter.avg:.2f}")
598
+
599
+ print("\nAll utilities loaded successfully!")
600
+ else:
601
+ print("PyTorch not available. Install with: pip install torch")