wavedl 1.5.6__py3-none-any.whl → 1.6.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,251 @@
1
+ """
2
+ MaxViT: Multi-Axis Vision Transformer
3
+ ======================================
4
+
5
+ MaxViT combines local and global attention with O(n) complexity using
6
+ multi-axis attention: block attention (local) + grid attention (global sparse).
7
+
8
+ **Key Features**:
9
+ - Multi-axis attention for both local and global context
10
+ - Hybrid design with MBConv + attention
11
+ - Linear O(n) complexity
12
+ - Hierarchical multi-scale features
13
+
14
+ **Variants**:
15
+ - maxvit_tiny: 31M params
16
+ - maxvit_small: 69M params
17
+ - maxvit_base: 120M params
18
+
19
+ **Requirements**:
20
+ - timm (for pretrained models and architecture)
21
+ - torchvision (fallback, limited support)
22
+
23
+ Reference:
24
+ Tu, Z., et al. (2022). MaxViT: Multi-Axis Vision Transformer.
25
+ ECCV 2022. https://arxiv.org/abs/2204.01697
26
+
27
+ Author: Ductho Le (ductho.le@outlook.com)
28
+ """
29
+
30
+ import torch
31
+ import torch.nn as nn
32
+
33
+ from wavedl.models._timm_utils import build_regression_head
34
+ from wavedl.models.base import BaseModel
35
+ from wavedl.models.registry import register_model
36
+
37
+
38
+ __all__ = [
39
+ "MaxViTBase",
40
+ "MaxViTBaseLarge",
41
+ "MaxViTSmall",
42
+ "MaxViTTiny",
43
+ ]
44
+
45
+
46
+ # =============================================================================
47
+ # MAXVIT BASE CLASS
48
+ # =============================================================================
49
+
50
+
51
+ class MaxViTBase(BaseModel):
52
+ """
53
+ MaxViT base class wrapping timm implementation.
54
+
55
+ Multi-axis attention with local block and global grid attention.
56
+ 2D only due to attention structure.
57
+ """
58
+
59
+ def __init__(
60
+ self,
61
+ in_shape: tuple[int, int],
62
+ out_size: int,
63
+ model_name: str = "maxvit_tiny_tf_224",
64
+ pretrained: bool = True,
65
+ freeze_backbone: bool = False,
66
+ dropout_rate: float = 0.3,
67
+ **kwargs,
68
+ ):
69
+ super().__init__(in_shape, out_size)
70
+
71
+ if len(in_shape) != 2:
72
+ raise ValueError(f"MaxViT requires 2D input (H, W), got {len(in_shape)}D")
73
+
74
+ self.pretrained = pretrained
75
+ self.freeze_backbone = freeze_backbone
76
+ self.model_name = model_name
77
+
78
+ # Try to load from timm
79
+ try:
80
+ import timm
81
+
82
+ self.backbone = timm.create_model(
83
+ model_name,
84
+ pretrained=pretrained,
85
+ num_classes=0, # Remove classifier
86
+ )
87
+
88
+ # Get feature dimension
89
+ with torch.no_grad():
90
+ dummy = torch.zeros(1, 3, *in_shape)
91
+ features = self.backbone(dummy)
92
+ in_features = features.shape[-1]
93
+
94
+ except ImportError:
95
+ raise ImportError(
96
+ "timm is required for MaxViT. Install with: pip install timm"
97
+ )
98
+ except Exception as e:
99
+ raise RuntimeError(f"Failed to load MaxViT model '{model_name}': {e}")
100
+
101
+ # Adapt input channels (3 -> 1)
102
+ self._adapt_input_channels()
103
+
104
+ # Regression head
105
+ self.head = build_regression_head(in_features, out_size, dropout_rate)
106
+
107
+ if freeze_backbone:
108
+ self._freeze_backbone()
109
+
110
+ def _adapt_input_channels(self):
111
+ """Adapt first conv layer for single-channel input."""
112
+ # MaxViT uses stem.conv1 (Conv2dSame from timm)
113
+ adapted = False
114
+
115
+ # Find the first Conv2d with 3 input channels
116
+ for name, module in self.backbone.named_modules():
117
+ if hasattr(module, "in_channels") and module.in_channels == 3:
118
+ # Get parent and child names
119
+ parts = name.split(".")
120
+ parent = self.backbone
121
+ for part in parts[:-1]:
122
+ parent = getattr(parent, part)
123
+ child_name = parts[-1]
124
+
125
+ # Create new conv with 1 input channel
126
+ new_conv = self._make_new_conv(module)
127
+ setattr(parent, child_name, new_conv)
128
+ adapted = True
129
+ break
130
+
131
+ if not adapted:
132
+ import warnings
133
+
134
+ warnings.warn(
135
+ "Could not adapt MaxViT input channels. Model may fail.", stacklevel=2
136
+ )
137
+
138
+ def _make_new_conv(self, old_conv: nn.Module) -> nn.Module:
139
+ """Create new conv layer with 1 input channel."""
140
+ # Handle both Conv2d and Conv2dSame from timm
141
+ type(old_conv)
142
+
143
+ # Get common parameters
144
+ kwargs = {
145
+ "out_channels": old_conv.out_channels,
146
+ "kernel_size": old_conv.kernel_size,
147
+ "stride": old_conv.stride,
148
+ "padding": old_conv.padding if hasattr(old_conv, "padding") else 0,
149
+ "bias": old_conv.bias is not None,
150
+ }
151
+
152
+ # Create new conv (use regular Conv2d for simplicity)
153
+ new_conv = nn.Conv2d(1, **kwargs)
154
+
155
+ if self.pretrained:
156
+ with torch.no_grad():
157
+ new_conv.weight.copy_(old_conv.weight.mean(dim=1, keepdim=True))
158
+ if old_conv.bias is not None:
159
+ new_conv.bias.copy_(old_conv.bias)
160
+ return new_conv
161
+
162
+ def _freeze_backbone(self):
163
+ """Freeze backbone parameters."""
164
+ for param in self.backbone.parameters():
165
+ param.requires_grad = False
166
+
167
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
168
+ features = self.backbone(x)
169
+ return self.head(features)
170
+
171
+
172
+ # =============================================================================
173
+ # REGISTERED VARIANTS
174
+ # =============================================================================
175
+
176
+
177
+ @register_model("maxvit_tiny")
178
+ class MaxViTTiny(MaxViTBase):
179
+ """
180
+ MaxViT Tiny: ~30.1M backbone parameters.
181
+
182
+ Multi-axis attention with local+global context.
183
+ 2D only.
184
+
185
+ Example:
186
+ >>> model = MaxViTTiny(in_shape=(224, 224), out_size=3)
187
+ >>> x = torch.randn(4, 1, 224, 224)
188
+ >>> out = model(x) # (4, 3)
189
+ """
190
+
191
+ def __init__(self, in_shape: tuple[int, int], out_size: int, **kwargs):
192
+ super().__init__(
193
+ in_shape=in_shape,
194
+ out_size=out_size,
195
+ model_name="maxvit_tiny_tf_224",
196
+ **kwargs,
197
+ )
198
+
199
+ def __repr__(self) -> str:
200
+ return (
201
+ f"MaxViT_Tiny(in_shape={self.in_shape}, out_size={self.out_size}, "
202
+ f"pretrained={self.pretrained})"
203
+ )
204
+
205
+
206
+ @register_model("maxvit_small")
207
+ class MaxViTSmall(MaxViTBase):
208
+ """
209
+ MaxViT Small: ~67.6M backbone parameters.
210
+
211
+ Multi-axis attention with local+global context.
212
+ 2D only.
213
+ """
214
+
215
+ def __init__(self, in_shape: tuple[int, int], out_size: int, **kwargs):
216
+ super().__init__(
217
+ in_shape=in_shape,
218
+ out_size=out_size,
219
+ model_name="maxvit_small_tf_224",
220
+ **kwargs,
221
+ )
222
+
223
+ def __repr__(self) -> str:
224
+ return (
225
+ f"MaxViT_Small(in_shape={self.in_shape}, out_size={self.out_size}, "
226
+ f"pretrained={self.pretrained})"
227
+ )
228
+
229
+
230
+ @register_model("maxvit_base")
231
+ class MaxViTBaseLarge(MaxViTBase):
232
+ """
233
+ MaxViT Base: ~118.1M backbone parameters.
234
+
235
+ Multi-axis attention with local+global context.
236
+ 2D only.
237
+ """
238
+
239
+ def __init__(self, in_shape: tuple[int, int], out_size: int, **kwargs):
240
+ super().__init__(
241
+ in_shape=in_shape,
242
+ out_size=out_size,
243
+ model_name="maxvit_base_tf_224",
244
+ **kwargs,
245
+ )
246
+
247
+ def __repr__(self) -> str:
248
+ return (
249
+ f"MaxViT_Base(in_shape={self.in_shape}, out_size={self.out_size}, "
250
+ f"pretrained={self.pretrained})"
251
+ )
@@ -13,8 +13,8 @@ optimization to achieve excellent accuracy with minimal computational cost.
13
13
  - Designed for real-time inference on CPUs and edge devices
