wavedl 1.5.7__py3-none-any.whl → 1.6.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (39) hide show
  1. wavedl/__init__.py +1 -1
  2. wavedl/hpo.py +451 -451
  3. wavedl/models/__init__.py +80 -4
  4. wavedl/models/_pretrained_utils.py +366 -0
  5. wavedl/models/base.py +48 -0
  6. wavedl/models/caformer.py +270 -0
  7. wavedl/models/cnn.py +2 -27
  8. wavedl/models/convnext.py +113 -51
  9. wavedl/models/convnext_v2.py +488 -0
  10. wavedl/models/densenet.py +10 -23
  11. wavedl/models/efficientnet.py +6 -6
  12. wavedl/models/efficientnetv2.py +315 -315
  13. wavedl/models/efficientvit.py +398 -0
  14. wavedl/models/fastvit.py +252 -0
  15. wavedl/models/mamba.py +555 -0
  16. wavedl/models/maxvit.py +254 -0
  17. wavedl/models/mobilenetv3.py +295 -295
  18. wavedl/models/regnet.py +406 -406
  19. wavedl/models/resnet.py +19 -61
  20. wavedl/models/resnet3d.py +258 -258
  21. wavedl/models/swin.py +443 -443
  22. wavedl/models/tcn.py +393 -409
  23. wavedl/models/unet.py +2 -6
  24. wavedl/models/unireplknet.py +491 -0
  25. wavedl/models/vit.py +9 -9
  26. wavedl/train.py +1430 -1425
  27. wavedl/utils/config.py +367 -367
  28. wavedl/utils/cross_validation.py +530 -530
  29. wavedl/utils/data.py +39 -6
  30. wavedl/utils/losses.py +216 -216
  31. wavedl/utils/optimizers.py +216 -216
  32. wavedl/utils/schedulers.py +251 -251
  33. {wavedl-1.5.7.dist-info → wavedl-1.6.1.dist-info}/METADATA +150 -82
  34. wavedl-1.6.1.dist-info/RECORD +46 -0
  35. wavedl-1.5.7.dist-info/RECORD +0 -38
  36. {wavedl-1.5.7.dist-info → wavedl-1.6.1.dist-info}/LICENSE +0 -0
  37. {wavedl-1.5.7.dist-info → wavedl-1.6.1.dist-info}/WHEEL +0 -0
  38. {wavedl-1.5.7.dist-info → wavedl-1.6.1.dist-info}/entry_points.txt +0 -0
  39. {wavedl-1.5.7.dist-info → wavedl-1.6.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,270 @@
1
+ """
2
+ CaFormer: MetaFormer with Convolution and Attention
3
+ ====================================================
4
+
5
+ CaFormer implements the MetaFormer architecture using depthwise separable
6
+ convolutions in early stages and vanilla self-attention in later stages.
7
+
8
+ **Key Features**:
9
+ - MetaFormer principle: architecture > token mixer
10
+ - Hybrid: Conv (early) + Attention (late)
11
+ - StarReLU activation for efficiency
12
+ - State-of-the-art on ImageNet (85.5%)
13
+
14
+ **Variants**:
15
+ - caformer_s18: 26M params
16
+ - caformer_s36: 39M params
17
+ - caformer_m36: 56M params
18
+
19
+ **Related Models**:
20
+ - PoolFormer: Uses pooling instead of attention
21
+ - ConvFormer: Uses only convolutions
22
+
23
+ **Requirements**:
24
+ - timm >= 0.9.0 (for CaFormer models)
25
+
26
+ Reference:
27
+ Yu, W., et al. (2023). MetaFormer Baselines for Vision.
28
+ TPAMI 2023. https://arxiv.org/abs/2210.13452
29
+
30
+ Author: Ductho Le (ductho.le@outlook.com)
31
+ """
32
+
33
+ import torch
34
+ import torch.nn as nn
35
+
36
+ from wavedl.models._pretrained_utils import build_regression_head
37
+ from wavedl.models.base import BaseModel
38
+ from wavedl.models.registry import register_model
39
+
40
+
41
+ __all__ = [
42
+ "CaFormerBase",
43
+ "CaFormerM36",
44
+ "CaFormerS18",
45
+ "CaFormerS36",
46
+ "PoolFormerS12",
47
+ ]
48
+
49
+
50
+ # =============================================================================
51
+ # CAFORMER BASE CLASS
52
+ # =============================================================================
53
+
54
+
55
+ class CaFormerBase(BaseModel):
56
+ """
57
+ CaFormer base class wrapping timm implementation.
58
+
59
+ MetaFormer with conv (early) + attention (late) token mixing.
60
+ 2D only.
61
+ """
62
+
63
+ def __init__(
64
+ self,
65
+ in_shape: tuple[int, int],
66
+ out_size: int,
67
+ model_name: str = "caformer_s18",
68
+ pretrained: bool = True,
69
+ freeze_backbone: bool = False,
70
+ dropout_rate: float = 0.3,
71
+ **kwargs,
72
+ ):
73
+ super().__init__(in_shape, out_size)
74
+
75
+ if len(in_shape) != 2:
76
+ raise ValueError(f"CaFormer requires 2D input (H, W), got {len(in_shape)}D")
77
+
78
+ self.pretrained = pretrained
79
+ self.freeze_backbone = freeze_backbone
80
+ self.model_name = model_name
81
+
82
+ # Try to load from timm
83
+ try:
84
+ import timm
85
+
86
+ self.backbone = timm.create_model(
87
+ model_name,
88
+ pretrained=pretrained,
89
+ num_classes=0, # Remove classifier
90
+ )
91
+
92
+ # Get feature dimension
93
+ with torch.no_grad():
94
+ dummy = torch.zeros(1, 3, *in_shape)
95
+ features = self.backbone(dummy)
96
+ in_features = features.shape[-1]
97
+
98
+ except ImportError:
99
+ raise ImportError(
100
+ "timm >= 0.9.0 is required for CaFormer. "
101
+ "Install with: pip install timm>=0.9.0"
102
+ )
103
+ except Exception as e:
104
+ raise RuntimeError(f"Failed to load CaFormer model '{model_name}': {e}")
105
+
106
+ # Adapt input channels (3 -> 1)
107
+ self._adapt_input_channels()
108
+
109
+ # Regression head
110
+ self.head = build_regression_head(in_features, out_size, dropout_rate)
111
+
112
+ if freeze_backbone:
113
+ self._freeze_backbone()
114
+
115
+ def _adapt_input_channels(self):
116
+ """Adapt first conv layer for single-channel input."""
117
+ # CaFormer uses stem for first layer
118
+ if hasattr(self.backbone, "stem"):
119
+ first_conv = None
120
+ # Find first conv in stem
121
+ for name, module in self.backbone.stem.named_modules():
122
+ if isinstance(module, nn.Conv2d):
123
+ first_conv = (name, module)
124
+ break
125
+
126
+ if first_conv is not None:
127
+ name, old_conv = first_conv
128
+ new_conv = self._make_new_conv(old_conv)
129
+ # Set the new conv (handle nested structure)
130
+ self._set_module(self.backbone.stem, name, new_conv)
131
+
132
+ def _make_new_conv(self, old_conv: nn.Conv2d) -> nn.Conv2d:
133
+ """Create new conv layer with 1 input channel."""
134
+ new_conv = nn.Conv2d(
135
+ 1,
136
+ old_conv.out_channels,
137
+ kernel_size=old_conv.kernel_size,
138
+ stride=old_conv.stride,
139
+ padding=old_conv.padding,
140
+ bias=old_conv.bias is not None,
141
+ )
142
+ if self.pretrained:
143
+ with torch.no_grad():
144
+ new_conv.weight.copy_(old_conv.weight.mean(dim=1, keepdim=True))
145
+ if old_conv.bias is not None:
146
+ new_conv.bias.copy_(old_conv.bias)
147
+ return new_conv
148
+
149
+ def _set_module(self, parent: nn.Module, name: str, module: nn.Module):
150
+ """Set a nested module by name."""
151
+ parts = name.split(".")
152
+ for part in parts[:-1]:
153
+ parent = getattr(parent, part)
154
+ setattr(parent, parts[-1], module)
155
+
156
+ def _freeze_backbone(self):
157
+ """Freeze backbone parameters."""
158
+ for param in self.backbone.parameters():
159
+ param.requires_grad = False
160
+
161
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
162
+ features = self.backbone(x)
163
+ return self.head(features)
164
+
165
+
166
+ # =============================================================================
167
+ # REGISTERED VARIANTS
168
+ # =============================================================================
169
+
170
+
171
+ @register_model("caformer_s18")
172
+ class CaFormerS18(CaFormerBase):
173
+ """
174
+ CaFormer-S18: ~23.2M backbone parameters.
175
+
176
+ MetaFormer with conv + attention.
177
+ 2D only.
178
+
179
+ Example:
180
+ >>> model = CaFormerS18(in_shape=(224, 224), out_size=3)
181
+ >>> x = torch.randn(4, 1, 224, 224)
182
+ >>> out = model(x) # (4, 3)
183
+ """
184
+
185
+ def __init__(self, in_shape: tuple[int, int], out_size: int, **kwargs):
186
+ super().__init__(
187
+ in_shape=in_shape,
188
+ out_size=out_size,
189
+ model_name="caformer_s18",
190
+ **kwargs,
191
+ )
192
+
193
+ def __repr__(self) -> str:
194
+ return (
195
+ f"CaFormer_S18(in_shape={self.in_shape}, out_size={self.out_size}, "
196
+ f"pretrained={self.pretrained})"
197
+ )
198
+
199
+
200
+ @register_model("caformer_s36")
201
+ class CaFormerS36(CaFormerBase):
202
+ """
203
+ CaFormer-S36: ~36.2M backbone parameters.
204
+
205
+ Deeper MetaFormer variant.
206
+ 2D only.
207
+ """
208
+
209
+ def __init__(self, in_shape: tuple[int, int], out_size: int, **kwargs):
210
+ super().__init__(
211
+ in_shape=in_shape,
212
+ out_size=out_size,
213
+ model_name="caformer_s36",
214
+ **kwargs,
215
+ )
216
+
217
+ def __repr__(self) -> str:
218
+ return (
219
+ f"CaFormer_S36(in_shape={self.in_shape}, out_size={self.out_size}, "
220
+ f"pretrained={self.pretrained})"
221
+ )
222
+
223
+
224
+ @register_model("caformer_m36")
225
+ class CaFormerM36(CaFormerBase):
226
+ """
227
+ CaFormer-M36: ~52.6M backbone parameters.
228
+
229
+ Medium-size MetaFormer variant.
230
+ 2D only.
231
+ """
232
+
233
+ def __init__(self, in_shape: tuple[int, int], out_size: int, **kwargs):
234
+ super().__init__(
235
+ in_shape=in_shape,
236
+ out_size=out_size,
237
+ model_name="caformer_m36",
238
+ **kwargs,
239
+ )
240
+
241
+ def __repr__(self) -> str:
242
+ return (
243
+ f"CaFormer_M36(in_shape={self.in_shape}, out_size={self.out_size}, "
244
+ f"pretrained={self.pretrained})"
245
+ )
246
+
247
+
248
+ @register_model("poolformer_s12")
249
+ class PoolFormerS12(CaFormerBase):
250
+ """
251
+ PoolFormer-S12: ~11.4M backbone parameters.
252
+
253
+ MetaFormer with simple pooling token mixer.
254
+ Proves that architecture matters more than complex attention.
255
+ 2D only.
256
+ """
257
+
258
+ def __init__(self, in_shape: tuple[int, int], out_size: int, **kwargs):
259
+ super().__init__(
260
+ in_shape=in_shape,
261
+ out_size=out_size,
262
+ model_name="poolformer_s12",
263
+ **kwargs,
264
+ )
265
+
266
+ def __repr__(self) -> str:
267
+ return (
268
+ f"PoolFormer_S12(in_shape={self.in_shape}, out_size={self.out_size}, "
269
+ f"pretrained={self.pretrained})"
270
+ )
wavedl/models/cnn.py CHANGED
@@ -24,14 +24,10 @@ from typing import Any
24
24
  import torch
25
25
  import torch.nn as nn
26
26
 
27
- from wavedl.models.base import BaseModel
27
+ from wavedl.models.base import BaseModel, SpatialShape, compute_num_groups
28
28
  from wavedl.models.registry import register_model
29
29
 
30
30
 
31
- # Type alias for spatial shapes
32
- SpatialShape = tuple[int] | tuple[int, int] | tuple[int, int, int]
33
-
34
-
35
31
  def _get_conv_layers(
36
32
  dim: int,
37
33
  ) -> tuple[type[nn.Module], type[nn.Module], type[nn.Module]]:
@@ -163,27 +159,6 @@ class CNN(BaseModel):
163
159
  nn.Linear(64, out_size),
164
160
  )
