wavedl 1.5.7__py3-none-any.whl → 1.6.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- wavedl/__init__.py +1 -1
- wavedl/hpo.py +451 -451
- wavedl/models/__init__.py +80 -4
- wavedl/models/_pretrained_utils.py +366 -0
- wavedl/models/base.py +48 -0
- wavedl/models/caformer.py +270 -0
- wavedl/models/cnn.py +2 -27
- wavedl/models/convnext.py +113 -51
- wavedl/models/convnext_v2.py +488 -0
- wavedl/models/densenet.py +10 -23
- wavedl/models/efficientnet.py +6 -6
- wavedl/models/efficientnetv2.py +315 -315
- wavedl/models/efficientvit.py +398 -0
- wavedl/models/fastvit.py +252 -0
- wavedl/models/mamba.py +555 -0
- wavedl/models/maxvit.py +254 -0
- wavedl/models/mobilenetv3.py +295 -295
- wavedl/models/regnet.py +406 -406
- wavedl/models/resnet.py +19 -61
- wavedl/models/resnet3d.py +258 -258
- wavedl/models/swin.py +443 -443
- wavedl/models/tcn.py +393 -409
- wavedl/models/unet.py +2 -6
- wavedl/models/unireplknet.py +491 -0
- wavedl/models/vit.py +9 -9
- wavedl/train.py +1430 -1425
- wavedl/utils/config.py +367 -367
- wavedl/utils/cross_validation.py +530 -530
- wavedl/utils/data.py +39 -6
- wavedl/utils/losses.py +216 -216
- wavedl/utils/optimizers.py +216 -216
- wavedl/utils/schedulers.py +251 -251
- {wavedl-1.5.7.dist-info → wavedl-1.6.1.dist-info}/METADATA +150 -82
- wavedl-1.6.1.dist-info/RECORD +46 -0
- wavedl-1.5.7.dist-info/RECORD +0 -38
- {wavedl-1.5.7.dist-info → wavedl-1.6.1.dist-info}/LICENSE +0 -0
- {wavedl-1.5.7.dist-info → wavedl-1.6.1.dist-info}/WHEEL +0 -0
- {wavedl-1.5.7.dist-info → wavedl-1.6.1.dist-info}/entry_points.txt +0 -0
- {wavedl-1.5.7.dist-info → wavedl-1.6.1.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,398 @@
|
|
|
1
|
+
"""
|
|
2
|
+
EfficientViT: Memory-Efficient Vision Transformer with Cascaded Group Attention
|
|
3
|
+
================================================================================
|
|
4
|
+
|
|
5
|
+
EfficientViT (MIT) achieves state-of-the-art speed-accuracy trade-off by using
|
|
6
|
+
cascaded group attention (CGA) which reduces computational redundancy in
|
|
7
|
+
multi-head self-attention while maintaining model capability.
|
|
8
|
+
|
|
9
|
+
**Key Features**:
|
|
10
|
+
- Cascaded Group Attention (CGA): Linear complexity attention
|
|
11
|
+
- Memory-efficient design for edge deployment
|
|
12
|
+
- Faster than Swin Transformer with similar accuracy
|
|
13
|
+
- Excellent for real-time NDE applications
|
|
14
|
+
|
|
15
|
+
**Variants**:
|
|
16
|
+
- efficientvit_m0: 2.3M params (mobile, fastest)
|
|
17
|
+
- efficientvit_m1: 2.9M params (mobile)
|
|
18
|
+
- efficientvit_m2: 4.2M params (mobile)
|
|
19
|
+
- efficientvit_b0: 3.4M params (balanced)
|
|
20
|
+
- efficientvit_b1: 9.1M params (balanced)
|
|
21
|
+
- efficientvit_b2: 24M params (balanced)
|
|
22
|
+
- efficientvit_b3: 49M params (balanced)
|
|
23
|
+
- efficientvit_l1: 53M params (large)
|
|
24
|
+
- efficientvit_l2: 64M params (large)
|
|
25
|
+
|
|
26
|
+
**Requirements**:
|
|
27
|
+
- timm >= 0.9.0 (for EfficientViT models)
|
|
28
|
+
|
|
29
|
+
Reference:
|
|
30
|
+
Liu, X., et al. (2023). EfficientViT: Memory Efficient Vision Transformer
|
|
31
|
+
with Cascaded Group Attention. CVPR 2023.
|
|
32
|
+
https://arxiv.org/abs/2305.07027
|
|
33
|
+
|
|
34
|
+
Author: Ductho Le (ductho.le@outlook.com)
|
|
35
|
+
"""
|
|
36
|
+
|
|
37
|
+
import torch
|
|
38
|
+
|
|
39
|
+
from wavedl.models._pretrained_utils import build_regression_head
|
|
40
|
+
from wavedl.models.base import BaseModel
|
|
41
|
+
from wavedl.models.registry import register_model
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
__all__ = [
|
|
45
|
+
"EfficientViTB0",
|
|
46
|
+
"EfficientViTB1",
|
|
47
|
+
"EfficientViTB2",
|
|
48
|
+
"EfficientViTB3",
|
|
49
|
+
"EfficientViTBase",
|
|
50
|
+
"EfficientViTL1",
|
|
51
|
+
"EfficientViTL2",
|
|
52
|
+
"EfficientViTM0",
|
|
53
|
+
"EfficientViTM1",
|
|
54
|
+
"EfficientViTM2",
|
|
55
|
+
]
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
# =============================================================================
|
|
59
|
+
# EFFICIENTVIT BASE CLASS
|
|
60
|
+
# =============================================================================
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
class EfficientViTBase(BaseModel):
|
|
64
|
+
"""
|
|
65
|
+
EfficientViT base class wrapping timm implementation.
|
|
66
|
+
|
|
67
|
+
Uses Cascaded Group Attention for efficient multi-head attention with
|
|
68
|
+
linear complexity. 2D only due to attention structure.
|
|
69
|
+
|
|
70
|
+
Args:
|
|
71
|
+
in_shape: (H, W) input shape (2D only)
|
|
72
|
+
out_size: Number of regression targets
|
|
73
|
+
model_name: timm model name
|
|
74
|
+
pretrained: Whether to load pretrained weights
|
|
75
|
+
freeze_backbone: Whether to freeze backbone for fine-tuning
|
|
76
|
+
dropout_rate: Dropout rate for regression head
|
|
77
|
+
"""
|
|
78
|
+
|
|
79
|
+
def __init__(
|
|
80
|
+
self,
|
|
81
|
+
in_shape: tuple[int, int],
|
|
82
|
+
out_size: int,
|
|
83
|
+
model_name: str = "efficientvit_b0",
|
|
84
|
+
pretrained: bool = True,
|
|
85
|
+
freeze_backbone: bool = False,
|
|
86
|
+
dropout_rate: float = 0.3,
|
|
87
|
+
**kwargs,
|
|
88
|
+
):
|
|
89
|
+
super().__init__(in_shape, out_size)
|
|
90
|
+
|
|
91
|
+
if len(in_shape) != 2:
|
|
92
|
+
raise ValueError(
|
|
93
|
+
f"EfficientViT requires 2D input (H, W), got {len(in_shape)}D"
|
|
94
|
+
)
|
|
95
|
+
|
|
96
|
+
self.pretrained = pretrained
|
|
97
|
+
self.freeze_backbone = freeze_backbone
|
|
98
|
+
self.model_name = model_name
|
|
99
|
+
|
|
100
|
+
# Load from timm
|
|
101
|
+
try:
|
|
102
|
+
import timm
|
|
103
|
+
|
|
104
|
+
self.backbone = timm.create_model(
|
|
105
|
+
model_name,
|
|
106
|
+
pretrained=pretrained,
|
|
107
|
+
num_classes=0, # Remove classifier
|
|
108
|
+
)
|
|
109
|
+
|
|
110
|
+
# Get feature dimension
|
|
111
|
+
with torch.no_grad():
|
|
112
|
+
dummy = torch.zeros(1, 3, *in_shape)
|
|
113
|
+
features = self.backbone(dummy)
|
|
114
|
+
in_features = features.shape[-1]
|
|
115
|
+
|
|
116
|
+
except ImportError:
|
|
117
|
+
raise ImportError(
|
|
118
|
+
"timm >= 0.9.0 is required for EfficientViT. "
|
|
119
|
+
"Install with: pip install timm>=0.9.0"
|
|
120
|
+
)
|
|
121
|
+
except Exception as e:
|
|
122
|
+
raise RuntimeError(f"Failed to load EfficientViT model '{model_name}': {e}")
|
|
123
|
+
|
|
124
|
+
# Adapt input channels (3 -> 1)
|
|
125
|
+
self._adapt_input_channels()
|
|
126
|
+
|
|
127
|
+
# Regression head
|
|
128
|
+
self.head = build_regression_head(in_features, out_size, dropout_rate)
|
|
129
|
+
|
|
130
|
+
if freeze_backbone:
|
|
131
|
+
self._freeze_backbone()
|
|
132
|
+
|
|
133
|
+
def _adapt_input_channels(self):
|
|
134
|
+
"""Adapt first conv layer for single-channel input."""
|
|
135
|
+
from wavedl.models._pretrained_utils import find_and_adapt_input_convs
|
|
136
|
+
|
|
137
|
+
adapted_count = find_and_adapt_input_convs(
|
|
138
|
+
self.backbone, pretrained=self.pretrained, adapt_all=False
|
|
139
|
+
)
|
|
140
|
+
|
|
141
|
+
if adapted_count == 0:
|
|
142
|
+
import warnings
|
|
143
|
+
|
|
144
|
+
warnings.warn(
|
|
145
|
+
"Could not adapt EfficientViT input channels. Model may fail.",
|
|
146
|
+
stacklevel=2,
|
|
147
|
+
)
|
|
148
|
+
|
|
149
|
+
def _freeze_backbone(self):
|
|
150
|
+
"""Freeze backbone parameters."""
|
|
151
|
+
for param in self.backbone.parameters():
|
|
152
|
+
param.requires_grad = False
|
|
153
|
+
|
|
154
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
155
|
+
features = self.backbone(x)
|
|
156
|
+
return self.head(features)
|
|
157
|
+
|
|
158
|
+
|
|
159
|
+
# =============================================================================
|
|
160
|
+
# MOBILE VARIANTS (Ultra-lightweight)
|
|
161
|
+
# =============================================================================
|
|
162
|
+
|
|
163
|
+
|
|
164
|
+
@register_model("efficientvit_m0")
|
|
165
|
+
class EfficientViTM0(EfficientViTBase):
|
|
166
|
+
"""
|
|
167
|
+
EfficientViT-M0: ~2.2M backbone parameters (fastest mobile variant).
|
|
168
|
+
|
|
169
|
+
Cascaded group attention for efficient inference.
|
|
170
|
+
Ideal for edge deployment and real-time NDE applications.
|
|
171
|
+
2D only.
|
|
172
|
+
|
|
173
|
+
Example:
|
|
174
|
+
>>> model = EfficientViTM0(in_shape=(224, 224), out_size=3)
|
|
175
|
+
>>> x = torch.randn(4, 1, 224, 224)
|
|
176
|
+
>>> out = model(x) # (4, 3)
|
|
177
|
+
"""
|
|
178
|
+
|
|
179
|
+
def __init__(self, in_shape: tuple[int, int], out_size: int, **kwargs):
|
|
180
|
+
super().__init__(
|
|
181
|
+
in_shape=in_shape,
|
|
182
|
+
out_size=out_size,
|
|
183
|
+
model_name="efficientvit_m0",
|
|
184
|
+
**kwargs,
|
|
185
|
+
)
|
|
186
|
+
|
|
187
|
+
def __repr__(self) -> str:
|
|
188
|
+
return (
|
|
189
|
+
f"EfficientViT_M0(in_shape={self.in_shape}, out_size={self.out_size}, "
|
|
190
|
+
f"pretrained={self.pretrained})"
|
|
191
|
+
)
|
|
192
|
+
|
|
193
|
+
|
|
194
|
+
@register_model("efficientvit_m1")
|
|
195
|
+
class EfficientViTM1(EfficientViTBase):
|
|
196
|
+
"""
|
|
197
|
+
EfficientViT-M1: ~2.6M backbone parameters.
|
|
198
|
+
|
|
199
|
+
Slightly larger mobile variant with better accuracy.
|
|
200
|
+
2D only.
|
|
201
|
+
"""
|
|
202
|
+
|
|
203
|
+
def __init__(self, in_shape: tuple[int, int], out_size: int, **kwargs):
|
|
204
|
+
super().__init__(
|
|
205
|
+
in_shape=in_shape,
|
|
206
|
+
out_size=out_size,
|
|
207
|
+
model_name="efficientvit_m1",
|
|
208
|
+
**kwargs,
|
|
209
|
+
)
|
|
210
|
+
|
|
211
|
+
def __repr__(self) -> str:
|
|
212
|
+
return (
|
|
213
|
+
f"EfficientViT_M1(in_shape={self.in_shape}, out_size={self.out_size}, "
|
|
214
|
+
f"pretrained={self.pretrained})"
|
|
215
|
+
)
|
|
216
|
+
|
|
217
|
+
|
|
218
|
+
@register_model("efficientvit_m2")
|
|
219
|
+
class EfficientViTM2(EfficientViTBase):
|
|
220
|
+
"""
|
|
221
|
+
EfficientViT-M2: ~3.8M backbone parameters.
|
|
222
|
+
|
|
223
|
+
Largest mobile variant, best accuracy among M-series.
|
|
224
|
+
2D only.
|
|
225
|
+
"""
|
|
226
|
+
|
|
227
|
+
def __init__(self, in_shape: tuple[int, int], out_size: int, **kwargs):
|
|
228
|
+
super().__init__(
|
|
229
|
+
in_shape=in_shape,
|
|
230
|
+
out_size=out_size,
|
|
231
|
+
model_name="efficientvit_m2",
|
|
232
|
+
**kwargs,
|
|
233
|
+
)
|
|
234
|
+
|
|
235
|
+
def __repr__(self) -> str:
|
|
236
|
+
return (
|
|
237
|
+
f"EfficientViT_M2(in_shape={self.in_shape}, out_size={self.out_size}, "
|
|
238
|
+
f"pretrained={self.pretrained})"
|
|
239
|
+
)
|
|
240
|
+
|
|
241
|
+
|
|
242
|
+
# =============================================================================
|
|
243
|
+
# BALANCED VARIANTS (B-series)
|
|
244
|
+
# =============================================================================
|
|
245
|
+
|
|
246
|
+
|
|
247
|
+
@register_model("efficientvit_b0")
|
|
248
|
+
class EfficientViTB0(EfficientViTBase):
|
|
249
|
+
"""
|
|
250
|
+
EfficientViT-B0: ~2.1M backbone parameters.
|
|
251
|
+
|
|
252
|
+
Smallest balanced variant. Good accuracy-speed trade-off.
|
|
253
|
+
2D only.
|
|
254
|
+
|
|
255
|
+
Example:
|
|
256
|
+
>>> model = EfficientViTB0(in_shape=(224, 224), out_size=3)
|
|
257
|
+
>>> x = torch.randn(4, 1, 224, 224)
|
|
258
|
+
>>> out = model(x) # (4, 3)
|
|
259
|
+
"""
|
|
260
|
+
|
|
261
|
+
def __init__(self, in_shape: tuple[int, int], out_size: int, **kwargs):
|
|
262
|
+
super().__init__(
|
|
263
|
+
in_shape=in_shape,
|
|
264
|
+
out_size=out_size,
|
|
265
|
+
model_name="efficientvit_b0",
|
|
266
|
+
**kwargs,
|
|
267
|
+
)
|
|
268
|
+
|
|
269
|
+
def __repr__(self) -> str:
|
|
270
|
+
return (
|
|
271
|
+
f"EfficientViT_B0(in_shape={self.in_shape}, out_size={self.out_size}, "
|
|
272
|
+
f"pretrained={self.pretrained})"
|
|
273
|
+
)
|
|
274
|
+
|
|
275
|
+
|
|
276
|
+
@register_model("efficientvit_b1")
|
|
277
|
+
class EfficientViTB1(EfficientViTBase):
|
|
278
|
+
"""
|
|
279
|
+
EfficientViT-B1: ~7.5M backbone parameters.
|
|
280
|
+
|
|
281
|
+
Medium balanced variant with improved capacity.
|
|
282
|
+
2D only.
|
|
283
|
+
"""
|
|
284
|
+
|
|
285
|
+
def __init__(self, in_shape: tuple[int, int], out_size: int, **kwargs):
|
|
286
|
+
super().__init__(
|
|
287
|
+
in_shape=in_shape,
|
|
288
|
+
out_size=out_size,
|
|
289
|
+
model_name="efficientvit_b1",
|
|
290
|
+
**kwargs,
|
|
291
|
+
)
|
|
292
|
+
|
|
293
|
+
def __repr__(self) -> str:
|
|
294
|
+
return (
|
|
295
|
+
f"EfficientViT_B1(in_shape={self.in_shape}, out_size={self.out_size}, "
|
|
296
|
+
f"pretrained={self.pretrained})"
|
|
297
|
+
)
|
|
298
|
+
|
|
299
|
+
|
|
300
|
+
@register_model("efficientvit_b2")
|
|
301
|
+
class EfficientViTB2(EfficientViTBase):
|
|
302
|
+
"""
|
|
303
|
+
EfficientViT-B2: ~21.8M backbone parameters.
|
|
304
|
+
|
|
305
|
+
Larger balanced variant for complex patterns.
|
|
306
|
+
2D only.
|
|
307
|
+
"""
|
|
308
|
+
|
|
309
|
+
def __init__(self, in_shape: tuple[int, int], out_size: int, **kwargs):
|
|
310
|
+
super().__init__(
|
|
311
|
+
in_shape=in_shape,
|
|
312
|
+
out_size=out_size,
|
|
313
|
+
model_name="efficientvit_b2",
|
|
314
|
+
**kwargs,
|
|
315
|
+
)
|
|
316
|
+
|
|
317
|
+
def __repr__(self) -> str:
|
|
318
|
+
return (
|
|
319
|
+
f"EfficientViT_B2(in_shape={self.in_shape}, out_size={self.out_size}, "
|
|
320
|
+
f"pretrained={self.pretrained})"
|
|
321
|
+
)
|
|
322
|
+
|
|
323
|
+
|
|
324
|
+
@register_model("efficientvit_b3")
|
|
325
|
+
class EfficientViTB3(EfficientViTBase):
|
|
326
|
+
"""
|
|
327
|
+
EfficientViT-B3: ~46.1M backbone parameters.
|
|
328
|
+
|
|
329
|
+
Largest balanced variant, highest accuracy in B-series.
|
|
330
|
+
2D only.
|
|
331
|
+
"""
|
|
332
|
+
|
|
333
|
+
def __init__(self, in_shape: tuple[int, int], out_size: int, **kwargs):
|
|
334
|
+
super().__init__(
|
|
335
|
+
in_shape=in_shape,
|
|
336
|
+
out_size=out_size,
|
|
337
|
+
model_name="efficientvit_b3",
|
|
338
|
+
**kwargs,
|
|
339
|
+
)
|
|
340
|
+
|
|
341
|
+
def __repr__(self) -> str:
|
|
342
|
+
return (
|
|
343
|
+
f"EfficientViT_B3(in_shape={self.in_shape}, out_size={self.out_size}, "
|
|
344
|
+
f"pretrained={self.pretrained})"
|
|
345
|
+
)
|
|
346
|
+
|
|
347
|
+
|
|
348
|
+
# =============================================================================
|
|
349
|
+
# LARGE VARIANTS (L-series)
|
|
350
|
+
# =============================================================================
|
|
351
|
+
|
|
352
|
+
|
|
353
|
+
@register_model("efficientvit_l1")
|
|
354
|
+
class EfficientViTL1(EfficientViTBase):
|
|
355
|
+
"""
|
|
356
|
+
EfficientViT-L1: ~49.5M backbone parameters.
|
|
357
|
+
|
|
358
|
+
Large variant for maximum accuracy.
|
|
359
|
+
2D only.
|
|
360
|
+
"""
|
|
361
|
+
|
|
362
|
+
def __init__(self, in_shape: tuple[int, int], out_size: int, **kwargs):
|
|
363
|
+
super().__init__(
|
|
364
|
+
in_shape=in_shape,
|
|
365
|
+
out_size=out_size,
|
|
366
|
+
model_name="efficientvit_l1",
|
|
367
|
+
**kwargs,
|
|
368
|
+
)
|
|
369
|
+
|
|
370
|
+
def __repr__(self) -> str:
|
|
371
|
+
return (
|
|
372
|
+
f"EfficientViT_L1(in_shape={self.in_shape}, out_size={self.out_size}, "
|
|
373
|
+
f"pretrained={self.pretrained})"
|
|
374
|
+
)
|
|
375
|
+
|
|
376
|
+
|
|
377
|
+
@register_model("efficientvit_l2")
|
|
378
|
+
class EfficientViTL2(EfficientViTBase):
|
|
379
|
+
"""
|
|
380
|
+
EfficientViT-L2: ~60.5M backbone parameters.
|
|
381
|
+
|
|
382
|
+
Largest variant, best accuracy.
|
|
383
|
+
2D only.
|
|
384
|
+
"""
|
|
385
|
+
|
|
386
|
+
def __init__(self, in_shape: tuple[int, int], out_size: int, **kwargs):
|
|
387
|
+
super().__init__(
|
|
388
|
+
in_shape=in_shape,
|
|
389
|
+
out_size=out_size,
|
|
390
|
+
model_name="efficientvit_l2",
|
|
391
|
+
**kwargs,
|
|
392
|
+
)
|
|
393
|
+
|
|
394
|
+
def __repr__(self) -> str:
|
|
395
|
+
return (
|
|
396
|
+
f"EfficientViT_L2(in_shape={self.in_shape}, out_size={self.out_size}, "
|
|
397
|
+
f"pretrained={self.pretrained})"
|
|
398
|
+
)
|
wavedl/models/fastvit.py
ADDED
|
@@ -0,0 +1,252 @@
|
|
|
1
|
+
"""
|
|
2
|
+
FastViT: A Fast Hybrid Vision Transformer
|
|
3
|
+
==========================================
|
|
4
|
+
|
|
5
|
+
FastViT from Apple uses RepMixer for efficient token mixing with structural
|
|
6
|
+
reparameterization - train with skip connections, deploy without.
|
|
7
|
+
|
|
8
|
+
**Key Features**:
|
|
9
|
+
- RepMixer: Reparameterizable token mixing
|
|
10
|
+
- Train-time overparameterization
|
|
11
|
+
- Faster than EfficientNet/ConvNeXt on mobile
|
|
12
|
+
- CoreML compatible
|
|
13
|
+
|
|
14
|
+
**Variants**:
|
|
15
|
+
- fastvit_t8: 4M params (fastest)
|
|
16
|
+
- fastvit_t12: 7M params
|
|
17
|
+
- fastvit_s12: 9M params
|
|
18
|
+
- fastvit_sa12: 21M params (with attention)
|
|
19
|
+
|
|
20
|
+
**Requirements**:
|
|
21
|
+
- timm >= 0.9.0 (for FastViT models)
|
|
22
|
+
|
|
23
|
+
Reference:
|
|
24
|
+
Vasu, P.K.A., et al. (2023). FastViT: A Fast Hybrid Vision Transformer
|
|
25
|
+
using Structural Reparameterization. ICCV 2023.
|
|
26
|
+
https://arxiv.org/abs/2303.14189
|
|
27
|
+
|
|
28
|
+
Author: Ductho Le (ductho.le@outlook.com)
|
|
29
|
+
"""
|
|
30
|
+
|
|
31
|
+
import torch
|
|
32
|
+
|
|
33
|
+
from wavedl.models._pretrained_utils import build_regression_head
|
|
34
|
+
from wavedl.models.base import BaseModel
|
|
35
|
+
from wavedl.models.registry import register_model
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
__all__ = [
|
|
39
|
+
"FastViTBase",
|
|
40
|
+
"FastViTS12",
|
|
41
|
+
"FastViTSA12",
|
|
42
|
+
"FastViTT8",
|
|
43
|
+
"FastViTT12",
|
|
44
|
+
]
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
# =============================================================================
|
|
48
|
+
# FASTVIT BASE CLASS
|
|
49
|
+
# =============================================================================
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
class FastViTBase(BaseModel):
|
|
53
|
+
"""
|
|
54
|
+
FastViT base class wrapping timm implementation.
|
|
55
|
+
|
|
56
|
+
Uses RepMixer for efficient token mixing with reparameterization.
|
|
57
|
+
2D only.
|
|
58
|
+
"""
|
|
59
|
+
|
|
60
|
+
def __init__(
|
|
61
|
+
self,
|
|
62
|
+
in_shape: tuple[int, int],
|
|
63
|
+
out_size: int,
|
|
64
|
+
model_name: str = "fastvit_t8",
|
|
65
|
+
pretrained: bool = True,
|
|
66
|
+
freeze_backbone: bool = False,
|
|
67
|
+
dropout_rate: float = 0.3,
|
|
68
|
+
**kwargs,
|
|
69
|
+
):
|
|
70
|
+
super().__init__(in_shape, out_size)
|
|
71
|
+
|
|
72
|
+
if len(in_shape) != 2:
|
|
73
|
+
raise ValueError(f"FastViT requires 2D input (H, W), got {len(in_shape)}D")
|
|
74
|
+
|
|
75
|
+
self.pretrained = pretrained
|
|
76
|
+
self.freeze_backbone = freeze_backbone
|
|
77
|
+
self.model_name = model_name
|
|
78
|
+
|
|
79
|
+
# Try to load from timm
|
|
80
|
+
try:
|
|
81
|
+
import timm
|
|
82
|
+
|
|
83
|
+
self.backbone = timm.create_model(
|
|
84
|
+
model_name,
|
|
85
|
+
pretrained=pretrained,
|
|
86
|
+
num_classes=0, # Remove classifier
|
|
87
|
+
)
|
|
88
|
+
|
|
89
|
+
# Get feature dimension
|
|
90
|
+
with torch.no_grad():
|
|
91
|
+
dummy = torch.zeros(1, 3, *in_shape)
|
|
92
|
+
features = self.backbone(dummy)
|
|
93
|
+
in_features = features.shape[-1]
|
|
94
|
+
|
|
95
|
+
except ImportError:
|
|
96
|
+
raise ImportError(
|
|
97
|
+
"timm >= 0.9.0 is required for FastViT. "
|
|
98
|
+
"Install with: pip install timm>=0.9.0"
|
|
99
|
+
)
|
|
100
|
+
except Exception as e:
|
|
101
|
+
raise RuntimeError(f"Failed to load FastViT model '{model_name}': {e}")
|
|
102
|
+
|
|
103
|
+
# Adapt input channels (3 -> 1)
|
|
104
|
+
self._adapt_input_channels()
|
|
105
|
+
|
|
106
|
+
# Regression head
|
|
107
|
+
self.head = build_regression_head(in_features, out_size, dropout_rate)
|
|
108
|
+
|
|
109
|
+
if freeze_backbone:
|
|
110
|
+
self._freeze_backbone()
|
|
111
|
+
|
|
112
|
+
def _adapt_input_channels(self):
|
|
113
|
+
"""Adapt all conv layers with 3 input channels for single-channel input."""
|
|
114
|
+
# FastViT may have multiple modules with 3 input channels (e.g., conv_kxk, conv_scale)
|
|
115
|
+
# We need to adapt all of them
|
|
116
|
+
from wavedl.models._pretrained_utils import find_and_adapt_input_convs
|
|
117
|
+
|
|
118
|
+
adapted_count = find_and_adapt_input_convs(
|
|
119
|
+
self.backbone, pretrained=self.pretrained, adapt_all=True
|
|
120
|
+
)
|
|
121
|
+
|
|
122
|
+
if adapted_count == 0:
|
|
123
|
+
import warnings
|
|
124
|
+
|
|
125
|
+
warnings.warn(
|
|
126
|
+
"Could not adapt FastViT input channels. Model may fail.", stacklevel=2
|
|
127
|
+
)
|
|
128
|
+
|
|
129
|
+
def _freeze_backbone(self):
|
|
130
|
+
"""Freeze backbone parameters."""
|
|
131
|
+
for param in self.backbone.parameters():
|
|
132
|
+
param.requires_grad = False
|
|
133
|
+
|
|
134
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
135
|
+
features = self.backbone(x)
|
|
136
|
+
return self.head(features)
|
|
137
|
+
|
|
138
|
+
def reparameterize(self):
|
|
139
|
+
"""
|
|
140
|
+
Reparameterize model for inference.
|
|
141
|
+
|
|
142
|
+
Fuses RepMixer blocks for faster inference.
|
|
143
|
+
Call this before deployment.
|
|
144
|
+
"""
|
|
145
|
+
if hasattr(self.backbone, "reparameterize"):
|
|
146
|
+
self.backbone.reparameterize()
|
|
147
|
+
|
|
148
|
+
|
|
149
|
+
# =============================================================================
|
|
150
|
+
# REGISTERED VARIANTS
|
|
151
|
+
# =============================================================================
|
|
152
|
+
|
|
153
|
+
|
|
154
|
+
@register_model("fastvit_t8")
|
|
155
|
+
class FastViTT8(FastViTBase):
|
|
156
|
+
"""
|
|
157
|
+
FastViT-T8: ~3.3M backbone parameters (fastest variant).
|
|
158
|
+
|
|
159
|
+
Optimized for mobile and edge deployment.
|
|
160
|
+
2D only.
|
|
161
|
+
|
|
162
|
+
Example:
|
|
163
|
+
>>> model = FastViTT8(in_shape=(224, 224), out_size=3)
|
|
164
|
+
>>> x = torch.randn(4, 1, 224, 224)
|
|
165
|
+
>>> out = model(x) # (4, 3)
|
|
166
|
+
"""
|
|
167
|
+
|
|
168
|
+
def __init__(self, in_shape: tuple[int, int], out_size: int, **kwargs):
|
|
169
|
+
super().__init__(
|
|
170
|
+
in_shape=in_shape,
|
|
171
|
+
out_size=out_size,
|
|
172
|
+
model_name="fastvit_t8",
|
|
173
|
+
**kwargs,
|
|
174
|
+
)
|
|
175
|
+
|
|
176
|
+
def __repr__(self) -> str:
|
|
177
|
+
return (
|
|
178
|
+
f"FastViT_T8(in_shape={self.in_shape}, out_size={self.out_size}, "
|
|
179
|
+
f"pretrained={self.pretrained})"
|
|
180
|
+
)
|
|
181
|
+
|
|
182
|
+
|
|
183
|
+
@register_model("fastvit_t12")
|
|
184
|
+
class FastViTT12(FastViTBase):
|
|
185
|
+
"""
|
|
186
|
+
FastViT-T12: ~6.5M backbone parameters.
|
|
187
|
+
|
|
188
|
+
Balanced speed and accuracy.
|
|
189
|
+
2D only.
|
|
190
|
+
"""
|
|
191
|
+
|
|
192
|
+
def __init__(self, in_shape: tuple[int, int], out_size: int, **kwargs):
|
|
193
|
+
super().__init__(
|
|
194
|
+
in_shape=in_shape,
|
|
195
|
+
out_size=out_size,
|
|
196
|
+
model_name="fastvit_t12",
|
|
197
|
+
**kwargs,
|
|
198
|
+
)
|
|
199
|
+
|
|
200
|
+
def __repr__(self) -> str:
|
|
201
|
+
return (
|
|
202
|
+
f"FastViT_T12(in_shape={self.in_shape}, out_size={self.out_size}, "
|
|
203
|
+
f"pretrained={self.pretrained})"
|
|
204
|
+
)
|
|
205
|
+
|
|
206
|
+
|
|
207
|
+
@register_model("fastvit_s12")
|
|
208
|
+
class FastViTS12(FastViTBase):
|
|
209
|
+
"""
|
|
210
|
+
FastViT-S12: ~8.5M backbone parameters.
|
|
211
|
+
|
|
212
|
+
Slightly larger for better accuracy.
|
|
213
|
+
2D only.
|
|
214
|
+
"""
|
|
215
|
+
|
|
216
|
+
def __init__(self, in_shape: tuple[int, int], out_size: int, **kwargs):
|
|
217
|
+
super().__init__(
|
|
218
|
+
in_shape=in_shape,
|
|
219
|
+
out_size=out_size,
|
|
220
|
+
model_name="fastvit_s12",
|
|
221
|
+
**kwargs,
|
|
222
|
+
)
|
|
223
|
+
|
|
224
|
+
def __repr__(self) -> str:
|
|
225
|
+
return (
|
|
226
|
+
f"FastViT_S12(in_shape={self.in_shape}, out_size={self.out_size}, "
|
|
227
|
+
f"pretrained={self.pretrained})"
|
|
228
|
+
)
|
|
229
|
+
|
|
230
|
+
|
|
231
|
+
@register_model("fastvit_sa12")
|
|
232
|
+
class FastViTSA12(FastViTBase):
|
|
233
|
+
"""
|
|
234
|
+
FastViT-SA12: ~10.6M backbone parameters.
|
|
235
|
+
|
|
236
|
+
With self-attention for better accuracy at the cost of speed.
|
|
237
|
+
2D only.
|
|
238
|
+
"""
|
|
239
|
+
|
|
240
|
+
def __init__(self, in_shape: tuple[int, int], out_size: int, **kwargs):
|
|
241
|
+
super().__init__(
|
|
242
|
+
in_shape=in_shape,
|
|
243
|
+
out_size=out_size,
|
|
244
|
+
model_name="fastvit_sa12",
|
|
245
|
+
**kwargs,
|
|
246
|
+
)
|
|
247
|
+
|
|
248
|
+
def __repr__(self) -> str:
|
|
249
|
+
return (
|
|
250
|
+
f"FastViT_SA12(in_shape={self.in_shape}, out_size={self.out_size}, "
|
|
251
|
+
f"pretrained={self.pretrained})"
|
|
252
|
+
)
|