14
14
 
15
15
  **Variants**:
16
- - mobilenet_v3_small: Ultra-lightweight (~1.1M params) - Edge/embedded
17
- - mobilenet_v3_large: Balanced (~3.2M params) - Mobile deployment
16
+ - mobilenet_v3_small: Ultra-lightweight (~0.9M backbone params) - Edge/embedded
17
+ - mobilenet_v3_large: Balanced (~3.0M backbone params) - Mobile deployment
18
18
 
19
19
  **Use Cases**:
20
20
  - Real-time structural health monitoring on embedded systems
@@ -136,10 +136,37 @@ class MobileNetV3Base(BaseModel):
136
136
  nn.Linear(regression_hidden, out_size),
137
137
  )
138
138
 
139
- # Optionally freeze backbone for fine-tuning
139
+ # Adapt first conv for single-channel input (3× memory savings vs expand)
140
+ self._adapt_input_channels()
141
+
142
+ # Optionally freeze backbone for fine-tuning (after adaptation so new conv is frozen too)
140
143
  if freeze_backbone:
141
144
  self._freeze_backbone()
142
145
 
146
+ def _adapt_input_channels(self):
147
+ """Modify first conv to accept single-channel input.
148
+
149
+ Instead of expanding 1→3 channels in forward (which triples memory),
150
+ we replace the first conv layer with a 1-channel version and initialize
151
+ weights as the mean of the pretrained RGB filters.
152
+ """
153
+ old_conv = self.backbone.features[0][0]
154
+ new_conv = nn.Conv2d(
155
+ 1, # Single channel input
156
+ old_conv.out_channels,
157
+ kernel_size=old_conv.kernel_size,
158
+ stride=old_conv.stride,
159
+ padding=old_conv.padding,
160
+ dilation=old_conv.dilation,
161
+ groups=old_conv.groups,
162
+ padding_mode=old_conv.padding_mode,
163
+ bias=old_conv.bias is not None,
164
+ )
165
+ if self.pretrained:
166
+ with torch.no_grad():
167
+ new_conv.weight.copy_(old_conv.weight.mean(dim=1, keepdim=True))
168
+ self.backbone.features[0][0] = new_conv
169
+
143
170
  def _freeze_backbone(self):
