wavedl 1.6.0__py3-none-any.whl → 1.6.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- wavedl/__init__.py +1 -1
- wavedl/hpo.py +451 -451
- wavedl/{hpc.py → launcher.py} +135 -61
- wavedl/models/__init__.py +28 -0
- wavedl/models/{_timm_utils.py → _pretrained_utils.py} +128 -0
- wavedl/models/base.py +48 -0
- wavedl/models/caformer.py +1 -1
- wavedl/models/cnn.py +2 -27
- wavedl/models/convnext.py +5 -18
- wavedl/models/convnext_v2.py +6 -22
- wavedl/models/densenet.py +5 -18
- wavedl/models/efficientnetv2.py +315 -315
- wavedl/models/efficientvit.py +398 -0
- wavedl/models/fastvit.py +6 -39
- wavedl/models/mamba.py +44 -24
- wavedl/models/maxvit.py +51 -48
- wavedl/models/mobilenetv3.py +295 -295
- wavedl/models/regnet.py +406 -406
- wavedl/models/resnet.py +14 -56
- wavedl/models/resnet3d.py +258 -258
- wavedl/models/swin.py +443 -443
- wavedl/models/tcn.py +393 -409
- wavedl/models/unet.py +1 -5
- wavedl/models/unireplknet.py +491 -0
- wavedl/models/vit.py +3 -3
- wavedl/train.py +1427 -1430
- wavedl/utils/config.py +367 -367
- wavedl/utils/cross_validation.py +530 -530
- wavedl/utils/losses.py +216 -216
- wavedl/utils/optimizers.py +216 -216
- wavedl/utils/schedulers.py +251 -251
- {wavedl-1.6.0.dist-info → wavedl-1.6.2.dist-info}/METADATA +150 -113
- wavedl-1.6.2.dist-info/RECORD +46 -0
- {wavedl-1.6.0.dist-info → wavedl-1.6.2.dist-info}/entry_points.txt +2 -2
- wavedl-1.6.0.dist-info/RECORD +0 -44
- {wavedl-1.6.0.dist-info → wavedl-1.6.2.dist-info}/LICENSE +0 -0
- {wavedl-1.6.0.dist-info → wavedl-1.6.2.dist-info}/WHEEL +0 -0
- {wavedl-1.6.0.dist-info → wavedl-1.6.2.dist-info}/top_level.txt +0 -0
wavedl/models/maxvit.py
CHANGED
|
@@ -28,9 +28,9 @@ Author: Ductho Le (ductho.le@outlook.com)
|
|
|
28
28
|
"""
|
|
29
29
|
|
|
30
30
|
import torch
|
|
31
|
-
import torch.nn as
|
|
31
|
+
import torch.nn.functional as F
|
|
32
32
|
|
|
33
|
-
from wavedl.models.
|
|
33
|
+
from wavedl.models._pretrained_utils import build_regression_head
|
|
34
34
|
from wavedl.models.base import BaseModel
|
|
35
35
|
from wavedl.models.registry import register_model
|
|
36
36
|
|
|
@@ -54,8 +54,16 @@ class MaxViTBase(BaseModel):
|
|
|
54
54
|
|
|
55
55
|
Multi-axis attention with local block and global grid attention.
|
|
56
56
|
2D only due to attention structure.
|
|
57
|
+
|
|
58
|
+
Note:
|
|
59
|
+
MaxViT requires input dimensions divisible by 28 (4x stem downsample × 7 window).
|
|
60
|
+
This implementation automatically resizes inputs to the nearest compatible size.
|
|
57
61
|
"""
|
|
58
62
|
|
|
63
|
+
# MaxViT stem downsamples by 4x, then requires divisibility by 7 (window size)
|
|
64
|
+
# So original input must be divisible by 4 * 7 = 28
|
|
65
|
+
_DIVISOR = 28
|
|
66
|
+
|
|
59
67
|
def __init__(
|
|
60
68
|
self,
|
|
61
69
|
in_shape: tuple[int, int],
|
|
@@ -75,6 +83,9 @@ class MaxViTBase(BaseModel):
|
|
|
75
83
|
self.freeze_backbone = freeze_backbone
|
|
76
84
|
self.model_name = model_name
|
|
77
85
|
|
|
86
|
+
# Compute compatible input size for MaxViT attention windows
|
|
87
|
+
self._target_size = self._compute_compatible_size(in_shape)
|
|
88
|
+
|
|
78
89
|
# Try to load from timm
|
|
79
90
|
try:
|
|
80
91
|
import timm
|
|
@@ -85,9 +96,9 @@ class MaxViTBase(BaseModel):
|
|
|
85
96
|
num_classes=0, # Remove classifier
|
|
86
97
|
)
|
|
87
98
|
|
|
88
|
-
# Get feature dimension
|
|
99
|
+
# Get feature dimension using compatible size
|
|
89
100
|
with torch.no_grad():
|
|
90
|
-
dummy = torch.zeros(1, 3, *
|
|
101
|
+
dummy = torch.zeros(1, 3, *self._target_size)
|
|
91
102
|
features = self.backbone(dummy)
|
|
92
103
|
in_features = features.shape[-1]
|
|
93
104
|
|
|
@@ -109,62 +120,54 @@ class MaxViTBase(BaseModel):
|
|
|
109
120
|
|
|
110
121
|
def _adapt_input_channels(self):
|
|
111
122
|
"""Adapt first conv layer for single-channel input."""
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
parts = name.split(".")
|
|
120
|
-
parent = self.backbone
|
|
121
|
-
for part in parts[:-1]:
|
|
122
|
-
parent = getattr(parent, part)
|
|
123
|
-
child_name = parts[-1]
|
|
124
|
-
|
|
125
|
-
# Create new conv with 1 input channel
|
|
126
|
-
new_conv = self._make_new_conv(module)
|
|
127
|
-
setattr(parent, child_name, new_conv)
|
|
128
|
-
adapted = True
|
|
129
|
-
break
|
|
130
|
-
|
|
131
|
-
if not adapted:
|
|
123
|
+
from wavedl.models._pretrained_utils import find_and_adapt_input_convs
|
|
124
|
+
|
|
125
|
+
adapted_count = find_and_adapt_input_convs(
|
|
126
|
+
self.backbone, pretrained=self.pretrained, adapt_all=False
|
|
127
|
+
)
|
|
128
|
+
|
|
129
|
+
if adapted_count == 0:
|
|
132
130
|
import warnings
|
|
133
131
|
|
|
134
132
|
warnings.warn(
|
|
135
133
|
"Could not adapt MaxViT input channels. Model may fail.", stacklevel=2
|
|
136
134
|
)
|
|
137
135
|
|
|
138
|
-
def _make_new_conv(self, old_conv: nn.Module) -> nn.Module:
|
|
139
|
-
"""Create new conv layer with 1 input channel."""
|
|
140
|
-
# Handle both Conv2d and Conv2dSame from timm
|
|
141
|
-
type(old_conv)
|
|
142
|
-
|
|
143
|
-
# Get common parameters
|
|
144
|
-
kwargs = {
|
|
145
|
-
"out_channels": old_conv.out_channels,
|
|
146
|
-
"kernel_size": old_conv.kernel_size,
|
|
147
|
-
"stride": old_conv.stride,
|
|
148
|
-
"padding": old_conv.padding if hasattr(old_conv, "padding") else 0,
|
|
149
|
-
"bias": old_conv.bias is not None,
|
|
150
|
-
}
|
|
151
|
-
|
|
152
|
-
# Create new conv (use regular Conv2d for simplicity)
|
|
153
|
-
new_conv = nn.Conv2d(1, **kwargs)
|
|
154
|
-
|
|
155
|
-
if self.pretrained:
|
|
156
|
-
with torch.no_grad():
|
|
157
|
-
new_conv.weight.copy_(old_conv.weight.mean(dim=1, keepdim=True))
|
|
158
|
-
if old_conv.bias is not None:
|
|
159
|
-
new_conv.bias.copy_(old_conv.bias)
|
|
160
|
-
return new_conv
|
|
161
|
-
|
|
162
136
|
def _freeze_backbone(self):
|
|
163
137
|
"""Freeze backbone parameters."""
|
|
164
138
|
for param in self.backbone.parameters():
|
|
165
139
|
param.requires_grad = False
|
|
166
140
|
|
|
141
|
+
def _compute_compatible_size(self, in_shape: tuple[int, int]) -> tuple[int, int]:
|
|
142
|
+
"""
|
|
143
|
+
Compute the nearest input size compatible with MaxViT attention windows.
|
|
144
|
+
|
|
145
|
+
MaxViT requires input dimensions divisible by 28 (4x stem downsample × 7 window).
|
|
146
|
+
This rounds up to the nearest compatible size.
|
|
147
|
+
|
|
148
|
+
Args:
|
|
149
|
+
in_shape: Original (H, W) input shape
|
|
150
|
+
|
|
151
|
+
Returns:
|
|
152
|
+
Compatible (H, W) shape divisible by 28
|
|
153
|
+
"""
|
|
154
|
+
import math
|
|
155
|
+
|
|
156
|
+
h, w = in_shape
|
|
157
|
+
target_h = math.ceil(h / self._DIVISOR) * self._DIVISOR
|
|
158
|
+
target_w = math.ceil(w / self._DIVISOR) * self._DIVISOR
|
|
159
|
+
return (target_h, target_w)
|
|
160
|
+
|
|
167
161
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
162
|
+
# Resize input to compatible size if needed
|
|
163
|
+
_, _, h, w = x.shape
|
|
164
|
+
if (h, w) != self._target_size:
|
|
165
|
+
x = F.interpolate(
|
|
166
|
+
x,
|
|
167
|
+
size=self._target_size,
|
|
168
|
+
mode="bilinear",
|
|
169
|
+
align_corners=False,
|
|
170
|
+
)
|
|
168
171
|
features = self.backbone(x)
|
|
169
172
|
return self.head(features)
|
|
170
173
|
|