wavedl 1.5.7__py3-none-any.whl → 1.6.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- wavedl/__init__.py +1 -1
- wavedl/models/__init__.py +52 -4
- wavedl/models/_timm_utils.py +238 -0
- wavedl/models/caformer.py +270 -0
- wavedl/models/convnext.py +108 -33
- wavedl/models/convnext_v2.py +504 -0
- wavedl/models/densenet.py +5 -5
- wavedl/models/efficientnet.py +6 -6
- wavedl/models/efficientnetv2.py +3 -3
- wavedl/models/fastvit.py +285 -0
- wavedl/models/mamba.py +535 -0
- wavedl/models/maxvit.py +251 -0
- wavedl/models/mobilenetv3.py +6 -6
- wavedl/models/regnet.py +10 -10
- wavedl/models/resnet.py +5 -5
- wavedl/models/resnet3d.py +2 -2
- wavedl/models/swin.py +3 -3
- wavedl/models/tcn.py +3 -3
- wavedl/models/unet.py +1 -1
- wavedl/models/vit.py +6 -6
- wavedl/train.py +21 -16
- wavedl/utils/data.py +39 -6
- {wavedl-1.5.7.dist-info → wavedl-1.6.0.dist-info}/METADATA +90 -62
- wavedl-1.6.0.dist-info/RECORD +44 -0
- wavedl-1.5.7.dist-info/RECORD +0 -38
- {wavedl-1.5.7.dist-info → wavedl-1.6.0.dist-info}/LICENSE +0 -0
- {wavedl-1.5.7.dist-info → wavedl-1.6.0.dist-info}/WHEEL +0 -0
- {wavedl-1.5.7.dist-info → wavedl-1.6.0.dist-info}/entry_points.txt +0 -0
- {wavedl-1.5.7.dist-info → wavedl-1.6.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,504 @@
|
|
|
1
|
+
"""
|
|
2
|
+
ConvNeXt V2: Co-designing and Scaling ConvNets with Masked Autoencoders
|
|
3
|
+
========================================================================
|
|
4
|
+
|
|
5
|
+
ConvNeXt V2 improves upon V1 by replacing LayerScale with Global Response
|
|
6
|
+
Normalization (GRN), which enhances inter-channel feature competition.
|
|
7
|
+
|
|
8
|
+
**Key Changes from V1**:
|
|
9
|
+
- GRN layer replaces LayerScale
|
|
10
|
+
- Better compatibility with masked autoencoder pretraining
|
|
11
|
+
- Prevents feature collapse in deep networks
|
|
12
|
+
|
|
13
|
+
**Variants**:
|
|
14
|
+
- convnext_v2_tiny: 28M params, depths [3,3,9,3], dims [96,192,384,768]
|
|
15
|
+
- convnext_v2_small: 50M params, depths [3,3,27,3], dims [96,192,384,768]
|
|
16
|
+
- convnext_v2_base: 89M params, depths [3,3,27,3], dims [128,256,512,1024]
|
|
17
|
+
- convnext_v2_tiny_pretrained: 2D only, ImageNet weights
|
|
18
|
+
|
|
19
|
+
**Supports**: 1D, 2D, 3D inputs
|
|
20
|
+
|
|
21
|
+
Reference:
|
|
22
|
+
Woo, S., et al. (2023). ConvNeXt V2: Co-designing and Scaling ConvNets
|
|
23
|
+
with Masked Autoencoders. CVPR 2023.
|
|
24
|
+
https://arxiv.org/abs/2301.00808
|
|
25
|
+
|
|
26
|
+
Author: Ductho Le (ductho.le@outlook.com)
|
|
27
|
+
"""
|
|
28
|
+
|
|
29
|
+
from typing import Any
|
|
30
|
+
|
|
31
|
+
import torch
|
|
32
|
+
import torch.nn as nn
|
|
33
|
+
|
|
34
|
+
from wavedl.models._timm_utils import (
|
|
35
|
+
LayerNormNd,
|
|
36
|
+
build_regression_head,
|
|
37
|
+
get_conv_layer,
|
|
38
|
+
get_grn_layer,
|
|
39
|
+
get_pool_layer,
|
|
40
|
+
)
|
|
41
|
+
from wavedl.models.base import BaseModel
|
|
42
|
+
from wavedl.models.registry import register_model
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
# Type alias for spatial shapes
|
|
46
|
+
SpatialShape = tuple[int] | tuple[int, int] | tuple[int, int, int]
|
|
47
|
+
|
|
48
|
+
__all__ = [
|
|
49
|
+
"ConvNeXtV2Base",
|
|
50
|
+
"ConvNeXtV2BaseLarge",
|
|
51
|
+
"ConvNeXtV2Small",
|
|
52
|
+
"ConvNeXtV2Tiny",
|
|
53
|
+
"ConvNeXtV2TinyPretrained",
|
|
54
|
+
]
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
# =============================================================================
|
|
58
|
+
# CONVNEXT V2 BLOCK
|
|
59
|
+
# =============================================================================
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
class ConvNeXtV2Block(nn.Module):
|
|
63
|
+
"""
|
|
64
|
+
ConvNeXt V2 Block with GRN instead of LayerScale.
|
|
65
|
+
|
|
66
|
+
Architecture:
|
|
67
|
+
Input → DwConv → LayerNorm → Linear → GELU → GRN → Linear → Residual
|
|
68
|
+
|
|
69
|
+
The GRN layer is the key difference from V1, placed after the
|
|
70
|
+
dimension-expansion in the MLP, replacing LayerScale.
|
|
71
|
+
"""
|
|
72
|
+
|
|
73
|
+
def __init__(
|
|
74
|
+
self,
|
|
75
|
+
dim: int,
|
|
76
|
+
spatial_dim: int,
|
|
77
|
+
drop_path: float = 0.0,
|
|
78
|
+
mlp_ratio: float = 4.0,
|
|
79
|
+
):
|
|
80
|
+
super().__init__()
|
|
81
|
+
self.spatial_dim = spatial_dim
|
|
82
|
+
|
|
83
|
+
Conv = get_conv_layer(spatial_dim)
|
|
84
|
+
GRN = get_grn_layer(spatial_dim)
|
|
85
|
+
|
|
86
|
+
# Depthwise convolution
|
|
87
|
+
kernel_size = 7
|
|
88
|
+
padding = 3
|
|
89
|
+
self.dwconv = Conv(
|
|
90
|
+
dim, dim, kernel_size=kernel_size, padding=padding, groups=dim
|
|
91
|
+
)
|
|
92
|
+
|
|
93
|
+
# LayerNorm (applied in forward with permutation)
|
|
94
|
+
self.norm = nn.LayerNorm(dim, eps=1e-6)
|
|
95
|
+
|
|
96
|
+
# MLP with expansion
|
|
97
|
+
hidden_dim = int(dim * mlp_ratio)
|
|
98
|
+
self.pwconv1 = nn.Linear(dim, hidden_dim) # Expansion
|
|
99
|
+
self.act = nn.GELU()
|
|
100
|
+
self.grn = GRN(hidden_dim) # GRN after expansion (key V2 change)
|
|
101
|
+
self.pwconv2 = nn.Linear(hidden_dim, dim) # Projection
|
|
102
|
+
|
|
103
|
+
# Stochastic depth
|
|
104
|
+
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
|
105
|
+
|
|
106
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
107
|
+
residual = x
|
|
108
|
+
|
|
109
|
+
# Depthwise conv
|
|
110
|
+
x = self.dwconv(x)
|
|
111
|
+
|
|
112
|
+
# Move channels to last for LayerNorm and Linear layers
|
|
113
|
+
if self.spatial_dim == 1:
|
|
114
|
+
x = x.permute(0, 2, 1) # (B, C, L) -> (B, L, C)
|
|
115
|
+
elif self.spatial_dim == 2:
|
|
116
|
+
x = x.permute(0, 2, 3, 1) # (B, C, H, W) -> (B, H, W, C)
|
|
117
|
+
else: # 3D
|
|
118
|
+
x = x.permute(0, 2, 3, 4, 1) # (B, C, D, H, W) -> (B, D, H, W, C)
|
|
119
|
+
|
|
120
|
+
x = self.norm(x)
|
|
121
|
+
x = self.pwconv1(x)
|
|
122
|
+
x = self.act(x)
|
|
123
|
+
|
|
124
|
+
# Move back to channels-first for GRN
|
|
125
|
+
if self.spatial_dim == 1:
|
|
126
|
+
x = x.permute(0, 2, 1) # (B, L, C) -> (B, C, L)
|
|
127
|
+
elif self.spatial_dim == 2:
|
|
128
|
+
x = x.permute(0, 3, 1, 2) # (B, H, W, C) -> (B, C, H, W)
|
|
129
|
+
else: # 3D
|
|
130
|
+
x = x.permute(0, 4, 1, 2, 3) # (B, D, H, W, C) -> (B, C, D, H, W)
|
|
131
|
+
|
|
132
|
+
# Apply GRN (the key V2 innovation)
|
|
133
|
+
x = self.grn(x)
|
|
134
|
+
|
|
135
|
+
# Move to channels-last for final projection
|
|
136
|
+
if self.spatial_dim == 1:
|
|
137
|
+
x = x.permute(0, 2, 1)
|
|
138
|
+
elif self.spatial_dim == 2:
|
|
139
|
+
x = x.permute(0, 2, 3, 1)
|
|
140
|
+
else:
|
|
141
|
+
x = x.permute(0, 2, 3, 4, 1)
|
|
142
|
+
|
|
143
|
+
x = self.pwconv2(x)
|
|
144
|
+
|
|
145
|
+
# Move back to channels-first
|
|
146
|
+
if self.spatial_dim == 1:
|
|
147
|
+
x = x.permute(0, 2, 1)
|
|
148
|
+
elif self.spatial_dim == 2:
|
|
149
|
+
x = x.permute(0, 3, 1, 2)
|
|
150
|
+
else:
|
|
151
|
+
x = x.permute(0, 4, 1, 2, 3)
|
|
152
|
+
|
|
153
|
+
x = residual + self.drop_path(x)
|
|
154
|
+
return x
|
|
155
|
+
|
|
156
|
+
|
|
157
|
+
class DropPath(nn.Module):
|
|
158
|
+
"""Stochastic Depth (drop path) regularization."""
|
|
159
|
+
|
|
160
|
+
def __init__(self, drop_prob: float = 0.0):
|
|
161
|
+
super().__init__()
|
|
162
|
+
self.drop_prob = drop_prob
|
|
163
|
+
|
|
164
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
165
|
+
if self.drop_prob == 0.0 or not self.training:
|
|
166
|
+
return x
|
|
167
|
+
|
|
168
|
+
keep_prob = 1 - self.drop_prob
|
|
169
|
+
shape = (x.shape[0],) + (1,) * (x.ndim - 1)
|
|
170
|
+
random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
|
|
171
|
+
random_tensor.floor_()
|
|
172
|
+
return x.div(keep_prob) * random_tensor
|
|
173
|
+
|
|
174
|
+
|
|
175
|
+
# =============================================================================
|
|
176
|
+
# CONVNEXT V2 BASE CLASS
|
|
177
|
+
# =============================================================================
|
|
178
|
+
|
|
179
|
+
|
|
180
|
+
class ConvNeXtV2Base(BaseModel):
|
|
181
|
+
"""
|
|
182
|
+
ConvNeXt V2 base class for regression.
|
|
183
|
+
|
|
184
|
+
Dimension-agnostic implementation supporting 1D, 2D, and 3D inputs.
|
|
185
|
+
Uses GRN (Global Response Normalization) instead of LayerScale.
|
|
186
|
+
"""
|
|
187
|
+
|
|
188
|
+
def __init__(
|
|
189
|
+
self,
|
|
190
|
+
in_shape: SpatialShape,
|
|
191
|
+
out_size: int,
|
|
192
|
+
depths: list[int],
|
|
193
|
+
dims: list[int],
|
|
194
|
+
drop_path_rate: float = 0.0,
|
|
195
|
+
dropout_rate: float = 0.3,
|
|
196
|
+
**kwargs,
|
|
197
|
+
):
|
|
198
|
+
super().__init__(in_shape, out_size)
|
|
199
|
+
|
|
200
|
+
self.dim = len(in_shape)
|
|
201
|
+
self.depths = depths
|
|
202
|
+
self.dims = dims
|
|
203
|
+
|
|
204
|
+
Conv = get_conv_layer(self.dim)
|
|
205
|
+
Pool = get_pool_layer(self.dim)
|
|
206
|
+
|
|
207
|
+
# Stem: aggressive downsampling (4x stride like ConvNeXt)
|
|
208
|
+
self.stem = nn.Sequential(
|
|
209
|
+
Conv(1, dims[0], kernel_size=4, stride=4),
|
|
210
|
+
LayerNormNd(dims[0], self.dim),
|
|
211
|
+
)
|
|
212
|
+
|
|
213
|
+
# Stochastic depth decay rule
|
|
214
|
+
dp_rates = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
|
|
215
|
+
|
|
216
|
+
# Build stages
|
|
217
|
+
self.stages = nn.ModuleList()
|
|
218
|
+
self.downsamples = nn.ModuleList()
|
|
219
|
+
cur = 0
|
|
220
|
+
|
|
221
|
+
for i in range(len(depths)):
|
|
222
|
+
# Stage: sequence of ConvNeXt V2 blocks
|
|
223
|
+
stage = nn.Sequential(
|
|
224
|
+
*[
|
|
225
|
+
ConvNeXtV2Block(
|
|
226
|
+
dim=dims[i],
|
|
227
|
+
spatial_dim=self.dim,
|
|
228
|
+
drop_path=dp_rates[cur + j],
|
|
229
|
+
)
|
|
230
|
+
for j in range(depths[i])
|
|
231
|
+
]
|
|
232
|
+
)
|
|
233
|
+
self.stages.append(stage)
|
|
234
|
+
cur += depths[i]
|
|
235
|
+
|
|
236
|
+
# Downsample between stages (except after last)
|
|
237
|
+
if i < len(depths) - 1:
|
|
238
|
+
downsample = nn.Sequential(
|
|
239
|
+
LayerNormNd(dims[i], self.dim),
|
|
240
|
+
Conv(dims[i], dims[i + 1], kernel_size=2, stride=2),
|
|
241
|
+
)
|
|
242
|
+
self.downsamples.append(downsample)
|
|
243
|
+
|
|
244
|
+
# Global pooling and head
|
|
245
|
+
self.norm = nn.LayerNorm(dims[-1], eps=1e-6)
|
|
246
|
+
self.global_pool = Pool(1)
|
|
247
|
+
self.head = nn.Sequential(
|
|
248
|
+
nn.Dropout(dropout_rate),
|
|
249
|
+
nn.Linear(dims[-1], dims[-1] // 2),
|
|
250
|
+
nn.GELU(),
|
|
251
|
+
nn.Dropout(dropout_rate * 0.5),
|
|
252
|
+
nn.Linear(dims[-1] // 2, out_size),
|
|
253
|
+
)
|
|
254
|
+
|
|
255
|
+
# Initialize weights
|
|
256
|
+
self._init_weights()
|
|
257
|
+
|
|
258
|
+
def _init_weights(self):
|
|
259
|
+
"""Initialize weights with truncated normal."""
|
|
260
|
+
for m in self.modules():
|
|
261
|
+
if isinstance(m, (nn.Conv1d, nn.Conv2d, nn.Conv3d, nn.Linear)):
|
|
262
|
+
nn.init.trunc_normal_(m.weight, std=0.02)
|
|
263
|
+
if m.bias is not None:
|
|
264
|
+
nn.init.zeros_(m.bias)
|
|
265
|
+
elif isinstance(m, nn.LayerNorm):
|
|
266
|
+
nn.init.ones_(m.weight)
|
|
267
|
+
nn.init.zeros_(m.bias)
|
|
268
|
+
|
|
269
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
270
|
+
"""
|
|
271
|
+
Forward pass.
|
|
272
|
+
|
|
273
|
+
Args:
|
|
274
|
+
x: Input tensor (B, 1, *in_shape)
|
|
275
|
+
|
|
276
|
+
Returns:
|
|
277
|
+
Output tensor (B, out_size)
|
|
278
|
+
"""
|
|
279
|
+
x = self.stem(x)
|
|
280
|
+
|
|
281
|
+
for i, stage in enumerate(self.stages):
|
|
282
|
+
x = stage(x)
|
|
283
|
+
if i < len(self.downsamples):
|
|
284
|
+
x = self.downsamples[i](x)
|
|
285
|
+
|
|
286
|
+
# Global pooling
|
|
287
|
+
x = self.global_pool(x)
|
|
288
|
+
x = x.flatten(1)
|
|
289
|
+
|
|
290
|
+
# Final norm and head
|
|
291
|
+
x = self.norm(x)
|
|
292
|
+
x = self.head(x)
|
|
293
|
+
|
|
294
|
+
return x
|
|
295
|
+
|
|
296
|
+
@classmethod
|
|
297
|
+
def get_default_config(cls) -> dict[str, Any]:
|
|
298
|
+
return {
|
|
299
|
+
"depths": [3, 3, 9, 3],
|
|
300
|
+
"dims": [96, 192, 384, 768],
|
|
301
|
+
"drop_path_rate": 0.1,
|
|
302
|
+
"dropout_rate": 0.3,
|
|
303
|
+
}
|
|
304
|
+
|
|
305
|
+
|
|
306
|
+
# =============================================================================
|
|
307
|
+
# REGISTERED VARIANTS
|
|
308
|
+
# =============================================================================
|
|
309
|
+
|
|
310
|
+
|
|
311
|
+
@register_model("convnext_v2_tiny")
|
|
312
|
+
class ConvNeXtV2Tiny(ConvNeXtV2Base):
|
|
313
|
+
"""
|
|
314
|
+
ConvNeXt V2 Tiny: ~27.9M backbone parameters.
|
|
315
|
+
|
|
316
|
+
Depths [3,3,9,3], Dims [96,192,384,768].
|
|
317
|
+
Supports 1D, 2D, 3D inputs.
|
|
318
|
+
|
|
319
|
+
Example:
|
|
320
|
+
>>> model = ConvNeXtV2Tiny(in_shape=(64, 64), out_size=3)
|
|
321
|
+
>>> x = torch.randn(4, 1, 64, 64)
|
|
322
|
+
>>> out = model(x) # (4, 3)
|
|
323
|
+
"""
|
|
324
|
+
|
|
325
|
+
def __init__(self, in_shape: SpatialShape, out_size: int, **kwargs):
|
|
326
|
+
super().__init__(
|
|
327
|
+
in_shape=in_shape,
|
|
328
|
+
out_size=out_size,
|
|
329
|
+
depths=[3, 3, 9, 3],
|
|
330
|
+
dims=[96, 192, 384, 768],
|
|
331
|
+
**kwargs,
|
|
332
|
+
)
|
|
333
|
+
|
|
334
|
+
def __repr__(self) -> str:
|
|
335
|
+
return (
|
|
336
|
+
f"ConvNeXtV2_Tiny({self.dim}D, in_shape={self.in_shape}, "
|
|
337
|
+
f"out_size={self.out_size})"
|
|
338
|
+
)
|
|
339
|
+
|
|
340
|
+
|
|
341
|
+
@register_model("convnext_v2_small")
|
|
342
|
+
class ConvNeXtV2Small(ConvNeXtV2Base):
|
|
343
|
+
"""
|
|
344
|
+
ConvNeXt V2 Small: ~49.6M backbone parameters.
|
|
345
|
+
|
|
346
|
+
Depths [3,3,27,3], Dims [96,192,384,768].
|
|
347
|
+
Supports 1D, 2D, 3D inputs.
|
|
348
|
+
"""
|
|
349
|
+
|
|
350
|
+
def __init__(self, in_shape: SpatialShape, out_size: int, **kwargs):
|
|
351
|
+
super().__init__(
|
|
352
|
+
in_shape=in_shape,
|
|
353
|
+
out_size=out_size,
|
|
354
|
+
depths=[3, 3, 27, 3],
|
|
355
|
+
dims=[96, 192, 384, 768],
|
|
356
|
+
**kwargs,
|
|
357
|
+
)
|
|
358
|
+
|
|
359
|
+
def __repr__(self) -> str:
|
|
360
|
+
return (
|
|
361
|
+
f"ConvNeXtV2_Small({self.dim}D, in_shape={self.in_shape}, "
|
|
362
|
+
f"out_size={self.out_size})"
|
|
363
|
+
)
|
|
364
|
+
|
|
365
|
+
|
|
366
|
+
@register_model("convnext_v2_base")
|
|
367
|
+
class ConvNeXtV2BaseLarge(ConvNeXtV2Base):
|
|
368
|
+
"""
|
|
369
|
+
ConvNeXt V2 Base: ~87.7M backbone parameters.
|
|
370
|
+
|
|
371
|
+
Depths [3,3,27,3], Dims [128,256,512,1024].
|
|
372
|
+
Supports 1D, 2D, 3D inputs.
|
|
373
|
+
"""
|
|
374
|
+
|
|
375
|
+
def __init__(self, in_shape: SpatialShape, out_size: int, **kwargs):
|
|
376
|
+
super().__init__(
|
|
377
|
+
in_shape=in_shape,
|
|
378
|
+
out_size=out_size,
|
|
379
|
+
depths=[3, 3, 27, 3],
|
|
380
|
+
dims=[128, 256, 512, 1024],
|
|
381
|
+
**kwargs,
|
|
382
|
+
)
|
|
383
|
+
|
|
384
|
+
def __repr__(self) -> str:
|
|
385
|
+
return (
|
|
386
|
+
f"ConvNeXtV2_Base({self.dim}D, in_shape={self.in_shape}, "
|
|
387
|
+
f"out_size={self.out_size})"
|
|
388
|
+
)
|
|
389
|
+
|
|
390
|
+
|
|
391
|
+
# =============================================================================
|
|
392
|
+
# PRETRAINED VARIANT (2D ONLY)
|
|
393
|
+
# =============================================================================
|
|
394
|
+
|
|
395
|
+
|
|
396
|
+
@register_model("convnext_v2_tiny_pretrained")
|
|
397
|
+
class ConvNeXtV2TinyPretrained(BaseModel):
|
|
398
|
+
"""
|
|
399
|
+
ConvNeXt V2 Tiny with ImageNet pretrained weights (2D only).
|
|
400
|
+
|
|
401
|
+
Uses torchvision's ConvNeXt V2 implementation with:
|
|
402
|
+
- Adapted input layer for single-channel input
|
|
403
|
+
- Replaced classifier for regression
|
|
404
|
+
|
|
405
|
+
Args:
|
|
406
|
+
in_shape: (H, W) input shape (2D only)
|
|
407
|
+
out_size: Number of regression targets
|
|
408
|
+
pretrained: Whether to load pretrained weights
|
|
409
|
+
freeze_backbone: Whether to freeze backbone for fine-tuning
|
|
410
|
+
"""
|
|
411
|
+
|
|
412
|
+
def __init__(
|
|
413
|
+
self,
|
|
414
|
+
in_shape: tuple[int, int],
|
|
415
|
+
out_size: int,
|
|
416
|
+
pretrained: bool = True,
|
|
417
|
+
freeze_backbone: bool = False,
|
|
418
|
+
dropout_rate: float = 0.3,
|
|
419
|
+
**kwargs,
|
|
420
|
+
):
|
|
421
|
+
super().__init__(in_shape, out_size)
|
|
422
|
+
|
|
423
|
+
if len(in_shape) != 2:
|
|
424
|
+
raise ValueError(
|
|
425
|
+
f"ConvNeXtV2TinyPretrained requires 2D input (H, W), "
|
|
426
|
+
f"got {len(in_shape)}D"
|
|
427
|
+
)
|
|
428
|
+
|
|
429
|
+
self.pretrained = pretrained
|
|
430
|
+
self.freeze_backbone = freeze_backbone
|
|
431
|
+
|
|
432
|
+
# Try to load from torchvision (if available)
|
|
433
|
+
try:
|
|
434
|
+
from torchvision.models import (
|
|
435
|
+
ConvNeXt_Tiny_Weights,
|
|
436
|
+
convnext_tiny,
|
|
437
|
+
)
|
|
438
|
+
|
|
439
|
+
weights = ConvNeXt_Tiny_Weights.IMAGENET1K_V1 if pretrained else None
|
|
440
|
+
self.backbone = convnext_tiny(weights=weights)
|
|
441
|
+
|
|
442
|
+
# Note: torchvision's ConvNeXt is V1, not V2
|
|
443
|
+
# For true V2, we'd need custom implementation or timm
|
|
444
|
+
# This is a fallback using V1 architecture
|
|
445
|
+
|
|
446
|
+
except ImportError:
|
|
447
|
+
raise ImportError(
|
|
448
|
+
"torchvision is required for pretrained ConvNeXt. "
|
|
449
|
+
"Install with: pip install torchvision"
|
|
450
|
+
)
|
|
451
|
+
|
|
452
|
+
# Adapt input layer (3 channels -> 1 channel)
|
|
453
|
+
self._adapt_input_channels()
|
|
454
|
+
|
|
455
|
+
# Replace classifier with regression head
|
|
456
|
+
# Keep the LayerNorm2d (idx 0) and Flatten (idx 1), only replace Linear (idx 2)
|
|
457
|
+
in_features = self.backbone.classifier[2].in_features
|
|
458
|
+
new_head = build_regression_head(in_features, out_size, dropout_rate)
|
|
459
|
+
|
|
460
|
+
# Build new classifier keeping LayerNorm2d and Flatten
|
|
461
|
+
self.backbone.classifier = nn.Sequential(
|
|
462
|
+
self.backbone.classifier[0], # LayerNorm2d
|
|
463
|
+
self.backbone.classifier[1], # Flatten
|
|
464
|
+
new_head, # Our regression head
|
|
465
|
+
)
|
|
466
|
+
|
|
467
|
+
if freeze_backbone:
|
|
468
|
+
self._freeze_backbone()
|
|
469
|
+
|
|
470
|
+
def _adapt_input_channels(self):
|
|
471
|
+
"""Adapt first conv layer for single-channel input."""
|
|
472
|
+
old_conv = self.backbone.features[0][0]
|
|
473
|
+
new_conv = nn.Conv2d(
|
|
474
|
+
1,
|
|
475
|
+
old_conv.out_channels,
|
|
476
|
+
kernel_size=old_conv.kernel_size,
|
|
477
|
+
stride=old_conv.stride,
|
|
478
|
+
padding=old_conv.padding,
|
|
479
|
+
bias=old_conv.bias is not None,
|
|
480
|
+
)
|
|
481
|
+
|
|
482
|
+
if self.pretrained:
|
|
483
|
+
with torch.no_grad():
|
|
484
|
+
# Average RGB weights for grayscale
|
|
485
|
+
new_conv.weight.copy_(old_conv.weight.mean(dim=1, keepdim=True))
|
|
486
|
+
if old_conv.bias is not None:
|
|
487
|
+
new_conv.bias.copy_(old_conv.bias)
|
|
488
|
+
|
|
489
|
+
self.backbone.features[0][0] = new_conv
|
|
490
|
+
|
|
491
|
+
def _freeze_backbone(self):
|
|
492
|
+
"""Freeze all backbone parameters except classifier."""
|
|
493
|
+
for name, param in self.backbone.named_parameters():
|
|
494
|
+
if "classifier" not in name:
|
|
495
|
+
param.requires_grad = False
|
|
496
|
+
|
|
497
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
498
|
+
return self.backbone(x)
|
|
499
|
+
|
|
500
|
+
def __repr__(self) -> str:
|
|
501
|
+
return (
|
|
502
|
+
f"ConvNeXtV2_Tiny_Pretrained(in_shape={self.in_shape}, "
|
|
503
|
+
f"out_size={self.out_size}, pretrained={self.pretrained})"
|
|
504
|
+
)
|
wavedl/models/densenet.py
CHANGED
|
@@ -11,8 +11,8 @@ Features: feature reuse, efficient gradient flow, compact model.
|
|
|
11
11
|
- 3D: Volumetric data, CT/MRI (N, 1, D, H, W) → Conv3d
|
|
12
12
|
|
|
13
13
|
**Variants**:
|
|
14
|
-
- densenet121: Standard (121 layers, ~
|
|
15
|
-
- densenet169: Deeper (169 layers, ~
|
|
14
|
+
- densenet121: Standard (121 layers, ~7.0M backbone params for 2D)
|
|
15
|
+
- densenet169: Deeper (169 layers, ~12.5M backbone params for 2D)
|
|
16
16
|
|
|
17
17
|
References:
|
|
18
18
|
Huang, G., et al. (2017). Densely Connected Convolutional Networks.
|
|
@@ -262,7 +262,7 @@ class DenseNet121(DenseNetBase):
|
|
|
262
262
|
"""
|
|
263
263
|
DenseNet-121: Standard variant with 121 layers.
|
|
264
264
|
|
|
265
|
-
~
|
|
265
|
+
~7.0M backbone parameters (2D). Good for: Balanced accuracy, efficient training.
|
|
266
266
|
|
|
267
267
|
Args:
|
|
268
268
|
in_shape: (L,), (H, W), or (D, H, W)
|
|
@@ -285,7 +285,7 @@ class DenseNet169(DenseNetBase):
|
|
|
285
285
|
"""
|
|
286
286
|
DenseNet-169: Deeper variant with 169 layers.
|
|
287
287
|
|
|
288
|
-
~
|
|
288
|
+
~12.5M backbone parameters (2D). Good for: Higher capacity, more complex patterns.
|
|
289
289
|
|
|
290
290
|
Args:
|
|
291
291
|
in_shape: (L,), (H, W), or (D, H, W)
|
|
@@ -320,7 +320,7 @@ class DenseNet121Pretrained(BaseModel):
|
|
|
320
320
|
"""
|
|
321
321
|
DenseNet-121 with ImageNet pretrained weights (2D only).
|
|
322
322
|
|
|
323
|
-
~
|
|
323
|
+
~7.0M backbone parameters. Good for: Transfer learning with efficient feature reuse.
|
|
324
324
|
|
|
325
325
|
Args:
|
|
326
326
|
in_shape: (H, W) image dimensions
|
wavedl/models/efficientnet.py
CHANGED
|
@@ -6,9 +6,9 @@ Wrapper around torchvision's EfficientNet with a regression head.
|
|
|
6
6
|
Provides optional ImageNet pretrained weights for transfer learning.
|
|
7
7
|
|
|
8
8
|
**Variants**:
|
|
9
|
-
- efficientnet_b0: Smallest, fastest (~4.
|
|
10
|
-
- efficientnet_b1: Light (~
|
|
11
|
-
- efficientnet_b2: Balanced (~
|
|
9
|
+
- efficientnet_b0: Smallest, fastest (~4.0M backbone params)
|
|
10
|
+
- efficientnet_b1: Light (~6.5M backbone params)
|
|
11
|
+
- efficientnet_b2: Balanced (~7.7M backbone params)
|
|
12
12
|
|
|
13
13
|
**Note**: EfficientNet is 2D-only. For 1D data, use TCN. For 3D data, use ResNet3D.
|
|
14
14
|
|
|
@@ -169,7 +169,7 @@ class EfficientNetB0(EfficientNetBase):
|
|
|
169
169
|
"""
|
|
170
170
|
EfficientNet-B0: Smallest, most efficient variant.
|
|
171
171
|
|
|
172
|
-
~
|
|
172
|
+
~4.0M backbone parameters. Good for: Quick training, limited compute, baseline.
|
|
173
173
|
|
|
174
174
|
Args:
|
|
175
175
|
in_shape: (H, W) image dimensions
|
|
@@ -200,7 +200,7 @@ class EfficientNetB1(EfficientNetBase):
|
|
|
200
200
|
"""
|
|
201
201
|
EfficientNet-B1: Slightly larger variant.
|
|
202
202
|
|
|
203
|
-
~
|
|
203
|
+
~6.5M backbone parameters. Good for: Better accuracy with moderate compute.
|
|
204
204
|
|
|
205
205
|
Args:
|
|
206
206
|
in_shape: (H, W) image dimensions
|
|
@@ -231,7 +231,7 @@ class EfficientNetB2(EfficientNetBase):
|
|
|
231
231
|
"""
|
|
232
232
|
EfficientNet-B2: Best balance of size and performance.
|
|
233
233
|
|
|
234
|
-
~
|
|
234
|
+
~7.7M backbone parameters. Good for: High accuracy without excessive compute.
|
|
235
235
|
|
|
236
236
|
Args:
|
|
237
237
|
in_shape: (H, W) image dimensions
|
wavedl/models/efficientnetv2.py
CHANGED
|
@@ -199,7 +199,7 @@ class EfficientNetV2S(EfficientNetV2Base):
|
|
|
199
199
|
"""
|
|
200
200
|
EfficientNetV2-S: Small variant, recommended default.
|
|
201
201
|
|
|
202
|
-
~
|
|
202
|
+
~20.2M backbone parameters. Best balance of speed and accuracy for most tasks.
|
|
203
203
|
2× faster training than EfficientNet-B4 with better accuracy.
|
|
204
204
|
|
|
205
205
|
Recommended for:
|
|
@@ -240,7 +240,7 @@ class EfficientNetV2M(EfficientNetV2Base):
|
|
|
240
240
|
"""
|
|
241
241
|
EfficientNetV2-M: Medium variant for higher accuracy.
|
|
242
242
|
|
|
243
|
-
~
|
|
243
|
+
~52.9M backbone parameters. Use when accuracy is more important than speed.
|
|
244
244
|
|
|
245
245
|
Recommended for:
|
|
246
246
|
- Large datasets (>50k samples)
|
|
@@ -280,7 +280,7 @@ class EfficientNetV2L(EfficientNetV2Base):
|
|
|
280
280
|
"""
|
|
281
281
|
EfficientNetV2-L: Large variant for maximum accuracy.
|
|
282
282
|
|
|
283
|
-
~
|
|
283
|
+
~117.2M backbone parameters. Use only with large datasets and sufficient compute.
|
|
284
284
|
|
|
285
285
|
Recommended for:
|
|
286
286
|
- Very large datasets (>100k samples)
|