144
171
  """Freeze all backbone parameters except the classifier."""
145
172
  for name, param in self.backbone.named_parameters():
@@ -151,15 +178,11 @@ class MobileNetV3Base(BaseModel):
151
178
  Forward pass.
152
179
 
153
180
  Args:
154
- x: Input tensor of shape (B, C, H, W) where C is 1 or 3
181
+ x: Input tensor of shape (B, 1, H, W)
155
182
 
156
183
  Returns:
157
184
  Output tensor of shape (B, out_size)
158
185
  """
159
- # Expand single channel to 3 channels for pretrained weights compatibility
160
- if x.size(1) == 1:
161
- x = x.expand(-1, 3, -1, -1)
162
-
163
186
  return self.backbone(x)
164
187
 
165
188
  @classmethod
@@ -183,7 +206,7 @@ class MobileNetV3Small(MobileNetV3Base):
183
206
  """
184
207
  MobileNetV3-Small: Ultra-lightweight for edge deployment.
185
208
 
186
- ~1.1M parameters. Designed for the most constrained environments.
209
+ ~0.9M backbone parameters. Designed for the most constrained environments.
187
210
  Achieves ~67% ImageNet accuracy with minimal compute.
188
211
 
189
212
  Recommended for:
@@ -194,7 +217,7 @@ class MobileNetV3Small(MobileNetV3Base):
194
217
 
