wavedl 1.3.1__py3-none-any.whl → 1.4.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- wavedl/__init__.py +1 -1
- wavedl/hpc.py +48 -28
- wavedl/models/__init__.py +33 -7
- wavedl/models/_template.py +28 -41
- wavedl/models/base.py +49 -2
- wavedl/models/cnn.py +0 -1
- wavedl/models/convnext.py +4 -1
- wavedl/models/densenet.py +4 -1
- wavedl/models/efficientnet.py +9 -5
- wavedl/models/efficientnetv2.py +292 -0
- wavedl/models/mobilenetv3.py +272 -0
- wavedl/models/registry.py +0 -1
- wavedl/models/regnet.py +383 -0
- wavedl/models/resnet.py +7 -4
- wavedl/models/resnet3d.py +258 -0
- wavedl/models/swin.py +390 -0
- wavedl/models/tcn.py +389 -0
- wavedl/models/unet.py +44 -110
- wavedl/models/vit.py +8 -4
- wavedl/train.py +1144 -1116
- wavedl/utils/config.py +88 -2
- {wavedl-1.3.1.dist-info → wavedl-1.4.1.dist-info}/METADATA +136 -98
- wavedl-1.4.1.dist-info/RECORD +37 -0
- wavedl-1.3.1.dist-info/RECORD +0 -31
- {wavedl-1.3.1.dist-info → wavedl-1.4.1.dist-info}/LICENSE +0 -0
- {wavedl-1.3.1.dist-info → wavedl-1.4.1.dist-info}/WHEEL +0 -0
- {wavedl-1.3.1.dist-info → wavedl-1.4.1.dist-info}/entry_points.txt +0 -0
- {wavedl-1.3.1.dist-info → wavedl-1.4.1.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,292 @@
|
|
|
1
|
+
"""
|
|
2
|
+
EfficientNetV2: Faster Training and Better Accuracy
|
|
3
|
+
====================================================
|
|
4
|
+
|
|
5
|
+
Next-generation EfficientNet with improved training efficiency and performance.
|
|
6
|
+
EfficientNetV2 replaces early depthwise convolutions with fused MBConv blocks,
|
|
7
|
+
enabling 2-4× faster training while achieving better accuracy.
|
|
8
|
+
|
|
9
|
+
**Key Improvements over EfficientNet**:
|
|
10
|
+
- Fused-MBConv in early stages (faster on accelerators)
|
|
11
|
+
- Progressive learning support (start small, grow)
|
|
12
|
+
- Better NAS-optimized architecture
|
|
13
|
+
|
|
14
|
+
**Variants**:
|
|
15
|
+
- efficientnet_v2_s: Small (21.5M params) - Recommended default
|
|
16
|
+
- efficientnet_v2_m: Medium (54.1M params) - Higher accuracy
|
|
17
|
+
- efficientnet_v2_l: Large (118.5M params) - Maximum accuracy
|
|
18
|
+
|
|
19
|
+
**Note**: EfficientNetV2 is 2D-only. For 1D data, use TCN. For 3D data, use ResNet3D.
|
|
20
|
+
|
|
21
|
+
References:
|
|
22
|
+
Tan, M., & Le, Q. (2021). EfficientNetV2: Smaller Models and Faster Training.
|
|
23
|
+
ICML 2021. https://arxiv.org/abs/2104.00298
|
|
24
|
+
|
|
25
|
+
Author: Ductho Le (ductho.le@outlook.com)
|
|
26
|
+
"""
|
|
27
|
+
|
|
28
|
+
from typing import Any
|
|
29
|
+
|
|
30
|
+
import torch
|
|
31
|
+
import torch.nn as nn
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
try:
|
|
35
|
+
from torchvision.models import (
|
|
36
|
+
EfficientNet_V2_L_Weights,
|
|
37
|
+
EfficientNet_V2_M_Weights,
|
|
38
|
+
EfficientNet_V2_S_Weights,
|
|
39
|
+
efficientnet_v2_l,
|
|
40
|
+
efficientnet_v2_m,
|
|
41
|
+
efficientnet_v2_s,
|
|
42
|
+
)
|
|
43
|
+
|
|
44
|
+
EFFICIENTNETV2_AVAILABLE = True
|
|
45
|
+
except ImportError:
|
|
46
|
+
EFFICIENTNETV2_AVAILABLE = False
|
|
47
|
+
|
|
48
|
+
from wavedl.models.base import BaseModel
|
|
49
|
+
from wavedl.models.registry import register_model
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
class EfficientNetV2Base(BaseModel):
|
|
53
|
+
"""
|
|
54
|
+
Base EfficientNetV2 class for regression tasks.
|
|
55
|
+
|
|
56
|
+
Wraps torchvision EfficientNetV2 with:
|
|
57
|
+
- Optional pretrained weights (ImageNet-1K)
|
|
58
|
+
- Automatic input channel adaptation (grayscale → 3ch)
|
|
59
|
+
- Custom multi-layer regression head
|
|
60
|
+
|
|
61
|
+
Compared to EfficientNet (V1):
|
|
62
|
+
- 2-4× faster training on GPU/TPU
|
|
63
|
+
- Better accuracy at similar parameter counts
|
|
64
|
+
- More efficient at higher resolutions
|
|
65
|
+
|
|
66
|
+
Note: This is 2D-only. Input shape must be (H, W).
|
|
67
|
+
"""
|
|
68
|
+
|
|
69
|
+
def __init__(
|
|
70
|
+
self,
|
|
71
|
+
in_shape: tuple[int, int],
|
|
72
|
+
out_size: int,
|
|
73
|
+
model_fn,
|
|
74
|
+
weights_class,
|
|
75
|
+
pretrained: bool = True,
|
|
76
|
+
dropout_rate: float = 0.3,
|
|
77
|
+
freeze_backbone: bool = False,
|
|
78
|
+
regression_hidden: int = 512,
|
|
79
|
+
**kwargs,
|
|
80
|
+
):
|
|
81
|
+
"""
|
|
82
|
+
Initialize EfficientNetV2 for regression.
|
|
83
|
+
|
|
84
|
+
Args:
|
|
85
|
+
in_shape: (H, W) input image dimensions
|
|
86
|
+
out_size: Number of regression output targets
|
|
87
|
+
model_fn: torchvision model constructor
|
|
88
|
+
weights_class: Pretrained weights enum class
|
|
89
|
+
pretrained: Use ImageNet pretrained weights (default: True)
|
|
90
|
+
dropout_rate: Dropout rate in regression head (default: 0.3)
|
|
91
|
+
freeze_backbone: Freeze backbone for fine-tuning (default: False)
|
|
92
|
+
regression_hidden: Hidden units in regression head (default: 512)
|
|
93
|
+
"""
|
|
94
|
+
super().__init__(in_shape, out_size)
|
|
95
|
+
|
|
96
|
+
if not EFFICIENTNETV2_AVAILABLE:
|
|
97
|
+
raise ImportError(
|
|
98
|
+
"torchvision >= 0.13 is required for EfficientNetV2. "
|
|
99
|
+
"Install with: pip install torchvision>=0.13"
|
|
100
|
+
)
|
|
101
|
+
|
|
102
|
+
if len(in_shape) != 2:
|
|
103
|
+
raise ValueError(
|
|
104
|
+
f"EfficientNetV2 requires 2D input (H, W), got {len(in_shape)}D. "
|
|
105
|
+
"For 1D data, use TCN. For 3D data, use ResNet3D."
|
|
106
|
+
)
|
|
107
|
+
|
|
108
|
+
self.pretrained = pretrained
|
|
109
|
+
self.dropout_rate = dropout_rate
|
|
110
|
+
self.freeze_backbone = freeze_backbone
|
|
111
|
+
self.regression_hidden = regression_hidden
|
|
112
|
+
|
|
113
|
+
# Load pretrained backbone
|
|
114
|
+
weights = weights_class.IMAGENET1K_V1 if pretrained else None
|
|
115
|
+
self.backbone = model_fn(weights=weights)
|
|
116
|
+
|
|
117
|
+
# Get classifier input features (before the final classification layer)
|
|
118
|
+
in_features = self.backbone.classifier[1].in_features
|
|
119
|
+
|
|
120
|
+
# Replace classifier with regression head
|
|
121
|
+
# EfficientNetV2 benefits from a deeper regression head
|
|
122
|
+
self.backbone.classifier = nn.Sequential(
|
|
123
|
+
nn.Dropout(dropout_rate),
|
|
124
|
+
nn.Linear(in_features, regression_hidden),
|
|
125
|
+
nn.SiLU(inplace=True), # SiLU (Swish) matches EfficientNet's activation
|
|
126
|
+
nn.Dropout(dropout_rate * 0.5),
|
|
127
|
+
nn.Linear(regression_hidden, regression_hidden // 2),
|
|
128
|
+
nn.SiLU(inplace=True),
|
|
129
|
+
nn.Linear(regression_hidden // 2, out_size),
|
|
130
|
+
)
|
|
131
|
+
|
|
132
|
+
# Optionally freeze backbone for fine-tuning
|
|
133
|
+
if freeze_backbone:
|
|
134
|
+
self._freeze_backbone()
|
|
135
|
+
|
|
136
|
+
def _freeze_backbone(self):
|
|
137
|
+
"""Freeze all backbone parameters except the classifier."""
|
|
138
|
+
for name, param in self.backbone.named_parameters():
|
|
139
|
+
if "classifier" not in name:
|
|
140
|
+
param.requires_grad = False
|
|
141
|
+
|
|
142
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
143
|
+
"""
|
|
144
|
+
Forward pass.
|
|
145
|
+
|
|
146
|
+
Args:
|
|
147
|
+
x: Input tensor of shape (B, C, H, W) where C is 1 or 3
|
|
148
|
+
|
|
149
|
+
Returns:
|
|
150
|
+
Output tensor of shape (B, out_size)
|
|
151
|
+
"""
|
|
152
|
+
# Expand single channel to 3 channels for pretrained weights compatibility
|
|
153
|
+
if x.size(1) == 1:
|
|
154
|
+
x = x.expand(-1, 3, -1, -1)
|
|
155
|
+
|
|
156
|
+
return self.backbone(x)
|
|
157
|
+
|
|
158
|
+
@classmethod
|
|
159
|
+
def get_default_config(cls) -> dict[str, Any]:
|
|
160
|
+
"""Return default configuration for EfficientNetV2."""
|
|
161
|
+
return {
|
|
162
|
+
"pretrained": True,
|
|
163
|
+
"dropout_rate": 0.3,
|
|
164
|
+
"freeze_backbone": False,
|
|
165
|
+
"regression_hidden": 512,
|
|
166
|
+
}
|
|
167
|
+
|
|
168
|
+
|
|
169
|
+
# =============================================================================
|
|
170
|
+
# REGISTERED MODEL VARIANTS
|
|
171
|
+
# =============================================================================
|
|
172
|
+
|
|
173
|
+
|
|
174
|
+
@register_model("efficientnet_v2_s")
|
|
175
|
+
class EfficientNetV2S(EfficientNetV2Base):
|
|
176
|
+
"""
|
|
177
|
+
EfficientNetV2-S: Small variant, recommended default.
|
|
178
|
+
|
|
179
|
+
~21.5M parameters. Best balance of speed and accuracy for most tasks.
|
|
180
|
+
2× faster training than EfficientNet-B4 with better accuracy.
|
|
181
|
+
|
|
182
|
+
Recommended for:
|
|
183
|
+
- Default choice for 2D wave data
|
|
184
|
+
- Moderate compute budgets
|
|
185
|
+
- When training speed matters
|
|
186
|
+
|
|
187
|
+
Args:
|
|
188
|
+
in_shape: (H, W) image dimensions
|
|
189
|
+
out_size: Number of regression targets
|
|
190
|
+
pretrained: Use ImageNet pretrained weights (default: True)
|
|
191
|
+
dropout_rate: Dropout rate in head (default: 0.3)
|
|
192
|
+
freeze_backbone: Freeze backbone for fine-tuning (default: False)
|
|
193
|
+
regression_hidden: Hidden units in regression head (default: 512)
|
|
194
|
+
|
|
195
|
+
Example:
|
|
196
|
+
>>> model = EfficientNetV2S(in_shape=(500, 500), out_size=3)
|
|
197
|
+
>>> x = torch.randn(4, 1, 500, 500)
|
|
198
|
+
>>> out = model(x) # (4, 3)
|
|
199
|
+
"""
|
|
200
|
+
|
|
201
|
+
def __init__(self, in_shape: tuple[int, int], out_size: int, **kwargs):
|
|
202
|
+
super().__init__(
|
|
203
|
+
in_shape=in_shape,
|
|
204
|
+
out_size=out_size,
|
|
205
|
+
model_fn=efficientnet_v2_s,
|
|
206
|
+
weights_class=EfficientNet_V2_S_Weights,
|
|
207
|
+
**kwargs,
|
|
208
|
+
)
|
|
209
|
+
|
|
210
|
+
def __repr__(self) -> str:
|
|
211
|
+
pt = "pretrained" if self.pretrained else "scratch"
|
|
212
|
+
return f"EfficientNetV2_S({pt}, in={self.in_shape}, out={self.out_size})"
|
|
213
|
+
|
|
214
|
+
|
|
215
|
+
@register_model("efficientnet_v2_m")
|
|
216
|
+
class EfficientNetV2M(EfficientNetV2Base):
|
|
217
|
+
"""
|
|
218
|
+
EfficientNetV2-M: Medium variant for higher accuracy.
|
|
219
|
+
|
|
220
|
+
~54.1M parameters. Use when accuracy is more important than speed.
|
|
221
|
+
|
|
222
|
+
Recommended for:
|
|
223
|
+
- Large datasets (>50k samples)
|
|
224
|
+
- Complex wave patterns
|
|
225
|
+
- When compute is not a bottleneck
|
|
226
|
+
|
|
227
|
+
Args:
|
|
228
|
+
in_shape: (H, W) image dimensions
|
|
229
|
+
out_size: Number of regression targets
|
|
230
|
+
pretrained: Use ImageNet pretrained weights (default: True)
|
|
231
|
+
dropout_rate: Dropout rate in head (default: 0.3)
|
|
232
|
+
freeze_backbone: Freeze backbone for fine-tuning (default: False)
|
|
233
|
+
regression_hidden: Hidden units in regression head (default: 512)
|
|
234
|
+
|
|
235
|
+
Example:
|
|
236
|
+
>>> model = EfficientNetV2M(in_shape=(500, 500), out_size=3)
|
|
237
|
+
>>> x = torch.randn(4, 1, 500, 500)
|
|
238
|
+
>>> out = model(x) # (4, 3)
|
|
239
|
+
"""
|
|
240
|
+
|
|
241
|
+
def __init__(self, in_shape: tuple[int, int], out_size: int, **kwargs):
|
|
242
|
+
super().__init__(
|
|
243
|
+
in_shape=in_shape,
|
|
244
|
+
out_size=out_size,
|
|
245
|
+
model_fn=efficientnet_v2_m,
|
|
246
|
+
weights_class=EfficientNet_V2_M_Weights,
|
|
247
|
+
**kwargs,
|
|
248
|
+
)
|
|
249
|
+
|
|
250
|
+
def __repr__(self) -> str:
|
|
251
|
+
pt = "pretrained" if self.pretrained else "scratch"
|
|
252
|
+
return f"EfficientNetV2_M({pt}, in={self.in_shape}, out={self.out_size})"
|
|
253
|
+
|
|
254
|
+
|
|
255
|
+
@register_model("efficientnet_v2_l")
|
|
256
|
+
class EfficientNetV2L(EfficientNetV2Base):
|
|
257
|
+
"""
|
|
258
|
+
EfficientNetV2-L: Large variant for maximum accuracy.
|
|
259
|
+
|
|
260
|
+
~118.5M parameters. Use only with large datasets and sufficient compute.
|
|
261
|
+
|
|
262
|
+
Recommended for:
|
|
263
|
+
- Very large datasets (>100k samples)
|
|
264
|
+
- When maximum accuracy is critical
|
|
265
|
+
- HPC environments with ample GPU memory
|
|
266
|
+
|
|
267
|
+
Args:
|
|
268
|
+
in_shape: (H, W) image dimensions
|
|
269
|
+
out_size: Number of regression targets
|
|
270
|
+
pretrained: Use ImageNet pretrained weights (default: True)
|
|
271
|
+
dropout_rate: Dropout rate in head (default: 0.3)
|
|
272
|
+
freeze_backbone: Freeze backbone for fine-tuning (default: False)
|
|
273
|
+
regression_hidden: Hidden units in regression head (default: 512)
|
|
274
|
+
|
|
275
|
+
Example:
|
|
276
|
+
>>> model = EfficientNetV2L(in_shape=(500, 500), out_size=3)
|
|
277
|
+
>>> x = torch.randn(4, 1, 500, 500)
|
|
278
|
+
>>> out = model(x) # (4, 3)
|
|
279
|
+
"""
|
|
280
|
+
|
|
281
|
+
def __init__(self, in_shape: tuple[int, int], out_size: int, **kwargs):
|
|
282
|
+
super().__init__(
|
|
283
|
+
in_shape=in_shape,
|
|
284
|
+
out_size=out_size,
|
|
285
|
+
model_fn=efficientnet_v2_l,
|
|
286
|
+
weights_class=EfficientNet_V2_L_Weights,
|
|
287
|
+
**kwargs,
|
|
288
|
+
)
|
|
289
|
+
|
|
290
|
+
def __repr__(self) -> str:
|
|
291
|
+
pt = "pretrained" if self.pretrained else "scratch"
|
|
292
|
+
return f"EfficientNetV2_L({pt}, in={self.in_shape}, out={self.out_size})"
|
|
@@ -0,0 +1,272 @@
|
|
|
1
|
+
"""
|
|
2
|
+
MobileNetV3: Efficient Networks for Edge Deployment
|
|
3
|
+
====================================================
|
|
4
|
+
|
|
5
|
+
Lightweight architecture optimized for mobile and embedded devices.
|
|
6
|
+
MobileNetV3 combines neural architecture search (NAS) with hardware-aware
|
|
7
|
+
optimization to achieve excellent accuracy with minimal computational cost.
|
|
8
|
+
|
|
9
|
+
**Key Features**:
|
|
10
|
+
- Inverted residuals with depthwise separable convolutions
|
|
11
|
+
- Squeeze-and-Excitation (SE) attention for channel weighting
|
|
12
|
+
- h-swish activation: efficient approximation of swish
|
|
13
|
+
- Designed for real-time inference on CPUs and edge devices
|
|
14
|
+
|
|
15
|
+
**Variants**:
|
|
16
|
+
- mobilenet_v3_small: Ultra-lightweight (~1.1M params) - Edge/embedded
|
|
17
|
+
- mobilenet_v3_large: Balanced (~3.2M params) - Mobile deployment
|
|
18
|
+
|
|
19
|
+
**Use Cases**:
|
|
20
|
+
- Real-time structural health monitoring on embedded systems
|
|
21
|
+
- Field inspection with portable devices
|
|
22
|
+
- When model size and inference speed are critical
|
|
23
|
+
|
|
24
|
+
**Note**: MobileNetV3 is 2D-only. For 1D data, use TCN. For 3D data, use ResNet3D.
|
|
25
|
+
|
|
26
|
+
References:
|
|
27
|
+
Howard, A., et al. (2019). Searching for MobileNetV3.
|
|
28
|
+
ICCV 2019. https://arxiv.org/abs/1905.02244
|
|
29
|
+
|
|
30
|
+
Author: Ductho Le (ductho.le@outlook.com)
|
|
31
|
+
"""
|
|
32
|
+
|
|
33
|
+
from typing import Any
|
|
34
|
+
|
|
35
|
+
import torch
|
|
36
|
+
import torch.nn as nn
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
try:
|
|
40
|
+
from torchvision.models import (
|
|
41
|
+
MobileNet_V3_Large_Weights,
|
|
42
|
+
MobileNet_V3_Small_Weights,
|
|
43
|
+
mobilenet_v3_large,
|
|
44
|
+
mobilenet_v3_small,
|
|
45
|
+
)
|
|
46
|
+
|
|
47
|
+
MOBILENETV3_AVAILABLE = True
|
|
48
|
+
except ImportError:
|
|
49
|
+
MOBILENETV3_AVAILABLE = False
|
|
50
|
+
|
|
51
|
+
from wavedl.models.base import BaseModel
|
|
52
|
+
from wavedl.models.registry import register_model
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
class MobileNetV3Base(BaseModel):
|
|
56
|
+
"""
|
|
57
|
+
Base MobileNetV3 class for regression tasks.
|
|
58
|
+
|
|
59
|
+
Wraps torchvision MobileNetV3 with:
|
|
60
|
+
- Optional pretrained weights (ImageNet-1K)
|
|
61
|
+
- Automatic input channel adaptation (grayscale → 3ch)
|
|
62
|
+
- Lightweight regression head (maintains efficiency)
|
|
63
|
+
|
|
64
|
+
MobileNetV3 is ideal for:
|
|
65
|
+
- Edge deployment (Raspberry Pi, Jetson, mobile)
|
|
66
|
+
- Real-time inference requirements
|
|
67
|
+
- Memory-constrained environments
|
|
68
|
+
- Quick prototyping and experimentation
|
|
69
|
+
|
|
70
|
+
Note: This is 2D-only. Input shape must be (H, W).
|
|
71
|
+
"""
|
|
72
|
+
|
|
73
|
+
def __init__(
|
|
74
|
+
self,
|
|
75
|
+
in_shape: tuple[int, int],
|
|
76
|
+
out_size: int,
|
|
77
|
+
model_fn,
|
|
78
|
+
weights_class,
|
|
79
|
+
pretrained: bool = True,
|
|
80
|
+
dropout_rate: float = 0.2,
|
|
81
|
+
freeze_backbone: bool = False,
|
|
82
|
+
regression_hidden: int = 256,
|
|
83
|
+
**kwargs,
|
|
84
|
+
):
|
|
85
|
+
"""
|
|
86
|
+
Initialize MobileNetV3 for regression.
|
|
87
|
+
|
|
88
|
+
Args:
|
|
89
|
+
in_shape: (H, W) input image dimensions
|
|
90
|
+
out_size: Number of regression output targets
|
|
91
|
+
model_fn: torchvision model constructor
|
|
92
|
+
weights_class: Pretrained weights enum class
|
|
93
|
+
pretrained: Use ImageNet pretrained weights (default: True)
|
|
94
|
+
dropout_rate: Dropout rate in regression head (default: 0.2)
|
|
95
|
+
freeze_backbone: Freeze backbone for fine-tuning (default: False)
|
|
96
|
+
regression_hidden: Hidden units in regression head (default: 256)
|
|
97
|
+
"""
|
|
98
|
+
super().__init__(in_shape, out_size)
|
|
99
|
+
|
|
100
|
+
if not MOBILENETV3_AVAILABLE:
|
|
101
|
+
raise ImportError(
|
|
102
|
+
"torchvision is required for MobileNetV3. "
|
|
103
|
+
"Install with: pip install torchvision"
|
|
104
|
+
)
|
|
105
|
+
|
|
106
|
+
if len(in_shape) != 2:
|
|
107
|
+
raise ValueError(
|
|
108
|
+
f"MobileNetV3 requires 2D input (H, W), got {len(in_shape)}D. "
|
|
109
|
+
"For 1D data, use TCN. For 3D data, use ResNet3D."
|
|
110
|
+
)
|
|
111
|
+
|
|
112
|
+
self.pretrained = pretrained
|
|
113
|
+
self.dropout_rate = dropout_rate
|
|
114
|
+
self.freeze_backbone = freeze_backbone
|
|
115
|
+
self.regression_hidden = regression_hidden
|
|
116
|
+
|
|
117
|
+
# Load pretrained backbone
|
|
118
|
+
weights = weights_class.IMAGENET1K_V1 if pretrained else None
|
|
119
|
+
self.backbone = model_fn(weights=weights)
|
|
120
|
+
|
|
121
|
+
# MobileNetV3 classifier structure:
|
|
122
|
+
# classifier[0]: Linear (features → 1280 for Large, 1024 for Small)
|
|
123
|
+
# classifier[1]: Hardswish
|
|
124
|
+
# classifier[2]: Dropout
|
|
125
|
+
# classifier[3]: Linear (1280/1024 → num_classes)
|
|
126
|
+
|
|
127
|
+
# Get the input features to the final classifier
|
|
128
|
+
in_features = self.backbone.classifier[0].in_features
|
|
129
|
+
|
|
130
|
+
# Replace classifier with lightweight regression head
|
|
131
|
+
# Keep it efficient to maintain MobileNet's speed advantage
|
|
132
|
+
self.backbone.classifier = nn.Sequential(
|
|
133
|
+
nn.Linear(in_features, regression_hidden),
|
|
134
|
+
nn.Hardswish(inplace=True), # Match MobileNetV3's activation
|
|
135
|
+
nn.Dropout(dropout_rate),
|
|
136
|
+
nn.Linear(regression_hidden, out_size),
|
|
137
|
+
)
|
|
138
|
+
|
|
139
|
+
# Optionally freeze backbone for fine-tuning
|
|
140
|
+
if freeze_backbone:
|
|
141
|
+
self._freeze_backbone()
|
|
142
|
+
|
|
143
|
+
def _freeze_backbone(self):
|
|
144
|
+
"""Freeze all backbone parameters except the classifier."""
|
|
145
|
+
for name, param in self.backbone.named_parameters():
|
|
146
|
+
if "classifier" not in name:
|
|
147
|
+
param.requires_grad = False
|
|
148
|
+
|
|
149
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
150
|
+
"""
|
|
151
|
+
Forward pass.
|
|
152
|
+
|
|
153
|
+
Args:
|
|
154
|
+
x: Input tensor of shape (B, C, H, W) where C is 1 or 3
|
|
155
|
+
|
|
156
|
+
Returns:
|
|
157
|
+
Output tensor of shape (B, out_size)
|
|
158
|
+
"""
|
|
159
|
+
# Expand single channel to 3 channels for pretrained weights compatibility
|
|
160
|
+
if x.size(1) == 1:
|
|
161
|
+
x = x.expand(-1, 3, -1, -1)
|
|
162
|
+
|
|
163
|
+
return self.backbone(x)
|
|
164
|
+
|
|
165
|
+
@classmethod
|
|
166
|
+
def get_default_config(cls) -> dict[str, Any]:
|
|
167
|
+
"""Return default configuration for MobileNetV3."""
|
|
168
|
+
return {
|
|
169
|
+
"pretrained": True,
|
|
170
|
+
"dropout_rate": 0.2,
|
|
171
|
+
"freeze_backbone": False,
|
|
172
|
+
"regression_hidden": 256,
|
|
173
|
+
}
|
|
174
|
+
|
|
175
|
+
|
|
176
|
+
# =============================================================================
|
|
177
|
+
# REGISTERED MODEL VARIANTS
|
|
178
|
+
# =============================================================================
|
|
179
|
+
|
|
180
|
+
|
|
181
|
+
@register_model("mobilenet_v3_small")
|
|
182
|
+
class MobileNetV3Small(MobileNetV3Base):
|
|
183
|
+
"""
|
|
184
|
+
MobileNetV3-Small: Ultra-lightweight for edge deployment.
|
|
185
|
+
|
|
186
|
+
~1.1M parameters. Designed for the most constrained environments.
|
|
187
|
+
Achieves ~67% ImageNet accuracy with minimal compute.
|
|
188
|
+
|
|
189
|
+
Recommended for:
|
|
190
|
+
- Embedded systems (Raspberry Pi, Arduino with accelerators)
|
|
191
|
+
- Battery-powered devices
|
|
192
|
+
- Ultra-low latency requirements (<10ms)
|
|
193
|
+
- Quick training experiments
|
|
194
|
+
|
|
195
|
+
Performance (approximate):
|
|
196
|
+
- CPU inference: ~6ms (single core)
|
|
197
|
+
- Parameters: 2.5M
|
|
198
|
+
- MAdds: 56M
|
|
199
|
+
|
|
200
|
+
Args:
|
|
201
|
+
in_shape: (H, W) image dimensions
|
|
202
|
+
out_size: Number of regression targets
|
|
203
|
+
pretrained: Use ImageNet pretrained weights (default: True)
|
|
204
|
+
dropout_rate: Dropout rate in head (default: 0.2)
|
|
205
|
+
freeze_backbone: Freeze backbone for fine-tuning (default: False)
|
|
206
|
+
regression_hidden: Hidden units in regression head (default: 256)
|
|
207
|
+
|
|
208
|
+
Example:
|
|
209
|
+
>>> model = MobileNetV3Small(in_shape=(224, 224), out_size=3)
|
|
210
|
+
>>> x = torch.randn(1, 1, 224, 224)
|
|
211
|
+
>>> out = model(x) # (1, 3)
|
|
212
|
+
"""
|
|
213
|
+
|
|
214
|
+
def __init__(self, in_shape: tuple[int, int], out_size: int, **kwargs):
|
|
215
|
+
super().__init__(
|
|
216
|
+
in_shape=in_shape,
|
|
217
|
+
out_size=out_size,
|
|
218
|
+
model_fn=mobilenet_v3_small,
|
|
219
|
+
weights_class=MobileNet_V3_Small_Weights,
|
|
220
|
+
**kwargs,
|
|
221
|
+
)
|
|
222
|
+
|
|
223
|
+
def __repr__(self) -> str:
|
|
224
|
+
pt = "pretrained" if self.pretrained else "scratch"
|
|
225
|
+
return f"MobileNetV3_Small({pt}, in={self.in_shape}, out={self.out_size})"
|
|
226
|
+
|
|
227
|
+
|
|
228
|
+
@register_model("mobilenet_v3_large")
|
|
229
|
+
class MobileNetV3Large(MobileNetV3Base):
|
|
230
|
+
"""
|
|
231
|
+
MobileNetV3-Large: Balanced efficiency and accuracy.
|
|
232
|
+
|
|
233
|
+
~3.2M parameters. Best trade-off for mobile/portable deployment.
|
|
234
|
+
Achieves ~75% ImageNet accuracy with efficient inference.
|
|
235
|
+
|
|
236
|
+
Recommended for:
|
|
237
|
+
- Mobile deployment (smartphones, tablets)
|
|
238
|
+
- Portable inspection devices
|
|
239
|
+
- Real-time processing with moderate accuracy needs
|
|
240
|
+
- Default choice for edge deployment
|
|
241
|
+
|
|
242
|
+
Performance (approximate):
|
|
243
|
+
- CPU inference: ~20ms (single core)
|
|
244
|
+
- Parameters: 5.4M
|
|
245
|
+
- MAdds: 219M
|
|
246
|
+
|
|
247
|
+
Args:
|
|
248
|
+
in_shape: (H, W) image dimensions
|
|
249
|
+
out_size: Number of regression targets
|
|
250
|
+
pretrained: Use ImageNet pretrained weights (default: True)
|
|
251
|
+
dropout_rate: Dropout rate in head (default: 0.2)
|
|
252
|
+
freeze_backbone: Freeze backbone for fine-tuning (default: False)
|
|
253
|
+
regression_hidden: Hidden units in regression head (default: 256)
|
|
254
|
+
|
|
255
|
+
Example:
|
|
256
|
+
>>> model = MobileNetV3Large(in_shape=(224, 224), out_size=3)
|
|
257
|
+
>>> x = torch.randn(1, 1, 224, 224)
|
|
258
|
+
>>> out = model(x) # (1, 3)
|
|
259
|
+
"""
|
|
260
|
+
|
|
261
|
+
def __init__(self, in_shape: tuple[int, int], out_size: int, **kwargs):
|
|
262
|
+
super().__init__(
|
|
263
|
+
in_shape=in_shape,
|
|
264
|
+
out_size=out_size,
|
|
265
|
+
model_fn=mobilenet_v3_large,
|
|
266
|
+
weights_class=MobileNet_V3_Large_Weights,
|
|
267
|
+
**kwargs,
|
|
268
|
+
)
|
|
269
|
+
|
|
270
|
+
def __repr__(self) -> str:
|
|
271
|
+
pt = "pretrained" if self.pretrained else "scratch"
|
|
272
|
+
return f"MobileNetV3_Large({pt}, in={self.in_shape}, out={self.out_size})"
|
wavedl/models/registry.py
CHANGED