wavedl 1.6.0__py3-none-any.whl → 1.6.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- wavedl/__init__.py +1 -1
- wavedl/hpo.py +451 -451
- wavedl/models/__init__.py +28 -0
- wavedl/models/{_timm_utils.py → _pretrained_utils.py} +128 -0
- wavedl/models/base.py +48 -0
- wavedl/models/caformer.py +1 -1
- wavedl/models/cnn.py +2 -27
- wavedl/models/convnext.py +5 -18
- wavedl/models/convnext_v2.py +6 -22
- wavedl/models/densenet.py +5 -18
- wavedl/models/efficientnetv2.py +315 -315
- wavedl/models/efficientvit.py +398 -0
- wavedl/models/fastvit.py +6 -39
- wavedl/models/mamba.py +44 -24
- wavedl/models/maxvit.py +51 -48
- wavedl/models/mobilenetv3.py +295 -295
- wavedl/models/regnet.py +406 -406
- wavedl/models/resnet.py +14 -56
- wavedl/models/resnet3d.py +258 -258
- wavedl/models/swin.py +443 -443
- wavedl/models/tcn.py +393 -409
- wavedl/models/unet.py +1 -5
- wavedl/models/unireplknet.py +491 -0
- wavedl/models/vit.py +3 -3
- wavedl/train.py +1430 -1430
- wavedl/utils/config.py +367 -367
- wavedl/utils/cross_validation.py +530 -530
- wavedl/utils/losses.py +216 -216
- wavedl/utils/optimizers.py +216 -216
- wavedl/utils/schedulers.py +251 -251
- {wavedl-1.6.0.dist-info → wavedl-1.6.1.dist-info}/METADATA +93 -53
- wavedl-1.6.1.dist-info/RECORD +46 -0
- wavedl-1.6.0.dist-info/RECORD +0 -44
- {wavedl-1.6.0.dist-info → wavedl-1.6.1.dist-info}/LICENSE +0 -0
- {wavedl-1.6.0.dist-info → wavedl-1.6.1.dist-info}/WHEEL +0 -0
- {wavedl-1.6.0.dist-info → wavedl-1.6.1.dist-info}/entry_points.txt +0 -0
- {wavedl-1.6.0.dist-info → wavedl-1.6.1.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,398 @@
|
|
|
1
|
+
"""
|
|
2
|
+
EfficientViT: Memory-Efficient Vision Transformer with Cascaded Group Attention
|
|
3
|
+
================================================================================
|
|
4
|
+
|
|
5
|
+
EfficientViT (MIT) achieves state-of-the-art speed-accuracy trade-off by using
|
|
6
|
+
cascaded group attention (CGA) which reduces computational redundancy in
|
|
7
|
+
multi-head self-attention while maintaining model capability.
|
|
8
|
+
|
|
9
|
+
**Key Features**:
|
|
10
|
+
- Cascaded Group Attention (CGA): Linear complexity attention
|
|
11
|
+
- Memory-efficient design for edge deployment
|
|
12
|
+
- Faster than Swin Transformer with similar accuracy
|
|
13
|
+
- Excellent for real-time NDE applications
|
|
14
|
+
|
|
15
|
+
**Variants**:
|
|
16
|
+
- efficientvit_m0: 2.3M params (mobile, fastest)
|
|
17
|
+
- efficientvit_m1: 2.9M params (mobile)
|
|
18
|
+
- efficientvit_m2: 4.2M params (mobile)
|
|
19
|
+
- efficientvit_b0: 3.4M params (balanced)
|
|
20
|
+
- efficientvit_b1: 9.1M params (balanced)
|
|
21
|
+
- efficientvit_b2: 24M params (balanced)
|
|
22
|
+
- efficientvit_b3: 49M params (balanced)
|
|
23
|
+
- efficientvit_l1: 53M params (large)
|
|
24
|
+
- efficientvit_l2: 64M params (large)
|
|
25
|
+
|
|
26
|
+
**Requirements**:
|
|
27
|
+
- timm >= 0.9.0 (for EfficientViT models)
|
|
28
|
+
|
|
29
|
+
Reference:
|
|
30
|
+
Liu, X., et al. (2023). EfficientViT: Memory Efficient Vision Transformer
|
|
31
|
+
with Cascaded Group Attention. CVPR 2023.
|
|
32
|
+
https://arxiv.org/abs/2305.07027
|
|
33
|
+
|
|
34
|
+
Author: Ductho Le (ductho.le@outlook.com)
|
|
35
|
+
"""
|
|
36
|
+
|
|
37
|
+
import torch
|
|
38
|
+
|
|
39
|
+
from wavedl.models._pretrained_utils import build_regression_head
|
|
40
|
+
from wavedl.models.base import BaseModel
|
|
41
|
+
from wavedl.models.registry import register_model
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
__all__ = [
|
|
45
|
+
"EfficientViTB0",
|
|
46
|
+
"EfficientViTB1",
|
|
47
|
+
"EfficientViTB2",
|
|
48
|
+
"EfficientViTB3",
|
|
49
|
+
"EfficientViTBase",
|
|
50
|
+
"EfficientViTL1",
|
|
51
|
+
"EfficientViTL2",
|
|
52
|
+
"EfficientViTM0",
|
|
53
|
+
"EfficientViTM1",
|
|
54
|
+
"EfficientViTM2",
|
|
55
|
+
]
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
# =============================================================================
|
|
59
|
+
# EFFICIENTVIT BASE CLASS
|
|
60
|
+
# =============================================================================
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
class EfficientViTBase(BaseModel):
|
|
64
|
+
"""
|
|
65
|
+
EfficientViT base class wrapping timm implementation.
|
|
66
|
+
|
|
67
|
+
Uses Cascaded Group Attention for efficient multi-head attention with
|
|
68
|
+
linear complexity. 2D only due to attention structure.
|
|
69
|
+
|
|
70
|
+
Args:
|
|
71
|
+
in_shape: (H, W) input shape (2D only)
|
|
72
|
+
out_size: Number of regression targets
|
|
73
|
+
model_name: timm model name
|
|
74
|
+
pretrained: Whether to load pretrained weights
|
|
75
|
+
freeze_backbone: Whether to freeze backbone for fine-tuning
|
|
76
|
+
dropout_rate: Dropout rate for regression head
|
|
77
|
+
"""
|
|
78
|
+
|
|
79
|
+
def __init__(
|
|
80
|
+
self,
|
|
81
|
+
in_shape: tuple[int, int],
|
|
82
|
+
out_size: int,
|
|
83
|
+
model_name: str = "efficientvit_b0",
|
|
84
|
+
pretrained: bool = True,
|
|
85
|
+
freeze_backbone: bool = False,
|
|
86
|
+
dropout_rate: float = 0.3,
|
|
87
|
+
**kwargs,
|
|
88
|
+
):
|
|
89
|
+
super().__init__(in_shape, out_size)
|
|
90
|
+
|
|
91
|
+
if len(in_shape) != 2:
|
|
92
|
+
raise ValueError(
|
|
93
|
+
f"EfficientViT requires 2D input (H, W), got {len(in_shape)}D"
|
|
94
|
+
)
|
|
95
|
+
|
|
96
|
+
self.pretrained = pretrained
|
|
97
|
+
self.freeze_backbone = freeze_backbone
|
|
98
|
+
self.model_name = model_name
|
|
99
|
+
|
|
100
|
+
# Load from timm
|
|
101
|
+
try:
|
|
102
|
+
import timm
|
|
103
|
+
|
|
104
|
+
self.backbone = timm.create_model(
|
|
105
|
+
model_name,
|
|
106
|
+
pretrained=pretrained,
|
|
107
|
+
num_classes=0, # Remove classifier
|
|
108
|
+
)
|
|
109
|
+
|
|
110
|
+
# Get feature dimension
|
|
111
|
+
with torch.no_grad():
|
|
112
|
+
dummy = torch.zeros(1, 3, *in_shape)
|
|
113
|
+
features = self.backbone(dummy)
|
|
114
|
+
in_features = features.shape[-1]
|
|
115
|
+
|
|
116
|
+
except ImportError:
|
|
117
|
+
raise ImportError(
|
|
118
|
+
"timm >= 0.9.0 is required for EfficientViT. "
|
|
119
|
+
"Install with: pip install timm>=0.9.0"
|
|
120
|
+
)
|
|
121
|
+
except Exception as e:
|
|
122
|
+
raise RuntimeError(f"Failed to load EfficientViT model '{model_name}': {e}")
|
|
123
|
+
|
|
124
|
+
# Adapt input channels (3 -> 1)
|
|
125
|
+
self._adapt_input_channels()
|
|
126
|
+
|
|
127
|
+
# Regression head
|
|
128
|
+
self.head = build_regression_head(in_features, out_size, dropout_rate)
|
|
129
|
+
|
|
130
|
+
if freeze_backbone:
|
|
131
|
+
self._freeze_backbone()
|
|
132
|
+
|
|
133
|
+
def _adapt_input_channels(self):
|
|
134
|
+
"""Adapt first conv layer for single-channel input."""
|
|
135
|
+
from wavedl.models._pretrained_utils import find_and_adapt_input_convs
|
|
136
|
+
|
|
137
|
+
adapted_count = find_and_adapt_input_convs(
|
|
138
|
+
self.backbone, pretrained=self.pretrained, adapt_all=False
|
|
139
|
+
)
|
|
140
|
+
|
|
141
|
+
if adapted_count == 0:
|
|
142
|
+
import warnings
|
|
143
|
+
|
|
144
|
+
warnings.warn(
|
|
145
|
+
"Could not adapt EfficientViT input channels. Model may fail.",
|
|
146
|
+
stacklevel=2,
|
|
147
|
+
)
|
|
148
|
+
|
|
149
|
+
def _freeze_backbone(self):
|
|
150
|
+
"""Freeze backbone parameters."""
|
|
151
|
+
for param in self.backbone.parameters():
|
|
152
|
+
param.requires_grad = False
|
|
153
|
+
|
|
154
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
155
|
+
features = self.backbone(x)
|
|
156
|
+
return self.head(features)
|
|
157
|
+
|
|
158
|
+
|
|
159
|
+
# =============================================================================
|
|
160
|
+
# MOBILE VARIANTS (Ultra-lightweight)
|
|
161
|
+
# =============================================================================
|
|
162
|
+
|
|
163
|
+
|
|
164
|
+
@register_model("efficientvit_m0")
|
|
165
|
+
class EfficientViTM0(EfficientViTBase):
|
|
166
|
+
"""
|
|
167
|
+
EfficientViT-M0: ~2.2M backbone parameters (fastest mobile variant).
|
|
168
|
+
|
|
169
|
+
Cascaded group attention for efficient inference.
|
|
170
|
+
Ideal for edge deployment and real-time NDE applications.
|
|
171
|
+
2D only.
|
|
172
|
+
|
|
173
|
+
Example:
|
|
174
|
+
>>> model = EfficientViTM0(in_shape=(224, 224), out_size=3)
|
|
175
|
+
>>> x = torch.randn(4, 1, 224, 224)
|
|
176
|
+
>>> out = model(x) # (4, 3)
|
|
177
|
+
"""
|
|
178
|
+
|
|
179
|
+
def __init__(self, in_shape: tuple[int, int], out_size: int, **kwargs):
|
|
180
|
+
super().__init__(
|
|
181
|
+
in_shape=in_shape,
|
|
182
|
+
out_size=out_size,
|
|
183
|
+
model_name="efficientvit_m0",
|
|
184
|
+
**kwargs,
|
|
185
|
+
)
|
|
186
|
+
|
|
187
|
+
def __repr__(self) -> str:
|
|
188
|
+
return (
|
|
189
|
+
f"EfficientViT_M0(in_shape={self.in_shape}, out_size={self.out_size}, "
|
|
190
|
+
f"pretrained={self.pretrained})"
|
|
191
|
+
)
|
|
192
|
+
|
|
193
|
+
|
|
194
|
+
@register_model("efficientvit_m1")
|
|
195
|
+
class EfficientViTM1(EfficientViTBase):
|
|
196
|
+
"""
|
|
197
|
+
EfficientViT-M1: ~2.6M backbone parameters.
|
|
198
|
+
|
|
199
|
+
Slightly larger mobile variant with better accuracy.
|
|
200
|
+
2D only.
|
|
201
|
+
"""
|
|
202
|
+
|
|
203
|
+
def __init__(self, in_shape: tuple[int, int], out_size: int, **kwargs):
|
|
204
|
+
super().__init__(
|
|
205
|
+
in_shape=in_shape,
|
|
206
|
+
out_size=out_size,
|
|
207
|
+
model_name="efficientvit_m1",
|
|
208
|
+
**kwargs,
|
|
209
|
+
)
|
|
210
|
+
|
|
211
|
+
def __repr__(self) -> str:
|
|
212
|
+
return (
|
|
213
|
+
f"EfficientViT_M1(in_shape={self.in_shape}, out_size={self.out_size}, "
|
|
214
|
+
f"pretrained={self.pretrained})"
|
|
215
|
+
)
|
|
216
|
+
|
|
217
|
+
|
|
218
|
+
@register_model("efficientvit_m2")
|
|
219
|
+
class EfficientViTM2(EfficientViTBase):
|
|
220
|
+
"""
|
|
221
|
+
EfficientViT-M2: ~3.8M backbone parameters.
|
|
222
|
+
|
|
223
|
+
Largest mobile variant, best accuracy among M-series.
|
|
224
|
+
2D only.
|
|
225
|
+
"""
|
|
226
|
+
|
|
227
|
+
def __init__(self, in_shape: tuple[int, int], out_size: int, **kwargs):
|
|
228
|
+
super().__init__(
|
|
229
|
+
in_shape=in_shape,
|
|
230
|
+
out_size=out_size,
|
|
231
|
+
model_name="efficientvit_m2",
|
|
232
|
+
**kwargs,
|
|
233
|
+
)
|
|
234
|
+
|
|
235
|
+
def __repr__(self) -> str:
|
|
236
|
+
return (
|
|
237
|
+
f"EfficientViT_M2(in_shape={self.in_shape}, out_size={self.out_size}, "
|
|
238
|
+
f"pretrained={self.pretrained})"
|
|
239
|
+
)
|
|
240
|
+
|
|
241
|
+
|
|
242
|
+
# =============================================================================
|
|
243
|
+
# BALANCED VARIANTS (B-series)
|
|
244
|
+
# =============================================================================
|
|
245
|
+
|
|
246
|
+
|
|
247
|
+
@register_model("efficientvit_b0")
|
|
248
|
+
class EfficientViTB0(EfficientViTBase):
|
|
249
|
+
"""
|
|
250
|
+
EfficientViT-B0: ~2.1M backbone parameters.
|
|
251
|
+
|
|
252
|
+
Smallest balanced variant. Good accuracy-speed trade-off.
|
|
253
|
+
2D only.
|
|
254
|
+
|
|
255
|
+
Example:
|
|
256
|
+
>>> model = EfficientViTB0(in_shape=(224, 224), out_size=3)
|
|
257
|
+
>>> x = torch.randn(4, 1, 224, 224)
|
|
258
|
+
>>> out = model(x) # (4, 3)
|
|
259
|
+
"""
|
|
260
|
+
|
|
261
|
+
def __init__(self, in_shape: tuple[int, int], out_size: int, **kwargs):
|
|
262
|
+
super().__init__(
|
|
263
|
+
in_shape=in_shape,
|
|
264
|
+
out_size=out_size,
|
|
265
|
+
model_name="efficientvit_b0",
|
|
266
|
+
**kwargs,
|
|
267
|
+
)
|
|
268
|
+
|
|
269
|
+
def __repr__(self) -> str:
|
|
270
|
+
return (
|
|
271
|
+
f"EfficientViT_B0(in_shape={self.in_shape}, out_size={self.out_size}, "
|
|
272
|
+
f"pretrained={self.pretrained})"
|
|
273
|
+
)
|
|
274
|
+
|
|
275
|
+
|
|
276
|
+
@register_model("efficientvit_b1")
|
|
277
|
+
class EfficientViTB1(EfficientViTBase):
|
|
278
|
+
"""
|
|
279
|
+
EfficientViT-B1: ~7.5M backbone parameters.
|
|
280
|
+
|
|
281
|
+
Medium balanced variant with improved capacity.
|
|
282
|
+
2D only.
|
|
283
|
+
"""
|
|
284
|
+
|
|
285
|
+
def __init__(self, in_shape: tuple[int, int], out_size: int, **kwargs):
|
|
286
|
+
super().__init__(
|
|
287
|
+
in_shape=in_shape,
|
|
288
|
+
out_size=out_size,
|
|
289
|
+
model_name="efficientvit_b1",
|
|
290
|
+
**kwargs,
|
|
291
|
+
)
|
|
292
|
+
|
|
293
|
+
def __repr__(self) -> str:
|
|
294
|
+
return (
|
|
295
|
+
f"EfficientViT_B1(in_shape={self.in_shape}, out_size={self.out_size}, "
|
|
296
|
+
f"pretrained={self.pretrained})"
|
|
297
|
+
)
|
|
298
|
+
|
|
299
|
+
|
|
300
|
+
@register_model("efficientvit_b2")
|
|
301
|
+
class EfficientViTB2(EfficientViTBase):
|
|
302
|
+
"""
|
|
303
|
+
EfficientViT-B2: ~21.8M backbone parameters.
|
|
304
|
+
|
|
305
|
+
Larger balanced variant for complex patterns.
|
|
306
|
+
2D only.
|
|
307
|
+
"""
|
|
308
|
+
|
|
309
|
+
def __init__(self, in_shape: tuple[int, int], out_size: int, **kwargs):
|
|
310
|
+
super().__init__(
|
|
311
|
+
in_shape=in_shape,
|
|
312
|
+
out_size=out_size,
|
|
313
|
+
model_name="efficientvit_b2",
|
|
314
|
+
**kwargs,
|
|
315
|
+
)
|
|
316
|
+
|
|
317
|
+
def __repr__(self) -> str:
|
|
318
|
+
return (
|
|
319
|
+
f"EfficientViT_B2(in_shape={self.in_shape}, out_size={self.out_size}, "
|
|
320
|
+
f"pretrained={self.pretrained})"
|
|
321
|
+
)
|
|
322
|
+
|
|
323
|
+
|
|
324
|
+
@register_model("efficientvit_b3")
|
|
325
|
+
class EfficientViTB3(EfficientViTBase):
|
|
326
|
+
"""
|
|
327
|
+
EfficientViT-B3: ~46.1M backbone parameters.
|
|
328
|
+
|
|
329
|
+
Largest balanced variant, highest accuracy in B-series.
|
|
330
|
+
2D only.
|
|
331
|
+
"""
|
|
332
|
+
|
|
333
|
+
def __init__(self, in_shape: tuple[int, int], out_size: int, **kwargs):
|
|
334
|
+
super().__init__(
|
|
335
|
+
in_shape=in_shape,
|
|
336
|
+
out_size=out_size,
|
|
337
|
+
model_name="efficientvit_b3",
|
|
338
|
+
**kwargs,
|
|
339
|
+
)
|
|
340
|
+
|
|
341
|
+
def __repr__(self) -> str:
|
|
342
|
+
return (
|
|
343
|
+
f"EfficientViT_B3(in_shape={self.in_shape}, out_size={self.out_size}, "
|
|
344
|
+
f"pretrained={self.pretrained})"
|
|
345
|
+
)
|
|
346
|
+
|
|
347
|
+
|
|
348
|
+
# =============================================================================
|
|
349
|
+
# LARGE VARIANTS (L-series)
|
|
350
|
+
# =============================================================================
|
|
351
|
+
|
|
352
|
+
|
|
353
|
+
@register_model("efficientvit_l1")
|
|
354
|
+
class EfficientViTL1(EfficientViTBase):
|
|
355
|
+
"""
|
|
356
|
+
EfficientViT-L1: ~49.5M backbone parameters.
|
|
357
|
+
|
|
358
|
+
Large variant for maximum accuracy.
|
|
359
|
+
2D only.
|
|
360
|
+
"""
|
|
361
|
+
|
|
362
|
+
def __init__(self, in_shape: tuple[int, int], out_size: int, **kwargs):
|
|
363
|
+
super().__init__(
|
|
364
|
+
in_shape=in_shape,
|
|
365
|
+
out_size=out_size,
|
|
366
|
+
model_name="efficientvit_l1",
|
|
367
|
+
**kwargs,
|
|
368
|
+
)
|
|
369
|
+
|
|
370
|
+
def __repr__(self) -> str:
|
|
371
|
+
return (
|
|
372
|
+
f"EfficientViT_L1(in_shape={self.in_shape}, out_size={self.out_size}, "
|
|
373
|
+
f"pretrained={self.pretrained})"
|
|
374
|
+
)
|
|
375
|
+
|
|
376
|
+
|
|
377
|
+
@register_model("efficientvit_l2")
|
|
378
|
+
class EfficientViTL2(EfficientViTBase):
|
|
379
|
+
"""
|
|
380
|
+
EfficientViT-L2: ~60.5M backbone parameters.
|
|
381
|
+
|
|
382
|
+
Largest variant, best accuracy.
|
|
383
|
+
2D only.
|
|
384
|
+
"""
|
|
385
|
+
|
|
386
|
+
def __init__(self, in_shape: tuple[int, int], out_size: int, **kwargs):
|
|
387
|
+
super().__init__(
|
|
388
|
+
in_shape=in_shape,
|
|
389
|
+
out_size=out_size,
|
|
390
|
+
model_name="efficientvit_l2",
|
|
391
|
+
**kwargs,
|
|
392
|
+
)
|
|
393
|
+
|
|
394
|
+
def __repr__(self) -> str:
|
|
395
|
+
return (
|
|
396
|
+
f"EfficientViT_L2(in_shape={self.in_shape}, out_size={self.out_size}, "
|
|
397
|
+
f"pretrained={self.pretrained})"
|
|
398
|
+
)
|
wavedl/models/fastvit.py
CHANGED
|
@@ -29,9 +29,8 @@ Author: Ductho Le (ductho.le@outlook.com)
|
|
|
29
29
|
"""
|
|
30
30
|
|
|
31
31
|
import torch
|
|
32
|
-
import torch.nn as nn
|
|
33
32
|
|
|
34
|
-
from wavedl.models.
|
|
33
|
+
from wavedl.models._pretrained_utils import build_regression_head
|
|
35
34
|
from wavedl.models.base import BaseModel
|
|
36
35
|
from wavedl.models.registry import register_model
|
|
37
36
|
|
|
@@ -114,26 +113,11 @@ class FastViTBase(BaseModel):
|
|
|
114
113
|
"""Adapt all conv layers with 3 input channels for single-channel input."""
|
|
115
114
|
# FastViT may have multiple modules with 3 input channels (e.g., conv_kxk, conv_scale)
|
|
116
115
|
# We need to adapt all of them
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
if hasattr(module, "conv") and isinstance(module.conv, nn.Conv2d):
|
|
123
|
-
# Adapt the inner conv layer
|
|
124
|
-
old_conv = module.conv
|
|
125
|
-
module.conv = self._make_new_conv(old_conv)
|
|
126
|
-
adapted_count += 1
|
|
127
|
-
elif isinstance(module, nn.Conv2d):
|
|
128
|
-
# Direct Conv2d - replace it
|
|
129
|
-
parts = name.split(".")
|
|
130
|
-
parent = self.backbone
|
|
131
|
-
for part in parts[:-1]:
|
|
132
|
-
parent = getattr(parent, part)
|
|
133
|
-
child_name = parts[-1]
|
|
134
|
-
new_conv = self._make_new_conv(module)
|
|
135
|
-
setattr(parent, child_name, new_conv)
|
|
136
|
-
adapted_count += 1
|
|
116
|
+
from wavedl.models._pretrained_utils import find_and_adapt_input_convs
|
|
117
|
+
|
|
118
|
+
adapted_count = find_and_adapt_input_convs(
|
|
119
|
+
self.backbone, pretrained=self.pretrained, adapt_all=True
|
|
120
|
+
)
|
|
137
121
|
|
|
138
122
|
if adapted_count == 0:
|
|
139
123
|
import warnings
|
|
@@ -142,23 +126,6 @@ class FastViTBase(BaseModel):
|
|
|
142
126
|
"Could not adapt FastViT input channels. Model may fail.", stacklevel=2
|
|
143
127
|
)
|
|
144
128
|
|
|
145
|
-
def _make_new_conv(self, old_conv: nn.Conv2d) -> nn.Conv2d:
|
|
146
|
-
"""Create new conv layer with 1 input channel."""
|
|
147
|
-
new_conv = nn.Conv2d(
|
|
148
|
-
1,
|
|
149
|
-
old_conv.out_channels,
|
|
150
|
-
kernel_size=old_conv.kernel_size,
|
|
151
|
-
stride=old_conv.stride,
|
|
152
|
-
padding=old_conv.padding,
|
|
153
|
-
bias=old_conv.bias is not None,
|
|
154
|
-
)
|
|
155
|
-
if self.pretrained:
|
|
156
|
-
with torch.no_grad():
|
|
157
|
-
new_conv.weight.copy_(old_conv.weight.mean(dim=1, keepdim=True))
|
|
158
|
-
if old_conv.bias is not None:
|
|
159
|
-
new_conv.bias.copy_(old_conv.bias)
|
|
160
|
-
return new_conv
|
|
161
|
-
|
|
162
129
|
def _freeze_backbone(self):
|
|
163
130
|
"""Freeze backbone parameters."""
|
|
164
131
|
for param in self.backbone.parameters():
|
wavedl/models/mamba.py
CHANGED
|
@@ -34,12 +34,12 @@ import torch
|
|
|
34
34
|
import torch.nn as nn
|
|
35
35
|
import torch.nn.functional as F
|
|
36
36
|
|
|
37
|
-
from wavedl.models.base import BaseModel
|
|
37
|
+
from wavedl.models.base import BaseModel, SpatialShape1D, SpatialShape2D
|
|
38
38
|
from wavedl.models.registry import register_model
|
|
39
39
|
|
|
40
40
|
|
|
41
|
-
# Type
|
|
42
|
-
SpatialShape =
|
|
41
|
+
# Type alias for Mamba models (1D and 2D only)
|
|
42
|
+
SpatialShape = SpatialShape1D | SpatialShape2D
|
|
43
43
|
|
|
44
44
|
__all__ = [
|
|
45
45
|
"Mamba1D",
|
|
@@ -154,35 +154,55 @@ class SelectiveSSM(nn.Module):
|
|
|
154
154
|
D: torch.Tensor,
|
|
155
155
|
) -> torch.Tensor:
|
|
156
156
|
"""
|
|
157
|
-
|
|
157
|
+
Vectorized selective scan using parallel associative scan.
|
|
158
158
|
|
|
159
|
-
|
|
160
|
-
|
|
159
|
+
This implementation avoids the sequential for-loop by computing
|
|
160
|
+
all timesteps in parallel using cumulative products and sums.
|
|
161
|
+
~100x faster than the naive sequential implementation.
|
|
161
162
|
"""
|
|
162
|
-
B_batch, L, d_inner = x.shape
|
|
163
|
-
d_state = A.shape[0]
|
|
164
163
|
|
|
165
|
-
#
|
|
166
|
-
|
|
164
|
+
# Compute discretized A_bar for all timesteps: (B, L, d_inner, d_state)
|
|
165
|
+
A_bar = torch.exp(delta.unsqueeze(-1) * A) # (B, L, d_inner, d_state)
|
|
167
166
|
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
B_t = B[:, t, :] # (B, d_state)
|
|
173
|
-
C_t = C[:, t, :] # (B, d_state)
|
|
167
|
+
# Compute input contribution: delta * B * x for all timesteps
|
|
168
|
+
# B: (B, L, d_state), x: (B, L, d_inner), delta: (B, L, d_inner)
|
|
169
|
+
# Result: (B, L, d_inner, d_state)
|
|
170
|
+
BX = delta.unsqueeze(-1) * B.unsqueeze(2) * x.unsqueeze(-1)
|
|
174
171
|
|
|
175
|
-
|
|
176
|
-
|
|
172
|
+
# Parallel scan using log-space cumulative products for numerical stability
|
|
173
|
+
# For SSM: h[t] = A_bar[t] * h[t-1] + BX[t]
|
|
174
|
+
# This is a linear recurrence that can be solved with associative scan
|
|
177
175
|
|
|
178
|
-
|
|
179
|
-
|
|
176
|
+
# Use chunked approach for memory efficiency with parallel scan
|
|
177
|
+
# Compute cumulative product of A_bar (in log space for stability)
|
|
178
|
+
log_A_bar = torch.log(A_bar.clamp(min=1e-10))
|
|
179
|
+
log_A_cumsum = torch.cumsum(log_A_bar, dim=1) # (B, L, d_inner, d_state)
|
|
180
|
+
A_cumsum = torch.exp(log_A_cumsum)
|
|
180
181
|
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
182
|
+
# For each timestep t, we need: sum_{s=0}^{t} (prod_{k=s+1}^{t} A_bar[k]) * BX[s]
|
|
183
|
+
# = sum_{s=0}^{t} (A_cumsum[t] / A_cumsum[s]) * BX[s]
|
|
184
|
+
# = A_cumsum[t] * sum_{s=0}^{t} (BX[s] / A_cumsum[s])
|
|
184
185
|
|
|
185
|
-
|
|
186
|
+
# Compute BX / A_cumsum (use A_cumsum shifted by 1 for proper indexing)
|
|
187
|
+
# A_cumsum[s] represents prod_{k=0}^{s} A_bar[k], but we need prod_{k=0}^{s-1}
|
|
188
|
+
# So we shift: use A_cumsum from previous timestep
|
|
189
|
+
A_cumsum_shifted = F.pad(A_cumsum[:, :-1], (0, 0, 0, 0, 1, 0), value=1.0)
|
|
190
|
+
|
|
191
|
+
# Weighted input: BX[s] / A_cumsum[s-1] = BX[s] * exp(-log_A_cumsum[s-1])
|
|
192
|
+
weighted_BX = BX / A_cumsum_shifted.clamp(min=1e-10)
|
|
193
|
+
|
|
194
|
+
# Cumulative sum of weighted inputs
|
|
195
|
+
weighted_BX_cumsum = torch.cumsum(weighted_BX, dim=1)
|
|
196
|
+
|
|
197
|
+
# Final state at each timestep: h[t] = A_cumsum[t] * weighted_BX_cumsum[t]
|
|
198
|
+
# But A_cumsum includes A_bar[0], so adjust
|
|
199
|
+
h = A_cumsum * weighted_BX_cumsum / A_bar.clamp(min=1e-10)
|
|
200
|
+
|
|
201
|
+
# Output: y = C * h + D * x
|
|
202
|
+
# h: (B, L, d_inner, d_state), C: (B, L, d_state)
|
|
203
|
+
y = (C.unsqueeze(2) * h).sum(-1) + D * x # (B, L, d_inner)
|
|
204
|
+
|
|
205
|
+
return y
|
|
186
206
|
|
|
187
207
|
|
|
188
208
|
# =============================================================================
|