wavedl 1.5.7__py3-none-any.whl → 1.6.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- wavedl/__init__.py +1 -1
- wavedl/hpo.py +451 -451
- wavedl/models/__init__.py +80 -4
- wavedl/models/_pretrained_utils.py +366 -0
- wavedl/models/base.py +48 -0
- wavedl/models/caformer.py +270 -0
- wavedl/models/cnn.py +2 -27
- wavedl/models/convnext.py +113 -51
- wavedl/models/convnext_v2.py +488 -0
- wavedl/models/densenet.py +10 -23
- wavedl/models/efficientnet.py +6 -6
- wavedl/models/efficientnetv2.py +315 -315
- wavedl/models/efficientvit.py +398 -0
- wavedl/models/fastvit.py +252 -0
- wavedl/models/mamba.py +555 -0
- wavedl/models/maxvit.py +254 -0
- wavedl/models/mobilenetv3.py +295 -295
- wavedl/models/regnet.py +406 -406
- wavedl/models/resnet.py +19 -61
- wavedl/models/resnet3d.py +258 -258
- wavedl/models/swin.py +443 -443
- wavedl/models/tcn.py +393 -409
- wavedl/models/unet.py +2 -6
- wavedl/models/unireplknet.py +491 -0
- wavedl/models/vit.py +9 -9
- wavedl/train.py +1430 -1425
- wavedl/utils/config.py +367 -367
- wavedl/utils/cross_validation.py +530 -530
- wavedl/utils/data.py +39 -6
- wavedl/utils/losses.py +216 -216
- wavedl/utils/optimizers.py +216 -216
- wavedl/utils/schedulers.py +251 -251
- {wavedl-1.5.7.dist-info → wavedl-1.6.1.dist-info}/METADATA +150 -82
- wavedl-1.6.1.dist-info/RECORD +46 -0
- wavedl-1.5.7.dist-info/RECORD +0 -38
- {wavedl-1.5.7.dist-info → wavedl-1.6.1.dist-info}/LICENSE +0 -0
- {wavedl-1.5.7.dist-info → wavedl-1.6.1.dist-info}/WHEEL +0 -0
- {wavedl-1.5.7.dist-info → wavedl-1.6.1.dist-info}/entry_points.txt +0 -0
- {wavedl-1.5.7.dist-info → wavedl-1.6.1.dist-info}/top_level.txt +0 -0
wavedl/models/efficientnetv2.py
CHANGED
|
@@ -1,315 +1,315 @@
|
|
|
1
|
-
"""
|
|
2
|
-
EfficientNetV2: Faster Training and Better Accuracy
|
|
3
|
-
====================================================
|
|
4
|
-
|
|
5
|
-
Next-generation EfficientNet with improved training efficiency and performance.
|
|
6
|
-
EfficientNetV2 replaces early depthwise convolutions with fused MBConv blocks,
|
|
7
|
-
enabling 2-4× faster training while achieving better accuracy.
|
|
8
|
-
|
|
9
|
-
**Key Improvements over EfficientNet**:
|
|
10
|
-
- Fused-MBConv in early stages (faster on accelerators)
|
|
11
|
-
- Progressive learning support (start small, grow)
|
|
12
|
-
- Better NAS-optimized architecture
|
|
13
|
-
|
|
14
|
-
**Variants**:
|
|
15
|
-
- efficientnet_v2_s: Small (21.5M params) - Recommended default
|
|
16
|
-
- efficientnet_v2_m: Medium (54.1M params) - Higher accuracy
|
|
17
|
-
- efficientnet_v2_l: Large (118.5M params) - Maximum accuracy
|
|
18
|
-
|
|
19
|
-
**Note**: EfficientNetV2 is 2D-only. For 1D data, use TCN. For 3D data, use ResNet3D.
|
|
20
|
-
|
|
21
|
-
References:
|
|
22
|
-
Tan, M., & Le, Q. (2021). EfficientNetV2: Smaller Models and Faster Training.
|
|
23
|
-
ICML 2021. https://arxiv.org/abs/2104.00298
|
|
24
|
-
|
|
25
|
-
Author: Ductho Le (ductho.le@outlook.com)
|
|
26
|
-
"""
|
|
27
|
-
|
|
28
|
-
from typing import Any
|
|
29
|
-
|
|
30
|
-
import torch
|
|
31
|
-
import torch.nn as nn
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
try:
|
|
35
|
-
from torchvision.models import (
|
|
36
|
-
EfficientNet_V2_L_Weights,
|
|
37
|
-
EfficientNet_V2_M_Weights,
|
|
38
|
-
EfficientNet_V2_S_Weights,
|
|
39
|
-
efficientnet_v2_l,
|
|
40
|
-
efficientnet_v2_m,
|
|
41
|
-
efficientnet_v2_s,
|
|
42
|
-
)
|
|
43
|
-
|
|
44
|
-
EFFICIENTNETV2_AVAILABLE = True
|
|
45
|
-
except ImportError:
|
|
46
|
-
EFFICIENTNETV2_AVAILABLE = False
|
|
47
|
-
|
|
48
|
-
from wavedl.models.base import BaseModel
|
|
49
|
-
from wavedl.models.registry import register_model
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
class EfficientNetV2Base(BaseModel):
|
|
53
|
-
"""
|
|
54
|
-
Base EfficientNetV2 class for regression tasks.
|
|
55
|
-
|
|
56
|
-
Wraps torchvision EfficientNetV2 with:
|
|
57
|
-
- Optional pretrained weights (ImageNet-1K)
|
|
58
|
-
- Automatic input channel adaptation (grayscale → 3ch)
|
|
59
|
-
- Custom multi-layer regression head
|
|
60
|
-
|
|
61
|
-
Compared to EfficientNet (V1):
|
|
62
|
-
- 2-4× faster training on GPU/TPU
|
|
63
|
-
- Better accuracy at similar parameter counts
|
|
64
|
-
- More efficient at higher resolutions
|
|
65
|
-
|
|
66
|
-
Note: This is 2D-only. Input shape must be (H, W).
|
|
67
|
-
"""
|
|
68
|
-
|
|
69
|
-
def __init__(
|
|
70
|
-
self,
|
|
71
|
-
in_shape: tuple[int, int],
|
|
72
|
-
out_size: int,
|
|
73
|
-
model_fn,
|
|
74
|
-
weights_class,
|
|
75
|
-
pretrained: bool = True,
|
|
76
|
-
dropout_rate: float = 0.3,
|
|
77
|
-
freeze_backbone: bool = False,
|
|
78
|
-
regression_hidden: int = 512,
|
|
79
|
-
**kwargs,
|
|
80
|
-
):
|
|
81
|
-
"""
|
|
82
|
-
Initialize EfficientNetV2 for regression.
|
|
83
|
-
|
|
84
|
-
Args:
|
|
85
|
-
in_shape: (H, W) input image dimensions
|
|
86
|
-
out_size: Number of regression output targets
|
|
87
|
-
model_fn: torchvision model constructor
|
|
88
|
-
weights_class: Pretrained weights enum class
|
|
89
|
-
pretrained: Use ImageNet pretrained weights (default: True)
|
|
90
|
-
dropout_rate: Dropout rate in regression head (default: 0.3)
|
|
91
|
-
freeze_backbone: Freeze backbone for fine-tuning (default: False)
|
|
92
|
-
regression_hidden: Hidden units in regression head (default: 512)
|
|
93
|
-
"""
|
|
94
|
-
super().__init__(in_shape, out_size)
|
|
95
|
-
|
|
96
|
-
if not EFFICIENTNETV2_AVAILABLE:
|
|
97
|
-
raise ImportError(
|
|
98
|
-
"torchvision >= 0.13 is required for EfficientNetV2. "
|
|
99
|
-
"Install with: pip install torchvision>=0.13"
|
|
100
|
-
)
|
|
101
|
-
|
|
102
|
-
if len(in_shape) != 2:
|
|
103
|
-
raise ValueError(
|
|
104
|
-
f"EfficientNetV2 requires 2D input (H, W), got {len(in_shape)}D. "
|
|
105
|
-
"For 1D data, use TCN. For 3D data, use ResNet3D."
|
|
106
|
-
)
|
|
107
|
-
|
|
108
|
-
self.pretrained = pretrained
|
|
109
|
-
self.dropout_rate = dropout_rate
|
|
110
|
-
self.freeze_backbone = freeze_backbone
|
|
111
|
-
self.regression_hidden = regression_hidden
|
|
112
|
-
|
|
113
|
-
# Load pretrained backbone
|
|
114
|
-
weights = weights_class.IMAGENET1K_V1 if pretrained else None
|
|
115
|
-
self.backbone = model_fn(weights=weights)
|
|
116
|
-
|
|
117
|
-
# Get classifier input features (before the final classification layer)
|
|
118
|
-
in_features = self.backbone.classifier[1].in_features
|
|
119
|
-
|
|
120
|
-
# Replace classifier with regression head
|
|
121
|
-
# EfficientNetV2 benefits from a deeper regression head
|
|
122
|
-
self.backbone.classifier = nn.Sequential(
|
|
123
|
-
nn.Dropout(dropout_rate),
|
|
124
|
-
nn.Linear(in_features, regression_hidden),
|
|
125
|
-
nn.SiLU(inplace=True), # SiLU (Swish) matches EfficientNet's activation
|
|
126
|
-
nn.Dropout(dropout_rate * 0.5),
|
|
127
|
-
nn.Linear(regression_hidden, regression_hidden // 2),
|
|
128
|
-
nn.SiLU(inplace=True),
|
|
129
|
-
nn.Linear(regression_hidden // 2, out_size),
|
|
130
|
-
)
|
|
131
|
-
|
|
132
|
-
# Adapt first conv for single-channel input (3× memory savings vs expand)
|
|
133
|
-
self._adapt_input_channels()
|
|
134
|
-
|
|
135
|
-
# Optionally freeze backbone for fine-tuning (after adaptation so new conv is frozen too)
|
|
136
|
-
if freeze_backbone:
|
|
137
|
-
self._freeze_backbone()
|
|
138
|
-
|
|
139
|
-
def _adapt_input_channels(self):
|
|
140
|
-
"""Modify first conv to accept single-channel input.
|
|
141
|
-
|
|
142
|
-
Instead of expanding 1→3 channels in forward (which triples memory),
|
|
143
|
-
we replace the first conv layer with a 1-channel version and initialize
|
|
144
|
-
weights as the mean of the pretrained RGB filters.
|
|
145
|
-
"""
|
|
146
|
-
old_conv = self.backbone.features[0][0]
|
|
147
|
-
new_conv = nn.Conv2d(
|
|
148
|
-
1, # Single channel input
|
|
149
|
-
old_conv.out_channels,
|
|
150
|
-
kernel_size=old_conv.kernel_size,
|
|
151
|
-
stride=old_conv.stride,
|
|
152
|
-
padding=old_conv.padding,
|
|
153
|
-
dilation=old_conv.dilation,
|
|
154
|
-
groups=old_conv.groups,
|
|
155
|
-
padding_mode=old_conv.padding_mode,
|
|
156
|
-
bias=old_conv.bias is not None,
|
|
157
|
-
)
|
|
158
|
-
if self.pretrained:
|
|
159
|
-
with torch.no_grad():
|
|
160
|
-
new_conv.weight.copy_(old_conv.weight.mean(dim=1, keepdim=True))
|
|
161
|
-
self.backbone.features[0][0] = new_conv
|
|
162
|
-
|
|
163
|
-
def _freeze_backbone(self):
|
|
164
|
-
"""Freeze all backbone parameters except the classifier."""
|
|
165
|
-
for name, param in self.backbone.named_parameters():
|
|
166
|
-
if "classifier" not in name:
|
|
167
|
-
param.requires_grad = False
|
|
168
|
-
|
|
169
|
-
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
170
|
-
"""
|
|
171
|
-
Forward pass.
|
|
172
|
-
|
|
173
|
-
Args:
|
|
174
|
-
x: Input tensor of shape (B, 1, H, W)
|
|
175
|
-
|
|
176
|
-
Returns:
|
|
177
|
-
Output tensor of shape (B, out_size)
|
|
178
|
-
"""
|
|
179
|
-
return self.backbone(x)
|
|
180
|
-
|
|
181
|
-
@classmethod
|
|
182
|
-
def get_default_config(cls) -> dict[str, Any]:
|
|
183
|
-
"""Return default configuration for EfficientNetV2."""
|
|
184
|
-
return {
|
|
185
|
-
"pretrained": True,
|
|
186
|
-
"dropout_rate": 0.3,
|
|
187
|
-
"freeze_backbone": False,
|
|
188
|
-
"regression_hidden": 512,
|
|
189
|
-
}
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
# =============================================================================
|
|
193
|
-
# REGISTERED MODEL VARIANTS
|
|
194
|
-
# =============================================================================
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
@register_model("efficientnet_v2_s")
|
|
198
|
-
class EfficientNetV2S(EfficientNetV2Base):
|
|
199
|
-
"""
|
|
200
|
-
EfficientNetV2-S: Small variant, recommended default.
|
|
201
|
-
|
|
202
|
-
~
|
|
203
|
-
2× faster training than EfficientNet-B4 with better accuracy.
|
|
204
|
-
|
|
205
|
-
Recommended for:
|
|
206
|
-
- Default choice for 2D wave data
|
|
207
|
-
- Moderate compute budgets
|
|
208
|
-
- When training speed matters
|
|
209
|
-
|
|
210
|
-
Args:
|
|
211
|
-
in_shape: (H, W) image dimensions
|
|
212
|
-
out_size: Number of regression targets
|
|
213
|
-
pretrained: Use ImageNet pretrained weights (default: True)
|
|
214
|
-
dropout_rate: Dropout rate in head (default: 0.3)
|
|
215
|
-
freeze_backbone: Freeze backbone for fine-tuning (default: False)
|
|
216
|
-
regression_hidden: Hidden units in regression head (default: 512)
|
|
217
|
-
|
|
218
|
-
Example:
|
|
219
|
-
>>> model = EfficientNetV2S(in_shape=(500, 500), out_size=3)
|
|
220
|
-
>>> x = torch.randn(4, 1, 500, 500)
|
|
221
|
-
>>> out = model(x) # (4, 3)
|
|
222
|
-
"""
|
|
223
|
-
|
|
224
|
-
def __init__(self, in_shape: tuple[int, int], out_size: int, **kwargs):
|
|
225
|
-
super().__init__(
|
|
226
|
-
in_shape=in_shape,
|
|
227
|
-
out_size=out_size,
|
|
228
|
-
model_fn=efficientnet_v2_s,
|
|
229
|
-
weights_class=EfficientNet_V2_S_Weights,
|
|
230
|
-
**kwargs,
|
|
231
|
-
)
|
|
232
|
-
|
|
233
|
-
def __repr__(self) -> str:
|
|
234
|
-
pt = "pretrained" if self.pretrained else "scratch"
|
|
235
|
-
return f"EfficientNetV2_S({pt}, in={self.in_shape}, out={self.out_size})"
|
|
236
|
-
|
|
237
|
-
|
|
238
|
-
@register_model("efficientnet_v2_m")
|
|
239
|
-
class EfficientNetV2M(EfficientNetV2Base):
|
|
240
|
-
"""
|
|
241
|
-
EfficientNetV2-M: Medium variant for higher accuracy.
|
|
242
|
-
|
|
243
|
-
~
|
|
244
|
-
|
|
245
|
-
Recommended for:
|
|
246
|
-
- Large datasets (>50k samples)
|
|
247
|
-
- Complex wave patterns
|
|
248
|
-
- When compute is not a bottleneck
|
|
249
|
-
|
|
250
|
-
Args:
|
|
251
|
-
in_shape: (H, W) image dimensions
|
|
252
|
-
out_size: Number of regression targets
|
|
253
|
-
pretrained: Use ImageNet pretrained weights (default: True)
|
|
254
|
-
dropout_rate: Dropout rate in head (default: 0.3)
|
|
255
|
-
freeze_backbone: Freeze backbone for fine-tuning (default: False)
|
|
256
|
-
regression_hidden: Hidden units in regression head (default: 512)
|
|
257
|
-
|
|
258
|
-
Example:
|
|
259
|
-
>>> model = EfficientNetV2M(in_shape=(500, 500), out_size=3)
|
|
260
|
-
>>> x = torch.randn(4, 1, 500, 500)
|
|
261
|
-
>>> out = model(x) # (4, 3)
|
|
262
|
-
"""
|
|
263
|
-
|
|
264
|
-
def __init__(self, in_shape: tuple[int, int], out_size: int, **kwargs):
|
|
265
|
-
super().__init__(
|
|
266
|
-
in_shape=in_shape,
|
|
267
|
-
out_size=out_size,
|
|
268
|
-
model_fn=efficientnet_v2_m,
|
|
269
|
-
weights_class=EfficientNet_V2_M_Weights,
|
|
270
|
-
**kwargs,
|
|
271
|
-
)
|
|
272
|
-
|
|
273
|
-
def __repr__(self) -> str:
|
|
274
|
-
pt = "pretrained" if self.pretrained else "scratch"
|
|
275
|
-
return f"EfficientNetV2_M({pt}, in={self.in_shape}, out={self.out_size})"
|
|
276
|
-
|
|
277
|
-
|
|
278
|
-
@register_model("efficientnet_v2_l")
|
|
279
|
-
class EfficientNetV2L(EfficientNetV2Base):
|
|
280
|
-
"""
|
|
281
|
-
EfficientNetV2-L: Large variant for maximum accuracy.
|
|
282
|
-
|
|
283
|
-
~
|
|
284
|
-
|
|
285
|
-
Recommended for:
|
|
286
|
-
- Very large datasets (>100k samples)
|
|
287
|
-
- When maximum accuracy is critical
|
|
288
|
-
- HPC environments with ample GPU memory
|
|
289
|
-
|
|
290
|
-
Args:
|
|
291
|
-
in_shape: (H, W) image dimensions
|
|
292
|
-
out_size: Number of regression targets
|
|
293
|
-
pretrained: Use ImageNet pretrained weights (default: True)
|
|
294
|
-
dropout_rate: Dropout rate in head (default: 0.3)
|
|
295
|
-
freeze_backbone: Freeze backbone for fine-tuning (default: False)
|
|
296
|
-
regression_hidden: Hidden units in regression head (default: 512)
|
|
297
|
-
|
|
298
|
-
Example:
|
|
299
|
-
>>> model = EfficientNetV2L(in_shape=(500, 500), out_size=3)
|
|
300
|
-
>>> x = torch.randn(4, 1, 500, 500)
|
|
301
|
-
>>> out = model(x) # (4, 3)
|
|
302
|
-
"""
|
|
303
|
-
|
|
304
|
-
def __init__(self, in_shape: tuple[int, int], out_size: int, **kwargs):
|
|
305
|
-
super().__init__(
|
|
306
|
-
in_shape=in_shape,
|
|
307
|
-
out_size=out_size,
|
|
308
|
-
model_fn=efficientnet_v2_l,
|
|
309
|
-
weights_class=EfficientNet_V2_L_Weights,
|
|
310
|
-
**kwargs,
|
|
311
|
-
)
|
|
312
|
-
|
|
313
|
-
def __repr__(self) -> str:
|
|
314
|
-
pt = "pretrained" if self.pretrained else "scratch"
|
|
315
|
-
return f"EfficientNetV2_L({pt}, in={self.in_shape}, out={self.out_size})"
|
|
1
|
+
"""
|
|
2
|
+
EfficientNetV2: Faster Training and Better Accuracy
|
|
3
|
+
====================================================
|
|
4
|
+
|
|
5
|
+
Next-generation EfficientNet with improved training efficiency and performance.
|
|
6
|
+
EfficientNetV2 replaces early depthwise convolutions with fused MBConv blocks,
|
|
7
|
+
enabling 2-4× faster training while achieving better accuracy.
|
|
8
|
+
|
|
9
|
+
**Key Improvements over EfficientNet**:
|
|
10
|
+
- Fused-MBConv in early stages (faster on accelerators)
|
|
11
|
+
- Progressive learning support (start small, grow)
|
|
12
|
+
- Better NAS-optimized architecture
|
|
13
|
+
|
|
14
|
+
**Variants**:
|
|
15
|
+
- efficientnet_v2_s: Small (21.5M params) - Recommended default
|
|
16
|
+
- efficientnet_v2_m: Medium (54.1M params) - Higher accuracy
|
|
17
|
+
- efficientnet_v2_l: Large (118.5M params) - Maximum accuracy
|
|
18
|
+
|
|
19
|
+
**Note**: EfficientNetV2 is 2D-only. For 1D data, use TCN. For 3D data, use ResNet3D.
|
|
20
|
+
|
|
21
|
+
References:
|
|
22
|
+
Tan, M., & Le, Q. (2021). EfficientNetV2: Smaller Models and Faster Training.
|
|
23
|
+
ICML 2021. https://arxiv.org/abs/2104.00298
|
|
24
|
+
|
|
25
|
+
Author: Ductho Le (ductho.le@outlook.com)
|
|
26
|
+
"""
|
|
27
|
+
|
|
28
|
+
from typing import Any
|
|
29
|
+
|
|
30
|
+
import torch
|
|
31
|
+
import torch.nn as nn
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
try:
|
|
35
|
+
from torchvision.models import (
|
|
36
|
+
EfficientNet_V2_L_Weights,
|
|
37
|
+
EfficientNet_V2_M_Weights,
|
|
38
|
+
EfficientNet_V2_S_Weights,
|
|
39
|
+
efficientnet_v2_l,
|
|
40
|
+
efficientnet_v2_m,
|
|
41
|
+
efficientnet_v2_s,
|
|
42
|
+
)
|
|
43
|
+
|
|
44
|
+
EFFICIENTNETV2_AVAILABLE = True
|
|
45
|
+
except ImportError:
|
|
46
|
+
EFFICIENTNETV2_AVAILABLE = False
|
|
47
|
+
|
|
48
|
+
from wavedl.models.base import BaseModel
|
|
49
|
+
from wavedl.models.registry import register_model
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
class EfficientNetV2Base(BaseModel):
|
|
53
|
+
"""
|
|
54
|
+
Base EfficientNetV2 class for regression tasks.
|
|
55
|
+
|
|
56
|
+
Wraps torchvision EfficientNetV2 with:
|
|
57
|
+
- Optional pretrained weights (ImageNet-1K)
|
|
58
|
+
- Automatic input channel adaptation (grayscale → 3ch)
|
|
59
|
+
- Custom multi-layer regression head
|
|
60
|
+
|
|
61
|
+
Compared to EfficientNet (V1):
|
|
62
|
+
- 2-4× faster training on GPU/TPU
|
|
63
|
+
- Better accuracy at similar parameter counts
|
|
64
|
+
- More efficient at higher resolutions
|
|
65
|
+
|
|
66
|
+
Note: This is 2D-only. Input shape must be (H, W).
|
|
67
|
+
"""
|
|
68
|
+
|
|
69
|
+
def __init__(
|
|
70
|
+
self,
|
|
71
|
+
in_shape: tuple[int, int],
|
|
72
|
+
out_size: int,
|
|
73
|
+
model_fn,
|
|
74
|
+
weights_class,
|
|
75
|
+
pretrained: bool = True,
|
|
76
|
+
dropout_rate: float = 0.3,
|
|
77
|
+
freeze_backbone: bool = False,
|
|
78
|
+
regression_hidden: int = 512,
|
|
79
|
+
**kwargs,
|
|
80
|
+
):
|
|
81
|
+
"""
|
|
82
|
+
Initialize EfficientNetV2 for regression.
|
|
83
|
+
|
|
84
|
+
Args:
|
|
85
|
+
in_shape: (H, W) input image dimensions
|
|
86
|
+
out_size: Number of regression output targets
|
|
87
|
+
model_fn: torchvision model constructor
|
|
88
|
+
weights_class: Pretrained weights enum class
|
|
89
|
+
pretrained: Use ImageNet pretrained weights (default: True)
|
|
90
|
+
dropout_rate: Dropout rate in regression head (default: 0.3)
|
|
91
|
+
freeze_backbone: Freeze backbone for fine-tuning (default: False)
|
|
92
|
+
regression_hidden: Hidden units in regression head (default: 512)
|
|
93
|
+
"""
|
|
94
|
+
super().__init__(in_shape, out_size)
|
|
95
|
+
|
|
96
|
+
if not EFFICIENTNETV2_AVAILABLE:
|
|
97
|
+
raise ImportError(
|
|
98
|
+
"torchvision >= 0.13 is required for EfficientNetV2. "
|
|
99
|
+
"Install with: pip install torchvision>=0.13"
|
|
100
|
+
)
|
|
101
|
+
|
|
102
|
+
if len(in_shape) != 2:
|
|
103
|
+
raise ValueError(
|
|
104
|
+
f"EfficientNetV2 requires 2D input (H, W), got {len(in_shape)}D. "
|
|
105
|
+
"For 1D data, use TCN. For 3D data, use ResNet3D."
|
|
106
|
+
)
|
|
107
|
+
|
|
108
|
+
self.pretrained = pretrained
|
|
109
|
+
self.dropout_rate = dropout_rate
|
|
110
|
+
self.freeze_backbone = freeze_backbone
|
|
111
|
+
self.regression_hidden = regression_hidden
|
|
112
|
+
|
|
113
|
+
# Load pretrained backbone
|
|
114
|
+
weights = weights_class.IMAGENET1K_V1 if pretrained else None
|
|
115
|
+
self.backbone = model_fn(weights=weights)
|
|
116
|
+
|
|
117
|
+
# Get classifier input features (before the final classification layer)
|
|
118
|
+
in_features = self.backbone.classifier[1].in_features
|
|
119
|
+
|
|
120
|
+
# Replace classifier with regression head
|
|
121
|
+
# EfficientNetV2 benefits from a deeper regression head
|
|
122
|
+
self.backbone.classifier = nn.Sequential(
|
|
123
|
+
nn.Dropout(dropout_rate),
|
|
124
|
+
nn.Linear(in_features, regression_hidden),
|
|
125
|
+
nn.SiLU(inplace=True), # SiLU (Swish) matches EfficientNet's activation
|
|
126
|
+
nn.Dropout(dropout_rate * 0.5),
|
|
127
|
+
nn.Linear(regression_hidden, regression_hidden // 2),
|
|
128
|
+
nn.SiLU(inplace=True),
|
|
129
|
+
nn.Linear(regression_hidden // 2, out_size),
|
|
130
|
+
)
|
|
131
|
+
|
|
132
|
+
# Adapt first conv for single-channel input (3× memory savings vs expand)
|
|
133
|
+
self._adapt_input_channels()
|
|
134
|
+
|
|
135
|
+
# Optionally freeze backbone for fine-tuning (after adaptation so new conv is frozen too)
|
|
136
|
+
if freeze_backbone:
|
|
137
|
+
self._freeze_backbone()
|
|
138
|
+
|
|
139
|
+
def _adapt_input_channels(self):
|
|
140
|
+
"""Modify first conv to accept single-channel input.
|
|
141
|
+
|
|
142
|
+
Instead of expanding 1→3 channels in forward (which triples memory),
|
|
143
|
+
we replace the first conv layer with a 1-channel version and initialize
|
|
144
|
+
weights as the mean of the pretrained RGB filters.
|
|
145
|
+
"""
|
|
146
|
+
old_conv = self.backbone.features[0][0]
|
|
147
|
+
new_conv = nn.Conv2d(
|
|
148
|
+
1, # Single channel input
|
|
149
|
+
old_conv.out_channels,
|
|
150
|
+
kernel_size=old_conv.kernel_size,
|
|
151
|
+
stride=old_conv.stride,
|
|
152
|
+
padding=old_conv.padding,
|
|
153
|
+
dilation=old_conv.dilation,
|
|
154
|
+
groups=old_conv.groups,
|
|
155
|
+
padding_mode=old_conv.padding_mode,
|
|
156
|
+
bias=old_conv.bias is not None,
|
|
157
|
+
)
|
|
158
|
+
if self.pretrained:
|
|
159
|
+
with torch.no_grad():
|
|
160
|
+
new_conv.weight.copy_(old_conv.weight.mean(dim=1, keepdim=True))
|
|
161
|
+
self.backbone.features[0][0] = new_conv
|
|
162
|
+
|
|
163
|
+
def _freeze_backbone(self):
|
|
164
|
+
"""Freeze all backbone parameters except the classifier."""
|
|
165
|
+
for name, param in self.backbone.named_parameters():
|
|
166
|
+
if "classifier" not in name:
|
|
167
|
+
param.requires_grad = False
|
|
168
|
+
|
|
169
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
170
|
+
"""
|
|
171
|
+
Forward pass.
|
|
172
|
+
|
|
173
|
+
Args:
|
|
174
|
+
x: Input tensor of shape (B, 1, H, W)
|
|
175
|
+
|
|
176
|
+
Returns:
|
|
177
|
+
Output tensor of shape (B, out_size)
|
|
178
|
+
"""
|
|
179
|
+
return self.backbone(x)
|
|
180
|
+
|
|
181
|
+
@classmethod
|
|
182
|
+
def get_default_config(cls) -> dict[str, Any]:
|
|
183
|
+
"""Return default configuration for EfficientNetV2."""
|
|
184
|
+
return {
|
|
185
|
+
"pretrained": True,
|
|
186
|
+
"dropout_rate": 0.3,
|
|
187
|
+
"freeze_backbone": False,
|
|
188
|
+
"regression_hidden": 512,
|
|
189
|
+
}
|
|
190
|
+
|
|
191
|
+
|
|
192
|
+
# =============================================================================
|
|
193
|
+
# REGISTERED MODEL VARIANTS
|
|
194
|
+
# =============================================================================
|
|
195
|
+
|
|
196
|
+
|
|
197
|
+
@register_model("efficientnet_v2_s")
|
|
198
|
+
class EfficientNetV2S(EfficientNetV2Base):
|
|
199
|
+
"""
|
|
200
|
+
EfficientNetV2-S: Small variant, recommended default.
|
|
201
|
+
|
|
202
|
+
~20.2M backbone parameters. Best balance of speed and accuracy for most tasks.
|
|
203
|
+
2× faster training than EfficientNet-B4 with better accuracy.
|
|
204
|
+
|
|
205
|
+
Recommended for:
|
|
206
|
+
- Default choice for 2D wave data
|
|
207
|
+
- Moderate compute budgets
|
|
208
|
+
- When training speed matters
|
|
209
|
+
|
|
210
|
+
Args:
|
|
211
|
+
in_shape: (H, W) image dimensions
|
|
212
|
+
out_size: Number of regression targets
|
|
213
|
+
pretrained: Use ImageNet pretrained weights (default: True)
|
|
214
|
+
dropout_rate: Dropout rate in head (default: 0.3)
|
|
215
|
+
freeze_backbone: Freeze backbone for fine-tuning (default: False)
|
|
216
|
+
regression_hidden: Hidden units in regression head (default: 512)
|
|
217
|
+
|
|
218
|
+
Example:
|
|
219
|
+
>>> model = EfficientNetV2S(in_shape=(500, 500), out_size=3)
|
|
220
|
+
>>> x = torch.randn(4, 1, 500, 500)
|
|
221
|
+
>>> out = model(x) # (4, 3)
|
|
222
|
+
"""
|
|
223
|
+
|
|
224
|
+
def __init__(self, in_shape: tuple[int, int], out_size: int, **kwargs):
|
|
225
|
+
super().__init__(
|
|
226
|
+
in_shape=in_shape,
|
|
227
|
+
out_size=out_size,
|
|
228
|
+
model_fn=efficientnet_v2_s,
|
|
229
|
+
weights_class=EfficientNet_V2_S_Weights,
|
|
230
|
+
**kwargs,
|
|
231
|
+
)
|
|
232
|
+
|
|
233
|
+
def __repr__(self) -> str:
|
|
234
|
+
pt = "pretrained" if self.pretrained else "scratch"
|
|
235
|
+
return f"EfficientNetV2_S({pt}, in={self.in_shape}, out={self.out_size})"
|
|
236
|
+
|
|
237
|
+
|
|
238
|
+
@register_model("efficientnet_v2_m")
|
|
239
|
+
class EfficientNetV2M(EfficientNetV2Base):
|
|
240
|
+
"""
|
|
241
|
+
EfficientNetV2-M: Medium variant for higher accuracy.
|
|
242
|
+
|
|
243
|
+
~52.9M backbone parameters. Use when accuracy is more important than speed.
|
|
244
|
+
|
|
245
|
+
Recommended for:
|
|
246
|
+
- Large datasets (>50k samples)
|
|
247
|
+
- Complex wave patterns
|
|
248
|
+
- When compute is not a bottleneck
|
|
249
|
+
|
|
250
|
+
Args:
|
|
251
|
+
in_shape: (H, W) image dimensions
|
|
252
|
+
out_size: Number of regression targets
|
|
253
|
+
pretrained: Use ImageNet pretrained weights (default: True)
|
|
254
|
+
dropout_rate: Dropout rate in head (default: 0.3)
|
|
255
|
+
freeze_backbone: Freeze backbone for fine-tuning (default: False)
|
|
256
|
+
regression_hidden: Hidden units in regression head (default: 512)
|
|
257
|
+
|
|
258
|
+
Example:
|
|
259
|
+
>>> model = EfficientNetV2M(in_shape=(500, 500), out_size=3)
|
|
260
|
+
>>> x = torch.randn(4, 1, 500, 500)
|
|
261
|
+
>>> out = model(x) # (4, 3)
|
|
262
|
+
"""
|
|
263
|
+
|
|
264
|
+
def __init__(self, in_shape: tuple[int, int], out_size: int, **kwargs):
|
|
265
|
+
super().__init__(
|
|
266
|
+
in_shape=in_shape,
|
|
267
|
+
out_size=out_size,
|
|
268
|
+
model_fn=efficientnet_v2_m,
|
|
269
|
+
weights_class=EfficientNet_V2_M_Weights,
|
|
270
|
+
**kwargs,
|
|
271
|
+
)
|
|
272
|
+
|
|
273
|
+
def __repr__(self) -> str:
|
|
274
|
+
pt = "pretrained" if self.pretrained else "scratch"
|
|
275
|
+
return f"EfficientNetV2_M({pt}, in={self.in_shape}, out={self.out_size})"
|
|
276
|
+
|
|
277
|
+
|
|
278
|
+
@register_model("efficientnet_v2_l")
|
|
279
|
+
class EfficientNetV2L(EfficientNetV2Base):
|
|
280
|
+
"""
|
|
281
|
+
EfficientNetV2-L: Large variant for maximum accuracy.
|
|
282
|
+
|
|
283
|
+
~117.2M backbone parameters. Use only with large datasets and sufficient compute.
|
|
284
|
+
|
|
285
|
+
Recommended for:
|
|
286
|
+
- Very large datasets (>100k samples)
|
|
287
|
+
- When maximum accuracy is critical
|
|
288
|
+
- HPC environments with ample GPU memory
|
|
289
|
+
|
|
290
|
+
Args:
|
|
291
|
+
in_shape: (H, W) image dimensions
|
|
292
|
+
out_size: Number of regression targets
|
|
293
|
+
pretrained: Use ImageNet pretrained weights (default: True)
|
|
294
|
+
dropout_rate: Dropout rate in head (default: 0.3)
|
|
295
|
+
freeze_backbone: Freeze backbone for fine-tuning (default: False)
|
|
296
|
+
regression_hidden: Hidden units in regression head (default: 512)
|
|
297
|
+
|
|
298
|
+
Example:
|
|
299
|
+
>>> model = EfficientNetV2L(in_shape=(500, 500), out_size=3)
|
|
300
|
+
>>> x = torch.randn(4, 1, 500, 500)
|
|
301
|
+
>>> out = model(x) # (4, 3)
|
|
302
|
+
"""
|
|
303
|
+
|
|
304
|
+
def __init__(self, in_shape: tuple[int, int], out_size: int, **kwargs):
|
|
305
|
+
super().__init__(
|
|
306
|
+
in_shape=in_shape,
|
|
307
|
+
out_size=out_size,
|
|
308
|
+
model_fn=efficientnet_v2_l,
|
|
309
|
+
weights_class=EfficientNet_V2_L_Weights,
|
|
310
|
+
**kwargs,
|
|
311
|
+
)
|
|
312
|
+
|
|
313
|
+
def __repr__(self) -> str:
|
|
314
|
+
pt = "pretrained" if self.pretrained else "scratch"
|
|
315
|
+
return f"EfficientNetV2_L({pt}, in={self.in_shape}, out={self.out_size})"
|