165
161
 
166
- @staticmethod
167
- def _compute_num_groups(num_channels: int, target_groups: int = 4) -> int:
168
- """
169
- Compute valid num_groups for GroupNorm that divides num_channels.
170
-
171
- Finds the largest divisor of num_channels that is <= target_groups,
172
- or falls back to 1 if no suitable divisor exists.
173
-
174
- Args:
175
- num_channels: Number of channels (must be positive)
176
- target_groups: Desired number of groups (default: 4)
177
-
178
- Returns:
179
- Valid num_groups that satisfies num_channels % num_groups == 0
180
- """
181
- # Try target_groups down to 1, return first valid divisor
182
- for g in range(min(target_groups, num_channels), 0, -1):
183
- if num_channels % g == 0:
184
- return g
185
- return 1 # Fallback (always valid)
186
-
187
162
  def _make_conv_block(
188
163
  self, in_channels: int, out_channels: int, dropout: float = 0.0
189
164
  ) -> nn.Sequential:
@@ -198,7 +173,7 @@ class CNN(BaseModel):
198
173
  Returns:
199
174
  Sequential block: Conv → GroupNorm → LeakyReLU → MaxPool [→ Dropout]
200
175
  """
201
- num_groups = self._compute_num_groups(out_channels, target_groups=4)
176
+ num_groups = compute_num_groups(out_channels, preferred_groups=4)
202
177
 
203
178
  layers = [
204
179
  self._Conv(in_channels, out_channels, kernel_size=3, padding=1),
wavedl/models/convnext.py CHANGED
@@ -11,9 +11,9 @@ Features: inverted bottleneck, LayerNorm, GELU activation, depthwise convolution
11
11
  - 3D: Volumetric data, CT/MRI (N, 1, D, H, W) → Conv3d
12
12
 
13
13
  **Variants**:
14
- - convnext_tiny: Smallest (~28M params for 2D)
15
- - convnext_small: Medium (~50M params for 2D)
16
- - convnext_base: Standard (~89M params for 2D)
14
+ - convnext_tiny: Smallest (~27.8M backbone params for 2D)
15
+ - convnext_small: Medium (~49.5M backbone params for 2D)
16
+ - convnext_base: Standard (~87.6M backbone params for 2D)
17
17
 
18
18
  References:
19
19
  Liu, Z., et al. (2022). A ConvNet for the 2020s.
@@ -26,15 +26,12 @@ from typing import Any
26
26
 
27
27
  import torch
28
28
  import torch.nn as nn
29
+ import torch.nn.functional as F
29
30
 
30
- from wavedl.models.base import BaseModel
31
+ from wavedl.models.base import BaseModel, SpatialShape
31
32
  from wavedl.models.registry import register_model
32
33
 
33
34
 
34
- # Type alias for spatial shapes
35
- SpatialShape = tuple[int] | tuple[int, int] | tuple[int, int, int]
36
-
37
-
38
35
  def _get_conv_layer(dim: int) -> type[nn.Module]:
39
36
  """Get dimension-appropriate Conv class."""
40
37
  if dim == 1:
@@ -51,40 +48,75 @@ class LayerNormNd(nn.Module):
51
48
  """
