wavedl 1.5.7__py3-none-any.whl → 1.6.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
wavedl/models/convnext.py CHANGED
@@ -11,9 +11,9 @@ Features: inverted bottleneck, LayerNorm, GELU activation, depthwise convolution
11
11
  - 3D: Volumetric data, CT/MRI (N, 1, D, H, W) → Conv3d
12
12
 
13
13
  **Variants**:
14
- - convnext_tiny: Smallest (~28M params for 2D)
15
- - convnext_small: Medium (~50M params for 2D)
16
- - convnext_base: Standard (~89M params for 2D)
14
+ - convnext_tiny: Smallest (~27.8M backbone params for 2D)
15
+ - convnext_small: Medium (~49.5M backbone params for 2D)
16
+ - convnext_base: Standard (~87.6M backbone params for 2D)
17
17
 
18
18
  References:
19
19
  Liu, Z., et al. (2022). A ConvNet for the 2020s.
@@ -26,6 +26,7 @@ from typing import Any
26
26
 
27
27
  import torch
28
28
  import torch.nn as nn
29
+ import torch.nn.functional as F
29
30
 
30
31
  from wavedl.models.base import BaseModel
31
32
  from wavedl.models.registry import register_model
@@ -51,40 +52,75 @@ class LayerNormNd(nn.Module):
51
52
  """
52
53
  LayerNorm for N-dimensional tensors (channels-first format).
53
54
 
54
- Normalizes over the channel dimension, supporting Conv1d/2d/3d outputs.
55
+ Implements channels-last LayerNorm as used in the original ConvNeXt paper.
56
+ Permutes data to channels-last, applies LayerNorm per-channel over spatial
57
+ dimensions, and permutes back to channels-first format.
58
+
59
+ This matches PyTorch's nn.LayerNorm behavior when applied to the channel
60
+ dimension, providing stable gradients for deep ConvNeXt networks.
61
+
62
+ References:
63
+ Liu, Z., et al. (2022). A ConvNet for the 2020s. CVPR 2022.
64
+ https://github.com/facebookresearch/ConvNeXt
55
65
  """
56
66
 
57
67
  def __init__(self, num_channels: int, dim: int, eps: float = 1e-6):
58
68
  super().__init__()
59
69
  self.dim = dim
70
+ self.num_channels = num_channels
60
71
  self.weight = nn.Parameter(torch.ones(num_channels))
61
72
  self.bias = nn.Parameter(torch.zeros(num_channels))
62
73
  self.eps = eps
63
74
 
64
75
  def forward(self, x: torch.Tensor) -> torch.Tensor:
65
- # x: (B, C, ..spatial..)
66
- # Normalize over channel dimension
67
- mean = x.mean(dim=1, keepdim=True)
68
- var = x.var(dim=1, keepdim=True, unbiased=False)
69
- x = (x - mean) / (var + self.eps).sqrt()
70
-
71
- # Apply learnable parameters
72
- shape = [1, -1] + [1] * self.dim # (1, C, 1, ...) for broadcasting
73
- x = x * self.weight.view(*shape) + self.bias.view(*shape)
76
+ """
77
+ Apply LayerNorm in channels-last format.
78
+
79
+ Args:
80
+ x: Input tensor in channels-first format
81
+ - 1D: (B, C, L)
82
+ - 2D: (B, C, H, W)
83
+ - 3D: (B, C, D, H, W)
84
+
85
+ Returns:
86
+ Normalized tensor in same format as input
87
+ """
88
+ if self.dim == 1:
89
+ # (B, C, L) -> (B, L, C) -> LayerNorm -> (B, C, L)
90
+ x = x.permute(0, 2, 1)
91
+ x = F.layer_norm(x, (self.num_channels,), self.weight, self.bias, self.eps)
92
+ x = x.permute(0, 2, 1)
93
+ elif self.dim == 2:
94
+ # (B, C, H, W) -> (B, H, W, C) -> LayerNorm -> (B, C, H, W)
95
+ x = x.permute(0, 2, 3, 1)
96
+ x = F.layer_norm(x, (self.num_channels,), self.weight, self.bias, self.eps)
97
+ x = x.permute(0, 3, 1, 2)
98
+ else:
99
+ # (B, C, D, H, W) -> (B, D, H, W, C) -> LayerNorm -> (B, C, D, H, W)
100
+ x = x.permute(0, 2, 3, 4, 1)
101
+ x = F.layer_norm(x, (self.num_channels,), self.weight, self.bias, self.eps)
102
+ x = x.permute(0, 4, 1, 2, 3)
74
103
  return x
