wavedl 1.5.7__py3-none-any.whl → 1.6.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (39) hide show
  1. wavedl/__init__.py +1 -1
  2. wavedl/hpo.py +451 -451
  3. wavedl/models/__init__.py +80 -4
  4. wavedl/models/_pretrained_utils.py +366 -0
  5. wavedl/models/base.py +48 -0
  6. wavedl/models/caformer.py +270 -0
  7. wavedl/models/cnn.py +2 -27
  8. wavedl/models/convnext.py +113 -51
  9. wavedl/models/convnext_v2.py +488 -0
  10. wavedl/models/densenet.py +10 -23
  11. wavedl/models/efficientnet.py +6 -6
  12. wavedl/models/efficientnetv2.py +315 -315
  13. wavedl/models/efficientvit.py +398 -0
  14. wavedl/models/fastvit.py +252 -0
  15. wavedl/models/mamba.py +555 -0
  16. wavedl/models/maxvit.py +254 -0
  17. wavedl/models/mobilenetv3.py +295 -295
  18. wavedl/models/regnet.py +406 -406
  19. wavedl/models/resnet.py +19 -61
  20. wavedl/models/resnet3d.py +258 -258
  21. wavedl/models/swin.py +443 -443
  22. wavedl/models/tcn.py +393 -409
  23. wavedl/models/unet.py +2 -6
  24. wavedl/models/unireplknet.py +491 -0
  25. wavedl/models/vit.py +9 -9
  26. wavedl/train.py +1430 -1425
  27. wavedl/utils/config.py +367 -367
  28. wavedl/utils/cross_validation.py +530 -530
  29. wavedl/utils/data.py +39 -6
  30. wavedl/utils/losses.py +216 -216
  31. wavedl/utils/optimizers.py +216 -216
  32. wavedl/utils/schedulers.py +251 -251
  33. {wavedl-1.5.7.dist-info → wavedl-1.6.1.dist-info}/METADATA +150 -82
  34. wavedl-1.6.1.dist-info/RECORD +46 -0
  35. wavedl-1.5.7.dist-info/RECORD +0 -38
  36. {wavedl-1.5.7.dist-info → wavedl-1.6.1.dist-info}/LICENSE +0 -0
  37. {wavedl-1.5.7.dist-info → wavedl-1.6.1.dist-info}/WHEEL +0 -0
  38. {wavedl-1.5.7.dist-info → wavedl-1.6.1.dist-info}/entry_points.txt +0 -0
  39. {wavedl-1.5.7.dist-info → wavedl-1.6.1.dist-info}/top_level.txt +0 -0
