wavedl 1.5.6__py3-none-any.whl → 1.6.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- wavedl/__init__.py +1 -1
- wavedl/models/__init__.py +52 -4
- wavedl/models/_timm_utils.py +238 -0
- wavedl/models/caformer.py +270 -0
- wavedl/models/convnext.py +108 -33
- wavedl/models/convnext_v2.py +504 -0
- wavedl/models/densenet.py +5 -5
- wavedl/models/efficientnet.py +30 -13
- wavedl/models/efficientnetv2.py +32 -9
- wavedl/models/fastvit.py +285 -0
- wavedl/models/mamba.py +535 -0
- wavedl/models/maxvit.py +251 -0
- wavedl/models/mobilenetv3.py +35 -12
- wavedl/models/regnet.py +39 -16
- wavedl/models/resnet.py +5 -5
- wavedl/models/resnet3d.py +2 -2
- wavedl/models/swin.py +41 -9
- wavedl/models/tcn.py +25 -5
- wavedl/models/unet.py +1 -1
- wavedl/models/vit.py +6 -6
- wavedl/test.py +7 -3
- wavedl/train.py +57 -23
- wavedl/utils/constraints.py +11 -5
- wavedl/utils/data.py +120 -18
- wavedl/utils/metrics.py +287 -326
- {wavedl-1.5.6.dist-info → wavedl-1.6.0.dist-info}/METADATA +104 -67
- wavedl-1.6.0.dist-info/RECORD +44 -0
- wavedl-1.5.6.dist-info/RECORD +0 -38
- {wavedl-1.5.6.dist-info → wavedl-1.6.0.dist-info}/LICENSE +0 -0
- {wavedl-1.5.6.dist-info → wavedl-1.6.0.dist-info}/WHEEL +0 -0
- {wavedl-1.5.6.dist-info → wavedl-1.6.0.dist-info}/entry_points.txt +0 -0
- {wavedl-1.5.6.dist-info → wavedl-1.6.0.dist-info}/top_level.txt +0 -0
wavedl/models/fastvit.py
ADDED
|
@@ -0,0 +1,285 @@
|
|
|
1
|
+
"""
|
|
2
|
+
FastViT: A Fast Hybrid Vision Transformer
|
|
3
|
+
==========================================
|
|
4
|
+
|
|
5
|
+
FastViT from Apple uses RepMixer for efficient token mixing with structural
|
|
6
|
+
reparameterization - train with skip connections, deploy without.
|
|
7
|
+
|
|
8
|
+
**Key Features**:
|
|
9
|
+
- RepMixer: Reparameterizable token mixing
|
|
10
|
+
- Train-time overparameterization
|
|
11
|
+
- Faster than EfficientNet/ConvNeXt on mobile
|
|
12
|
+
- CoreML compatible
|
|
13
|
+
|
|
14
|
+
**Variants**:
|
|
15
|
+
- fastvit_t8: 4M params (fastest)
|
|
16
|
+
- fastvit_t12: 7M params
|
|
17
|
+
- fastvit_s12: 9M params
|
|
18
|
+
- fastvit_sa12: 21M params (with attention)
|
|
19
|
+
|
|
20
|
+
**Requirements**:
|
|
21
|
+
- timm >= 0.9.0 (for FastViT models)
|
|
22
|
+
|
|
23
|
+
Reference:
|
|
24
|
+
Vasu, P.K.A., et al. (2023). FastViT: A Fast Hybrid Vision Transformer
|
|
25
|
+
using Structural Reparameterization. ICCV 2023.
|
|
26
|
+
https://arxiv.org/abs/2303.14189
|
|
27
|
+
|
|
28
|
+
Author: Ductho Le (ductho.le@outlook.com)
|
|
29
|
+
"""
|
|
30
|
+
|
|
31
|
+
import torch
|
|
32
|
+
import torch.nn as nn
|
|
33
|
+
|
|
34
|
+
from wavedl.models._timm_utils import build_regression_head
|
|
35
|
+
from wavedl.models.base import BaseModel
|
|
36
|
+
from wavedl.models.registry import register_model
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
__all__ = [
|
|
40
|
+
"FastViTBase",
|
|
41
|
+
"FastViTS12",
|
|
42
|
+
"FastViTSA12",
|
|
43
|
+
"FastViTT8",
|
|
44
|
+
"FastViTT12",
|
|
45
|
+
]
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
# =============================================================================
|
|
49
|
+
# FASTVIT BASE CLASS
|
|
50
|
+
# =============================================================================
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
class FastViTBase(BaseModel):
|
|
54
|
+
"""
|
|
55
|
+
FastViT base class wrapping timm implementation.
|
|
56
|
+
|
|
57
|
+
Uses RepMixer for efficient token mixing with reparameterization.
|
|
58
|
+
2D only.
|
|
59
|
+
"""
|
|
60
|
+
|
|
61
|
+
def __init__(
|
|
62
|
+
self,
|
|
63
|
+
in_shape: tuple[int, int],
|
|
64
|
+
out_size: int,
|
|
65
|
+
model_name: str = "fastvit_t8",
|
|
66
|
+
pretrained: bool = True,
|
|
67
|
+
freeze_backbone: bool = False,
|
|
68
|
+
dropout_rate: float = 0.3,
|
|
69
|
+
**kwargs,
|
|
70
|
+
):
|
|
71
|
+
super().__init__(in_shape, out_size)
|
|
72
|
+
|
|
73
|
+
if len(in_shape) != 2:
|
|
74
|
+
raise ValueError(f"FastViT requires 2D input (H, W), got {len(in_shape)}D")
|
|
75
|
+
|
|
76
|
+
self.pretrained = pretrained
|
|
77
|
+
self.freeze_backbone = freeze_backbone
|
|
78
|
+
self.model_name = model_name
|
|
79
|
+
|
|
80
|
+
# Try to load from timm
|
|
81
|
+
try:
|
|
82
|
+
import timm
|
|
83
|
+
|
|
84
|
+
self.backbone = timm.create_model(
|
|
85
|
+
model_name,
|
|
86
|
+
pretrained=pretrained,
|
|
87
|
+
num_classes=0, # Remove classifier
|
|
88
|
+
)
|
|
89
|
+
|
|
90
|
+
# Get feature dimension
|
|
91
|
+
with torch.no_grad():
|
|
92
|
+
dummy = torch.zeros(1, 3, *in_shape)
|
|
93
|
+
features = self.backbone(dummy)
|
|
94
|
+
in_features = features.shape[-1]
|
|
95
|
+
|
|
96
|
+
except ImportError:
|
|
97
|
+
raise ImportError(
|
|
98
|
+
"timm >= 0.9.0 is required for FastViT. "
|
|
99
|
+
"Install with: pip install timm>=0.9.0"
|
|
100
|
+
)
|
|
101
|
+
except Exception as e:
|
|
102
|
+
raise RuntimeError(f"Failed to load FastViT model '{model_name}': {e}")
|
|
103
|
+
|
|
104
|
+
# Adapt input channels (3 -> 1)
|
|
105
|
+
self._adapt_input_channels()
|
|
106
|
+
|
|
107
|
+
# Regression head
|
|
108
|
+
self.head = build_regression_head(in_features, out_size, dropout_rate)
|
|
109
|
+
|
|
110
|
+
if freeze_backbone:
|
|
111
|
+
self._freeze_backbone()
|
|
112
|
+
|
|
113
|
+
def _adapt_input_channels(self):
|
|
114
|
+
"""Adapt all conv layers with 3 input channels for single-channel input."""
|
|
115
|
+
# FastViT may have multiple modules with 3 input channels (e.g., conv_kxk, conv_scale)
|
|
116
|
+
# We need to adapt all of them
|
|
117
|
+
adapted_count = 0
|
|
118
|
+
|
|
119
|
+
for name, module in self.backbone.named_modules():
|
|
120
|
+
if hasattr(module, "in_channels") and module.in_channels == 3:
|
|
121
|
+
# Check if this is a wrapper (e.g., ConvNormAct) with inner .conv
|
|
122
|
+
if hasattr(module, "conv") and isinstance(module.conv, nn.Conv2d):
|
|
123
|
+
# Adapt the inner conv layer
|
|
124
|
+
old_conv = module.conv
|
|
125
|
+
module.conv = self._make_new_conv(old_conv)
|
|
126
|
+
adapted_count += 1
|
|
127
|
+
elif isinstance(module, nn.Conv2d):
|
|
128
|
+
# Direct Conv2d - replace it
|
|
129
|
+
parts = name.split(".")
|
|
130
|
+
parent = self.backbone
|
|
131
|
+
for part in parts[:-1]:
|
|
132
|
+
parent = getattr(parent, part)
|
|
133
|
+
child_name = parts[-1]
|
|
134
|
+
new_conv = self._make_new_conv(module)
|
|
135
|
+
setattr(parent, child_name, new_conv)
|
|
136
|
+
adapted_count += 1
|
|
137
|
+
|
|
138
|
+
if adapted_count == 0:
|
|
139
|
+
import warnings
|
|
140
|
+
|
|
141
|
+
warnings.warn(
|
|
142
|
+
"Could not adapt FastViT input channels. Model may fail.", stacklevel=2
|
|
143
|
+
)
|
|
144
|
+
|
|
145
|
+
def _make_new_conv(self, old_conv: nn.Conv2d) -> nn.Conv2d:
|
|
146
|
+
"""Create new conv layer with 1 input channel."""
|
|
147
|
+
new_conv = nn.Conv2d(
|
|
148
|
+
1,
|
|
149
|
+
old_conv.out_channels,
|
|
150
|
+
kernel_size=old_conv.kernel_size,
|
|
151
|
+
stride=old_conv.stride,
|
|
152
|
+
padding=old_conv.padding,
|
|
153
|
+
bias=old_conv.bias is not None,
|
|
154
|
+
)
|
|
155
|
+
if self.pretrained:
|
|
156
|
+
with torch.no_grad():
|
|
157
|
+
new_conv.weight.copy_(old_conv.weight.mean(dim=1, keepdim=True))
|
|
158
|
+
if old_conv.bias is not None:
|
|
159
|
+
new_conv.bias.copy_(old_conv.bias)
|
|
160
|
+
return new_conv
|
|
161
|
+
|
|
162
|
+
def _freeze_backbone(self):
|
|
163
|
+
"""Freeze backbone parameters."""
|
|
164
|
+
for param in self.backbone.parameters():
|
|
165
|
+
param.requires_grad = False
|
|
166
|
+
|
|
167
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
168
|
+
features = self.backbone(x)
|
|
169
|
+
return self.head(features)
|
|
170
|
+
|
|
171
|
+
def reparameterize(self):
|
|
172
|
+
"""
|
|
173
|
+
Reparameterize model for inference.
|
|
174
|
+
|
|
175
|
+
Fuses RepMixer blocks for faster inference.
|
|
176
|
+
Call this before deployment.
|
|
177
|
+
"""
|
|
178
|
+
if hasattr(self.backbone, "reparameterize"):
|
|
179
|
+
self.backbone.reparameterize()
|
|
180
|
+
|
|
181
|
+
|
|
182
|
+
# =============================================================================
|
|
183
|
+
# REGISTERED VARIANTS
|
|
184
|
+
# =============================================================================
|
|
185
|
+
|
|
186
|
+
|
|
187
|
+
@register_model("fastvit_t8")
|
|
188
|
+
class FastViTT8(FastViTBase):
|
|
189
|
+
"""
|
|
190
|
+
FastViT-T8: ~3.3M backbone parameters (fastest variant).
|
|
191
|
+
|
|
192
|
+
Optimized for mobile and edge deployment.
|
|
193
|
+
2D only.
|
|
194
|
+
|
|
195
|
+
Example:
|
|
196
|
+
>>> model = FastViTT8(in_shape=(224, 224), out_size=3)
|
|
197
|
+
>>> x = torch.randn(4, 1, 224, 224)
|
|
198
|
+
>>> out = model(x) # (4, 3)
|
|
199
|
+
"""
|
|
200
|
+
|
|
201
|
+
def __init__(self, in_shape: tuple[int, int], out_size: int, **kwargs):
|
|
202
|
+
super().__init__(
|
|
203
|
+
in_shape=in_shape,
|
|
204
|
+
out_size=out_size,
|
|
205
|
+
model_name="fastvit_t8",
|
|
206
|
+
**kwargs,
|
|
207
|
+
)
|
|
208
|
+
|
|
209
|
+
def __repr__(self) -> str:
|
|
210
|
+
return (
|
|
211
|
+
f"FastViT_T8(in_shape={self.in_shape}, out_size={self.out_size}, "
|
|
212
|
+
f"pretrained={self.pretrained})"
|
|
213
|
+
)
|
|
214
|
+
|
|
215
|
+
|
|
216
|
+
@register_model("fastvit_t12")
|
|
217
|
+
class FastViTT12(FastViTBase):
|
|
218
|
+
"""
|
|
219
|
+
FastViT-T12: ~6.5M backbone parameters.
|
|
220
|
+
|
|
221
|
+
Balanced speed and accuracy.
|
|
222
|
+
2D only.
|
|
223
|
+
"""
|
|
224
|
+
|
|
225
|
+
def __init__(self, in_shape: tuple[int, int], out_size: int, **kwargs):
|
|
226
|
+
super().__init__(
|
|
227
|
+
in_shape=in_shape,
|
|
228
|
+
out_size=out_size,
|
|
229
|
+
model_name="fastvit_t12",
|
|
230
|
+
**kwargs,
|
|
231
|
+
)
|
|
232
|
+
|
|
233
|
+
def __repr__(self) -> str:
|
|
234
|
+
return (
|
|
235
|
+
f"FastViT_T12(in_shape={self.in_shape}, out_size={self.out_size}, "
|
|
236
|
+
f"pretrained={self.pretrained})"
|
|
237
|
+
)
|
|
238
|
+
|
|
239
|
+
|
|
240
|
+
@register_model("fastvit_s12")
|
|
241
|
+
class FastViTS12(FastViTBase):
|
|
242
|
+
"""
|
|
243
|
+
FastViT-S12: ~8.5M backbone parameters.
|
|
244
|
+
|
|
245
|
+
Slightly larger for better accuracy.
|
|
246
|
+
2D only.
|
|
247
|
+
"""
|
|
248
|
+
|
|
249
|
+
def __init__(self, in_shape: tuple[int, int], out_size: int, **kwargs):
|
|
250
|
+
super().__init__(
|
|
251
|
+
in_shape=in_shape,
|
|
252
|
+
out_size=out_size,
|
|
253
|
+
model_name="fastvit_s12",
|
|
254
|
+
**kwargs,
|
|
255
|
+
)
|
|
256
|
+
|
|
257
|
+
def __repr__(self) -> str:
|
|
258
|
+
return (
|
|
259
|
+
f"FastViT_S12(in_shape={self.in_shape}, out_size={self.out_size}, "
|
|
260
|
+
f"pretrained={self.pretrained})"
|
|
261
|
+
)
|
|
262
|
+
|
|
263
|
+
|
|
264
|
+
@register_model("fastvit_sa12")
|
|
265
|
+
class FastViTSA12(FastViTBase):
|
|
266
|
+
"""
|
|
267
|
+
FastViT-SA12: ~10.6M backbone parameters.
|
|
268
|
+
|
|
269
|
+
With self-attention for better accuracy at the cost of speed.
|
|
270
|
+
2D only.
|
|
271
|
+
"""
|
|
272
|
+
|
|
273
|
+
def __init__(self, in_shape: tuple[int, int], out_size: int, **kwargs):
|
|
274
|
+
super().__init__(
|
|
275
|
+
in_shape=in_shape,
|
|
276
|
+
out_size=out_size,
|
|
277
|
+
model_name="fastvit_sa12",
|
|
278
|
+
**kwargs,
|
|
279
|
+
)
|
|
280
|
+
|
|
281
|
+
def __repr__(self) -> str:
|
|
282
|
+
return (
|
|
283
|
+
f"FastViT_SA12(in_shape={self.in_shape}, out_size={self.out_size}, "
|
|
284
|
+
f"pretrained={self.pretrained})"
|
|
285
|
+
)
|