75
104
 
76
105
 
77
106
  class ConvNeXtBlock(nn.Module):
78
107
  """
79
- ConvNeXt block with inverted bottleneck design.
80
-
81
- Architecture:
82
- - 7x7 depthwise conv
83
- - LayerNorm
84
- - 1x1 conv (expand by 4x)
85
- - GELU
86
- - 1x1 conv (reduce back)
87
- - Residual connection
108
+ ConvNeXt block matching the official Facebook implementation.
109
+
110
+ Uses the second variant from the paper which is slightly faster in PyTorch:
111
+ 1. DwConv (channels-first)
112
+ 2. Permute to channels-last
113
+ 3. LayerNorm Linear → GELU → Linear (all channels-last)
114
+ 4. LayerScale (gamma * x)
115
+ 5. Permute back to channels-first
116
+ 6. Residual connection
117
+
118
+ The LayerScale mechanism is critical for stable training in deep networks.
119
+ It scales the output by a learnable parameter initialized to 1e-6.
120
+
121
+ References:
122
+ Liu, Z., et al. (2022). A ConvNet for the 2020s. CVPR 2022.
123
+ https://github.com/facebookresearch/ConvNeXt
88
124
  """
89
125
 
90
126
  def __init__(
@@ -93,21 +129,36 @@ class ConvNeXtBlock(nn.Module):
93
129
  dim: int = 2,
94
130
  expansion_ratio: float = 4.0,
95
131
  drop_path: float = 0.0,
132
+ layer_scale_init_value: float = 1e-6,
96
133
  ):
97
134
  super().__init__()
135
+ self.dim = dim
98
136
  Conv = _get_conv_layer(dim)
99
137
  hidden_dim = int(channels * expansion_ratio)
100
138
 
101
- # Depthwise conv (7x7)
139
+ # Depthwise conv (7x7) - operates in channels-first
102
140
  self.dwconv = Conv(
103
141
  channels, channels, kernel_size=7, padding=3, groups=channels
104
142
  )
105
- self.norm = LayerNormNd(channels, dim)
106
143
 
107
- # Pointwise convs (1x1)
108
- self.pwconv1 = Conv(channels, hidden_dim, kernel_size=1)
144
+ # LayerNorm (channels-last format, using standard nn.LayerNorm)
145
+ self.norm = nn.LayerNorm(channels, eps=1e-6)
146
+
147
+ # Pointwise convs implemented with Linear layers (channels-last)
148
+ # This matches the official implementation and is slightly faster
149
+ self.pwconv1 = nn.Linear(channels, hidden_dim)
109
150
  self.act = nn.GELU()
110
- self.pwconv2 = Conv(hidden_dim, channels, kernel_size=1)
151
+ self.pwconv2 = nn.Linear(hidden_dim, channels)
152
+
153
+ # LayerScale: learnable per-channel scaling (critical for deep networks)
154
+ # Initialized to small value (1e-6) to prevent gradient explosion
155
+ self.gamma = (
156
+ nn.Parameter(
157
+ layer_scale_init_value * torch.ones(channels), requires_grad=True
158
+ )
159
+ if layer_scale_init_value > 0
160
+ else None
161
+ )
111
162
 