wavedl/models/swin.py CHANGED
@@ -1,443 +1,443 @@
1
- """
2
- Swin Transformer: Hierarchical Vision Transformer with Shifted Windows
3
- =======================================================================
4
-
5
- State-of-the-art vision transformer that computes self-attention within
6
- local windows while enabling cross-window connections via shifting.
7
- Achieves excellent accuracy with linear computational complexity.
8
-
9
- **Key Innovations**:
10
- - Hierarchical feature maps (like CNNs) for multi-scale processing
11
- - Shifted window attention: O(n) complexity vs O(n²) for vanilla ViT
12
- - Local attention with global receptive field through layer stacking
13
- - Strong inductive bias for structured data
14
-
15
- **Variants**:
16
- - swin_t: Tiny (28M params) - Efficient default
17
- - swin_s: Small (50M params) - Better accuracy
18
- - swin_b: Base (88M params) - High accuracy
19
-
20
- **Why Swin over ViT?**:
21
- - Better for smaller datasets (stronger inductive bias)
22
- - Handles higher resolution inputs efficiently
23
- - Produces hierarchical features (useful for multi-scale patterns)
24
- - More efficient memory usage
25
-
26
- **Note**: Swin Transformer is 2D-only. For 1D data, use TCN. For 3D data, use ResNet3D.
27
-
28
- References:
29
- Liu, Z., et al. (2021). Swin Transformer: Hierarchical Vision Transformer
30
- using Shifted Windows. ICCV 2021 (Best Paper). https://arxiv.org/abs/2103.14030
31
-
32
- Author: Ductho Le (ductho.le@outlook.com)
33
- """
34
-
35
- from typing import Any
36
-
37
- import torch
38
- import torch.nn as nn
39
-
40
-
41
- try:
42
- from torchvision.models import (
43
- Swin_B_Weights,
44
- Swin_S_Weights,
45
- Swin_T_Weights,
46
- swin_b,
47
- swin_s,
48
- swin_t,
49
- )
50
-
51
- SWIN_AVAILABLE = True
52
- except ImportError:
53
- SWIN_AVAILABLE = False
54
-
55
- from wavedl.models.base import BaseModel
56
- from wavedl.models.registry import register_model
57
-
58
-
59
- class SwinTransformerBase(BaseModel):
60
- """
61
- Base Swin Transformer class for regression tasks.
62
-
63
- Wraps torchvision Swin Transformer with:
64
- - Optional pretrained weights (ImageNet-1K or ImageNet-22K)
65
- - Automatic input channel adaptation (grayscale → 3ch)
66
- - Custom regression head with layer normalization
67
-
68
- Swin Transformer excels at:
69
- - Multi-scale feature extraction (dispersion curves, spectrograms)
70
- - High-resolution inputs (efficient O(n) attention)
71
- - Tasks requiring both local and global context
72
- - Transfer learning from pretrained weights
73
-
74
- Note: This is 2D-only. Input shape must be (H, W).
75
- """
76
-
77
- def __init__(
78
- self,
79
- in_shape: tuple[int, int],
80
- out_size: int,
81
- model_fn,
82
- weights_class,
83
- pretrained: bool = True,
84
- dropout_rate: float = 0.3,
85
- freeze_backbone: bool = False,
86
- regression_hidden: int = 512,
87
- **kwargs,
88
- ):
89
- """
90
- Initialize Swin Transformer for regression.
91
-
92
- Args:
93
- in_shape: (H, W) input image dimensions
94
- out_size: Number of regression output targets
95
- model_fn: torchvision model constructor
96
- weights_class: Pretrained weights enum class
97
- pretrained: Use ImageNet pretrained weights (default: True)
98
- dropout_rate: Dropout rate in regression head (default: 0.3)
99
- freeze_backbone: Freeze backbone for fine-tuning (default: False)
100
- regression_hidden: Hidden units in regression head (default: 512)
101
- """
102
- super().__init__(in_shape, out_size)
103
-
104
- if not SWIN_AVAILABLE:
105
- raise ImportError(
106
- "torchvision >= 0.12 is required for Swin Transformer. "
107
- "Install with: pip install torchvision>=0.12"
108
- )
109
-
110
- if len(in_shape) != 2:
111
- raise ValueError(
112
- f"Swin Transformer requires 2D input (H, W), got {len(in_shape)}D. "
113
- "For 1D data, use TCN. For 3D data, use ResNet3D."
114
- )
115
-
116
- self.pretrained = pretrained
117
- self.dropout_rate = dropout_rate
118
- self.freeze_backbone = freeze_backbone
119
- self.regression_hidden = regression_hidden
120
-
121
- # Load pretrained backbone
122
- weights = weights_class.IMAGENET1K_V1 if pretrained else None
123
- self.backbone = model_fn(weights=weights)
124
-
125
- # Swin Transformer head structure:
126
- # head: Linear (embed_dim → num_classes)
127
- # We need to get the embedding dimension from the head
128
-
129
- in_features = self.backbone.head.in_features
130
-
131
- # Replace head with regression head
132
- # Use LayerNorm for stability (matches Transformer architecture)
133
- self.backbone.head = nn.Sequential(
134
- nn.LayerNorm(in_features),
135
- nn.Dropout(dropout_rate),
136
- nn.Linear(in_features, regression_hidden),
137
- nn.GELU(), # GELU matches Transformer's activation
138
- nn.Dropout(dropout_rate * 0.5),
139
- nn.Linear(regression_hidden, regression_hidden // 2),
140
- nn.GELU(),
141
- nn.Linear(regression_hidden // 2, out_size),
142
- )
143
-
144
- # Adapt patch embedding conv for single-channel input (3× memory savings vs expand)
145
- self._adapt_input_channels()
146
-
147
- # Optionally freeze backbone for fine-tuning (after adaptation so new conv is frozen too)
148
- if freeze_backbone:
149
- self._freeze_backbone()
150
-
151
- def _adapt_input_channels(self):
152
- """Modify patch embedding conv to accept single-channel input.
153
-
154
- Instead of expanding 1→3 channels in forward (which triples memory),
155
- we replace the patch embedding conv with a 1-channel version and
156
- initialize weights as the mean of the pretrained RGB filters.
157
- """
158
- # Swin's patch embedding is at features[0][0]
159
- try:
160
- old_conv = self.backbone.features[0][0]
161
- except (IndexError, AttributeError, TypeError) as e:
162
- raise RuntimeError(
163
- f"Swin patch embed structure changed in this torchvision version. "
164
- f"Cannot adapt input channels. Error: {e}"
165
- ) from e
166
- new_conv = nn.Conv2d(
167
- 1, # Single channel input
168
- old_conv.out_channels,
169
- kernel_size=old_conv.kernel_size,
170
- stride=old_conv.stride,
171
- padding=old_conv.padding,
172
- dilation=old_conv.dilation,
173
- groups=old_conv.groups,
174
- padding_mode=old_conv.padding_mode,
175
- bias=old_conv.bias is not None,
176
- )
177
- if self.pretrained:
178
- with torch.no_grad():
179
- new_conv.weight.copy_(old_conv.weight.mean(dim=1, keepdim=True))
180
- if old_conv.bias is not None:
181
- new_conv.bias.copy_(old_conv.bias)
182
- self.backbone.features[0][0] = new_conv
183
-
184
- def _freeze_backbone(self):
185
- """Freeze all backbone parameters except the head."""
186
- for name, param in self.backbone.named_parameters():
187
- if "head" not in name:
188
- param.requires_grad = False
189
-
190
- def forward(self, x: torch.Tensor) -> torch.Tensor:
191
- """
192
- Forward pass.
193
-
194
- Args:
195
- x: Input tensor of shape (B, 1, H, W)
196
-
197
- Returns:
198
- Output tensor of shape (B, out_size)
199
- """
200
- return self.backbone(x)
201
-
202
- @classmethod
203
- def get_default_config(cls) -> dict[str, Any]:
204
- """Return default configuration for Swin Transformer."""
205
- return {
206
- "pretrained": True,
207
- "dropout_rate": 0.3,
208
- "freeze_backbone": False,
209
- "regression_hidden": 512,
210
- }
211
-
212
- def get_optimizer_groups(self, base_lr: float, weight_decay: float = 0.05) -> list:
213
- """
214
- Get parameter groups with layer-wise learning rate decay.
215
-
216
- Swin Transformer benefits from decaying learning rate for earlier layers.
217
- This is a common practice for fine-tuning vision transformers.
218
-
219
- Args:
220
- base_lr: Base learning rate (applied to head)
221
- weight_decay: Weight decay coefficient
222
-
223
- Returns:
224
- List of parameter group dictionaries
225
- """
226
- # Separate parameters into 4 groups for proper LR decay:
227
- # 1. Head params with decay (full LR)
228
- # 2. Backbone params with decay (0.1× LR)
229
- # 3. Head bias/norm without decay (full LR)
230
- # 4. Backbone bias/norm without decay (0.1× LR)
231
- head_params = []
232
- backbone_params = []
233
- head_no_decay = []
234
- backbone_no_decay = []
235
-
236
- for name, param in self.backbone.named_parameters():
237
- if not param.requires_grad:
238
- continue
239
-
240
- is_head = "head" in name
241
- is_no_decay = "bias" in name or "norm" in name
242
-
243
- if is_head:
244
- if is_no_decay:
245
- head_no_decay.append(param)
246
- else:
247
- head_params.append(param)
248
- else:
249
- if is_no_decay:
250
- backbone_no_decay.append(param)
251
- else:
252
- backbone_params.append(param)
253
-
254
- groups = []
255
-
256
- if head_params:
257
- groups.append(
258
- {
259
- "params": head_params,
260
- "lr": base_lr,
261
- "weight_decay": weight_decay,
262
- }
263
- )
264
-
265
- if backbone_params:
266
- # Apply 0.1x learning rate to backbone (common for fine-tuning)
267
- groups.append(
268
- {
269
- "params": backbone_params,
270
- "lr": base_lr * 0.1,
271
- "weight_decay": weight_decay,
272
- }
273
- )
274
-
275
- if head_no_decay:
276
- groups.append(
277
- {
278
- "params": head_no_decay,
279
- "lr": base_lr,
280
- "weight_decay": 0.0,
281
- }
282
- )
283
-
284
- if backbone_no_decay:
285
- # Backbone bias/norm also gets 0.1× LR to match intended decay
286
- groups.append(
287
- {
288
- "params": backbone_no_decay,
289
- "lr": base_lr * 0.1,
290
- "weight_decay": 0.0,
291
- }
292
- )
293
-
294
- return groups if groups else [{"params": self.parameters(), "lr": base_lr}]
295
-
296
-
297
- # =============================================================================
298
- # REGISTERED MODEL VARIANTS
299
- # =============================================================================
300
-
301
-
302
- @register_model("swin_t")
303
- class SwinTiny(SwinTransformerBase):
304
- """
305
- Swin-T (Tiny): Efficient default for most wave-based tasks.
306
-
307
- ~28M parameters. Good balance of accuracy and computational cost.
308
- Outperforms ResNet50 while being more efficient.
309
-
310
- Recommended for:
311
- - Default choice for 2D wave data
312
- - Dispersion curves, spectrograms, B-scans
313
- - When hierarchical features matter
314
- - Transfer learning with limited data
315
-
316
- Architecture:
317
- - Patch size: 4×4
318
- - Window size: 7×7
319
- - Embed dim: 96
320
- - Depths: [2, 2, 6, 2]
321
- - Heads: [3, 6, 12, 24]
322
-
323
- Args:
324
- in_shape: (H, W) image dimensions
325
- out_size: Number of regression targets
326
- pretrained: Use ImageNet pretrained weights (default: True)
327
- dropout_rate: Dropout rate in head (default: 0.3)
328
- freeze_backbone: Freeze backbone for fine-tuning (default: False)
329
- regression_hidden: Hidden units in regression head (default: 512)
330
-
331
- Example:
332
- >>> model = SwinTiny(in_shape=(224, 224), out_size=3)
333
- >>> x = torch.randn(4, 1, 224, 224)
334
- >>> out = model(x) # (4, 3)
335
- """
336
-
337
- def __init__(self, in_shape: tuple[int, int], out_size: int, **kwargs):
338
- super().__init__(
339
- in_shape=in_shape,
340
- out_size=out_size,
341
- model_fn=swin_t,
342
- weights_class=Swin_T_Weights,
343
- **kwargs,
344
- )
345
-
346
- def __repr__(self) -> str:
347
- pt = "pretrained" if self.pretrained else "scratch"
348
- return f"Swin_Tiny({pt}, in={self.in_shape}, out={self.out_size})"
349
-
350
-
351
- @register_model("swin_s")
352
- class SwinSmall(SwinTransformerBase):
353
- """
354
- Swin-S (Small): Higher accuracy with moderate compute.
355
-
356
- ~50M parameters. Better accuracy than Swin-T for larger datasets.
357
-
358
- Recommended for:
359
- - Larger datasets (>20k samples)
360
- - When Swin-T doesn't provide enough capacity
361
- - Complex multi-scale patterns
362
-
363
- Architecture:
364
- - Patch size: 4×4
365
- - Window size: 7×7
366
- - Embed dim: 96
367
- - Depths: [2, 2, 18, 2] (deeper stage 3)
368
- - Heads: [3, 6, 12, 24]
369
-
370
- Args:
371
- in_shape: (H, W) image dimensions
372
- out_size: Number of regression targets
373
- pretrained: Use ImageNet pretrained weights (default: True)
374
- dropout_rate: Dropout rate in head (default: 0.3)
375
- freeze_backbone: Freeze backbone for fine-tuning (default: False)
376
- regression_hidden: Hidden units in regression head (default: 512)
377
-
378
- Example:
379
- >>> model = SwinSmall(in_shape=(224, 224), out_size=3)
380
- >>> x = torch.randn(4, 1, 224, 224)
381
- >>> out = model(x) # (4, 3)
382
- """
383
-
384
- def __init__(self, in_shape: tuple[int, int], out_size: int, **kwargs):
385
- super().__init__(
386
- in_shape=in_shape,
387
- out_size=out_size,
388
- model_fn=swin_s,
389
- weights_class=Swin_S_Weights,
390
- **kwargs,
391
- )
392
-
393
- def __repr__(self) -> str:
394
- pt = "pretrained" if self.pretrained else "scratch"
395
- return f"Swin_Small({pt}, in={self.in_shape}, out={self.out_size})"
396
-
397
-
398
- @register_model("swin_b")
399
- class SwinBase(SwinTransformerBase):
400
- """
401
- Swin-B (Base): Maximum accuracy for large-scale tasks.
402
-
403
- ~88M parameters. Best accuracy but requires more compute and data.
404
-
405
- Recommended for:
406
- - Very large datasets (>50k samples)
407
- - When accuracy is more important than efficiency
408
- - HPC environments with ample GPU memory
409
- - Research experiments
410
-
411
- Architecture:
412
- - Patch size: 4×4
413
- - Window size: 7×7
414
- - Embed dim: 128
415
- - Depths: [2, 2, 18, 2]
416
- - Heads: [4, 8, 16, 32]
417
-
418
- Args:
419
- in_shape: (H, W) image dimensions
420
- out_size: Number of regression targets
421
- pretrained: Use ImageNet pretrained weights (default: True)
422
- dropout_rate: Dropout rate in head (default: 0.3)
423
- freeze_backbone: Freeze backbone for fine-tuning (default: False)
424
- regression_hidden: Hidden units in regression head (default: 512)
425
-
426
- Example:
427
- >>> model = SwinBase(in_shape=(224, 224), out_size=3)
428
- >>> x = torch.randn(4, 1, 224, 224)
429
- >>> out = model(x) # (4, 3)
430
- """
431
-
432
- def __init__(self, in_shape: tuple[int, int], out_size: int, **kwargs):
433
- super().__init__(
434
- in_shape=in_shape,
435
- out_size=out_size,
436
- model_fn=swin_b,
437
- weights_class=Swin_B_Weights,
438
- **kwargs,
439
- )
440
-
441
- def __repr__(self) -> str:
442
- pt = "pretrained" if self.pretrained else "scratch"
443
- return f"Swin_Base({pt}, in={self.in_shape}, out={self.out_size})"
1
+ """
2
+ Swin Transformer: Hierarchical Vision Transformer with Shifted Windows
3
+ =======================================================================
4
+
5
+ State-of-the-art vision transformer that computes self-attention within
6
+ local windows while enabling cross-window connections via shifting.
7
+ Achieves excellent accuracy with linear computational complexity.
8
+
9
+ **Key Innovations**:
10
+ - Hierarchical feature maps (like CNNs) for multi-scale processing
11
+ - Shifted window attention: O(n) complexity vs O(n²) for vanilla ViT
12
+ - Local attention with global receptive field through layer stacking
13
+ - Strong inductive bias for structured data
14
+
15
+ **Variants**:
16
+ - swin_t: Tiny (28M params) - Efficient default
17
+ - swin_s: Small (50M params) - Better accuracy
18
+ - swin_b: Base (88M params) - High accuracy
19
+
20
+ **Why Swin over ViT?**:
21
+ - Better for smaller datasets (stronger inductive bias)
22
+ - Handles higher resolution inputs efficiently
23
+ - Produces hierarchical features (useful for multi-scale patterns)
24
+ - More efficient memory usage
25
+
26
+ **Note**: Swin Transformer is 2D-only. For 1D data, use TCN. For 3D data, use ResNet3D.
27
+
28
+ References:
29
+ Liu, Z., et al. (2021). Swin Transformer: Hierarchical Vision Transformer
30
+ using Shifted Windows. ICCV 2021 (Best Paper). https://arxiv.org/abs/2103.14030
31
+
32
+ Author: Ductho Le (ductho.le@outlook.com)
33
+ """
34
+
35
+ from typing import Any
36
+
37
+ import torch
38
+ import torch.nn as nn
39
+
40
+
41
+ try:
42
+ from torchvision.models import (
43
+ Swin_B_Weights,
44
+ Swin_S_Weights,
45
+ Swin_T_Weights,
46
+ swin_b,
47
+ swin_s,
48
+ swin_t,
49
+ )
50
+
51
+ SWIN_AVAILABLE = True
52
+ except ImportError:
53
+ SWIN_AVAILABLE = False
54
+
55
+ from wavedl.models.base import BaseModel
56
+ from wavedl.models.registry import register_model
57
+
58
+
59
+ class SwinTransformerBase(BaseModel):
60
+ """
61
+ Base Swin Transformer class for regression tasks.
62
+
63
+ Wraps torchvision Swin Transformer with:
64
+ - Optional pretrained weights (ImageNet-1K or ImageNet-22K)
65
+ - Automatic input channel adaptation (grayscale → 3ch)
66
+ - Custom regression head with layer normalization
67
+
68
+ Swin Transformer excels at:
69
+ - Multi-scale feature extraction (dispersion curves, spectrograms)
70
+ - High-resolution inputs (efficient O(n) attention)
71
+ - Tasks requiring both local and global context
72
+ - Transfer learning from pretrained weights
73
+
74
+ Note: This is 2D-only. Input shape must be (H, W).
75
+ """
76
+
77
+ def __init__(
78
+ self,
79
+ in_shape: tuple[int, int],
80
+ out_size: int,
81
+ model_fn,
82
+ weights_class,
83
+ pretrained: bool = True,
84
+ dropout_rate: float = 0.3,
85
+ freeze_backbone: bool = False,
86
+ regression_hidden: int = 512,
87
+ **kwargs,
88
+ ):
89
+ """
90
+ Initialize Swin Transformer for regression.
91
+
92
+ Args:
93
+ in_shape: (H, W) input image dimensions
94
+ out_size: Number of regression output targets
95
+ model_fn: torchvision model constructor
96
+ weights_class: Pretrained weights enum class
97
+ pretrained: Use ImageNet pretrained weights (default: True)
98
+ dropout_rate: Dropout rate in regression head (default: 0.3)
99
+ freeze_backbone: Freeze backbone for fine-tuning (default: False)
100
+ regression_hidden: Hidden units in regression head (default: 512)
101
+ """
102
+ super().__init__(in_shape, out_size)
103
+
104
+ if not SWIN_AVAILABLE:
105
+ raise ImportError(
106
+ "torchvision >= 0.12 is required for Swin Transformer. "
107
+ "Install with: pip install torchvision>=0.12"
108
+ )
109
+
110
+ if len(in_shape) != 2:
111
+ raise ValueError(
112
+ f"Swin Transformer requires 2D input (H, W), got {len(in_shape)}D. "
113
+ "For 1D data, use TCN. For 3D data, use ResNet3D."
114
+ )
115
+
116
+ self.pretrained = pretrained
117
+ self.dropout_rate = dropout_rate
118
+ self.freeze_backbone = freeze_backbone
119
+ self.regression_hidden = regression_hidden
120
+
121
+ # Load pretrained backbone
122
+ weights = weights_class.IMAGENET1K_V1 if pretrained else None
123
+ self.backbone = model_fn(weights=weights)
124
+
125
+ # Swin Transformer head structure:
126
+ # head: Linear (embed_dim → num_classes)
127
+ # We need to get the embedding dimension from the head
128
+
129
+ in_features = self.backbone.head.in_features
130
+
131
+ # Replace head with regression head
132
+ # Use LayerNorm for stability (matches Transformer architecture)
133
+ self.backbone.head = nn.Sequential(
134
+ nn.LayerNorm(in_features),
135
+ nn.Dropout(dropout_rate),
136
+ nn.Linear(in_features, regression_hidden),
137
+ nn.GELU(), # GELU matches Transformer's activation
138
+ nn.Dropout(dropout_rate * 0.5),
139
+ nn.Linear(regression_hidden, regression_hidden // 2),
140
+ nn.GELU(),
141
+ nn.Linear(regression_hidden // 2, out_size),
142
+ )
143
+
144
+ # Adapt patch embedding conv for single-channel input (3× memory savings vs expand)
145
+ self._adapt_input_channels()
146
+
147
+ # Optionally freeze backbone for fine-tuning (after adaptation so new conv is frozen too)
148
+ if freeze_backbone:
149
+ self._freeze_backbone()
150
+
151
+ def _adapt_input_channels(self):
152
+ """Modify patch embedding conv to accept single-channel input.
153
+
154
+ Instead of expanding 1→3 channels in forward (which triples memory),
155
+ we replace the patch embedding conv with a 1-channel version and
156
+ initialize weights as the mean of the pretrained RGB filters.
157
+ """
158
+ # Swin's patch embedding is at features[0][0]
159
+ try:
160
+ old_conv = self.backbone.features[0][0]
161
+ except (IndexError, AttributeError, TypeError) as e:
162
+ raise RuntimeError(
163
+ f"Swin patch embed structure changed in this torchvision version. "
164
+ f"Cannot adapt input channels. Error: {e}"
165
+ ) from e
166
+ new_conv = nn.Conv2d(
167
+ 1, # Single channel input
168
+ old_conv.out_channels,
169
+ kernel_size=old_conv.kernel_size,
170
+ stride=old_conv.stride,
171
+ padding=old_conv.padding,
172
+ dilation=old_conv.dilation,
173
+ groups=old_conv.groups,
174
+ padding_mode=old_conv.padding_mode,
175
+ bias=old_conv.bias is not None,
176
+ )
177
+ if self.pretrained:
178
+ with torch.no_grad():
179
+ new_conv.weight.copy_(old_conv.weight.mean(dim=1, keepdim=True))
180
+ if old_conv.bias is not None:
181
+ new_conv.bias.copy_(old_conv.bias)
182
+ self.backbone.features[0][0] = new_conv
183
+
184
+ def _freeze_backbone(self):
185
+ """Freeze all backbone parameters except the head."""
186
+ for name, param in self.backbone.named_parameters():
187
+ if "head" not in name:
188
+ param.requires_grad = False
189
+
190
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
191
+ """
192
+ Forward pass.
193
+
194
+ Args:
195
+ x: Input tensor of shape (B, 1, H, W)
196
+
197
+ Returns:
198
+ Output tensor of shape (B, out_size)
199
+ """
200
+ return self.backbone(x)
201
+
202
+ @classmethod
203
+ def get_default_config(cls) -> dict[str, Any]:
204
+ """Return default configuration for Swin Transformer."""
205
+ return {
206
+ "pretrained": True,
207
+ "dropout_rate": 0.3,
208
+ "freeze_backbone": False,
209
+ "regression_hidden": 512,
210
+ }
211
+
212
+ def get_optimizer_groups(self, base_lr: float, weight_decay: float = 0.05) -> list:
213
+ """
214
+ Get parameter groups with layer-wise learning rate decay.
215
+
216
+ Swin Transformer benefits from decaying learning rate for earlier layers.
217
+ This is a common practice for fine-tuning vision transformers.
218
+
219
+ Args:
220
+ base_lr: Base learning rate (applied to head)
221
+ weight_decay: Weight decay coefficient
222
+
223
+ Returns:
224
+ List of parameter group dictionaries
225
+ """
226
+ # Separate parameters into 4 groups for proper LR decay:
227
+ # 1. Head params with decay (full LR)
228
+ # 2. Backbone params with decay (0.1× LR)
229
+ # 3. Head bias/norm without decay (full LR)
230
+ # 4. Backbone bias/norm without decay (0.1× LR)
231
+ head_params = []
232
+ backbone_params = []
233
+ head_no_decay = []
234
+ backbone_no_decay = []
235
+
236
+ for name, param in self.backbone.named_parameters():
237
+ if not param.requires_grad:
238
+ continue
239
+
240
+ is_head = "head" in name
241
+ is_no_decay = "bias" in name or "norm" in name
242
+
243
+ if is_head:
244
+ if is_no_decay:
245
+ head_no_decay.append(param)
246
+ else:
247
+ head_params.append(param)
248
+ else:
249
+ if is_no_decay:
250
+ backbone_no_decay.append(param)
251
+ else:
252
+ backbone_params.append(param)
253
+
254
+ groups = []
255
+
256
+ if head_params:
257
+ groups.append(
258
+ {
259
+ "params": head_params,
260
+ "lr": base_lr,
261
+ "weight_decay": weight_decay,
262
+ }
263
+ )
264
+
265
+ if backbone_params:
266
+ # Apply 0.1x learning rate to backbone (common for fine-tuning)
267
+ groups.append(
268
+ {
269
+ "params": backbone_params,
270
+ "lr": base_lr * 0.1,
271
+ "weight_decay": weight_decay,
272
+ }
273
+ )
274
+
275
+ if head_no_decay:
276
+ groups.append(
277
+ {
278
+ "params": head_no_decay,
279
+ "lr": base_lr,
280
+ "weight_decay": 0.0,
281
+ }
282
+ )
283
+
284
+ if backbone_no_decay:
285
+ # Backbone bias/norm also gets 0.1× LR to match intended decay
286
+ groups.append(
287
+ {
288
+ "params": backbone_no_decay,
289
+ "lr": base_lr * 0.1,
290
+ "weight_decay": 0.0,
291
+ }
292
+ )
293
+
294
+ return groups if groups else [{"params": self.parameters(), "lr": base_lr}]
295
+
296
+
297
+ # =============================================================================
298
+ # REGISTERED MODEL VARIANTS
299
+ # =============================================================================
300
+
301
+
302
+ @register_model("swin_t")
303
+ class SwinTiny(SwinTransformerBase):
304
+ """
305
+ Swin-T (Tiny): Efficient default for most wave-based tasks.
306
+
307
+ ~27.5M backbone parameters. Good balance of accuracy and computational cost.
308
+ Outperforms ResNet50 while being more efficient.
309
+
310
+ Recommended for:
311
+ - Default choice for 2D wave data
312
+ - Dispersion curves, spectrograms, B-scans
313
+ - When hierarchical features matter
314
+ - Transfer learning with limited data
315
+
316
+ Architecture:
317
+ - Patch size: 4×4
318
+ - Window size: 7×7
319
+ - Embed dim: 96
320
+ - Depths: [2, 2, 6, 2]
321
+ - Heads: [3, 6, 12, 24]
322
+
323
+ Args:
324
+ in_shape: (H, W) image dimensions
325
+ out_size: Number of regression targets
326
+ pretrained: Use ImageNet pretrained weights (default: True)
327
+ dropout_rate: Dropout rate in head (default: 0.3)
328
+ freeze_backbone: Freeze backbone for fine-tuning (default: False)
329
+ regression_hidden: Hidden units in regression head (default: 512)
330
+
331
+ Example:
332
+ >>> model = SwinTiny(in_shape=(224, 224), out_size=3)
333
+ >>> x = torch.randn(4, 1, 224, 224)
334
+ >>> out = model(x) # (4, 3)
335
+ """
336
+
337
+ def __init__(self, in_shape: tuple[int, int], out_size: int, **kwargs):
338
+ super().__init__(
339
+ in_shape=in_shape,
340
+ out_size=out_size,
341
+ model_fn=swin_t,
342
+ weights_class=Swin_T_Weights,
343
+ **kwargs,
344
+ )
345
+
346
+ def __repr__(self) -> str:
347
+ pt = "pretrained" if self.pretrained else "scratch"
348
+ return f"Swin_Tiny({pt}, in={self.in_shape}, out={self.out_size})"
349
+
350
+
351
+ @register_model("swin_s")
352
+ class SwinSmall(SwinTransformerBase):
353
+ """
354
+ Swin-S (Small): Higher accuracy with moderate compute.
355
+
356
+ ~48.8M backbone parameters. Better accuracy than Swin-T for larger datasets.
357
+
358
+ Recommended for:
359
+ - Larger datasets (>20k samples)
360
+ - When Swin-T doesn't provide enough capacity
361
+ - Complex multi-scale patterns
362
+
363
+ Architecture:
364
+ - Patch size: 4×4
365
+ - Window size: 7×7
366
+ - Embed dim: 96
367
+ - Depths: [2, 2, 18, 2] (deeper stage 3)
368
+ - Heads: [3, 6, 12, 24]
369
+
370
+ Args:
371
+ in_shape: (H, W) image dimensions
372
+ out_size: Number of regression targets
373
+ pretrained: Use ImageNet pretrained weights (default: True)
374
+ dropout_rate: Dropout rate in head (default: 0.3)
375
+ freeze_backbone: Freeze backbone for fine-tuning (default: False)
376
+ regression_hidden: Hidden units in regression head (default: 512)
377
+
378
+ Example:
379
+ >>> model = SwinSmall(in_shape=(224, 224), out_size=3)
380
+ >>> x = torch.randn(4, 1, 224, 224)
381
+ >>> out = model(x) # (4, 3)
382
+ """
383
+
384
+ def __init__(self, in_shape: tuple[int, int], out_size: int, **kwargs):
385
+ super().__init__(
386
+ in_shape=in_shape,
387
+ out_size=out_size,
388
+ model_fn=swin_s,
389
+ weights_class=Swin_S_Weights,
390
+ **kwargs,
391
+ )
392
+
393
+ def __repr__(self) -> str:
394
+ pt = "pretrained" if self.pretrained else "scratch"
395
+ return f"Swin_Small({pt}, in={self.in_shape}, out={self.out_size})"
396
+
397
+
398
+ @register_model("swin_b")
399
+ class SwinBase(SwinTransformerBase):
400
+ """
401
+ Swin-B (Base): Maximum accuracy for large-scale tasks.
402
+
403
+ ~86.7M backbone parameters. Best accuracy but requires more compute and data.
404
+
405
+ Recommended for:
406
+ - Very large datasets (>50k samples)
407
+ - When accuracy is more important than efficiency
408
+ - HPC environments with ample GPU memory
409
+ - Research experiments
410
+
411
+ Architecture:
412
+ - Patch size: 4×4
413
+ - Window size: 7×7
414
+ - Embed dim: 128
415
+ - Depths: [2, 2, 18, 2]
416
+ - Heads: [4, 8, 16, 32]
417
+
418
+ Args:
419
+ in_shape: (H, W) image dimensions
420
+ out_size: Number of regression targets
421
+ pretrained: Use ImageNet pretrained weights (default: True)
422
+ dropout_rate: Dropout rate in head (default: 0.3)
423
+ freeze_backbone: Freeze backbone for fine-tuning (default: False)
424
+ regression_hidden: Hidden units in regression head (default: 512)
425
+
426
+ Example:
427
+ >>> model = SwinBase(in_shape=(224, 224), out_size=3)
428
+ >>> x = torch.randn(4, 1, 224, 224)
429
+ >>> out = model(x) # (4, 3)
430
+ """
431
+
432
+ def __init__(self, in_shape: tuple[int, int], out_size: int, **kwargs):
433
+ super().__init__(
434
+ in_shape=in_shape,
435
+ out_size=out_size,
436
+ model_fn=swin_b,
437
+ weights_class=Swin_B_Weights,
438
+ **kwargs,
439
+ )
440
+
441
+ def __repr__(self) -> str:
442
+ pt = "pretrained" if self.pretrained else "scratch"
443
+ return f"Swin_Base({pt}, in={self.in_shape}, out={self.out_size})"