wavedl 1.3.0__py3-none-any.whl → 1.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,272 @@
1
+ """
2
+ MobileNetV3: Efficient Networks for Edge Deployment
3
+ ====================================================
4
+
5
+ Lightweight architecture optimized for mobile and embedded devices.
6
+ MobileNetV3 combines neural architecture search (NAS) with hardware-aware
7
+ optimization to achieve excellent accuracy with minimal computational cost.
8
+
9
+ **Key Features**:
10
+ - Inverted residuals with depthwise separable convolutions
11
+ - Squeeze-and-Excitation (SE) attention for channel weighting
12
+ - h-swish activation: efficient approximation of swish
13
+ - Designed for real-time inference on CPUs and edge devices
14
+
15
+ **Variants**:
16
+ - mobilenet_v3_small: Ultra-lightweight (~1.1M params) - Edge/embedded
17
+ - mobilenet_v3_large: Balanced (~3.2M params) - Mobile deployment
18
+
19
+ **Use Cases**:
20
+ - Real-time structural health monitoring on embedded systems
21
+ - Field inspection with portable devices
22
+ - When model size and inference speed are critical
23
+
24
+ **Note**: MobileNetV3 is 2D-only. For 1D data, use TCN. For 3D data, use ResNet3D.
25
+
26
+ References:
27
+ Howard, A., et al. (2019). Searching for MobileNetV3.
28
+ ICCV 2019. https://arxiv.org/abs/1905.02244
29
+
30
+ Author: Ductho Le (ductho.le@outlook.com)
31
+ """
32
+
33
+ from typing import Any
34
+
35
+ import torch
36
+ import torch.nn as nn
37
+
38
+
39
+ try:
40
+ from torchvision.models import (
41
+ MobileNet_V3_Large_Weights,
42
+ MobileNet_V3_Small_Weights,
43
+ mobilenet_v3_large,
44
+ mobilenet_v3_small,
45
+ )
46
+
47
+ MOBILENETV3_AVAILABLE = True
48
+ except ImportError:
49
+ MOBILENETV3_AVAILABLE = False
50
+
51
+ from wavedl.models.base import BaseModel
52
+ from wavedl.models.registry import register_model
53
+
54
+
55
+ class MobileNetV3Base(BaseModel):
56
+ """
57
+ Base MobileNetV3 class for regression tasks.
58
+
59
+ Wraps torchvision MobileNetV3 with:
60
+ - Optional pretrained weights (ImageNet-1K)
61
+ - Automatic input channel adaptation (grayscale → 3ch)
62
+ - Lightweight regression head (maintains efficiency)
63
+
64
+ MobileNetV3 is ideal for:
65
+ - Edge deployment (Raspberry Pi, Jetson, mobile)
66
+ - Real-time inference requirements
67
+ - Memory-constrained environments
68
+ - Quick prototyping and experimentation
69
+
70
+ Note: This is 2D-only. Input shape must be (H, W).
71
+ """
72
+
73
+ def __init__(
74
+ self,
75
+ in_shape: tuple[int, int],
76
+ out_size: int,
77
+ model_fn,
78
+ weights_class,
79
+ pretrained: bool = True,
80
+ dropout_rate: float = 0.2,
81
+ freeze_backbone: bool = False,
82
+ regression_hidden: int = 256,
83
+ **kwargs,
84
+ ):
85
+ """
86
+ Initialize MobileNetV3 for regression.
87
+
88
+ Args:
89
+ in_shape: (H, W) input image dimensions
90
+ out_size: Number of regression output targets
91
+ model_fn: torchvision model constructor
92
+ weights_class: Pretrained weights enum class
93
+ pretrained: Use ImageNet pretrained weights (default: True)
94
+ dropout_rate: Dropout rate in regression head (default: 0.2)
95
+ freeze_backbone: Freeze backbone for fine-tuning (default: False)
96
+ regression_hidden: Hidden units in regression head (default: 256)
97
+ """
98
+ super().__init__(in_shape, out_size)
99
+
100
+ if not MOBILENETV3_AVAILABLE:
101
+ raise ImportError(
102
+ "torchvision is required for MobileNetV3. "
103
+ "Install with: pip install torchvision"
104
+ )
105
+
106
+ if len(in_shape) != 2:
107
+ raise ValueError(
108
+ f"MobileNetV3 requires 2D input (H, W), got {len(in_shape)}D. "
109
+ "For 1D data, use TCN. For 3D data, use ResNet3D."
110
+ )
111
+
112
+ self.pretrained = pretrained
113
+ self.dropout_rate = dropout_rate
114
+ self.freeze_backbone = freeze_backbone
115
+ self.regression_hidden = regression_hidden
116
+
117
+ # Load pretrained backbone
118
+ weights = weights_class.IMAGENET1K_V1 if pretrained else None
119
+ self.backbone = model_fn(weights=weights)
120
+
121
+ # MobileNetV3 classifier structure:
122
+ # classifier[0]: Linear (features → 1280 for Large, 1024 for Small)
123
+ # classifier[1]: Hardswish
124
+ # classifier[2]: Dropout
125
+ # classifier[3]: Linear (1280/1024 → num_classes)
126
+
127
+ # Get the input features to the final classifier
128
+ in_features = self.backbone.classifier[0].in_features
129
+
130
+ # Replace classifier with lightweight regression head
131
+ # Keep it efficient to maintain MobileNet's speed advantage
132
+ self.backbone.classifier = nn.Sequential(
133
+ nn.Linear(in_features, regression_hidden),
134
+ nn.Hardswish(inplace=True), # Match MobileNetV3's activation
135
+ nn.Dropout(dropout_rate),
136
+ nn.Linear(regression_hidden, out_size),
137
+ )
138
+
139
+ # Optionally freeze backbone for fine-tuning
140
+ if freeze_backbone:
141
+ self._freeze_backbone()
142
+
143
+ def _freeze_backbone(self):
144
+ """Freeze all backbone parameters except the classifier."""
145
+ for name, param in self.backbone.named_parameters():
146
+ if "classifier" not in name:
147
+ param.requires_grad = False
148
+
149
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
150
+ """
151
+ Forward pass.
152
+
153
+ Args:
154
+ x: Input tensor of shape (B, C, H, W) where C is 1 or 3
155
+
156
+ Returns:
157
+ Output tensor of shape (B, out_size)
158
+ """
159
+ # Expand single channel to 3 channels for pretrained weights compatibility
160
+ if x.size(1) == 1:
161
+ x = x.expand(-1, 3, -1, -1)
162
+
163
+ return self.backbone(x)
164
+
165
+ @classmethod
166
+ def get_default_config(cls) -> dict[str, Any]:
167
+ """Return default configuration for MobileNetV3."""
168
+ return {
169
+ "pretrained": True,
170
+ "dropout_rate": 0.2,
171
+ "freeze_backbone": False,
172
+ "regression_hidden": 256,
173
+ }
174
+
175
+
176
+ # =============================================================================
177
+ # REGISTERED MODEL VARIANTS
178
+ # =============================================================================
179
+
180
+
181
+ @register_model("mobilenet_v3_small")
182
+ class MobileNetV3Small(MobileNetV3Base):
183
+ """
184
+ MobileNetV3-Small: Ultra-lightweight for edge deployment.
185
+
186
+ ~1.1M parameters. Designed for the most constrained environments.
187
+ Achieves ~67% ImageNet accuracy with minimal compute.
188
+
189
+ Recommended for:
190
+ - Embedded systems (Raspberry Pi, Arduino with accelerators)
191
+ - Battery-powered devices
192
+ - Ultra-low latency requirements (<10ms)
193
+ - Quick training experiments
194
+
195
+ Performance (approximate):
196
+ - CPU inference: ~6ms (single core)
197
+ - Parameters: 2.5M
198
+ - MAdds: 56M
199
+
200
+ Args:
201
+ in_shape: (H, W) image dimensions
202
+ out_size: Number of regression targets
203
+ pretrained: Use ImageNet pretrained weights (default: True)
204
+ dropout_rate: Dropout rate in head (default: 0.2)
205
+ freeze_backbone: Freeze backbone for fine-tuning (default: False)
206
+ regression_hidden: Hidden units in regression head (default: 256)
207
+
208
+ Example:
209
+ >>> model = MobileNetV3Small(in_shape=(224, 224), out_size=3)
210
+ >>> x = torch.randn(1, 1, 224, 224)
211
+ >>> out = model(x) # (1, 3)
212
+ """
213
+
214
+ def __init__(self, in_shape: tuple[int, int], out_size: int, **kwargs):
215
+ super().__init__(
216
+ in_shape=in_shape,
217
+ out_size=out_size,
218
+ model_fn=mobilenet_v3_small,
219
+ weights_class=MobileNet_V3_Small_Weights,
220
+ **kwargs,
221
+ )
222
+
223
+ def __repr__(self) -> str:
224
+ pt = "pretrained" if self.pretrained else "scratch"
225
+ return f"MobileNetV3_Small({pt}, in={self.in_shape}, out={self.out_size})"
226
+
227
+
228
+ @register_model("mobilenet_v3_large")
229
+ class MobileNetV3Large(MobileNetV3Base):
230
+ """
231
+ MobileNetV3-Large: Balanced efficiency and accuracy.
232
+
233
+ ~3.2M parameters. Best trade-off for mobile/portable deployment.
234
+ Achieves ~75% ImageNet accuracy with efficient inference.
235
+
236
+ Recommended for:
237
+ - Mobile deployment (smartphones, tablets)
238
+ - Portable inspection devices
239
+ - Real-time processing with moderate accuracy needs
240
+ - Default choice for edge deployment
241
+
242
+ Performance (approximate):
243
+ - CPU inference: ~20ms (single core)
244
+ - Parameters: 5.4M
245
+ - MAdds: 219M
246
+
247
+ Args:
248
+ in_shape: (H, W) image dimensions
249
+ out_size: Number of regression targets
250
+ pretrained: Use ImageNet pretrained weights (default: True)
251
+ dropout_rate: Dropout rate in head (default: 0.2)
252
+ freeze_backbone: Freeze backbone for fine-tuning (default: False)
253
+ regression_hidden: Hidden units in regression head (default: 256)
254
+
255
+ Example:
256
+ >>> model = MobileNetV3Large(in_shape=(224, 224), out_size=3)
257
+ >>> x = torch.randn(1, 1, 224, 224)
258
+ >>> out = model(x) # (1, 3)
259
+ """
260
+
261
+ def __init__(self, in_shape: tuple[int, int], out_size: int, **kwargs):
262
+ super().__init__(
263
+ in_shape=in_shape,
264
+ out_size=out_size,
265
+ model_fn=mobilenet_v3_large,
266
+ weights_class=MobileNet_V3_Large_Weights,
267
+ **kwargs,
268
+ )
269
+
270
+ def __repr__(self) -> str:
271
+ pt = "pretrained" if self.pretrained else "scratch"
272
+ return f"MobileNetV3_Large({pt}, in={self.in_shape}, out={self.out_size})"
wavedl/models/registry.py CHANGED
@@ -6,7 +6,6 @@ Provides the core model registration and factory functionality.
6
6
  This module has no dependencies on other model modules to prevent circular imports.
