morphml 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of morphml might be problematic. Click here for more details.

Files changed (158) hide show
  1. morphml/__init__.py +14 -0
  2. morphml/api/__init__.py +26 -0
  3. morphml/api/app.py +326 -0
  4. morphml/api/auth.py +193 -0
  5. morphml/api/client.py +338 -0
  6. morphml/api/models.py +132 -0
  7. morphml/api/rate_limit.py +192 -0
  8. morphml/benchmarking/__init__.py +36 -0
  9. morphml/benchmarking/comparison.py +430 -0
  10. morphml/benchmarks/__init__.py +56 -0
  11. morphml/benchmarks/comparator.py +409 -0
  12. morphml/benchmarks/datasets.py +280 -0
  13. morphml/benchmarks/metrics.py +199 -0
  14. morphml/benchmarks/openml_suite.py +201 -0
  15. morphml/benchmarks/problems.py +289 -0
  16. morphml/benchmarks/suite.py +318 -0
  17. morphml/cli/__init__.py +5 -0
  18. morphml/cli/commands/experiment.py +329 -0
  19. morphml/cli/main.py +457 -0
  20. morphml/cli/quickstart.py +312 -0
  21. morphml/config.py +278 -0
  22. morphml/constraints/__init__.py +19 -0
  23. morphml/constraints/handler.py +205 -0
  24. morphml/constraints/predicates.py +285 -0
  25. morphml/core/__init__.py +3 -0
  26. morphml/core/crossover.py +449 -0
  27. morphml/core/dsl/README.md +359 -0
  28. morphml/core/dsl/__init__.py +72 -0
  29. morphml/core/dsl/ast_nodes.py +364 -0
  30. morphml/core/dsl/compiler.py +318 -0
  31. morphml/core/dsl/layers.py +368 -0
  32. morphml/core/dsl/lexer.py +336 -0
  33. morphml/core/dsl/parser.py +455 -0
  34. morphml/core/dsl/search_space.py +386 -0
  35. morphml/core/dsl/syntax.py +199 -0
  36. morphml/core/dsl/type_system.py +361 -0
  37. morphml/core/dsl/validator.py +386 -0
  38. morphml/core/graph/__init__.py +40 -0
  39. morphml/core/graph/edge.py +124 -0
  40. morphml/core/graph/graph.py +507 -0
  41. morphml/core/graph/mutations.py +409 -0
  42. morphml/core/graph/node.py +196 -0
  43. morphml/core/graph/serialization.py +361 -0
  44. morphml/core/graph/visualization.py +431 -0
  45. morphml/core/objectives/__init__.py +20 -0
  46. morphml/core/search/__init__.py +33 -0
  47. morphml/core/search/individual.py +252 -0
  48. morphml/core/search/parameters.py +453 -0
  49. morphml/core/search/population.py +375 -0
  50. morphml/core/search/search_engine.py +340 -0
  51. morphml/distributed/__init__.py +76 -0
  52. morphml/distributed/fault_tolerance.py +497 -0
  53. morphml/distributed/health_monitor.py +348 -0
  54. morphml/distributed/master.py +709 -0
  55. morphml/distributed/proto/README.md +224 -0
  56. morphml/distributed/proto/__init__.py +74 -0
  57. morphml/distributed/proto/worker.proto +170 -0
  58. morphml/distributed/proto/worker_pb2.py +79 -0
  59. morphml/distributed/proto/worker_pb2_grpc.py +423 -0
  60. morphml/distributed/resource_manager.py +416 -0
  61. morphml/distributed/scheduler.py +567 -0
  62. morphml/distributed/storage/__init__.py +33 -0
  63. morphml/distributed/storage/artifacts.py +381 -0
  64. morphml/distributed/storage/cache.py +366 -0
  65. morphml/distributed/storage/checkpointing.py +329 -0
  66. morphml/distributed/storage/database.py +459 -0
  67. morphml/distributed/worker.py +549 -0
  68. morphml/evaluation/__init__.py +5 -0
  69. morphml/evaluation/heuristic.py +237 -0
  70. morphml/exceptions.py +55 -0
  71. morphml/execution/__init__.py +5 -0
  72. morphml/execution/local_executor.py +350 -0
  73. morphml/integrations/__init__.py +28 -0
  74. morphml/integrations/jax_adapter.py +206 -0
  75. morphml/integrations/pytorch_adapter.py +530 -0
  76. morphml/integrations/sklearn_adapter.py +206 -0
  77. morphml/integrations/tensorflow_adapter.py +230 -0
  78. morphml/logging_config.py +93 -0
  79. morphml/meta_learning/__init__.py +66 -0
  80. morphml/meta_learning/architecture_similarity.py +277 -0
  81. morphml/meta_learning/experiment_database.py +240 -0
  82. morphml/meta_learning/knowledge_base/__init__.py +19 -0
  83. morphml/meta_learning/knowledge_base/embedder.py +179 -0
  84. morphml/meta_learning/knowledge_base/knowledge_base.py +313 -0
  85. morphml/meta_learning/knowledge_base/meta_features.py +265 -0
  86. morphml/meta_learning/knowledge_base/vector_store.py +271 -0
  87. morphml/meta_learning/predictors/__init__.py +27 -0
  88. morphml/meta_learning/predictors/ensemble.py +221 -0
  89. morphml/meta_learning/predictors/gnn_predictor.py +552 -0
  90. morphml/meta_learning/predictors/learning_curve.py +231 -0
  91. morphml/meta_learning/predictors/proxy_metrics.py +261 -0
  92. morphml/meta_learning/strategy_evolution/__init__.py +27 -0
  93. morphml/meta_learning/strategy_evolution/adaptive_optimizer.py +226 -0
  94. morphml/meta_learning/strategy_evolution/bandit.py +276 -0
  95. morphml/meta_learning/strategy_evolution/portfolio.py +230 -0
  96. morphml/meta_learning/transfer.py +581 -0
  97. morphml/meta_learning/warm_start.py +286 -0
  98. morphml/optimizers/__init__.py +74 -0
  99. morphml/optimizers/adaptive_operators.py +399 -0
  100. morphml/optimizers/bayesian/__init__.py +52 -0
  101. morphml/optimizers/bayesian/acquisition.py +387 -0
  102. morphml/optimizers/bayesian/base.py +319 -0
  103. morphml/optimizers/bayesian/gaussian_process.py +635 -0
  104. morphml/optimizers/bayesian/smac.py +534 -0
  105. morphml/optimizers/bayesian/tpe.py +411 -0
  106. morphml/optimizers/differential_evolution.py +220 -0
  107. morphml/optimizers/evolutionary/__init__.py +61 -0
  108. morphml/optimizers/evolutionary/cma_es.py +416 -0
  109. morphml/optimizers/evolutionary/differential_evolution.py +556 -0
  110. morphml/optimizers/evolutionary/encoding.py +426 -0
  111. morphml/optimizers/evolutionary/particle_swarm.py +449 -0
  112. morphml/optimizers/genetic_algorithm.py +486 -0
  113. morphml/optimizers/gradient_based/__init__.py +22 -0
  114. morphml/optimizers/gradient_based/darts.py +550 -0
  115. morphml/optimizers/gradient_based/enas.py +585 -0
  116. morphml/optimizers/gradient_based/operations.py +474 -0
  117. morphml/optimizers/gradient_based/utils.py +601 -0
  118. morphml/optimizers/hill_climbing.py +169 -0
  119. morphml/optimizers/multi_objective/__init__.py +56 -0
  120. morphml/optimizers/multi_objective/indicators.py +504 -0
  121. morphml/optimizers/multi_objective/nsga2.py +647 -0
  122. morphml/optimizers/multi_objective/visualization.py +427 -0
  123. morphml/optimizers/nsga2.py +308 -0
  124. morphml/optimizers/random_search.py +172 -0
  125. morphml/optimizers/simulated_annealing.py +181 -0
  126. morphml/plugins/__init__.py +35 -0
  127. morphml/plugins/custom_evaluator_example.py +81 -0
  128. morphml/plugins/custom_optimizer_example.py +63 -0
  129. morphml/plugins/plugin_system.py +454 -0
  130. morphml/reports/__init__.py +30 -0
  131. morphml/reports/generator.py +362 -0
  132. morphml/tracking/__init__.py +7 -0
  133. morphml/tracking/experiment.py +309 -0
  134. morphml/tracking/logger.py +301 -0
  135. morphml/tracking/reporter.py +357 -0
  136. morphml/utils/__init__.py +6 -0
  137. morphml/utils/checkpoint.py +189 -0
  138. morphml/utils/comparison.py +390 -0
  139. morphml/utils/export.py +407 -0
  140. morphml/utils/progress.py +392 -0
  141. morphml/utils/validation.py +392 -0
  142. morphml/version.py +7 -0
  143. morphml/visualization/__init__.py +50 -0
  144. morphml/visualization/analytics.py +423 -0
  145. morphml/visualization/architecture_diagrams.py +353 -0
  146. morphml/visualization/architecture_plot.py +223 -0
  147. morphml/visualization/convergence_plot.py +174 -0
  148. morphml/visualization/crossover_viz.py +386 -0
  149. morphml/visualization/graph_viz.py +338 -0
  150. morphml/visualization/pareto_plot.py +149 -0
  151. morphml/visualization/plotly_dashboards.py +422 -0
  152. morphml/visualization/population.py +309 -0
  153. morphml/visualization/progress.py +260 -0
  154. morphml-1.0.0.dist-info/METADATA +434 -0
  155. morphml-1.0.0.dist-info/RECORD +158 -0
  156. morphml-1.0.0.dist-info/WHEEL +4 -0
  157. morphml-1.0.0.dist-info/entry_points.txt +3 -0
  158. morphml-1.0.0.dist-info/licenses/LICENSE +21 -0
