morphml 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of morphml might be problematic. Click here for more details.
- morphml/__init__.py +14 -0
- morphml/api/__init__.py +26 -0
- morphml/api/app.py +326 -0
- morphml/api/auth.py +193 -0
- morphml/api/client.py +338 -0
- morphml/api/models.py +132 -0
- morphml/api/rate_limit.py +192 -0
- morphml/benchmarking/__init__.py +36 -0
- morphml/benchmarking/comparison.py +430 -0
- morphml/benchmarks/__init__.py +56 -0
- morphml/benchmarks/comparator.py +409 -0
- morphml/benchmarks/datasets.py +280 -0
- morphml/benchmarks/metrics.py +199 -0
- morphml/benchmarks/openml_suite.py +201 -0
- morphml/benchmarks/problems.py +289 -0
- morphml/benchmarks/suite.py +318 -0
- morphml/cli/__init__.py +5 -0
- morphml/cli/commands/experiment.py +329 -0
- morphml/cli/main.py +457 -0
- morphml/cli/quickstart.py +312 -0
- morphml/config.py +278 -0
- morphml/constraints/__init__.py +19 -0
- morphml/constraints/handler.py +205 -0
- morphml/constraints/predicates.py +285 -0
- morphml/core/__init__.py +3 -0
- morphml/core/crossover.py +449 -0
- morphml/core/dsl/README.md +359 -0
- morphml/core/dsl/__init__.py +72 -0
- morphml/core/dsl/ast_nodes.py +364 -0
- morphml/core/dsl/compiler.py +318 -0
- morphml/core/dsl/layers.py +368 -0
- morphml/core/dsl/lexer.py +336 -0
- morphml/core/dsl/parser.py +455 -0
- morphml/core/dsl/search_space.py +386 -0
- morphml/core/dsl/syntax.py +199 -0
- morphml/core/dsl/type_system.py +361 -0
- morphml/core/dsl/validator.py +386 -0
- morphml/core/graph/__init__.py +40 -0
- morphml/core/graph/edge.py +124 -0
- morphml/core/graph/graph.py +507 -0
- morphml/core/graph/mutations.py +409 -0
- morphml/core/graph/node.py +196 -0
- morphml/core/graph/serialization.py +361 -0
- morphml/core/graph/visualization.py +431 -0
- morphml/core/objectives/__init__.py +20 -0
- morphml/core/search/__init__.py +33 -0
- morphml/core/search/individual.py +252 -0
- morphml/core/search/parameters.py +453 -0
- morphml/core/search/population.py +375 -0
- morphml/core/search/search_engine.py +340 -0
- morphml/distributed/__init__.py +76 -0
- morphml/distributed/fault_tolerance.py +497 -0
- morphml/distributed/health_monitor.py +348 -0
- morphml/distributed/master.py +709 -0
- morphml/distributed/proto/README.md +224 -0
- morphml/distributed/proto/__init__.py +74 -0
- morphml/distributed/proto/worker.proto +170 -0
- morphml/distributed/proto/worker_pb2.py +79 -0
- morphml/distributed/proto/worker_pb2_grpc.py +423 -0
- morphml/distributed/resource_manager.py +416 -0
- morphml/distributed/scheduler.py +567 -0
- morphml/distributed/storage/__init__.py +33 -0
- morphml/distributed/storage/artifacts.py +381 -0
- morphml/distributed/storage/cache.py +366 -0
- morphml/distributed/storage/checkpointing.py +329 -0
- morphml/distributed/storage/database.py +459 -0
- morphml/distributed/worker.py +549 -0
- morphml/evaluation/__init__.py +5 -0
- morphml/evaluation/heuristic.py +237 -0
- morphml/exceptions.py +55 -0
- morphml/execution/__init__.py +5 -0
- morphml/execution/local_executor.py +350 -0
- morphml/integrations/__init__.py +28 -0
- morphml/integrations/jax_adapter.py +206 -0
- morphml/integrations/pytorch_adapter.py +530 -0
- morphml/integrations/sklearn_adapter.py +206 -0
- morphml/integrations/tensorflow_adapter.py +230 -0
- morphml/logging_config.py +93 -0
- morphml/meta_learning/__init__.py +66 -0
- morphml/meta_learning/architecture_similarity.py +277 -0
- morphml/meta_learning/experiment_database.py +240 -0
- morphml/meta_learning/knowledge_base/__init__.py +19 -0
- morphml/meta_learning/knowledge_base/embedder.py +179 -0
- morphml/meta_learning/knowledge_base/knowledge_base.py +313 -0
- morphml/meta_learning/knowledge_base/meta_features.py +265 -0
- morphml/meta_learning/knowledge_base/vector_store.py +271 -0
- morphml/meta_learning/predictors/__init__.py +27 -0
- morphml/meta_learning/predictors/ensemble.py +221 -0
- morphml/meta_learning/predictors/gnn_predictor.py +552 -0
- morphml/meta_learning/predictors/learning_curve.py +231 -0
- morphml/meta_learning/predictors/proxy_metrics.py +261 -0
- morphml/meta_learning/strategy_evolution/__init__.py +27 -0
- morphml/meta_learning/strategy_evolution/adaptive_optimizer.py +226 -0
- morphml/meta_learning/strategy_evolution/bandit.py +276 -0
- morphml/meta_learning/strategy_evolution/portfolio.py +230 -0
- morphml/meta_learning/transfer.py +581 -0
- morphml/meta_learning/warm_start.py +286 -0
- morphml/optimizers/__init__.py +74 -0
- morphml/optimizers/adaptive_operators.py +399 -0
- morphml/optimizers/bayesian/__init__.py +52 -0
- morphml/optimizers/bayesian/acquisition.py +387 -0
- morphml/optimizers/bayesian/base.py +319 -0
- morphml/optimizers/bayesian/gaussian_process.py +635 -0
- morphml/optimizers/bayesian/smac.py +534 -0
- morphml/optimizers/bayesian/tpe.py +411 -0
- morphml/optimizers/differential_evolution.py +220 -0
- morphml/optimizers/evolutionary/__init__.py +61 -0
- morphml/optimizers/evolutionary/cma_es.py +416 -0
- morphml/optimizers/evolutionary/differential_evolution.py +556 -0
- morphml/optimizers/evolutionary/encoding.py +426 -0
- morphml/optimizers/evolutionary/particle_swarm.py +449 -0
- morphml/optimizers/genetic_algorithm.py +486 -0
- morphml/optimizers/gradient_based/__init__.py +22 -0
- morphml/optimizers/gradient_based/darts.py +550 -0
- morphml/optimizers/gradient_based/enas.py +585 -0
- morphml/optimizers/gradient_based/operations.py +474 -0
- morphml/optimizers/gradient_based/utils.py +601 -0
- morphml/optimizers/hill_climbing.py +169 -0
- morphml/optimizers/multi_objective/__init__.py +56 -0
- morphml/optimizers/multi_objective/indicators.py +504 -0
- morphml/optimizers/multi_objective/nsga2.py +647 -0
- morphml/optimizers/multi_objective/visualization.py +427 -0
- morphml/optimizers/nsga2.py +308 -0
- morphml/optimizers/random_search.py +172 -0
- morphml/optimizers/simulated_annealing.py +181 -0
- morphml/plugins/__init__.py +35 -0
- morphml/plugins/custom_evaluator_example.py +81 -0
- morphml/plugins/custom_optimizer_example.py +63 -0
- morphml/plugins/plugin_system.py +454 -0
- morphml/reports/__init__.py +30 -0
- morphml/reports/generator.py +362 -0
- morphml/tracking/__init__.py +7 -0
- morphml/tracking/experiment.py +309 -0
- morphml/tracking/logger.py +301 -0
- morphml/tracking/reporter.py +357 -0
- morphml/utils/__init__.py +6 -0
- morphml/utils/checkpoint.py +189 -0
- morphml/utils/comparison.py +390 -0
- morphml/utils/export.py +407 -0
- morphml/utils/progress.py +392 -0
- morphml/utils/validation.py +392 -0
- morphml/version.py +7 -0
- morphml/visualization/__init__.py +50 -0
- morphml/visualization/analytics.py +423 -0
- morphml/visualization/architecture_diagrams.py +353 -0
- morphml/visualization/architecture_plot.py +223 -0
- morphml/visualization/convergence_plot.py +174 -0
- morphml/visualization/crossover_viz.py +386 -0
- morphml/visualization/graph_viz.py +338 -0
- morphml/visualization/pareto_plot.py +149 -0
- morphml/visualization/plotly_dashboards.py +422 -0
- morphml/visualization/population.py +309 -0
- morphml/visualization/progress.py +260 -0
- morphml-1.0.0.dist-info/METADATA +434 -0
- morphml-1.0.0.dist-info/RECORD +158 -0
- morphml-1.0.0.dist-info/WHEEL +4 -0
- morphml-1.0.0.dist-info/entry_points.txt +3 -0
- morphml-1.0.0.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,474 @@
|
|
|
1
|
+
"""Operation primitives for gradient-based NAS (DARTS, ENAS).
|
|
2
|
+
|
|
3
|
+
This module provides efficient operation implementations used in differentiable
|
|
4
|
+
architecture search. All operations are PyTorch modules optimized for GPU execution.
|
|
5
|
+
|
|
6
|
+
Operations:
|
|
7
|
+
- Separable Convolutions (SepConv)
|
|
8
|
+
- Dilated Convolutions (DilConv)
|
|
9
|
+
- Pooling operations (Max, Avg)
|
|
10
|
+
- Skip connections (Identity)
|
|
11
|
+
- Zero operations (None)
|
|
12
|
+
|
|
13
|
+
Reference:
|
|
14
|
+
Liu, H., et al. "DARTS: Differentiable Architecture Search." ICLR 2019.
|
|
15
|
+
Pham, H., et al. "Efficient Neural Architecture Search via Parameter Sharing." ICML 2018.
|
|
16
|
+
|
|
17
|
+
Author: Eshan Roy <eshanized@proton.me>
|
|
18
|
+
Organization: TONMOY INFRASTRUCTURE & VISION
|
|
19
|
+
"""
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
try:
|
|
23
|
+
import torch
|
|
24
|
+
import torch.nn as nn
|
|
25
|
+
|
|
26
|
+
TORCH_AVAILABLE = True
|
|
27
|
+
except ImportError:
|
|
28
|
+
TORCH_AVAILABLE = False
|
|
29
|
+
|
|
30
|
+
# Dummy classes for type checking
|
|
31
|
+
class nn:
|
|
32
|
+
class Module:
|
|
33
|
+
pass
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
from morphml.logging_config import get_logger
|
|
37
|
+
|
|
38
|
+
logger = get_logger(__name__)
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
def check_torch_available() -> None:
|
|
42
|
+
"""Check if PyTorch is available and raise error if not."""
|
|
43
|
+
if not TORCH_AVAILABLE:
|
|
44
|
+
raise ImportError(
|
|
45
|
+
"PyTorch is required for gradient-based NAS. "
|
|
46
|
+
"Install with: pip install torch or poetry add torch"
|
|
47
|
+
)
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
class SepConv(nn.Module):
|
|
51
|
+
"""
|
|
52
|
+
Separable Convolution operation.
|
|
53
|
+
|
|
54
|
+
Separable convolutions reduce computational cost by factorizing a standard
|
|
55
|
+
convolution into depthwise and pointwise convolutions.
|
|
56
|
+
|
|
57
|
+
Standard conv: O(C_in * C_out * k^2)
|
|
58
|
+
Separable conv: O(C_in * k^2 + C_in * C_out)
|
|
59
|
+
|
|
60
|
+
Structure:
|
|
61
|
+
ReLU → Depthwise Conv → Pointwise Conv → BatchNorm
|
|
62
|
+
|
|
63
|
+
Args:
|
|
64
|
+
C_in: Input channels
|
|
65
|
+
C_out: Output channels
|
|
66
|
+
kernel_size: Kernel size for depthwise conv
|
|
67
|
+
stride: Stride for convolution
|
|
68
|
+
padding: Padding for convolution
|
|
69
|
+
affine: Whether to use learnable affine params in BN
|
|
70
|
+
|
|
71
|
+
Example:
|
|
72
|
+
>>> sep_conv = SepConv(16, 32, kernel_size=3, stride=1, padding=1)
|
|
73
|
+
>>> x = torch.randn(2, 16, 32, 32)
|
|
74
|
+
>>> out = sep_conv(x) # Shape: (2, 32, 32, 32)
|
|
75
|
+
"""
|
|
76
|
+
|
|
77
|
+
def __init__(
|
|
78
|
+
self,
|
|
79
|
+
C_in: int,
|
|
80
|
+
C_out: int,
|
|
81
|
+
kernel_size: int,
|
|
82
|
+
stride: int,
|
|
83
|
+
padding: int,
|
|
84
|
+
affine: bool = True,
|
|
85
|
+
):
|
|
86
|
+
super().__init__()
|
|
87
|
+
|
|
88
|
+
check_torch_available()
|
|
89
|
+
|
|
90
|
+
self.op = nn.Sequential(
|
|
91
|
+
nn.ReLU(inplace=False),
|
|
92
|
+
# Depthwise convolution (groups=C_in)
|
|
93
|
+
nn.Conv2d(
|
|
94
|
+
C_in,
|
|
95
|
+
C_in,
|
|
96
|
+
kernel_size=kernel_size,
|
|
97
|
+
stride=stride,
|
|
98
|
+
padding=padding,
|
|
99
|
+
groups=C_in,
|
|
100
|
+
bias=False,
|
|
101
|
+
),
|
|
102
|
+
# Pointwise convolution (1x1)
|
|
103
|
+
nn.Conv2d(C_in, C_out, kernel_size=1, bias=False),
|
|
104
|
+
nn.BatchNorm2d(C_out, affine=affine),
|
|
105
|
+
)
|
|
106
|
+
|
|
107
|
+
def forward(self, x: "torch.Tensor") -> "torch.Tensor":
|
|
108
|
+
"""Forward pass."""
|
|
109
|
+
return self.op(x)
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
class DilConv(nn.Module):
|
|
113
|
+
"""
|
|
114
|
+
Dilated (Atrous) Convolution operation.
|
|
115
|
+
|
|
116
|
+
Dilated convolutions increase receptive field without increasing
|
|
117
|
+
parameter count by inserting spaces (dilation) between kernel elements.
|
|
118
|
+
|
|
119
|
+
Receptive field = k + (k-1) * (dilation-1)
|
|
120
|
+
|
|
121
|
+
Structure:
|
|
122
|
+
ReLU → Dilated Depthwise Conv → Pointwise Conv → BatchNorm
|
|
123
|
+
|
|
124
|
+
Args:
|
|
125
|
+
C_in: Input channels
|
|
126
|
+
C_out: Output channels
|
|
127
|
+
kernel_size: Kernel size
|
|
128
|
+
stride: Stride
|
|
129
|
+
padding: Padding (should be = dilation * (kernel_size - 1) / 2)
|
|
130
|
+
dilation: Dilation rate
|
|
131
|
+
affine: Whether to use learnable affine params in BN
|
|
132
|
+
|
|
133
|
+
Example:
|
|
134
|
+
>>> dil_conv = DilConv(16, 32, kernel_size=3, stride=1,
|
|
135
|
+
... padding=2, dilation=2)
|
|
136
|
+
>>> x = torch.randn(2, 16, 32, 32)
|
|
137
|
+
>>> out = dil_conv(x) # Shape: (2, 32, 32, 32)
|
|
138
|
+
"""
|
|
139
|
+
|
|
140
|
+
def __init__(
|
|
141
|
+
self,
|
|
142
|
+
C_in: int,
|
|
143
|
+
C_out: int,
|
|
144
|
+
kernel_size: int,
|
|
145
|
+
stride: int,
|
|
146
|
+
padding: int,
|
|
147
|
+
dilation: int,
|
|
148
|
+
affine: bool = True,
|
|
149
|
+
):
|
|
150
|
+
super().__init__()
|
|
151
|
+
|
|
152
|
+
check_torch_available()
|
|
153
|
+
|
|
154
|
+
self.op = nn.Sequential(
|
|
155
|
+
nn.ReLU(inplace=False),
|
|
156
|
+
# Dilated depthwise convolution
|
|
157
|
+
nn.Conv2d(
|
|
158
|
+
C_in,
|
|
159
|
+
C_in,
|
|
160
|
+
kernel_size=kernel_size,
|
|
161
|
+
stride=stride,
|
|
162
|
+
padding=padding,
|
|
163
|
+
dilation=dilation,
|
|
164
|
+
groups=C_in,
|
|
165
|
+
bias=False,
|
|
166
|
+
),
|
|
167
|
+
# Pointwise convolution
|
|
168
|
+
nn.Conv2d(C_in, C_out, kernel_size=1, bias=False),
|
|
169
|
+
nn.BatchNorm2d(C_out, affine=affine),
|
|
170
|
+
)
|
|
171
|
+
|
|
172
|
+
def forward(self, x: "torch.Tensor") -> "torch.Tensor":
|
|
173
|
+
"""Forward pass."""
|
|
174
|
+
return self.op(x)
|
|
175
|
+
|
|
176
|
+
|
|
177
|
+
class Identity(nn.Module):
|
|
178
|
+
"""
|
|
179
|
+
Identity operation (skip connection).
|
|
180
|
+
|
|
181
|
+
Simply passes input to output unchanged. Used for residual connections.
|
|
182
|
+
|
|
183
|
+
Example:
|
|
184
|
+
>>> identity = Identity()
|
|
185
|
+
>>> x = torch.randn(2, 16, 32, 32)
|
|
186
|
+
>>> out = identity(x)
|
|
187
|
+
>>> assert torch.equal(x, out)
|
|
188
|
+
"""
|
|
189
|
+
|
|
190
|
+
def __init__(self):
|
|
191
|
+
super().__init__()
|
|
192
|
+
check_torch_available()
|
|
193
|
+
|
|
194
|
+
def forward(self, x: "torch.Tensor") -> "torch.Tensor":
|
|
195
|
+
"""Forward pass (identity)."""
|
|
196
|
+
return x
|
|
197
|
+
|
|
198
|
+
|
|
199
|
+
class Zero(nn.Module):
|
|
200
|
+
"""
|
|
201
|
+
Zero operation (no connection).
|
|
202
|
+
|
|
203
|
+
Returns a zero tensor of the same shape as input. Used to represent
|
|
204
|
+
the absence of a connection in the architecture.
|
|
205
|
+
|
|
206
|
+
Args:
|
|
207
|
+
stride: Stride for output shape calculation
|
|
208
|
+
|
|
209
|
+
Example:
|
|
210
|
+
>>> zero = Zero(stride=1)
|
|
211
|
+
>>> x = torch.randn(2, 16, 32, 32)
|
|
212
|
+
>>> out = zero(x)
|
|
213
|
+
>>> assert torch.all(out == 0)
|
|
214
|
+
"""
|
|
215
|
+
|
|
216
|
+
def __init__(self, stride: int = 1):
|
|
217
|
+
super().__init__()
|
|
218
|
+
check_torch_available()
|
|
219
|
+
self.stride = stride
|
|
220
|
+
|
|
221
|
+
def forward(self, x: "torch.Tensor") -> "torch.Tensor":
|
|
222
|
+
"""Forward pass (zeros)."""
|
|
223
|
+
if self.stride == 1:
|
|
224
|
+
return x.mul(0.0)
|
|
225
|
+
else:
|
|
226
|
+
# Stride > 1: reduce spatial dimensions
|
|
227
|
+
return x[:, :, :: self.stride, :: self.stride].mul(0.0)
|
|
228
|
+
|
|
229
|
+
|
|
230
|
+
class FactorizedReduce(nn.Module):
|
|
231
|
+
"""
|
|
232
|
+
Factorized reduction operation.
|
|
233
|
+
|
|
234
|
+
Reduces spatial dimensions by factor of 2 while doubling channels.
|
|
235
|
+
More efficient than strided convolution for this specific task.
|
|
236
|
+
|
|
237
|
+
Method:
|
|
238
|
+
1. Two parallel 1x1 convolutions with different offsets
|
|
239
|
+
2. Concatenate outputs
|
|
240
|
+
3. BatchNorm
|
|
241
|
+
|
|
242
|
+
Args:
|
|
243
|
+
C_in: Input channels
|
|
244
|
+
C_out: Output channels (typically 2 * C_in)
|
|
245
|
+
affine: Whether to use learnable affine params in BN
|
|
246
|
+
|
|
247
|
+
Example:
|
|
248
|
+
>>> reduce = FactorizedReduce(16, 32)
|
|
249
|
+
>>> x = torch.randn(2, 16, 32, 32)
|
|
250
|
+
>>> out = reduce(x) # Shape: (2, 32, 16, 16)
|
|
251
|
+
"""
|
|
252
|
+
|
|
253
|
+
def __init__(self, C_in: int, C_out: int, affine: bool = True):
|
|
254
|
+
super().__init__()
|
|
255
|
+
|
|
256
|
+
check_torch_available()
|
|
257
|
+
|
|
258
|
+
assert C_out % 2 == 0, "C_out must be divisible by 2"
|
|
259
|
+
|
|
260
|
+
self.relu = nn.ReLU(inplace=False)
|
|
261
|
+
self.conv_1 = nn.Conv2d(C_in, C_out // 2, 1, stride=2, bias=False)
|
|
262
|
+
self.conv_2 = nn.Conv2d(C_in, C_out // 2, 1, stride=2, bias=False)
|
|
263
|
+
self.bn = nn.BatchNorm2d(C_out, affine=affine)
|
|
264
|
+
|
|
265
|
+
def forward(self, x: "torch.Tensor") -> "torch.Tensor":
|
|
266
|
+
"""Forward pass."""
|
|
267
|
+
x = self.relu(x)
|
|
268
|
+
# Two parallel convolutions with offset
|
|
269
|
+
out1 = self.conv_1(x)
|
|
270
|
+
out2 = self.conv_2(x[:, :, 1:, 1:]) # Offset by 1 pixel
|
|
271
|
+
out = torch.cat([out1, out2], dim=1)
|
|
272
|
+
out = self.bn(out)
|
|
273
|
+
return out
|
|
274
|
+
|
|
275
|
+
|
|
276
|
+
class ReLUConvBN(nn.Module):
|
|
277
|
+
"""
|
|
278
|
+
Standard convolution block: ReLU → Conv → BatchNorm.
|
|
279
|
+
|
|
280
|
+
Args:
|
|
281
|
+
C_in: Input channels
|
|
282
|
+
C_out: Output channels
|
|
283
|
+
kernel_size: Kernel size
|
|
284
|
+
stride: Stride
|
|
285
|
+
padding: Padding
|
|
286
|
+
affine: Whether to use learnable affine params in BN
|
|
287
|
+
|
|
288
|
+
Example:
|
|
289
|
+
>>> conv_block = ReLUConvBN(16, 32, 3, 1, 1)
|
|
290
|
+
>>> x = torch.randn(2, 16, 32, 32)
|
|
291
|
+
>>> out = conv_block(x) # Shape: (2, 32, 32, 32)
|
|
292
|
+
"""
|
|
293
|
+
|
|
294
|
+
def __init__(
|
|
295
|
+
self,
|
|
296
|
+
C_in: int,
|
|
297
|
+
C_out: int,
|
|
298
|
+
kernel_size: int,
|
|
299
|
+
stride: int,
|
|
300
|
+
padding: int,
|
|
301
|
+
affine: bool = True,
|
|
302
|
+
):
|
|
303
|
+
super().__init__()
|
|
304
|
+
|
|
305
|
+
check_torch_available()
|
|
306
|
+
|
|
307
|
+
self.op = nn.Sequential(
|
|
308
|
+
nn.ReLU(inplace=False),
|
|
309
|
+
nn.Conv2d(C_in, C_out, kernel_size, stride=stride, padding=padding, bias=False),
|
|
310
|
+
nn.BatchNorm2d(C_out, affine=affine),
|
|
311
|
+
)
|
|
312
|
+
|
|
313
|
+
def forward(self, x: "torch.Tensor") -> "torch.Tensor":
|
|
314
|
+
"""Forward pass."""
|
|
315
|
+
return self.op(x)
|
|
316
|
+
|
|
317
|
+
|
|
318
|
+
class DropPath(nn.Module):
|
|
319
|
+
"""
|
|
320
|
+
Drop Path (Stochastic Depth) regularization.
|
|
321
|
+
|
|
322
|
+
Randomly drops entire paths (operations) during training to prevent
|
|
323
|
+
over-reliance on specific paths and improve generalization.
|
|
324
|
+
|
|
325
|
+
Reference:
|
|
326
|
+
Huang et al. "Deep Networks with Stochastic Depth." ECCV 2016.
|
|
327
|
+
|
|
328
|
+
Args:
|
|
329
|
+
drop_prob: Probability of dropping a path
|
|
330
|
+
|
|
331
|
+
Example:
|
|
332
|
+
>>> drop_path = DropPath(drop_prob=0.2)
|
|
333
|
+
>>> x = torch.randn(2, 16, 32, 32)
|
|
334
|
+
>>> out = drop_path(x) # Randomly zeroed during training
|
|
335
|
+
"""
|
|
336
|
+
|
|
337
|
+
def __init__(self, drop_prob: float = 0.0):
|
|
338
|
+
super().__init__()
|
|
339
|
+
check_torch_available()
|
|
340
|
+
self.drop_prob = drop_prob
|
|
341
|
+
|
|
342
|
+
def forward(self, x: "torch.Tensor") -> "torch.Tensor":
|
|
343
|
+
"""Forward pass with stochastic depth."""
|
|
344
|
+
if not self.training or self.drop_prob == 0.0:
|
|
345
|
+
return x
|
|
346
|
+
|
|
347
|
+
keep_prob = 1.0 - self.drop_prob
|
|
348
|
+
# Create binary mask
|
|
349
|
+
shape = (x.shape[0],) + (1,) * (x.ndim - 1) # (N, 1, 1, 1)
|
|
350
|
+
random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
|
|
351
|
+
random_tensor.floor_() # Binarize
|
|
352
|
+
|
|
353
|
+
# Scale and apply mask
|
|
354
|
+
output = x.div(keep_prob) * random_tensor
|
|
355
|
+
return output
|
|
356
|
+
|
|
357
|
+
|
|
358
|
+
# Operation factory function
|
|
359
|
+
def create_operation(
|
|
360
|
+
op_name: str, C_in: int, C_out: int, stride: int = 1, affine: bool = True
|
|
361
|
+
) -> nn.Module:
|
|
362
|
+
"""
|
|
363
|
+
Factory function to create operations by name.
|
|
364
|
+
|
|
365
|
+
Args:
|
|
366
|
+
op_name: Operation name
|
|
367
|
+
C_in: Input channels
|
|
368
|
+
C_out: Output channels
|
|
369
|
+
stride: Stride
|
|
370
|
+
affine: Use affine params in BN
|
|
371
|
+
|
|
372
|
+
Returns:
|
|
373
|
+
Operation module
|
|
374
|
+
|
|
375
|
+
Raises:
|
|
376
|
+
ValueError: If operation name is unknown
|
|
377
|
+
|
|
378
|
+
Example:
|
|
379
|
+
>>> op = create_operation('sep_conv_3x3', C_in=16, C_out=32)
|
|
380
|
+
>>> x = torch.randn(2, 16, 32, 32)
|
|
381
|
+
>>> out = op(x)
|
|
382
|
+
"""
|
|
383
|
+
check_torch_available()
|
|
384
|
+
|
|
385
|
+
if op_name == "none":
|
|
386
|
+
return Zero(stride=stride)
|
|
387
|
+
|
|
388
|
+
elif op_name == "skip_connect":
|
|
389
|
+
if stride == 1:
|
|
390
|
+
return Identity()
|
|
391
|
+
else:
|
|
392
|
+
return FactorizedReduce(C_in, C_out, affine=affine)
|
|
393
|
+
|
|
394
|
+
elif op_name == "max_pool_3x3":
|
|
395
|
+
return nn.MaxPool2d(3, stride=stride, padding=1)
|
|
396
|
+
|
|
397
|
+
elif op_name == "avg_pool_3x3":
|
|
398
|
+
return nn.AvgPool2d(3, stride=stride, padding=1, count_include_pad=False)
|
|
399
|
+
|
|
400
|
+
elif op_name == "sep_conv_3x3":
|
|
401
|
+
return SepConv(C_in, C_out, 3, stride, 1, affine=affine)
|
|
402
|
+
|
|
403
|
+
elif op_name == "sep_conv_5x5":
|
|
404
|
+
return SepConv(C_in, C_out, 5, stride, 2, affine=affine)
|
|
405
|
+
|
|
406
|
+
elif op_name == "sep_conv_7x7":
|
|
407
|
+
return SepConv(C_in, C_out, 7, stride, 3, affine=affine)
|
|
408
|
+
|
|
409
|
+
elif op_name == "dil_conv_3x3":
|
|
410
|
+
return DilConv(C_in, C_out, 3, stride, 2, dilation=2, affine=affine)
|
|
411
|
+
|
|
412
|
+
elif op_name == "dil_conv_5x5":
|
|
413
|
+
return DilConv(C_in, C_out, 5, stride, 4, dilation=2, affine=affine)
|
|
414
|
+
|
|
415
|
+
elif op_name == "conv_3x3":
|
|
416
|
+
return ReLUConvBN(C_in, C_out, 3, stride, 1, affine=affine)
|
|
417
|
+
|
|
418
|
+
elif op_name == "conv_1x1":
|
|
419
|
+
return ReLUConvBN(C_in, C_out, 1, stride, 0, affine=affine)
|
|
420
|
+
|
|
421
|
+
else:
|
|
422
|
+
raise ValueError(f"Unknown operation: {op_name}")
|
|
423
|
+
|
|
424
|
+
|
|
425
|
+
# Standard operation set for DARTS/ENAS
|
|
426
|
+
OPERATIONS = [
|
|
427
|
+
"none",
|
|
428
|
+
"max_pool_3x3",
|
|
429
|
+
"avg_pool_3x3",
|
|
430
|
+
"skip_connect",
|
|
431
|
+
"sep_conv_3x3",
|
|
432
|
+
"sep_conv_5x5",
|
|
433
|
+
"dil_conv_3x3",
|
|
434
|
+
"dil_conv_5x5",
|
|
435
|
+
]
|
|
436
|
+
|
|
437
|
+
|
|
438
|
+
def get_operation_names() -> list:
|
|
439
|
+
"""Get list of available operation names."""
|
|
440
|
+
return OPERATIONS.copy()
|
|
441
|
+
|
|
442
|
+
|
|
443
|
+
def count_operation_parameters(op: nn.Module) -> int:
|
|
444
|
+
"""
|
|
445
|
+
Count number of trainable parameters in an operation.
|
|
446
|
+
|
|
447
|
+
Args:
|
|
448
|
+
op: PyTorch module
|
|
449
|
+
|
|
450
|
+
Returns:
|
|
451
|
+
Number of trainable parameters
|
|
452
|
+
"""
|
|
453
|
+
if not TORCH_AVAILABLE:
|
|
454
|
+
return 0
|
|
455
|
+
return sum(p.numel() for p in op.parameters() if p.requires_grad)
|
|
456
|
+
|
|
457
|
+
|
|
458
|
+
# Example usage and testing
|
|
459
|
+
if __name__ == "__main__":
|
|
460
|
+
if TORCH_AVAILABLE:
|
|
461
|
+
print("Testing operations...")
|
|
462
|
+
|
|
463
|
+
# Test each operation
|
|
464
|
+
x = torch.randn(2, 16, 32, 32)
|
|
465
|
+
|
|
466
|
+
for op_name in OPERATIONS:
|
|
467
|
+
op = create_operation(op_name, C_in=16, C_out=16)
|
|
468
|
+
out = op(x)
|
|
469
|
+
params = count_operation_parameters(op)
|
|
470
|
+
print(f"{op_name:20s} | Output shape: {tuple(out.shape)} | Params: {params:,}")
|
|
471
|
+
|
|
472
|
+
print("\nAll operations tested successfully!")
|
|
473
|
+
else:
|
|
474
|
+
print("PyTorch not available. Install with: pip install torch")
|