morphml 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of morphml might be problematic. Click here for more details.
- morphml/__init__.py +14 -0
- morphml/api/__init__.py +26 -0
- morphml/api/app.py +326 -0
- morphml/api/auth.py +193 -0
- morphml/api/client.py +338 -0
- morphml/api/models.py +132 -0
- morphml/api/rate_limit.py +192 -0
- morphml/benchmarking/__init__.py +36 -0
- morphml/benchmarking/comparison.py +430 -0
- morphml/benchmarks/__init__.py +56 -0
- morphml/benchmarks/comparator.py +409 -0
- morphml/benchmarks/datasets.py +280 -0
- morphml/benchmarks/metrics.py +199 -0
- morphml/benchmarks/openml_suite.py +201 -0
- morphml/benchmarks/problems.py +289 -0
- morphml/benchmarks/suite.py +318 -0
- morphml/cli/__init__.py +5 -0
- morphml/cli/commands/experiment.py +329 -0
- morphml/cli/main.py +457 -0
- morphml/cli/quickstart.py +312 -0
- morphml/config.py +278 -0
- morphml/constraints/__init__.py +19 -0
- morphml/constraints/handler.py +205 -0
- morphml/constraints/predicates.py +285 -0
- morphml/core/__init__.py +3 -0
- morphml/core/crossover.py +449 -0
- morphml/core/dsl/README.md +359 -0
- morphml/core/dsl/__init__.py +72 -0
- morphml/core/dsl/ast_nodes.py +364 -0
- morphml/core/dsl/compiler.py +318 -0
- morphml/core/dsl/layers.py +368 -0
- morphml/core/dsl/lexer.py +336 -0
- morphml/core/dsl/parser.py +455 -0
- morphml/core/dsl/search_space.py +386 -0
- morphml/core/dsl/syntax.py +199 -0
- morphml/core/dsl/type_system.py +361 -0
- morphml/core/dsl/validator.py +386 -0
- morphml/core/graph/__init__.py +40 -0
- morphml/core/graph/edge.py +124 -0
- morphml/core/graph/graph.py +507 -0
- morphml/core/graph/mutations.py +409 -0
- morphml/core/graph/node.py +196 -0
- morphml/core/graph/serialization.py +361 -0
- morphml/core/graph/visualization.py +431 -0
- morphml/core/objectives/__init__.py +20 -0
- morphml/core/search/__init__.py +33 -0
- morphml/core/search/individual.py +252 -0
- morphml/core/search/parameters.py +453 -0
- morphml/core/search/population.py +375 -0
- morphml/core/search/search_engine.py +340 -0
- morphml/distributed/__init__.py +76 -0
- morphml/distributed/fault_tolerance.py +497 -0
- morphml/distributed/health_monitor.py +348 -0
- morphml/distributed/master.py +709 -0
- morphml/distributed/proto/README.md +224 -0
- morphml/distributed/proto/__init__.py +74 -0
- morphml/distributed/proto/worker.proto +170 -0
- morphml/distributed/proto/worker_pb2.py +79 -0
- morphml/distributed/proto/worker_pb2_grpc.py +423 -0
- morphml/distributed/resource_manager.py +416 -0
- morphml/distributed/scheduler.py +567 -0
- morphml/distributed/storage/__init__.py +33 -0
- morphml/distributed/storage/artifacts.py +381 -0
- morphml/distributed/storage/cache.py +366 -0
- morphml/distributed/storage/checkpointing.py +329 -0
- morphml/distributed/storage/database.py +459 -0
- morphml/distributed/worker.py +549 -0
- morphml/evaluation/__init__.py +5 -0
- morphml/evaluation/heuristic.py +237 -0
- morphml/exceptions.py +55 -0
- morphml/execution/__init__.py +5 -0
- morphml/execution/local_executor.py +350 -0
- morphml/integrations/__init__.py +28 -0
- morphml/integrations/jax_adapter.py +206 -0
- morphml/integrations/pytorch_adapter.py +530 -0
- morphml/integrations/sklearn_adapter.py +206 -0
- morphml/integrations/tensorflow_adapter.py +230 -0
- morphml/logging_config.py +93 -0
- morphml/meta_learning/__init__.py +66 -0
- morphml/meta_learning/architecture_similarity.py +277 -0
- morphml/meta_learning/experiment_database.py +240 -0
- morphml/meta_learning/knowledge_base/__init__.py +19 -0
- morphml/meta_learning/knowledge_base/embedder.py +179 -0
- morphml/meta_learning/knowledge_base/knowledge_base.py +313 -0
- morphml/meta_learning/knowledge_base/meta_features.py +265 -0
- morphml/meta_learning/knowledge_base/vector_store.py +271 -0
- morphml/meta_learning/predictors/__init__.py +27 -0
- morphml/meta_learning/predictors/ensemble.py +221 -0
- morphml/meta_learning/predictors/gnn_predictor.py +552 -0
- morphml/meta_learning/predictors/learning_curve.py +231 -0
- morphml/meta_learning/predictors/proxy_metrics.py +261 -0
- morphml/meta_learning/strategy_evolution/__init__.py +27 -0
- morphml/meta_learning/strategy_evolution/adaptive_optimizer.py +226 -0
- morphml/meta_learning/strategy_evolution/bandit.py +276 -0
- morphml/meta_learning/strategy_evolution/portfolio.py +230 -0
- morphml/meta_learning/transfer.py +581 -0
- morphml/meta_learning/warm_start.py +286 -0
- morphml/optimizers/__init__.py +74 -0
- morphml/optimizers/adaptive_operators.py +399 -0
- morphml/optimizers/bayesian/__init__.py +52 -0
- morphml/optimizers/bayesian/acquisition.py +387 -0
- morphml/optimizers/bayesian/base.py +319 -0
- morphml/optimizers/bayesian/gaussian_process.py +635 -0
- morphml/optimizers/bayesian/smac.py +534 -0
- morphml/optimizers/bayesian/tpe.py +411 -0
- morphml/optimizers/differential_evolution.py +220 -0
- morphml/optimizers/evolutionary/__init__.py +61 -0
- morphml/optimizers/evolutionary/cma_es.py +416 -0
- morphml/optimizers/evolutionary/differential_evolution.py +556 -0
- morphml/optimizers/evolutionary/encoding.py +426 -0
- morphml/optimizers/evolutionary/particle_swarm.py +449 -0
- morphml/optimizers/genetic_algorithm.py +486 -0
- morphml/optimizers/gradient_based/__init__.py +22 -0
- morphml/optimizers/gradient_based/darts.py +550 -0
- morphml/optimizers/gradient_based/enas.py +585 -0
- morphml/optimizers/gradient_based/operations.py +474 -0
- morphml/optimizers/gradient_based/utils.py +601 -0
- morphml/optimizers/hill_climbing.py +169 -0
- morphml/optimizers/multi_objective/__init__.py +56 -0
- morphml/optimizers/multi_objective/indicators.py +504 -0
- morphml/optimizers/multi_objective/nsga2.py +647 -0
- morphml/optimizers/multi_objective/visualization.py +427 -0
- morphml/optimizers/nsga2.py +308 -0
- morphml/optimizers/random_search.py +172 -0
- morphml/optimizers/simulated_annealing.py +181 -0
- morphml/plugins/__init__.py +35 -0
- morphml/plugins/custom_evaluator_example.py +81 -0
- morphml/plugins/custom_optimizer_example.py +63 -0
- morphml/plugins/plugin_system.py +454 -0
- morphml/reports/__init__.py +30 -0
- morphml/reports/generator.py +362 -0
- morphml/tracking/__init__.py +7 -0
- morphml/tracking/experiment.py +309 -0
- morphml/tracking/logger.py +301 -0
- morphml/tracking/reporter.py +357 -0
- morphml/utils/__init__.py +6 -0
- morphml/utils/checkpoint.py +189 -0
- morphml/utils/comparison.py +390 -0
- morphml/utils/export.py +407 -0
- morphml/utils/progress.py +392 -0
- morphml/utils/validation.py +392 -0
- morphml/version.py +7 -0
- morphml/visualization/__init__.py +50 -0
- morphml/visualization/analytics.py +423 -0
- morphml/visualization/architecture_diagrams.py +353 -0
- morphml/visualization/architecture_plot.py +223 -0
- morphml/visualization/convergence_plot.py +174 -0
- morphml/visualization/crossover_viz.py +386 -0
- morphml/visualization/graph_viz.py +338 -0
- morphml/visualization/pareto_plot.py +149 -0
- morphml/visualization/plotly_dashboards.py +422 -0
- morphml/visualization/population.py +309 -0
- morphml/visualization/progress.py +260 -0
- morphml-1.0.0.dist-info/METADATA +434 -0
- morphml-1.0.0.dist-info/RECORD +158 -0
- morphml-1.0.0.dist-info/WHEEL +4 -0
- morphml-1.0.0.dist-info/entry_points.txt +3 -0
- morphml-1.0.0.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,601 @@
|
|
|
1
|
+
"""Utility functions for gradient-based NAS.
|
|
2
|
+
|
|
3
|
+
This module provides helper functions for DARTS and ENAS implementations including:
|
|
4
|
+
- GPU management utilities
|
|
5
|
+
- Parameter counting
|
|
6
|
+
- Learning rate scheduling
|
|
7
|
+
- Drop path regularization
|
|
8
|
+
- Architecture visualization
|
|
9
|
+
|
|
10
|
+
Author: Eshan Roy <eshanized@proton.me>
|
|
11
|
+
Organization: TONMOY INFRASTRUCTURE & VISION
|
|
12
|
+
"""
|
|
13
|
+
|
|
14
|
+
from typing import Any, Dict, List, Optional, Tuple
|
|
15
|
+
|
|
16
|
+
try:
|
|
17
|
+
import torch
|
|
18
|
+
import torch.nn as nn
|
|
19
|
+
|
|
20
|
+
TORCH_AVAILABLE = True
|
|
21
|
+
except ImportError:
|
|
22
|
+
TORCH_AVAILABLE = False
|
|
23
|
+
|
|
24
|
+
from morphml.logging_config import get_logger
|
|
25
|
+
|
|
26
|
+
logger = get_logger(__name__)
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def check_cuda_available() -> bool:
|
|
30
|
+
"""
|
|
31
|
+
Check if CUDA is available.
|
|
32
|
+
|
|
33
|
+
Returns:
|
|
34
|
+
True if CUDA is available, False otherwise
|
|
35
|
+
"""
|
|
36
|
+
if not TORCH_AVAILABLE:
|
|
37
|
+
return False
|
|
38
|
+
return torch.cuda.is_available()
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
def get_device(use_cuda: bool = True) -> "torch.device":
|
|
42
|
+
"""
|
|
43
|
+
Get PyTorch device (CUDA or CPU).
|
|
44
|
+
|
|
45
|
+
Args:
|
|
46
|
+
use_cuda: Whether to use CUDA if available
|
|
47
|
+
|
|
48
|
+
Returns:
|
|
49
|
+
torch.device object
|
|
50
|
+
|
|
51
|
+
Example:
|
|
52
|
+
>>> device = get_device()
|
|
53
|
+
>>> model = model.to(device)
|
|
54
|
+
"""
|
|
55
|
+
if not TORCH_AVAILABLE:
|
|
56
|
+
raise ImportError("PyTorch not available")
|
|
57
|
+
|
|
58
|
+
if use_cuda and torch.cuda.is_available():
|
|
59
|
+
device = torch.device("cuda")
|
|
60
|
+
logger.info(f"Using GPU: {torch.cuda.get_device_name(0)}")
|
|
61
|
+
else:
|
|
62
|
+
device = torch.device("cpu")
|
|
63
|
+
logger.info("Using CPU")
|
|
64
|
+
|
|
65
|
+
return device
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
def count_parameters(model: "nn.Module") -> int:
|
|
69
|
+
"""
|
|
70
|
+
Count total number of trainable parameters in a model.
|
|
71
|
+
|
|
72
|
+
Args:
|
|
73
|
+
model: PyTorch model
|
|
74
|
+
|
|
75
|
+
Returns:
|
|
76
|
+
Number of trainable parameters
|
|
77
|
+
|
|
78
|
+
Example:
|
|
79
|
+
>>> from torchvision import models
|
|
80
|
+
>>> model = models.resnet18()
|
|
81
|
+
>>> params = count_parameters(model)
|
|
82
|
+
>>> print(f"Parameters: {params:,}")
|
|
83
|
+
"""
|
|
84
|
+
if not TORCH_AVAILABLE:
|
|
85
|
+
return 0
|
|
86
|
+
return sum(p.numel() for p in model.parameters() if p.requires_grad)
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
def count_parameters_by_layer(model: "nn.Module") -> Dict[str, int]:
|
|
90
|
+
"""
|
|
91
|
+
Count parameters for each layer in the model.
|
|
92
|
+
|
|
93
|
+
Args:
|
|
94
|
+
model: PyTorch model
|
|
95
|
+
|
|
96
|
+
Returns:
|
|
97
|
+
Dictionary mapping layer names to parameter counts
|
|
98
|
+
"""
|
|
99
|
+
if not TORCH_AVAILABLE:
|
|
100
|
+
return {}
|
|
101
|
+
|
|
102
|
+
param_dict = {}
|
|
103
|
+
for name, param in model.named_parameters():
|
|
104
|
+
if param.requires_grad:
|
|
105
|
+
param_dict[name] = param.numel()
|
|
106
|
+
|
|
107
|
+
return param_dict
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
def get_model_size_mb(model: "nn.Module") -> float:
|
|
111
|
+
"""
|
|
112
|
+
Estimate model size in megabytes.
|
|
113
|
+
|
|
114
|
+
Args:
|
|
115
|
+
model: PyTorch model
|
|
116
|
+
|
|
117
|
+
Returns:
|
|
118
|
+
Model size in MB
|
|
119
|
+
"""
|
|
120
|
+
if not TORCH_AVAILABLE:
|
|
121
|
+
return 0.0
|
|
122
|
+
|
|
123
|
+
param_size = 0
|
|
124
|
+
for param in model.parameters():
|
|
125
|
+
param_size += param.numel() * param.element_size()
|
|
126
|
+
|
|
127
|
+
buffer_size = 0
|
|
128
|
+
for buffer in model.buffers():
|
|
129
|
+
buffer_size += buffer.numel() * buffer.element_size()
|
|
130
|
+
|
|
131
|
+
size_mb = (param_size + buffer_size) / (1024**2)
|
|
132
|
+
return size_mb
|
|
133
|
+
|
|
134
|
+
|
|
135
|
+
def drop_path(x: "torch.Tensor", drop_prob: float, training: bool = True) -> "torch.Tensor":
|
|
136
|
+
"""
|
|
137
|
+
Drop path (Stochastic Depth) regularization.
|
|
138
|
+
|
|
139
|
+
Randomly drops entire paths during training with probability drop_prob.
|
|
140
|
+
|
|
141
|
+
Args:
|
|
142
|
+
x: Input tensor
|
|
143
|
+
drop_prob: Probability of dropping the path
|
|
144
|
+
training: Whether in training mode
|
|
145
|
+
|
|
146
|
+
Returns:
|
|
147
|
+
Tensor with paths potentially dropped
|
|
148
|
+
|
|
149
|
+
Reference:
|
|
150
|
+
Huang et al. "Deep Networks with Stochastic Depth." ECCV 2016.
|
|
151
|
+
|
|
152
|
+
Example:
|
|
153
|
+
>>> x = torch.randn(32, 128, 16, 16)
|
|
154
|
+
>>> out = drop_path(x, drop_prob=0.2, training=True)
|
|
155
|
+
"""
|
|
156
|
+
if not TORCH_AVAILABLE:
|
|
157
|
+
raise ImportError("PyTorch not available")
|
|
158
|
+
|
|
159
|
+
if not training or drop_prob == 0.0:
|
|
160
|
+
return x
|
|
161
|
+
|
|
162
|
+
keep_prob = 1.0 - drop_prob
|
|
163
|
+
shape = (x.shape[0],) + (1,) * (x.ndim - 1)
|
|
164
|
+
random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
|
|
165
|
+
random_tensor.floor_() # Binarize
|
|
166
|
+
|
|
167
|
+
output = x.div(keep_prob) * random_tensor
|
|
168
|
+
return output
|
|
169
|
+
|
|
170
|
+
|
|
171
|
+
class AverageMeter:
|
|
172
|
+
"""
|
|
173
|
+
Computes and stores the average and current value.
|
|
174
|
+
|
|
175
|
+
Useful for tracking metrics during training.
|
|
176
|
+
|
|
177
|
+
Example:
|
|
178
|
+
>>> losses = AverageMeter()
|
|
179
|
+
>>> for batch in data_loader:
|
|
180
|
+
... loss = compute_loss(batch)
|
|
181
|
+
... losses.update(loss.item(), batch_size)
|
|
182
|
+
>>> print(f"Average loss: {losses.avg:.4f}")
|
|
183
|
+
"""
|
|
184
|
+
|
|
185
|
+
def __init__(self):
|
|
186
|
+
self.reset()
|
|
187
|
+
|
|
188
|
+
def reset(self):
|
|
189
|
+
"""Reset all statistics."""
|
|
190
|
+
self.val = 0
|
|
191
|
+
self.avg = 0
|
|
192
|
+
self.sum = 0
|
|
193
|
+
self.count = 0
|
|
194
|
+
|
|
195
|
+
def update(self, val: float, n: int = 1):
|
|
196
|
+
"""
|
|
197
|
+
Update statistics with new value.
|
|
198
|
+
|
|
199
|
+
Args:
|
|
200
|
+
val: New value
|
|
201
|
+
n: Number of samples this value represents
|
|
202
|
+
"""
|
|
203
|
+
self.val = val
|
|
204
|
+
self.sum += val * n
|
|
205
|
+
self.count += n
|
|
206
|
+
self.avg = self.sum / self.count if self.count > 0 else 0
|
|
207
|
+
|
|
208
|
+
|
|
209
|
+
def accuracy(
|
|
210
|
+
output: "torch.Tensor", target: "torch.Tensor", topk: Tuple[int, ...] = (1,)
|
|
211
|
+
) -> List[float]:
|
|
212
|
+
"""
|
|
213
|
+
Compute top-k accuracy.
|
|
214
|
+
|
|
215
|
+
Args:
|
|
216
|
+
output: Model predictions (logits), shape (N, num_classes)
|
|
217
|
+
target: Ground truth labels, shape (N,)
|
|
218
|
+
topk: Tuple of k values for top-k accuracy
|
|
219
|
+
|
|
220
|
+
Returns:
|
|
221
|
+
List of top-k accuracies
|
|
222
|
+
|
|
223
|
+
Example:
|
|
224
|
+
>>> logits = torch.randn(32, 10)
|
|
225
|
+
>>> targets = torch.randint(0, 10, (32,))
|
|
226
|
+
>>> top1, top5 = accuracy(logits, targets, topk=(1, 5))
|
|
227
|
+
"""
|
|
228
|
+
if not TORCH_AVAILABLE:
|
|
229
|
+
return [0.0] * len(topk)
|
|
230
|
+
|
|
231
|
+
maxk = max(topk)
|
|
232
|
+
batch_size = target.size(0)
|
|
233
|
+
|
|
234
|
+
_, pred = output.topk(maxk, 1, True, True)
|
|
235
|
+
pred = pred.t()
|
|
236
|
+
correct = pred.eq(target.view(1, -1).expand_as(pred))
|
|
237
|
+
|
|
238
|
+
res = []
|
|
239
|
+
for k in topk:
|
|
240
|
+
correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
|
|
241
|
+
res.append(correct_k.mul_(100.0 / batch_size).item())
|
|
242
|
+
|
|
243
|
+
return res
|
|
244
|
+
|
|
245
|
+
|
|
246
|
+
class CosineAnnealingLR:
|
|
247
|
+
"""
|
|
248
|
+
Cosine annealing learning rate scheduler.
|
|
249
|
+
|
|
250
|
+
Gradually decreases learning rate following a cosine curve.
|
|
251
|
+
|
|
252
|
+
Args:
|
|
253
|
+
optimizer: PyTorch optimizer
|
|
254
|
+
T_max: Maximum number of iterations
|
|
255
|
+
eta_min: Minimum learning rate
|
|
256
|
+
|
|
257
|
+
Example:
|
|
258
|
+
>>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
|
|
259
|
+
>>> scheduler = CosineAnnealingLR(optimizer, T_max=100)
|
|
260
|
+
>>> for epoch in range(100):
|
|
261
|
+
... train(...)
|
|
262
|
+
... scheduler.step()
|
|
263
|
+
"""
|
|
264
|
+
|
|
265
|
+
def __init__(self, optimizer: "torch.optim.Optimizer", T_max: int, eta_min: float = 0):
|
|
266
|
+
self.optimizer = optimizer
|
|
267
|
+
self.T_max = T_max
|
|
268
|
+
self.eta_min = eta_min
|
|
269
|
+
self.base_lrs = [group["lr"] for group in optimizer.param_groups]
|
|
270
|
+
self.last_epoch = 0
|
|
271
|
+
|
|
272
|
+
def step(self):
|
|
273
|
+
"""Update learning rate."""
|
|
274
|
+
import math
|
|
275
|
+
|
|
276
|
+
for i, param_group in enumerate(self.optimizer.param_groups):
|
|
277
|
+
lr = (
|
|
278
|
+
self.eta_min
|
|
279
|
+
+ (self.base_lrs[i] - self.eta_min)
|
|
280
|
+
* (1 + math.cos(math.pi * self.last_epoch / self.T_max))
|
|
281
|
+
/ 2
|
|
282
|
+
)
|
|
283
|
+
param_group["lr"] = lr
|
|
284
|
+
|
|
285
|
+
self.last_epoch += 1
|
|
286
|
+
|
|
287
|
+
def get_lr(self) -> List[float]:
|
|
288
|
+
"""Get current learning rates."""
|
|
289
|
+
return [group["lr"] for group in self.optimizer.param_groups]
|
|
290
|
+
|
|
291
|
+
|
|
292
|
+
def save_checkpoint(state: Dict[str, Any], filepath: str, is_best: bool = False):
|
|
293
|
+
"""
|
|
294
|
+
Save training checkpoint.
|
|
295
|
+
|
|
296
|
+
Args:
|
|
297
|
+
state: State dictionary containing model, optimizer, etc.
|
|
298
|
+
filepath: Path to save checkpoint
|
|
299
|
+
is_best: Whether this is the best model so far
|
|
300
|
+
|
|
301
|
+
Example:
|
|
302
|
+
>>> state = {
|
|
303
|
+
... 'epoch': epoch,
|
|
304
|
+
... 'model_state': model.state_dict(),
|
|
305
|
+
... 'optimizer_state': optimizer.state_dict(),
|
|
306
|
+
... 'best_acc': best_acc
|
|
307
|
+
... }
|
|
308
|
+
>>> save_checkpoint(state, 'checkpoint.pth.tar', is_best=True)
|
|
309
|
+
"""
|
|
310
|
+
if not TORCH_AVAILABLE:
|
|
311
|
+
logger.warning("PyTorch not available, cannot save checkpoint")
|
|
312
|
+
return
|
|
313
|
+
|
|
314
|
+
torch.save(state, filepath)
|
|
315
|
+
logger.info(f"Checkpoint saved to {filepath}")
|
|
316
|
+
|
|
317
|
+
if is_best:
|
|
318
|
+
import shutil
|
|
319
|
+
|
|
320
|
+
best_path = filepath.replace(".pth.tar", "_best.pth.tar")
|
|
321
|
+
shutil.copyfile(filepath, best_path)
|
|
322
|
+
logger.info(f"Best model saved to {best_path}")
|
|
323
|
+
|
|
324
|
+
|
|
325
|
+
def load_checkpoint(
|
|
326
|
+
filepath: str, model: "nn.Module", optimizer: Optional["torch.optim.Optimizer"] = None
|
|
327
|
+
) -> Dict[str, Any]:
|
|
328
|
+
"""
|
|
329
|
+
Load training checkpoint.
|
|
330
|
+
|
|
331
|
+
Args:
|
|
332
|
+
filepath: Path to checkpoint file
|
|
333
|
+
model: Model to load weights into
|
|
334
|
+
optimizer: Optional optimizer to load state
|
|
335
|
+
|
|
336
|
+
Returns:
|
|
337
|
+
Checkpoint dictionary
|
|
338
|
+
|
|
339
|
+
Example:
|
|
340
|
+
>>> checkpoint = load_checkpoint('checkpoint.pth.tar', model, optimizer)
|
|
341
|
+
>>> start_epoch = checkpoint['epoch']
|
|
342
|
+
"""
|
|
343
|
+
if not TORCH_AVAILABLE:
|
|
344
|
+
raise ImportError("PyTorch not available")
|
|
345
|
+
|
|
346
|
+
checkpoint = torch.load(filepath)
|
|
347
|
+
model.load_state_dict(checkpoint["model_state"])
|
|
348
|
+
|
|
349
|
+
if optimizer is not None and "optimizer_state" in checkpoint:
|
|
350
|
+
optimizer.load_state_dict(checkpoint["optimizer_state"])
|
|
351
|
+
|
|
352
|
+
logger.info(f"Checkpoint loaded from {filepath}")
|
|
353
|
+
return checkpoint
|
|
354
|
+
|
|
355
|
+
|
|
356
|
+
def set_seed(seed: int):
|
|
357
|
+
"""
|
|
358
|
+
Set random seed for reproducibility.
|
|
359
|
+
|
|
360
|
+
Args:
|
|
361
|
+
seed: Random seed
|
|
362
|
+
|
|
363
|
+
Example:
|
|
364
|
+
>>> set_seed(42)
|
|
365
|
+
"""
|
|
366
|
+
import random
|
|
367
|
+
|
|
368
|
+
import numpy as np
|
|
369
|
+
|
|
370
|
+
random.seed(seed)
|
|
371
|
+
np.random.seed(seed)
|
|
372
|
+
|
|
373
|
+
if TORCH_AVAILABLE:
|
|
374
|
+
torch.manual_seed(seed)
|
|
375
|
+
torch.cuda.manual_seed_all(seed)
|
|
376
|
+
torch.backends.cudnn.deterministic = True
|
|
377
|
+
torch.backends.cudnn.benchmark = False
|
|
378
|
+
|
|
379
|
+
logger.info(f"Random seed set to {seed}")
|
|
380
|
+
|
|
381
|
+
|
|
382
|
+
def print_model_summary(model: "nn.Module", input_size: Tuple[int, ...]):
|
|
383
|
+
"""
|
|
384
|
+
Print model summary including layer shapes and parameters.
|
|
385
|
+
|
|
386
|
+
Args:
|
|
387
|
+
model: PyTorch model
|
|
388
|
+
input_size: Input tensor size (C, H, W)
|
|
389
|
+
|
|
390
|
+
Example:
|
|
391
|
+
>>> model = MyModel()
|
|
392
|
+
>>> print_model_summary(model, (3, 32, 32))
|
|
393
|
+
"""
|
|
394
|
+
if not TORCH_AVAILABLE:
|
|
395
|
+
logger.warning("PyTorch not available")
|
|
396
|
+
return
|
|
397
|
+
|
|
398
|
+
try:
|
|
399
|
+
from torchsummary import summary
|
|
400
|
+
|
|
401
|
+
summary(model, input_size)
|
|
402
|
+
except ImportError:
|
|
403
|
+
logger.warning("torchsummary not installed. Install with: pip install torchsummary")
|
|
404
|
+
# Fallback: simple parameter count
|
|
405
|
+
total_params = count_parameters(model)
|
|
406
|
+
print(f"Total parameters: {total_params:,}")
|
|
407
|
+
|
|
408
|
+
|
|
409
|
+
def freeze_model(model: "nn.Module"):
|
|
410
|
+
"""
|
|
411
|
+
Freeze all model parameters (set requires_grad=False).
|
|
412
|
+
|
|
413
|
+
Args:
|
|
414
|
+
model: PyTorch model
|
|
415
|
+
"""
|
|
416
|
+
if not TORCH_AVAILABLE:
|
|
417
|
+
return
|
|
418
|
+
|
|
419
|
+
for param in model.parameters():
|
|
420
|
+
param.requires_grad = False
|
|
421
|
+
|
|
422
|
+
logger.info("Model parameters frozen")
|
|
423
|
+
|
|
424
|
+
|
|
425
|
+
def unfreeze_model(model: "nn.Module"):
|
|
426
|
+
"""
|
|
427
|
+
Unfreeze all model parameters (set requires_grad=True).
|
|
428
|
+
|
|
429
|
+
Args:
|
|
430
|
+
model: PyTorch model
|
|
431
|
+
"""
|
|
432
|
+
if not TORCH_AVAILABLE:
|
|
433
|
+
return
|
|
434
|
+
|
|
435
|
+
for param in model.parameters():
|
|
436
|
+
param.requires_grad = True
|
|
437
|
+
|
|
438
|
+
logger.info("Model parameters unfrozen")
|
|
439
|
+
|
|
440
|
+
|
|
441
|
+
def get_memory_usage() -> Dict[str, float]:
|
|
442
|
+
"""
|
|
443
|
+
Get current GPU memory usage.
|
|
444
|
+
|
|
445
|
+
Returns:
|
|
446
|
+
Dictionary with memory statistics in MB
|
|
447
|
+
"""
|
|
448
|
+
if not TORCH_AVAILABLE or not torch.cuda.is_available():
|
|
449
|
+
return {}
|
|
450
|
+
|
|
451
|
+
return {
|
|
452
|
+
"allocated": torch.cuda.memory_allocated() / (1024**2),
|
|
453
|
+
"cached": torch.cuda.memory_reserved() / (1024**2),
|
|
454
|
+
"max_allocated": torch.cuda.max_memory_allocated() / (1024**2),
|
|
455
|
+
}
|
|
456
|
+
|
|
457
|
+
|
|
458
|
+
def print_memory_usage():
|
|
459
|
+
"""Print current GPU memory usage."""
|
|
460
|
+
mem = get_memory_usage()
|
|
461
|
+
if mem:
|
|
462
|
+
print(
|
|
463
|
+
f"GPU Memory: Allocated={mem['allocated']:.2f}MB, "
|
|
464
|
+
f"Cached={mem['cached']:.2f}MB, "
|
|
465
|
+
f"Max={mem['max_allocated']:.2f}MB"
|
|
466
|
+
)
|
|
467
|
+
else:
|
|
468
|
+
print("GPU not available or PyTorch not installed")
|
|
469
|
+
|
|
470
|
+
|
|
471
|
+
def get_lr(optimizer: "torch.optim.Optimizer") -> List[float]:
|
|
472
|
+
"""
|
|
473
|
+
Get current learning rates from optimizer.
|
|
474
|
+
|
|
475
|
+
Args:
|
|
476
|
+
optimizer: PyTorch optimizer
|
|
477
|
+
|
|
478
|
+
Returns:
|
|
479
|
+
List of current learning rates
|
|
480
|
+
"""
|
|
481
|
+
if not TORCH_AVAILABLE:
|
|
482
|
+
return []
|
|
483
|
+
return [group["lr"] for group in optimizer.param_groups]
|
|
484
|
+
|
|
485
|
+
|
|
486
|
+
def clip_grad_norm(parameters, max_norm: float, norm_type: float = 2.0) -> float:
|
|
487
|
+
"""
|
|
488
|
+
Clip gradient norm of parameters.
|
|
489
|
+
|
|
490
|
+
Args:
|
|
491
|
+
parameters: Model parameters
|
|
492
|
+
max_norm: Maximum norm
|
|
493
|
+
norm_type: Type of norm (2 for L2)
|
|
494
|
+
|
|
495
|
+
Returns:
|
|
496
|
+
Total norm before clipping
|
|
497
|
+
"""
|
|
498
|
+
if not TORCH_AVAILABLE:
|
|
499
|
+
return 0.0
|
|
500
|
+
return torch.nn.utils.clip_grad_norm_(parameters, max_norm, norm_type=norm_type)
|
|
501
|
+
|
|
502
|
+
|
|
503
|
+
class EarlyStopping:
|
|
504
|
+
"""
|
|
505
|
+
Early stopping to stop training when validation metric stops improving.
|
|
506
|
+
|
|
507
|
+
Args:
|
|
508
|
+
patience: Number of epochs to wait for improvement
|
|
509
|
+
mode: 'min' to minimize metric, 'max' to maximize
|
|
510
|
+
delta: Minimum change to qualify as improvement
|
|
511
|
+
|
|
512
|
+
Example:
|
|
513
|
+
>>> early_stopping = EarlyStopping(patience=10, mode='max')
|
|
514
|
+
>>> for epoch in range(epochs):
|
|
515
|
+
... val_acc = validate(...)
|
|
516
|
+
... if early_stopping(val_acc):
|
|
517
|
+
... print("Early stopping triggered")
|
|
518
|
+
... break
|
|
519
|
+
"""
|
|
520
|
+
|
|
521
|
+
def __init__(self, patience: int = 10, mode: str = "max", delta: float = 0.0):
|
|
522
|
+
self.patience = patience
|
|
523
|
+
self.mode = mode
|
|
524
|
+
self.delta = delta
|
|
525
|
+
self.counter = 0
|
|
526
|
+
self.best_score = None
|
|
527
|
+
self.early_stop = False
|
|
528
|
+
|
|
529
|
+
def __call__(self, score: float) -> bool:
|
|
530
|
+
"""
|
|
531
|
+
Check if should stop training.
|
|
532
|
+
|
|
533
|
+
Args:
|
|
534
|
+
score: Current validation metric
|
|
535
|
+
|
|
536
|
+
Returns:
|
|
537
|
+
True if should stop, False otherwise
|
|
538
|
+
"""
|
|
539
|
+
if self.best_score is None:
|
|
540
|
+
self.best_score = score
|
|
541
|
+
elif self._is_better(score):
|
|
542
|
+
self.best_score = score
|
|
543
|
+
self.counter = 0
|
|
544
|
+
else:
|
|
545
|
+
self.counter += 1
|
|
546
|
+
if self.counter >= self.patience:
|
|
547
|
+
self.early_stop = True
|
|
548
|
+
|
|
549
|
+
return self.early_stop
|
|
550
|
+
|
|
551
|
+
def _is_better(self, score: float) -> bool:
|
|
552
|
+
"""Check if score is better than best."""
|
|
553
|
+
if self.mode == "min":
|
|
554
|
+
return score < self.best_score - self.delta
|
|
555
|
+
else:
|
|
556
|
+
return score > self.best_score + self.delta
|
|
557
|
+
|
|
558
|
+
|
|
559
|
+
# Version check
|
|
560
|
+
def check_pytorch_version(min_version: str = "1.7.0") -> bool:
|
|
561
|
+
"""
|
|
562
|
+
Check if PyTorch version meets minimum requirement.
|
|
563
|
+
|
|
564
|
+
Args:
|
|
565
|
+
min_version: Minimum required version
|
|
566
|
+
|
|
567
|
+
Returns:
|
|
568
|
+
True if version is sufficient
|
|
569
|
+
"""
|
|
570
|
+
if not TORCH_AVAILABLE:
|
|
571
|
+
return False
|
|
572
|
+
|
|
573
|
+
from packaging import version
|
|
574
|
+
|
|
575
|
+
current_version = torch.__version__.split("+")[0] # Remove +cu111 suffix
|
|
576
|
+
|
|
577
|
+
return version.parse(current_version) >= version.parse(min_version)
|
|
578
|
+
|
|
579
|
+
|
|
580
|
+
if __name__ == "__main__":
|
|
581
|
+
print("Gradient-Based NAS Utilities")
|
|
582
|
+
print("=" * 50)
|
|
583
|
+
|
|
584
|
+
if TORCH_AVAILABLE:
|
|
585
|
+
print(f"PyTorch version: {torch.__version__}")
|
|
586
|
+
print(f"CUDA available: {check_cuda_available()}")
|
|
587
|
+
|
|
588
|
+
if check_cuda_available():
|
|
589
|
+
print(f"CUDA device: {torch.cuda.get_device_name(0)}")
|
|
590
|
+
print_memory_usage()
|
|
591
|
+
|
|
592
|
+
# Test utilities
|
|
593
|
+
print("\nTesting utilities...")
|
|
594
|
+
meter = AverageMeter()
|
|
595
|
+
for i in range(10):
|
|
596
|
+
meter.update(i, 1)
|
|
597
|
+
print(f"Average meter test: avg={meter.avg:.2f}")
|
|
598
|
+
|
|
599
|
+
print("\nAll utilities loaded successfully!")
|
|
600
|
+
else:
|
|
601
|
+
print("PyTorch not available. Install with: pip install torch")
|