wavedl 1.5.7__py3-none-any.whl → 1.6.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- wavedl/__init__.py +1 -1
- wavedl/hpo.py +451 -451
- wavedl/models/__init__.py +80 -4
- wavedl/models/_pretrained_utils.py +366 -0
- wavedl/models/base.py +48 -0
- wavedl/models/caformer.py +270 -0
- wavedl/models/cnn.py +2 -27
- wavedl/models/convnext.py +113 -51
- wavedl/models/convnext_v2.py +488 -0
- wavedl/models/densenet.py +10 -23
- wavedl/models/efficientnet.py +6 -6
- wavedl/models/efficientnetv2.py +315 -315
- wavedl/models/efficientvit.py +398 -0
- wavedl/models/fastvit.py +252 -0
- wavedl/models/mamba.py +555 -0
- wavedl/models/maxvit.py +254 -0
- wavedl/models/mobilenetv3.py +295 -295
- wavedl/models/regnet.py +406 -406
- wavedl/models/resnet.py +19 -61
- wavedl/models/resnet3d.py +258 -258
- wavedl/models/swin.py +443 -443
- wavedl/models/tcn.py +393 -409
- wavedl/models/unet.py +2 -6
- wavedl/models/unireplknet.py +491 -0
- wavedl/models/vit.py +9 -9
- wavedl/train.py +1430 -1425
- wavedl/utils/config.py +367 -367
- wavedl/utils/cross_validation.py +530 -530
- wavedl/utils/data.py +39 -6
- wavedl/utils/losses.py +216 -216
- wavedl/utils/optimizers.py +216 -216
- wavedl/utils/schedulers.py +251 -251
- {wavedl-1.5.7.dist-info → wavedl-1.6.1.dist-info}/METADATA +150 -82
- wavedl-1.6.1.dist-info/RECORD +46 -0
- wavedl-1.5.7.dist-info/RECORD +0 -38
- {wavedl-1.5.7.dist-info → wavedl-1.6.1.dist-info}/LICENSE +0 -0
- {wavedl-1.5.7.dist-info → wavedl-1.6.1.dist-info}/WHEEL +0 -0
- {wavedl-1.5.7.dist-info → wavedl-1.6.1.dist-info}/entry_points.txt +0 -0
- {wavedl-1.5.7.dist-info → wavedl-1.6.1.dist-info}/top_level.txt +0 -0
wavedl/models/maxvit.py
ADDED
|
@@ -0,0 +1,254 @@
|
|
|
1
|
+
"""
|
|
2
|
+
MaxViT: Multi-Axis Vision Transformer
|
|
3
|
+
======================================
|
|
4
|
+
|
|
5
|
+
MaxViT combines local and global attention with O(n) complexity using
|
|
6
|
+
multi-axis attention: block attention (local) + grid attention (global sparse).
|
|
7
|
+
|
|
8
|
+
**Key Features**:
|
|
9
|
+
- Multi-axis attention for both local and global context
|
|
10
|
+
- Hybrid design with MBConv + attention
|
|
11
|
+
- Linear O(n) complexity
|
|
12
|
+
- Hierarchical multi-scale features
|
|
13
|
+
|
|
14
|
+
**Variants**:
|
|
15
|
+
- maxvit_tiny: 31M params
|
|
16
|
+
- maxvit_small: 69M params
|
|
17
|
+
- maxvit_base: 120M params
|
|
18
|
+
|
|
19
|
+
**Requirements**:
|
|
20
|
+
- timm (for pretrained models and architecture)
|
|
21
|
+
- torchvision (fallback, limited support)
|
|
22
|
+
|
|
23
|
+
Reference:
|
|
24
|
+
Tu, Z., et al. (2022). MaxViT: Multi-Axis Vision Transformer.
|
|
25
|
+
ECCV 2022. https://arxiv.org/abs/2204.01697
|
|
26
|
+
|
|
27
|
+
Author: Ductho Le (ductho.le@outlook.com)
|
|
28
|
+
"""
|
|
29
|
+
|
|
30
|
+
import torch
|
|
31
|
+
import torch.nn.functional as F
|
|
32
|
+
|
|
33
|
+
from wavedl.models._pretrained_utils import build_regression_head
|
|
34
|
+
from wavedl.models.base import BaseModel
|
|
35
|
+
from wavedl.models.registry import register_model
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
__all__ = [
|
|
39
|
+
"MaxViTBase",
|
|
40
|
+
"MaxViTBaseLarge",
|
|
41
|
+
"MaxViTSmall",
|
|
42
|
+
"MaxViTTiny",
|
|
43
|
+
]
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
# =============================================================================
|
|
47
|
+
# MAXVIT BASE CLASS
|
|
48
|
+
# =============================================================================
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
class MaxViTBase(BaseModel):
|
|
52
|
+
"""
|
|
53
|
+
MaxViT base class wrapping timm implementation.
|
|
54
|
+
|
|
55
|
+
Multi-axis attention with local block and global grid attention.
|
|
56
|
+
2D only due to attention structure.
|
|
57
|
+
|
|
58
|
+
Note:
|
|
59
|
+
MaxViT requires input dimensions divisible by 28 (4x stem downsample × 7 window).
|
|
60
|
+
This implementation automatically resizes inputs to the nearest compatible size.
|
|
61
|
+
"""
|
|
62
|
+
|
|
63
|
+
# MaxViT stem downsamples by 4x, then requires divisibility by 7 (window size)
|
|
64
|
+
# So original input must be divisible by 4 * 7 = 28
|
|
65
|
+
_DIVISOR = 28
|
|
66
|
+
|
|
67
|
+
def __init__(
|
|
68
|
+
self,
|
|
69
|
+
in_shape: tuple[int, int],
|
|
70
|
+
out_size: int,
|
|
71
|
+
model_name: str = "maxvit_tiny_tf_224",
|
|
72
|
+
pretrained: bool = True,
|
|
73
|
+
freeze_backbone: bool = False,
|
|
74
|
+
dropout_rate: float = 0.3,
|
|
75
|
+
**kwargs,
|
|
76
|
+
):
|
|
77
|
+
super().__init__(in_shape, out_size)
|
|
78
|
+
|
|
79
|
+
if len(in_shape) != 2:
|
|
80
|
+
raise ValueError(f"MaxViT requires 2D input (H, W), got {len(in_shape)}D")
|
|
81
|
+
|
|
82
|
+
self.pretrained = pretrained
|
|
83
|
+
self.freeze_backbone = freeze_backbone
|
|
84
|
+
self.model_name = model_name
|
|
85
|
+
|
|
86
|
+
# Compute compatible input size for MaxViT attention windows
|
|
87
|
+
self._target_size = self._compute_compatible_size(in_shape)
|
|
88
|
+
|
|
89
|
+
# Try to load from timm
|
|
90
|
+
try:
|
|
91
|
+
import timm
|
|
92
|
+
|
|
93
|
+
self.backbone = timm.create_model(
|
|
94
|
+
model_name,
|
|
95
|
+
pretrained=pretrained,
|
|
96
|
+
num_classes=0, # Remove classifier
|
|
97
|
+
)
|
|
98
|
+
|
|
99
|
+
# Get feature dimension using compatible size
|
|
100
|
+
with torch.no_grad():
|
|
101
|
+
dummy = torch.zeros(1, 3, *self._target_size)
|
|
102
|
+
features = self.backbone(dummy)
|
|
103
|
+
in_features = features.shape[-1]
|
|
104
|
+
|
|
105
|
+
except ImportError:
|
|
106
|
+
raise ImportError(
|
|
107
|
+
"timm is required for MaxViT. Install with: pip install timm"
|
|
108
|
+
)
|
|
109
|
+
except Exception as e:
|
|
110
|
+
raise RuntimeError(f"Failed to load MaxViT model '{model_name}': {e}")
|
|
111
|
+
|
|
112
|
+
# Adapt input channels (3 -> 1)
|
|
113
|
+
self._adapt_input_channels()
|
|
114
|
+
|
|
115
|
+
# Regression head
|
|
116
|
+
self.head = build_regression_head(in_features, out_size, dropout_rate)
|
|
117
|
+
|
|
118
|
+
if freeze_backbone:
|
|
119
|
+
self._freeze_backbone()
|
|
120
|
+
|
|
121
|
+
def _adapt_input_channels(self):
|
|
122
|
+
"""Adapt first conv layer for single-channel input."""
|
|
123
|
+
from wavedl.models._pretrained_utils import find_and_adapt_input_convs
|
|
124
|
+
|
|
125
|
+
adapted_count = find_and_adapt_input_convs(
|
|
126
|
+
self.backbone, pretrained=self.pretrained, adapt_all=False
|
|
127
|
+
)
|
|
128
|
+
|
|
129
|
+
if adapted_count == 0:
|
|
130
|
+
import warnings
|
|
131
|
+
|
|
132
|
+
warnings.warn(
|
|
133
|
+
"Could not adapt MaxViT input channels. Model may fail.", stacklevel=2
|
|
134
|
+
)
|
|
135
|
+
|
|
136
|
+
def _freeze_backbone(self):
|
|
137
|
+
"""Freeze backbone parameters."""
|
|
138
|
+
for param in self.backbone.parameters():
|
|
139
|
+
param.requires_grad = False
|
|
140
|
+
|
|
141
|
+
def _compute_compatible_size(self, in_shape: tuple[int, int]) -> tuple[int, int]:
|
|
142
|
+
"""
|
|
143
|
+
Compute the nearest input size compatible with MaxViT attention windows.
|
|
144
|
+
|
|
145
|
+
MaxViT requires input dimensions divisible by 28 (4x stem downsample × 7 window).
|
|
146
|
+
This rounds up to the nearest compatible size.
|
|
147
|
+
|
|
148
|
+
Args:
|
|
149
|
+
in_shape: Original (H, W) input shape
|
|
150
|
+
|
|
151
|
+
Returns:
|
|
152
|
+
Compatible (H, W) shape divisible by 28
|
|
153
|
+
"""
|
|
154
|
+
import math
|
|
155
|
+
|
|
156
|
+
h, w = in_shape
|
|
157
|
+
target_h = math.ceil(h / self._DIVISOR) * self._DIVISOR
|
|
158
|
+
target_w = math.ceil(w / self._DIVISOR) * self._DIVISOR
|
|
159
|
+
return (target_h, target_w)
|
|
160
|
+
|
|
161
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
162
|
+
# Resize input to compatible size if needed
|
|
163
|
+
_, _, h, w = x.shape
|
|
164
|
+
if (h, w) != self._target_size:
|
|
165
|
+
x = F.interpolate(
|
|
166
|
+
x,
|
|
167
|
+
size=self._target_size,
|
|
168
|
+
mode="bilinear",
|
|
169
|
+
align_corners=False,
|
|
170
|
+
)
|
|
171
|
+
features = self.backbone(x)
|
|
172
|
+
return self.head(features)
|
|
173
|
+
|
|
174
|
+
|
|
175
|
+
# =============================================================================
|
|
176
|
+
# REGISTERED VARIANTS
|
|
177
|
+
# =============================================================================
|
|
178
|
+
|
|
179
|
+
|
|
180
|
+
@register_model("maxvit_tiny")
|
|
181
|
+
class MaxViTTiny(MaxViTBase):
|
|
182
|
+
"""
|
|
183
|
+
MaxViT Tiny: ~30.1M backbone parameters.
|
|
184
|
+
|
|
185
|
+
Multi-axis attention with local+global context.
|
|
186
|
+
2D only.
|
|
187
|
+
|
|
188
|
+
Example:
|
|
189
|
+
>>> model = MaxViTTiny(in_shape=(224, 224), out_size=3)
|
|
190
|
+
>>> x = torch.randn(4, 1, 224, 224)
|
|
191
|
+
>>> out = model(x) # (4, 3)
|
|
192
|
+
"""
|
|
193
|
+
|
|
194
|
+
def __init__(self, in_shape: tuple[int, int], out_size: int, **kwargs):
|
|
195
|
+
super().__init__(
|
|
196
|
+
in_shape=in_shape,
|
|
197
|
+
out_size=out_size,
|
|
198
|
+
model_name="maxvit_tiny_tf_224",
|
|
199
|
+
**kwargs,
|
|
200
|
+
)
|
|
201
|
+
|
|
202
|
+
def __repr__(self) -> str:
|
|
203
|
+
return (
|
|
204
|
+
f"MaxViT_Tiny(in_shape={self.in_shape}, out_size={self.out_size}, "
|
|
205
|
+
f"pretrained={self.pretrained})"
|
|
206
|
+
)
|
|
207
|
+
|
|
208
|
+
|
|
209
|
+
@register_model("maxvit_small")
|
|
210
|
+
class MaxViTSmall(MaxViTBase):
|
|
211
|
+
"""
|
|
212
|
+
MaxViT Small: ~67.6M backbone parameters.
|
|
213
|
+
|
|
214
|
+
Multi-axis attention with local+global context.
|
|
215
|
+
2D only.
|
|
216
|
+
"""
|
|
217
|
+
|
|
218
|
+
def __init__(self, in_shape: tuple[int, int], out_size: int, **kwargs):
|
|
219
|
+
super().__init__(
|
|
220
|
+
in_shape=in_shape,
|
|
221
|
+
out_size=out_size,
|
|
222
|
+
model_name="maxvit_small_tf_224",
|
|
223
|
+
**kwargs,
|
|
224
|
+
)
|
|
225
|
+
|
|
226
|
+
def __repr__(self) -> str:
|
|
227
|
+
return (
|
|
228
|
+
f"MaxViT_Small(in_shape={self.in_shape}, out_size={self.out_size}, "
|
|
229
|
+
f"pretrained={self.pretrained})"
|
|
230
|
+
)
|
|
231
|
+
|
|
232
|
+
|
|
233
|
+
@register_model("maxvit_base")
|
|
234
|
+
class MaxViTBaseLarge(MaxViTBase):
|
|
235
|
+
"""
|
|
236
|
+
MaxViT Base: ~118.1M backbone parameters.
|
|
237
|
+
|
|
238
|
+
Multi-axis attention with local+global context.
|
|
239
|
+
2D only.
|
|
240
|
+
"""
|
|
241
|
+
|
|
242
|
+
def __init__(self, in_shape: tuple[int, int], out_size: int, **kwargs):
|
|
243
|
+
super().__init__(
|
|
244
|
+
in_shape=in_shape,
|
|
245
|
+
out_size=out_size,
|
|
246
|
+
model_name="maxvit_base_tf_224",
|
|
247
|
+
**kwargs,
|
|
248
|
+
)
|
|
249
|
+
|
|
250
|
+
def __repr__(self) -> str:
|
|
251
|
+
return (
|
|
252
|
+
f"MaxViT_Base(in_shape={self.in_shape}, out_size={self.out_size}, "
|
|
253
|
+
f"pretrained={self.pretrained})"
|
|
254
|
+
)
|