112
163
  # Stochastic depth (drop path) - simplified version
113
164
  self.drop_path = nn.Identity() # Can be replaced with DropPath if needed
@@ -115,14 +166,38 @@ class ConvNeXtBlock(nn.Module):
115
166
  def forward(self, x: torch.Tensor) -> torch.Tensor:
116
167
  residual = x
117
168
 
169
+ # Depthwise conv in channels-first format
118
170
  x = self.dwconv(x)
171
+
172
+ # Permute to channels-last for LayerNorm and Linear layers
173
+ if self.dim == 1:
174
+ x = x.permute(0, 2, 1) # (B, C, L) -> (B, L, C)
175
+ elif self.dim == 2:
176
+ x = x.permute(0, 2, 3, 1) # (B, C, H, W) -> (B, H, W, C)
177
+ else:
178
+ x = x.permute(0, 2, 3, 4, 1) # (B, C, D, H, W) -> (B, D, H, W, C)
179
+
180
+ # LayerNorm + MLP (all in channels-last)
119
181
  x = self.norm(x)
120
182
  x = self.pwconv1(x)
121
183
  x = self.act(x)
122
184
  x = self.pwconv2(x)
123
- x = self.drop_path(x)
124
185
 
125
- return residual + x
186
+ # Apply LayerScale
187
+ if self.gamma is not None:
188
+ x = self.gamma * x
189
+
190
+ # Permute back to channels-first
191
+ if self.dim == 1:
192
+ x = x.permute(0, 2, 1) # (B, L, C) -> (B, C, L)
193
+ elif self.dim == 2:
194
+ x = x.permute(0, 3, 1, 2) # (B, H, W, C) -> (B, C, H, W)
195
+ else:
196
+ x = x.permute(0, 4, 1, 2, 3) # (B, D, H, W, C) -> (B, C, D, H, W)
197
+
198
+ # Residual connection with drop path
199
+ x = residual + self.drop_path(x)
200
+ return x
126
201
 
127
202
 
128
203
  class ConvNeXtBase(BaseModel):
@@ -244,7 +319,7 @@ class ConvNeXtTiny(ConvNeXtBase):
244
319
  """
245
320
  ConvNeXt-Tiny: Smallest variant.
246
321
 
247
- ~28M parameters (2D). Good for: Limited compute, fast training.
322
+ ~27.8M backbone parameters (2D). Good for: Limited compute, fast training.
248
323
 
249
324
  Args:
250
325
  in_shape: (L,), (H, W), or (D, H, W)
@@ -270,7 +345,7 @@ class ConvNeXtSmall(ConvNeXtBase):
270
345
  """
271
346
  ConvNeXt-Small: Medium variant.
272
347
 
273
- ~50M parameters (2D). Good for: Balanced performance.
348
+ ~49.5M backbone parameters (2D). Good for: Balanced performance.
274
349
 
275
350
  Args:
276
351
  in_shape: (L,), (H, W), or (D, H, W)
@@ -296,7 +371,7 @@ class ConvNeXtBase_(ConvNeXtBase):
296
371
  """
297
372
  ConvNeXt-Base: Standard variant.
298
373
 
299
- ~89M parameters (2D). Good for: High accuracy, larger datasets.
374
+ ~87.6M backbone parameters (2D). Good for: High accuracy, larger datasets.
300
375
 
301
376
  Args:
302
377
  in_shape: (L,), (H, W), or (D, H, W)
@@ -337,7 +412,7 @@ class ConvNeXtTinyPretrained(BaseModel):
337
412
  """
338
413
  ConvNeXt-Tiny with ImageNet pretrained weights (2D only).
339
414
 
340
- ~28M parameters. Good for: Transfer learning with modern CNN.
415
+ ~27.8M backbone parameters. Good for: Transfer learning with modern CNN.
341
416
 
342
417
  Args:
343
418
  in_shape: (H, W) image dimensions