52
49
  LayerNorm for N-dimensional tensors (channels-first format).
53
50
 
54
- Normalizes over the channel dimension, supporting Conv1d/2d/3d outputs.
51
+ Implements channels-last LayerNorm as used in the original ConvNeXt paper.
52
+ Permutes data to channels-last, applies LayerNorm per-channel over spatial
53
+ dimensions, and permutes back to channels-first format.
54
+
55
+ This matches PyTorch's nn.LayerNorm behavior when applied to the channel
56
+ dimension, providing stable gradients for deep ConvNeXt networks.
57
+
58
+ References:
59
+ Liu, Z., et al. (2022). A ConvNet for the 2020s. CVPR 2022.
60
+ https://github.com/facebookresearch/ConvNeXt
55
61
  """
56
62
 
57
63
  def __init__(self, num_channels: int, dim: int, eps: float = 1e-6):
58
64
  super().__init__()
59
65
  self.dim = dim
66
+ self.num_channels = num_channels
60
67
  self.weight = nn.Parameter(torch.ones(num_channels))
61
68
  self.bias = nn.Parameter(torch.zeros(num_channels))
62
69
  self.eps = eps
63
70
 
64
71
  def forward(self, x: torch.Tensor) -> torch.Tensor:
65
- # x: (B, C, ..spatial..)
66
- # Normalize over channel dimension
67
- mean = x.mean(dim=1, keepdim=True)
68
- var = x.var(dim=1, keepdim=True, unbiased=False)
69
- x = (x - mean) / (var + self.eps).sqrt()
70
-
71
- # Apply learnable parameters
72
- shape = [1, -1] + [1] * self.dim # (1, C, 1, ...) for broadcasting
73
- x = x * self.weight.view(*shape) + self.bias.view(*shape)
72
+ """
73
+ Apply LayerNorm in channels-last format.
74
+
75
+ Args:
76
+ x: Input tensor in channels-first format
77
+ - 1D: (B, C, L)
78
+ - 2D: (B, C, H, W)
79
+ - 3D: (B, C, D, H, W)
80
+
81
+ Returns:
82
+ Normalized tensor in same format as input
83
+ """
84
+ if self.dim == 1:
85
+ # (B, C, L) -> (B, L, C) -> LayerNorm -> (B, C, L)
86
+ x = x.permute(0, 2, 1)
87
+ x = F.layer_norm(x, (self.num_channels,), self.weight, self.bias, self.eps)
88
+ x = x.permute(0, 2, 1)
89
+ elif self.dim == 2:
90
+ # (B, C, H, W) -> (B, H, W, C) -> LayerNorm -> (B, C, H, W)
91
+ x = x.permute(0, 2, 3, 1)
92
+ x = F.layer_norm(x, (self.num_channels,), self.weight, self.bias, self.eps)
93
+ x = x.permute(0, 3, 1, 2)
94
+ else:
95
+ # (B, C, D, H, W) -> (B, D, H, W, C) -> LayerNorm -> (B, C, D, H, W)
96
+ x = x.permute(0, 2, 3, 4, 1)
97
+ x = F.layer_norm(x, (self.num_channels,), self.weight, self.bias, self.eps)
98
+ x = x.permute(0, 4, 1, 2, 3)
74
99
  return x
75
100
 
76
101
 
77
102
  class ConvNeXtBlock(nn.Module):
78
103
  """
