wavedl 1.5.7__py3-none-any.whl → 1.6.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- wavedl/__init__.py +1 -1
- wavedl/hpo.py +451 -451
- wavedl/models/__init__.py +80 -4
- wavedl/models/_pretrained_utils.py +366 -0
- wavedl/models/base.py +48 -0
- wavedl/models/caformer.py +270 -0
- wavedl/models/cnn.py +2 -27
- wavedl/models/convnext.py +113 -51
- wavedl/models/convnext_v2.py +488 -0
- wavedl/models/densenet.py +10 -23
- wavedl/models/efficientnet.py +6 -6
- wavedl/models/efficientnetv2.py +315 -315
- wavedl/models/efficientvit.py +398 -0
- wavedl/models/fastvit.py +252 -0
- wavedl/models/mamba.py +555 -0
- wavedl/models/maxvit.py +254 -0
- wavedl/models/mobilenetv3.py +295 -295
- wavedl/models/regnet.py +406 -406
- wavedl/models/resnet.py +19 -61
- wavedl/models/resnet3d.py +258 -258
- wavedl/models/swin.py +443 -443
- wavedl/models/tcn.py +393 -409
- wavedl/models/unet.py +2 -6
- wavedl/models/unireplknet.py +491 -0
- wavedl/models/vit.py +9 -9
- wavedl/train.py +1430 -1425
- wavedl/utils/config.py +367 -367
- wavedl/utils/cross_validation.py +530 -530
- wavedl/utils/data.py +39 -6
- wavedl/utils/losses.py +216 -216
- wavedl/utils/optimizers.py +216 -216
- wavedl/utils/schedulers.py +251 -251
- {wavedl-1.5.7.dist-info → wavedl-1.6.1.dist-info}/METADATA +150 -82
- wavedl-1.6.1.dist-info/RECORD +46 -0
- wavedl-1.5.7.dist-info/RECORD +0 -38
- {wavedl-1.5.7.dist-info → wavedl-1.6.1.dist-info}/LICENSE +0 -0
- {wavedl-1.5.7.dist-info → wavedl-1.6.1.dist-info}/WHEEL +0 -0
- {wavedl-1.5.7.dist-info → wavedl-1.6.1.dist-info}/entry_points.txt +0 -0
- {wavedl-1.5.7.dist-info → wavedl-1.6.1.dist-info}/top_level.txt +0 -0
wavedl/models/resnet.py
CHANGED
|
@@ -11,9 +11,9 @@ Provides multiple depth variants (18, 34, 50) with optional pretrained weights f
|
|
|
11
11
|
- 3D: Volumetric data, CT/MRI (N, 1, D, H, W) → Conv3d
|
|
12
12
|
|
|
13
13
|
**Variants**:
|
|
14
|
-
- resnet18: Lightweight, fast training (~
|
|
15
|
-
- resnet34: Balanced capacity (~
|
|
16
|
-
- resnet50: Higher capacity with bottleneck blocks (~
|
|
14
|
+
- resnet18: Lightweight, fast training (~11.2M backbone params)
|
|
15
|
+
- resnet34: Balanced capacity (~21.3M backbone params)
|
|
16
|
+
- resnet50: Higher capacity with bottleneck blocks (~23.5M backbone params)
|
|
17
17
|
|
|
18
18
|
References:
|
|
19
19
|
He, K., Zhang, X., Ren, S., & Sun, J. (2016). Deep Residual Learning
|
|
@@ -27,14 +27,10 @@ from typing import Any
|
|
|
27
27
|
import torch
|
|
28
28
|
import torch.nn as nn
|
|
29
29
|
|
|
30
|
-
from wavedl.models.base import BaseModel
|
|
30
|
+
from wavedl.models.base import BaseModel, SpatialShape, compute_num_groups
|
|
31
31
|
from wavedl.models.registry import register_model
|
|
32
32
|
|
|
33
33
|
|
|
34
|
-
# Type alias for spatial shapes
|
|
35
|
-
SpatialShape = tuple[int] | tuple[int, int] | tuple[int, int, int]
|
|
36
|
-
|
|
37
|
-
|
|
38
34
|
def _get_conv_layers(
|
|
39
35
|
dim: int,
|
|
40
36
|
) -> tuple[type[nn.Module], type[nn.Module], type[nn.Module]]:
|
|
@@ -49,36 +45,6 @@ def _get_conv_layers(
|
|
|
49
45
|
raise ValueError(f"Unsupported dimensionality: {dim}D. Supported: 1D, 2D, 3D.")
|
|
50
46
|
|
|
51
47
|
|
|
52
|
-
def _get_num_groups(num_channels: int, preferred_groups: int = 32) -> int:
|
|
53
|
-
"""
|
|
54
|
-
Get valid num_groups for GroupNorm that divides num_channels evenly.
|
|
55
|
-
|
|
56
|
-
Args:
|
|
57
|
-
num_channels: Number of channels to normalize
|
|
58
|
-
preferred_groups: Preferred number of groups (default: 32)
|
|
59
|
-
|
|
60
|
-
Returns:
|
|
61
|
-
Valid num_groups that divides num_channels
|
|
62
|
-
|
|
63
|
-
Raises:
|
|
64
|
-
ValueError: If no valid divisor found (shouldn't happen with power-of-2 channels)
|
|
65
|
-
"""
|
|
66
|
-
# Try preferred groups first, then decrease
|
|
67
|
-
for groups in [preferred_groups, 16, 8, 4, 2, 1]:
|
|
68
|
-
if groups <= num_channels and num_channels % groups == 0:
|
|
69
|
-
return groups
|
|
70
|
-
|
|
71
|
-
# Fallback: find any valid divisor
|
|
72
|
-
for groups in range(min(32, num_channels), 0, -1):
|
|
73
|
-
if num_channels % groups == 0:
|
|
74
|
-
return groups
|
|
75
|
-
|
|
76
|
-
raise ValueError(
|
|
77
|
-
f"Cannot find valid num_groups for {num_channels} channels. "
|
|
78
|
-
f"Consider using base_width that is a power of 2 (e.g., 32, 64, 128)."
|
|
79
|
-
)
|
|
80
|
-
|
|
81
|
-
|
|
82
48
|
class BasicBlock(nn.Module):
|
|
83
49
|
"""
|
|
84
50
|
Basic residual block for ResNet-18/34.
|
|
@@ -107,12 +73,12 @@ class BasicBlock(nn.Module):
|
|
|
107
73
|
padding=1,
|
|
108
74
|
bias=False,
|
|
109
75
|
)
|
|
110
|
-
self.gn1 = nn.GroupNorm(
|
|
76
|
+
self.gn1 = nn.GroupNorm(compute_num_groups(out_channels), out_channels)
|
|
111
77
|
self.relu = nn.ReLU(inplace=True)
|
|
112
78
|
self.conv2 = Conv(
|
|
113
79
|
out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False
|
|
114
80
|
)
|
|
115
|
-
self.gn2 = nn.GroupNorm(
|
|
81
|
+
self.gn2 = nn.GroupNorm(compute_num_groups(out_channels), out_channels)
|
|
116
82
|
self.downsample = downsample
|
|
117
83
|
|
|
118
84
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
@@ -155,7 +121,7 @@ class Bottleneck(nn.Module):
|
|
|
155
121
|
|
|
156
122
|
# 1x1 reduce
|
|
157
123
|
self.conv1 = Conv(in_channels, out_channels, kernel_size=1, bias=False)
|
|
158
|
-
self.gn1 = nn.GroupNorm(
|
|
124
|
+
self.gn1 = nn.GroupNorm(compute_num_groups(out_channels), out_channels)
|
|
159
125
|
|
|
160
126
|
# 3x3 conv
|
|
161
127
|
self.conv2 = Conv(
|
|
@@ -166,14 +132,16 @@ class Bottleneck(nn.Module):
|
|
|
166
132
|
padding=1,
|
|
167
133
|
bias=False,
|
|
168
134
|
)
|
|
169
|
-
self.gn2 = nn.GroupNorm(
|
|
135
|
+
self.gn2 = nn.GroupNorm(compute_num_groups(out_channels), out_channels)
|
|
170
136
|
|
|
171
137
|
# 1x1 expand
|
|
172
138
|
self.conv3 = Conv(
|
|
173
139
|
out_channels, out_channels * self.expansion, kernel_size=1, bias=False
|
|
174
140
|
)
|
|
175
141
|
expanded_channels = out_channels * self.expansion
|
|
176
|
-
self.gn3 = nn.GroupNorm(
|
|
142
|
+
self.gn3 = nn.GroupNorm(
|
|
143
|
+
compute_num_groups(expanded_channels), expanded_channels
|
|
144
|
+
)
|
|
177
145
|
|
|
178
146
|
self.relu = nn.ReLU(inplace=True)
|
|
179
147
|
self.downsample = downsample
|
|
@@ -229,7 +197,7 @@ class ResNetBase(BaseModel):
|
|
|
229
197
|
|
|
230
198
|
# Stem: 7x7 conv (or equivalent for 1D/3D)
|
|
231
199
|
self.conv1 = Conv(1, base_width, kernel_size=7, stride=2, padding=3, bias=False)
|
|
232
|
-
self.gn1 = nn.GroupNorm(
|
|
200
|
+
self.gn1 = nn.GroupNorm(compute_num_groups(base_width), base_width)
|
|
233
201
|
self.relu = nn.ReLU(inplace=True)
|
|
234
202
|
self.maxpool = MaxPool(kernel_size=3, stride=2, padding=1)
|
|
235
203
|
|
|
@@ -275,7 +243,7 @@ class ResNetBase(BaseModel):
|
|
|
275
243
|
bias=False,
|
|
276
244
|
),
|
|
277
245
|
nn.GroupNorm(
|
|
278
|
-
|
|
246
|
+
compute_num_groups(out_channels * block.expansion),
|
|
279
247
|
out_channels * block.expansion,
|
|
280
248
|
),
|
|
281
249
|
)
|
|
@@ -495,21 +463,11 @@ class PretrainedResNetBase(BaseModel):
|
|
|
495
463
|
|
|
496
464
|
# Modify first conv for single-channel input
|
|
497
465
|
# Original: Conv2d(3, 64, ...) → New: Conv2d(1, 64, ...)
|
|
498
|
-
|
|
499
|
-
|
|
500
|
-
|
|
501
|
-
|
|
502
|
-
kernel_size=old_conv.kernel_size,
|
|
503
|
-
stride=old_conv.stride,
|
|
504
|
-
padding=old_conv.padding,
|
|
505
|
-
bias=False,
|
|
466
|
+
from wavedl.models._pretrained_utils import adapt_first_conv_for_single_channel
|
|
467
|
+
|
|
468
|
+
adapt_first_conv_for_single_channel(
|
|
469
|
+
self.backbone, "conv1", pretrained=pretrained
|
|
506
470
|
)
|
|
507
|
-
# Initialize new conv with mean of pretrained weights
|
|
508
|
-
if pretrained:
|
|
509
|
-
with torch.no_grad():
|
|
510
|
-
self.backbone.conv1.weight = nn.Parameter(
|
|
511
|
-
old_conv.weight.mean(dim=1, keepdim=True)
|
|
512
|
-
)
|
|
513
471
|
|
|
514
472
|
# Optionally freeze backbone
|
|
515
473
|
if freeze_backbone:
|
|
@@ -534,7 +492,7 @@ class ResNet18Pretrained(PretrainedResNetBase):
|
|
|
534
492
|
"""
|
|
535
493
|
ResNet-18 with ImageNet pretrained weights (2D only).
|
|
536
494
|
|
|
537
|
-
~
|
|
495
|
+
~11.2M backbone parameters. Good for: Transfer learning, faster convergence.
|
|
538
496
|
|
|
539
497
|
Args:
|
|
540
498
|
in_shape: (H, W) image dimensions
|
|
@@ -563,7 +521,7 @@ class ResNet50Pretrained(PretrainedResNetBase):
|
|
|
563
521
|
"""
|
|
564
522
|
ResNet-50 with ImageNet pretrained weights (2D only).
|
|
565
523
|
|
|
566
|
-
~
|
|
524
|
+
~23.5M backbone parameters. Good for: High accuracy with transfer learning.
|
|
567
525
|
|
|
568
526
|
Args:
|
|
569
527
|
in_shape: (H, W) image dimensions
|