195
218
  Performance (approximate):
196
219
  - CPU inference: ~6ms (single core)
197
- - Parameters: 2.5M
220
+ - Parameters: ~0.9M backbone
198
221
  - MAdds: 56M
199
222
 
200
223
  Args:
@@ -230,7 +253,7 @@ class MobileNetV3Large(MobileNetV3Base):
230
253
  """
231
254
  MobileNetV3-Large: Balanced efficiency and accuracy.
232
255
 
233
- ~3.2M parameters. Best trade-off for mobile/portable deployment.
256
+ ~3.0M backbone parameters. Best trade-off for mobile/portable deployment.
234
257
  Achieves ~75% ImageNet accuracy with efficient inference.
235
258
 
236
259
  Recommended for:
@@ -241,7 +264,7 @@ class MobileNetV3Large(MobileNetV3Base):
241
264
 
242
265
  Performance (approximate):
243
266
  - CPU inference: ~20ms (single core)
244
- - Parameters: 5.4M
267
+ - Parameters: ~3.0M backbone
245
268
  - MAdds: 219M
246
269
 
247
270
  Args:
wavedl/models/regnet.py CHANGED
@@ -13,11 +13,11 @@ Models scale smoothly from mobile to server deployments.
13
13
  - Optional Squeeze-and-Excitation (SE) attention
14
14
 
15
15
  **Variants** (RegNetY includes SE attention):
16
- - regnet_y_400mf: Ultra-light (~4.0M params, 0.4 GFLOPs)
17
- - regnet_y_800mf: Light (~5.8M params, 0.8 GFLOPs)
18
- - regnet_y_1_6gf: Medium (~10.5M params, 1.6 GFLOPs) - Recommended
19
- - regnet_y_3_2gf: Large (~18.3M params, 3.2 GFLOPs)
20
- - regnet_y_8gf: Very large (~37.9M params, 8.0 GFLOPs)
16
+ - regnet_y_400mf: Ultra-light (~3.9M backbone params, 0.4 GFLOPs)
17
+ - regnet_y_800mf: Light (~5.7M backbone params, 0.8 GFLOPs)
18
+ - regnet_y_1_6gf: Medium (~10.3M backbone params, 1.6 GFLOPs) - Recommended
19
+ - regnet_y_3_2gf: Large (~17.9M backbone params, 3.2 GFLOPs)
20
+ - regnet_y_8gf: Very large (~37.4M backbone params, 8.0 GFLOPs)
21
21
 
22
22
  **When to Use RegNet**:
23
23
  - When you need predictable performance at a given compute budget
@@ -140,10 +140,37 @@ class RegNetBase(BaseModel):
140
140
  nn.Linear(regression_hidden, out_size),
141
141
  )
142
142
 
143
- # Optionally freeze backbone for fine-tuning
143
+ # Adapt first conv for single-channel input (3× memory savings vs expand)
144
+ self._adapt_input_channels()
145
+
146
+ # Optionally freeze backbone for fine-tuning (after adaptation so new conv is frozen too)
144
147
  if freeze_backbone:
145
148
  self._freeze_backbone()
146
149
 
150
+ def _adapt_input_channels(self):
151
+ """Modify first conv to accept single-channel input.
152
+
153
+ Instead of expanding 1→3 channels in forward (which triples memory),
154
+ we replace the first conv layer with a 1-channel version and initialize
155
+ weights as the mean of the pretrained RGB filters.
156
+ """
157
+ old_conv = self.backbone.stem[0]
158
+ new_conv = nn.Conv2d(
159
+ 1, # Single channel input
160
+ old_conv.out_channels,
161
+ kernel_size=old_conv.kernel_size,
162
+ stride=old_conv.stride,
163
+ padding=old_conv.padding,
164
+ dilation=old_conv.dilation,
165
+ groups=old_conv.groups,
166
+ padding_mode=old_conv.padding_mode,
167
+ bias=old_conv.bias is not None,
168
+ )
169
+ if self.pretrained:
170
+ with torch.no_grad():
171
+ new_conv.weight.copy_(old_conv.weight.mean(dim=1, keepdim=True))
172
+ self.backbone.stem[0] = new_conv
173
+
147
174
  def _freeze_backbone(self):
148
175
  """Freeze all backbone parameters except the fc layer."""
149
176
  for name, param in self.backbone.named_parameters():
@@ -155,15 +182,11 @@ class RegNetBase(BaseModel):
155
182
  Forward pass.
156
183
 
157
184
  Args:
158
- x: Input tensor of shape (B, C, H, W) where C is 1 or 3
185
+ x: Input tensor of shape (B, 1, H, W)
159
186
 
160
187
  Returns:
161
188
  Output tensor of shape (B, out_size)
162
189
  """