79
- ConvNeXt block with inverted bottleneck design.
80
-
81
- Architecture:
82
- - 7x7 depthwise conv
83
- - LayerNorm
84
- - 1x1 conv (expand by 4x)
85
- - GELU
86
- - 1x1 conv (reduce back)
87
- - Residual connection
104
+ ConvNeXt block matching the official Facebook implementation.
105
+
106
+ Uses the second variant from the paper which is slightly faster in PyTorch:
107
+ 1. DwConv (channels-first)
108
+ 2. Permute to channels-last
109
+ 3. LayerNorm Linear → GELU → Linear (all channels-last)
110
+ 4. LayerScale (gamma * x)
111
+ 5. Permute back to channels-first
112
+ 6. Residual connection
113
+
114
+ The LayerScale mechanism is critical for stable training in deep networks.
115
+ It scales the output by a learnable parameter initialized to 1e-6.
116
+
117
+ References:
118
+ Liu, Z., et al. (2022). A ConvNet for the 2020s. CVPR 2022.
119
+ https://github.com/facebookresearch/ConvNeXt
88
120
  """
89
121
 
90
122
  def __init__(
@@ -93,21 +125,36 @@ class ConvNeXtBlock(nn.Module):
93
125
  dim: int = 2,
94
126
  expansion_ratio: float = 4.0,
95
127
  drop_path: float = 0.0,
128
+ layer_scale_init_value: float = 1e-6,
96
129
  ):
97
130
  super().__init__()
131
+ self.dim = dim
98
132
  Conv = _get_conv_layer(dim)
99
133
  hidden_dim = int(channels * expansion_ratio)
100
134
 
101
- # Depthwise conv (7x7)
135
+ # Depthwise conv (7x7) - operates in channels-first
102
136
  self.dwconv = Conv(
103
137
  channels, channels, kernel_size=7, padding=3, groups=channels
104
138
  )
105
- self.norm = LayerNormNd(channels, dim)
106
139
 
107
- # Pointwise convs (1x1)
108
- self.pwconv1 = Conv(channels, hidden_dim, kernel_size=1)
140
+ # LayerNorm (channels-last format, using standard nn.LayerNorm)
141
+ self.norm = nn.LayerNorm(channels, eps=1e-6)
142
+
143
+ # Pointwise convs implemented with Linear layers (channels-last)
144
+ # This matches the official implementation and is slightly faster
145
+ self.pwconv1 = nn.Linear(channels, hidden_dim)
109
146
  self.act = nn.GELU()
110
- self.pwconv2 = Conv(hidden_dim, channels, kernel_size=1)
147
+ self.pwconv2 = nn.Linear(hidden_dim, channels)
148
+
149
+ # LayerScale: learnable per-channel scaling (critical for deep networks)
150
+ # Initialized to small value (1e-6) to prevent gradient explosion
151
+ self.gamma = (
152
+ nn.Parameter(
153
+ layer_scale_init_value * torch.ones(channels), requires_grad=True
154
+ )
155
+ if layer_scale_init_value > 0
156
+ else None
157
+ )
111
158
 
112
159
  # Stochastic depth (drop path) - simplified version
113
160
  self.drop_path = nn.Identity() # Can be replaced with DropPath if needed
@@ -115,14 +162,38 @@ class ConvNeXtBlock(nn.Module):
115
162
  def forward(self, x: torch.Tensor) -> torch.Tensor:
116
163
  residual = x
117
164
 
165
+ # Depthwise conv in channels-first format
118
166
  x = self.dwconv(x)
167
+
168
+ # Permute to channels-last for LayerNorm and Linear layers
169
+ if self.dim == 1:
170
+ x = x.permute(0, 2, 1) # (B, C, L) -> (B, L, C)
171
+ elif self.dim == 2:
172
+ x = x.permute(0, 2, 3, 1) # (B, C, H, W) -> (B, H, W, C)
173
+ else:
174
+ x = x.permute(0, 2, 3, 4, 1) # (B, C, D, H, W) -> (B, D, H, W, C)
175
+
176
+ # LayerNorm + MLP (all in channels-last)
119
177
  x = self.norm(x)
120
178
  x = self.pwconv1(x)
121
179
  x = self.act(x)
122
180
  x = self.pwconv2(x)
123
- x = self.drop_path(x)
124
181
 
125
- return residual + x
182
+ # Apply LayerScale
183
+ if self.gamma is not None:
184
+ x = self.gamma * x
185
+
186
+ # Permute back to channels-first
187
+ if self.dim == 1:
188
+ x = x.permute(0, 2, 1) # (B, L, C) -> (B, C, L)
189
+ elif self.dim == 2:
190
+ x = x.permute(0, 3, 1, 2) # (B, H, W, C) -> (B, C, H, W)
191
+ else:
192
+ x = x.permute(0, 4, 1, 2, 3) # (B, D, H, W, C) -> (B, C, D, H, W)
193
+
194
+ # Residual connection with drop path
195
+ x = residual + self.drop_path(x)
196
+ return x
126
197
 
127
198
 
128
199
  class ConvNeXtBase(BaseModel):
@@ -244,7 +315,7 @@ class ConvNeXtTiny(ConvNeXtBase):
244
315
  """
