wavedl 1.6.0__py3-none-any.whl → 1.6.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (38) hide show
  1. wavedl/__init__.py +1 -1
  2. wavedl/hpo.py +451 -451
  3. wavedl/{hpc.py → launcher.py} +135 -61
  4. wavedl/models/__init__.py +28 -0
  5. wavedl/models/{_timm_utils.py → _pretrained_utils.py} +128 -0
  6. wavedl/models/base.py +48 -0
  7. wavedl/models/caformer.py +1 -1
  8. wavedl/models/cnn.py +2 -27
  9. wavedl/models/convnext.py +5 -18
  10. wavedl/models/convnext_v2.py +6 -22
  11. wavedl/models/densenet.py +5 -18
  12. wavedl/models/efficientnetv2.py +315 -315
  13. wavedl/models/efficientvit.py +398 -0
  14. wavedl/models/fastvit.py +6 -39
  15. wavedl/models/mamba.py +44 -24
  16. wavedl/models/maxvit.py +51 -48
  17. wavedl/models/mobilenetv3.py +295 -295
  18. wavedl/models/regnet.py +406 -406
  19. wavedl/models/resnet.py +14 -56
  20. wavedl/models/resnet3d.py +258 -258
  21. wavedl/models/swin.py +443 -443
  22. wavedl/models/tcn.py +393 -409
  23. wavedl/models/unet.py +1 -5
  24. wavedl/models/unireplknet.py +491 -0
  25. wavedl/models/vit.py +3 -3
  26. wavedl/train.py +1427 -1430
  27. wavedl/utils/config.py +367 -367
  28. wavedl/utils/cross_validation.py +530 -530
  29. wavedl/utils/losses.py +216 -216
  30. wavedl/utils/optimizers.py +216 -216
  31. wavedl/utils/schedulers.py +251 -251
  32. {wavedl-1.6.0.dist-info → wavedl-1.6.2.dist-info}/METADATA +150 -113
  33. wavedl-1.6.2.dist-info/RECORD +46 -0
  34. {wavedl-1.6.0.dist-info → wavedl-1.6.2.dist-info}/entry_points.txt +2 -2
  35. wavedl-1.6.0.dist-info/RECORD +0 -44
  36. {wavedl-1.6.0.dist-info → wavedl-1.6.2.dist-info}/LICENSE +0 -0
  37. {wavedl-1.6.0.dist-info → wavedl-1.6.2.dist-info}/WHEEL +0 -0
  38. {wavedl-1.6.0.dist-info → wavedl-1.6.2.dist-info}/top_level.txt +0 -0