163
- # Expand single channel to 3 channels for pretrained weights compatibility
164
- if x.size(1) == 1:
165
- x = x.expand(-1, 3, -1, -1)
166
-
167
190
  return self.backbone(x)
168
191
 
169
192
  @classmethod
@@ -187,7 +210,7 @@ class RegNetY400MF(RegNetBase):
187
210
  """
188
211
  RegNetY-400MF: Ultra-lightweight for constrained environments.
189
212
 
190
- ~4.0M parameters, 0.4 GFLOPs. Smallest RegNet variant with SE attention.
213
+ ~3.9M backbone parameters, 0.4 GFLOPs. Smallest RegNet variant with SE attention.
191
214
 
192
215
  Recommended for:
193
216
  - Edge deployment with moderate accuracy needs
@@ -227,7 +250,7 @@ class RegNetY800MF(RegNetBase):
227
250
  """
228
251
  RegNetY-800MF: Light variant with good accuracy.
229
252
 
230
- ~6.4M parameters, 0.8 GFLOPs. Good balance for mobile deployment.
253
+ ~5.7M backbone parameters, 0.8 GFLOPs. Good balance for mobile deployment.
231
254
 
232
255
  Recommended for:
233
256
  - Mobile/portable devices
@@ -267,7 +290,7 @@ class RegNetY1_6GF(RegNetBase):
267
290
  """
268
291
  RegNetY-1.6GF: Recommended default for balanced performance.
269
292
 
270
- ~11.2M parameters, 1.6 GFLOPs. Best trade-off of accuracy and efficiency.
293
+ ~10.3M backbone parameters, 1.6 GFLOPs. Best trade-off of accuracy and efficiency.
271
294
  Comparable to ResNet50 but more efficient.
272
295
 
273
296
  Recommended for:
@@ -308,7 +331,7 @@ class RegNetY3_2GF(RegNetBase):
308
331
  """
309
332
  RegNetY-3.2GF: Higher accuracy for demanding tasks.
310
333
 
311
- ~19.4M parameters, 3.2 GFLOPs. Use when 1.6GF isn't sufficient.
334
+ ~17.9M backbone parameters, 3.2 GFLOPs. Use when 1.6GF isn't sufficient.
312
335
 