245
316
  ConvNeXt-Tiny: Smallest variant.
246
317
 
247
- ~28M parameters (2D). Good for: Limited compute, fast training.
318
+ ~27.8M backbone parameters (2D). Good for: Limited compute, fast training.
248
319
 
249
320
  Args:
250
321
  in_shape: (L,), (H, W), or (D, H, W)
@@ -270,7 +341,7 @@ class ConvNeXtSmall(ConvNeXtBase):
270
341
  """
271
342
  ConvNeXt-Small: Medium variant.
272
343
 
273
- ~50M parameters (2D). Good for: Balanced performance.
344
+ ~49.5M backbone parameters (2D). Good for: Balanced performance.
274
345
 
275
346
  Args:
276
347
  in_shape: (L,), (H, W), or (D, H, W)
@@ -296,7 +367,7 @@ class ConvNeXtBase_(ConvNeXtBase):
296
367
  """
297
368
  ConvNeXt-Base: Standard variant.
298
369
 
299
- ~89M parameters (2D). Good for: High accuracy, larger datasets.
370
+ ~87.6M backbone parameters (2D). Good for: High accuracy, larger datasets.
300
371
 
301
372
  Args:
302
373
  in_shape: (L,), (H, W), or (D, H, W)
@@ -337,7 +408,7 @@ class ConvNeXtTinyPretrained(BaseModel):
337
408
  """
338
409
  ConvNeXt-Tiny with ImageNet pretrained weights (2D only).
339
410
 
340
- ~28M parameters. Good for: Transfer learning with modern CNN.
411
+ ~27.8M backbone parameters. Good for: Transfer learning with modern CNN.
341
412
 
342
413
  Args:
343
414
  in_shape: (H, W) image dimensions
@@ -393,20 +464,11 @@ class ConvNeXtTinyPretrained(BaseModel):
393
464
  )
394
465
 
395
466
  # Modify first conv for single-channel input
396
- old_conv = self.backbone.features[0][0]
397
- self.backbone.features[0][0] = nn.Conv2d(
398
- 1,
399
- old_conv.out_channels,
400
- kernel_size=old_conv.kernel_size,
401
- stride=old_conv.stride,
402
- padding=old_conv.padding,
403
- bias=old_conv.bias is not None,
467
+ from wavedl.models._pretrained_utils import adapt_first_conv_for_single_channel
468
+
469
+ adapt_first_conv_for_single_channel(
470
+ self.backbone, "features.0.0", pretrained=pretrained
404
471
  )
405
- if pretrained:
406
- with torch.no_grad():
407
- self.backbone.features[0][0].weight = nn.Parameter(
408
- old_conv.weight.mean(dim=1, keepdim=True)
409
- )
410
472
 
411
473
  if freeze_backbone:
412
474
  self._freeze_backbone()