wavedl 1.5.7__py3-none-any.whl → 1.6.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- wavedl/__init__.py +1 -1
- wavedl/models/__init__.py +52 -4
- wavedl/models/_timm_utils.py +238 -0
- wavedl/models/caformer.py +270 -0
- wavedl/models/convnext.py +108 -33
- wavedl/models/convnext_v2.py +504 -0
- wavedl/models/densenet.py +5 -5
- wavedl/models/efficientnet.py +6 -6
- wavedl/models/efficientnetv2.py +3 -3
- wavedl/models/fastvit.py +285 -0
- wavedl/models/mamba.py +535 -0
- wavedl/models/maxvit.py +251 -0
- wavedl/models/mobilenetv3.py +6 -6
- wavedl/models/regnet.py +10 -10
- wavedl/models/resnet.py +5 -5
- wavedl/models/resnet3d.py +2 -2
- wavedl/models/swin.py +3 -3
- wavedl/models/tcn.py +3 -3
- wavedl/models/unet.py +1 -1
- wavedl/models/vit.py +6 -6
- wavedl/train.py +21 -16
- wavedl/utils/data.py +39 -6
- {wavedl-1.5.7.dist-info → wavedl-1.6.0.dist-info}/METADATA +90 -62
- wavedl-1.6.0.dist-info/RECORD +44 -0
- wavedl-1.5.7.dist-info/RECORD +0 -38
- {wavedl-1.5.7.dist-info → wavedl-1.6.0.dist-info}/LICENSE +0 -0
- {wavedl-1.5.7.dist-info → wavedl-1.6.0.dist-info}/WHEEL +0 -0
- {wavedl-1.5.7.dist-info → wavedl-1.6.0.dist-info}/entry_points.txt +0 -0
- {wavedl-1.5.7.dist-info → wavedl-1.6.0.dist-info}/top_level.txt +0 -0
wavedl/models/mamba.py
ADDED
|
@@ -0,0 +1,535 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Vision Mamba: Efficient Visual Representation Learning with State Space Models
|
|
3
|
+
===============================================================================
|
|
4
|
+
|
|
5
|
+
Vision Mamba (Vim) adapts the Mamba selective state space model for vision tasks.
|
|
6
|
+
Provides O(n) linear complexity vs O(n²) for transformers, making it efficient
|
|
7
|
+
for long sequences and high-resolution images.
|
|
8
|
+
|
|
9
|
+
**Key Features**:
|
|
10
|
+
- Bidirectional SSM for image understanding
|
|
11
|
+
- O(n) linear complexity
|
|
12
|
+
- 2.8x faster than ViT, 86.8% less GPU memory
|
|
13
|
+
- Works for 1D (time-series) and 2D (images)
|
|
14
|
+
|
|
15
|
+
**Variants**:
|
|
16
|
+
- mamba_1d: For 1D time-series (alternative to TCN)
|
|
17
|
+
- vim_tiny: 7M params for 2D images
|
|
18
|
+
- vim_small: 26M params for 2D images
|
|
19
|
+
- vim_base: 98M params for 2D images
|
|
20
|
+
|
|
21
|
+
**Dependencies**:
|
|
22
|
+
- Optional: mamba-ssm (for optimized CUDA kernels)
|
|
23
|
+
- Fallback: Pure PyTorch implementation
|
|
24
|
+
|
|
25
|
+
Reference:
|
|
26
|
+
Zhu, L., et al. (2024). Vision Mamba: Efficient Visual Representation
|
|
27
|
+
Learning with Bidirectional State Space Model. ICML 2024.
|
|
28
|
+
https://arxiv.org/abs/2401.09417
|
|
29
|
+
|
|
30
|
+
Author: Ductho Le (ductho.le@outlook.com)
|
|
31
|
+
"""
|
|
32
|
+
|
|
33
|
+
import torch
|
|
34
|
+
import torch.nn as nn
|
|
35
|
+
import torch.nn.functional as F
|
|
36
|
+
|
|
37
|
+
from wavedl.models.base import BaseModel
|
|
38
|
+
from wavedl.models.registry import register_model
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
# Type aliases
|
|
42
|
+
SpatialShape = tuple[int] | tuple[int, int]
|
|
43
|
+
|
|
44
|
+
__all__ = [
|
|
45
|
+
"Mamba1D",
|
|
46
|
+
"Mamba1DBase",
|
|
47
|
+
"MambaBlock",
|
|
48
|
+
"VimBase",
|
|
49
|
+
"VimSmall",
|
|
50
|
+
"VimTiny",
|
|
51
|
+
"VisionMambaBase",
|
|
52
|
+
]
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
# =============================================================================
|
|
56
|
+
# SELECTIVE SSM CORE (Pure PyTorch Implementation)
|
|
57
|
+
# =============================================================================
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
class SelectiveSSM(nn.Module):
|
|
61
|
+
"""
|
|
62
|
+
Selective State Space Model (S6) - Core of Mamba.
|
|
63
|
+
|
|
64
|
+
The key innovation is making the SSM parameters (B, C, Δ) input-dependent,
|
|
65
|
+
allowing the model to selectively focus on or ignore inputs.
|
|
66
|
+
|
|
67
|
+
This is a simplified pure-PyTorch implementation. For production use,
|
|
68
|
+
consider the optimized mamba-ssm package.
|
|
69
|
+
"""
|
|
70
|
+
|
|
71
|
+
def __init__(
|
|
72
|
+
self,
|
|
73
|
+
d_model: int,
|
|
74
|
+
d_state: int = 16,
|
|
75
|
+
d_conv: int = 4,
|
|
76
|
+
expand: int = 2,
|
|
77
|
+
):
|
|
78
|
+
super().__init__()
|
|
79
|
+
|
|
80
|
+
self.d_model = d_model
|
|
81
|
+
self.d_state = d_state
|
|
82
|
+
self.d_inner = d_model * expand
|
|
83
|
+
|
|
84
|
+
# Input projection
|
|
85
|
+
self.in_proj = nn.Linear(d_model, self.d_inner * 2, bias=False)
|
|
86
|
+
|
|
87
|
+
# Conv for local context
|
|
88
|
+
self.conv1d = nn.Conv1d(
|
|
89
|
+
self.d_inner,
|
|
90
|
+
self.d_inner,
|
|
91
|
+
kernel_size=d_conv,
|
|
92
|
+
padding=d_conv - 1,
|
|
93
|
+
groups=self.d_inner,
|
|
94
|
+
)
|
|
95
|
+
|
|
96
|
+
# SSM parameters (input-dependent)
|
|
97
|
+
self.x_proj = nn.Linear(self.d_inner, d_state * 2 + 1, bias=False)
|
|
98
|
+
|
|
99
|
+
# Learnable SSM matrices
|
|
100
|
+
self.dt_proj = nn.Linear(1, self.d_inner, bias=True)
|
|
101
|
+
self.A_log = nn.Parameter(
|
|
102
|
+
torch.log(torch.arange(1, d_state + 1, dtype=torch.float32))
|
|
103
|
+
)
|
|
104
|
+
self.D = nn.Parameter(torch.ones(self.d_inner))
|
|
105
|
+
|
|
106
|
+
# Output projection
|
|
107
|
+
self.out_proj = nn.Linear(self.d_inner, d_model, bias=False)
|
|
108
|
+
|
|
109
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
110
|
+
"""
|
|
111
|
+
Args:
|
|
112
|
+
x: (B, L, D) input sequence
|
|
113
|
+
|
|
114
|
+
Returns:
|
|
115
|
+
y: (B, L, D) output sequence
|
|
116
|
+
"""
|
|
117
|
+
_B, L, _D = x.shape
|
|
118
|
+
|
|
119
|
+
# Input projection and split
|
|
120
|
+
xz = self.in_proj(x) # (B, L, 2*d_inner)
|
|
121
|
+
x, z = xz.chunk(2, dim=-1) # Each: (B, L, d_inner)
|
|
122
|
+
|
|
123
|
+
# Conv for local context
|
|
124
|
+
x = x.transpose(1, 2) # (B, d_inner, L)
|
|
125
|
+
x = self.conv1d(x)[:, :, :L] # Causal
|
|
126
|
+
x = x.transpose(1, 2) # (B, L, d_inner)
|
|
127
|
+
x = F.silu(x)
|
|
128
|
+
|
|
129
|
+
# SSM parameters from input
|
|
130
|
+
x_proj = self.x_proj(x) # (B, L, d_state*2 + 1)
|
|
131
|
+
delta = F.softplus(self.dt_proj(x_proj[:, :, :1])) # (B, L, d_inner)
|
|
132
|
+
B_param = x_proj[:, :, 1 : self.d_state + 1] # (B, L, d_state)
|
|
133
|
+
C_param = x_proj[:, :, self.d_state + 1 :] # (B, L, d_state)
|
|
134
|
+
|
|
135
|
+
# Discretize A
|
|
136
|
+
A = -torch.exp(self.A_log) # (d_state,)
|
|
137
|
+
|
|
138
|
+
# Selective scan (simplified, not optimized)
|
|
139
|
+
y = self._selective_scan(x, delta, A, B_param, C_param, self.D)
|
|
140
|
+
|
|
141
|
+
# Gating
|
|
142
|
+
y = y * F.silu(z)
|
|
143
|
+
|
|
144
|
+
# Output projection
|
|
145
|
+
return self.out_proj(y)
|
|
146
|
+
|
|
147
|
+
def _selective_scan(
|
|
148
|
+
self,
|
|
149
|
+
x: torch.Tensor,
|
|
150
|
+
delta: torch.Tensor,
|
|
151
|
+
A: torch.Tensor,
|
|
152
|
+
B: torch.Tensor,
|
|
153
|
+
C: torch.Tensor,
|
|
154
|
+
D: torch.Tensor,
|
|
155
|
+
) -> torch.Tensor:
|
|
156
|
+
"""
|
|
157
|
+
Simplified selective scan.
|
|
158
|
+
|
|
159
|
+
For real applications, use the CUDA-optimized version from mamba-ssm.
|
|
160
|
+
This implementation is for understanding and testing only.
|
|
161
|
+
"""
|
|
162
|
+
B_batch, L, d_inner = x.shape
|
|
163
|
+
d_state = A.shape[0]
|
|
164
|
+
|
|
165
|
+
# Initialize state
|
|
166
|
+
h = torch.zeros(B_batch, d_inner, d_state, device=x.device, dtype=x.dtype)
|
|
167
|
+
|
|
168
|
+
outputs = []
|
|
169
|
+
for t in range(L):
|
|
170
|
+
x_t = x[:, t, :] # (B, d_inner)
|
|
171
|
+
delta_t = delta[:, t, :] # (B, d_inner)
|
|
172
|
+
B_t = B[:, t, :] # (B, d_state)
|
|
173
|
+
C_t = C[:, t, :] # (B, d_state)
|
|
174
|
+
|
|
175
|
+
# Discretize: A_bar = exp(delta * A)
|
|
176
|
+
A_bar = torch.exp(delta_t.unsqueeze(-1) * A) # (B, d_inner, d_state)
|
|
177
|
+
|
|
178
|
+
# Update state: h = A_bar * h + delta * B * x
|
|
179
|
+
h = A_bar * h + delta_t.unsqueeze(-1) * B_t.unsqueeze(1) * x_t.unsqueeze(-1)
|
|
180
|
+
|
|
181
|
+
# Output: y = C * h + D * x
|
|
182
|
+
y_t = (C_t.unsqueeze(1) * h).sum(-1) + D * x_t # (B, d_inner)
|
|
183
|
+
outputs.append(y_t)
|
|
184
|
+
|
|
185
|
+
return torch.stack(outputs, dim=1) # (B, L, d_inner)
|
|
186
|
+
|
|
187
|
+
|
|
188
|
+
# =============================================================================
|
|
189
|
+
# MAMBA BLOCK
|
|
190
|
+
# =============================================================================
|
|
191
|
+
|
|
192
|
+
|
|
193
|
+
class MambaBlock(nn.Module):
|
|
194
|
+
"""
|
|
195
|
+
Mamba Block with residual connection.
|
|
196
|
+
|
|
197
|
+
Architecture:
|
|
198
|
+
Input → Norm → SelectiveSSM → Residual
|
|
199
|
+
"""
|
|
200
|
+
|
|
201
|
+
def __init__(
|
|
202
|
+
self,
|
|
203
|
+
d_model: int,
|
|
204
|
+
d_state: int = 16,
|
|
205
|
+
d_conv: int = 4,
|
|
206
|
+
expand: int = 2,
|
|
207
|
+
):
|
|
208
|
+
super().__init__()
|
|
209
|
+
self.norm = nn.LayerNorm(d_model)
|
|
210
|
+
self.ssm = SelectiveSSM(d_model, d_state, d_conv, expand)
|
|
211
|
+
|
|
212
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
213
|
+
return x + self.ssm(self.norm(x))
|
|
214
|
+
|
|
215
|
+
|
|
216
|
+
# =============================================================================
|
|
217
|
+
# BIDIRECTIONAL MAMBA (For Vision)
|
|
218
|
+
# =============================================================================
|
|
219
|
+
|
|
220
|
+
|
|
221
|
+
class BidirectionalMambaBlock(nn.Module):
|
|
222
|
+
"""
|
|
223
|
+
Bidirectional Mamba Block for vision tasks.
|
|
224
|
+
|
|
225
|
+
Processes sequence in both forward and backward directions
|
|
226
|
+
to capture global context in images.
|
|
227
|
+
"""
|
|
228
|
+
|
|
229
|
+
def __init__(
|
|
230
|
+
self,
|
|
231
|
+
d_model: int,
|
|
232
|
+
d_state: int = 16,
|
|
233
|
+
d_conv: int = 4,
|
|
234
|
+
expand: int = 2,
|
|
235
|
+
):
|
|
236
|
+
super().__init__()
|
|
237
|
+
self.norm = nn.LayerNorm(d_model)
|
|
238
|
+
self.ssm_forward = SelectiveSSM(d_model, d_state, d_conv, expand)
|
|
239
|
+
self.ssm_backward = SelectiveSSM(d_model, d_state, d_conv, expand)
|
|
240
|
+
self.merge = nn.Linear(d_model * 2, d_model)
|
|
241
|
+
|
|
242
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
243
|
+
x_norm = self.norm(x)
|
|
244
|
+
|
|
245
|
+
# Forward pass
|
|
246
|
+
y_forward = self.ssm_forward(x_norm)
|
|
247
|
+
|
|
248
|
+
# Backward pass (flip, process, flip back)
|
|
249
|
+
y_backward = self.ssm_backward(x_norm.flip(dims=[1])).flip(dims=[1])
|
|
250
|
+
|
|
251
|
+
# Merge
|
|
252
|
+
y = self.merge(torch.cat([y_forward, y_backward], dim=-1))
|
|
253
|
+
|
|
254
|
+
return x + y
|
|
255
|
+
|
|
256
|
+
|
|
257
|
+
# =============================================================================
|
|
258
|
+
# MAMBA 1D (For Time-Series)
|
|
259
|
+
# =============================================================================
|
|
260
|
+
|
|
261
|
+
|
|
262
|
+
class Mamba1DBase(BaseModel):
|
|
263
|
+
"""
|
|
264
|
+
Mamba for 1D time-series data.
|
|
265
|
+
|
|
266
|
+
Alternative to TCN with theoretically infinite receptive field
|
|
267
|
+
and linear complexity.
|
|
268
|
+
"""
|
|
269
|
+
|
|
270
|
+
def __init__(
|
|
271
|
+
self,
|
|
272
|
+
in_shape: tuple[int],
|
|
273
|
+
out_size: int,
|
|
274
|
+
d_model: int = 256,
|
|
275
|
+
n_layers: int = 8,
|
|
276
|
+
d_state: int = 16,
|
|
277
|
+
d_conv: int = 4,
|
|
278
|
+
expand: int = 2,
|
|
279
|
+
dropout_rate: float = 0.1,
|
|
280
|
+
**kwargs,
|
|
281
|
+
):
|
|
282
|
+
super().__init__(in_shape, out_size)
|
|
283
|
+
|
|
284
|
+
if len(in_shape) != 1:
|
|
285
|
+
raise ValueError(f"Mamba1D requires 1D input (L,), got {len(in_shape)}D")
|
|
286
|
+
|
|
287
|
+
self.d_model = d_model
|
|
288
|
+
|
|
289
|
+
# Input projection
|
|
290
|
+
self.input_proj = nn.Linear(1, d_model)
|
|
291
|
+
|
|
292
|
+
# Positional encoding
|
|
293
|
+
self.pos_embed = nn.Parameter(torch.zeros(1, in_shape[0], d_model))
|
|
294
|
+
nn.init.trunc_normal_(self.pos_embed, std=0.02)
|
|
295
|
+
|
|
296
|
+
# Mamba blocks
|
|
297
|
+
self.blocks = nn.ModuleList(
|
|
298
|
+
[MambaBlock(d_model, d_state, d_conv, expand) for _ in range(n_layers)]
|
|
299
|
+
)
|
|
300
|
+
|
|
301
|
+
# Final norm
|
|
302
|
+
self.norm = nn.LayerNorm(d_model)
|
|
303
|
+
|
|
304
|
+
# Regression head
|
|
305
|
+
self.head = nn.Sequential(
|
|
306
|
+
nn.Dropout(dropout_rate),
|
|
307
|
+
nn.Linear(d_model, d_model // 2),
|
|
308
|
+
nn.GELU(),
|
|
309
|
+
nn.Dropout(dropout_rate * 0.5),
|
|
310
|
+
nn.Linear(d_model // 2, out_size),
|
|
311
|
+
)
|
|
312
|
+
|
|
313
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
314
|
+
"""
|
|
315
|
+
Args:
|
|
316
|
+
x: (B, 1, L) input signal
|
|
317
|
+
|
|
318
|
+
Returns:
|
|
319
|
+
(B, out_size) regression output
|
|
320
|
+
"""
|
|
321
|
+
_B, _C, L = x.shape
|
|
322
|
+
|
|
323
|
+
# Reshape to sequence
|
|
324
|
+
x = x.transpose(1, 2) # (B, L, 1)
|
|
325
|
+
x = self.input_proj(x) # (B, L, d_model)
|
|
326
|
+
|
|
327
|
+
# Add positional encoding
|
|
328
|
+
x = x + self.pos_embed[:, :L, :]
|
|
329
|
+
|
|
330
|
+
# Mamba blocks
|
|
331
|
+
for block in self.blocks:
|
|
332
|
+
x = block(x)
|
|
333
|
+
|
|
334
|
+
# Global pooling (mean over sequence)
|
|
335
|
+
x = x.mean(dim=1) # (B, d_model)
|
|
336
|
+
|
|
337
|
+
# Final norm and head
|
|
338
|
+
x = self.norm(x)
|
|
339
|
+
return self.head(x)
|
|
340
|
+
|
|
341
|
+
|
|
342
|
+
# =============================================================================
|
|
343
|
+
# VISION MAMBA (For 2D Images)
|
|
344
|
+
# =============================================================================
|
|
345
|
+
|
|
346
|
+
|
|
347
|
+
class VisionMambaBase(BaseModel):
|
|
348
|
+
"""
|
|
349
|
+
Vision Mamba (Vim) for 2D images.
|
|
350
|
+
|
|
351
|
+
Uses bidirectional SSM to capture global context efficiently.
|
|
352
|
+
O(n) complexity instead of O(n²) for transformers.
|
|
353
|
+
"""
|
|
354
|
+
|
|
355
|
+
def __init__(
|
|
356
|
+
self,
|
|
357
|
+
in_shape: tuple[int, int],
|
|
358
|
+
out_size: int,
|
|
359
|
+
patch_size: int = 16,
|
|
360
|
+
d_model: int = 192,
|
|
361
|
+
n_layers: int = 12,
|
|
362
|
+
d_state: int = 16,
|
|
363
|
+
d_conv: int = 4,
|
|
364
|
+
expand: int = 2,
|
|
365
|
+
dropout_rate: float = 0.1,
|
|
366
|
+
**kwargs,
|
|
367
|
+
):
|
|
368
|
+
super().__init__(in_shape, out_size)
|
|
369
|
+
|
|
370
|
+
if len(in_shape) != 2:
|
|
371
|
+
raise ValueError(
|
|
372
|
+
f"VisionMamba requires 2D input (H, W), got {len(in_shape)}D"
|
|
373
|
+
)
|
|
374
|
+
|
|
375
|
+
self.patch_size = patch_size
|
|
376
|
+
self.d_model = d_model
|
|
377
|
+
|
|
378
|
+
H, W = in_shape
|
|
379
|
+
self.num_patches = (H // patch_size) * (W // patch_size)
|
|
380
|
+
self.grid_size = (H // patch_size, W // patch_size)
|
|
381
|
+
|
|
382
|
+
# Patch embedding
|
|
383
|
+
self.patch_embed = nn.Conv2d(
|
|
384
|
+
1, d_model, kernel_size=patch_size, stride=patch_size
|
|
385
|
+
)
|
|
386
|
+
|
|
387
|
+
# CLS token for classification/regression
|
|
388
|
+
self.cls_token = nn.Parameter(torch.zeros(1, 1, d_model))
|
|
389
|
+
nn.init.trunc_normal_(self.cls_token, std=0.02)
|
|
390
|
+
|
|
391
|
+
# Positional embedding
|
|
392
|
+
self.pos_embed = nn.Parameter(torch.zeros(1, self.num_patches + 1, d_model))
|
|
393
|
+
nn.init.trunc_normal_(self.pos_embed, std=0.02)
|
|
394
|
+
|
|
395
|
+
# Bidirectional Mamba blocks
|
|
396
|
+
self.blocks = nn.ModuleList(
|
|
397
|
+
[
|
|
398
|
+
BidirectionalMambaBlock(d_model, d_state, d_conv, expand)
|
|
399
|
+
for _ in range(n_layers)
|
|
400
|
+
]
|
|
401
|
+
)
|
|
402
|
+
|
|
403
|
+
# Final norm
|
|
404
|
+
self.norm = nn.LayerNorm(d_model)
|
|
405
|
+
|
|
406
|
+
# Regression head
|
|
407
|
+
self.head = nn.Sequential(
|
|
408
|
+
nn.Dropout(dropout_rate),
|
|
409
|
+
nn.Linear(d_model, d_model // 2),
|
|
410
|
+
nn.GELU(),
|
|
411
|
+
nn.Dropout(dropout_rate * 0.5),
|
|
412
|
+
nn.Linear(d_model // 2, out_size),
|
|
413
|
+
)
|
|
414
|
+
|
|
415
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
416
|
+
"""
|
|
417
|
+
Args:
|
|
418
|
+
x: (B, 1, H, W) input image
|
|
419
|
+
|
|
420
|
+
Returns:
|
|
421
|
+
(B, out_size) regression output
|
|
422
|
+
"""
|
|
423
|
+
B = x.shape[0]
|
|
424
|
+
|
|
425
|
+
# Patch embedding
|
|
426
|
+
x = self.patch_embed(x) # (B, d_model, H', W')
|
|
427
|
+
x = x.flatten(2).transpose(1, 2) # (B, num_patches, d_model)
|
|
428
|
+
|
|
429
|
+
# Prepend CLS token
|
|
430
|
+
cls_tokens = self.cls_token.expand(B, -1, -1)
|
|
431
|
+
x = torch.cat([cls_tokens, x], dim=1) # (B, 1 + num_patches, d_model)
|
|
432
|
+
|
|
433
|
+
# Add positional embedding
|
|
434
|
+
x = x + self.pos_embed
|
|
435
|
+
|
|
436
|
+
# Bidirectional Mamba blocks
|
|
437
|
+
for block in self.blocks:
|
|
438
|
+
x = block(x)
|
|
439
|
+
|
|
440
|
+
# Extract CLS token
|
|
441
|
+
cls_output = x[:, 0] # (B, d_model)
|
|
442
|
+
|
|
443
|
+
# Final norm and head
|
|
444
|
+
cls_output = self.norm(cls_output)
|
|
445
|
+
return self.head(cls_output)
|
|
446
|
+
|
|
447
|
+
|
|
448
|
+
# =============================================================================
|
|
449
|
+
# REGISTERED VARIANTS
|
|
450
|
+
# =============================================================================
|
|
451
|
+
|
|
452
|
+
|
|
453
|
+
@register_model("mamba_1d")
|
|
454
|
+
class Mamba1D(Mamba1DBase):
|
|
455
|
+
"""
|
|
456
|
+
Mamba 1D: ~3.4M backbone parameters (for time-series regression).
|
|
457
|
+
|
|
458
|
+
8 layers, 256 dim. Alternative to TCN for time-series.
|
|
459
|
+
Pure PyTorch implementation.
|
|
460
|
+
|
|
461
|
+
Example:
|
|
462
|
+
>>> model = Mamba1D(in_shape=(4096,), out_size=3)
|
|
463
|
+
>>> x = torch.randn(4, 1, 4096)
|
|
464
|
+
>>> out = model(x) # (4, 3)
|
|
465
|
+
"""
|
|
466
|
+
|
|
467
|
+
def __init__(self, in_shape: tuple[int], out_size: int, **kwargs):
|
|
468
|
+
kwargs.setdefault("d_model", 256)
|
|
469
|
+
kwargs.setdefault("n_layers", 8)
|
|
470
|
+
super().__init__(in_shape=in_shape, out_size=out_size, **kwargs)
|
|
471
|
+
|
|
472
|
+
def __repr__(self) -> str:
|
|
473
|
+
return f"Mamba1D(in_shape={self.in_shape}, out_size={self.out_size})"
|
|
474
|
+
|
|
475
|
+
|
|
476
|
+
@register_model("vim_tiny")
|
|
477
|
+
class VimTiny(VisionMambaBase):
|
|
478
|
+
"""
|
|
479
|
+
Vision Mamba Tiny: ~6.6M backbone parameters.
|
|
480
|
+
|
|
481
|
+
12 layers, 192 dim. For 2D images.
|
|
482
|
+
Pure PyTorch implementation with O(n) complexity.
|
|
483
|
+
|
|
484
|
+
Example:
|
|
485
|
+
>>> model = VimTiny(in_shape=(224, 224), out_size=3)
|
|
486
|
+
>>> x = torch.randn(4, 1, 224, 224)
|
|
487
|
+
>>> out = model(x) # (4, 3)
|
|
488
|
+
"""
|
|
489
|
+
|
|
490
|
+
def __init__(self, in_shape: tuple[int, int], out_size: int, **kwargs):
|
|
491
|
+
kwargs.setdefault("patch_size", 16)
|
|
492
|
+
kwargs.setdefault("d_model", 192)
|
|
493
|
+
kwargs.setdefault("n_layers", 12)
|
|
494
|
+
super().__init__(in_shape=in_shape, out_size=out_size, **kwargs)
|
|
495
|
+
|
|
496
|
+
def __repr__(self) -> str:
|
|
497
|
+
return f"VisionMamba_Tiny(in_shape={self.in_shape}, out_size={self.out_size})"
|
|
498
|
+
|
|
499
|
+
|
|
500
|
+
@register_model("vim_small")
|
|
501
|
+
class VimSmall(VisionMambaBase):
|
|
502
|
+
"""
|
|
503
|
+
Vision Mamba Small: ~51.1M backbone parameters.
|
|
504
|
+
|
|
505
|
+
24 layers, 384 dim. For 2D images.
|
|
506
|
+
Pure PyTorch implementation with O(n) complexity.
|
|
507
|
+
"""
|
|
508
|
+
|
|
509
|
+
def __init__(self, in_shape: tuple[int, int], out_size: int, **kwargs):
|
|
510
|
+
kwargs.setdefault("patch_size", 16)
|
|
511
|
+
kwargs.setdefault("d_model", 384)
|
|
512
|
+
kwargs.setdefault("n_layers", 24)
|
|
513
|
+
super().__init__(in_shape=in_shape, out_size=out_size, **kwargs)
|
|
514
|
+
|
|
515
|
+
def __repr__(self) -> str:
|
|
516
|
+
return f"VisionMamba_Small(in_shape={self.in_shape}, out_size={self.out_size})"
|
|
517
|
+
|
|
518
|
+
|
|
519
|
+
@register_model("vim_base")
|
|
520
|
+
class VimBase(VisionMambaBase):
|
|
521
|
+
"""
|
|
522
|
+
Vision Mamba Base: ~201.4M backbone parameters.
|
|
523
|
+
|
|
524
|
+
24 layers, 768 dim. For 2D images.
|
|
525
|
+
Pure PyTorch implementation with O(n) complexity.
|
|
526
|
+
"""
|
|
527
|
+
|
|
528
|
+
def __init__(self, in_shape: tuple[int, int], out_size: int, **kwargs):
|
|
529
|
+
kwargs.setdefault("patch_size", 16)
|
|
530
|
+
kwargs.setdefault("d_model", 768)
|
|
531
|
+
kwargs.setdefault("n_layers", 24)
|
|
532
|
+
super().__init__(in_shape=in_shape, out_size=out_size, **kwargs)
|
|
533
|
+
|
|
534
|
+
def __repr__(self) -> str:
|
|
535
|
+
return f"VisionMamba_Base(in_shape={self.in_shape}, out_size={self.out_size})"
|