313
336
  Recommended for:
314
337
  - Larger datasets requiring more capacity
@@ -348,7 +371,7 @@ class RegNetY8GF(RegNetBase):
348
371
  """
349
372
  RegNetY-8GF: High capacity for large-scale tasks.
350
373
 
351
- ~39.2M parameters, 8.0 GFLOPs. Use for maximum accuracy needs.
374
+ ~37.4M backbone parameters, 8.0 GFLOPs. Use for maximum accuracy needs.
352
375
 
353
376
  Recommended for:
354
377
  - Very large datasets (>50k samples)
wavedl/models/resnet.py CHANGED
@@ -11,9 +11,9 @@ Provides multiple depth variants (18, 34, 50) with optional pretrained weights f
11
11
  - 3D: Volumetric data, CT/MRI (N, 1, D, H, W) → Conv3d
12
12
 
13
13
  **Variants**:
14
- - resnet18: Lightweight, fast training (~11M params)
15
- - resnet34: Balanced capacity (~21M params)
16
- - resnet50: Higher capacity with bottleneck blocks (~25M params)
14
+ - resnet18: Lightweight, fast training (~11.2M backbone params)
15
+ - resnet34: Balanced capacity (~21.3M backbone params)
16
+ - resnet50: Higher capacity with bottleneck blocks (~23.5M backbone params)
17
17
 
18
18
  References:
19
19
  He, K., Zhang, X., Ren, S., & Sun, J. (2016). Deep Residual Learning
@@ -534,7 +534,7 @@ class ResNet18Pretrained(PretrainedResNetBase):
534
534
  """
535
535
  ResNet-18 with ImageNet pretrained weights (2D only).
536
536
 
537
- ~11M parameters. Good for: Transfer learning, faster convergence.
537
+ ~11.2M backbone parameters. Good for: Transfer learning, faster convergence.
538
538
 
539
539
  Args:
540
540
  in_shape: (H, W) image dimensions
@@ -563,7 +563,7 @@ class ResNet50Pretrained(PretrainedResNetBase):
563
563
  """
564
564
  ResNet-50 with ImageNet pretrained weights (2D only).
565
565
 
566
- ~25M parameters. Good for: High accuracy with transfer learning.
566
+ ~23.5M backbone parameters. Good for: High accuracy with transfer learning.
567
567
 
568
568
  Args:
569
569
  in_shape: (H, W) image dimensions
wavedl/models/resnet3d.py CHANGED
@@ -179,7 +179,7 @@ class ResNet3D18(ResNet3DBase):
179
179
  """
180
180
  ResNet3D-18: Lightweight 3D ResNet for volumetric data.
181
181
 
182
- ~33M parameters. Uses 3D convolutions throughout for true volumetric processing.
182
+ ~33.2M backbone parameters. Uses 3D convolutions throughout for true volumetric processing.
183
183
  Pretrained on Kinetics-400 (video action recognition).
184
184
 
185
185
  Recommended for:
