wavedl 1.5.7__py3-none-any.whl → 1.6.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- wavedl/__init__.py +1 -1
- wavedl/models/__init__.py +52 -4
- wavedl/models/_timm_utils.py +238 -0
- wavedl/models/caformer.py +270 -0
- wavedl/models/convnext.py +108 -33
- wavedl/models/convnext_v2.py +504 -0
- wavedl/models/densenet.py +5 -5
- wavedl/models/efficientnet.py +6 -6
- wavedl/models/efficientnetv2.py +3 -3
- wavedl/models/fastvit.py +285 -0
- wavedl/models/mamba.py +535 -0
- wavedl/models/maxvit.py +251 -0
- wavedl/models/mobilenetv3.py +6 -6
- wavedl/models/regnet.py +10 -10
- wavedl/models/resnet.py +5 -5
- wavedl/models/resnet3d.py +2 -2
- wavedl/models/swin.py +3 -3
- wavedl/models/tcn.py +3 -3
- wavedl/models/unet.py +1 -1
- wavedl/models/vit.py +6 -6
- wavedl/train.py +21 -16
- wavedl/utils/data.py +39 -6
- {wavedl-1.5.7.dist-info → wavedl-1.6.0.dist-info}/METADATA +90 -62
- wavedl-1.6.0.dist-info/RECORD +44 -0
- wavedl-1.5.7.dist-info/RECORD +0 -38
- {wavedl-1.5.7.dist-info → wavedl-1.6.0.dist-info}/LICENSE +0 -0
- {wavedl-1.5.7.dist-info → wavedl-1.6.0.dist-info}/WHEEL +0 -0
- {wavedl-1.5.7.dist-info → wavedl-1.6.0.dist-info}/entry_points.txt +0 -0
- {wavedl-1.5.7.dist-info → wavedl-1.6.0.dist-info}/top_level.txt +0 -0
wavedl/models/maxvit.py
ADDED
|
@@ -0,0 +1,251 @@
|
|
|
1
|
+
"""
|
|
2
|
+
MaxViT: Multi-Axis Vision Transformer
|
|
3
|
+
======================================
|
|
4
|
+
|
|
5
|
+
MaxViT combines local and global attention with O(n) complexity using
|
|
6
|
+
multi-axis attention: block attention (local) + grid attention (global sparse).
|
|
7
|
+
|
|
8
|
+
**Key Features**:
|
|
9
|
+
- Multi-axis attention for both local and global context
|
|
10
|
+
- Hybrid design with MBConv + attention
|
|
11
|
+
- Linear O(n) complexity
|
|
12
|
+
- Hierarchical multi-scale features
|
|
13
|
+
|
|
14
|
+
**Variants**:
|
|
15
|
+
- maxvit_tiny: 31M params
|
|
16
|
+
- maxvit_small: 69M params
|
|
17
|
+
- maxvit_base: 120M params
|
|
18
|
+
|
|
19
|
+
**Requirements**:
|
|
20
|
+
- timm (for pretrained models and architecture)
|
|
21
|
+
- torchvision (fallback, limited support)
|
|
22
|
+
|
|
23
|
+
Reference:
|
|
24
|
+
Tu, Z., et al. (2022). MaxViT: Multi-Axis Vision Transformer.
|
|
25
|
+
ECCV 2022. https://arxiv.org/abs/2204.01697
|
|
26
|
+
|
|
27
|
+
Author: Ductho Le (ductho.le@outlook.com)
|
|
28
|
+
"""
|
|
29
|
+
|
|
30
|
+
import torch
|
|
31
|
+
import torch.nn as nn
|
|
32
|
+
|
|
33
|
+
from wavedl.models._timm_utils import build_regression_head
|
|
34
|
+
from wavedl.models.base import BaseModel
|
|
35
|
+
from wavedl.models.registry import register_model
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
__all__ = [
|
|
39
|
+
"MaxViTBase",
|
|
40
|
+
"MaxViTBaseLarge",
|
|
41
|
+
"MaxViTSmall",
|
|
42
|
+
"MaxViTTiny",
|
|
43
|
+
]
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
# =============================================================================
|
|
47
|
+
# MAXVIT BASE CLASS
|
|
48
|
+
# =============================================================================
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
class MaxViTBase(BaseModel):
|
|
52
|
+
"""
|
|
53
|
+
MaxViT base class wrapping timm implementation.
|
|
54
|
+
|
|
55
|
+
Multi-axis attention with local block and global grid attention.
|
|
56
|
+
2D only due to attention structure.
|
|
57
|
+
"""
|
|
58
|
+
|
|
59
|
+
def __init__(
|
|
60
|
+
self,
|
|
61
|
+
in_shape: tuple[int, int],
|
|
62
|
+
out_size: int,
|
|
63
|
+
model_name: str = "maxvit_tiny_tf_224",
|
|
64
|
+
pretrained: bool = True,
|
|
65
|
+
freeze_backbone: bool = False,
|
|
66
|
+
dropout_rate: float = 0.3,
|
|
67
|
+
**kwargs,
|
|
68
|
+
):
|
|
69
|
+
super().__init__(in_shape, out_size)
|
|
70
|
+
|
|
71
|
+
if len(in_shape) != 2:
|
|
72
|
+
raise ValueError(f"MaxViT requires 2D input (H, W), got {len(in_shape)}D")
|
|
73
|
+
|
|
74
|
+
self.pretrained = pretrained
|
|
75
|
+
self.freeze_backbone = freeze_backbone
|
|
76
|
+
self.model_name = model_name
|
|
77
|
+
|
|
78
|
+
# Try to load from timm
|
|
79
|
+
try:
|
|
80
|
+
import timm
|
|
81
|
+
|
|
82
|
+
self.backbone = timm.create_model(
|
|
83
|
+
model_name,
|
|
84
|
+
pretrained=pretrained,
|
|
85
|
+
num_classes=0, # Remove classifier
|
|
86
|
+
)
|
|
87
|
+
|
|
88
|
+
# Get feature dimension
|
|
89
|
+
with torch.no_grad():
|
|
90
|
+
dummy = torch.zeros(1, 3, *in_shape)
|
|
91
|
+
features = self.backbone(dummy)
|
|
92
|
+
in_features = features.shape[-1]
|
|
93
|
+
|
|
94
|
+
except ImportError:
|
|
95
|
+
raise ImportError(
|
|
96
|
+
"timm is required for MaxViT. Install with: pip install timm"
|
|
97
|
+
)
|
|
98
|
+
except Exception as e:
|
|
99
|
+
raise RuntimeError(f"Failed to load MaxViT model '{model_name}': {e}")
|
|
100
|
+
|
|
101
|
+
# Adapt input channels (3 -> 1)
|
|
102
|
+
self._adapt_input_channels()
|
|
103
|
+
|
|
104
|
+
# Regression head
|
|
105
|
+
self.head = build_regression_head(in_features, out_size, dropout_rate)
|
|
106
|
+
|
|
107
|
+
if freeze_backbone:
|
|
108
|
+
self._freeze_backbone()
|
|
109
|
+
|
|
110
|
+
def _adapt_input_channels(self):
|
|
111
|
+
"""Adapt first conv layer for single-channel input."""
|
|
112
|
+
# MaxViT uses stem.conv1 (Conv2dSame from timm)
|
|
113
|
+
adapted = False
|
|
114
|
+
|
|
115
|
+
# Find the first Conv2d with 3 input channels
|
|
116
|
+
for name, module in self.backbone.named_modules():
|
|
117
|
+
if hasattr(module, "in_channels") and module.in_channels == 3:
|
|
118
|
+
# Get parent and child names
|
|
119
|
+
parts = name.split(".")
|
|
120
|
+
parent = self.backbone
|
|
121
|
+
for part in parts[:-1]:
|
|
122
|
+
parent = getattr(parent, part)
|
|
123
|
+
child_name = parts[-1]
|
|
124
|
+
|
|
125
|
+
# Create new conv with 1 input channel
|
|
126
|
+
new_conv = self._make_new_conv(module)
|
|
127
|
+
setattr(parent, child_name, new_conv)
|
|
128
|
+
adapted = True
|
|
129
|
+
break
|
|
130
|
+
|
|
131
|
+
if not adapted:
|
|
132
|
+
import warnings
|
|
133
|
+
|
|
134
|
+
warnings.warn(
|
|
135
|
+
"Could not adapt MaxViT input channels. Model may fail.", stacklevel=2
|
|
136
|
+
)
|
|
137
|
+
|
|
138
|
+
def _make_new_conv(self, old_conv: nn.Module) -> nn.Module:
|
|
139
|
+
"""Create new conv layer with 1 input channel."""
|
|
140
|
+
# Handle both Conv2d and Conv2dSame from timm
|
|
141
|
+
type(old_conv)
|
|
142
|
+
|
|
143
|
+
# Get common parameters
|
|
144
|
+
kwargs = {
|
|
145
|
+
"out_channels": old_conv.out_channels,
|
|
146
|
+
"kernel_size": old_conv.kernel_size,
|
|
147
|
+
"stride": old_conv.stride,
|
|
148
|
+
"padding": old_conv.padding if hasattr(old_conv, "padding") else 0,
|
|
149
|
+
"bias": old_conv.bias is not None,
|
|
150
|
+
}
|
|
151
|
+
|
|
152
|
+
# Create new conv (use regular Conv2d for simplicity)
|
|
153
|
+
new_conv = nn.Conv2d(1, **kwargs)
|
|
154
|
+
|
|
155
|
+
if self.pretrained:
|
|
156
|
+
with torch.no_grad():
|
|
157
|
+
new_conv.weight.copy_(old_conv.weight.mean(dim=1, keepdim=True))
|
|
158
|
+
if old_conv.bias is not None:
|
|
159
|
+
new_conv.bias.copy_(old_conv.bias)
|
|
160
|
+
return new_conv
|
|
161
|
+
|
|
162
|
+
def _freeze_backbone(self):
|
|
163
|
+
"""Freeze backbone parameters."""
|
|
164
|
+
for param in self.backbone.parameters():
|
|
165
|
+
param.requires_grad = False
|
|
166
|
+
|
|
167
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
168
|
+
features = self.backbone(x)
|
|
169
|
+
return self.head(features)
|
|
170
|
+
|
|
171
|
+
|
|
172
|
+
# =============================================================================
|
|
173
|
+
# REGISTERED VARIANTS
|
|
174
|
+
# =============================================================================
|
|
175
|
+
|
|
176
|
+
|
|
177
|
+
@register_model("maxvit_tiny")
|
|
178
|
+
class MaxViTTiny(MaxViTBase):
|
|
179
|
+
"""
|
|
180
|
+
MaxViT Tiny: ~30.1M backbone parameters.
|
|
181
|
+
|
|
182
|
+
Multi-axis attention with local+global context.
|
|
183
|
+
2D only.
|
|
184
|
+
|
|
185
|
+
Example:
|
|
186
|
+
>>> model = MaxViTTiny(in_shape=(224, 224), out_size=3)
|
|
187
|
+
>>> x = torch.randn(4, 1, 224, 224)
|
|
188
|
+
>>> out = model(x) # (4, 3)
|
|
189
|
+
"""
|
|
190
|
+
|
|
191
|
+
def __init__(self, in_shape: tuple[int, int], out_size: int, **kwargs):
|
|
192
|
+
super().__init__(
|
|
193
|
+
in_shape=in_shape,
|
|
194
|
+
out_size=out_size,
|
|
195
|
+
model_name="maxvit_tiny_tf_224",
|
|
196
|
+
**kwargs,
|
|
197
|
+
)
|
|
198
|
+
|
|
199
|
+
def __repr__(self) -> str:
|
|
200
|
+
return (
|
|
201
|
+
f"MaxViT_Tiny(in_shape={self.in_shape}, out_size={self.out_size}, "
|
|
202
|
+
f"pretrained={self.pretrained})"
|
|
203
|
+
)
|
|
204
|
+
|
|
205
|
+
|
|
206
|
+
@register_model("maxvit_small")
|
|
207
|
+
class MaxViTSmall(MaxViTBase):
|
|
208
|
+
"""
|
|
209
|
+
MaxViT Small: ~67.6M backbone parameters.
|
|
210
|
+
|
|
211
|
+
Multi-axis attention with local+global context.
|
|
212
|
+
2D only.
|
|
213
|
+
"""
|
|
214
|
+
|
|
215
|
+
def __init__(self, in_shape: tuple[int, int], out_size: int, **kwargs):
|
|
216
|
+
super().__init__(
|
|
217
|
+
in_shape=in_shape,
|
|
218
|
+
out_size=out_size,
|
|
219
|
+
model_name="maxvit_small_tf_224",
|
|
220
|
+
**kwargs,
|
|
221
|
+
)
|
|
222
|
+
|
|
223
|
+
def __repr__(self) -> str:
|
|
224
|
+
return (
|
|
225
|
+
f"MaxViT_Small(in_shape={self.in_shape}, out_size={self.out_size}, "
|
|
226
|
+
f"pretrained={self.pretrained})"
|
|
227
|
+
)
|
|
228
|
+
|
|
229
|
+
|
|
230
|
+
@register_model("maxvit_base")
|
|
231
|
+
class MaxViTBaseLarge(MaxViTBase):
|
|
232
|
+
"""
|
|
233
|
+
MaxViT Base: ~118.1M backbone parameters.
|
|
234
|
+
|
|
235
|
+
Multi-axis attention with local+global context.
|
|
236
|
+
2D only.
|
|
237
|
+
"""
|
|
238
|
+
|
|
239
|
+
def __init__(self, in_shape: tuple[int, int], out_size: int, **kwargs):
|
|
240
|
+
super().__init__(
|
|
241
|
+
in_shape=in_shape,
|
|
242
|
+
out_size=out_size,
|
|
243
|
+
model_name="maxvit_base_tf_224",
|
|
244
|
+
**kwargs,
|
|
245
|
+
)
|
|
246
|
+
|
|
247
|
+
def __repr__(self) -> str:
|
|
248
|
+
return (
|
|
249
|
+
f"MaxViT_Base(in_shape={self.in_shape}, out_size={self.out_size}, "
|
|
250
|
+
f"pretrained={self.pretrained})"
|
|
251
|
+
)
|
wavedl/models/mobilenetv3.py
CHANGED
|
@@ -13,8 +13,8 @@ optimization to achieve excellent accuracy with minimal computational cost.
|
|
|
13
13
|
- Designed for real-time inference on CPUs and edge devices
|
|
14
14
|
|
|
15
15
|
**Variants**:
|
|
16
|
-
- mobilenet_v3_small: Ultra-lightweight (~
|
|
17
|
-
- mobilenet_v3_large: Balanced (~3.
|
|
16
|
+
- mobilenet_v3_small: Ultra-lightweight (~0.9M backbone params) - Edge/embedded
|
|
17
|
+
- mobilenet_v3_large: Balanced (~3.0M backbone params) - Mobile deployment
|
|
18
18
|
|
|
19
19
|
**Use Cases**:
|
|
20
20
|
- Real-time structural health monitoring on embedded systems
|
|
@@ -206,7 +206,7 @@ class MobileNetV3Small(MobileNetV3Base):
|
|
|
206
206
|
"""
|
|
207
207
|
MobileNetV3-Small: Ultra-lightweight for edge deployment.
|
|
208
208
|
|
|
209
|
-
~
|
|
209
|
+
~0.9M backbone parameters. Designed for the most constrained environments.
|
|
210
210
|
Achieves ~67% ImageNet accuracy with minimal compute.
|
|
211
211
|
|
|
212
212
|
Recommended for:
|
|
@@ -217,7 +217,7 @@ class MobileNetV3Small(MobileNetV3Base):
|
|
|
217
217
|
|
|
218
218
|
Performance (approximate):
|
|
219
219
|
- CPU inference: ~6ms (single core)
|
|
220
|
-
- Parameters: ~
|
|
220
|
+
- Parameters: ~0.9M backbone
|
|
221
221
|
- MAdds: 56M
|
|
222
222
|
|
|
223
223
|
Args:
|
|
@@ -253,7 +253,7 @@ class MobileNetV3Large(MobileNetV3Base):
|
|
|
253
253
|
"""
|
|
254
254
|
MobileNetV3-Large: Balanced efficiency and accuracy.
|
|
255
255
|
|
|
256
|
-
~3.
|
|
256
|
+
~3.0M backbone parameters. Best trade-off for mobile/portable deployment.
|
|
257
257
|
Achieves ~75% ImageNet accuracy with efficient inference.
|
|
258
258
|
|
|
259
259
|
Recommended for:
|
|
@@ -264,7 +264,7 @@ class MobileNetV3Large(MobileNetV3Base):
|
|
|
264
264
|
|
|
265
265
|
Performance (approximate):
|
|
266
266
|
- CPU inference: ~20ms (single core)
|
|
267
|
-
- Parameters: ~3.
|
|
267
|
+
- Parameters: ~3.0M backbone
|
|
268
268
|
- MAdds: 219M
|
|
269
269
|
|
|
270
270
|
Args:
|
wavedl/models/regnet.py
CHANGED
|
@@ -13,11 +13,11 @@ Models scale smoothly from mobile to server deployments.
|
|
|
13
13
|
- Optional Squeeze-and-Excitation (SE) attention
|
|
14
14
|
|
|
15
15
|
**Variants** (RegNetY includes SE attention):
|
|
16
|
-
- regnet_y_400mf: Ultra-light (~
|
|
17
|
-
- regnet_y_800mf: Light (~5.
|
|
18
|
-
- regnet_y_1_6gf: Medium (~10.
|
|
19
|
-
- regnet_y_3_2gf: Large (~
|
|
20
|
-
- regnet_y_8gf: Very large (~37.
|
|
16
|
+
- regnet_y_400mf: Ultra-light (~3.9M backbone params, 0.4 GFLOPs)
|
|
17
|
+
- regnet_y_800mf: Light (~5.7M backbone params, 0.8 GFLOPs)
|
|
18
|
+
- regnet_y_1_6gf: Medium (~10.3M backbone params, 1.6 GFLOPs) - Recommended
|
|
19
|
+
- regnet_y_3_2gf: Large (~17.9M backbone params, 3.2 GFLOPs)
|
|
20
|
+
- regnet_y_8gf: Very large (~37.4M backbone params, 8.0 GFLOPs)
|
|
21
21
|
|
|
22
22
|
**When to Use RegNet**:
|
|
23
23
|
- When you need predictable performance at a given compute budget
|
|
@@ -210,7 +210,7 @@ class RegNetY400MF(RegNetBase):
|
|
|
210
210
|
"""
|
|
211
211
|
RegNetY-400MF: Ultra-lightweight for constrained environments.
|
|
212
212
|
|
|
213
|
-
~
|
|
213
|
+
~3.9M backbone parameters, 0.4 GFLOPs. Smallest RegNet variant with SE attention.
|
|
214
214
|
|
|
215
215
|
Recommended for:
|
|
216
216
|
- Edge deployment with moderate accuracy needs
|
|
@@ -250,7 +250,7 @@ class RegNetY800MF(RegNetBase):
|
|
|
250
250
|
"""
|
|
251
251
|
RegNetY-800MF: Light variant with good accuracy.
|
|
252
252
|
|
|
253
|
-
~
|
|
253
|
+
~5.7M backbone parameters, 0.8 GFLOPs. Good balance for mobile deployment.
|
|
254
254
|
|
|
255
255
|
Recommended for:
|
|
256
256
|
- Mobile/portable devices
|
|
@@ -290,7 +290,7 @@ class RegNetY1_6GF(RegNetBase):
|
|
|
290
290
|
"""
|
|
291
291
|
RegNetY-1.6GF: Recommended default for balanced performance.
|
|
292
292
|
|
|
293
|
-
~
|
|
293
|
+
~10.3M backbone parameters, 1.6 GFLOPs. Best trade-off of accuracy and efficiency.
|
|
294
294
|
Comparable to ResNet50 but more efficient.
|
|
295
295
|
|
|
296
296
|
Recommended for:
|
|
@@ -331,7 +331,7 @@ class RegNetY3_2GF(RegNetBase):
|
|
|
331
331
|
"""
|
|
332
332
|
RegNetY-3.2GF: Higher accuracy for demanding tasks.
|
|
333
333
|
|
|
334
|
-
~
|
|
334
|
+
~17.9M backbone parameters, 3.2 GFLOPs. Use when 1.6GF isn't sufficient.
|
|
335
335
|
|
|
336
336
|
Recommended for:
|
|
337
337
|
- Larger datasets requiring more capacity
|
|
@@ -371,7 +371,7 @@ class RegNetY8GF(RegNetBase):
|
|
|
371
371
|
"""
|
|
372
372
|
RegNetY-8GF: High capacity for large-scale tasks.
|
|
373
373
|
|
|
374
|
-
~
|
|
374
|
+
~37.4M backbone parameters, 8.0 GFLOPs. Use for maximum accuracy needs.
|
|
375
375
|
|
|
376
376
|
Recommended for:
|
|
377
377
|
- Very large datasets (>50k samples)
|
wavedl/models/resnet.py
CHANGED
|
@@ -11,9 +11,9 @@ Provides multiple depth variants (18, 34, 50) with optional pretrained weights f
|
|
|
11
11
|
- 3D: Volumetric data, CT/MRI (N, 1, D, H, W) → Conv3d
|
|
12
12
|
|
|
13
13
|
**Variants**:
|
|
14
|
-
- resnet18: Lightweight, fast training (~
|
|
15
|
-
- resnet34: Balanced capacity (~
|
|
16
|
-
- resnet50: Higher capacity with bottleneck blocks (~
|
|
14
|
+
- resnet18: Lightweight, fast training (~11.2M backbone params)
|
|
15
|
+
- resnet34: Balanced capacity (~21.3M backbone params)
|
|
16
|
+
- resnet50: Higher capacity with bottleneck blocks (~23.5M backbone params)
|
|
17
17
|
|
|
18
18
|
References:
|
|
19
19
|
He, K., Zhang, X., Ren, S., & Sun, J. (2016). Deep Residual Learning
|
|
@@ -534,7 +534,7 @@ class ResNet18Pretrained(PretrainedResNetBase):
|
|
|
534
534
|
"""
|
|
535
535
|
ResNet-18 with ImageNet pretrained weights (2D only).
|
|
536
536
|
|
|
537
|
-
~
|
|
537
|
+
~11.2M backbone parameters. Good for: Transfer learning, faster convergence.
|
|
538
538
|
|
|
539
539
|
Args:
|
|
540
540
|
in_shape: (H, W) image dimensions
|
|
@@ -563,7 +563,7 @@ class ResNet50Pretrained(PretrainedResNetBase):
|
|
|
563
563
|
"""
|
|
564
564
|
ResNet-50 with ImageNet pretrained weights (2D only).
|
|
565
565
|
|
|
566
|
-
~
|
|
566
|
+
~23.5M backbone parameters. Good for: High accuracy with transfer learning.
|
|
567
567
|
|
|
568
568
|
Args:
|
|
569
569
|
in_shape: (H, W) image dimensions
|
wavedl/models/resnet3d.py
CHANGED
|
@@ -179,7 +179,7 @@ class ResNet3D18(ResNet3DBase):
|
|
|
179
179
|
"""
|
|
180
180
|
ResNet3D-18: Lightweight 3D ResNet for volumetric data.
|
|
181
181
|
|
|
182
|
-
~
|
|
182
|
+
~33.2M backbone parameters. Uses 3D convolutions throughout for true volumetric processing.
|
|
183
183
|
Pretrained on Kinetics-400 (video action recognition).
|
|
184
184
|
|
|
185
185
|
Recommended for:
|
|
@@ -221,7 +221,7 @@ class MC3_18(ResNet3DBase):
|
|
|
221
221
|
"""
|
|
222
222
|
MC3-18: Mixed Convolution 3D ResNet (3D stem + 2D residual blocks).
|
|
223
223
|
|
|
224
|
-
~
|
|
224
|
+
~11.5M backbone parameters. More efficient than pure 3D ResNet while maintaining
|
|
225
225
|
good spatiotemporal modeling. Uses 3D convolutions in early layers
|
|
226
226
|
and 2D convolutions in later layers.
|
|
227
227
|
|
wavedl/models/swin.py
CHANGED
|
@@ -304,7 +304,7 @@ class SwinTiny(SwinTransformerBase):
|
|
|
304
304
|
"""
|
|
305
305
|
Swin-T (Tiny): Efficient default for most wave-based tasks.
|
|
306
306
|
|
|
307
|
-
~
|
|
307
|
+
~27.5M backbone parameters. Good balance of accuracy and computational cost.
|
|
308
308
|
Outperforms ResNet50 while being more efficient.
|
|
309
309
|
|
|
310
310
|
Recommended for:
|
|
@@ -353,7 +353,7 @@ class SwinSmall(SwinTransformerBase):
|
|
|
353
353
|
"""
|
|
354
354
|
Swin-S (Small): Higher accuracy with moderate compute.
|
|
355
355
|
|
|
356
|
-
~
|
|
356
|
+
~48.8M backbone parameters. Better accuracy than Swin-T for larger datasets.
|
|
357
357
|
|
|
358
358
|
Recommended for:
|
|
359
359
|
- Larger datasets (>20k samples)
|
|
@@ -400,7 +400,7 @@ class SwinBase(SwinTransformerBase):
|
|
|
400
400
|
"""
|
|
401
401
|
Swin-B (Base): Maximum accuracy for large-scale tasks.
|
|
402
402
|
|
|
403
|
-
~
|
|
403
|
+
~86.7M backbone parameters. Best accuracy but requires more compute and data.
|
|
404
404
|
|
|
405
405
|
Recommended for:
|
|
406
406
|
- Very large datasets (>50k samples)
|
wavedl/models/tcn.py
CHANGED
|
@@ -296,7 +296,7 @@ class TCN(TCNBase):
|
|
|
296
296
|
"""
|
|
297
297
|
TCN: Standard Temporal Convolutional Network.
|
|
298
298
|
|
|
299
|
-
~
|
|
299
|
+
~6.9M backbone parameters. 8 temporal blocks with channels [64→128→256→256→512→512→512→512].
|
|
300
300
|
Receptive field: 511 samples with kernel_size=3.
|
|
301
301
|
|
|
302
302
|
Recommended for:
|
|
@@ -338,7 +338,7 @@ class TCNSmall(TCNBase):
|
|
|
338
338
|
"""
|
|
339
339
|
TCN-Small: Lightweight variant for quick experiments.
|
|
340
340
|
|
|
341
|
-
~
|
|
341
|
+
~0.9M backbone parameters. 6 temporal blocks with channels [32→64→128→128→256→256].
|
|
342
342
|
Receptive field: 127 samples with kernel_size=3.
|
|
343
343
|
|
|
344
344
|
Recommended for:
|
|
@@ -376,7 +376,7 @@ class TCNLarge(TCNBase):
|
|
|
376
376
|
"""
|
|
377
377
|
TCN-Large: High-capacity variant for complex patterns.
|
|
378
378
|
|
|
379
|
-
~10.
|
|
379
|
+
~10.0M backbone parameters. 10 temporal blocks with channels [64→128→256→256→512→512→512→512→512→512].
|
|
380
380
|
Receptive field: 2047 samples with kernel_size=3.
|
|
381
381
|
|
|
382
382
|
Recommended for:
|
wavedl/models/unet.py
CHANGED
|
@@ -119,7 +119,7 @@ class UNetRegression(BaseModel):
|
|
|
119
119
|
Uses U-Net encoder-decoder architecture with skip connections,
|
|
120
120
|
then applies global pooling for standard vector regression output.
|
|
121
121
|
|
|
122
|
-
~31.
|
|
122
|
+
~31.0M backbone parameters (2D). Good for leveraging multi-scale features
|
|
123
123
|
and skip connections for regression tasks.
|
|
124
124
|
|
|
125
125
|
Args:
|
wavedl/models/vit.py
CHANGED
|
@@ -10,9 +10,9 @@ Supports both 1D (signals) and 2D (images) inputs via configurable patch embeddi
|
|
|
10
10
|
- 2D: Images/spectrograms → patches are grid squares
|
|
11
11
|
|
|
12
12
|
**Variants**:
|
|
13
|
-
- vit_tiny: Smallest (~5.
|
|
14
|
-
- vit_small: Light (~
|
|
15
|
-
- vit_base: Standard (~
|
|
13
|
+
- vit_tiny: Smallest (~5.4M backbone params, embed_dim=192, depth=12, heads=3)
|
|
14
|
+
- vit_small: Light (~21.4M backbone params, embed_dim=384, depth=12, heads=6)
|
|
15
|
+
- vit_base: Standard (~85.3M backbone params, embed_dim=768, depth=12, heads=12)
|
|
16
16
|
|
|
17
17
|
References:
|
|
18
18
|
Dosovitskiy, A., et al. (2021). An Image is Worth 16x16 Words:
|
|
@@ -365,7 +365,7 @@ class ViTTiny(ViTBase):
|
|
|
365
365
|
"""
|
|
366
366
|
ViT-Tiny: Smallest Vision Transformer variant.
|
|
367
367
|
|
|
368
|
-
~5.
|
|
368
|
+
~5.4M backbone parameters. Good for: Quick experiments, smaller datasets.
|
|
369
369
|
|
|
370
370
|
Args:
|
|
371
371
|
in_shape: (L,) for 1D or (H, W) for 2D
|
|
@@ -398,7 +398,7 @@ class ViTSmall(ViTBase):
|
|
|
398
398
|
"""
|
|
399
399
|
ViT-Small: Light Vision Transformer variant.
|
|
400
400
|
|
|
401
|
-
~
|
|
401
|
+
~21.4M backbone parameters. Good for: Balanced performance.
|
|
402
402
|
|
|
403
403
|
Args:
|
|
404
404
|
in_shape: (L,) for 1D or (H, W) for 2D
|
|
@@ -429,7 +429,7 @@ class ViTBase_(ViTBase):
|
|
|
429
429
|
"""
|
|
430
430
|
ViT-Base: Standard Vision Transformer variant.
|
|
431
431
|
|
|
432
|
-
~
|
|
432
|
+
~85.3M backbone parameters. Good for: High accuracy, larger datasets.
|
|
433
433
|
|
|
434
434
|
Args:
|
|
435
435
|
in_shape: (L,) for 1D or (H, W) for 2D
|
wavedl/train.py
CHANGED
|
@@ -239,11 +239,12 @@ def parse_args() -> argparse.Namespace:
|
|
|
239
239
|
help="Python modules to import before training (for custom models)",
|
|
240
240
|
)
|
|
241
241
|
parser.add_argument(
|
|
242
|
-
"--
|
|
243
|
-
|
|
244
|
-
|
|
245
|
-
help="
|
|
242
|
+
"--no_pretrained",
|
|
243
|
+
dest="pretrained",
|
|
244
|
+
action="store_false",
|
|
245
|
+
help="Train from scratch without pretrained weights (default: use pretrained)",
|
|
246
246
|
)
|
|
247
|
+
parser.set_defaults(pretrained=True)
|
|
247
248
|
|
|
248
249
|
# Configuration File
|
|
249
250
|
parser.add_argument(
|
|
@@ -1028,12 +1029,14 @@ def main():
|
|
|
1028
1029
|
|
|
1029
1030
|
for x, y in pbar:
|
|
1030
1031
|
with accelerator.accumulate(model):
|
|
1031
|
-
|
|
1032
|
-
|
|
1033
|
-
|
|
1034
|
-
|
|
1035
|
-
|
|
1036
|
-
|
|
1032
|
+
# Use mixed precision for forward pass (respects --precision flag)
|
|
1033
|
+
with accelerator.autocast():
|
|
1034
|
+
pred = model(x)
|
|
1035
|
+
# Pass inputs for input-dependent constraints (x_mean, x[...], etc.)
|
|
1036
|
+
if isinstance(criterion, PhysicsConstrainedLoss):
|
|
1037
|
+
loss = criterion(pred, y, x)
|
|
1038
|
+
else:
|
|
1039
|
+
loss = criterion(pred, y)
|
|
1037
1040
|
|
|
1038
1041
|
accelerator.backward(loss)
|
|
1039
1042
|
|
|
@@ -1082,12 +1085,14 @@ def main():
|
|
|
1082
1085
|
|
|
1083
1086
|
with torch.inference_mode():
|
|
1084
1087
|
for x, y in val_dl:
|
|
1085
|
-
|
|
1086
|
-
|
|
1087
|
-
|
|
1088
|
-
|
|
1089
|
-
|
|
1090
|
-
|
|
1088
|
+
# Use mixed precision for validation (consistent with training)
|
|
1089
|
+
with accelerator.autocast():
|
|
1090
|
+
pred = model(x)
|
|
1091
|
+
# Pass inputs for input-dependent constraints
|
|
1092
|
+
if isinstance(criterion, PhysicsConstrainedLoss):
|
|
1093
|
+
loss = criterion(pred, y, x)
|
|
1094
|
+
else:
|
|
1095
|
+
loss = criterion(pred, y)
|
|
1091
1096
|
|
|
1092
1097
|
val_loss_sum += loss.detach() * x.size(0)
|
|
1093
1098
|
val_samples += x.size(0)
|
wavedl/utils/data.py
CHANGED
|
@@ -474,9 +474,18 @@ class _TransposedH5Dataset:
|
|
|
474
474
|
self.shape = tuple(reversed(h5_dataset.shape))
|
|
475
475
|
self.dtype = h5_dataset.dtype
|
|
476
476
|
|
|
477
|
-
|
|
478
|
-
|
|
479
|
-
|
|
477
|
+
@property
|
|
478
|
+
def ndim(self) -> int:
|
|
479
|
+
"""Number of dimensions (derived from shape for numpy compatibility)."""
|
|
480
|
+
return len(self.shape)
|
|
481
|
+
|
|
482
|
+
@property
|
|
483
|
+
def _transpose_axes(self) -> tuple[int, ...]:
|
|
484
|
+
"""Transpose axis order for reversing dimensions.
|
|
485
|
+
|
|
486
|
+
For shape (A, B, C) -> reversed (C, B, A), transpose axes are (2, 1, 0).
|
|
487
|
+
"""
|
|
488
|
+
return tuple(range(len(self._dataset.shape) - 1, -1, -1))
|
|
480
489
|
|
|
481
490
|
def __len__(self) -> int:
|
|
482
491
|
return self.shape[0]
|
|
@@ -965,8 +974,17 @@ def load_test_data(
|
|
|
965
974
|
else:
|
|
966
975
|
# Fallback to default source.load() for unknown formats
|
|
967
976
|
inp, outp = source.load(path)
|
|
968
|
-
except KeyError:
|
|
969
|
-
#
|
|
977
|
+
except KeyError as e:
|
|
978
|
+
# IMPORTANT: Only fall back to inference-only mode if outputs are
|
|
979
|
+
# genuinely missing (auto-detection failed). If user explicitly
|
|
980
|
+
# provided --output_key, they expect it to exist - don't silently drop.
|
|
981
|
+
if output_key is not None:
|
|
982
|
+
raise KeyError(
|
|
983
|
+
f"Explicit --output_key '{output_key}' not found in file. "
|
|
984
|
+
f"Available keys depend on file format. Original error: {e}"
|
|
985
|
+
) from e
|
|
986
|
+
|
|
987
|
+
# Legitimate fallback: no explicit output_key, outputs just not present
|
|
970
988
|
if format == "npz":
|
|
971
989
|
# First pass to find keys
|
|
972
990
|
with np.load(path, allow_pickle=False) as probe:
|
|
@@ -1083,11 +1101,26 @@ def load_test_data(
|
|
|
1083
1101
|
raise ValueError(
|
|
1084
1102
|
f"Input appears to be channels-last format: {tuple(X.shape)}. "
|
|
1085
1103
|
"WaveDL expects channels-first (N, C, H, W). "
|
|
1086
|
-
"Convert your data using: X = X.permute(0, 3, 1, 2)"
|
|
1104
|
+
"Convert your data using: X = X.permute(0, 3, 1, 2). "
|
|
1105
|
+
"If this is actually a 3D volume with small depth, "
|
|
1106
|
+
"use --input_channels 1 to add a channel dimension."
|
|
1087
1107
|
)
|
|
1088
1108
|
elif X.shape[1] > 16:
|
|
1089
1109
|
# Heuristic fallback: large dim 1 suggests 3D volume needing channel
|
|
1090
1110
|
X = X.unsqueeze(1) # 3D volume: (N, D, H, W) → (N, 1, D, H, W)
|
|
1111
|
+
else:
|
|
1112
|
+
# Ambiguous case: shallow 3D volume (D <= 16) or multi-channel 2D
|
|
1113
|
+
# Default to treating as multi-channel 2D (no modification needed)
|
|
1114
|
+
# Log a warning so users know about the --input_channels option
|
|
1115
|
+
import warnings
|
|
1116
|
+
|
|
1117
|
+
warnings.warn(
|
|
1118
|
+
f"Ambiguous 4D input shape: {tuple(X.shape)}. "
|
|
1119
|
+
f"Assuming {X.shape[1]} channels (multi-channel 2D). "
|
|
1120
|
+
f"For 3D volumes with depth={X.shape[1]}, use --input_channels 1.",
|
|
1121
|
+
UserWarning,
|
|
1122
|
+
stacklevel=2,
|
|
1123
|
+
)
|
|
1091
1124
|
# X.ndim >= 5: assume channel dimension already exists
|
|
1092
1125
|
|
|
1093
1126
|
return X, y
|