wavedl/models/regnet.py CHANGED
@@ -1,406 +1,406 @@
1
- """
2
- RegNet: Designing Network Design Spaces
3
- ========================================
4
-
5
- RegNet provides a family of models with predictable scaling behavior,
6
- designed through systematic exploration of network design spaces.
7
- Models scale smoothly from mobile to server deployments.
8
-
9
- **Key Features**:
10
- - Predictable scaling: accuracy increases linearly with compute
11
- - Simple, uniform architecture (no complex compound scaling)
12
- - Group convolutions for efficiency
13
- - Optional Squeeze-and-Excitation (SE) attention
14
-
15
- **Variants** (RegNetY includes SE attention):
16
- - regnet_y_400mf: Ultra-light (~3.9M backbone params, 0.4 GFLOPs)
17
- - regnet_y_800mf: Light (~5.7M backbone params, 0.8 GFLOPs)
18
- - regnet_y_1_6gf: Medium (~10.3M backbone params, 1.6 GFLOPs) - Recommended
19
- - regnet_y_3_2gf: Large (~17.9M backbone params, 3.2 GFLOPs)
20
- - regnet_y_8gf: Very large (~37.4M backbone params, 8.0 GFLOPs)
21
-
22
- **When to Use RegNet**:
23
- - When you need predictable performance at a given compute budget
24
- - For systematic model selection experiments
25
- - When interpretability of design choices matters
26
- - As an efficient alternative to ResNet
27
-
28
- **Note**: RegNet is 2D-only. For 1D data, use TCN. For 3D data, use ResNet3D.
29
-
30
- References:
31
- Radosavovic, I., et al. (2020). Designing Network Design Spaces.
32
- CVPR 2020. https://arxiv.org/abs/2003.13678
33
-
34
- Author: Ductho Le (ductho.le@outlook.com)
35
- """
36
-
37
- from typing import Any
38
-
39
- import torch
40
- import torch.nn as nn
41
-
42
-
43
- try:
44
- from torchvision.models import (
45
- RegNet_Y_1_6GF_Weights,
46
- RegNet_Y_3_2GF_Weights,
47
- RegNet_Y_8GF_Weights,
48
- RegNet_Y_400MF_Weights,
49
- RegNet_Y_800MF_Weights,
50
- regnet_y_1_6gf,
51
- regnet_y_3_2gf,
52
- regnet_y_8gf,
53
- regnet_y_400mf,
54
- regnet_y_800mf,
55
- )
56
-
57
- REGNET_AVAILABLE = True
58
- except ImportError:
59
- REGNET_AVAILABLE = False
60
-
61
- from wavedl.models.base import BaseModel
62
- from wavedl.models.registry import register_model
63
-
64
-
65
- class RegNetBase(BaseModel):
66
- """
67
- Base RegNet class for regression tasks.
68
-
69
- Wraps torchvision RegNetY (with SE attention) with:
70
- - Optional pretrained weights (ImageNet-1K)
71
- - Automatic input channel adaptation (grayscale → 3ch)
72
- - Custom regression head
73
-
74
- RegNet advantages:
75
- - Simple, uniform design (easy to understand and modify)
76
- - Predictable accuracy/compute trade-off
77
- - Efficient group convolutions
78
- - SE attention for channel weighting (RegNetY variants)
79
-
80
- Note: This is 2D-only. Input shape must be (H, W).
81
- """
82
-
83
- def __init__(
84
- self,
85
- in_shape: tuple[int, int],
86
- out_size: int,
87
- model_fn,
88
- weights_class,
89
- pretrained: bool = True,
90
- dropout_rate: float = 0.2,
91
- freeze_backbone: bool = False,
92
- regression_hidden: int = 256,
93
- **kwargs,
94
- ):
95
- """
96
- Initialize RegNet for regression.
97
-
98
- Args:
99
- in_shape: (H, W) input image dimensions
100
- out_size: Number of regression output targets
101
- model_fn: torchvision model constructor
102
- weights_class: Pretrained weights enum class
103
- pretrained: Use ImageNet pretrained weights (default: True)
104
- dropout_rate: Dropout rate in regression head (default: 0.2)
105
- freeze_backbone: Freeze backbone for fine-tuning (default: False)
106
- regression_hidden: Hidden units in regression head (default: 256)
107
- """
108
- super().__init__(in_shape, out_size)
109
-
110
- if not REGNET_AVAILABLE:
111
- raise ImportError(
112
- "torchvision is required for RegNet. "
113
- "Install with: pip install torchvision"
114
- )
115
-
116
- if len(in_shape) != 2:
117
- raise ValueError(
118
- f"RegNet requires 2D input (H, W), got {len(in_shape)}D. "
119
- "For 1D data, use TCN. For 3D data, use ResNet3D."
120
- )
121
-
122
- self.pretrained = pretrained
123
- self.dropout_rate = dropout_rate
124
- self.freeze_backbone = freeze_backbone
125
- self.regression_hidden = regression_hidden
126
-
127
- # Load pretrained backbone
128
- weights = weights_class.IMAGENET1K_V1 if pretrained else None
129
- self.backbone = model_fn(weights=weights)
130
-
131
- # RegNet uses .fc as the classification head
132
- in_features = self.backbone.fc.in_features
133
-
134
- # Replace fc with regression head
135
- self.backbone.fc = nn.Sequential(
136
- nn.Dropout(dropout_rate),
137
- nn.Linear(in_features, regression_hidden),
138
- nn.ReLU(inplace=True),
139
- nn.Dropout(dropout_rate * 0.5),
140
- nn.Linear(regression_hidden, out_size),
141
- )
142
-
143
- # Adapt first conv for single-channel input (3× memory savings vs expand)
144
- self._adapt_input_channels()
145
-
146
- # Optionally freeze backbone for fine-tuning (after adaptation so new conv is frozen too)
147
- if freeze_backbone:
148
- self._freeze_backbone()
149
-
150
- def _adapt_input_channels(self):
151
- """Modify first conv to accept single-channel input.
152
-
153
- Instead of expanding 1→3 channels in forward (which triples memory),
154
- we replace the first conv layer with a 1-channel version and initialize
155
- weights as the mean of the pretrained RGB filters.
156
- """
157
- old_conv = self.backbone.stem[0]
158
- new_conv = nn.Conv2d(
159
- 1, # Single channel input
160
- old_conv.out_channels,
161
- kernel_size=old_conv.kernel_size,
162
- stride=old_conv.stride,
163
- padding=old_conv.padding,
164
- dilation=old_conv.dilation,
165
- groups=old_conv.groups,
166
- padding_mode=old_conv.padding_mode,
167
- bias=old_conv.bias is not None,
168
- )
169
- if self.pretrained:
170
- with torch.no_grad():
171
- new_conv.weight.copy_(old_conv.weight.mean(dim=1, keepdim=True))
172
- self.backbone.stem[0] = new_conv
173
-
174
- def _freeze_backbone(self):
175
- """Freeze all backbone parameters except the fc layer."""
176
- for name, param in self.backbone.named_parameters():
177
- if "fc" not in name:
178
- param.requires_grad = False
179
-
180
- def forward(self, x: torch.Tensor) -> torch.Tensor:
181
- """
182
- Forward pass.
183
-
184
- Args:
185
- x: Input tensor of shape (B, 1, H, W)
186
-
187
- Returns:
188
- Output tensor of shape (B, out_size)
189
- """
190
- return self.backbone(x)
191
-
192
- @classmethod
193
- def get_default_config(cls) -> dict[str, Any]:
194
- """Return default configuration for RegNet."""
195
- return {
196
- "pretrained": True,
197
- "dropout_rate": 0.2,
198
- "freeze_backbone": False,
199
- "regression_hidden": 256,
200
- }
201
-
202
-
203
- # =============================================================================
204
- # REGISTERED MODEL VARIANTS
205
- # =============================================================================
206
-
207
-
208
- @register_model("regnet_y_400mf")
209
- class RegNetY400MF(RegNetBase):
210
- """
211
- RegNetY-400MF: Ultra-lightweight for constrained environments.
212
-
213
- ~3.9M backbone parameters, 0.4 GFLOPs. Smallest RegNet variant with SE attention.
214
-
215
- Recommended for:
216
- - Edge deployment with moderate accuracy needs
217
- - Quick training experiments
218
- - Baseline comparisons
219
-
220
- Args:
221
- in_shape: (H, W) image dimensions
222
- out_size: Number of regression targets
223
- pretrained: Use ImageNet pretrained weights (default: True)
224
- dropout_rate: Dropout rate in head (default: 0.2)
225
- freeze_backbone: Freeze backbone for fine-tuning (default: False)
226
- regression_hidden: Hidden units in regression head (default: 256)
227
-
228
- Example:
229
- >>> model = RegNetY400MF(in_shape=(224, 224), out_size=3)
230
- >>> x = torch.randn(4, 1, 224, 224)
231
- >>> out = model(x) # (4, 3)
232
- """
233
-
234
- def __init__(self, in_shape: tuple[int, int], out_size: int, **kwargs):
235
- super().__init__(
236
- in_shape=in_shape,
237
- out_size=out_size,
238
- model_fn=regnet_y_400mf,
239
- weights_class=RegNet_Y_400MF_Weights,
240
- **kwargs,
241
- )
242
-
243
- def __repr__(self) -> str:
244
- pt = "pretrained" if self.pretrained else "scratch"
245
- return f"RegNetY_400MF({pt}, in={self.in_shape}, out={self.out_size})"
246
-
247
-
248
- @register_model("regnet_y_800mf")
249
- class RegNetY800MF(RegNetBase):
250
- """
251
- RegNetY-800MF: Light variant with good accuracy.
252
-
253
- ~5.7M backbone parameters, 0.8 GFLOPs. Good balance for mobile deployment.
254
-
255
- Recommended for:
256
- - Mobile/portable devices
257
- - When MobileNet isn't accurate enough
258
- - Moderate compute budgets
259
-
260
- Args:
261
- in_shape: (H, W) image dimensions
262
- out_size: Number of regression targets
263
- pretrained: Use ImageNet pretrained weights (default: True)
264
- dropout_rate: Dropout rate in head (default: 0.2)
265
- freeze_backbone: Freeze backbone for fine-tuning (default: False)
266
- regression_hidden: Hidden units in regression head (default: 256)
267
-
268
- Example:
269
- >>> model = RegNetY800MF(in_shape=(224, 224), out_size=3)
270
- >>> x = torch.randn(4, 1, 224, 224)
271
- >>> out = model(x) # (4, 3)
272
- """
273
-
274
- def __init__(self, in_shape: tuple[int, int], out_size: int, **kwargs):
275
- super().__init__(
276
- in_shape=in_shape,
277
- out_size=out_size,
278
- model_fn=regnet_y_800mf,
279
- weights_class=RegNet_Y_800MF_Weights,
280
- **kwargs,
281
- )
282
-
283
- def __repr__(self) -> str:
284
- pt = "pretrained" if self.pretrained else "scratch"
285
- return f"RegNetY_800MF({pt}, in={self.in_shape}, out={self.out_size})"
286
-
287
-
288
- @register_model("regnet_y_1_6gf")
289
- class RegNetY1_6GF(RegNetBase):
290
- """
291
- RegNetY-1.6GF: Recommended default for balanced performance.
292
-
293
- ~10.3M backbone parameters, 1.6 GFLOPs. Best trade-off of accuracy and efficiency.
294
- Comparable to ResNet50 but more efficient.
295
-
296
- Recommended for:
297
- - Default choice for general wave-based tasks
298
- - When you want predictable scaling
299
- - Server deployment with efficiency needs
300
-
301
- Args:
302
- in_shape: (H, W) image dimensions
303
- out_size: Number of regression targets
304
- pretrained: Use ImageNet pretrained weights (default: True)
305
- dropout_rate: Dropout rate in head (default: 0.2)
306
- freeze_backbone: Freeze backbone for fine-tuning (default: False)
307
- regression_hidden: Hidden units in regression head (default: 256)
308
-
309
- Example:
310
- >>> model = RegNetY1_6GF(in_shape=(224, 224), out_size=3)
311
- >>> x = torch.randn(4, 1, 224, 224)
312
- >>> out = model(x) # (4, 3)
313
- """
314
-
315
- def __init__(self, in_shape: tuple[int, int], out_size: int, **kwargs):
316
- super().__init__(
317
- in_shape=in_shape,
318
- out_size=out_size,
319
- model_fn=regnet_y_1_6gf,
320
- weights_class=RegNet_Y_1_6GF_Weights,
321
- **kwargs,
322
- )
323
-
324
- def __repr__(self) -> str:
325
- pt = "pretrained" if self.pretrained else "scratch"
326
- return f"RegNetY_1.6GF({pt}, in={self.in_shape}, out={self.out_size})"
327
-
328
-
329
- @register_model("regnet_y_3_2gf")
330
- class RegNetY3_2GF(RegNetBase):
331
- """
332
- RegNetY-3.2GF: Higher accuracy for demanding tasks.
333
-
334
- ~17.9M backbone parameters, 3.2 GFLOPs. Use when 1.6GF isn't sufficient.
335
-
336
- Recommended for:
337
- - Larger datasets requiring more capacity
338
- - When accuracy is more important than efficiency
339
- - Research experiments with multiple model sizes
340
-
341
- Args:
342
- in_shape: (H, W) image dimensions
343
- out_size: Number of regression targets
344
- pretrained: Use ImageNet pretrained weights (default: True)
345
- dropout_rate: Dropout rate in head (default: 0.2)
346
- freeze_backbone: Freeze backbone for fine-tuning (default: False)
347
- regression_hidden: Hidden units in regression head (default: 256)
348
-
349
- Example:
350
- >>> model = RegNetY3_2GF(in_shape=(224, 224), out_size=3)
351
- >>> x = torch.randn(4, 1, 224, 224)
352
- >>> out = model(x) # (4, 3)
353
- """
354
-
355
- def __init__(self, in_shape: tuple[int, int], out_size: int, **kwargs):
356
- super().__init__(
357
- in_shape=in_shape,
358
- out_size=out_size,
359
- model_fn=regnet_y_3_2gf,
360
- weights_class=RegNet_Y_3_2GF_Weights,
361
- **kwargs,
362
- )
363
-
364
- def __repr__(self) -> str:
365
- pt = "pretrained" if self.pretrained else "scratch"
366
- return f"RegNetY_3.2GF({pt}, in={self.in_shape}, out={self.out_size})"
367
-
368
-
369
- @register_model("regnet_y_8gf")
370
- class RegNetY8GF(RegNetBase):
371
- """
372
- RegNetY-8GF: High capacity for large-scale tasks.
373
-
374
- ~37.4M backbone parameters, 8.0 GFLOPs. Use for maximum accuracy needs.
375
-
376
- Recommended for:
377
- - Very large datasets (>50k samples)
378
- - Complex wave patterns
379
- - HPC environments with ample GPU memory
380
-
381
- Args:
382
- in_shape: (H, W) image dimensions
383
- out_size: Number of regression targets
384
- pretrained: Use ImageNet pretrained weights (default: True)
385
- dropout_rate: Dropout rate in head (default: 0.2)
386
- freeze_backbone: Freeze backbone for fine-tuning (default: False)
387
- regression_hidden: Hidden units in regression head (default: 256)
388
-
389
- Example:
390
- >>> model = RegNetY8GF(in_shape=(224, 224), out_size=3)
391
- >>> x = torch.randn(4, 1, 224, 224)
392
- >>> out = model(x) # (4, 3)
393
- """
394
-
395
- def __init__(self, in_shape: tuple[int, int], out_size: int, **kwargs):
396
- super().__init__(
397
- in_shape=in_shape,
398
- out_size=out_size,
399
- model_fn=regnet_y_8gf,
400
- weights_class=RegNet_Y_8GF_Weights,
401
- **kwargs,
402
- )
403
-
404
- def __repr__(self) -> str:
405
- pt = "pretrained" if self.pretrained else "scratch"
406
- return f"RegNetY_8GF({pt}, in={self.in_shape}, out={self.out_size})"
1
+ """
2
+ RegNet: Designing Network Design Spaces
3
+ ========================================
4
+
5
+ RegNet provides a family of models with predictable scaling behavior,
6
+ designed through systematic exploration of network design spaces.
7
+ Models scale smoothly from mobile to server deployments.
8
+
9
+ **Key Features**:
10
+ - Predictable scaling: accuracy increases linearly with compute
11
+ - Simple, uniform architecture (no complex compound scaling)
12
+ - Group convolutions for efficiency
13
+ - Optional Squeeze-and-Excitation (SE) attention
14
+
15
+ **Variants** (RegNetY includes SE attention):
16
+ - regnet_y_400mf: Ultra-light (~3.9M backbone params, 0.4 GFLOPs)
17
+ - regnet_y_800mf: Light (~5.7M backbone params, 0.8 GFLOPs)
18
+ - regnet_y_1_6gf: Medium (~10.3M backbone params, 1.6 GFLOPs) - Recommended
19
+ - regnet_y_3_2gf: Large (~17.9M backbone params, 3.2 GFLOPs)
20
+ - regnet_y_8gf: Very large (~37.4M backbone params, 8.0 GFLOPs)
21
+
22
+ **When to Use RegNet**:
23
+ - When you need predictable performance at a given compute budget
24
+ - For systematic model selection experiments
25
+ - When interpretability of design choices matters
26
+ - As an efficient alternative to ResNet
27
+
28
+ **Note**: RegNet is 2D-only. For 1D data, use TCN. For 3D data, use ResNet3D.
29
+
30
+ References:
31
+ Radosavovic, I., et al. (2020). Designing Network Design Spaces.
32
+ CVPR 2020. https://arxiv.org/abs/2003.13678
33
+
34
+ Author: Ductho Le (ductho.le@outlook.com)
35
+ """
36
+
37
+ from typing import Any
38
+
39
+ import torch
40
+ import torch.nn as nn
41
+
42
+
43
+ try:
44
+ from torchvision.models import (
45
+ RegNet_Y_1_6GF_Weights,
46
+ RegNet_Y_3_2GF_Weights,
47
+ RegNet_Y_8GF_Weights,
48
+ RegNet_Y_400MF_Weights,
49
+ RegNet_Y_800MF_Weights,
50
+ regnet_y_1_6gf,
51
+ regnet_y_3_2gf,
52
+ regnet_y_8gf,
53
+ regnet_y_400mf,
54
+ regnet_y_800mf,
55
+ )
56
+
57
+ REGNET_AVAILABLE = True
58
+ except ImportError:
59
+ REGNET_AVAILABLE = False
60
+
61
+ from wavedl.models.base import BaseModel
62
+ from wavedl.models.registry import register_model
63
+
64
+
65
+ class RegNetBase(BaseModel):
66
+ """
67
+ Base RegNet class for regression tasks.
68
+
69
+ Wraps torchvision RegNetY (with SE attention) with:
70
+ - Optional pretrained weights (ImageNet-1K)
71
+ - Automatic input channel adaptation (grayscale → 3ch)
72
+ - Custom regression head
73
+
74
+ RegNet advantages:
75
+ - Simple, uniform design (easy to understand and modify)
76
+ - Predictable accuracy/compute trade-off
77
+ - Efficient group convolutions
78
+ - SE attention for channel weighting (RegNetY variants)
79
+
80
+ Note: This is 2D-only. Input shape must be (H, W).
81
+ """
82
+
83
+ def __init__(
84
+ self,
85
+ in_shape: tuple[int, int],
86
+ out_size: int,
87
+ model_fn,
88
+ weights_class,
89
+ pretrained: bool = True,
90
+ dropout_rate: float = 0.2,
91
+ freeze_backbone: bool = False,
92
+ regression_hidden: int = 256,
93
+ **kwargs,
94
+ ):
95
+ """
96
+ Initialize RegNet for regression.
97
+
98
+ Args:
99
+ in_shape: (H, W) input image dimensions
100
+ out_size: Number of regression output targets
101
+ model_fn: torchvision model constructor
102
+ weights_class: Pretrained weights enum class
103
+ pretrained: Use ImageNet pretrained weights (default: True)
104
+ dropout_rate: Dropout rate in regression head (default: 0.2)
105
+ freeze_backbone: Freeze backbone for fine-tuning (default: False)
106
+ regression_hidden: Hidden units in regression head (default: 256)
107
+ """
108
+ super().__init__(in_shape, out_size)
109
+
110
+ if not REGNET_AVAILABLE:
111
+ raise ImportError(
112
+ "torchvision is required for RegNet. "
113
+ "Install with: pip install torchvision"
114
+ )
115
+
116
+ if len(in_shape) != 2:
117
+ raise ValueError(
118
+ f"RegNet requires 2D input (H, W), got {len(in_shape)}D. "
119
+ "For 1D data, use TCN. For 3D data, use ResNet3D."
120
+ )
121
+
122
+ self.pretrained = pretrained
123
+ self.dropout_rate = dropout_rate
124
+ self.freeze_backbone = freeze_backbone
125
+ self.regression_hidden = regression_hidden
126
+
127
+ # Load pretrained backbone
128
+ weights = weights_class.IMAGENET1K_V1 if pretrained else None
129
+ self.backbone = model_fn(weights=weights)
130
+
131
+ # RegNet uses .fc as the classification head
132
+ in_features = self.backbone.fc.in_features
133
+
134
+ # Replace fc with regression head
135
+ self.backbone.fc = nn.Sequential(
136
+ nn.Dropout(dropout_rate),
137
+ nn.Linear(in_features, regression_hidden),
138
+ nn.ReLU(inplace=True),
139
+ nn.Dropout(dropout_rate * 0.5),
140
+ nn.Linear(regression_hidden, out_size),
141
+ )
142
+
143
+ # Adapt first conv for single-channel input (3× memory savings vs expand)
144
+ self._adapt_input_channels()
145
+
146
+ # Optionally freeze backbone for fine-tuning (after adaptation so new conv is frozen too)
147
+ if freeze_backbone:
148
+ self._freeze_backbone()
149
+
150
+ def _adapt_input_channels(self):
151
+ """Modify first conv to accept single-channel input.
152
+
153
+ Instead of expanding 1→3 channels in forward (which triples memory),
154
+ we replace the first conv layer with a 1-channel version and initialize
155
+ weights as the mean of the pretrained RGB filters.
156
+ """
157
+ old_conv = self.backbone.stem[0]
158
+ new_conv = nn.Conv2d(
159
+ 1, # Single channel input
160
+ old_conv.out_channels,
161
+ kernel_size=old_conv.kernel_size,
162
+ stride=old_conv.stride,
163
+ padding=old_conv.padding,
164
+ dilation=old_conv.dilation,
165
+ groups=old_conv.groups,
166
+ padding_mode=old_conv.padding_mode,
167
+ bias=old_conv.bias is not None,
168
+ )
169
+ if self.pretrained:
170
+ with torch.no_grad():
171
+ new_conv.weight.copy_(old_conv.weight.mean(dim=1, keepdim=True))
172
+ self.backbone.stem[0] = new_conv
173
+
174
+ def _freeze_backbone(self):
175
+ """Freeze all backbone parameters except the fc layer."""
176
+ for name, param in self.backbone.named_parameters():
177
+ if "fc" not in name:
178
+ param.requires_grad = False
179
+
180
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
181
+ """
182
+ Forward pass.
183
+
184
+ Args:
185
+ x: Input tensor of shape (B, 1, H, W)
186
+
187
+ Returns:
188
+ Output tensor of shape (B, out_size)
189
+ """
190
+ return self.backbone(x)
191
+
192
+ @classmethod
193
+ def get_default_config(cls) -> dict[str, Any]:
194
+ """Return default configuration for RegNet."""
195
+ return {
196
+ "pretrained": True,
197
+ "dropout_rate": 0.2,
198
+ "freeze_backbone": False,
199
+ "regression_hidden": 256,
200
+ }
201
+
202
+
203
+ # =============================================================================
204
+ # REGISTERED MODEL VARIANTS
205
+ # =============================================================================
206
+
207
+
208
+ @register_model("regnet_y_400mf")
209
+ class RegNetY400MF(RegNetBase):
210
+ """
211
+ RegNetY-400MF: Ultra-lightweight for constrained environments.
212
+
213
+ ~3.9M backbone parameters, 0.4 GFLOPs. Smallest RegNet variant with SE attention.
214
+
215
+ Recommended for:
216
+ - Edge deployment with moderate accuracy needs
217
+ - Quick training experiments
218
+ - Baseline comparisons
219
+
220
+ Args:
221
+ in_shape: (H, W) image dimensions
222
+ out_size: Number of regression targets
223
+ pretrained: Use ImageNet pretrained weights (default: True)
224
+ dropout_rate: Dropout rate in head (default: 0.2)
225
+ freeze_backbone: Freeze backbone for fine-tuning (default: False)
226
+ regression_hidden: Hidden units in regression head (default: 256)
227
+
228
+ Example:
229
+ >>> model = RegNetY400MF(in_shape=(224, 224), out_size=3)
230
+ >>> x = torch.randn(4, 1, 224, 224)
231
+ >>> out = model(x) # (4, 3)
232
+ """
233
+
234
+ def __init__(self, in_shape: tuple[int, int], out_size: int, **kwargs):
235
+ super().__init__(
236
+ in_shape=in_shape,
237
+ out_size=out_size,
238
+ model_fn=regnet_y_400mf,
239
+ weights_class=RegNet_Y_400MF_Weights,
240
+ **kwargs,
241
+ )
242
+
243
+ def __repr__(self) -> str:
244
+ pt = "pretrained" if self.pretrained else "scratch"
245
+ return f"RegNetY_400MF({pt}, in={self.in_shape}, out={self.out_size})"
246
+
247
+
248
+ @register_model("regnet_y_800mf")
249
+ class RegNetY800MF(RegNetBase):
250
+ """
251
+ RegNetY-800MF: Light variant with good accuracy.
252
+
253
+ ~5.7M backbone parameters, 0.8 GFLOPs. Good balance for mobile deployment.
254
+
255
+ Recommended for:
256
+ - Mobile/portable devices
257
+ - When MobileNet isn't accurate enough
258
+ - Moderate compute budgets
259
+
260
+ Args:
261
+ in_shape: (H, W) image dimensions
262
+ out_size: Number of regression targets
263
+ pretrained: Use ImageNet pretrained weights (default: True)
264
+ dropout_rate: Dropout rate in head (default: 0.2)
265
+ freeze_backbone: Freeze backbone for fine-tuning (default: False)
266
+ regression_hidden: Hidden units in regression head (default: 256)
267
+
268
+ Example:
269
+ >>> model = RegNetY800MF(in_shape=(224, 224), out_size=3)
270
+ >>> x = torch.randn(4, 1, 224, 224)
271
+ >>> out = model(x) # (4, 3)
272
+ """
273
+
274
+ def __init__(self, in_shape: tuple[int, int], out_size: int, **kwargs):
275
+ super().__init__(
276
+ in_shape=in_shape,
277
+ out_size=out_size,
278
+ model_fn=regnet_y_800mf,
279
+ weights_class=RegNet_Y_800MF_Weights,
280
+ **kwargs,
281
+ )
282
+
283
+ def __repr__(self) -> str:
284
+ pt = "pretrained" if self.pretrained else "scratch"
285
+ return f"RegNetY_800MF({pt}, in={self.in_shape}, out={self.out_size})"
286
+
287
+
288
+ @register_model("regnet_y_1_6gf")
289
+ class RegNetY1_6GF(RegNetBase):
290
+ """
291
+ RegNetY-1.6GF: Recommended default for balanced performance.
292
+
293
+ ~10.3M backbone parameters, 1.6 GFLOPs. Best trade-off of accuracy and efficiency.
294
+ Comparable to ResNet50 but more efficient.
295
+
296
+ Recommended for:
297
+ - Default choice for general wave-based tasks
298
+ - When you want predictable scaling
299
+ - Server deployment with efficiency needs
300
+
301
+ Args:
302
+ in_shape: (H, W) image dimensions
303
+ out_size: Number of regression targets
304
+ pretrained: Use ImageNet pretrained weights (default: True)
305
+ dropout_rate: Dropout rate in head (default: 0.2)
306
+ freeze_backbone: Freeze backbone for fine-tuning (default: False)
307
+ regression_hidden: Hidden units in regression head (default: 256)
308
+
309
+ Example:
310
+ >>> model = RegNetY1_6GF(in_shape=(224, 224), out_size=3)
311
+ >>> x = torch.randn(4, 1, 224, 224)
312
+ >>> out = model(x) # (4, 3)
313
+ """
314
+
315
+ def __init__(self, in_shape: tuple[int, int], out_size: int, **kwargs):
316
+ super().__init__(
317
+ in_shape=in_shape,
318
+ out_size=out_size,
319
+ model_fn=regnet_y_1_6gf,
320
+ weights_class=RegNet_Y_1_6GF_Weights,
321
+ **kwargs,
322
+ )
323
+
324
+ def __repr__(self) -> str:
325
+ pt = "pretrained" if self.pretrained else "scratch"
326
+ return f"RegNetY_1.6GF({pt}, in={self.in_shape}, out={self.out_size})"
327
+
328
+
329
+ @register_model("regnet_y_3_2gf")
330
+ class RegNetY3_2GF(RegNetBase):
331
+ """
332
+ RegNetY-3.2GF: Higher accuracy for demanding tasks.
333
+
334
+ ~17.9M backbone parameters, 3.2 GFLOPs. Use when 1.6GF isn't sufficient.
335
+
336
+ Recommended for:
337
+ - Larger datasets requiring more capacity
338
+ - When accuracy is more important than efficiency
339
+ - Research experiments with multiple model sizes
340
+
341
+ Args:
342
+ in_shape: (H, W) image dimensions
343
+ out_size: Number of regression targets
344
+ pretrained: Use ImageNet pretrained weights (default: True)
345
+ dropout_rate: Dropout rate in head (default: 0.2)
346
+ freeze_backbone: Freeze backbone for fine-tuning (default: False)
347
+ regression_hidden: Hidden units in regression head (default: 256)
348
+
349
+ Example:
350
+ >>> model = RegNetY3_2GF(in_shape=(224, 224), out_size=3)
351
+ >>> x = torch.randn(4, 1, 224, 224)
352
+ >>> out = model(x) # (4, 3)
353
+ """
354
+
355
+ def __init__(self, in_shape: tuple[int, int], out_size: int, **kwargs):
356
+ super().__init__(
357
+ in_shape=in_shape,
358
+ out_size=out_size,
359
+ model_fn=regnet_y_3_2gf,
360
+ weights_class=RegNet_Y_3_2GF_Weights,
361
+ **kwargs,
362
+ )
363
+
364
+ def __repr__(self) -> str:
365
+ pt = "pretrained" if self.pretrained else "scratch"
366
+ return f"RegNetY_3.2GF({pt}, in={self.in_shape}, out={self.out_size})"
367
+
368
+
369
+ @register_model("regnet_y_8gf")
370
+ class RegNetY8GF(RegNetBase):
371
+ """
372
+ RegNetY-8GF: High capacity for large-scale tasks.
373
+
374
+ ~37.4M backbone parameters, 8.0 GFLOPs. Use for maximum accuracy needs.
375
+
376
+ Recommended for:
377
+ - Very large datasets (>50k samples)
378
+ - Complex wave patterns
379
+ - HPC environments with ample GPU memory
380
+
381
+ Args:
382
+ in_shape: (H, W) image dimensions
383
+ out_size: Number of regression targets
384
+ pretrained: Use ImageNet pretrained weights (default: True)
385
+ dropout_rate: Dropout rate in head (default: 0.2)
386
+ freeze_backbone: Freeze backbone for fine-tuning (default: False)
387
+ regression_hidden: Hidden units in regression head (default: 256)
388
+
389
+ Example:
390
+ >>> model = RegNetY8GF(in_shape=(224, 224), out_size=3)
391
+ >>> x = torch.randn(4, 1, 224, 224)
392
+ >>> out = model(x) # (4, 3)
393
+ """
394
+
395
+ def __init__(self, in_shape: tuple[int, int], out_size: int, **kwargs):
396
+ super().__init__(
397
+ in_shape=in_shape,
398
+ out_size=out_size,
399
+ model_fn=regnet_y_8gf,
400
+ weights_class=RegNet_Y_8GF_Weights,
401
+ **kwargs,
402
+ )
403
+
404
+ def __repr__(self) -> str:
405
+ pt = "pretrained" if self.pretrained else "scratch"
406
+ return f"RegNetY_8GF({pt}, in={self.in_shape}, out={self.out_size})"