wavedl 1.5.6__py3-none-any.whl → 1.6.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- wavedl/__init__.py +1 -1
- wavedl/models/__init__.py +52 -4
- wavedl/models/_timm_utils.py +238 -0
- wavedl/models/caformer.py +270 -0
- wavedl/models/convnext.py +108 -33
- wavedl/models/convnext_v2.py +504 -0
- wavedl/models/densenet.py +5 -5
- wavedl/models/efficientnet.py +30 -13
- wavedl/models/efficientnetv2.py +32 -9
- wavedl/models/fastvit.py +285 -0
- wavedl/models/mamba.py +535 -0
- wavedl/models/maxvit.py +251 -0
- wavedl/models/mobilenetv3.py +35 -12
- wavedl/models/regnet.py +39 -16
- wavedl/models/resnet.py +5 -5
- wavedl/models/resnet3d.py +2 -2
- wavedl/models/swin.py +41 -9
- wavedl/models/tcn.py +25 -5
- wavedl/models/unet.py +1 -1
- wavedl/models/vit.py +6 -6
- wavedl/test.py +7 -3
- wavedl/train.py +57 -23
- wavedl/utils/constraints.py +11 -5
- wavedl/utils/data.py +120 -18
- wavedl/utils/metrics.py +287 -326
- {wavedl-1.5.6.dist-info → wavedl-1.6.0.dist-info}/METADATA +104 -67
- wavedl-1.6.0.dist-info/RECORD +44 -0
- wavedl-1.5.6.dist-info/RECORD +0 -38
- {wavedl-1.5.6.dist-info → wavedl-1.6.0.dist-info}/LICENSE +0 -0
- {wavedl-1.5.6.dist-info → wavedl-1.6.0.dist-info}/WHEEL +0 -0
- {wavedl-1.5.6.dist-info → wavedl-1.6.0.dist-info}/entry_points.txt +0 -0
- {wavedl-1.5.6.dist-info → wavedl-1.6.0.dist-info}/top_level.txt +0 -0
wavedl/models/convnext.py
CHANGED
|
@@ -11,9 +11,9 @@ Features: inverted bottleneck, LayerNorm, GELU activation, depthwise convolution
|
|
|
11
11
|
- 3D: Volumetric data, CT/MRI (N, 1, D, H, W) → Conv3d
|
|
12
12
|
|
|
13
13
|
**Variants**:
|
|
14
|
-
- convnext_tiny: Smallest (~
|
|
15
|
-
- convnext_small: Medium (~
|
|
16
|
-
- convnext_base: Standard (~
|
|
14
|
+
- convnext_tiny: Smallest (~27.8M backbone params for 2D)
|
|
15
|
+
- convnext_small: Medium (~49.5M backbone params for 2D)
|
|
16
|
+
- convnext_base: Standard (~87.6M backbone params for 2D)
|
|
17
17
|
|
|
18
18
|
References:
|
|
19
19
|
Liu, Z., et al. (2022). A ConvNet for the 2020s.
|
|
@@ -26,6 +26,7 @@ from typing import Any
|
|
|
26
26
|
|
|
27
27
|
import torch
|
|
28
28
|
import torch.nn as nn
|
|
29
|
+
import torch.nn.functional as F
|
|
29
30
|
|
|
30
31
|
from wavedl.models.base import BaseModel
|
|
31
32
|
from wavedl.models.registry import register_model
|
|
@@ -51,40 +52,75 @@ class LayerNormNd(nn.Module):
|
|
|
51
52
|
"""
|
|
52
53
|
LayerNorm for N-dimensional tensors (channels-first format).
|
|
53
54
|
|
|
54
|
-
|
|
55
|
+
Implements channels-last LayerNorm as used in the original ConvNeXt paper.
|
|
56
|
+
Permutes data to channels-last, applies LayerNorm per-channel over spatial
|
|
57
|
+
dimensions, and permutes back to channels-first format.
|
|
58
|
+
|
|
59
|
+
This matches PyTorch's nn.LayerNorm behavior when applied to the channel
|
|
60
|
+
dimension, providing stable gradients for deep ConvNeXt networks.
|
|
61
|
+
|
|
62
|
+
References:
|
|
63
|
+
Liu, Z., et al. (2022). A ConvNet for the 2020s. CVPR 2022.
|
|
64
|
+
https://github.com/facebookresearch/ConvNeXt
|
|
55
65
|
"""
|
|
56
66
|
|
|
57
67
|
def __init__(self, num_channels: int, dim: int, eps: float = 1e-6):
|
|
58
68
|
super().__init__()
|
|
59
69
|
self.dim = dim
|
|
70
|
+
self.num_channels = num_channels
|
|
60
71
|
self.weight = nn.Parameter(torch.ones(num_channels))
|
|
61
72
|
self.bias = nn.Parameter(torch.zeros(num_channels))
|
|
62
73
|
self.eps = eps
|
|
63
74
|
|
|
64
75
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
76
|
+
"""
|
|
77
|
+
Apply LayerNorm in channels-last format.
|
|
78
|
+
|
|
79
|
+
Args:
|
|
80
|
+
x: Input tensor in channels-first format
|
|
81
|
+
- 1D: (B, C, L)
|
|
82
|
+
- 2D: (B, C, H, W)
|
|
83
|
+
- 3D: (B, C, D, H, W)
|
|
84
|
+
|
|
85
|
+
Returns:
|
|
86
|
+
Normalized tensor in same format as input
|
|
87
|
+
"""
|
|
88
|
+
if self.dim == 1:
|
|
89
|
+
# (B, C, L) -> (B, L, C) -> LayerNorm -> (B, C, L)
|
|
90
|
+
x = x.permute(0, 2, 1)
|
|
91
|
+
x = F.layer_norm(x, (self.num_channels,), self.weight, self.bias, self.eps)
|
|
92
|
+
x = x.permute(0, 2, 1)
|
|
93
|
+
elif self.dim == 2:
|
|
94
|
+
# (B, C, H, W) -> (B, H, W, C) -> LayerNorm -> (B, C, H, W)
|
|
95
|
+
x = x.permute(0, 2, 3, 1)
|
|
96
|
+
x = F.layer_norm(x, (self.num_channels,), self.weight, self.bias, self.eps)
|
|
97
|
+
x = x.permute(0, 3, 1, 2)
|
|
98
|
+
else:
|
|
99
|
+
# (B, C, D, H, W) -> (B, D, H, W, C) -> LayerNorm -> (B, C, D, H, W)
|
|
100
|
+
x = x.permute(0, 2, 3, 4, 1)
|
|
101
|
+
x = F.layer_norm(x, (self.num_channels,), self.weight, self.bias, self.eps)
|
|
102
|
+
x = x.permute(0, 4, 1, 2, 3)
|
|
74
103
|
return x
|
|
75
104
|
|
|
76
105
|
|
|
77
106
|
class ConvNeXtBlock(nn.Module):
|
|
78
107
|
"""
|
|
79
|
-
ConvNeXt block
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
108
|
+
ConvNeXt block matching the official Facebook implementation.
|
|
109
|
+
|
|
110
|
+
Uses the second variant from the paper which is slightly faster in PyTorch:
|
|
111
|
+
1. DwConv (channels-first)
|
|
112
|
+
2. Permute to channels-last
|
|
113
|
+
3. LayerNorm → Linear → GELU → Linear (all channels-last)
|
|
114
|
+
4. LayerScale (gamma * x)
|
|
115
|
+
5. Permute back to channels-first
|
|
116
|
+
6. Residual connection
|
|
117
|
+
|
|
118
|
+
The LayerScale mechanism is critical for stable training in deep networks.
|
|
119
|
+
It scales the output by a learnable parameter initialized to 1e-6.
|
|
120
|
+
|
|
121
|
+
References:
|
|
122
|
+
Liu, Z., et al. (2022). A ConvNet for the 2020s. CVPR 2022.
|
|
123
|
+
https://github.com/facebookresearch/ConvNeXt
|
|
88
124
|
"""
|
|
89
125
|
|
|
90
126
|
def __init__(
|
|
@@ -93,21 +129,36 @@ class ConvNeXtBlock(nn.Module):
|
|
|
93
129
|
dim: int = 2,
|
|
94
130
|
expansion_ratio: float = 4.0,
|
|
95
131
|
drop_path: float = 0.0,
|
|
132
|
+
layer_scale_init_value: float = 1e-6,
|
|
96
133
|
):
|
|
97
134
|
super().__init__()
|
|
135
|
+
self.dim = dim
|
|
98
136
|
Conv = _get_conv_layer(dim)
|
|
99
137
|
hidden_dim = int(channels * expansion_ratio)
|
|
100
138
|
|
|
101
|
-
# Depthwise conv (7x7)
|
|
139
|
+
# Depthwise conv (7x7) - operates in channels-first
|
|
102
140
|
self.dwconv = Conv(
|
|
103
141
|
channels, channels, kernel_size=7, padding=3, groups=channels
|
|
104
142
|
)
|
|
105
|
-
self.norm = LayerNormNd(channels, dim)
|
|
106
143
|
|
|
107
|
-
#
|
|
108
|
-
self.
|
|
144
|
+
# LayerNorm (channels-last format, using standard nn.LayerNorm)
|
|
145
|
+
self.norm = nn.LayerNorm(channels, eps=1e-6)
|
|
146
|
+
|
|
147
|
+
# Pointwise convs implemented with Linear layers (channels-last)
|
|
148
|
+
# This matches the official implementation and is slightly faster
|
|
149
|
+
self.pwconv1 = nn.Linear(channels, hidden_dim)
|
|
109
150
|
self.act = nn.GELU()
|
|
110
|
-
self.pwconv2 =
|
|
151
|
+
self.pwconv2 = nn.Linear(hidden_dim, channels)
|
|
152
|
+
|
|
153
|
+
# LayerScale: learnable per-channel scaling (critical for deep networks)
|
|
154
|
+
# Initialized to small value (1e-6) to prevent gradient explosion
|
|
155
|
+
self.gamma = (
|
|
156
|
+
nn.Parameter(
|
|
157
|
+
layer_scale_init_value * torch.ones(channels), requires_grad=True
|
|
158
|
+
)
|
|
159
|
+
if layer_scale_init_value > 0
|
|
160
|
+
else None
|
|
161
|
+
)
|
|
111
162
|
|
|
112
163
|
# Stochastic depth (drop path) - simplified version
|
|
113
164
|
self.drop_path = nn.Identity() # Can be replaced with DropPath if needed
|
|
@@ -115,14 +166,38 @@ class ConvNeXtBlock(nn.Module):
|
|
|
115
166
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
116
167
|
residual = x
|
|
117
168
|
|
|
169
|
+
# Depthwise conv in channels-first format
|
|
118
170
|
x = self.dwconv(x)
|
|
171
|
+
|
|
172
|
+
# Permute to channels-last for LayerNorm and Linear layers
|
|
173
|
+
if self.dim == 1:
|
|
174
|
+
x = x.permute(0, 2, 1) # (B, C, L) -> (B, L, C)
|
|
175
|
+
elif self.dim == 2:
|
|
176
|
+
x = x.permute(0, 2, 3, 1) # (B, C, H, W) -> (B, H, W, C)
|
|
177
|
+
else:
|
|
178
|
+
x = x.permute(0, 2, 3, 4, 1) # (B, C, D, H, W) -> (B, D, H, W, C)
|
|
179
|
+
|
|
180
|
+
# LayerNorm + MLP (all in channels-last)
|
|
119
181
|
x = self.norm(x)
|
|
120
182
|
x = self.pwconv1(x)
|
|
121
183
|
x = self.act(x)
|
|
122
184
|
x = self.pwconv2(x)
|
|
123
|
-
x = self.drop_path(x)
|
|
124
185
|
|
|
125
|
-
|
|
186
|
+
# Apply LayerScale
|
|
187
|
+
if self.gamma is not None:
|
|
188
|
+
x = self.gamma * x
|
|
189
|
+
|
|
190
|
+
# Permute back to channels-first
|
|
191
|
+
if self.dim == 1:
|
|
192
|
+
x = x.permute(0, 2, 1) # (B, L, C) -> (B, C, L)
|
|
193
|
+
elif self.dim == 2:
|
|
194
|
+
x = x.permute(0, 3, 1, 2) # (B, H, W, C) -> (B, C, H, W)
|
|
195
|
+
else:
|
|
196
|
+
x = x.permute(0, 4, 1, 2, 3) # (B, D, H, W, C) -> (B, C, D, H, W)
|
|
197
|
+
|
|
198
|
+
# Residual connection with drop path
|
|
199
|
+
x = residual + self.drop_path(x)
|
|
200
|
+
return x
|
|
126
201
|
|
|
127
202
|
|
|
128
203
|
class ConvNeXtBase(BaseModel):
|
|
@@ -244,7 +319,7 @@ class ConvNeXtTiny(ConvNeXtBase):
|
|
|
244
319
|
"""
|
|
245
320
|
ConvNeXt-Tiny: Smallest variant.
|
|
246
321
|
|
|
247
|
-
~
|
|
322
|
+
~27.8M backbone parameters (2D). Good for: Limited compute, fast training.
|
|
248
323
|
|
|
249
324
|
Args:
|
|
250
325
|
in_shape: (L,), (H, W), or (D, H, W)
|
|
@@ -270,7 +345,7 @@ class ConvNeXtSmall(ConvNeXtBase):
|
|
|
270
345
|
"""
|
|
271
346
|
ConvNeXt-Small: Medium variant.
|
|
272
347
|
|
|
273
|
-
~
|
|
348
|
+
~49.5M backbone parameters (2D). Good for: Balanced performance.
|
|
274
349
|
|
|
275
350
|
Args:
|
|
276
351
|
in_shape: (L,), (H, W), or (D, H, W)
|
|
@@ -296,7 +371,7 @@ class ConvNeXtBase_(ConvNeXtBase):
|
|
|
296
371
|
"""
|
|
297
372
|
ConvNeXt-Base: Standard variant.
|
|
298
373
|
|
|
299
|
-
~
|
|
374
|
+
~87.6M backbone parameters (2D). Good for: High accuracy, larger datasets.
|
|
300
375
|
|
|
301
376
|
Args:
|
|
302
377
|
in_shape: (L,), (H, W), or (D, H, W)
|
|
@@ -337,7 +412,7 @@ class ConvNeXtTinyPretrained(BaseModel):
|
|
|
337
412
|
"""
|
|
338
413
|
ConvNeXt-Tiny with ImageNet pretrained weights (2D only).
|
|
339
414
|
|
|
340
|
-
~
|
|
415
|
+
~27.8M backbone parameters. Good for: Transfer learning with modern CNN.
|
|
341
416
|
|
|
342
417
|
Args:
|
|
343
418
|
in_shape: (H, W) image dimensions
|