wavedl 1.3.1__py3-none-any.whl → 1.4.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- wavedl/__init__.py +1 -1
- wavedl/hpc.py +48 -28
- wavedl/models/__init__.py +33 -7
- wavedl/models/_template.py +28 -41
- wavedl/models/base.py +49 -2
- wavedl/models/cnn.py +0 -1
- wavedl/models/convnext.py +4 -1
- wavedl/models/densenet.py +4 -1
- wavedl/models/efficientnet.py +9 -5
- wavedl/models/efficientnetv2.py +292 -0
- wavedl/models/mobilenetv3.py +272 -0
- wavedl/models/registry.py +0 -1
- wavedl/models/regnet.py +383 -0
- wavedl/models/resnet.py +7 -4
- wavedl/models/resnet3d.py +258 -0
- wavedl/models/swin.py +390 -0
- wavedl/models/tcn.py +389 -0
- wavedl/models/unet.py +44 -110
- wavedl/models/vit.py +8 -4
- wavedl/train.py +1144 -1116
- wavedl/utils/config.py +88 -2
- {wavedl-1.3.1.dist-info → wavedl-1.4.1.dist-info}/METADATA +136 -98
- wavedl-1.4.1.dist-info/RECORD +37 -0
- wavedl-1.3.1.dist-info/RECORD +0 -31
- {wavedl-1.3.1.dist-info → wavedl-1.4.1.dist-info}/LICENSE +0 -0
- {wavedl-1.3.1.dist-info → wavedl-1.4.1.dist-info}/WHEEL +0 -0
- {wavedl-1.3.1.dist-info → wavedl-1.4.1.dist-info}/entry_points.txt +0 -0
- {wavedl-1.3.1.dist-info → wavedl-1.4.1.dist-info}/top_level.txt +0 -0
wavedl/models/swin.py
ADDED
|
@@ -0,0 +1,390 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Swin Transformer: Hierarchical Vision Transformer with Shifted Windows
|
|
3
|
+
=======================================================================
|
|
4
|
+
|
|
5
|
+
State-of-the-art vision transformer that computes self-attention within
|
|
6
|
+
local windows while enabling cross-window connections via shifting.
|
|
7
|
+
Achieves excellent accuracy with linear computational complexity.
|
|
8
|
+
|
|
9
|
+
**Key Innovations**:
|
|
10
|
+
- Hierarchical feature maps (like CNNs) for multi-scale processing
|
|
11
|
+
- Shifted window attention: O(n) complexity vs O(n²) for vanilla ViT
|
|
12
|
+
- Local attention with global receptive field through layer stacking
|
|
13
|
+
- Strong inductive bias for structured data
|
|
14
|
+
|
|
15
|
+
**Variants**:
|
|
16
|
+
- swin_t: Tiny (28M params) - Efficient default
|
|
17
|
+
- swin_s: Small (50M params) - Better accuracy
|
|
18
|
+
- swin_b: Base (88M params) - High accuracy
|
|
19
|
+
|
|
20
|
+
**Why Swin over ViT?**:
|
|
21
|
+
- Better for smaller datasets (stronger inductive bias)
|
|
22
|
+
- Handles higher resolution inputs efficiently
|
|
23
|
+
- Produces hierarchical features (useful for multi-scale patterns)
|
|
24
|
+
- More efficient memory usage
|
|
25
|
+
|
|
26
|
+
**Note**: Swin Transformer is 2D-only. For 1D data, use TCN. For 3D data, use ResNet3D.
|
|
27
|
+
|
|
28
|
+
References:
|
|
29
|
+
Liu, Z., et al. (2021). Swin Transformer: Hierarchical Vision Transformer
|
|
30
|
+
using Shifted Windows. ICCV 2021 (Best Paper). https://arxiv.org/abs/2103.14030
|
|
31
|
+
|
|
32
|
+
Author: Ductho Le (ductho.le@outlook.com)
|
|
33
|
+
"""
|
|
34
|
+
|
|
35
|
+
from typing import Any
|
|
36
|
+
|
|
37
|
+
import torch
|
|
38
|
+
import torch.nn as nn
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
try:
|
|
42
|
+
from torchvision.models import (
|
|
43
|
+
Swin_B_Weights,
|
|
44
|
+
Swin_S_Weights,
|
|
45
|
+
Swin_T_Weights,
|
|
46
|
+
swin_b,
|
|
47
|
+
swin_s,
|
|
48
|
+
swin_t,
|
|
49
|
+
)
|
|
50
|
+
|
|
51
|
+
SWIN_AVAILABLE = True
|
|
52
|
+
except ImportError:
|
|
53
|
+
SWIN_AVAILABLE = False
|
|
54
|
+
|
|
55
|
+
from wavedl.models.base import BaseModel
|
|
56
|
+
from wavedl.models.registry import register_model
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
class SwinTransformerBase(BaseModel):
|
|
60
|
+
"""
|
|
61
|
+
Base Swin Transformer class for regression tasks.
|
|
62
|
+
|
|
63
|
+
Wraps torchvision Swin Transformer with:
|
|
64
|
+
- Optional pretrained weights (ImageNet-1K or ImageNet-22K)
|
|
65
|
+
- Automatic input channel adaptation (grayscale → 3ch)
|
|
66
|
+
- Custom regression head with layer normalization
|
|
67
|
+
|
|
68
|
+
Swin Transformer excels at:
|
|
69
|
+
- Multi-scale feature extraction (dispersion curves, spectrograms)
|
|
70
|
+
- High-resolution inputs (efficient O(n) attention)
|
|
71
|
+
- Tasks requiring both local and global context
|
|
72
|
+
- Transfer learning from pretrained weights
|
|
73
|
+
|
|
74
|
+
Note: This is 2D-only. Input shape must be (H, W).
|
|
75
|
+
"""
|
|
76
|
+
|
|
77
|
+
def __init__(
|
|
78
|
+
self,
|
|
79
|
+
in_shape: tuple[int, int],
|
|
80
|
+
out_size: int,
|
|
81
|
+
model_fn,
|
|
82
|
+
weights_class,
|
|
83
|
+
pretrained: bool = True,
|
|
84
|
+
dropout_rate: float = 0.3,
|
|
85
|
+
freeze_backbone: bool = False,
|
|
86
|
+
regression_hidden: int = 512,
|
|
87
|
+
**kwargs,
|
|
88
|
+
):
|
|
89
|
+
"""
|
|
90
|
+
Initialize Swin Transformer for regression.
|
|
91
|
+
|
|
92
|
+
Args:
|
|
93
|
+
in_shape: (H, W) input image dimensions
|
|
94
|
+
out_size: Number of regression output targets
|
|
95
|
+
model_fn: torchvision model constructor
|
|
96
|
+
weights_class: Pretrained weights enum class
|
|
97
|
+
pretrained: Use ImageNet pretrained weights (default: True)
|
|
98
|
+
dropout_rate: Dropout rate in regression head (default: 0.3)
|
|
99
|
+
freeze_backbone: Freeze backbone for fine-tuning (default: False)
|
|
100
|
+
regression_hidden: Hidden units in regression head (default: 512)
|
|
101
|
+
"""
|
|
102
|
+
super().__init__(in_shape, out_size)
|
|
103
|
+
|
|
104
|
+
if not SWIN_AVAILABLE:
|
|
105
|
+
raise ImportError(
|
|
106
|
+
"torchvision >= 0.12 is required for Swin Transformer. "
|
|
107
|
+
"Install with: pip install torchvision>=0.12"
|
|
108
|
+
)
|
|
109
|
+
|
|
110
|
+
if len(in_shape) != 2:
|
|
111
|
+
raise ValueError(
|
|
112
|
+
f"Swin Transformer requires 2D input (H, W), got {len(in_shape)}D. "
|
|
113
|
+
"For 1D data, use TCN. For 3D data, use ResNet3D."
|
|
114
|
+
)
|
|
115
|
+
|
|
116
|
+
self.pretrained = pretrained
|
|
117
|
+
self.dropout_rate = dropout_rate
|
|
118
|
+
self.freeze_backbone = freeze_backbone
|
|
119
|
+
self.regression_hidden = regression_hidden
|
|
120
|
+
|
|
121
|
+
# Load pretrained backbone
|
|
122
|
+
weights = weights_class.IMAGENET1K_V1 if pretrained else None
|
|
123
|
+
self.backbone = model_fn(weights=weights)
|
|
124
|
+
|
|
125
|
+
# Swin Transformer head structure:
|
|
126
|
+
# head: Linear (embed_dim → num_classes)
|
|
127
|
+
# We need to get the embedding dimension from the head
|
|
128
|
+
|
|
129
|
+
in_features = self.backbone.head.in_features
|
|
130
|
+
|
|
131
|
+
# Replace head with regression head
|
|
132
|
+
# Use LayerNorm for stability (matches Transformer architecture)
|
|
133
|
+
self.backbone.head = nn.Sequential(
|
|
134
|
+
nn.LayerNorm(in_features),
|
|
135
|
+
nn.Dropout(dropout_rate),
|
|
136
|
+
nn.Linear(in_features, regression_hidden),
|
|
137
|
+
nn.GELU(), # GELU matches Transformer's activation
|
|
138
|
+
nn.Dropout(dropout_rate * 0.5),
|
|
139
|
+
nn.Linear(regression_hidden, regression_hidden // 2),
|
|
140
|
+
nn.GELU(),
|
|
141
|
+
nn.Linear(regression_hidden // 2, out_size),
|
|
142
|
+
)
|
|
143
|
+
|
|
144
|
+
# Optionally freeze backbone for fine-tuning
|
|
145
|
+
if freeze_backbone:
|
|
146
|
+
self._freeze_backbone()
|
|
147
|
+
|
|
148
|
+
def _freeze_backbone(self):
|
|
149
|
+
"""Freeze all backbone parameters except the head."""
|
|
150
|
+
for name, param in self.backbone.named_parameters():
|
|
151
|
+
if "head" not in name:
|
|
152
|
+
param.requires_grad = False
|
|
153
|
+
|
|
154
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
155
|
+
"""
|
|
156
|
+
Forward pass.
|
|
157
|
+
|
|
158
|
+
Args:
|
|
159
|
+
x: Input tensor of shape (B, C, H, W) where C is 1 or 3
|
|
160
|
+
|
|
161
|
+
Returns:
|
|
162
|
+
Output tensor of shape (B, out_size)
|
|
163
|
+
"""
|
|
164
|
+
# Expand single channel to 3 channels for pretrained weights compatibility
|
|
165
|
+
if x.size(1) == 1:
|
|
166
|
+
x = x.expand(-1, 3, -1, -1)
|
|
167
|
+
|
|
168
|
+
return self.backbone(x)
|
|
169
|
+
|
|
170
|
+
@classmethod
|
|
171
|
+
def get_default_config(cls) -> dict[str, Any]:
|
|
172
|
+
"""Return default configuration for Swin Transformer."""
|
|
173
|
+
return {
|
|
174
|
+
"pretrained": True,
|
|
175
|
+
"dropout_rate": 0.3,
|
|
176
|
+
"freeze_backbone": False,
|
|
177
|
+
"regression_hidden": 512,
|
|
178
|
+
}
|
|
179
|
+
|
|
180
|
+
def get_optimizer_groups(self, base_lr: float, weight_decay: float = 0.05) -> list:
|
|
181
|
+
"""
|
|
182
|
+
Get parameter groups with layer-wise learning rate decay.
|
|
183
|
+
|
|
184
|
+
Swin Transformer benefits from decaying learning rate for earlier layers.
|
|
185
|
+
This is a common practice for fine-tuning vision transformers.
|
|
186
|
+
|
|
187
|
+
Args:
|
|
188
|
+
base_lr: Base learning rate (applied to head)
|
|
189
|
+
weight_decay: Weight decay coefficient
|
|
190
|
+
|
|
191
|
+
Returns:
|
|
192
|
+
List of parameter group dictionaries
|
|
193
|
+
"""
|
|
194
|
+
# Separate parameters: head (full LR) vs backbone (decayed LR)
|
|
195
|
+
head_params = []
|
|
196
|
+
backbone_params = []
|
|
197
|
+
no_decay_params = []
|
|
198
|
+
|
|
199
|
+
for name, param in self.backbone.named_parameters():
|
|
200
|
+
if not param.requires_grad:
|
|
201
|
+
continue
|
|
202
|
+
|
|
203
|
+
# No weight decay for bias and normalization
|
|
204
|
+
if "bias" in name or "norm" in name:
|
|
205
|
+
no_decay_params.append(param)
|
|
206
|
+
elif "head" in name:
|
|
207
|
+
head_params.append(param)
|
|
208
|
+
else:
|
|
209
|
+
backbone_params.append(param)
|
|
210
|
+
|
|
211
|
+
groups = []
|
|
212
|
+
|
|
213
|
+
if head_params:
|
|
214
|
+
groups.append(
|
|
215
|
+
{
|
|
216
|
+
"params": head_params,
|
|
217
|
+
"lr": base_lr,
|
|
218
|
+
"weight_decay": weight_decay,
|
|
219
|
+
}
|
|
220
|
+
)
|
|
221
|
+
|
|
222
|
+
if backbone_params:
|
|
223
|
+
# Apply 0.1x learning rate to backbone (common for fine-tuning)
|
|
224
|
+
groups.append(
|
|
225
|
+
{
|
|
226
|
+
"params": backbone_params,
|
|
227
|
+
"lr": base_lr * 0.1,
|
|
228
|
+
"weight_decay": weight_decay,
|
|
229
|
+
}
|
|
230
|
+
)
|
|
231
|
+
|
|
232
|
+
if no_decay_params:
|
|
233
|
+
groups.append(
|
|
234
|
+
{
|
|
235
|
+
"params": no_decay_params,
|
|
236
|
+
"lr": base_lr,
|
|
237
|
+
"weight_decay": 0.0,
|
|
238
|
+
}
|
|
239
|
+
)
|
|
240
|
+
|
|
241
|
+
return groups if groups else [{"params": self.parameters(), "lr": base_lr}]
|
|
242
|
+
|
|
243
|
+
|
|
244
|
+
# =============================================================================
|
|
245
|
+
# REGISTERED MODEL VARIANTS
|
|
246
|
+
# =============================================================================
|
|
247
|
+
|
|
248
|
+
|
|
249
|
+
@register_model("swin_t")
|
|
250
|
+
class SwinTiny(SwinTransformerBase):
|
|
251
|
+
"""
|
|
252
|
+
Swin-T (Tiny): Efficient default for most wave-based tasks.
|
|
253
|
+
|
|
254
|
+
~28M parameters. Good balance of accuracy and computational cost.
|
|
255
|
+
Outperforms ResNet50 while being more efficient.
|
|
256
|
+
|
|
257
|
+
Recommended for:
|
|
258
|
+
- Default choice for 2D wave data
|
|
259
|
+
- Dispersion curves, spectrograms, B-scans
|
|
260
|
+
- When hierarchical features matter
|
|
261
|
+
- Transfer learning with limited data
|
|
262
|
+
|
|
263
|
+
Architecture:
|
|
264
|
+
- Patch size: 4×4
|
|
265
|
+
- Window size: 7×7
|
|
266
|
+
- Embed dim: 96
|
|
267
|
+
- Depths: [2, 2, 6, 2]
|
|
268
|
+
- Heads: [3, 6, 12, 24]
|
|
269
|
+
|
|
270
|
+
Args:
|
|
271
|
+
in_shape: (H, W) image dimensions
|
|
272
|
+
out_size: Number of regression targets
|
|
273
|
+
pretrained: Use ImageNet pretrained weights (default: True)
|
|
274
|
+
dropout_rate: Dropout rate in head (default: 0.3)
|
|
275
|
+
freeze_backbone: Freeze backbone for fine-tuning (default: False)
|
|
276
|
+
regression_hidden: Hidden units in regression head (default: 512)
|
|
277
|
+
|
|
278
|
+
Example:
|
|
279
|
+
>>> model = SwinTiny(in_shape=(224, 224), out_size=3)
|
|
280
|
+
>>> x = torch.randn(4, 1, 224, 224)
|
|
281
|
+
>>> out = model(x) # (4, 3)
|
|
282
|
+
"""
|
|
283
|
+
|
|
284
|
+
def __init__(self, in_shape: tuple[int, int], out_size: int, **kwargs):
|
|
285
|
+
super().__init__(
|
|
286
|
+
in_shape=in_shape,
|
|
287
|
+
out_size=out_size,
|
|
288
|
+
model_fn=swin_t,
|
|
289
|
+
weights_class=Swin_T_Weights,
|
|
290
|
+
**kwargs,
|
|
291
|
+
)
|
|
292
|
+
|
|
293
|
+
def __repr__(self) -> str:
|
|
294
|
+
pt = "pretrained" if self.pretrained else "scratch"
|
|
295
|
+
return f"Swin_Tiny({pt}, in={self.in_shape}, out={self.out_size})"
|
|
296
|
+
|
|
297
|
+
|
|
298
|
+
@register_model("swin_s")
|
|
299
|
+
class SwinSmall(SwinTransformerBase):
|
|
300
|
+
"""
|
|
301
|
+
Swin-S (Small): Higher accuracy with moderate compute.
|
|
302
|
+
|
|
303
|
+
~50M parameters. Better accuracy than Swin-T for larger datasets.
|
|
304
|
+
|
|
305
|
+
Recommended for:
|
|
306
|
+
- Larger datasets (>20k samples)
|
|
307
|
+
- When Swin-T doesn't provide enough capacity
|
|
308
|
+
- Complex multi-scale patterns
|
|
309
|
+
|
|
310
|
+
Architecture:
|
|
311
|
+
- Patch size: 4×4
|
|
312
|
+
- Window size: 7×7
|
|
313
|
+
- Embed dim: 96
|
|
314
|
+
- Depths: [2, 2, 18, 2] (deeper stage 3)
|
|
315
|
+
- Heads: [3, 6, 12, 24]
|
|
316
|
+
|
|
317
|
+
Args:
|
|
318
|
+
in_shape: (H, W) image dimensions
|
|
319
|
+
out_size: Number of regression targets
|
|
320
|
+
pretrained: Use ImageNet pretrained weights (default: True)
|
|
321
|
+
dropout_rate: Dropout rate in head (default: 0.3)
|
|
322
|
+
freeze_backbone: Freeze backbone for fine-tuning (default: False)
|
|
323
|
+
regression_hidden: Hidden units in regression head (default: 512)
|
|
324
|
+
|
|
325
|
+
Example:
|
|
326
|
+
>>> model = SwinSmall(in_shape=(224, 224), out_size=3)
|
|
327
|
+
>>> x = torch.randn(4, 1, 224, 224)
|
|
328
|
+
>>> out = model(x) # (4, 3)
|
|
329
|
+
"""
|
|
330
|
+
|
|
331
|
+
def __init__(self, in_shape: tuple[int, int], out_size: int, **kwargs):
|
|
332
|
+
super().__init__(
|
|
333
|
+
in_shape=in_shape,
|
|
334
|
+
out_size=out_size,
|
|
335
|
+
model_fn=swin_s,
|
|
336
|
+
weights_class=Swin_S_Weights,
|
|
337
|
+
**kwargs,
|
|
338
|
+
)
|
|
339
|
+
|
|
340
|
+
def __repr__(self) -> str:
|
|
341
|
+
pt = "pretrained" if self.pretrained else "scratch"
|
|
342
|
+
return f"Swin_Small({pt}, in={self.in_shape}, out={self.out_size})"
|
|
343
|
+
|
|
344
|
+
|
|
345
|
+
@register_model("swin_b")
|
|
346
|
+
class SwinBase(SwinTransformerBase):
|
|
347
|
+
"""
|
|
348
|
+
Swin-B (Base): Maximum accuracy for large-scale tasks.
|
|
349
|
+
|
|
350
|
+
~88M parameters. Best accuracy but requires more compute and data.
|
|
351
|
+
|
|
352
|
+
Recommended for:
|
|
353
|
+
- Very large datasets (>50k samples)
|
|
354
|
+
- When accuracy is more important than efficiency
|
|
355
|
+
- HPC environments with ample GPU memory
|
|
356
|
+
- Research experiments
|
|
357
|
+
|
|
358
|
+
Architecture:
|
|
359
|
+
- Patch size: 4×4
|
|
360
|
+
- Window size: 7×7
|
|
361
|
+
- Embed dim: 128
|
|
362
|
+
- Depths: [2, 2, 18, 2]
|
|
363
|
+
- Heads: [4, 8, 16, 32]
|
|
364
|
+
|
|
365
|
+
Args:
|
|
366
|
+
in_shape: (H, W) image dimensions
|
|
367
|
+
out_size: Number of regression targets
|
|
368
|
+
pretrained: Use ImageNet pretrained weights (default: True)
|
|
369
|
+
dropout_rate: Dropout rate in head (default: 0.3)
|
|
370
|
+
freeze_backbone: Freeze backbone for fine-tuning (default: False)
|
|
371
|
+
regression_hidden: Hidden units in regression head (default: 512)
|
|
372
|
+
|
|
373
|
+
Example:
|
|
374
|
+
>>> model = SwinBase(in_shape=(224, 224), out_size=3)
|
|
375
|
+
>>> x = torch.randn(4, 1, 224, 224)
|
|
376
|
+
>>> out = model(x) # (4, 3)
|
|
377
|
+
"""
|
|
378
|
+
|
|
379
|
+
def __init__(self, in_shape: tuple[int, int], out_size: int, **kwargs):
|
|
380
|
+
super().__init__(
|
|
381
|
+
in_shape=in_shape,
|
|
382
|
+
out_size=out_size,
|
|
383
|
+
model_fn=swin_b,
|
|
384
|
+
weights_class=Swin_B_Weights,
|
|
385
|
+
**kwargs,
|
|
386
|
+
)
|
|
387
|
+
|
|
388
|
+
def __repr__(self) -> str:
|
|
389
|
+
pt = "pretrained" if self.pretrained else "scratch"
|
|
390
|
+
return f"Swin_Base({pt}, in={self.in_shape}, out={self.out_size})"
|