@@ -221,7 +221,7 @@ class MC3_18(ResNet3DBase):
221
221
  """
222
222
  MC3-18: Mixed Convolution 3D ResNet (3D stem + 2D residual blocks).
223
223
 
224
- ~11M parameters. More efficient than pure 3D ResNet while maintaining
224
+ ~11.5M backbone parameters. More efficient than pure 3D ResNet while maintaining
225
225
  good spatiotemporal modeling. Uses 3D convolutions in early layers
226
226
  and 2D convolutions in later layers.
227
227
 
wavedl/models/swin.py CHANGED
@@ -141,10 +141,46 @@ class SwinTransformerBase(BaseModel):
141
141
  nn.Linear(regression_hidden // 2, out_size),
142
142
  )
143
143
 
144
- # Optionally freeze backbone for fine-tuning
144
+ # Adapt patch embedding conv for single-channel input (3× memory savings vs expand)
145
+ self._adapt_input_channels()
146
+
147
+ # Optionally freeze backbone for fine-tuning (after adaptation so new conv is frozen too)
145
148
  if freeze_backbone:
146
149
  self._freeze_backbone()
147
150
 
151
+ def _adapt_input_channels(self):
152
+ """Modify patch embedding conv to accept single-channel input.
153
+
154
+ Instead of expanding 1→3 channels in forward (which triples memory),
155
+ we replace the patch embedding conv with a 1-channel version and
156
+ initialize weights as the mean of the pretrained RGB filters.
157
+ """
158
+ # Swin's patch embedding is at features[0][0]
159
+ try:
160
+ old_conv = self.backbone.features[0][0]
161
+ except (IndexError, AttributeError, TypeError) as e:
162
+ raise RuntimeError(
163
+ f"Swin patch embed structure changed in this torchvision version. "
164
+ f"Cannot adapt input channels. Error: {e}"
165
+ ) from e
166
+ new_conv = nn.Conv2d(
167
+ 1, # Single channel input
168
+ old_conv.out_channels,
169
+ kernel_size=old_conv.kernel_size,
170
+ stride=old_conv.stride,
171
+ padding=old_conv.padding,
172
+ dilation=old_conv.dilation,
173
+ groups=old_conv.groups,
174
+ padding_mode=old_conv.padding_mode,
175
+ bias=old_conv.bias is not None,
176
+ )
177
+ if self.pretrained:
178
+ with torch.no_grad():
179
+ new_conv.weight.copy_(old_conv.weight.mean(dim=1, keepdim=True))
180
+ if old_conv.bias is not None:
181
+ new_conv.bias.copy_(old_conv.bias)
182
+ self.backbone.features[0][0] = new_conv
183
+
148
184
  def _freeze_backbone(self):
149
185
  """Freeze all backbone parameters except the head."""
150
186
  for name, param in self.backbone.named_parameters():
@@ -156,15 +192,11 @@ class SwinTransformerBase(BaseModel):
156
192
  Forward pass.
157
193
 
158
194
  Args:
159
- x: Input tensor of shape (B, C, H, W) where C is 1 or 3
195
+ x: Input tensor of shape (B, 1, H, W)
160
196
 
161
197
  Returns:
162
198
  Output tensor of shape (B, out_size)
163
199
  """
164
- # Expand single channel to 3 channels for pretrained weights compatibility
165
- if x.size(1) == 1:
166
- x = x.expand(-1, 3, -1, -1)
167
-
168
200
  return self.backbone(x)
169
201
 
170
202
  @classmethod
@@ -272,7 +304,7 @@ class SwinTiny(SwinTransformerBase):
272
304
  """
273
305
  Swin-T (Tiny): Efficient default for most wave-based tasks.
274
306
 
275
- ~28M parameters. Good balance of accuracy and computational cost.
307
+ ~27.5M backbone parameters. Good balance of accuracy and computational cost.
276
308
  Outperforms ResNet50 while being more efficient.
277
309
 
278
310
  Recommended for:
@@ -321,7 +353,7 @@ class SwinSmall(SwinTransformerBase):
321
353
  """
322
354
  Swin-S (Small): Higher accuracy with moderate compute.
323
355
 
324
- ~50M parameters. Better accuracy than Swin-T for larger datasets.
356
+ ~48.8M backbone parameters. Better accuracy than Swin-T for larger datasets.
325
357
 
326
358
  Recommended for:
327
359
  - Larger datasets (>20k samples)
