wavedl 1.5.7__py3-none-any.whl → 1.6.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- wavedl/__init__.py +1 -1
- wavedl/hpo.py +451 -451
- wavedl/models/__init__.py +80 -4
- wavedl/models/_pretrained_utils.py +366 -0
- wavedl/models/base.py +48 -0
- wavedl/models/caformer.py +270 -0
- wavedl/models/cnn.py +2 -27
- wavedl/models/convnext.py +113 -51
- wavedl/models/convnext_v2.py +488 -0
- wavedl/models/densenet.py +10 -23
- wavedl/models/efficientnet.py +6 -6
- wavedl/models/efficientnetv2.py +315 -315
- wavedl/models/efficientvit.py +398 -0
- wavedl/models/fastvit.py +252 -0
- wavedl/models/mamba.py +555 -0
- wavedl/models/maxvit.py +254 -0
- wavedl/models/mobilenetv3.py +295 -295
- wavedl/models/regnet.py +406 -406
- wavedl/models/resnet.py +19 -61
- wavedl/models/resnet3d.py +258 -258
- wavedl/models/swin.py +443 -443
- wavedl/models/tcn.py +393 -409
- wavedl/models/unet.py +2 -6
- wavedl/models/unireplknet.py +491 -0
- wavedl/models/vit.py +9 -9
- wavedl/train.py +1430 -1425
- wavedl/utils/config.py +367 -367
- wavedl/utils/cross_validation.py +530 -530
- wavedl/utils/data.py +39 -6
- wavedl/utils/losses.py +216 -216
- wavedl/utils/optimizers.py +216 -216
- wavedl/utils/schedulers.py +251 -251
- {wavedl-1.5.7.dist-info → wavedl-1.6.1.dist-info}/METADATA +150 -82
- wavedl-1.6.1.dist-info/RECORD +46 -0
- wavedl-1.5.7.dist-info/RECORD +0 -38
- {wavedl-1.5.7.dist-info → wavedl-1.6.1.dist-info}/LICENSE +0 -0
- {wavedl-1.5.7.dist-info → wavedl-1.6.1.dist-info}/WHEEL +0 -0
- {wavedl-1.5.7.dist-info → wavedl-1.6.1.dist-info}/entry_points.txt +0 -0
- {wavedl-1.5.7.dist-info → wavedl-1.6.1.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,488 @@
|
|
|
1
|
+
"""
|
|
2
|
+
ConvNeXt V2: Co-designing and Scaling ConvNets with Masked Autoencoders
|
|
3
|
+
========================================================================
|
|
4
|
+
|
|
5
|
+
ConvNeXt V2 improves upon V1 by replacing LayerScale with Global Response
|
|
6
|
+
Normalization (GRN), which enhances inter-channel feature competition.
|
|
7
|
+
|
|
8
|
+
**Key Changes from V1**:
|
|
9
|
+
- GRN layer replaces LayerScale
|
|
10
|
+
- Better compatibility with masked autoencoder pretraining
|
|
11
|
+
- Prevents feature collapse in deep networks
|
|
12
|
+
|
|
13
|
+
**Variants**:
|
|
14
|
+
- convnext_v2_tiny: 28M params, depths [3,3,9,3], dims [96,192,384,768]
|
|
15
|
+
- convnext_v2_small: 50M params, depths [3,3,27,3], dims [96,192,384,768]
|
|
16
|
+
- convnext_v2_base: 89M params, depths [3,3,27,3], dims [128,256,512,1024]
|
|
17
|
+
- convnext_v2_tiny_pretrained: 2D only, ImageNet weights
|
|
18
|
+
|
|
19
|
+
**Supports**: 1D, 2D, 3D inputs
|
|
20
|
+
|
|
21
|
+
Reference:
|
|
22
|
+
Woo, S., et al. (2023). ConvNeXt V2: Co-designing and Scaling ConvNets
|
|
23
|
+
with Masked Autoencoders. CVPR 2023.
|
|
24
|
+
https://arxiv.org/abs/2301.00808
|
|
25
|
+
|
|
26
|
+
Author: Ductho Le (ductho.le@outlook.com)
|
|
27
|
+
"""
|
|
28
|
+
|
|
29
|
+
from typing import Any
|
|
30
|
+
|
|
31
|
+
import torch
|
|
32
|
+
import torch.nn as nn
|
|
33
|
+
|
|
34
|
+
from wavedl.models._pretrained_utils import (
|
|
35
|
+
LayerNormNd,
|
|
36
|
+
build_regression_head,
|
|
37
|
+
get_conv_layer,
|
|
38
|
+
get_grn_layer,
|
|
39
|
+
get_pool_layer,
|
|
40
|
+
)
|
|
41
|
+
from wavedl.models.base import BaseModel, SpatialShape
|
|
42
|
+
from wavedl.models.registry import register_model
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
__all__ = [
|
|
46
|
+
"ConvNeXtV2Base",
|
|
47
|
+
"ConvNeXtV2BaseLarge",
|
|
48
|
+
"ConvNeXtV2Small",
|
|
49
|
+
"ConvNeXtV2Tiny",
|
|
50
|
+
"ConvNeXtV2TinyPretrained",
|
|
51
|
+
]
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
# =============================================================================
|
|
55
|
+
# CONVNEXT V2 BLOCK
|
|
56
|
+
# =============================================================================
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
class ConvNeXtV2Block(nn.Module):
|
|
60
|
+
"""
|
|
61
|
+
ConvNeXt V2 Block with GRN instead of LayerScale.
|
|
62
|
+
|
|
63
|
+
Architecture:
|
|
64
|
+
Input → DwConv → LayerNorm → Linear → GELU → GRN → Linear → Residual
|
|
65
|
+
|
|
66
|
+
The GRN layer is the key difference from V1, placed after the
|
|
67
|
+
dimension-expansion in the MLP, replacing LayerScale.
|
|
68
|
+
"""
|
|
69
|
+
|
|
70
|
+
def __init__(
|
|
71
|
+
self,
|
|
72
|
+
dim: int,
|
|
73
|
+
spatial_dim: int,
|
|
74
|
+
drop_path: float = 0.0,
|
|
75
|
+
mlp_ratio: float = 4.0,
|
|
76
|
+
):
|
|
77
|
+
super().__init__()
|
|
78
|
+
self.spatial_dim = spatial_dim
|
|
79
|
+
|
|
80
|
+
Conv = get_conv_layer(spatial_dim)
|
|
81
|
+
GRN = get_grn_layer(spatial_dim)
|
|
82
|
+
|
|
83
|
+
# Depthwise convolution
|
|
84
|
+
kernel_size = 7
|
|
85
|
+
padding = 3
|
|
86
|
+
self.dwconv = Conv(
|
|
87
|
+
dim, dim, kernel_size=kernel_size, padding=padding, groups=dim
|
|
88
|
+
)
|
|
89
|
+
|
|
90
|
+
# LayerNorm (applied in forward with permutation)
|
|
91
|
+
self.norm = nn.LayerNorm(dim, eps=1e-6)
|
|
92
|
+
|
|
93
|
+
# MLP with expansion
|
|
94
|
+
hidden_dim = int(dim * mlp_ratio)
|
|
95
|
+
self.pwconv1 = nn.Linear(dim, hidden_dim) # Expansion
|
|
96
|
+
self.act = nn.GELU()
|
|
97
|
+
self.grn = GRN(hidden_dim) # GRN after expansion (key V2 change)
|
|
98
|
+
self.pwconv2 = nn.Linear(hidden_dim, dim) # Projection
|
|
99
|
+
|
|
100
|
+
# Stochastic depth
|
|
101
|
+
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
|
102
|
+
|
|
103
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
104
|
+
residual = x
|
|
105
|
+
|
|
106
|
+
# Depthwise conv
|
|
107
|
+
x = self.dwconv(x)
|
|
108
|
+
|
|
109
|
+
# Move channels to last for LayerNorm and Linear layers
|
|
110
|
+
if self.spatial_dim == 1:
|
|
111
|
+
x = x.permute(0, 2, 1) # (B, C, L) -> (B, L, C)
|
|
112
|
+
elif self.spatial_dim == 2:
|
|
113
|
+
x = x.permute(0, 2, 3, 1) # (B, C, H, W) -> (B, H, W, C)
|
|
114
|
+
else: # 3D
|
|
115
|
+
x = x.permute(0, 2, 3, 4, 1) # (B, C, D, H, W) -> (B, D, H, W, C)
|
|
116
|
+
|
|
117
|
+
x = self.norm(x)
|
|
118
|
+
x = self.pwconv1(x)
|
|
119
|
+
x = self.act(x)
|
|
120
|
+
|
|
121
|
+
# Move back to channels-first for GRN
|
|
122
|
+
if self.spatial_dim == 1:
|
|
123
|
+
x = x.permute(0, 2, 1) # (B, L, C) -> (B, C, L)
|
|
124
|
+
elif self.spatial_dim == 2:
|
|
125
|
+
x = x.permute(0, 3, 1, 2) # (B, H, W, C) -> (B, C, H, W)
|
|
126
|
+
else: # 3D
|
|
127
|
+
x = x.permute(0, 4, 1, 2, 3) # (B, D, H, W, C) -> (B, C, D, H, W)
|
|
128
|
+
|
|
129
|
+
# Apply GRN (the key V2 innovation)
|
|
130
|
+
x = self.grn(x)
|
|
131
|
+
|
|
132
|
+
# Move to channels-last for final projection
|
|
133
|
+
if self.spatial_dim == 1:
|
|
134
|
+
x = x.permute(0, 2, 1)
|
|
135
|
+
elif self.spatial_dim == 2:
|
|
136
|
+
x = x.permute(0, 2, 3, 1)
|
|
137
|
+
else:
|
|
138
|
+
x = x.permute(0, 2, 3, 4, 1)
|
|
139
|
+
|
|
140
|
+
x = self.pwconv2(x)
|
|
141
|
+
|
|
142
|
+
# Move back to channels-first
|
|
143
|
+
if self.spatial_dim == 1:
|
|
144
|
+
x = x.permute(0, 2, 1)
|
|
145
|
+
elif self.spatial_dim == 2:
|
|
146
|
+
x = x.permute(0, 3, 1, 2)
|
|
147
|
+
else:
|
|
148
|
+
x = x.permute(0, 4, 1, 2, 3)
|
|
149
|
+
|
|
150
|
+
x = residual + self.drop_path(x)
|
|
151
|
+
return x
|
|
152
|
+
|
|
153
|
+
|
|
154
|
+
class DropPath(nn.Module):
|
|
155
|
+
"""Stochastic Depth (drop path) regularization."""
|
|
156
|
+
|
|
157
|
+
def __init__(self, drop_prob: float = 0.0):
|
|
158
|
+
super().__init__()
|
|
159
|
+
self.drop_prob = drop_prob
|
|
160
|
+
|
|
161
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
162
|
+
if self.drop_prob == 0.0 or not self.training:
|
|
163
|
+
return x
|
|
164
|
+
|
|
165
|
+
keep_prob = 1 - self.drop_prob
|
|
166
|
+
shape = (x.shape[0],) + (1,) * (x.ndim - 1)
|
|
167
|
+
random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
|
|
168
|
+
random_tensor.floor_()
|
|
169
|
+
return x.div(keep_prob) * random_tensor
|
|
170
|
+
|
|
171
|
+
|
|
172
|
+
# =============================================================================
|
|
173
|
+
# CONVNEXT V2 BASE CLASS
|
|
174
|
+
# =============================================================================
|
|
175
|
+
|
|
176
|
+
|
|
177
|
+
class ConvNeXtV2Base(BaseModel):
|
|
178
|
+
"""
|
|
179
|
+
ConvNeXt V2 base class for regression.
|
|
180
|
+
|
|
181
|
+
Dimension-agnostic implementation supporting 1D, 2D, and 3D inputs.
|
|
182
|
+
Uses GRN (Global Response Normalization) instead of LayerScale.
|
|
183
|
+
"""
|
|
184
|
+
|
|
185
|
+
def __init__(
|
|
186
|
+
self,
|
|
187
|
+
in_shape: SpatialShape,
|
|
188
|
+
out_size: int,
|
|
189
|
+
depths: list[int],
|
|
190
|
+
dims: list[int],
|
|
191
|
+
drop_path_rate: float = 0.0,
|
|
192
|
+
dropout_rate: float = 0.3,
|
|
193
|
+
**kwargs,
|
|
194
|
+
):
|
|
195
|
+
super().__init__(in_shape, out_size)
|
|
196
|
+
|
|
197
|
+
self.dim = len(in_shape)
|
|
198
|
+
self.depths = depths
|
|
199
|
+
self.dims = dims
|
|
200
|
+
|
|
201
|
+
Conv = get_conv_layer(self.dim)
|
|
202
|
+
Pool = get_pool_layer(self.dim)
|
|
203
|
+
|
|
204
|
+
# Stem: aggressive downsampling (4x stride like ConvNeXt)
|
|
205
|
+
self.stem = nn.Sequential(
|
|
206
|
+
Conv(1, dims[0], kernel_size=4, stride=4),
|
|
207
|
+
LayerNormNd(dims[0], self.dim),
|
|
208
|
+
)
|
|
209
|
+
|
|
210
|
+
# Stochastic depth decay rule
|
|
211
|
+
dp_rates = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
|
|
212
|
+
|
|
213
|
+
# Build stages
|
|
214
|
+
self.stages = nn.ModuleList()
|
|
215
|
+
self.downsamples = nn.ModuleList()
|
|
216
|
+
cur = 0
|
|
217
|
+
|
|
218
|
+
for i in range(len(depths)):
|
|
219
|
+
# Stage: sequence of ConvNeXt V2 blocks
|
|
220
|
+
stage = nn.Sequential(
|
|
221
|
+
*[
|
|
222
|
+
ConvNeXtV2Block(
|
|
223
|
+
dim=dims[i],
|
|
224
|
+
spatial_dim=self.dim,
|
|
225
|
+
drop_path=dp_rates[cur + j],
|
|
226
|
+
)
|
|
227
|
+
for j in range(depths[i])
|
|
228
|
+
]
|
|
229
|
+
)
|
|
230
|
+
self.stages.append(stage)
|
|
231
|
+
cur += depths[i]
|
|
232
|
+
|
|
233
|
+
# Downsample between stages (except after last)
|
|
234
|
+
if i < len(depths) - 1:
|
|
235
|
+
downsample = nn.Sequential(
|
|
236
|
+
LayerNormNd(dims[i], self.dim),
|
|
237
|
+
Conv(dims[i], dims[i + 1], kernel_size=2, stride=2),
|
|
238
|
+
)
|
|
239
|
+
self.downsamples.append(downsample)
|
|
240
|
+
|
|
241
|
+
# Global pooling and head
|
|
242
|
+
self.norm = nn.LayerNorm(dims[-1], eps=1e-6)
|
|
243
|
+
self.global_pool = Pool(1)
|
|
244
|
+
self.head = nn.Sequential(
|
|
245
|
+
nn.Dropout(dropout_rate),
|
|
246
|
+
nn.Linear(dims[-1], dims[-1] // 2),
|
|
247
|
+
nn.GELU(),
|
|
248
|
+
nn.Dropout(dropout_rate * 0.5),
|
|
249
|
+
nn.Linear(dims[-1] // 2, out_size),
|
|
250
|
+
)
|
|
251
|
+
|
|
252
|
+
# Initialize weights
|
|
253
|
+
self._init_weights()
|
|
254
|
+
|
|
255
|
+
def _init_weights(self):
|
|
256
|
+
"""Initialize weights with truncated normal."""
|
|
257
|
+
for m in self.modules():
|
|
258
|
+
if isinstance(m, (nn.Conv1d, nn.Conv2d, nn.Conv3d, nn.Linear)):
|
|
259
|
+
nn.init.trunc_normal_(m.weight, std=0.02)
|
|
260
|
+
if m.bias is not None:
|
|
261
|
+
nn.init.zeros_(m.bias)
|
|
262
|
+
elif isinstance(m, nn.LayerNorm):
|
|
263
|
+
nn.init.ones_(m.weight)
|
|
264
|
+
nn.init.zeros_(m.bias)
|
|
265
|
+
|
|
266
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
267
|
+
"""
|
|
268
|
+
Forward pass.
|
|
269
|
+
|
|
270
|
+
Args:
|
|
271
|
+
x: Input tensor (B, 1, *in_shape)
|
|
272
|
+
|
|
273
|
+
Returns:
|
|
274
|
+
Output tensor (B, out_size)
|
|
275
|
+
"""
|
|
276
|
+
x = self.stem(x)
|
|
277
|
+
|
|
278
|
+
for i, stage in enumerate(self.stages):
|
|
279
|
+
x = stage(x)
|
|
280
|
+
if i < len(self.downsamples):
|
|
281
|
+
x = self.downsamples[i](x)
|
|
282
|
+
|
|
283
|
+
# Global pooling
|
|
284
|
+
x = self.global_pool(x)
|
|
285
|
+
x = x.flatten(1)
|
|
286
|
+
|
|
287
|
+
# Final norm and head
|
|
288
|
+
x = self.norm(x)
|
|
289
|
+
x = self.head(x)
|
|
290
|
+
|
|
291
|
+
return x
|
|
292
|
+
|
|
293
|
+
@classmethod
|
|
294
|
+
def get_default_config(cls) -> dict[str, Any]:
|
|
295
|
+
return {
|
|
296
|
+
"depths": [3, 3, 9, 3],
|
|
297
|
+
"dims": [96, 192, 384, 768],
|
|
298
|
+
"drop_path_rate": 0.1,
|
|
299
|
+
"dropout_rate": 0.3,
|
|
300
|
+
}
|
|
301
|
+
|
|
302
|
+
|
|
303
|
+
# =============================================================================
|
|
304
|
+
# REGISTERED VARIANTS
|
|
305
|
+
# =============================================================================
|
|
306
|
+
|
|
307
|
+
|
|
308
|
+
@register_model("convnext_v2_tiny")
|
|
309
|
+
class ConvNeXtV2Tiny(ConvNeXtV2Base):
|
|
310
|
+
"""
|
|
311
|
+
ConvNeXt V2 Tiny: ~27.9M backbone parameters.
|
|
312
|
+
|
|
313
|
+
Depths [3,3,9,3], Dims [96,192,384,768].
|
|
314
|
+
Supports 1D, 2D, 3D inputs.
|
|
315
|
+
|
|
316
|
+
Example:
|
|
317
|
+
>>> model = ConvNeXtV2Tiny(in_shape=(64, 64), out_size=3)
|
|
318
|
+
>>> x = torch.randn(4, 1, 64, 64)
|
|
319
|
+
>>> out = model(x) # (4, 3)
|
|
320
|
+
"""
|
|
321
|
+
|
|
322
|
+
def __init__(self, in_shape: SpatialShape, out_size: int, **kwargs):
|
|
323
|
+
super().__init__(
|
|
324
|
+
in_shape=in_shape,
|
|
325
|
+
out_size=out_size,
|
|
326
|
+
depths=[3, 3, 9, 3],
|
|
327
|
+
dims=[96, 192, 384, 768],
|
|
328
|
+
**kwargs,
|
|
329
|
+
)
|
|
330
|
+
|
|
331
|
+
def __repr__(self) -> str:
|
|
332
|
+
return (
|
|
333
|
+
f"ConvNeXtV2_Tiny({self.dim}D, in_shape={self.in_shape}, "
|
|
334
|
+
f"out_size={self.out_size})"
|
|
335
|
+
)
|
|
336
|
+
|
|
337
|
+
|
|
338
|
+
@register_model("convnext_v2_small")
|
|
339
|
+
class ConvNeXtV2Small(ConvNeXtV2Base):
|
|
340
|
+
"""
|
|
341
|
+
ConvNeXt V2 Small: ~49.6M backbone parameters.
|
|
342
|
+
|
|
343
|
+
Depths [3,3,27,3], Dims [96,192,384,768].
|
|
344
|
+
Supports 1D, 2D, 3D inputs.
|
|
345
|
+
"""
|
|
346
|
+
|
|
347
|
+
def __init__(self, in_shape: SpatialShape, out_size: int, **kwargs):
|
|
348
|
+
super().__init__(
|
|
349
|
+
in_shape=in_shape,
|
|
350
|
+
out_size=out_size,
|
|
351
|
+
depths=[3, 3, 27, 3],
|
|
352
|
+
dims=[96, 192, 384, 768],
|
|
353
|
+
**kwargs,
|
|
354
|
+
)
|
|
355
|
+
|
|
356
|
+
def __repr__(self) -> str:
|
|
357
|
+
return (
|
|
358
|
+
f"ConvNeXtV2_Small({self.dim}D, in_shape={self.in_shape}, "
|
|
359
|
+
f"out_size={self.out_size})"
|
|
360
|
+
)
|
|
361
|
+
|
|
362
|
+
|
|
363
|
+
@register_model("convnext_v2_base")
|
|
364
|
+
class ConvNeXtV2BaseLarge(ConvNeXtV2Base):
|
|
365
|
+
"""
|
|
366
|
+
ConvNeXt V2 Base: ~87.7M backbone parameters.
|
|
367
|
+
|
|
368
|
+
Depths [3,3,27,3], Dims [128,256,512,1024].
|
|
369
|
+
Supports 1D, 2D, 3D inputs.
|
|
370
|
+
"""
|
|
371
|
+
|
|
372
|
+
def __init__(self, in_shape: SpatialShape, out_size: int, **kwargs):
|
|
373
|
+
super().__init__(
|
|
374
|
+
in_shape=in_shape,
|
|
375
|
+
out_size=out_size,
|
|
376
|
+
depths=[3, 3, 27, 3],
|
|
377
|
+
dims=[128, 256, 512, 1024],
|
|
378
|
+
**kwargs,
|
|
379
|
+
)
|
|
380
|
+
|
|
381
|
+
def __repr__(self) -> str:
|
|
382
|
+
return (
|
|
383
|
+
f"ConvNeXtV2_Base({self.dim}D, in_shape={self.in_shape}, "
|
|
384
|
+
f"out_size={self.out_size})"
|
|
385
|
+
)
|
|
386
|
+
|
|
387
|
+
|
|
388
|
+
# =============================================================================
|
|
389
|
+
# PRETRAINED VARIANT (2D ONLY)
|
|
390
|
+
# =============================================================================
|
|
391
|
+
|
|
392
|
+
|
|
393
|
+
@register_model("convnext_v2_tiny_pretrained")
|
|
394
|
+
class ConvNeXtV2TinyPretrained(BaseModel):
|
|
395
|
+
"""
|
|
396
|
+
ConvNeXt V2 Tiny with ImageNet pretrained weights (2D only).
|
|
397
|
+
|
|
398
|
+
Uses torchvision's ConvNeXt V2 implementation with:
|
|
399
|
+
- Adapted input layer for single-channel input
|
|
400
|
+
- Replaced classifier for regression
|
|
401
|
+
|
|
402
|
+
Args:
|
|
403
|
+
in_shape: (H, W) input shape (2D only)
|
|
404
|
+
out_size: Number of regression targets
|
|
405
|
+
pretrained: Whether to load pretrained weights
|
|
406
|
+
freeze_backbone: Whether to freeze backbone for fine-tuning
|
|
407
|
+
"""
|
|
408
|
+
|
|
409
|
+
def __init__(
|
|
410
|
+
self,
|
|
411
|
+
in_shape: tuple[int, int],
|
|
412
|
+
out_size: int,
|
|
413
|
+
pretrained: bool = True,
|
|
414
|
+
freeze_backbone: bool = False,
|
|
415
|
+
dropout_rate: float = 0.3,
|
|
416
|
+
**kwargs,
|
|
417
|
+
):
|
|
418
|
+
super().__init__(in_shape, out_size)
|
|
419
|
+
|
|
420
|
+
if len(in_shape) != 2:
|
|
421
|
+
raise ValueError(
|
|
422
|
+
f"ConvNeXtV2TinyPretrained requires 2D input (H, W), "
|
|
423
|
+
f"got {len(in_shape)}D"
|
|
424
|
+
)
|
|
425
|
+
|
|
426
|
+
self.pretrained = pretrained
|
|
427
|
+
self.freeze_backbone = freeze_backbone
|
|
428
|
+
|
|
429
|
+
# Try to load from torchvision (if available)
|
|
430
|
+
try:
|
|
431
|
+
from torchvision.models import (
|
|
432
|
+
ConvNeXt_Tiny_Weights,
|
|
433
|
+
convnext_tiny,
|
|
434
|
+
)
|
|
435
|
+
|
|
436
|
+
weights = ConvNeXt_Tiny_Weights.IMAGENET1K_V1 if pretrained else None
|
|
437
|
+
self.backbone = convnext_tiny(weights=weights)
|
|
438
|
+
|
|
439
|
+
# Note: torchvision's ConvNeXt is V1, not V2
|
|
440
|
+
# For true V2, we'd need custom implementation or timm
|
|
441
|
+
# This is a fallback using V1 architecture
|
|
442
|
+
|
|
443
|
+
except ImportError:
|
|
444
|
+
raise ImportError(
|
|
445
|
+
"torchvision is required for pretrained ConvNeXt. "
|
|
446
|
+
"Install with: pip install torchvision"
|
|
447
|
+
)
|
|
448
|
+
|
|
449
|
+
# Adapt input layer (3 channels -> 1 channel)
|
|
450
|
+
self._adapt_input_channels()
|
|
451
|
+
|
|
452
|
+
# Replace classifier with regression head
|
|
453
|
+
# Keep the LayerNorm2d (idx 0) and Flatten (idx 1), only replace Linear (idx 2)
|
|
454
|
+
in_features = self.backbone.classifier[2].in_features
|
|
455
|
+
new_head = build_regression_head(in_features, out_size, dropout_rate)
|
|
456
|
+
|
|
457
|
+
# Build new classifier keeping LayerNorm2d and Flatten
|
|
458
|
+
self.backbone.classifier = nn.Sequential(
|
|
459
|
+
self.backbone.classifier[0], # LayerNorm2d
|
|
460
|
+
self.backbone.classifier[1], # Flatten
|
|
461
|
+
new_head, # Our regression head
|
|
462
|
+
)
|
|
463
|
+
|
|
464
|
+
if freeze_backbone:
|
|
465
|
+
self._freeze_backbone()
|
|
466
|
+
|
|
467
|
+
def _adapt_input_channels(self):
|
|
468
|
+
"""Adapt first conv layer for single-channel input."""
|
|
469
|
+
from wavedl.models._pretrained_utils import adapt_first_conv_for_single_channel
|
|
470
|
+
|
|
471
|
+
adapt_first_conv_for_single_channel(
|
|
472
|
+
self.backbone, "features.0.0", pretrained=self.pretrained
|
|
473
|
+
)
|
|
474
|
+
|
|
475
|
+
def _freeze_backbone(self):
|
|
476
|
+
"""Freeze all backbone parameters except classifier."""
|
|
477
|
+
for name, param in self.backbone.named_parameters():
|
|
478
|
+
if "classifier" not in name:
|
|
479
|
+
param.requires_grad = False
|
|
480
|
+
|
|
481
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
482
|
+
return self.backbone(x)
|
|
483
|
+
|
|
484
|
+
def __repr__(self) -> str:
|
|
485
|
+
return (
|
|
486
|
+
f"ConvNeXtV2_Tiny_Pretrained(in_shape={self.in_shape}, "
|
|
487
|
+
f"out_size={self.out_size}, pretrained={self.pretrained})"
|
|
488
|
+
)
|
wavedl/models/densenet.py
CHANGED
|
@@ -11,8 +11,8 @@ Features: feature reuse, efficient gradient flow, compact model.
|
|
|
11
11
|
- 3D: Volumetric data, CT/MRI (N, 1, D, H, W) → Conv3d
|
|
12
12
|
|
|
13
13
|
**Variants**:
|
|
14
|
-
- densenet121: Standard (121 layers, ~
|
|
15
|
-
- densenet169: Deeper (169 layers, ~
|
|
14
|
+
- densenet121: Standard (121 layers, ~7.0M backbone params for 2D)
|
|
15
|
+
- densenet169: Deeper (169 layers, ~12.5M backbone params for 2D)
|
|
16
16
|
|
|
17
17
|
References:
|
|
18
18
|
Huang, G., et al. (2017). Densely Connected Convolutional Networks.
|
|
@@ -26,14 +26,10 @@ from typing import Any
|
|
|
26
26
|
import torch
|
|
27
27
|
import torch.nn as nn
|
|
28
28
|
|
|
29
|
-
from wavedl.models.base import BaseModel
|
|
29
|
+
from wavedl.models.base import BaseModel, SpatialShape
|
|
30
30
|
from wavedl.models.registry import register_model
|
|
31
31
|
|
|
32
32
|
|
|
33
|
-
# Type alias for spatial shapes
|
|
34
|
-
SpatialShape = tuple[int] | tuple[int, int] | tuple[int, int, int]
|
|
35
|
-
|
|
36
|
-
|
|
37
33
|
def _get_layers(dim: int):
|
|
38
34
|
"""Get dimension-appropriate layer classes."""
|
|
39
35
|
if dim == 1:
|
|
@@ -262,7 +258,7 @@ class DenseNet121(DenseNetBase):
|
|
|
262
258
|
"""
|
|
263
259
|
DenseNet-121: Standard variant with 121 layers.
|
|
264
260
|
|
|
265
|
-
~
|
|
261
|
+
~7.0M backbone parameters (2D). Good for: Balanced accuracy, efficient training.
|
|
266
262
|
|
|
267
263
|
Args:
|
|
268
264
|
in_shape: (L,), (H, W), or (D, H, W)
|
|
@@ -285,7 +281,7 @@ class DenseNet169(DenseNetBase):
|
|
|
285
281
|
"""
|
|
286
282
|
DenseNet-169: Deeper variant with 169 layers.
|
|
287
283
|
|
|
288
|
-
~
|
|
284
|
+
~12.5M backbone parameters (2D). Good for: Higher capacity, more complex patterns.
|
|
289
285
|
|
|
290
286
|
Args:
|
|
291
287
|
in_shape: (L,), (H, W), or (D, H, W)
|
|
@@ -320,7 +316,7 @@ class DenseNet121Pretrained(BaseModel):
|
|
|
320
316
|
"""
|
|
321
317
|
DenseNet-121 with ImageNet pretrained weights (2D only).
|
|
322
318
|
|
|
323
|
-
~
|
|
319
|
+
~7.0M backbone parameters. Good for: Transfer learning with efficient feature reuse.
|
|
324
320
|
|
|
325
321
|
Args:
|
|
326
322
|
in_shape: (H, W) image dimensions
|
|
@@ -374,20 +370,11 @@ class DenseNet121Pretrained(BaseModel):
|
|
|
374
370
|
)
|
|
375
371
|
|
|
376
372
|
# Modify first conv for single-channel input
|
|
377
|
-
|
|
378
|
-
|
|
379
|
-
|
|
380
|
-
|
|
381
|
-
kernel_size=old_conv.kernel_size,
|
|
382
|
-
stride=old_conv.stride,
|
|
383
|
-
padding=old_conv.padding,
|
|
384
|
-
bias=False,
|
|
373
|
+
from wavedl.models._pretrained_utils import adapt_first_conv_for_single_channel
|
|
374
|
+
|
|
375
|
+
adapt_first_conv_for_single_channel(
|
|
376
|
+
self.backbone, "features.conv0", pretrained=pretrained
|
|
385
377
|
)
|
|
386
|
-
if pretrained:
|
|
387
|
-
with torch.no_grad():
|
|
388
|
-
self.backbone.features.conv0.weight = nn.Parameter(
|
|
389
|
-
old_conv.weight.mean(dim=1, keepdim=True)
|
|
390
|
-
)
|
|
391
378
|
|
|
392
379
|
if freeze_backbone:
|
|
393
380
|
self._freeze_backbone()
|
wavedl/models/efficientnet.py
CHANGED
|
@@ -6,9 +6,9 @@ Wrapper around torchvision's EfficientNet with a regression head.
|
|
|
6
6
|
Provides optional ImageNet pretrained weights for transfer learning.
|
|
7
7
|
|
|
8
8
|
**Variants**:
|
|
9
|
-
- efficientnet_b0: Smallest, fastest (~4.
|
|
10
|
-
- efficientnet_b1: Light (~
|
|
11
|
-
- efficientnet_b2: Balanced (~
|
|
9
|
+
- efficientnet_b0: Smallest, fastest (~4.0M backbone params)
|
|
10
|
+
- efficientnet_b1: Light (~6.5M backbone params)
|
|
11
|
+
- efficientnet_b2: Balanced (~7.7M backbone params)
|
|
12
12
|
|
|
13
13
|
**Note**: EfficientNet is 2D-only. For 1D data, use TCN. For 3D data, use ResNet3D.
|
|
14
14
|
|
|
@@ -169,7 +169,7 @@ class EfficientNetB0(EfficientNetBase):
|
|
|
169
169
|
"""
|
|
170
170
|
EfficientNet-B0: Smallest, most efficient variant.
|
|
171
171
|
|
|
172
|
-
~
|
|
172
|
+
~4.0M backbone parameters. Good for: Quick training, limited compute, baseline.
|
|
173
173
|
|
|
174
174
|
Args:
|
|
175
175
|
in_shape: (H, W) image dimensions
|
|
@@ -200,7 +200,7 @@ class EfficientNetB1(EfficientNetBase):
|
|
|
200
200
|
"""
|
|
201
201
|
EfficientNet-B1: Slightly larger variant.
|
|
202
202
|
|
|
203
|
-
~
|
|
203
|
+
~6.5M backbone parameters. Good for: Better accuracy with moderate compute.
|
|
204
204
|
|
|
205
205
|
Args:
|
|
206
206
|
in_shape: (H, W) image dimensions
|
|
@@ -231,7 +231,7 @@ class EfficientNetB2(EfficientNetBase):
|
|
|
231
231
|
"""
|
|
232
232
|
EfficientNet-B2: Best balance of size and performance.
|
|
233
233
|
|
|
234
|
-
~
|
|
234
|
+
~7.7M backbone parameters. Good for: High accuracy without excessive compute.
|
|
235
235
|
|
|
236
236
|
Args:
|
|
237
237
|
in_shape: (H, W) image dimensions
|