@@ -0,0 +1,474 @@
1
+ """Operation primitives for gradient-based NAS (DARTS, ENAS).
2
+
3
+ This module provides efficient operation implementations used in differentiable
4
+ architecture search. All operations are PyTorch modules optimized for GPU execution.
5
+
6
+ Operations:
7
+ - Separable Convolutions (SepConv)
8
+ - Dilated Convolutions (DilConv)
9
+ - Pooling operations (Max, Avg)
10
+ - Skip connections (Identity)
11
+ - Zero operations (None)
12
+
13
+ Reference:
14
+ Liu, H., et al. "DARTS: Differentiable Architecture Search." ICLR 2019.
15
+ Pham, H., et al. "Efficient Neural Architecture Search via Parameter Sharing." ICML 2018.
16
+
17
+ Author: Eshan Roy <eshanized@proton.me>
18
+ Organization: TONMOY INFRASTRUCTURE & VISION
19
+ """
20
+
21
+
22
+ try:
23
+ import torch
24
+ import torch.nn as nn
25
+
26
+ TORCH_AVAILABLE = True
27
+ except ImportError:
28
+ TORCH_AVAILABLE = False
29
+
30
+ # Dummy classes for type checking
31
+ class nn:
32
+ class Module:
33
+ pass
34
+
35
+
36
+ from morphml.logging_config import get_logger
37
+
38
+ logger = get_logger(__name__)
39
+
40
+
41
+ def check_torch_available() -> None:
42
+ """Check if PyTorch is available and raise error if not."""
43
+ if not TORCH_AVAILABLE:
44
+ raise ImportError(
45
+ "PyTorch is required for gradient-based NAS. "
46
+ "Install with: pip install torch or poetry add torch"
47
+ )
48
+
49
+
50
+ class SepConv(nn.Module):
51
+ """
52
+ Separable Convolution operation.
53
+
54
+ Separable convolutions reduce computational cost by factorizing a standard
55
+ convolution into depthwise and pointwise convolutions.
56
+
57
+ Standard conv: O(C_in * C_out * k^2)
58
+ Separable conv: O(C_in * k^2 + C_in * C_out)
59
+
60
+ Structure:
61
+ ReLU → Depthwise Conv → Pointwise Conv → BatchNorm
62
+
63
+ Args:
64
+ C_in: Input channels
65
+ C_out: Output channels
66
+ kernel_size: Kernel size for depthwise conv
67
+ stride: Stride for convolution
68
+ padding: Padding for convolution
69
+ affine: Whether to use learnable affine params in BN
70
+
71
+ Example:
72
+ >>> sep_conv = SepConv(16, 32, kernel_size=3, stride=1, padding=1)
73
+ >>> x = torch.randn(2, 16, 32, 32)
74
+ >>> out = sep_conv(x) # Shape: (2, 32, 32, 32)
75
+ """
76
+
77
+ def __init__(
78
+ self,
79
+ C_in: int,
80
+ C_out: int,
81
+ kernel_size: int,
82
+ stride: int,
83
+ padding: int,
84
+ affine: bool = True,
85
+ ):
86
+ super().__init__()
87
+
88
+ check_torch_available()
89
+
90
+ self.op = nn.Sequential(
91
+ nn.ReLU(inplace=False),
92
+ # Depthwise convolution (groups=C_in)
93
+ nn.Conv2d(
94
+ C_in,
95
+ C_in,
96
+ kernel_size=kernel_size,
97
+ stride=stride,
98
+ padding=padding,
99
+ groups=C_in,
100
+ bias=False,
101
+ ),
102
+ # Pointwise convolution (1x1)
103
+ nn.Conv2d(C_in, C_out, kernel_size=1, bias=False),
104
+ nn.BatchNorm2d(C_out, affine=affine),
105
+ )
106
+
107
+ def forward(self, x: "torch.Tensor") -> "torch.Tensor":
108
+ """Forward pass."""
109
+ return self.op(x)
110
+
111
+
112
+ class DilConv(nn.Module):
113
+ """
114
+ Dilated (Atrous) Convolution operation.
115
+
116
+ Dilated convolutions increase receptive field without increasing
117
+ parameter count by inserting spaces (dilation) between kernel elements.
118
+
119
+ Receptive field = k + (k-1) * (dilation-1)
120
+
121
+ Structure:
122
+ ReLU → Dilated Depthwise Conv → Pointwise Conv → BatchNorm
123
+
124
+ Args:
125
+ C_in: Input channels
126
+ C_out: Output channels
127
+ kernel_size: Kernel size
128
+ stride: Stride
129
+ padding: Padding (should be = dilation * (kernel_size - 1) / 2)
130
+ dilation: Dilation rate
131
+ affine: Whether to use learnable affine params in BN
132
+
133
+ Example:
134
+ >>> dil_conv = DilConv(16, 32, kernel_size=3, stride=1,
135
+ ... padding=2, dilation=2)
136
+ >>> x = torch.randn(2, 16, 32, 32)
137
+ >>> out = dil_conv(x) # Shape: (2, 32, 32, 32)
138
+ """
139
+
140
+ def __init__(
141
+ self,
142
+ C_in: int,
143
+ C_out: int,
144
+ kernel_size: int,
145
+ stride: int,
146
+ padding: int,
147
+ dilation: int,
148
+ affine: bool = True,
149
+ ):
150
+ super().__init__()
151
+
152
+ check_torch_available()
153
+
154
+ self.op = nn.Sequential(
155
+ nn.ReLU(inplace=False),
156
+ # Dilated depthwise convolution
157
+ nn.Conv2d(
158
+ C_in,
159
+ C_in,
160
+ kernel_size=kernel_size,
161
+ stride=stride,
162
+ padding=padding,
163
+ dilation=dilation,
164
+ groups=C_in,
165
+ bias=False,
166
+ ),
167
+ # Pointwise convolution
168
+ nn.Conv2d(C_in, C_out, kernel_size=1, bias=False),
169
+ nn.BatchNorm2d(C_out, affine=affine),
170
+ )
171
+
172
+ def forward(self, x: "torch.Tensor") -> "torch.Tensor":
173
+ """Forward pass."""
174
+ return self.op(x)
175
+
176
+
177
+ class Identity(nn.Module):
178
+ """
179
+ Identity operation (skip connection).
180
+
181
+ Simply passes input to output unchanged. Used for residual connections.
182
+
183
+ Example:
184
+ >>> identity = Identity()
185
+ >>> x = torch.randn(2, 16, 32, 32)
186
+ >>> out = identity(x)
187
+ >>> assert torch.equal(x, out)
188
+ """
189
+
190
+ def __init__(self):
191
+ super().__init__()
192
+ check_torch_available()
193
+
194
+ def forward(self, x: "torch.Tensor") -> "torch.Tensor":
195
+ """Forward pass (identity)."""
196
+ return x
197
+
198
+
199
+ class Zero(nn.Module):
200
+ """
201
+ Zero operation (no connection).
202
+
203
+ Returns a zero tensor of the same shape as input. Used to represent
204
+ the absence of a connection in the architecture.
205
+
206
+ Args:
207
+ stride: Stride for output shape calculation
208
+
209
+ Example:
210
+ >>> zero = Zero(stride=1)
211
+ >>> x = torch.randn(2, 16, 32, 32)
212
+ >>> out = zero(x)
213
+ >>> assert torch.all(out == 0)
214
+ """
215
+
216
+ def __init__(self, stride: int = 1):
217
+ super().__init__()
218
+ check_torch_available()
219
+ self.stride = stride
220
+
221
+ def forward(self, x: "torch.Tensor") -> "torch.Tensor":
222
+ """Forward pass (zeros)."""
223
+ if self.stride == 1:
224
+ return x.mul(0.0)
225
+ else:
226
+ # Stride > 1: reduce spatial dimensions
227
+ return x[:, :, :: self.stride, :: self.stride].mul(0.0)
228
+
229
+
230
+ class FactorizedReduce(nn.Module):
231
+ """
232
+ Factorized reduction operation.
233
+
234
+ Reduces spatial dimensions by factor of 2 while doubling channels.
235
+ More efficient than strided convolution for this specific task.
236
+
237
+ Method:
238
+ 1. Two parallel 1x1 convolutions with different offsets
239
+ 2. Concatenate outputs
240
+ 3. BatchNorm
241
+
242
+ Args:
243
+ C_in: Input channels
244
+ C_out: Output channels (typically 2 * C_in)
245
+ affine: Whether to use learnable affine params in BN
246
+
247
+ Example:
248
+ >>> reduce = FactorizedReduce(16, 32)
249
+ >>> x = torch.randn(2, 16, 32, 32)
250
+ >>> out = reduce(x) # Shape: (2, 32, 16, 16)
251
+ """
252
+
253
+ def __init__(self, C_in: int, C_out: int, affine: bool = True):
254
+ super().__init__()
255
+
256
+ check_torch_available()
257
+
258
+ assert C_out % 2 == 0, "C_out must be divisible by 2"
259
+
260
+ self.relu = nn.ReLU(inplace=False)
261
+ self.conv_1 = nn.Conv2d(C_in, C_out // 2, 1, stride=2, bias=False)
262
+ self.conv_2 = nn.Conv2d(C_in, C_out // 2, 1, stride=2, bias=False)
263
+ self.bn = nn.BatchNorm2d(C_out, affine=affine)
264
+
265
+ def forward(self, x: "torch.Tensor") -> "torch.Tensor":
266
+ """Forward pass."""
267
+ x = self.relu(x)
268
+ # Two parallel convolutions with offset
269
+ out1 = self.conv_1(x)
270
+ out2 = self.conv_2(x[:, :, 1:, 1:]) # Offset by 1 pixel
271
+ out = torch.cat([out1, out2], dim=1)
272
+ out = self.bn(out)
273
+ return out
274
+
275
+
276
+ class ReLUConvBN(nn.Module):
277
+ """
278
+ Standard convolution block: ReLU → Conv → BatchNorm.
279
+
280
+ Args:
281
+ C_in: Input channels
282
+ C_out: Output channels
283
+ kernel_size: Kernel size
284
+ stride: Stride
285
+ padding: Padding
286
+ affine: Whether to use learnable affine params in BN
287
+
288
+ Example:
289
+ >>> conv_block = ReLUConvBN(16, 32, 3, 1, 1)
290
+ >>> x = torch.randn(2, 16, 32, 32)
291
+ >>> out = conv_block(x) # Shape: (2, 32, 32, 32)
292
+ """
293
+
294
+ def __init__(
295
+ self,
296
+ C_in: int,
297
+ C_out: int,
298
+ kernel_size: int,
299
+ stride: int,
300
+ padding: int,
301
+ affine: bool = True,
302
+ ):
303
+ super().__init__()
304
+
305
+ check_torch_available()
306
+
307
+ self.op = nn.Sequential(
308
+ nn.ReLU(inplace=False),
309
+ nn.Conv2d(C_in, C_out, kernel_size, stride=stride, padding=padding, bias=False),
310
+ nn.BatchNorm2d(C_out, affine=affine),
311
+ )
312
+
313
+ def forward(self, x: "torch.Tensor") -> "torch.Tensor":
314
+ """Forward pass."""
315
+ return self.op(x)
316
+
317
+
318
+ class DropPath(nn.Module):
319
+ """
320
+ Drop Path (Stochastic Depth) regularization.
321
+
322
+ Randomly drops entire paths (operations) during training to prevent
323
+ over-reliance on specific paths and improve generalization.
324
+
325
+ Reference:
326
+ Huang et al. "Deep Networks with Stochastic Depth." ECCV 2016.
327
+
328
+ Args:
329
+ drop_prob: Probability of dropping a path
330
+
331
+ Example:
332
+ >>> drop_path = DropPath(drop_prob=0.2)
333
+ >>> x = torch.randn(2, 16, 32, 32)
334
+ >>> out = drop_path(x) # Randomly zeroed during training
335
+ """
336
+
337
+ def __init__(self, drop_prob: float = 0.0):
338
+ super().__init__()
339
+ check_torch_available()
340
+ self.drop_prob = drop_prob
341
+
342
+ def forward(self, x: "torch.Tensor") -> "torch.Tensor":
343
+ """Forward pass with stochastic depth."""
344
+ if not self.training or self.drop_prob == 0.0:
345
+ return x
346
+
347
+ keep_prob = 1.0 - self.drop_prob
348
+ # Create binary mask
349
+ shape = (x.shape[0],) + (1,) * (x.ndim - 1) # (N, 1, 1, 1)
350
+ random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
351
+ random_tensor.floor_() # Binarize
352
+
353
+ # Scale and apply mask
354
+ output = x.div(keep_prob) * random_tensor
355
+ return output
356
+
357
+
358
+ # Operation factory function
359
+ def create_operation(
360
+ op_name: str, C_in: int, C_out: int, stride: int = 1, affine: bool = True
361
+ ) -> nn.Module:
362
+ """
363
+ Factory function to create operations by name.
364
+
365
+ Args:
366
+ op_name: Operation name
367
+ C_in: Input channels
368
+ C_out: Output channels
369
+ stride: Stride
370
+ affine: Use affine params in BN
371
+
372
+ Returns:
373
+ Operation module
374
+
375
+ Raises:
376
+ ValueError: If operation name is unknown
377
+
378
+ Example:
379
+ >>> op = create_operation('sep_conv_3x3', C_in=16, C_out=32)
380
+ >>> x = torch.randn(2, 16, 32, 32)
381
+ >>> out = op(x)
382
+ """
383
+ check_torch_available()
384
+
385
+ if op_name == "none":
386
+ return Zero(stride=stride)
387
+
388
+ elif op_name == "skip_connect":
389
+ if stride == 1:
390
+ return Identity()
391
+ else:
392
+ return FactorizedReduce(C_in, C_out, affine=affine)
393
+
394
+ elif op_name == "max_pool_3x3":
395
+ return nn.MaxPool2d(3, stride=stride, padding=1)
396
+
397
+ elif op_name == "avg_pool_3x3":
398
+ return nn.AvgPool2d(3, stride=stride, padding=1, count_include_pad=False)
399
+
400
+ elif op_name == "sep_conv_3x3":
401
+ return SepConv(C_in, C_out, 3, stride, 1, affine=affine)
402
+
403
+ elif op_name == "sep_conv_5x5":
404
+ return SepConv(C_in, C_out, 5, stride, 2, affine=affine)
405
+
406
+ elif op_name == "sep_conv_7x7":
407
+ return SepConv(C_in, C_out, 7, stride, 3, affine=affine)
408
+
409
+ elif op_name == "dil_conv_3x3":
410
+ return DilConv(C_in, C_out, 3, stride, 2, dilation=2, affine=affine)
411
+
412
+ elif op_name == "dil_conv_5x5":
413
+ return DilConv(C_in, C_out, 5, stride, 4, dilation=2, affine=affine)
414
+
415
+ elif op_name == "conv_3x3":
416
+ return ReLUConvBN(C_in, C_out, 3, stride, 1, affine=affine)
417
+
418
+ elif op_name == "conv_1x1":
419
+ return ReLUConvBN(C_in, C_out, 1, stride, 0, affine=affine)
420
+
421
+ else:
422
+ raise ValueError(f"Unknown operation: {op_name}")
423
+
424
+
425
+ # Standard operation set for DARTS/ENAS
426
+ OPERATIONS = [
427
+ "none",
428
+ "max_pool_3x3",
429
+ "avg_pool_3x3",
430
+ "skip_connect",
431
+ "sep_conv_3x3",
432
+ "sep_conv_5x5",
433
+ "dil_conv_3x3",
434
+ "dil_conv_5x5",
435
+ ]
436
+
437
+
438
+ def get_operation_names() -> list:
439
+ """Get list of available operation names."""
440
+ return OPERATIONS.copy()
441
+
442
+
443
+ def count_operation_parameters(op: nn.Module) -> int:
444
+ """
445
+ Count number of trainable parameters in an operation.
446
+
447
+ Args:
448
+ op: PyTorch module
449
+
450
+ Returns:
451
+ Number of trainable parameters
452
+ """
453
+ if not TORCH_AVAILABLE:
454
+ return 0
455
+ return sum(p.numel() for p in op.parameters() if p.requires_grad)
456
+
457
+
458
+ # Example usage and testing
459
+ if __name__ == "__main__":
460
+ if TORCH_AVAILABLE:
461
+ print("Testing operations...")
462
+
463
+ # Test each operation
464
+ x = torch.randn(2, 16, 32, 32)
465
+
466
+ for op_name in OPERATIONS:
467
+ op = create_operation(op_name, C_in=16, C_out=16)
468
+ out = op(x)
469
+ params = count_operation_parameters(op)
470
+ print(f"{op_name:20s} | Output shape: {tuple(out.shape)} | Params: {params:,}")
471
+
472
+ print("\nAll operations tested successfully!")
473
+ else:
474
+ print("PyTorch not available. Install with: pip install torch")