7
7
 
8
8
  Author: Ductho Le (ductho.le@outlook.com)
9
- Version: 1.0.0
10
9
  """
11
10
 
12
11
  import torch.nn as nn
@@ -0,0 +1,383 @@
1
+ """
2
+ RegNet: Designing Network Design Spaces
3
+ ========================================
4
+
5
+ RegNet provides a family of models with predictable scaling behavior,
6
+ designed through systematic exploration of network design spaces.
7
+ Models scale smoothly from mobile to server deployments.
8
+
9
+ **Key Features**:
10
+ - Predictable scaling: accuracy increases linearly with compute
11
+ - Simple, uniform architecture (no complex compound scaling)
12
+ - Group convolutions for efficiency
13
+ - Optional Squeeze-and-Excitation (SE) attention
14
+
15
+ **Variants** (RegNetY includes SE attention):
16
+ - regnet_y_400mf: Ultra-light (~4.0M params, 0.4 GFLOPs)
17
+ - regnet_y_800mf: Light (~5.8M params, 0.8 GFLOPs)
18
+ - regnet_y_1_6gf: Medium (~10.5M params, 1.6 GFLOPs) - Recommended
19
+ - regnet_y_3_2gf: Large (~18.3M params, 3.2 GFLOPs)
20
+ - regnet_y_8gf: Very large (~37.9M params, 8.0 GFLOPs)
21
+
22
+ **When to Use RegNet**:
23
+ - When you need predictable performance at a given compute budget
24
+ - For systematic model selection experiments
25
+ - When interpretability of design choices matters
26
+ - As an efficient alternative to ResNet
27
+
28
+ **Note**: RegNet is 2D-only. For 1D data, use TCN. For 3D data, use ResNet3D.
29
+
30
+ References:
31
+ Radosavovic, I., et al. (2020). Designing Network Design Spaces.
32
+ CVPR 2020. https://arxiv.org/abs/2003.13678
33
+
34
+ Author: Ductho Le (ductho.le@outlook.com)
35
+ """
36
+
37
+ from typing import Any
38
+
39
+ import torch
40
+ import torch.nn as nn
41
+
42
+
43
+ try:
44
+ from torchvision.models import (
45
+ RegNet_Y_1_6GF_Weights,
46
+ RegNet_Y_3_2GF_Weights,
47
+ RegNet_Y_8GF_Weights,
48
+ RegNet_Y_400MF_Weights,
49
+ RegNet_Y_800MF_Weights,
50
+ regnet_y_1_6gf,
51
+ regnet_y_3_2gf,
52
+ regnet_y_8gf,
53
+ regnet_y_400mf,
54
+ regnet_y_800mf,
55
+ )
56
+
57
+ REGNET_AVAILABLE = True
58
+ except ImportError:
59
+ REGNET_AVAILABLE = False
60
+
61
+ from wavedl.models.base import BaseModel
62
+ from wavedl.models.registry import register_model
63
+
64
+
65
+ class RegNetBase(BaseModel):
66
+ """
67
+ Base RegNet class for regression tasks.
68
+
69
+ Wraps torchvision RegNetY (with SE attention) with:
70
+ - Optional pretrained weights (ImageNet-1K)
71
+ - Automatic input channel adaptation (grayscale → 3ch)
72
+ - Custom regression head
73
+
74
+ RegNet advantages:
75
+ - Simple, uniform design (easy to understand and modify)
76
+ - Predictable accuracy/compute trade-off
77
+ - Efficient group convolutions
78
+ - SE attention for channel weighting (RegNetY variants)
79
+
80
+ Note: This is 2D-only. Input shape must be (H, W).
81
+ """
82
+
83
+ def __init__(
84
+ self,
85
+ in_shape: tuple[int, int],
86
+ out_size: int,
87
+ model_fn,
88
+ weights_class,
89
+ pretrained: bool = True,
90
+ dropout_rate: float = 0.2,
91
+ freeze_backbone: bool = False,
92
+ regression_hidden: int = 256,
93
+ **kwargs,
94
+ ):
95
+ """
96
+ Initialize RegNet for regression.
97
+
98
+ Args:
99
+ in_shape: (H, W) input image dimensions
100
+ out_size: Number of regression output targets
101
+ model_fn: torchvision model constructor
102
+ weights_class: Pretrained weights enum class
103
+ pretrained: Use ImageNet pretrained weights (default: True)
104
+ dropout_rate: Dropout rate in regression head (default: 0.2)
105
+ freeze_backbone: Freeze backbone for fine-tuning (default: False)
106
+ regression_hidden: Hidden units in regression head (default: 256)
107
+ """
108
+ super().__init__(in_shape, out_size)
109
+
110
+ if not REGNET_AVAILABLE:
111
+ raise ImportError(
112
+ "torchvision is required for RegNet. "
113
+ "Install with: pip install torchvision"
114
+ )
115
+
116
+ if len(in_shape) != 2:
117
+ raise ValueError(
118
+ f"RegNet requires 2D input (H, W), got {len(in_shape)}D. "
119
+ "For 1D data, use TCN. For 3D data, use ResNet3D."
120
+ )
121
+
122
+ self.pretrained = pretrained
123
+ self.dropout_rate = dropout_rate
124
+ self.freeze_backbone = freeze_backbone
125
+ self.regression_hidden = regression_hidden
126
+
127
+ # Load pretrained backbone
128
+ weights = weights_class.IMAGENET1K_V1 if pretrained else None
129
+ self.backbone = model_fn(weights=weights)
130
+
131
+ # RegNet uses .fc as the classification head
132
+ in_features = self.backbone.fc.in_features
133
+
134
+ # Replace fc with regression head
135
+ self.backbone.fc = nn.Sequential(
136
+ nn.Dropout(dropout_rate),
137
+ nn.Linear(in_features, regression_hidden),
138
+ nn.ReLU(inplace=True),
139
+ nn.Dropout(dropout_rate * 0.5),
140
+ nn.Linear(regression_hidden, out_size),
141
+ )
142
+
143
+ # Optionally freeze backbone for fine-tuning
144
+ if freeze_backbone:
145
+ self._freeze_backbone()
146
+
147
+ def _freeze_backbone(self):
148
+ """Freeze all backbone parameters except the fc layer."""
149
+ for name, param in self.backbone.named_parameters():
150
+ if "fc" not in name:
151
+ param.requires_grad = False
152
+
153
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
154
+ """
155
+ Forward pass.
156
+
157
+ Args:
158
+ x: Input tensor of shape (B, C, H, W) where C is 1 or 3
159
+
160
+ Returns:
161
+ Output tensor of shape (B, out_size)
162
+ """
163
+ # Expand single channel to 3 channels for pretrained weights compatibility
164
+ if x.size(1) == 1:
165
+ x = x.expand(-1, 3, -1, -1)
166
+
167
+ return self.backbone(x)
168
+
169
+ @classmethod
170
+ def get_default_config(cls) -> dict[str, Any]:
171
+ """Return default configuration for RegNet."""
172
+ return {
173
+ "pretrained": True,
174
+ "dropout_rate": 0.2,
175
+ "freeze_backbone": False,
176
+ "regression_hidden": 256,
177
+ }
178
+
179
+
180
+ # =============================================================================
181
+ # REGISTERED MODEL VARIANTS
182
+ # =============================================================================
183
+
184
+
185
+ @register_model("regnet_y_400mf")
186
+ class RegNetY400MF(RegNetBase):
187
+ """
188
+ RegNetY-400MF: Ultra-lightweight for constrained environments.
189
+
190
+ ~4.0M parameters, 0.4 GFLOPs. Smallest RegNet variant with SE attention.
191
+
192
+ Recommended for:
193
+ - Edge deployment with moderate accuracy needs
194
+ - Quick training experiments
195
+ - Baseline comparisons
196
+
197
+ Args:
198
+ in_shape: (H, W) image dimensions
199
+ out_size: Number of regression targets
200
+ pretrained: Use ImageNet pretrained weights (default: True)
201
+ dropout_rate: Dropout rate in head (default: 0.2)
202
+ freeze_backbone: Freeze backbone for fine-tuning (default: False)
203
+ regression_hidden: Hidden units in regression head (default: 256)
204
+
205
+ Example:
206
+ >>> model = RegNetY400MF(in_shape=(224, 224), out_size=3)
207
+ >>> x = torch.randn(4, 1, 224, 224)
208
+ >>> out = model(x) # (4, 3)
209
+ """
210
+
211
+ def __init__(self, in_shape: tuple[int, int], out_size: int, **kwargs):
212
+ super().__init__(
213
+ in_shape=in_shape,
214
+ out_size=out_size,
215
+ model_fn=regnet_y_400mf,
216
+ weights_class=RegNet_Y_400MF_Weights,
217
+ **kwargs,
218
+ )
219
+
220
+ def __repr__(self) -> str:
221
+ pt = "pretrained" if self.pretrained else "scratch"
222
+ return f"RegNetY_400MF({pt}, in={self.in_shape}, out={self.out_size})"
223
+
224
+
225
+ @register_model("regnet_y_800mf")
226
+ class RegNetY800MF(RegNetBase):
227
+ """
228
+ RegNetY-800MF: Light variant with good accuracy.
229
+
230
+ ~6.4M parameters, 0.8 GFLOPs. Good balance for mobile deployment.
231
+
232
+ Recommended for:
233
+ - Mobile/portable devices
234
+ - When MobileNet isn't accurate enough
235
+ - Moderate compute budgets
236
+
237
+ Args:
238
+ in_shape: (H, W) image dimensions
239
+ out_size: Number of regression targets
240
+ pretrained: Use ImageNet pretrained weights (default: True)
241
+ dropout_rate: Dropout rate in head (default: 0.2)
242
+ freeze_backbone: Freeze backbone for fine-tuning (default: False)
243
+ regression_hidden: Hidden units in regression head (default: 256)
244
+
245
+ Example:
246
+ >>> model = RegNetY800MF(in_shape=(224, 224), out_size=3)
247
+ >>> x = torch.randn(4, 1, 224, 224)
248
+ >>> out = model(x) # (4, 3)
249
+ """
250
+
251
+ def __init__(self, in_shape: tuple[int, int], out_size: int, **kwargs):
252
+ super().__init__(
253
+ in_shape=in_shape,
254
+ out_size=out_size,
255
+ model_fn=regnet_y_800mf,
256
+ weights_class=RegNet_Y_800MF_Weights,
257
+ **kwargs,
258
+ )
259
+
260
+ def __repr__(self) -> str:
261
+ pt = "pretrained" if self.pretrained else "scratch"
262
+ return f"RegNetY_800MF({pt}, in={self.in_shape}, out={self.out_size})"
263
+
264
+
265
+ @register_model("regnet_y_1_6gf")
266
+ class RegNetY1_6GF(RegNetBase):
267
+ """
268
+ RegNetY-1.6GF: Recommended default for balanced performance.
269
+
270
+ ~11.2M parameters, 1.6 GFLOPs. Best trade-off of accuracy and efficiency.
271
+ Comparable to ResNet50 but more efficient.
272
+
273
+ Recommended for:
274
+ - Default choice for general wave-based tasks
275
+ - When you want predictable scaling
276
+ - Server deployment with efficiency needs
277
+
278
+ Args:
279
+ in_shape: (H, W) image dimensions
280
+ out_size: Number of regression targets
281
+ pretrained: Use ImageNet pretrained weights (default: True)
282
+ dropout_rate: Dropout rate in head (default: 0.2)
283
+ freeze_backbone: Freeze backbone for fine-tuning (default: False)
284
+ regression_hidden: Hidden units in regression head (default: 256)
285
+
286
+ Example:
287
+ >>> model = RegNetY1_6GF(in_shape=(224, 224), out_size=3)
288
+ >>> x = torch.randn(4, 1, 224, 224)
289
+ >>> out = model(x) # (4, 3)
290
+ """
291
+
292
+ def __init__(self, in_shape: tuple[int, int], out_size: int, **kwargs):
293
+ super().__init__(
294
+ in_shape=in_shape,
295
+ out_size=out_size,
296
+ model_fn=regnet_y_1_6gf,
297
+ weights_class=RegNet_Y_1_6GF_Weights,
298
+ **kwargs,
299
+ )
300
+
301
+ def __repr__(self) -> str:
302
+ pt = "pretrained" if self.pretrained else "scratch"
303
+ return f"RegNetY_1.6GF({pt}, in={self.in_shape}, out={self.out_size})"
304
+
305
+
306
+ @register_model("regnet_y_3_2gf")
307
+ class RegNetY3_2GF(RegNetBase):
308
+ """
309
+ RegNetY-3.2GF: Higher accuracy for demanding tasks.
310
+
311
+ ~19.4M parameters, 3.2 GFLOPs. Use when 1.6GF isn't sufficient.
312
+
313
+ Recommended for:
314
+ - Larger datasets requiring more capacity
315
+ - When accuracy is more important than efficiency
316
+ - Research experiments with multiple model sizes
317
+
318
+ Args:
319
+ in_shape: (H, W) image dimensions
320
+ out_size: Number of regression targets
321
+ pretrained: Use ImageNet pretrained weights (default: True)
322
+ dropout_rate: Dropout rate in head (default: 0.2)
323
+ freeze_backbone: Freeze backbone for fine-tuning (default: False)
324
+ regression_hidden: Hidden units in regression head (default: 256)
325
+
326
+ Example:
327
+ >>> model = RegNetY3_2GF(in_shape=(224, 224), out_size=3)
328
+ >>> x = torch.randn(4, 1, 224, 224)
329
+ >>> out = model(x) # (4, 3)
330
+ """
331
+
332
+ def __init__(self, in_shape: tuple[int, int], out_size: int, **kwargs):
333
+ super().__init__(
334
+ in_shape=in_shape,
335
+ out_size=out_size,
336
+ model_fn=regnet_y_3_2gf,
337
+ weights_class=RegNet_Y_3_2GF_Weights,
338
+ **kwargs,
339
+ )
340
+
341
+ def __repr__(self) -> str:
342
+ pt = "pretrained" if self.pretrained else "scratch"
343
+ return f"RegNetY_3.2GF({pt}, in={self.in_shape}, out={self.out_size})"
344
+
345
+
346
+ @register_model("regnet_y_8gf")
347
+ class RegNetY8GF(RegNetBase):
348
+ """
349
+ RegNetY-8GF: High capacity for large-scale tasks.
350
+
351
+ ~39.2M parameters, 8.0 GFLOPs. Use for maximum accuracy needs.
352
+
353
+ Recommended for:
354
+ - Very large datasets (>50k samples)
355
+ - Complex wave patterns
356
+ - HPC environments with ample GPU memory
357
+
358
+ Args:
359
+ in_shape: (H, W) image dimensions
360
+ out_size: Number of regression targets
361
+ pretrained: Use ImageNet pretrained weights (default: True)
362
+ dropout_rate: Dropout rate in head (default: 0.2)
363
+ freeze_backbone: Freeze backbone for fine-tuning (default: False)
364
+ regression_hidden: Hidden units in regression head (default: 256)
365
+
366
+ Example:
367
+ >>> model = RegNetY8GF(in_shape=(224, 224), out_size=3)
368
+ >>> x = torch.randn(4, 1, 224, 224)
369
+ >>> out = model(x) # (4, 3)
370
+ """
371
+
372
+ def __init__(self, in_shape: tuple[int, int], out_size: int, **kwargs):
373
+ super().__init__(
374
+ in_shape=in_shape,
375
+ out_size=out_size,
376
+ model_fn=regnet_y_8gf,
377
+ weights_class=RegNet_Y_8GF_Weights,
378
+ **kwargs,
379
+ )
380
+
381
+ def __repr__(self) -> str:
382
+ pt = "pretrained" if self.pretrained else "scratch"
383
+ return f"RegNetY_8GF({pt}, in={self.in_shape}, out={self.out_size})"
wavedl/models/resnet.py CHANGED
@@ -11,12 +11,15 @@ Provides multiple depth variants (18, 34, 50) with optional pretrained weights f
11
11
  - 3D: Volumetric data, CT/MRI (N, 1, D, H, W) → Conv3d
12
12
 
13
13
  **Variants**:
14
- - resnet18: Lightweight, fast training
15
- - resnet34: Balanced capacity
16
- - resnet50: Higher capacity with bottleneck blocks
14
+ - resnet18: Lightweight, fast training (~11M params)
15
+ - resnet34: Balanced capacity (~21M params)
16
+ - resnet50: Higher capacity with bottleneck blocks (~25M params)
17
+
18
+ References:
19
+ He, K., Zhang, X., Ren, S., & Sun, J. (2016). Deep Residual Learning
20
+ for Image Recognition. CVPR 2016. https://arxiv.org/abs/1512.03385
17
21
 
18
22
  Author: Ductho Le (ductho.le@outlook.com)
19
- Version: 1.0.0
20
23
  """
21
24
 
22
25
  from typing import Any