@@ -368,7 +400,7 @@ class SwinBase(SwinTransformerBase):
368
400
  """
369
401
  Swin-B (Base): Maximum accuracy for large-scale tasks.
370
402
 
371
- ~88M parameters. Best accuracy but requires more compute and data.
403
+ ~86.7M backbone parameters. Best accuracy but requires more compute and data.
372
404
 
373
405
  Recommended for:
374
406
  - Very large datasets (>50k samples)
wavedl/models/tcn.py CHANGED
@@ -45,6 +45,26 @@ from wavedl.models.base import BaseModel
45
45
  from wavedl.models.registry import register_model
46
46
 
47
47
 
48
+ def _find_group_count(channels: int, max_groups: int = 8) -> int:
49
+ """
50
+ Find largest valid group count for GroupNorm.
51
+
52
+ GroupNorm requires channels to be divisible by num_groups.
53
+ This finds the largest divisor up to max_groups.
54
+
55
+ Args:
56
+ channels: Number of channels
57
+ max_groups: Maximum group count to consider (default: 8)
58
+
59
+ Returns:
60
+ Largest valid group count (always >= 1)
61
+ """
62
+ for g in range(min(max_groups, channels), 0, -1):
63
+ if channels % g == 0:
64
+ return g
65
+ return 1
66
+
67
+
48
68
  class CausalConv1d(nn.Module):
49
69
  """
50
70
  Causal 1D convolution with dilation.
@@ -101,13 +121,13 @@ class TemporalBlock(nn.Module):
101
121
 
102
122
  # First causal convolution
103
123
  self.conv1 = CausalConv1d(in_channels, out_channels, kernel_size, dilation)
104
- self.norm1 = nn.GroupNorm(min(8, out_channels), out_channels)
124
+ self.norm1 = nn.GroupNorm(_find_group_count(out_channels), out_channels)
105
125
  self.act1 = nn.GELU()
106
126
  self.dropout1 = nn.Dropout(dropout)
107
127
 
108
128
  # Second causal convolution
109
129
  self.conv2 = CausalConv1d(out_channels, out_channels, kernel_size, dilation)
110
- self.norm2 = nn.GroupNorm(min(8, out_channels), out_channels)
130
+ self.norm2 = nn.GroupNorm(_find_group_count(out_channels), out_channels)
111
131
  self.act2 = nn.GELU()
112
132
  self.dropout2 = nn.Dropout(dropout)
113
133
 
@@ -276,7 +296,7 @@ class TCN(TCNBase):
276
296
  """
277
297
  TCN: Standard Temporal Convolutional Network.
278
298
 
279
- ~7.0M parameters. 8 temporal blocks with channels [64→128→256→256→512→512→512→512].
299
+ ~6.9M backbone parameters. 8 temporal blocks with channels [64→128→256→256→512→512→512→512].
280
300
  Receptive field: 511 samples with kernel_size=3.
281
301
 
282
302
  Recommended for:
@@ -318,7 +338,7 @@ class TCNSmall(TCNBase):
318
338
  """
319
339
  TCN-Small: Lightweight variant for quick experiments.
320
340
 
321
- ~1.0M parameters. 6 temporal blocks with channels [32→64→128→128→256→256].
341
+ ~0.9M backbone parameters. 6 temporal blocks with channels [32→64→128→128→256→256].
322
342
  Receptive field: 127 samples with kernel_size=3.
323
343
 
324
344
  Recommended for:
@@ -356,7 +376,7 @@ class TCNLarge(TCNBase):
356
376
  """
357
377
  TCN-Large: High-capacity variant for complex patterns.
358
378
 
359
- ~10.2M parameters. 10 temporal blocks with channels [64→128→256→256→512→512→512→512→512→512].
379
+ ~10.0M backbone parameters. 10 temporal blocks with channels [64→128→256→256→512→512→512→512→512→512].
360
380
  Receptive field: 2047 samples with kernel_size=3.
361
381
 
362
382
  Recommended for:
wavedl/models/unet.py CHANGED
@@ -119,7 +119,7 @@ class UNetRegression(BaseModel):
119
119
  Uses U-Net encoder-decoder architecture with skip connections,
120
120
  then applies global pooling for standard vector regression output.
121
121
 
122
- ~31.1M parameters (2D). Good for leveraging multi-scale features
122
+ ~31.0M backbone parameters (2D). Good for leveraging multi-scale features
123
123
  and skip connections for regression tasks.
124
124
 
125
125
  Args: