wavedl 1.6.0__py3-none-any.whl → 1.6.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- wavedl/__init__.py +1 -1
- wavedl/hpo.py +451 -451
- wavedl/models/__init__.py +28 -0
- wavedl/models/{_timm_utils.py → _pretrained_utils.py} +128 -0
- wavedl/models/base.py +48 -0
- wavedl/models/caformer.py +1 -1
- wavedl/models/cnn.py +2 -27
- wavedl/models/convnext.py +5 -18
- wavedl/models/convnext_v2.py +6 -22
- wavedl/models/densenet.py +5 -18
- wavedl/models/efficientnetv2.py +315 -315
- wavedl/models/efficientvit.py +398 -0
- wavedl/models/fastvit.py +6 -39
- wavedl/models/mamba.py +44 -24
- wavedl/models/maxvit.py +51 -48
- wavedl/models/mobilenetv3.py +295 -295
- wavedl/models/regnet.py +406 -406
- wavedl/models/resnet.py +14 -56
- wavedl/models/resnet3d.py +258 -258
- wavedl/models/swin.py +443 -443
- wavedl/models/tcn.py +393 -409
- wavedl/models/unet.py +1 -5
- wavedl/models/unireplknet.py +491 -0
- wavedl/models/vit.py +3 -3
- wavedl/train.py +1430 -1430
- wavedl/utils/config.py +367 -367
- wavedl/utils/cross_validation.py +530 -530
- wavedl/utils/losses.py +216 -216
- wavedl/utils/optimizers.py +216 -216
- wavedl/utils/schedulers.py +251 -251
- {wavedl-1.6.0.dist-info → wavedl-1.6.1.dist-info}/METADATA +93 -53
- wavedl-1.6.1.dist-info/RECORD +46 -0
- wavedl-1.6.0.dist-info/RECORD +0 -44
- {wavedl-1.6.0.dist-info → wavedl-1.6.1.dist-info}/LICENSE +0 -0
- {wavedl-1.6.0.dist-info → wavedl-1.6.1.dist-info}/WHEEL +0 -0
- {wavedl-1.6.0.dist-info → wavedl-1.6.1.dist-info}/entry_points.txt +0 -0
- {wavedl-1.6.0.dist-info → wavedl-1.6.1.dist-info}/top_level.txt +0 -0
wavedl/models/mobilenetv3.py
CHANGED
|
@@ -1,295 +1,295 @@
|
|
|
1
|
-
"""
|
|
2
|
-
MobileNetV3: Efficient Networks for Edge Deployment
|
|
3
|
-
====================================================
|
|
4
|
-
|
|
5
|
-
Lightweight architecture optimized for mobile and embedded devices.
|
|
6
|
-
MobileNetV3 combines neural architecture search (NAS) with hardware-aware
|
|
7
|
-
optimization to achieve excellent accuracy with minimal computational cost.
|
|
8
|
-
|
|
9
|
-
**Key Features**:
|
|
10
|
-
- Inverted residuals with depthwise separable convolutions
|
|
11
|
-
- Squeeze-and-Excitation (SE) attention for channel weighting
|
|
12
|
-
- h-swish activation: efficient approximation of swish
|
|
13
|
-
- Designed for real-time inference on CPUs and edge devices
|
|
14
|
-
|
|
15
|
-
**Variants**:
|
|
16
|
-
- mobilenet_v3_small: Ultra-lightweight (~0.9M backbone params) - Edge/embedded
|
|
17
|
-
- mobilenet_v3_large: Balanced (~3.0M backbone params) - Mobile deployment
|
|
18
|
-
|
|
19
|
-
**Use Cases**:
|
|
20
|
-
- Real-time structural health monitoring on embedded systems
|
|
21
|
-
- Field inspection with portable devices
|
|
22
|
-
- When model size and inference speed are critical
|
|
23
|
-
|
|
24
|
-
**Note**: MobileNetV3 is 2D-only. For 1D data, use TCN. For 3D data, use ResNet3D.
|
|
25
|
-
|
|
26
|
-
References:
|
|
27
|
-
Howard, A., et al. (2019). Searching for MobileNetV3.
|
|
28
|
-
ICCV 2019. https://arxiv.org/abs/1905.02244
|
|
29
|
-
|
|
30
|
-
Author: Ductho Le (ductho.le@outlook.com)
|
|
31
|
-
"""
|
|
32
|
-
|
|
33
|
-
from typing import Any
|
|
34
|
-
|
|
35
|
-
import torch
|
|
36
|
-
import torch.nn as nn
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
try:
|
|
40
|
-
from torchvision.models import (
|
|
41
|
-
MobileNet_V3_Large_Weights,
|
|
42
|
-
MobileNet_V3_Small_Weights,
|
|
43
|
-
mobilenet_v3_large,
|
|
44
|
-
mobilenet_v3_small,
|
|
45
|
-
)
|
|
46
|
-
|
|
47
|
-
MOBILENETV3_AVAILABLE = True
|
|
48
|
-
except ImportError:
|
|
49
|
-
MOBILENETV3_AVAILABLE = False
|
|
50
|
-
|
|
51
|
-
from wavedl.models.base import BaseModel
|
|
52
|
-
from wavedl.models.registry import register_model
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
class MobileNetV3Base(BaseModel):
|
|
56
|
-
"""
|
|
57
|
-
Base MobileNetV3 class for regression tasks.
|
|
58
|
-
|
|
59
|
-
Wraps torchvision MobileNetV3 with:
|
|
60
|
-
- Optional pretrained weights (ImageNet-1K)
|
|
61
|
-
- Automatic input channel adaptation (grayscale → 3ch)
|
|
62
|
-
- Lightweight regression head (maintains efficiency)
|
|
63
|
-
|
|
64
|
-
MobileNetV3 is ideal for:
|
|
65
|
-
- Edge deployment (Raspberry Pi, Jetson, mobile)
|
|
66
|
-
- Real-time inference requirements
|
|
67
|
-
- Memory-constrained environments
|
|
68
|
-
- Quick prototyping and experimentation
|
|
69
|
-
|
|
70
|
-
Note: This is 2D-only. Input shape must be (H, W).
|
|
71
|
-
"""
|
|
72
|
-
|
|
73
|
-
def __init__(
|
|
74
|
-
self,
|
|
75
|
-
in_shape: tuple[int, int],
|
|
76
|
-
out_size: int,
|
|
77
|
-
model_fn,
|
|
78
|
-
weights_class,
|
|
79
|
-
pretrained: bool = True,
|
|
80
|
-
dropout_rate: float = 0.2,
|
|
81
|
-
freeze_backbone: bool = False,
|
|
82
|
-
regression_hidden: int = 256,
|
|
83
|
-
**kwargs,
|
|
84
|
-
):
|
|
85
|
-
"""
|
|
86
|
-
Initialize MobileNetV3 for regression.
|
|
87
|
-
|
|
88
|
-
Args:
|
|
89
|
-
in_shape: (H, W) input image dimensions
|
|
90
|
-
out_size: Number of regression output targets
|
|
91
|
-
model_fn: torchvision model constructor
|
|
92
|
-
weights_class: Pretrained weights enum class
|
|
93
|
-
pretrained: Use ImageNet pretrained weights (default: True)
|
|
94
|
-
dropout_rate: Dropout rate in regression head (default: 0.2)
|
|
95
|
-
freeze_backbone: Freeze backbone for fine-tuning (default: False)
|
|
96
|
-
regression_hidden: Hidden units in regression head (default: 256)
|
|
97
|
-
"""
|
|
98
|
-
super().__init__(in_shape, out_size)
|
|
99
|
-
|
|
100
|
-
if not MOBILENETV3_AVAILABLE:
|
|
101
|
-
raise ImportError(
|
|
102
|
-
"torchvision is required for MobileNetV3. "
|
|
103
|
-
"Install with: pip install torchvision"
|
|
104
|
-
)
|
|
105
|
-
|
|
106
|
-
if len(in_shape) != 2:
|
|
107
|
-
raise ValueError(
|
|
108
|
-
f"MobileNetV3 requires 2D input (H, W), got {len(in_shape)}D. "
|
|
109
|
-
"For 1D data, use TCN. For 3D data, use ResNet3D."
|
|
110
|
-
)
|
|
111
|
-
|
|
112
|
-
self.pretrained = pretrained
|
|
113
|
-
self.dropout_rate = dropout_rate
|
|
114
|
-
self.freeze_backbone = freeze_backbone
|
|
115
|
-
self.regression_hidden = regression_hidden
|
|
116
|
-
|
|
117
|
-
# Load pretrained backbone
|
|
118
|
-
weights = weights_class.IMAGENET1K_V1 if pretrained else None
|
|
119
|
-
self.backbone = model_fn(weights=weights)
|
|
120
|
-
|
|
121
|
-
# MobileNetV3 classifier structure:
|
|
122
|
-
# classifier[0]: Linear (features → 1280 for Large, 1024 for Small)
|
|
123
|
-
# classifier[1]: Hardswish
|
|
124
|
-
# classifier[2]: Dropout
|
|
125
|
-
# classifier[3]: Linear (1280/1024 → num_classes)
|
|
126
|
-
|
|
127
|
-
# Get the input features to the final classifier
|
|
128
|
-
in_features = self.backbone.classifier[0].in_features
|
|
129
|
-
|
|
130
|
-
# Replace classifier with lightweight regression head
|
|
131
|
-
# Keep it efficient to maintain MobileNet's speed advantage
|
|
132
|
-
self.backbone.classifier = nn.Sequential(
|
|
133
|
-
nn.Linear(in_features, regression_hidden),
|
|
134
|
-
nn.Hardswish(inplace=True), # Match MobileNetV3's activation
|
|
135
|
-
nn.Dropout(dropout_rate),
|
|
136
|
-
nn.Linear(regression_hidden, out_size),
|
|
137
|
-
)
|
|
138
|
-
|
|
139
|
-
# Adapt first conv for single-channel input (3× memory savings vs expand)
|
|
140
|
-
self._adapt_input_channels()
|
|
141
|
-
|
|
142
|
-
# Optionally freeze backbone for fine-tuning (after adaptation so new conv is frozen too)
|
|
143
|
-
if freeze_backbone:
|
|
144
|
-
self._freeze_backbone()
|
|
145
|
-
|
|
146
|
-
def _adapt_input_channels(self):
|
|
147
|
-
"""Modify first conv to accept single-channel input.
|
|
148
|
-
|
|
149
|
-
Instead of expanding 1→3 channels in forward (which triples memory),
|
|
150
|
-
we replace the first conv layer with a 1-channel version and initialize
|
|
151
|
-
weights as the mean of the pretrained RGB filters.
|
|
152
|
-
"""
|
|
153
|
-
old_conv = self.backbone.features[0][0]
|
|
154
|
-
new_conv = nn.Conv2d(
|
|
155
|
-
1, # Single channel input
|
|
156
|
-
old_conv.out_channels,
|
|
157
|
-
kernel_size=old_conv.kernel_size,
|
|
158
|
-
stride=old_conv.stride,
|
|
159
|
-
padding=old_conv.padding,
|
|
160
|
-
dilation=old_conv.dilation,
|
|
161
|
-
groups=old_conv.groups,
|
|
162
|
-
padding_mode=old_conv.padding_mode,
|
|
163
|
-
bias=old_conv.bias is not None,
|
|
164
|
-
)
|
|
165
|
-
if self.pretrained:
|
|
166
|
-
with torch.no_grad():
|
|
167
|
-
new_conv.weight.copy_(old_conv.weight.mean(dim=1, keepdim=True))
|
|
168
|
-
self.backbone.features[0][0] = new_conv
|
|
169
|
-
|
|
170
|
-
def _freeze_backbone(self):
|
|
171
|
-
"""Freeze all backbone parameters except the classifier."""
|
|
172
|
-
for name, param in self.backbone.named_parameters():
|
|
173
|
-
if "classifier" not in name:
|
|
174
|
-
param.requires_grad = False
|
|
175
|
-
|
|
176
|
-
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
177
|
-
"""
|
|
178
|
-
Forward pass.
|
|
179
|
-
|
|
180
|
-
Args:
|
|
181
|
-
x: Input tensor of shape (B, 1, H, W)
|
|
182
|
-
|
|
183
|
-
Returns:
|
|
184
|
-
Output tensor of shape (B, out_size)
|
|
185
|
-
"""
|
|
186
|
-
return self.backbone(x)
|
|
187
|
-
|
|
188
|
-
@classmethod
|
|
189
|
-
def get_default_config(cls) -> dict[str, Any]:
|
|
190
|
-
"""Return default configuration for MobileNetV3."""
|
|
191
|
-
return {
|
|
192
|
-
"pretrained": True,
|
|
193
|
-
"dropout_rate": 0.2,
|
|
194
|
-
"freeze_backbone": False,
|
|
195
|
-
"regression_hidden": 256,
|
|
196
|
-
}
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
# =============================================================================
|
|
200
|
-
# REGISTERED MODEL VARIANTS
|
|
201
|
-
# =============================================================================
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
@register_model("mobilenet_v3_small")
|
|
205
|
-
class MobileNetV3Small(MobileNetV3Base):
|
|
206
|
-
"""
|
|
207
|
-
MobileNetV3-Small: Ultra-lightweight for edge deployment.
|
|
208
|
-
|
|
209
|
-
~0.9M backbone parameters. Designed for the most constrained environments.
|
|
210
|
-
Achieves ~67% ImageNet accuracy with minimal compute.
|
|
211
|
-
|
|
212
|
-
Recommended for:
|
|
213
|
-
- Embedded systems (Raspberry Pi, Arduino with accelerators)
|
|
214
|
-
- Battery-powered devices
|
|
215
|
-
- Ultra-low latency requirements (<10ms)
|
|
216
|
-
- Quick training experiments
|
|
217
|
-
|
|
218
|
-
Performance (approximate):
|
|
219
|
-
- CPU inference: ~6ms (single core)
|
|
220
|
-
- Parameters: ~0.9M backbone
|
|
221
|
-
- MAdds: 56M
|
|
222
|
-
|
|
223
|
-
Args:
|
|
224
|
-
in_shape: (H, W) image dimensions
|
|
225
|
-
out_size: Number of regression targets
|
|
226
|
-
pretrained: Use ImageNet pretrained weights (default: True)
|
|
227
|
-
dropout_rate: Dropout rate in head (default: 0.2)
|
|
228
|
-
freeze_backbone: Freeze backbone for fine-tuning (default: False)
|
|
229
|
-
regression_hidden: Hidden units in regression head (default: 256)
|
|
230
|
-
|
|
231
|
-
Example:
|
|
232
|
-
>>> model = MobileNetV3Small(in_shape=(224, 224), out_size=3)
|
|
233
|
-
>>> x = torch.randn(1, 1, 224, 224)
|
|
234
|
-
>>> out = model(x) # (1, 3)
|
|
235
|
-
"""
|
|
236
|
-
|
|
237
|
-
def __init__(self, in_shape: tuple[int, int], out_size: int, **kwargs):
|
|
238
|
-
super().__init__(
|
|
239
|
-
in_shape=in_shape,
|
|
240
|
-
out_size=out_size,
|
|
241
|
-
model_fn=mobilenet_v3_small,
|
|
242
|
-
weights_class=MobileNet_V3_Small_Weights,
|
|
243
|
-
**kwargs,
|
|
244
|
-
)
|
|
245
|
-
|
|
246
|
-
def __repr__(self) -> str:
|
|
247
|
-
pt = "pretrained" if self.pretrained else "scratch"
|
|
248
|
-
return f"MobileNetV3_Small({pt}, in={self.in_shape}, out={self.out_size})"
|
|
249
|
-
|
|
250
|
-
|
|
251
|
-
@register_model("mobilenet_v3_large")
|
|
252
|
-
class MobileNetV3Large(MobileNetV3Base):
|
|
253
|
-
"""
|
|
254
|
-
MobileNetV3-Large: Balanced efficiency and accuracy.
|
|
255
|
-
|
|
256
|
-
~3.0M backbone parameters. Best trade-off for mobile/portable deployment.
|
|
257
|
-
Achieves ~75% ImageNet accuracy with efficient inference.
|
|
258
|
-
|
|
259
|
-
Recommended for:
|
|
260
|
-
- Mobile deployment (smartphones, tablets)
|
|
261
|
-
- Portable inspection devices
|
|
262
|
-
- Real-time processing with moderate accuracy needs
|
|
263
|
-
- Default choice for edge deployment
|
|
264
|
-
|
|
265
|
-
Performance (approximate):
|
|
266
|
-
- CPU inference: ~20ms (single core)
|
|
267
|
-
- Parameters: ~3.0M backbone
|
|
268
|
-
- MAdds: 219M
|
|
269
|
-
|
|
270
|
-
Args:
|
|
271
|
-
in_shape: (H, W) image dimensions
|
|
272
|
-
out_size: Number of regression targets
|
|
273
|
-
pretrained: Use ImageNet pretrained weights (default: True)
|
|
274
|
-
dropout_rate: Dropout rate in head (default: 0.2)
|
|
275
|
-
freeze_backbone: Freeze backbone for fine-tuning (default: False)
|
|
276
|
-
regression_hidden: Hidden units in regression head (default: 256)
|
|
277
|
-
|
|
278
|
-
Example:
|
|
279
|
-
>>> model = MobileNetV3Large(in_shape=(224, 224), out_size=3)
|
|
280
|
-
>>> x = torch.randn(1, 1, 224, 224)
|
|
281
|
-
>>> out = model(x) # (1, 3)
|
|
282
|
-
"""
|
|
283
|
-
|
|
284
|
-
def __init__(self, in_shape: tuple[int, int], out_size: int, **kwargs):
|
|
285
|
-
super().__init__(
|
|
286
|
-
in_shape=in_shape,
|
|
287
|
-
out_size=out_size,
|
|
288
|
-
model_fn=mobilenet_v3_large,
|
|
289
|
-
weights_class=MobileNet_V3_Large_Weights,
|
|
290
|
-
**kwargs,
|
|
291
|
-
)
|
|
292
|
-
|
|
293
|
-
def __repr__(self) -> str:
|
|
294
|
-
pt = "pretrained" if self.pretrained else "scratch"
|
|
295
|
-
return f"MobileNetV3_Large({pt}, in={self.in_shape}, out={self.out_size})"
|
|
1
|
+
"""
|
|
2
|
+
MobileNetV3: Efficient Networks for Edge Deployment
|
|
3
|
+
====================================================
|
|
4
|
+
|
|
5
|
+
Lightweight architecture optimized for mobile and embedded devices.
|
|
6
|
+
MobileNetV3 combines neural architecture search (NAS) with hardware-aware
|
|
7
|
+
optimization to achieve excellent accuracy with minimal computational cost.
|
|
8
|
+
|
|
9
|
+
**Key Features**:
|
|
10
|
+
- Inverted residuals with depthwise separable convolutions
|
|
11
|
+
- Squeeze-and-Excitation (SE) attention for channel weighting
|
|
12
|
+
- h-swish activation: efficient approximation of swish
|
|
13
|
+
- Designed for real-time inference on CPUs and edge devices
|
|
14
|
+
|
|
15
|
+
**Variants**:
|
|
16
|
+
- mobilenet_v3_small: Ultra-lightweight (~0.9M backbone params) - Edge/embedded
|
|
17
|
+
- mobilenet_v3_large: Balanced (~3.0M backbone params) - Mobile deployment
|
|
18
|
+
|
|
19
|
+
**Use Cases**:
|
|
20
|
+
- Real-time structural health monitoring on embedded systems
|
|
21
|
+
- Field inspection with portable devices
|
|
22
|
+
- When model size and inference speed are critical
|
|
23
|
+
|
|
24
|
+
**Note**: MobileNetV3 is 2D-only. For 1D data, use TCN. For 3D data, use ResNet3D.
|
|
25
|
+
|
|
26
|
+
References:
|
|
27
|
+
Howard, A., et al. (2019). Searching for MobileNetV3.
|
|
28
|
+
ICCV 2019. https://arxiv.org/abs/1905.02244
|
|
29
|
+
|
|
30
|
+
Author: Ductho Le (ductho.le@outlook.com)
|
|
31
|
+
"""
|
|
32
|
+
|
|
33
|
+
from typing import Any
|
|
34
|
+
|
|
35
|
+
import torch
|
|
36
|
+
import torch.nn as nn
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
try:
|
|
40
|
+
from torchvision.models import (
|
|
41
|
+
MobileNet_V3_Large_Weights,
|
|
42
|
+
MobileNet_V3_Small_Weights,
|
|
43
|
+
mobilenet_v3_large,
|
|
44
|
+
mobilenet_v3_small,
|
|
45
|
+
)
|
|
46
|
+
|
|
47
|
+
MOBILENETV3_AVAILABLE = True
|
|
48
|
+
except ImportError:
|
|
49
|
+
MOBILENETV3_AVAILABLE = False
|
|
50
|
+
|
|
51
|
+
from wavedl.models.base import BaseModel
|
|
52
|
+
from wavedl.models.registry import register_model
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
class MobileNetV3Base(BaseModel):
|
|
56
|
+
"""
|
|
57
|
+
Base MobileNetV3 class for regression tasks.
|
|
58
|
+
|
|
59
|
+
Wraps torchvision MobileNetV3 with:
|
|
60
|
+
- Optional pretrained weights (ImageNet-1K)
|
|
61
|
+
- Automatic input channel adaptation (grayscale → 3ch)
|
|
62
|
+
- Lightweight regression head (maintains efficiency)
|
|
63
|
+
|
|
64
|
+
MobileNetV3 is ideal for:
|
|
65
|
+
- Edge deployment (Raspberry Pi, Jetson, mobile)
|
|
66
|
+
- Real-time inference requirements
|
|
67
|
+
- Memory-constrained environments
|
|
68
|
+
- Quick prototyping and experimentation
|
|
69
|
+
|
|
70
|
+
Note: This is 2D-only. Input shape must be (H, W).
|
|
71
|
+
"""
|
|
72
|
+
|
|
73
|
+
def __init__(
|
|
74
|
+
self,
|
|
75
|
+
in_shape: tuple[int, int],
|
|
76
|
+
out_size: int,
|
|
77
|
+
model_fn,
|
|
78
|
+
weights_class,
|
|
79
|
+
pretrained: bool = True,
|
|
80
|
+
dropout_rate: float = 0.2,
|
|
81
|
+
freeze_backbone: bool = False,
|
|
82
|
+
regression_hidden: int = 256,
|
|
83
|
+
**kwargs,
|
|
84
|
+
):
|
|
85
|
+
"""
|
|
86
|
+
Initialize MobileNetV3 for regression.
|
|
87
|
+
|
|
88
|
+
Args:
|
|
89
|
+
in_shape: (H, W) input image dimensions
|
|
90
|
+
out_size: Number of regression output targets
|
|
91
|
+
model_fn: torchvision model constructor
|
|
92
|
+
weights_class: Pretrained weights enum class
|
|
93
|
+
pretrained: Use ImageNet pretrained weights (default: True)
|
|
94
|
+
dropout_rate: Dropout rate in regression head (default: 0.2)
|
|
95
|
+
freeze_backbone: Freeze backbone for fine-tuning (default: False)
|
|
96
|
+
regression_hidden: Hidden units in regression head (default: 256)
|
|
97
|
+
"""
|
|
98
|
+
super().__init__(in_shape, out_size)
|
|
99
|
+
|
|
100
|
+
if not MOBILENETV3_AVAILABLE:
|
|
101
|
+
raise ImportError(
|
|
102
|
+
"torchvision is required for MobileNetV3. "
|
|
103
|
+
"Install with: pip install torchvision"
|
|
104
|
+
)
|
|
105
|
+
|
|
106
|
+
if len(in_shape) != 2:
|
|
107
|
+
raise ValueError(
|
|
108
|
+
f"MobileNetV3 requires 2D input (H, W), got {len(in_shape)}D. "
|
|
109
|
+
"For 1D data, use TCN. For 3D data, use ResNet3D."
|
|
110
|
+
)
|
|
111
|
+
|
|
112
|
+
self.pretrained = pretrained
|
|
113
|
+
self.dropout_rate = dropout_rate
|
|
114
|
+
self.freeze_backbone = freeze_backbone
|
|
115
|
+
self.regression_hidden = regression_hidden
|
|
116
|
+
|
|
117
|
+
# Load pretrained backbone
|
|
118
|
+
weights = weights_class.IMAGENET1K_V1 if pretrained else None
|
|
119
|
+
self.backbone = model_fn(weights=weights)
|
|
120
|
+
|
|
121
|
+
# MobileNetV3 classifier structure:
|
|
122
|
+
# classifier[0]: Linear (features → 1280 for Large, 1024 for Small)
|
|
123
|
+
# classifier[1]: Hardswish
|
|
124
|
+
# classifier[2]: Dropout
|
|
125
|
+
# classifier[3]: Linear (1280/1024 → num_classes)
|
|
126
|
+
|
|
127
|
+
# Get the input features to the final classifier
|
|
128
|
+
in_features = self.backbone.classifier[0].in_features
|
|
129
|
+
|
|
130
|
+
# Replace classifier with lightweight regression head
|
|
131
|
+
# Keep it efficient to maintain MobileNet's speed advantage
|
|
132
|
+
self.backbone.classifier = nn.Sequential(
|
|
133
|
+
nn.Linear(in_features, regression_hidden),
|
|
134
|
+
nn.Hardswish(inplace=True), # Match MobileNetV3's activation
|
|
135
|
+
nn.Dropout(dropout_rate),
|
|
136
|
+
nn.Linear(regression_hidden, out_size),
|
|
137
|
+
)
|
|
138
|
+
|
|
139
|
+
# Adapt first conv for single-channel input (3× memory savings vs expand)
|
|
140
|
+
self._adapt_input_channels()
|
|
141
|
+
|
|
142
|
+
# Optionally freeze backbone for fine-tuning (after adaptation so new conv is frozen too)
|
|
143
|
+
if freeze_backbone:
|
|
144
|
+
self._freeze_backbone()
|
|
145
|
+
|
|
146
|
+
def _adapt_input_channels(self):
|
|
147
|
+
"""Modify first conv to accept single-channel input.
|
|
148
|
+
|
|
149
|
+
Instead of expanding 1→3 channels in forward (which triples memory),
|
|
150
|
+
we replace the first conv layer with a 1-channel version and initialize
|
|
151
|
+
weights as the mean of the pretrained RGB filters.
|
|
152
|
+
"""
|
|
153
|
+
old_conv = self.backbone.features[0][0]
|
|
154
|
+
new_conv = nn.Conv2d(
|
|
155
|
+
1, # Single channel input
|
|
156
|
+
old_conv.out_channels,
|
|
157
|
+
kernel_size=old_conv.kernel_size,
|
|
158
|
+
stride=old_conv.stride,
|
|
159
|
+
padding=old_conv.padding,
|
|
160
|
+
dilation=old_conv.dilation,
|
|
161
|
+
groups=old_conv.groups,
|
|
162
|
+
padding_mode=old_conv.padding_mode,
|
|
163
|
+
bias=old_conv.bias is not None,
|
|
164
|
+
)
|
|
165
|
+
if self.pretrained:
|
|
166
|
+
with torch.no_grad():
|
|
167
|
+
new_conv.weight.copy_(old_conv.weight.mean(dim=1, keepdim=True))
|
|
168
|
+
self.backbone.features[0][0] = new_conv
|
|
169
|
+
|
|
170
|
+
def _freeze_backbone(self):
|
|
171
|
+
"""Freeze all backbone parameters except the classifier."""
|
|
172
|
+
for name, param in self.backbone.named_parameters():
|
|
173
|
+
if "classifier" not in name:
|
|
174
|
+
param.requires_grad = False
|
|
175
|
+
|
|
176
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
177
|
+
"""
|
|
178
|
+
Forward pass.
|
|
179
|
+
|
|
180
|
+
Args:
|
|
181
|
+
x: Input tensor of shape (B, 1, H, W)
|
|
182
|
+
|
|
183
|
+
Returns:
|
|
184
|
+
Output tensor of shape (B, out_size)
|
|
185
|
+
"""
|
|
186
|
+
return self.backbone(x)
|
|
187
|
+
|
|
188
|
+
@classmethod
|
|
189
|
+
def get_default_config(cls) -> dict[str, Any]:
|
|
190
|
+
"""Return default configuration for MobileNetV3."""
|
|
191
|
+
return {
|
|
192
|
+
"pretrained": True,
|
|
193
|
+
"dropout_rate": 0.2,
|
|
194
|
+
"freeze_backbone": False,
|
|
195
|
+
"regression_hidden": 256,
|
|
196
|
+
}
|
|
197
|
+
|
|
198
|
+
|
|
199
|
+
# =============================================================================
|
|
200
|
+
# REGISTERED MODEL VARIANTS
|
|
201
|
+
# =============================================================================
|
|
202
|
+
|
|
203
|
+
|
|
204
|
+
@register_model("mobilenet_v3_small")
|
|
205
|
+
class MobileNetV3Small(MobileNetV3Base):
|
|
206
|
+
"""
|
|
207
|
+
MobileNetV3-Small: Ultra-lightweight for edge deployment.
|
|
208
|
+
|
|
209
|
+
~0.9M backbone parameters. Designed for the most constrained environments.
|
|
210
|
+
Achieves ~67% ImageNet accuracy with minimal compute.
|
|
211
|
+
|
|
212
|
+
Recommended for:
|
|
213
|
+
- Embedded systems (Raspberry Pi, Arduino with accelerators)
|
|
214
|
+
- Battery-powered devices
|
|
215
|
+
- Ultra-low latency requirements (<10ms)
|
|
216
|
+
- Quick training experiments
|
|
217
|
+
|
|
218
|
+
Performance (approximate):
|
|
219
|
+
- CPU inference: ~6ms (single core)
|
|
220
|
+
- Parameters: ~0.9M backbone
|
|
221
|
+
- MAdds: 56M
|
|
222
|
+
|
|
223
|
+
Args:
|
|
224
|
+
in_shape: (H, W) image dimensions
|
|
225
|
+
out_size: Number of regression targets
|
|
226
|
+
pretrained: Use ImageNet pretrained weights (default: True)
|
|
227
|
+
dropout_rate: Dropout rate in head (default: 0.2)
|
|
228
|
+
freeze_backbone: Freeze backbone for fine-tuning (default: False)
|
|
229
|
+
regression_hidden: Hidden units in regression head (default: 256)
|
|
230
|
+
|
|
231
|
+
Example:
|
|
232
|
+
>>> model = MobileNetV3Small(in_shape=(224, 224), out_size=3)
|
|
233
|
+
>>> x = torch.randn(1, 1, 224, 224)
|
|
234
|
+
>>> out = model(x) # (1, 3)
|
|
235
|
+
"""
|
|
236
|
+
|
|
237
|
+
def __init__(self, in_shape: tuple[int, int], out_size: int, **kwargs):
|
|
238
|
+
super().__init__(
|
|
239
|
+
in_shape=in_shape,
|
|
240
|
+
out_size=out_size,
|
|
241
|
+
model_fn=mobilenet_v3_small,
|
|
242
|
+
weights_class=MobileNet_V3_Small_Weights,
|
|
243
|
+
**kwargs,
|
|
244
|
+
)
|
|
245
|
+
|
|
246
|
+
def __repr__(self) -> str:
|
|
247
|
+
pt = "pretrained" if self.pretrained else "scratch"
|
|
248
|
+
return f"MobileNetV3_Small({pt}, in={self.in_shape}, out={self.out_size})"
|
|
249
|
+
|
|
250
|
+
|
|
251
|
+
@register_model("mobilenet_v3_large")
|
|
252
|
+
class MobileNetV3Large(MobileNetV3Base):
|
|
253
|
+
"""
|
|
254
|
+
MobileNetV3-Large: Balanced efficiency and accuracy.
|
|
255
|
+
|
|
256
|
+
~3.0M backbone parameters. Best trade-off for mobile/portable deployment.
|
|
257
|
+
Achieves ~75% ImageNet accuracy with efficient inference.
|
|
258
|
+
|
|
259
|
+
Recommended for:
|
|
260
|
+
- Mobile deployment (smartphones, tablets)
|
|
261
|
+
- Portable inspection devices
|
|
262
|
+
- Real-time processing with moderate accuracy needs
|
|
263
|
+
- Default choice for edge deployment
|
|
264
|
+
|
|
265
|
+
Performance (approximate):
|
|
266
|
+
- CPU inference: ~20ms (single core)
|
|
267
|
+
- Parameters: ~3.0M backbone
|
|
268
|
+
- MAdds: 219M
|
|
269
|
+
|
|
270
|
+
Args:
|
|
271
|
+
in_shape: (H, W) image dimensions
|
|
272
|
+
out_size: Number of regression targets
|
|
273
|
+
pretrained: Use ImageNet pretrained weights (default: True)
|
|
274
|
+
dropout_rate: Dropout rate in head (default: 0.2)
|
|
275
|
+
freeze_backbone: Freeze backbone for fine-tuning (default: False)
|
|
276
|
+
regression_hidden: Hidden units in regression head (default: 256)
|
|
277
|
+
|
|
278
|
+
Example:
|
|
279
|
+
>>> model = MobileNetV3Large(in_shape=(224, 224), out_size=3)
|
|
280
|
+
>>> x = torch.randn(1, 1, 224, 224)
|
|
281
|
+
>>> out = model(x) # (1, 3)
|
|
282
|
+
"""
|
|
283
|
+
|
|
284
|
+
def __init__(self, in_shape: tuple[int, int], out_size: int, **kwargs):
|
|
285
|
+
super().__init__(
|
|
286
|
+
in_shape=in_shape,
|
|
287
|
+
out_size=out_size,
|
|
288
|
+
model_fn=mobilenet_v3_large,
|
|
289
|
+
weights_class=MobileNet_V3_Large_Weights,
|
|
290
|
+
**kwargs,
|
|
291
|
+
)
|
|
292
|
+
|
|
293
|
+
def __repr__(self) -> str:
|
|
294
|
+
pt = "pretrained" if self.pretrained else "scratch"
|
|
295
|
+
return f"MobileNetV3_Large({pt}, in={self.in_shape}, out={self.out_size})"
|