wavedl 1.5.6__py3-none-any.whl → 1.6.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- wavedl/__init__.py +1 -1
- wavedl/models/__init__.py +52 -4
- wavedl/models/_timm_utils.py +238 -0
- wavedl/models/caformer.py +270 -0
- wavedl/models/convnext.py +108 -33
- wavedl/models/convnext_v2.py +504 -0
- wavedl/models/densenet.py +5 -5
- wavedl/models/efficientnet.py +30 -13
- wavedl/models/efficientnetv2.py +32 -9
- wavedl/models/fastvit.py +285 -0
- wavedl/models/mamba.py +535 -0
- wavedl/models/maxvit.py +251 -0
- wavedl/models/mobilenetv3.py +35 -12
- wavedl/models/regnet.py +39 -16
- wavedl/models/resnet.py +5 -5
- wavedl/models/resnet3d.py +2 -2
- wavedl/models/swin.py +41 -9
- wavedl/models/tcn.py +25 -5
- wavedl/models/unet.py +1 -1
- wavedl/models/vit.py +6 -6
- wavedl/test.py +7 -3
- wavedl/train.py +57 -23
- wavedl/utils/constraints.py +11 -5
- wavedl/utils/data.py +120 -18
- wavedl/utils/metrics.py +287 -326
- {wavedl-1.5.6.dist-info → wavedl-1.6.0.dist-info}/METADATA +104 -67
- wavedl-1.6.0.dist-info/RECORD +44 -0
- wavedl-1.5.6.dist-info/RECORD +0 -38
- {wavedl-1.5.6.dist-info → wavedl-1.6.0.dist-info}/LICENSE +0 -0
- {wavedl-1.5.6.dist-info → wavedl-1.6.0.dist-info}/WHEEL +0 -0
- {wavedl-1.5.6.dist-info → wavedl-1.6.0.dist-info}/entry_points.txt +0 -0
- {wavedl-1.5.6.dist-info → wavedl-1.6.0.dist-info}/top_level.txt +0 -0
wavedl/models/maxvit.py
ADDED
|
@@ -0,0 +1,251 @@
|
|
|
1
|
+
"""
|
|
2
|
+
MaxViT: Multi-Axis Vision Transformer
|
|
3
|
+
======================================
|
|
4
|
+
|
|
5
|
+
MaxViT combines local and global attention with O(n) complexity using
|
|
6
|
+
multi-axis attention: block attention (local) + grid attention (global sparse).
|
|
7
|
+
|
|
8
|
+
**Key Features**:
|
|
9
|
+
- Multi-axis attention for both local and global context
|
|
10
|
+
- Hybrid design with MBConv + attention
|
|
11
|
+
- Linear O(n) complexity
|
|
12
|
+
- Hierarchical multi-scale features
|
|
13
|
+
|
|
14
|
+
**Variants**:
|
|
15
|
+
- maxvit_tiny: 31M params
|
|
16
|
+
- maxvit_small: 69M params
|
|
17
|
+
- maxvit_base: 120M params
|
|
18
|
+
|
|
19
|
+
**Requirements**:
|
|
20
|
+
- timm (for pretrained models and architecture)
|
|
21
|
+
- torchvision (fallback, limited support)
|
|
22
|
+
|
|
23
|
+
Reference:
|
|
24
|
+
Tu, Z., et al. (2022). MaxViT: Multi-Axis Vision Transformer.
|
|
25
|
+
ECCV 2022. https://arxiv.org/abs/2204.01697
|
|
26
|
+
|
|
27
|
+
Author: Ductho Le (ductho.le@outlook.com)
|
|
28
|
+
"""
|
|
29
|
+
|
|
30
|
+
import torch
|
|
31
|
+
import torch.nn as nn
|
|
32
|
+
|
|
33
|
+
from wavedl.models._timm_utils import build_regression_head
|
|
34
|
+
from wavedl.models.base import BaseModel
|
|
35
|
+
from wavedl.models.registry import register_model
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
__all__ = [
|
|
39
|
+
"MaxViTBase",
|
|
40
|
+
"MaxViTBaseLarge",
|
|
41
|
+
"MaxViTSmall",
|
|
42
|
+
"MaxViTTiny",
|
|
43
|
+
]
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
# =============================================================================
|
|
47
|
+
# MAXVIT BASE CLASS
|
|
48
|
+
# =============================================================================
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
class MaxViTBase(BaseModel):
|
|
52
|
+
"""
|
|
53
|
+
MaxViT base class wrapping timm implementation.
|
|
54
|
+
|
|
55
|
+
Multi-axis attention with local block and global grid attention.
|
|
56
|
+
2D only due to attention structure.
|
|
57
|
+
"""
|
|
58
|
+
|
|
59
|
+
def __init__(
|
|
60
|
+
self,
|
|
61
|
+
in_shape: tuple[int, int],
|
|
62
|
+
out_size: int,
|
|
63
|
+
model_name: str = "maxvit_tiny_tf_224",
|
|
64
|
+
pretrained: bool = True,
|
|
65
|
+
freeze_backbone: bool = False,
|
|
66
|
+
dropout_rate: float = 0.3,
|
|
67
|
+
**kwargs,
|
|
68
|
+
):
|
|
69
|
+
super().__init__(in_shape, out_size)
|
|
70
|
+
|
|
71
|
+
if len(in_shape) != 2:
|
|
72
|
+
raise ValueError(f"MaxViT requires 2D input (H, W), got {len(in_shape)}D")
|
|
73
|
+
|
|
74
|
+
self.pretrained = pretrained
|
|
75
|
+
self.freeze_backbone = freeze_backbone
|
|
76
|
+
self.model_name = model_name
|
|
77
|
+
|
|
78
|
+
# Try to load from timm
|
|
79
|
+
try:
|
|
80
|
+
import timm
|
|
81
|
+
|
|
82
|
+
self.backbone = timm.create_model(
|
|
83
|
+
model_name,
|
|
84
|
+
pretrained=pretrained,
|
|
85
|
+
num_classes=0, # Remove classifier
|
|
86
|
+
)
|
|
87
|
+
|
|
88
|
+
# Get feature dimension
|
|
89
|
+
with torch.no_grad():
|
|
90
|
+
dummy = torch.zeros(1, 3, *in_shape)
|
|
91
|
+
features = self.backbone(dummy)
|
|
92
|
+
in_features = features.shape[-1]
|
|
93
|
+
|
|
94
|
+
except ImportError:
|
|
95
|
+
raise ImportError(
|
|
96
|
+
"timm is required for MaxViT. Install with: pip install timm"
|
|
97
|
+
)
|
|
98
|
+
except Exception as e:
|
|
99
|
+
raise RuntimeError(f"Failed to load MaxViT model '{model_name}': {e}")
|
|
100
|
+
|
|
101
|
+
# Adapt input channels (3 -> 1)
|
|
102
|
+
self._adapt_input_channels()
|
|
103
|
+
|
|
104
|
+
# Regression head
|
|
105
|
+
self.head = build_regression_head(in_features, out_size, dropout_rate)
|
|
106
|
+
|
|
107
|
+
if freeze_backbone:
|
|
108
|
+
self._freeze_backbone()
|
|
109
|
+
|
|
110
|
+
def _adapt_input_channels(self):
|
|
111
|
+
"""Adapt first conv layer for single-channel input."""
|
|
112
|
+
# MaxViT uses stem.conv1 (Conv2dSame from timm)
|
|
113
|
+
adapted = False
|
|
114
|
+
|
|
115
|
+
# Find the first Conv2d with 3 input channels
|
|
116
|
+
for name, module in self.backbone.named_modules():
|
|
117
|
+
if hasattr(module, "in_channels") and module.in_channels == 3:
|
|
118
|
+
# Get parent and child names
|
|
119
|
+
parts = name.split(".")
|
|
120
|
+
parent = self.backbone
|
|
121
|
+
for part in parts[:-1]:
|
|
122
|
+
parent = getattr(parent, part)
|
|
123
|
+
child_name = parts[-1]
|
|
124
|
+
|
|
125
|
+
# Create new conv with 1 input channel
|
|
126
|
+
new_conv = self._make_new_conv(module)
|
|
127
|
+
setattr(parent, child_name, new_conv)
|
|
128
|
+
adapted = True
|
|
129
|
+
break
|
|
130
|
+
|
|
131
|
+
if not adapted:
|
|
132
|
+
import warnings
|
|
133
|
+
|
|
134
|
+
warnings.warn(
|
|
135
|
+
"Could not adapt MaxViT input channels. Model may fail.", stacklevel=2
|
|
136
|
+
)
|
|
137
|
+
|
|
138
|
+
def _make_new_conv(self, old_conv: nn.Module) -> nn.Module:
|
|
139
|
+
"""Create new conv layer with 1 input channel."""
|
|
140
|
+
# Handle both Conv2d and Conv2dSame from timm
|
|
141
|
+
type(old_conv)
|
|
142
|
+
|
|
143
|
+
# Get common parameters
|
|
144
|
+
kwargs = {
|
|
145
|
+
"out_channels": old_conv.out_channels,
|
|
146
|
+
"kernel_size": old_conv.kernel_size,
|
|
147
|
+
"stride": old_conv.stride,
|
|
148
|
+
"padding": old_conv.padding if hasattr(old_conv, "padding") else 0,
|
|
149
|
+
"bias": old_conv.bias is not None,
|
|
150
|
+
}
|
|
151
|
+
|
|
152
|
+
# Create new conv (use regular Conv2d for simplicity)
|
|
153
|
+
new_conv = nn.Conv2d(1, **kwargs)
|
|
154
|
+
|
|
155
|
+
if self.pretrained:
|
|
156
|
+
with torch.no_grad():
|
|
157
|
+
new_conv.weight.copy_(old_conv.weight.mean(dim=1, keepdim=True))
|
|
158
|
+
if old_conv.bias is not None:
|
|
159
|
+
new_conv.bias.copy_(old_conv.bias)
|
|
160
|
+
return new_conv
|
|
161
|
+
|
|
162
|
+
def _freeze_backbone(self):
|
|
163
|
+
"""Freeze backbone parameters."""
|
|
164
|
+
for param in self.backbone.parameters():
|
|
165
|
+
param.requires_grad = False
|
|
166
|
+
|
|
167
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
168
|
+
features = self.backbone(x)
|
|
169
|
+
return self.head(features)
|
|
170
|
+
|
|
171
|
+
|
|
172
|
+
# =============================================================================
|
|
173
|
+
# REGISTERED VARIANTS
|
|
174
|
+
# =============================================================================
|
|
175
|
+
|
|
176
|
+
|
|
177
|
+
@register_model("maxvit_tiny")
|
|
178
|
+
class MaxViTTiny(MaxViTBase):
|
|
179
|
+
"""
|
|
180
|
+
MaxViT Tiny: ~30.1M backbone parameters.
|
|
181
|
+
|
|
182
|
+
Multi-axis attention with local+global context.
|
|
183
|
+
2D only.
|
|
184
|
+
|
|
185
|
+
Example:
|
|
186
|
+
>>> model = MaxViTTiny(in_shape=(224, 224), out_size=3)
|
|
187
|
+
>>> x = torch.randn(4, 1, 224, 224)
|
|
188
|
+
>>> out = model(x) # (4, 3)
|
|
189
|
+
"""
|
|
190
|
+
|
|
191
|
+
def __init__(self, in_shape: tuple[int, int], out_size: int, **kwargs):
|
|
192
|
+
super().__init__(
|
|
193
|
+
in_shape=in_shape,
|
|
194
|
+
out_size=out_size,
|
|
195
|
+
model_name="maxvit_tiny_tf_224",
|
|
196
|
+
**kwargs,
|
|
197
|
+
)
|
|
198
|
+
|
|
199
|
+
def __repr__(self) -> str:
|
|
200
|
+
return (
|
|
201
|
+
f"MaxViT_Tiny(in_shape={self.in_shape}, out_size={self.out_size}, "
|
|
202
|
+
f"pretrained={self.pretrained})"
|
|
203
|
+
)
|
|
204
|
+
|
|
205
|
+
|
|
206
|
+
@register_model("maxvit_small")
|
|
207
|
+
class MaxViTSmall(MaxViTBase):
|
|
208
|
+
"""
|
|
209
|
+
MaxViT Small: ~67.6M backbone parameters.
|
|
210
|
+
|
|
211
|
+
Multi-axis attention with local+global context.
|
|
212
|
+
2D only.
|
|
213
|
+
"""
|
|
214
|
+
|
|
215
|
+
def __init__(self, in_shape: tuple[int, int], out_size: int, **kwargs):
|
|
216
|
+
super().__init__(
|
|
217
|
+
in_shape=in_shape,
|
|
218
|
+
out_size=out_size,
|
|
219
|
+
model_name="maxvit_small_tf_224",
|
|
220
|
+
**kwargs,
|
|
221
|
+
)
|
|
222
|
+
|
|
223
|
+
def __repr__(self) -> str:
|
|
224
|
+
return (
|
|
225
|
+
f"MaxViT_Small(in_shape={self.in_shape}, out_size={self.out_size}, "
|
|
226
|
+
f"pretrained={self.pretrained})"
|
|
227
|
+
)
|
|
228
|
+
|
|
229
|
+
|
|
230
|
+
@register_model("maxvit_base")
|
|
231
|
+
class MaxViTBaseLarge(MaxViTBase):
|
|
232
|
+
"""
|
|
233
|
+
MaxViT Base: ~118.1M backbone parameters.
|
|
234
|
+
|
|
235
|
+
Multi-axis attention with local+global context.
|
|
236
|
+
2D only.
|
|
237
|
+
"""
|
|
238
|
+
|
|
239
|
+
def __init__(self, in_shape: tuple[int, int], out_size: int, **kwargs):
|
|
240
|
+
super().__init__(
|
|
241
|
+
in_shape=in_shape,
|
|
242
|
+
out_size=out_size,
|
|
243
|
+
model_name="maxvit_base_tf_224",
|
|
244
|
+
**kwargs,
|
|
245
|
+
)
|
|
246
|
+
|
|
247
|
+
def __repr__(self) -> str:
|
|
248
|
+
return (
|
|
249
|
+
f"MaxViT_Base(in_shape={self.in_shape}, out_size={self.out_size}, "
|
|
250
|
+
f"pretrained={self.pretrained})"
|
|
251
|
+
)
|
wavedl/models/mobilenetv3.py
CHANGED
|
@@ -13,8 +13,8 @@ optimization to achieve excellent accuracy with minimal computational cost.
|
|
|
13
13
|
- Designed for real-time inference on CPUs and edge devices
|
|
14
14
|
|
|
15
15
|
**Variants**:
|
|
16
|
-
- mobilenet_v3_small: Ultra-lightweight (~
|
|
17
|
-
- mobilenet_v3_large: Balanced (~3.
|
|
16
|
+
- mobilenet_v3_small: Ultra-lightweight (~0.9M backbone params) - Edge/embedded
|
|
17
|
+
- mobilenet_v3_large: Balanced (~3.0M backbone params) - Mobile deployment
|
|
18
18
|
|
|
19
19
|
**Use Cases**:
|
|
20
20
|
- Real-time structural health monitoring on embedded systems
|
|
@@ -136,10 +136,37 @@ class MobileNetV3Base(BaseModel):
|
|
|
136
136
|
nn.Linear(regression_hidden, out_size),
|
|
137
137
|
)
|
|
138
138
|
|
|
139
|
-
#
|
|
139
|
+
# Adapt first conv for single-channel input (3× memory savings vs expand)
|
|
140
|
+
self._adapt_input_channels()
|
|
141
|
+
|
|
142
|
+
# Optionally freeze backbone for fine-tuning (after adaptation so new conv is frozen too)
|
|
140
143
|
if freeze_backbone:
|
|
141
144
|
self._freeze_backbone()
|
|
142
145
|
|
|
146
|
+
def _adapt_input_channels(self):
|
|
147
|
+
"""Modify first conv to accept single-channel input.
|
|
148
|
+
|
|
149
|
+
Instead of expanding 1→3 channels in forward (which triples memory),
|
|
150
|
+
we replace the first conv layer with a 1-channel version and initialize
|
|
151
|
+
weights as the mean of the pretrained RGB filters.
|
|
152
|
+
"""
|
|
153
|
+
old_conv = self.backbone.features[0][0]
|
|
154
|
+
new_conv = nn.Conv2d(
|
|
155
|
+
1, # Single channel input
|
|
156
|
+
old_conv.out_channels,
|
|
157
|
+
kernel_size=old_conv.kernel_size,
|
|
158
|
+
stride=old_conv.stride,
|
|
159
|
+
padding=old_conv.padding,
|
|
160
|
+
dilation=old_conv.dilation,
|
|
161
|
+
groups=old_conv.groups,
|
|
162
|
+
padding_mode=old_conv.padding_mode,
|
|
163
|
+
bias=old_conv.bias is not None,
|
|
164
|
+
)
|
|
165
|
+
if self.pretrained:
|
|
166
|
+
with torch.no_grad():
|
|
167
|
+
new_conv.weight.copy_(old_conv.weight.mean(dim=1, keepdim=True))
|
|
168
|
+
self.backbone.features[0][0] = new_conv
|
|
169
|
+
|
|
143
170
|
def _freeze_backbone(self):
|
|
144
171
|
"""Freeze all backbone parameters except the classifier."""
|
|
145
172
|
for name, param in self.backbone.named_parameters():
|
|
@@ -151,15 +178,11 @@ class MobileNetV3Base(BaseModel):
|
|
|
151
178
|
Forward pass.
|
|
152
179
|
|
|
153
180
|
Args:
|
|
154
|
-
x: Input tensor of shape (B,
|
|
181
|
+
x: Input tensor of shape (B, 1, H, W)
|
|
155
182
|
|
|
156
183
|
Returns:
|
|
157
184
|
Output tensor of shape (B, out_size)
|
|
158
185
|
"""
|
|
159
|
-
# Expand single channel to 3 channels for pretrained weights compatibility
|
|
160
|
-
if x.size(1) == 1:
|
|
161
|
-
x = x.expand(-1, 3, -1, -1)
|
|
162
|
-
|
|
163
186
|
return self.backbone(x)
|
|
164
187
|
|
|
165
188
|
@classmethod
|
|
@@ -183,7 +206,7 @@ class MobileNetV3Small(MobileNetV3Base):
|
|
|
183
206
|
"""
|
|
184
207
|
MobileNetV3-Small: Ultra-lightweight for edge deployment.
|
|
185
208
|
|
|
186
|
-
~
|
|
209
|
+
~0.9M backbone parameters. Designed for the most constrained environments.
|
|
187
210
|
Achieves ~67% ImageNet accuracy with minimal compute.
|
|
188
211
|
|
|
189
212
|
Recommended for:
|
|
@@ -194,7 +217,7 @@ class MobileNetV3Small(MobileNetV3Base):
|
|
|
194
217
|
|
|
195
218
|
Performance (approximate):
|
|
196
219
|
- CPU inference: ~6ms (single core)
|
|
197
|
-
- Parameters:
|
|
220
|
+
- Parameters: ~0.9M backbone
|
|
198
221
|
- MAdds: 56M
|
|
199
222
|
|
|
200
223
|
Args:
|
|
@@ -230,7 +253,7 @@ class MobileNetV3Large(MobileNetV3Base):
|
|
|
230
253
|
"""
|
|
231
254
|
MobileNetV3-Large: Balanced efficiency and accuracy.
|
|
232
255
|
|
|
233
|
-
~3.
|
|
256
|
+
~3.0M backbone parameters. Best trade-off for mobile/portable deployment.
|
|
234
257
|
Achieves ~75% ImageNet accuracy with efficient inference.
|
|
235
258
|
|
|
236
259
|
Recommended for:
|
|
@@ -241,7 +264,7 @@ class MobileNetV3Large(MobileNetV3Base):
|
|
|
241
264
|
|
|
242
265
|
Performance (approximate):
|
|
243
266
|
- CPU inference: ~20ms (single core)
|
|
244
|
-
- Parameters:
|
|
267
|
+
- Parameters: ~3.0M backbone
|
|
245
268
|
- MAdds: 219M
|
|
246
269
|
|
|
247
270
|
Args:
|
wavedl/models/regnet.py
CHANGED
|
@@ -13,11 +13,11 @@ Models scale smoothly from mobile to server deployments.
|
|
|
13
13
|
- Optional Squeeze-and-Excitation (SE) attention
|
|
14
14
|
|
|
15
15
|
**Variants** (RegNetY includes SE attention):
|
|
16
|
-
- regnet_y_400mf: Ultra-light (~
|
|
17
|
-
- regnet_y_800mf: Light (~5.
|
|
18
|
-
- regnet_y_1_6gf: Medium (~10.
|
|
19
|
-
- regnet_y_3_2gf: Large (~
|
|
20
|
-
- regnet_y_8gf: Very large (~37.
|
|
16
|
+
- regnet_y_400mf: Ultra-light (~3.9M backbone params, 0.4 GFLOPs)
|
|
17
|
+
- regnet_y_800mf: Light (~5.7M backbone params, 0.8 GFLOPs)
|
|
18
|
+
- regnet_y_1_6gf: Medium (~10.3M backbone params, 1.6 GFLOPs) - Recommended
|
|
19
|
+
- regnet_y_3_2gf: Large (~17.9M backbone params, 3.2 GFLOPs)
|
|
20
|
+
- regnet_y_8gf: Very large (~37.4M backbone params, 8.0 GFLOPs)
|
|
21
21
|
|
|
22
22
|
**When to Use RegNet**:
|
|
23
23
|
- When you need predictable performance at a given compute budget
|
|
@@ -140,10 +140,37 @@ class RegNetBase(BaseModel):
|
|
|
140
140
|
nn.Linear(regression_hidden, out_size),
|
|
141
141
|
)
|
|
142
142
|
|
|
143
|
-
#
|
|
143
|
+
# Adapt first conv for single-channel input (3× memory savings vs expand)
|
|
144
|
+
self._adapt_input_channels()
|
|
145
|
+
|
|
146
|
+
# Optionally freeze backbone for fine-tuning (after adaptation so new conv is frozen too)
|
|
144
147
|
if freeze_backbone:
|
|
145
148
|
self._freeze_backbone()
|
|
146
149
|
|
|
150
|
+
def _adapt_input_channels(self):
|
|
151
|
+
"""Modify first conv to accept single-channel input.
|
|
152
|
+
|
|
153
|
+
Instead of expanding 1→3 channels in forward (which triples memory),
|
|
154
|
+
we replace the first conv layer with a 1-channel version and initialize
|
|
155
|
+
weights as the mean of the pretrained RGB filters.
|
|
156
|
+
"""
|
|
157
|
+
old_conv = self.backbone.stem[0]
|
|
158
|
+
new_conv = nn.Conv2d(
|
|
159
|
+
1, # Single channel input
|
|
160
|
+
old_conv.out_channels,
|
|
161
|
+
kernel_size=old_conv.kernel_size,
|
|
162
|
+
stride=old_conv.stride,
|
|
163
|
+
padding=old_conv.padding,
|
|
164
|
+
dilation=old_conv.dilation,
|
|
165
|
+
groups=old_conv.groups,
|
|
166
|
+
padding_mode=old_conv.padding_mode,
|
|
167
|
+
bias=old_conv.bias is not None,
|
|
168
|
+
)
|
|
169
|
+
if self.pretrained:
|
|
170
|
+
with torch.no_grad():
|
|
171
|
+
new_conv.weight.copy_(old_conv.weight.mean(dim=1, keepdim=True))
|
|
172
|
+
self.backbone.stem[0] = new_conv
|
|
173
|
+
|
|
147
174
|
def _freeze_backbone(self):
|
|
148
175
|
"""Freeze all backbone parameters except the fc layer."""
|
|
149
176
|
for name, param in self.backbone.named_parameters():
|
|
@@ -155,15 +182,11 @@ class RegNetBase(BaseModel):
|
|
|
155
182
|
Forward pass.
|
|
156
183
|
|
|
157
184
|
Args:
|
|
158
|
-
x: Input tensor of shape (B,
|
|
185
|
+
x: Input tensor of shape (B, 1, H, W)
|
|
159
186
|
|
|
160
187
|
Returns:
|
|
161
188
|
Output tensor of shape (B, out_size)
|
|
162
189
|
"""
|
|
163
|
-
# Expand single channel to 3 channels for pretrained weights compatibility
|
|
164
|
-
if x.size(1) == 1:
|
|
165
|
-
x = x.expand(-1, 3, -1, -1)
|
|
166
|
-
|
|
167
190
|
return self.backbone(x)
|
|
168
191
|
|
|
169
192
|
@classmethod
|
|
@@ -187,7 +210,7 @@ class RegNetY400MF(RegNetBase):
|
|
|
187
210
|
"""
|
|
188
211
|
RegNetY-400MF: Ultra-lightweight for constrained environments.
|
|
189
212
|
|
|
190
|
-
~
|
|
213
|
+
~3.9M backbone parameters, 0.4 GFLOPs. Smallest RegNet variant with SE attention.
|
|
191
214
|
|
|
192
215
|
Recommended for:
|
|
193
216
|
- Edge deployment with moderate accuracy needs
|
|
@@ -227,7 +250,7 @@ class RegNetY800MF(RegNetBase):
|
|
|
227
250
|
"""
|
|
228
251
|
RegNetY-800MF: Light variant with good accuracy.
|
|
229
252
|
|
|
230
|
-
~
|
|
253
|
+
~5.7M backbone parameters, 0.8 GFLOPs. Good balance for mobile deployment.
|
|
231
254
|
|
|
232
255
|
Recommended for:
|
|
233
256
|
- Mobile/portable devices
|
|
@@ -267,7 +290,7 @@ class RegNetY1_6GF(RegNetBase):
|
|
|
267
290
|
"""
|
|
268
291
|
RegNetY-1.6GF: Recommended default for balanced performance.
|
|
269
292
|
|
|
270
|
-
~
|
|
293
|
+
~10.3M backbone parameters, 1.6 GFLOPs. Best trade-off of accuracy and efficiency.
|
|
271
294
|
Comparable to ResNet50 but more efficient.
|
|
272
295
|
|
|
273
296
|
Recommended for:
|
|
@@ -308,7 +331,7 @@ class RegNetY3_2GF(RegNetBase):
|
|
|
308
331
|
"""
|
|
309
332
|
RegNetY-3.2GF: Higher accuracy for demanding tasks.
|
|
310
333
|
|
|
311
|
-
~
|
|
334
|
+
~17.9M backbone parameters, 3.2 GFLOPs. Use when 1.6GF isn't sufficient.
|
|
312
335
|
|
|
313
336
|
Recommended for:
|
|
314
337
|
- Larger datasets requiring more capacity
|
|
@@ -348,7 +371,7 @@ class RegNetY8GF(RegNetBase):
|
|
|
348
371
|
"""
|
|
349
372
|
RegNetY-8GF: High capacity for large-scale tasks.
|
|
350
373
|
|
|
351
|
-
~
|
|
374
|
+
~37.4M backbone parameters, 8.0 GFLOPs. Use for maximum accuracy needs.
|
|
352
375
|
|
|
353
376
|
Recommended for:
|
|
354
377
|
- Very large datasets (>50k samples)
|
wavedl/models/resnet.py
CHANGED
|
@@ -11,9 +11,9 @@ Provides multiple depth variants (18, 34, 50) with optional pretrained weights f
|
|
|
11
11
|
- 3D: Volumetric data, CT/MRI (N, 1, D, H, W) → Conv3d
|
|
12
12
|
|
|
13
13
|
**Variants**:
|
|
14
|
-
- resnet18: Lightweight, fast training (~
|
|
15
|
-
- resnet34: Balanced capacity (~
|
|
16
|
-
- resnet50: Higher capacity with bottleneck blocks (~
|
|
14
|
+
- resnet18: Lightweight, fast training (~11.2M backbone params)
|
|
15
|
+
- resnet34: Balanced capacity (~21.3M backbone params)
|
|
16
|
+
- resnet50: Higher capacity with bottleneck blocks (~23.5M backbone params)
|
|
17
17
|
|
|
18
18
|
References:
|
|
19
19
|
He, K., Zhang, X., Ren, S., & Sun, J. (2016). Deep Residual Learning
|
|
@@ -534,7 +534,7 @@ class ResNet18Pretrained(PretrainedResNetBase):
|
|
|
534
534
|
"""
|
|
535
535
|
ResNet-18 with ImageNet pretrained weights (2D only).
|
|
536
536
|
|
|
537
|
-
~
|
|
537
|
+
~11.2M backbone parameters. Good for: Transfer learning, faster convergence.
|
|
538
538
|
|
|
539
539
|
Args:
|
|
540
540
|
in_shape: (H, W) image dimensions
|
|
@@ -563,7 +563,7 @@ class ResNet50Pretrained(PretrainedResNetBase):
|
|
|
563
563
|
"""
|
|
564
564
|
ResNet-50 with ImageNet pretrained weights (2D only).
|
|
565
565
|
|
|
566
|
-
~
|
|
566
|
+
~23.5M backbone parameters. Good for: High accuracy with transfer learning.
|
|
567
567
|
|
|
568
568
|
Args:
|
|
569
569
|
in_shape: (H, W) image dimensions
|
wavedl/models/resnet3d.py
CHANGED
|
@@ -179,7 +179,7 @@ class ResNet3D18(ResNet3DBase):
|
|
|
179
179
|
"""
|
|
180
180
|
ResNet3D-18: Lightweight 3D ResNet for volumetric data.
|
|
181
181
|
|
|
182
|
-
~
|
|
182
|
+
~33.2M backbone parameters. Uses 3D convolutions throughout for true volumetric processing.
|
|
183
183
|
Pretrained on Kinetics-400 (video action recognition).
|
|
184
184
|
|
|
185
185
|
Recommended for:
|
|
@@ -221,7 +221,7 @@ class MC3_18(ResNet3DBase):
|
|
|
221
221
|
"""
|
|
222
222
|
MC3-18: Mixed Convolution 3D ResNet (3D stem + 2D residual blocks).
|
|
223
223
|
|
|
224
|
-
~
|
|
224
|
+
~11.5M backbone parameters. More efficient than pure 3D ResNet while maintaining
|
|
225
225
|
good spatiotemporal modeling. Uses 3D convolutions in early layers
|
|
226
226
|
and 2D convolutions in later layers.
|
|
227
227
|
|
wavedl/models/swin.py
CHANGED
|
@@ -141,10 +141,46 @@ class SwinTransformerBase(BaseModel):
|
|
|
141
141
|
nn.Linear(regression_hidden // 2, out_size),
|
|
142
142
|
)
|
|
143
143
|
|
|
144
|
-
#
|
|
144
|
+
# Adapt patch embedding conv for single-channel input (3× memory savings vs expand)
|
|
145
|
+
self._adapt_input_channels()
|
|
146
|
+
|
|
147
|
+
# Optionally freeze backbone for fine-tuning (after adaptation so new conv is frozen too)
|
|
145
148
|
if freeze_backbone:
|
|
146
149
|
self._freeze_backbone()
|
|
147
150
|
|
|
151
|
+
def _adapt_input_channels(self):
|
|
152
|
+
"""Modify patch embedding conv to accept single-channel input.
|
|
153
|
+
|
|
154
|
+
Instead of expanding 1→3 channels in forward (which triples memory),
|
|
155
|
+
we replace the patch embedding conv with a 1-channel version and
|
|
156
|
+
initialize weights as the mean of the pretrained RGB filters.
|
|
157
|
+
"""
|
|
158
|
+
# Swin's patch embedding is at features[0][0]
|
|
159
|
+
try:
|
|
160
|
+
old_conv = self.backbone.features[0][0]
|
|
161
|
+
except (IndexError, AttributeError, TypeError) as e:
|
|
162
|
+
raise RuntimeError(
|
|
163
|
+
f"Swin patch embed structure changed in this torchvision version. "
|
|
164
|
+
f"Cannot adapt input channels. Error: {e}"
|
|
165
|
+
) from e
|
|
166
|
+
new_conv = nn.Conv2d(
|
|
167
|
+
1, # Single channel input
|
|
168
|
+
old_conv.out_channels,
|
|
169
|
+
kernel_size=old_conv.kernel_size,
|
|
170
|
+
stride=old_conv.stride,
|
|
171
|
+
padding=old_conv.padding,
|
|
172
|
+
dilation=old_conv.dilation,
|
|
173
|
+
groups=old_conv.groups,
|
|
174
|
+
padding_mode=old_conv.padding_mode,
|
|
175
|
+
bias=old_conv.bias is not None,
|
|
176
|
+
)
|
|
177
|
+
if self.pretrained:
|
|
178
|
+
with torch.no_grad():
|
|
179
|
+
new_conv.weight.copy_(old_conv.weight.mean(dim=1, keepdim=True))
|
|
180
|
+
if old_conv.bias is not None:
|
|
181
|
+
new_conv.bias.copy_(old_conv.bias)
|
|
182
|
+
self.backbone.features[0][0] = new_conv
|
|
183
|
+
|
|
148
184
|
def _freeze_backbone(self):
|
|
149
185
|
"""Freeze all backbone parameters except the head."""
|
|
150
186
|
for name, param in self.backbone.named_parameters():
|
|
@@ -156,15 +192,11 @@ class SwinTransformerBase(BaseModel):
|
|
|
156
192
|
Forward pass.
|
|
157
193
|
|
|
158
194
|
Args:
|
|
159
|
-
x: Input tensor of shape (B,
|
|
195
|
+
x: Input tensor of shape (B, 1, H, W)
|
|
160
196
|
|
|
161
197
|
Returns:
|
|
162
198
|
Output tensor of shape (B, out_size)
|
|
163
199
|
"""
|
|
164
|
-
# Expand single channel to 3 channels for pretrained weights compatibility
|
|
165
|
-
if x.size(1) == 1:
|
|
166
|
-
x = x.expand(-1, 3, -1, -1)
|
|
167
|
-
|
|
168
200
|
return self.backbone(x)
|
|
169
201
|
|
|
170
202
|
@classmethod
|
|
@@ -272,7 +304,7 @@ class SwinTiny(SwinTransformerBase):
|
|
|
272
304
|
"""
|
|
273
305
|
Swin-T (Tiny): Efficient default for most wave-based tasks.
|
|
274
306
|
|
|
275
|
-
~
|
|
307
|
+
~27.5M backbone parameters. Good balance of accuracy and computational cost.
|
|
276
308
|
Outperforms ResNet50 while being more efficient.
|
|
277
309
|
|
|
278
310
|
Recommended for:
|
|
@@ -321,7 +353,7 @@ class SwinSmall(SwinTransformerBase):
|
|
|
321
353
|
"""
|
|
322
354
|
Swin-S (Small): Higher accuracy with moderate compute.
|
|
323
355
|
|
|
324
|
-
~
|
|
356
|
+
~48.8M backbone parameters. Better accuracy than Swin-T for larger datasets.
|
|
325
357
|
|
|
326
358
|
Recommended for:
|
|
327
359
|
- Larger datasets (>20k samples)
|
|
@@ -368,7 +400,7 @@ class SwinBase(SwinTransformerBase):
|
|
|
368
400
|
"""
|
|
369
401
|
Swin-B (Base): Maximum accuracy for large-scale tasks.
|
|
370
402
|
|
|
371
|
-
~
|
|
403
|
+
~86.7M backbone parameters. Best accuracy but requires more compute and data.
|
|
372
404
|
|
|
373
405
|
Recommended for:
|
|
374
406
|
- Very large datasets (>50k samples)
|
wavedl/models/tcn.py
CHANGED
|
@@ -45,6 +45,26 @@ from wavedl.models.base import BaseModel
|
|
|
45
45
|
from wavedl.models.registry import register_model
|
|
46
46
|
|
|
47
47
|
|
|
48
|
+
def _find_group_count(channels: int, max_groups: int = 8) -> int:
|
|
49
|
+
"""
|
|
50
|
+
Find largest valid group count for GroupNorm.
|
|
51
|
+
|
|
52
|
+
GroupNorm requires channels to be divisible by num_groups.
|
|
53
|
+
This finds the largest divisor up to max_groups.
|
|
54
|
+
|
|
55
|
+
Args:
|
|
56
|
+
channels: Number of channels
|
|
57
|
+
max_groups: Maximum group count to consider (default: 8)
|
|
58
|
+
|
|
59
|
+
Returns:
|
|
60
|
+
Largest valid group count (always >= 1)
|
|
61
|
+
"""
|
|
62
|
+
for g in range(min(max_groups, channels), 0, -1):
|
|
63
|
+
if channels % g == 0:
|
|
64
|
+
return g
|
|
65
|
+
return 1
|
|
66
|
+
|
|
67
|
+
|
|
48
68
|
class CausalConv1d(nn.Module):
|
|
49
69
|
"""
|
|
50
70
|
Causal 1D convolution with dilation.
|
|
@@ -101,13 +121,13 @@ class TemporalBlock(nn.Module):
|
|
|
101
121
|
|
|
102
122
|
# First causal convolution
|
|
103
123
|
self.conv1 = CausalConv1d(in_channels, out_channels, kernel_size, dilation)
|
|
104
|
-
self.norm1 = nn.GroupNorm(
|
|
124
|
+
self.norm1 = nn.GroupNorm(_find_group_count(out_channels), out_channels)
|
|
105
125
|
self.act1 = nn.GELU()
|
|
106
126
|
self.dropout1 = nn.Dropout(dropout)
|
|
107
127
|
|
|
108
128
|
# Second causal convolution
|
|
109
129
|
self.conv2 = CausalConv1d(out_channels, out_channels, kernel_size, dilation)
|
|
110
|
-
self.norm2 = nn.GroupNorm(
|
|
130
|
+
self.norm2 = nn.GroupNorm(_find_group_count(out_channels), out_channels)
|
|
111
131
|
self.act2 = nn.GELU()
|
|
112
132
|
self.dropout2 = nn.Dropout(dropout)
|
|
113
133
|
|
|
@@ -276,7 +296,7 @@ class TCN(TCNBase):
|
|
|
276
296
|
"""
|
|
277
297
|
TCN: Standard Temporal Convolutional Network.
|
|
278
298
|
|
|
279
|
-
~
|
|
299
|
+
~6.9M backbone parameters. 8 temporal blocks with channels [64→128→256→256→512→512→512→512].
|
|
280
300
|
Receptive field: 511 samples with kernel_size=3.
|
|
281
301
|
|
|
282
302
|
Recommended for:
|
|
@@ -318,7 +338,7 @@ class TCNSmall(TCNBase):
|
|
|
318
338
|
"""
|
|
319
339
|
TCN-Small: Lightweight variant for quick experiments.
|
|
320
340
|
|
|
321
|
-
~
|
|
341
|
+
~0.9M backbone parameters. 6 temporal blocks with channels [32→64→128→128→256→256].
|
|
322
342
|
Receptive field: 127 samples with kernel_size=3.
|
|
323
343
|
|
|
324
344
|
Recommended for:
|
|
@@ -356,7 +376,7 @@ class TCNLarge(TCNBase):
|
|
|
356
376
|
"""
|
|
357
377
|
TCN-Large: High-capacity variant for complex patterns.
|
|
358
378
|
|
|
359
|
-
~10.
|
|
379
|
+
~10.0M backbone parameters. 10 temporal blocks with channels [64→128→256→256→512→512→512→512→512→512].
|
|
360
380
|
Receptive field: 2047 samples with kernel_size=3.
|
|
361
381
|
|
|
362
382
|
Recommended for:
|
wavedl/models/unet.py
CHANGED
|
@@ -119,7 +119,7 @@ class UNetRegression(BaseModel):
|
|
|
119
119
|
Uses U-Net encoder-decoder architecture with skip connections,
|
|
120
120
|
then applies global pooling for standard vector regression output.
|
|
121
121
|
|
|
122
|
-
~31.
|
|
122
|
+
~31.0M backbone parameters (2D). Good for leveraging multi-scale features
|
|
123
123
|
and skip connections for regression tasks.
|
|
124
124
|
|
|
125
125
|
Args:
|