wavedl 1.5.7__py3-none-any.whl → 1.6.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- wavedl/__init__.py +1 -1
- wavedl/hpo.py +451 -451
- wavedl/models/__init__.py +80 -4
- wavedl/models/_pretrained_utils.py +366 -0
- wavedl/models/base.py +48 -0
- wavedl/models/caformer.py +270 -0
- wavedl/models/cnn.py +2 -27
- wavedl/models/convnext.py +113 -51
- wavedl/models/convnext_v2.py +488 -0
- wavedl/models/densenet.py +10 -23
- wavedl/models/efficientnet.py +6 -6
- wavedl/models/efficientnetv2.py +315 -315
- wavedl/models/efficientvit.py +398 -0
- wavedl/models/fastvit.py +252 -0
- wavedl/models/mamba.py +555 -0
- wavedl/models/maxvit.py +254 -0
- wavedl/models/mobilenetv3.py +295 -295
- wavedl/models/regnet.py +406 -406
- wavedl/models/resnet.py +19 -61
- wavedl/models/resnet3d.py +258 -258
- wavedl/models/swin.py +443 -443
- wavedl/models/tcn.py +393 -409
- wavedl/models/unet.py +2 -6
- wavedl/models/unireplknet.py +491 -0
- wavedl/models/vit.py +9 -9
- wavedl/train.py +1430 -1425
- wavedl/utils/config.py +367 -367
- wavedl/utils/cross_validation.py +530 -530
- wavedl/utils/data.py +39 -6
- wavedl/utils/losses.py +216 -216
- wavedl/utils/optimizers.py +216 -216
- wavedl/utils/schedulers.py +251 -251
- {wavedl-1.5.7.dist-info → wavedl-1.6.1.dist-info}/METADATA +150 -82
- wavedl-1.6.1.dist-info/RECORD +46 -0
- wavedl-1.5.7.dist-info/RECORD +0 -38
- {wavedl-1.5.7.dist-info → wavedl-1.6.1.dist-info}/LICENSE +0 -0
- {wavedl-1.5.7.dist-info → wavedl-1.6.1.dist-info}/WHEEL +0 -0
- {wavedl-1.5.7.dist-info → wavedl-1.6.1.dist-info}/entry_points.txt +0 -0
- {wavedl-1.5.7.dist-info → wavedl-1.6.1.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,270 @@
|
|
|
1
|
+
"""
|
|
2
|
+
CaFormer: MetaFormer with Convolution and Attention
|
|
3
|
+
====================================================
|
|
4
|
+
|
|
5
|
+
CaFormer implements the MetaFormer architecture using depthwise separable
|
|
6
|
+
convolutions in early stages and vanilla self-attention in later stages.
|
|
7
|
+
|
|
8
|
+
**Key Features**:
|
|
9
|
+
- MetaFormer principle: architecture > token mixer
|
|
10
|
+
- Hybrid: Conv (early) + Attention (late)
|
|
11
|
+
- StarReLU activation for efficiency
|
|
12
|
+
- State-of-the-art on ImageNet (85.5%)
|
|
13
|
+
|
|
14
|
+
**Variants**:
|
|
15
|
+
- caformer_s18: 26M params
|
|
16
|
+
- caformer_s36: 39M params
|
|
17
|
+
- caformer_m36: 56M params
|
|
18
|
+
|
|
19
|
+
**Related Models**:
|
|
20
|
+
- PoolFormer: Uses pooling instead of attention
|
|
21
|
+
- ConvFormer: Uses only convolutions
|
|
22
|
+
|
|
23
|
+
**Requirements**:
|
|
24
|
+
- timm >= 0.9.0 (for CaFormer models)
|
|
25
|
+
|
|
26
|
+
Reference:
|
|
27
|
+
Yu, W., et al. (2023). MetaFormer Baselines for Vision.
|
|
28
|
+
TPAMI 2023. https://arxiv.org/abs/2210.13452
|
|
29
|
+
|
|
30
|
+
Author: Ductho Le (ductho.le@outlook.com)
|
|
31
|
+
"""
|
|
32
|
+
|
|
33
|
+
import torch
|
|
34
|
+
import torch.nn as nn
|
|
35
|
+
|
|
36
|
+
from wavedl.models._pretrained_utils import build_regression_head
|
|
37
|
+
from wavedl.models.base import BaseModel
|
|
38
|
+
from wavedl.models.registry import register_model
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
__all__ = [
|
|
42
|
+
"CaFormerBase",
|
|
43
|
+
"CaFormerM36",
|
|
44
|
+
"CaFormerS18",
|
|
45
|
+
"CaFormerS36",
|
|
46
|
+
"PoolFormerS12",
|
|
47
|
+
]
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
# =============================================================================
|
|
51
|
+
# CAFORMER BASE CLASS
|
|
52
|
+
# =============================================================================
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
class CaFormerBase(BaseModel):
|
|
56
|
+
"""
|
|
57
|
+
CaFormer base class wrapping timm implementation.
|
|
58
|
+
|
|
59
|
+
MetaFormer with conv (early) + attention (late) token mixing.
|
|
60
|
+
2D only.
|
|
61
|
+
"""
|
|
62
|
+
|
|
63
|
+
def __init__(
|
|
64
|
+
self,
|
|
65
|
+
in_shape: tuple[int, int],
|
|
66
|
+
out_size: int,
|
|
67
|
+
model_name: str = "caformer_s18",
|
|
68
|
+
pretrained: bool = True,
|
|
69
|
+
freeze_backbone: bool = False,
|
|
70
|
+
dropout_rate: float = 0.3,
|
|
71
|
+
**kwargs,
|
|
72
|
+
):
|
|
73
|
+
super().__init__(in_shape, out_size)
|
|
74
|
+
|
|
75
|
+
if len(in_shape) != 2:
|
|
76
|
+
raise ValueError(f"CaFormer requires 2D input (H, W), got {len(in_shape)}D")
|
|
77
|
+
|
|
78
|
+
self.pretrained = pretrained
|
|
79
|
+
self.freeze_backbone = freeze_backbone
|
|
80
|
+
self.model_name = model_name
|
|
81
|
+
|
|
82
|
+
# Try to load from timm
|
|
83
|
+
try:
|
|
84
|
+
import timm
|
|
85
|
+
|
|
86
|
+
self.backbone = timm.create_model(
|
|
87
|
+
model_name,
|
|
88
|
+
pretrained=pretrained,
|
|
89
|
+
num_classes=0, # Remove classifier
|
|
90
|
+
)
|
|
91
|
+
|
|
92
|
+
# Get feature dimension
|
|
93
|
+
with torch.no_grad():
|
|
94
|
+
dummy = torch.zeros(1, 3, *in_shape)
|
|
95
|
+
features = self.backbone(dummy)
|
|
96
|
+
in_features = features.shape[-1]
|
|
97
|
+
|
|
98
|
+
except ImportError:
|
|
99
|
+
raise ImportError(
|
|
100
|
+
"timm >= 0.9.0 is required for CaFormer. "
|
|
101
|
+
"Install with: pip install timm>=0.9.0"
|
|
102
|
+
)
|
|
103
|
+
except Exception as e:
|
|
104
|
+
raise RuntimeError(f"Failed to load CaFormer model '{model_name}': {e}")
|
|
105
|
+
|
|
106
|
+
# Adapt input channels (3 -> 1)
|
|
107
|
+
self._adapt_input_channels()
|
|
108
|
+
|
|
109
|
+
# Regression head
|
|
110
|
+
self.head = build_regression_head(in_features, out_size, dropout_rate)
|
|
111
|
+
|
|
112
|
+
if freeze_backbone:
|
|
113
|
+
self._freeze_backbone()
|
|
114
|
+
|
|
115
|
+
def _adapt_input_channels(self):
|
|
116
|
+
"""Adapt first conv layer for single-channel input."""
|
|
117
|
+
# CaFormer uses stem for first layer
|
|
118
|
+
if hasattr(self.backbone, "stem"):
|
|
119
|
+
first_conv = None
|
|
120
|
+
# Find first conv in stem
|
|
121
|
+
for name, module in self.backbone.stem.named_modules():
|
|
122
|
+
if isinstance(module, nn.Conv2d):
|
|
123
|
+
first_conv = (name, module)
|
|
124
|
+
break
|
|
125
|
+
|
|
126
|
+
if first_conv is not None:
|
|
127
|
+
name, old_conv = first_conv
|
|
128
|
+
new_conv = self._make_new_conv(old_conv)
|
|
129
|
+
# Set the new conv (handle nested structure)
|
|
130
|
+
self._set_module(self.backbone.stem, name, new_conv)
|
|
131
|
+
|
|
132
|
+
def _make_new_conv(self, old_conv: nn.Conv2d) -> nn.Conv2d:
|
|
133
|
+
"""Create new conv layer with 1 input channel."""
|
|
134
|
+
new_conv = nn.Conv2d(
|
|
135
|
+
1,
|
|
136
|
+
old_conv.out_channels,
|
|
137
|
+
kernel_size=old_conv.kernel_size,
|
|
138
|
+
stride=old_conv.stride,
|
|
139
|
+
padding=old_conv.padding,
|
|
140
|
+
bias=old_conv.bias is not None,
|
|
141
|
+
)
|
|
142
|
+
if self.pretrained:
|
|
143
|
+
with torch.no_grad():
|
|
144
|
+
new_conv.weight.copy_(old_conv.weight.mean(dim=1, keepdim=True))
|
|
145
|
+
if old_conv.bias is not None:
|
|
146
|
+
new_conv.bias.copy_(old_conv.bias)
|
|
147
|
+
return new_conv
|
|
148
|
+
|
|
149
|
+
def _set_module(self, parent: nn.Module, name: str, module: nn.Module):
|
|
150
|
+
"""Set a nested module by name."""
|
|
151
|
+
parts = name.split(".")
|
|
152
|
+
for part in parts[:-1]:
|
|
153
|
+
parent = getattr(parent, part)
|
|
154
|
+
setattr(parent, parts[-1], module)
|
|
155
|
+
|
|
156
|
+
def _freeze_backbone(self):
|
|
157
|
+
"""Freeze backbone parameters."""
|
|
158
|
+
for param in self.backbone.parameters():
|
|
159
|
+
param.requires_grad = False
|
|
160
|
+
|
|
161
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
162
|
+
features = self.backbone(x)
|
|
163
|
+
return self.head(features)
|
|
164
|
+
|
|
165
|
+
|
|
166
|
+
# =============================================================================
|
|
167
|
+
# REGISTERED VARIANTS
|
|
168
|
+
# =============================================================================
|
|
169
|
+
|
|
170
|
+
|
|
171
|
+
@register_model("caformer_s18")
|
|
172
|
+
class CaFormerS18(CaFormerBase):
|
|
173
|
+
"""
|
|
174
|
+
CaFormer-S18: ~23.2M backbone parameters.
|
|
175
|
+
|
|
176
|
+
MetaFormer with conv + attention.
|
|
177
|
+
2D only.
|
|
178
|
+
|
|
179
|
+
Example:
|
|
180
|
+
>>> model = CaFormerS18(in_shape=(224, 224), out_size=3)
|
|
181
|
+
>>> x = torch.randn(4, 1, 224, 224)
|
|
182
|
+
>>> out = model(x) # (4, 3)
|
|
183
|
+
"""
|
|
184
|
+
|
|
185
|
+
def __init__(self, in_shape: tuple[int, int], out_size: int, **kwargs):
|
|
186
|
+
super().__init__(
|
|
187
|
+
in_shape=in_shape,
|
|
188
|
+
out_size=out_size,
|
|
189
|
+
model_name="caformer_s18",
|
|
190
|
+
**kwargs,
|
|
191
|
+
)
|
|
192
|
+
|
|
193
|
+
def __repr__(self) -> str:
|
|
194
|
+
return (
|
|
195
|
+
f"CaFormer_S18(in_shape={self.in_shape}, out_size={self.out_size}, "
|
|
196
|
+
f"pretrained={self.pretrained})"
|
|
197
|
+
)
|
|
198
|
+
|
|
199
|
+
|
|
200
|
+
@register_model("caformer_s36")
|
|
201
|
+
class CaFormerS36(CaFormerBase):
|
|
202
|
+
"""
|
|
203
|
+
CaFormer-S36: ~36.2M backbone parameters.
|
|
204
|
+
|
|
205
|
+
Deeper MetaFormer variant.
|
|
206
|
+
2D only.
|
|
207
|
+
"""
|
|
208
|
+
|
|
209
|
+
def __init__(self, in_shape: tuple[int, int], out_size: int, **kwargs):
|
|
210
|
+
super().__init__(
|
|
211
|
+
in_shape=in_shape,
|
|
212
|
+
out_size=out_size,
|
|
213
|
+
model_name="caformer_s36",
|
|
214
|
+
**kwargs,
|
|
215
|
+
)
|
|
216
|
+
|
|
217
|
+
def __repr__(self) -> str:
|
|
218
|
+
return (
|
|
219
|
+
f"CaFormer_S36(in_shape={self.in_shape}, out_size={self.out_size}, "
|
|
220
|
+
f"pretrained={self.pretrained})"
|
|
221
|
+
)
|
|
222
|
+
|
|
223
|
+
|
|
224
|
+
@register_model("caformer_m36")
|
|
225
|
+
class CaFormerM36(CaFormerBase):
|
|
226
|
+
"""
|
|
227
|
+
CaFormer-M36: ~52.6M backbone parameters.
|
|
228
|
+
|
|
229
|
+
Medium-size MetaFormer variant.
|
|
230
|
+
2D only.
|
|
231
|
+
"""
|
|
232
|
+
|
|
233
|
+
def __init__(self, in_shape: tuple[int, int], out_size: int, **kwargs):
|
|
234
|
+
super().__init__(
|
|
235
|
+
in_shape=in_shape,
|
|
236
|
+
out_size=out_size,
|
|
237
|
+
model_name="caformer_m36",
|
|
238
|
+
**kwargs,
|
|
239
|
+
)
|
|
240
|
+
|
|
241
|
+
def __repr__(self) -> str:
|
|
242
|
+
return (
|
|
243
|
+
f"CaFormer_M36(in_shape={self.in_shape}, out_size={self.out_size}, "
|
|
244
|
+
f"pretrained={self.pretrained})"
|
|
245
|
+
)
|
|
246
|
+
|
|
247
|
+
|
|
248
|
+
@register_model("poolformer_s12")
|
|
249
|
+
class PoolFormerS12(CaFormerBase):
|
|
250
|
+
"""
|
|
251
|
+
PoolFormer-S12: ~11.4M backbone parameters.
|
|
252
|
+
|
|
253
|
+
MetaFormer with simple pooling token mixer.
|
|
254
|
+
Proves that architecture matters more than complex attention.
|
|
255
|
+
2D only.
|
|
256
|
+
"""
|
|
257
|
+
|
|
258
|
+
def __init__(self, in_shape: tuple[int, int], out_size: int, **kwargs):
|
|
259
|
+
super().__init__(
|
|
260
|
+
in_shape=in_shape,
|
|
261
|
+
out_size=out_size,
|
|
262
|
+
model_name="poolformer_s12",
|
|
263
|
+
**kwargs,
|
|
264
|
+
)
|
|
265
|
+
|
|
266
|
+
def __repr__(self) -> str:
|
|
267
|
+
return (
|
|
268
|
+
f"PoolFormer_S12(in_shape={self.in_shape}, out_size={self.out_size}, "
|
|
269
|
+
f"pretrained={self.pretrained})"
|
|
270
|
+
)
|
wavedl/models/cnn.py
CHANGED
|
@@ -24,14 +24,10 @@ from typing import Any
|
|
|
24
24
|
import torch
|
|
25
25
|
import torch.nn as nn
|
|
26
26
|
|
|
27
|
-
from wavedl.models.base import BaseModel
|
|
27
|
+
from wavedl.models.base import BaseModel, SpatialShape, compute_num_groups
|
|
28
28
|
from wavedl.models.registry import register_model
|
|
29
29
|
|
|
30
30
|
|
|
31
|
-
# Type alias for spatial shapes
|
|
32
|
-
SpatialShape = tuple[int] | tuple[int, int] | tuple[int, int, int]
|
|
33
|
-
|
|
34
|
-
|
|
35
31
|
def _get_conv_layers(
|
|
36
32
|
dim: int,
|
|
37
33
|
) -> tuple[type[nn.Module], type[nn.Module], type[nn.Module]]:
|
|
@@ -163,27 +159,6 @@ class CNN(BaseModel):
|
|
|
163
159
|
nn.Linear(64, out_size),
|
|
164
160
|
)
|
|
165
161
|
|
|
166
|
-
@staticmethod
|
|
167
|
-
def _compute_num_groups(num_channels: int, target_groups: int = 4) -> int:
|
|
168
|
-
"""
|
|
169
|
-
Compute valid num_groups for GroupNorm that divides num_channels.
|
|
170
|
-
|
|
171
|
-
Finds the largest divisor of num_channels that is <= target_groups,
|
|
172
|
-
or falls back to 1 if no suitable divisor exists.
|
|
173
|
-
|
|
174
|
-
Args:
|
|
175
|
-
num_channels: Number of channels (must be positive)
|
|
176
|
-
target_groups: Desired number of groups (default: 4)
|
|
177
|
-
|
|
178
|
-
Returns:
|
|
179
|
-
Valid num_groups that satisfies num_channels % num_groups == 0
|
|
180
|
-
"""
|
|
181
|
-
# Try target_groups down to 1, return first valid divisor
|
|
182
|
-
for g in range(min(target_groups, num_channels), 0, -1):
|
|
183
|
-
if num_channels % g == 0:
|
|
184
|
-
return g
|
|
185
|
-
return 1 # Fallback (always valid)
|
|
186
|
-
|
|
187
162
|
def _make_conv_block(
|
|
188
163
|
self, in_channels: int, out_channels: int, dropout: float = 0.0
|
|
189
164
|
) -> nn.Sequential:
|
|
@@ -198,7 +173,7 @@ class CNN(BaseModel):
|
|
|
198
173
|
Returns:
|
|
199
174
|
Sequential block: Conv → GroupNorm → LeakyReLU → MaxPool [→ Dropout]
|
|
200
175
|
"""
|
|
201
|
-
num_groups =
|
|
176
|
+
num_groups = compute_num_groups(out_channels, preferred_groups=4)
|
|
202
177
|
|
|
203
178
|
layers = [
|
|
204
179
|
self._Conv(in_channels, out_channels, kernel_size=3, padding=1),
|
wavedl/models/convnext.py
CHANGED
|
@@ -11,9 +11,9 @@ Features: inverted bottleneck, LayerNorm, GELU activation, depthwise convolution
|
|
|
11
11
|
- 3D: Volumetric data, CT/MRI (N, 1, D, H, W) → Conv3d
|
|
12
12
|
|
|
13
13
|
**Variants**:
|
|
14
|
-
- convnext_tiny: Smallest (~
|
|
15
|
-
- convnext_small: Medium (~
|
|
16
|
-
- convnext_base: Standard (~
|
|
14
|
+
- convnext_tiny: Smallest (~27.8M backbone params for 2D)
|
|
15
|
+
- convnext_small: Medium (~49.5M backbone params for 2D)
|
|
16
|
+
- convnext_base: Standard (~87.6M backbone params for 2D)
|
|
17
17
|
|
|
18
18
|
References:
|
|
19
19
|
Liu, Z., et al. (2022). A ConvNet for the 2020s.
|
|
@@ -26,15 +26,12 @@ from typing import Any
|
|
|
26
26
|
|
|
27
27
|
import torch
|
|
28
28
|
import torch.nn as nn
|
|
29
|
+
import torch.nn.functional as F
|
|
29
30
|
|
|
30
|
-
from wavedl.models.base import BaseModel
|
|
31
|
+
from wavedl.models.base import BaseModel, SpatialShape
|
|
31
32
|
from wavedl.models.registry import register_model
|
|
32
33
|
|
|
33
34
|
|
|
34
|
-
# Type alias for spatial shapes
|
|
35
|
-
SpatialShape = tuple[int] | tuple[int, int] | tuple[int, int, int]
|
|
36
|
-
|
|
37
|
-
|
|
38
35
|
def _get_conv_layer(dim: int) -> type[nn.Module]:
|
|
39
36
|
"""Get dimension-appropriate Conv class."""
|
|
40
37
|
if dim == 1:
|
|
@@ -51,40 +48,75 @@ class LayerNormNd(nn.Module):
|
|
|
51
48
|
"""
|
|
52
49
|
LayerNorm for N-dimensional tensors (channels-first format).
|
|
53
50
|
|
|
54
|
-
|
|
51
|
+
Implements channels-last LayerNorm as used in the original ConvNeXt paper.
|
|
52
|
+
Permutes data to channels-last, applies LayerNorm per-channel over spatial
|
|
53
|
+
dimensions, and permutes back to channels-first format.
|
|
54
|
+
|
|
55
|
+
This matches PyTorch's nn.LayerNorm behavior when applied to the channel
|
|
56
|
+
dimension, providing stable gradients for deep ConvNeXt networks.
|
|
57
|
+
|
|
58
|
+
References:
|
|
59
|
+
Liu, Z., et al. (2022). A ConvNet for the 2020s. CVPR 2022.
|
|
60
|
+
https://github.com/facebookresearch/ConvNeXt
|
|
55
61
|
"""
|
|
56
62
|
|
|
57
63
|
def __init__(self, num_channels: int, dim: int, eps: float = 1e-6):
|
|
58
64
|
super().__init__()
|
|
59
65
|
self.dim = dim
|
|
66
|
+
self.num_channels = num_channels
|
|
60
67
|
self.weight = nn.Parameter(torch.ones(num_channels))
|
|
61
68
|
self.bias = nn.Parameter(torch.zeros(num_channels))
|
|
62
69
|
self.eps = eps
|
|
63
70
|
|
|
64
71
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
72
|
+
"""
|
|
73
|
+
Apply LayerNorm in channels-last format.
|
|
74
|
+
|
|
75
|
+
Args:
|
|
76
|
+
x: Input tensor in channels-first format
|
|
77
|
+
- 1D: (B, C, L)
|
|
78
|
+
- 2D: (B, C, H, W)
|
|
79
|
+
- 3D: (B, C, D, H, W)
|
|
80
|
+
|
|
81
|
+
Returns:
|
|
82
|
+
Normalized tensor in same format as input
|
|
83
|
+
"""
|
|
84
|
+
if self.dim == 1:
|
|
85
|
+
# (B, C, L) -> (B, L, C) -> LayerNorm -> (B, C, L)
|
|
86
|
+
x = x.permute(0, 2, 1)
|
|
87
|
+
x = F.layer_norm(x, (self.num_channels,), self.weight, self.bias, self.eps)
|
|
88
|
+
x = x.permute(0, 2, 1)
|
|
89
|
+
elif self.dim == 2:
|
|
90
|
+
# (B, C, H, W) -> (B, H, W, C) -> LayerNorm -> (B, C, H, W)
|
|
91
|
+
x = x.permute(0, 2, 3, 1)
|
|
92
|
+
x = F.layer_norm(x, (self.num_channels,), self.weight, self.bias, self.eps)
|
|
93
|
+
x = x.permute(0, 3, 1, 2)
|
|
94
|
+
else:
|
|
95
|
+
# (B, C, D, H, W) -> (B, D, H, W, C) -> LayerNorm -> (B, C, D, H, W)
|
|
96
|
+
x = x.permute(0, 2, 3, 4, 1)
|
|
97
|
+
x = F.layer_norm(x, (self.num_channels,), self.weight, self.bias, self.eps)
|
|
98
|
+
x = x.permute(0, 4, 1, 2, 3)
|
|
74
99
|
return x
|
|
75
100
|
|
|
76
101
|
|
|
77
102
|
class ConvNeXtBlock(nn.Module):
|
|
78
103
|
"""
|
|
79
|
-
ConvNeXt block
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
104
|
+
ConvNeXt block matching the official Facebook implementation.
|
|
105
|
+
|
|
106
|
+
Uses the second variant from the paper which is slightly faster in PyTorch:
|
|
107
|
+
1. DwConv (channels-first)
|
|
108
|
+
2. Permute to channels-last
|
|
109
|
+
3. LayerNorm → Linear → GELU → Linear (all channels-last)
|
|
110
|
+
4. LayerScale (gamma * x)
|
|
111
|
+
5. Permute back to channels-first
|
|
112
|
+
6. Residual connection
|
|
113
|
+
|
|
114
|
+
The LayerScale mechanism is critical for stable training in deep networks.
|
|
115
|
+
It scales the output by a learnable parameter initialized to 1e-6.
|
|
116
|
+
|
|
117
|
+
References:
|
|
118
|
+
Liu, Z., et al. (2022). A ConvNet for the 2020s. CVPR 2022.
|
|
119
|
+
https://github.com/facebookresearch/ConvNeXt
|
|
88
120
|
"""
|
|
89
121
|
|
|
90
122
|
def __init__(
|
|
@@ -93,21 +125,36 @@ class ConvNeXtBlock(nn.Module):
|
|
|
93
125
|
dim: int = 2,
|
|
94
126
|
expansion_ratio: float = 4.0,
|
|
95
127
|
drop_path: float = 0.0,
|
|
128
|
+
layer_scale_init_value: float = 1e-6,
|
|
96
129
|
):
|
|
97
130
|
super().__init__()
|
|
131
|
+
self.dim = dim
|
|
98
132
|
Conv = _get_conv_layer(dim)
|
|
99
133
|
hidden_dim = int(channels * expansion_ratio)
|
|
100
134
|
|
|
101
|
-
# Depthwise conv (7x7)
|
|
135
|
+
# Depthwise conv (7x7) - operates in channels-first
|
|
102
136
|
self.dwconv = Conv(
|
|
103
137
|
channels, channels, kernel_size=7, padding=3, groups=channels
|
|
104
138
|
)
|
|
105
|
-
self.norm = LayerNormNd(channels, dim)
|
|
106
139
|
|
|
107
|
-
#
|
|
108
|
-
self.
|
|
140
|
+
# LayerNorm (channels-last format, using standard nn.LayerNorm)
|
|
141
|
+
self.norm = nn.LayerNorm(channels, eps=1e-6)
|
|
142
|
+
|
|
143
|
+
# Pointwise convs implemented with Linear layers (channels-last)
|
|
144
|
+
# This matches the official implementation and is slightly faster
|
|
145
|
+
self.pwconv1 = nn.Linear(channels, hidden_dim)
|
|
109
146
|
self.act = nn.GELU()
|
|
110
|
-
self.pwconv2 =
|
|
147
|
+
self.pwconv2 = nn.Linear(hidden_dim, channels)
|
|
148
|
+
|
|
149
|
+
# LayerScale: learnable per-channel scaling (critical for deep networks)
|
|
150
|
+
# Initialized to small value (1e-6) to prevent gradient explosion
|
|
151
|
+
self.gamma = (
|
|
152
|
+
nn.Parameter(
|
|
153
|
+
layer_scale_init_value * torch.ones(channels), requires_grad=True
|
|
154
|
+
)
|
|
155
|
+
if layer_scale_init_value > 0
|
|
156
|
+
else None
|
|
157
|
+
)
|
|
111
158
|
|
|
112
159
|
# Stochastic depth (drop path) - simplified version
|
|
113
160
|
self.drop_path = nn.Identity() # Can be replaced with DropPath if needed
|
|
@@ -115,14 +162,38 @@ class ConvNeXtBlock(nn.Module):
|
|
|
115
162
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
116
163
|
residual = x
|
|
117
164
|
|
|
165
|
+
# Depthwise conv in channels-first format
|
|
118
166
|
x = self.dwconv(x)
|
|
167
|
+
|
|
168
|
+
# Permute to channels-last for LayerNorm and Linear layers
|
|
169
|
+
if self.dim == 1:
|
|
170
|
+
x = x.permute(0, 2, 1) # (B, C, L) -> (B, L, C)
|
|
171
|
+
elif self.dim == 2:
|
|
172
|
+
x = x.permute(0, 2, 3, 1) # (B, C, H, W) -> (B, H, W, C)
|
|
173
|
+
else:
|
|
174
|
+
x = x.permute(0, 2, 3, 4, 1) # (B, C, D, H, W) -> (B, D, H, W, C)
|
|
175
|
+
|
|
176
|
+
# LayerNorm + MLP (all in channels-last)
|
|
119
177
|
x = self.norm(x)
|
|
120
178
|
x = self.pwconv1(x)
|
|
121
179
|
x = self.act(x)
|
|
122
180
|
x = self.pwconv2(x)
|
|
123
|
-
x = self.drop_path(x)
|
|
124
181
|
|
|
125
|
-
|
|
182
|
+
# Apply LayerScale
|
|
183
|
+
if self.gamma is not None:
|
|
184
|
+
x = self.gamma * x
|
|
185
|
+
|
|
186
|
+
# Permute back to channels-first
|
|
187
|
+
if self.dim == 1:
|
|
188
|
+
x = x.permute(0, 2, 1) # (B, L, C) -> (B, C, L)
|
|
189
|
+
elif self.dim == 2:
|
|
190
|
+
x = x.permute(0, 3, 1, 2) # (B, H, W, C) -> (B, C, H, W)
|
|
191
|
+
else:
|
|
192
|
+
x = x.permute(0, 4, 1, 2, 3) # (B, D, H, W, C) -> (B, C, D, H, W)
|
|
193
|
+
|
|
194
|
+
# Residual connection with drop path
|
|
195
|
+
x = residual + self.drop_path(x)
|
|
196
|
+
return x
|
|
126
197
|
|
|
127
198
|
|
|
128
199
|
class ConvNeXtBase(BaseModel):
|
|
@@ -244,7 +315,7 @@ class ConvNeXtTiny(ConvNeXtBase):
|
|
|
244
315
|
"""
|
|
245
316
|
ConvNeXt-Tiny: Smallest variant.
|
|
246
317
|
|
|
247
|
-
~
|
|
318
|
+
~27.8M backbone parameters (2D). Good for: Limited compute, fast training.
|
|
248
319
|
|
|
249
320
|
Args:
|
|
250
321
|
in_shape: (L,), (H, W), or (D, H, W)
|
|
@@ -270,7 +341,7 @@ class ConvNeXtSmall(ConvNeXtBase):
|
|
|
270
341
|
"""
|
|
271
342
|
ConvNeXt-Small: Medium variant.
|
|
272
343
|
|
|
273
|
-
~
|
|
344
|
+
~49.5M backbone parameters (2D). Good for: Balanced performance.
|
|
274
345
|
|
|
275
346
|
Args:
|
|
276
347
|
in_shape: (L,), (H, W), or (D, H, W)
|
|
@@ -296,7 +367,7 @@ class ConvNeXtBase_(ConvNeXtBase):
|
|
|
296
367
|
"""
|
|
297
368
|
ConvNeXt-Base: Standard variant.
|
|
298
369
|
|
|
299
|
-
~
|
|
370
|
+
~87.6M backbone parameters (2D). Good for: High accuracy, larger datasets.
|
|
300
371
|
|
|
301
372
|
Args:
|
|
302
373
|
in_shape: (L,), (H, W), or (D, H, W)
|
|
@@ -337,7 +408,7 @@ class ConvNeXtTinyPretrained(BaseModel):
|
|
|
337
408
|
"""
|
|
338
409
|
ConvNeXt-Tiny with ImageNet pretrained weights (2D only).
|
|
339
410
|
|
|
340
|
-
~
|
|
411
|
+
~27.8M backbone parameters. Good for: Transfer learning with modern CNN.
|
|
341
412
|
|
|
342
413
|
Args:
|
|
343
414
|
in_shape: (H, W) image dimensions
|
|
@@ -393,20 +464,11 @@ class ConvNeXtTinyPretrained(BaseModel):
|
|
|
393
464
|
)
|
|
394
465
|
|
|
395
466
|
# Modify first conv for single-channel input
|
|
396
|
-
|
|
397
|
-
|
|
398
|
-
|
|
399
|
-
|
|
400
|
-
kernel_size=old_conv.kernel_size,
|
|
401
|
-
stride=old_conv.stride,
|
|
402
|
-
padding=old_conv.padding,
|
|
403
|
-
bias=old_conv.bias is not None,
|
|
467
|
+
from wavedl.models._pretrained_utils import adapt_first_conv_for_single_channel
|
|
468
|
+
|
|
469
|
+
adapt_first_conv_for_single_channel(
|
|
470
|
+
self.backbone, "features.0.0", pretrained=pretrained
|
|
404
471
|
)
|
|
405
|
-
if pretrained:
|
|
406
|
-
with torch.no_grad():
|
|
407
|
-
self.backbone.features[0][0].weight = nn.Parameter(
|
|
408
|
-
old_conv.weight.mean(dim=1, keepdim=True)
|
|
409
|
-
)
|
|
410
472
|
|
|
411
473
|
if freeze_backbone:
|
|
412
474
|
self._freeze_backbone()
|