wavedl 1.5.7__py3-none-any.whl → 1.6.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- wavedl/__init__.py +1 -1
- wavedl/hpo.py +451 -451
- wavedl/models/__init__.py +80 -4
- wavedl/models/_pretrained_utils.py +366 -0
- wavedl/models/base.py +48 -0
- wavedl/models/caformer.py +270 -0
- wavedl/models/cnn.py +2 -27
- wavedl/models/convnext.py +113 -51
- wavedl/models/convnext_v2.py +488 -0
- wavedl/models/densenet.py +10 -23
- wavedl/models/efficientnet.py +6 -6
- wavedl/models/efficientnetv2.py +315 -315
- wavedl/models/efficientvit.py +398 -0
- wavedl/models/fastvit.py +252 -0
- wavedl/models/mamba.py +555 -0
- wavedl/models/maxvit.py +254 -0
- wavedl/models/mobilenetv3.py +295 -295
- wavedl/models/regnet.py +406 -406
- wavedl/models/resnet.py +19 -61
- wavedl/models/resnet3d.py +258 -258
- wavedl/models/swin.py +443 -443
- wavedl/models/tcn.py +393 -409
- wavedl/models/unet.py +2 -6
- wavedl/models/unireplknet.py +491 -0
- wavedl/models/vit.py +9 -9
- wavedl/train.py +1430 -1425
- wavedl/utils/config.py +367 -367
- wavedl/utils/cross_validation.py +530 -530
- wavedl/utils/data.py +39 -6
- wavedl/utils/losses.py +216 -216
- wavedl/utils/optimizers.py +216 -216
- wavedl/utils/schedulers.py +251 -251
- {wavedl-1.5.7.dist-info → wavedl-1.6.1.dist-info}/METADATA +150 -82
- wavedl-1.6.1.dist-info/RECORD +46 -0
- wavedl-1.5.7.dist-info/RECORD +0 -38
- {wavedl-1.5.7.dist-info → wavedl-1.6.1.dist-info}/LICENSE +0 -0
- {wavedl-1.5.7.dist-info → wavedl-1.6.1.dist-info}/WHEEL +0 -0
- {wavedl-1.5.7.dist-info → wavedl-1.6.1.dist-info}/entry_points.txt +0 -0
- {wavedl-1.5.7.dist-info → wavedl-1.6.1.dist-info}/top_level.txt +0 -0
wavedl/models/mamba.py
ADDED
|
@@ -0,0 +1,555 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Vision Mamba: Efficient Visual Representation Learning with State Space Models
|
|
3
|
+
===============================================================================
|
|
4
|
+
|
|
5
|
+
Vision Mamba (Vim) adapts the Mamba selective state space model for vision tasks.
|
|
6
|
+
Provides O(n) linear complexity vs O(n²) for transformers, making it efficient
|
|
7
|
+
for long sequences and high-resolution images.
|
|
8
|
+
|
|
9
|
+
**Key Features**:
|
|
10
|
+
- Bidirectional SSM for image understanding
|
|
11
|
+
- O(n) linear complexity
|
|
12
|
+
- 2.8x faster than ViT, 86.8% less GPU memory
|
|
13
|
+
- Works for 1D (time-series) and 2D (images)
|
|
14
|
+
|
|
15
|
+
**Variants**:
|
|
16
|
+
- mamba_1d: For 1D time-series (alternative to TCN)
|
|
17
|
+
- vim_tiny: 7M params for 2D images
|
|
18
|
+
- vim_small: 26M params for 2D images
|
|
19
|
+
- vim_base: 98M params for 2D images
|
|
20
|
+
|
|
21
|
+
**Dependencies**:
|
|
22
|
+
- Optional: mamba-ssm (for optimized CUDA kernels)
|
|
23
|
+
- Fallback: Pure PyTorch implementation
|
|
24
|
+
|
|
25
|
+
Reference:
|
|
26
|
+
Zhu, L., et al. (2024). Vision Mamba: Efficient Visual Representation
|
|
27
|
+
Learning with Bidirectional State Space Model. ICML 2024.
|
|
28
|
+
https://arxiv.org/abs/2401.09417
|
|
29
|
+
|
|
30
|
+
Author: Ductho Le (ductho.le@outlook.com)
|
|
31
|
+
"""
|
|
32
|
+
|
|
33
|
+
import torch
|
|
34
|
+
import torch.nn as nn
|
|
35
|
+
import torch.nn.functional as F
|
|
36
|
+
|
|
37
|
+
from wavedl.models.base import BaseModel, SpatialShape1D, SpatialShape2D
|
|
38
|
+
from wavedl.models.registry import register_model
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
# Type alias for Mamba models (1D and 2D only)
|
|
42
|
+
SpatialShape = SpatialShape1D | SpatialShape2D
|
|
43
|
+
|
|
44
|
+
__all__ = [
|
|
45
|
+
"Mamba1D",
|
|
46
|
+
"Mamba1DBase",
|
|
47
|
+
"MambaBlock",
|
|
48
|
+
"VimBase",
|
|
49
|
+
"VimSmall",
|
|
50
|
+
"VimTiny",
|
|
51
|
+
"VisionMambaBase",
|
|
52
|
+
]
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
# =============================================================================
|
|
56
|
+
# SELECTIVE SSM CORE (Pure PyTorch Implementation)
|
|
57
|
+
# =============================================================================
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
class SelectiveSSM(nn.Module):
|
|
61
|
+
"""
|
|
62
|
+
Selective State Space Model (S6) - Core of Mamba.
|
|
63
|
+
|
|
64
|
+
The key innovation is making the SSM parameters (B, C, Δ) input-dependent,
|
|
65
|
+
allowing the model to selectively focus on or ignore inputs.
|
|
66
|
+
|
|
67
|
+
This is a simplified pure-PyTorch implementation. For production use,
|
|
68
|
+
consider the optimized mamba-ssm package.
|
|
69
|
+
"""
|
|
70
|
+
|
|
71
|
+
def __init__(
|
|
72
|
+
self,
|
|
73
|
+
d_model: int,
|
|
74
|
+
d_state: int = 16,
|
|
75
|
+
d_conv: int = 4,
|
|
76
|
+
expand: int = 2,
|
|
77
|
+
):
|
|
78
|
+
super().__init__()
|
|
79
|
+
|
|
80
|
+
self.d_model = d_model
|
|
81
|
+
self.d_state = d_state
|
|
82
|
+
self.d_inner = d_model * expand
|
|
83
|
+
|
|
84
|
+
# Input projection
|
|
85
|
+
self.in_proj = nn.Linear(d_model, self.d_inner * 2, bias=False)
|
|
86
|
+
|
|
87
|
+
# Conv for local context
|
|
88
|
+
self.conv1d = nn.Conv1d(
|
|
89
|
+
self.d_inner,
|
|
90
|
+
self.d_inner,
|
|
91
|
+
kernel_size=d_conv,
|
|
92
|
+
padding=d_conv - 1,
|
|
93
|
+
groups=self.d_inner,
|
|
94
|
+
)
|
|
95
|
+
|
|
96
|
+
# SSM parameters (input-dependent)
|
|
97
|
+
self.x_proj = nn.Linear(self.d_inner, d_state * 2 + 1, bias=False)
|
|
98
|
+
|
|
99
|
+
# Learnable SSM matrices
|
|
100
|
+
self.dt_proj = nn.Linear(1, self.d_inner, bias=True)
|
|
101
|
+
self.A_log = nn.Parameter(
|
|
102
|
+
torch.log(torch.arange(1, d_state + 1, dtype=torch.float32))
|
|
103
|
+
)
|
|
104
|
+
self.D = nn.Parameter(torch.ones(self.d_inner))
|
|
105
|
+
|
|
106
|
+
# Output projection
|
|
107
|
+
self.out_proj = nn.Linear(self.d_inner, d_model, bias=False)
|
|
108
|
+
|
|
109
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
110
|
+
"""
|
|
111
|
+
Args:
|
|
112
|
+
x: (B, L, D) input sequence
|
|
113
|
+
|
|
114
|
+
Returns:
|
|
115
|
+
y: (B, L, D) output sequence
|
|
116
|
+
"""
|
|
117
|
+
_B, L, _D = x.shape
|
|
118
|
+
|
|
119
|
+
# Input projection and split
|
|
120
|
+
xz = self.in_proj(x) # (B, L, 2*d_inner)
|
|
121
|
+
x, z = xz.chunk(2, dim=-1) # Each: (B, L, d_inner)
|
|
122
|
+
|
|
123
|
+
# Conv for local context
|
|
124
|
+
x = x.transpose(1, 2) # (B, d_inner, L)
|
|
125
|
+
x = self.conv1d(x)[:, :, :L] # Causal
|
|
126
|
+
x = x.transpose(1, 2) # (B, L, d_inner)
|
|
127
|
+
x = F.silu(x)
|
|
128
|
+
|
|
129
|
+
# SSM parameters from input
|
|
130
|
+
x_proj = self.x_proj(x) # (B, L, d_state*2 + 1)
|
|
131
|
+
delta = F.softplus(self.dt_proj(x_proj[:, :, :1])) # (B, L, d_inner)
|
|
132
|
+
B_param = x_proj[:, :, 1 : self.d_state + 1] # (B, L, d_state)
|
|
133
|
+
C_param = x_proj[:, :, self.d_state + 1 :] # (B, L, d_state)
|
|
134
|
+
|
|
135
|
+
# Discretize A
|
|
136
|
+
A = -torch.exp(self.A_log) # (d_state,)
|
|
137
|
+
|
|
138
|
+
# Selective scan (simplified, not optimized)
|
|
139
|
+
y = self._selective_scan(x, delta, A, B_param, C_param, self.D)
|
|
140
|
+
|
|
141
|
+
# Gating
|
|
142
|
+
y = y * F.silu(z)
|
|
143
|
+
|
|
144
|
+
# Output projection
|
|
145
|
+
return self.out_proj(y)
|
|
146
|
+
|
|
147
|
+
def _selective_scan(
|
|
148
|
+
self,
|
|
149
|
+
x: torch.Tensor,
|
|
150
|
+
delta: torch.Tensor,
|
|
151
|
+
A: torch.Tensor,
|
|
152
|
+
B: torch.Tensor,
|
|
153
|
+
C: torch.Tensor,
|
|
154
|
+
D: torch.Tensor,
|
|
155
|
+
) -> torch.Tensor:
|
|
156
|
+
"""
|
|
157
|
+
Vectorized selective scan using parallel associative scan.
|
|
158
|
+
|
|
159
|
+
This implementation avoids the sequential for-loop by computing
|
|
160
|
+
all timesteps in parallel using cumulative products and sums.
|
|
161
|
+
~100x faster than the naive sequential implementation.
|
|
162
|
+
"""
|
|
163
|
+
|
|
164
|
+
# Compute discretized A_bar for all timesteps: (B, L, d_inner, d_state)
|
|
165
|
+
A_bar = torch.exp(delta.unsqueeze(-1) * A) # (B, L, d_inner, d_state)
|
|
166
|
+
|
|
167
|
+
# Compute input contribution: delta * B * x for all timesteps
|
|
168
|
+
# B: (B, L, d_state), x: (B, L, d_inner), delta: (B, L, d_inner)
|
|
169
|
+
# Result: (B, L, d_inner, d_state)
|
|
170
|
+
BX = delta.unsqueeze(-1) * B.unsqueeze(2) * x.unsqueeze(-1)
|
|
171
|
+
|
|
172
|
+
# Parallel scan using log-space cumulative products for numerical stability
|
|
173
|
+
# For SSM: h[t] = A_bar[t] * h[t-1] + BX[t]
|
|
174
|
+
# This is a linear recurrence that can be solved with associative scan
|
|
175
|
+
|
|
176
|
+
# Use chunked approach for memory efficiency with parallel scan
|
|
177
|
+
# Compute cumulative product of A_bar (in log space for stability)
|
|
178
|
+
log_A_bar = torch.log(A_bar.clamp(min=1e-10))
|
|
179
|
+
log_A_cumsum = torch.cumsum(log_A_bar, dim=1) # (B, L, d_inner, d_state)
|
|
180
|
+
A_cumsum = torch.exp(log_A_cumsum)
|
|
181
|
+
|
|
182
|
+
# For each timestep t, we need: sum_{s=0}^{t} (prod_{k=s+1}^{t} A_bar[k]) * BX[s]
|
|
183
|
+
# = sum_{s=0}^{t} (A_cumsum[t] / A_cumsum[s]) * BX[s]
|
|
184
|
+
# = A_cumsum[t] * sum_{s=0}^{t} (BX[s] / A_cumsum[s])
|
|
185
|
+
|
|
186
|
+
# Compute BX / A_cumsum (use A_cumsum shifted by 1 for proper indexing)
|
|
187
|
+
# A_cumsum[s] represents prod_{k=0}^{s} A_bar[k], but we need prod_{k=0}^{s-1}
|
|
188
|
+
# So we shift: use A_cumsum from previous timestep
|
|
189
|
+
A_cumsum_shifted = F.pad(A_cumsum[:, :-1], (0, 0, 0, 0, 1, 0), value=1.0)
|
|
190
|
+
|
|
191
|
+
# Weighted input: BX[s] / A_cumsum[s-1] = BX[s] * exp(-log_A_cumsum[s-1])
|
|
192
|
+
weighted_BX = BX / A_cumsum_shifted.clamp(min=1e-10)
|
|
193
|
+
|
|
194
|
+
# Cumulative sum of weighted inputs
|
|
195
|
+
weighted_BX_cumsum = torch.cumsum(weighted_BX, dim=1)
|
|
196
|
+
|
|
197
|
+
# Final state at each timestep: h[t] = A_cumsum[t] * weighted_BX_cumsum[t]
|
|
198
|
+
# But A_cumsum includes A_bar[0], so adjust
|
|
199
|
+
h = A_cumsum * weighted_BX_cumsum / A_bar.clamp(min=1e-10)
|
|
200
|
+
|
|
201
|
+
# Output: y = C * h + D * x
|
|
202
|
+
# h: (B, L, d_inner, d_state), C: (B, L, d_state)
|
|
203
|
+
y = (C.unsqueeze(2) * h).sum(-1) + D * x # (B, L, d_inner)
|
|
204
|
+
|
|
205
|
+
return y
|
|
206
|
+
|
|
207
|
+
|
|
208
|
+
# =============================================================================
|
|
209
|
+
# MAMBA BLOCK
|
|
210
|
+
# =============================================================================
|
|
211
|
+
|
|
212
|
+
|
|
213
|
+
class MambaBlock(nn.Module):
|
|
214
|
+
"""
|
|
215
|
+
Mamba Block with residual connection.
|
|
216
|
+
|
|
217
|
+
Architecture:
|
|
218
|
+
Input → Norm → SelectiveSSM → Residual
|
|
219
|
+
"""
|
|
220
|
+
|
|
221
|
+
def __init__(
|
|
222
|
+
self,
|
|
223
|
+
d_model: int,
|
|
224
|
+
d_state: int = 16,
|
|
225
|
+
d_conv: int = 4,
|
|
226
|
+
expand: int = 2,
|
|
227
|
+
):
|
|
228
|
+
super().__init__()
|
|
229
|
+
self.norm = nn.LayerNorm(d_model)
|
|
230
|
+
self.ssm = SelectiveSSM(d_model, d_state, d_conv, expand)
|
|
231
|
+
|
|
232
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
233
|
+
return x + self.ssm(self.norm(x))
|
|
234
|
+
|
|
235
|
+
|
|
236
|
+
# =============================================================================
|
|
237
|
+
# BIDIRECTIONAL MAMBA (For Vision)
|
|
238
|
+
# =============================================================================
|
|
239
|
+
|
|
240
|
+
|
|
241
|
+
class BidirectionalMambaBlock(nn.Module):
|
|
242
|
+
"""
|
|
243
|
+
Bidirectional Mamba Block for vision tasks.
|
|
244
|
+
|
|
245
|
+
Processes sequence in both forward and backward directions
|
|
246
|
+
to capture global context in images.
|
|
247
|
+
"""
|
|
248
|
+
|
|
249
|
+
def __init__(
|
|
250
|
+
self,
|
|
251
|
+
d_model: int,
|
|
252
|
+
d_state: int = 16,
|
|
253
|
+
d_conv: int = 4,
|
|
254
|
+
expand: int = 2,
|
|
255
|
+
):
|
|
256
|
+
super().__init__()
|
|
257
|
+
self.norm = nn.LayerNorm(d_model)
|
|
258
|
+
self.ssm_forward = SelectiveSSM(d_model, d_state, d_conv, expand)
|
|
259
|
+
self.ssm_backward = SelectiveSSM(d_model, d_state, d_conv, expand)
|
|
260
|
+
self.merge = nn.Linear(d_model * 2, d_model)
|
|
261
|
+
|
|
262
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
263
|
+
x_norm = self.norm(x)
|
|
264
|
+
|
|
265
|
+
# Forward pass
|
|
266
|
+
y_forward = self.ssm_forward(x_norm)
|
|
267
|
+
|
|
268
|
+
# Backward pass (flip, process, flip back)
|
|
269
|
+
y_backward = self.ssm_backward(x_norm.flip(dims=[1])).flip(dims=[1])
|
|
270
|
+
|
|
271
|
+
# Merge
|
|
272
|
+
y = self.merge(torch.cat([y_forward, y_backward], dim=-1))
|
|
273
|
+
|
|
274
|
+
return x + y
|
|
275
|
+
|
|
276
|
+
|
|
277
|
+
# =============================================================================
|
|
278
|
+
# MAMBA 1D (For Time-Series)
|
|
279
|
+
# =============================================================================
|
|
280
|
+
|
|
281
|
+
|
|
282
|
+
class Mamba1DBase(BaseModel):
|
|
283
|
+
"""
|
|
284
|
+
Mamba for 1D time-series data.
|
|
285
|
+
|
|
286
|
+
Alternative to TCN with theoretically infinite receptive field
|
|
287
|
+
and linear complexity.
|
|
288
|
+
"""
|
|
289
|
+
|
|
290
|
+
def __init__(
|
|
291
|
+
self,
|
|
292
|
+
in_shape: tuple[int],
|
|
293
|
+
out_size: int,
|
|
294
|
+
d_model: int = 256,
|
|
295
|
+
n_layers: int = 8,
|
|
296
|
+
d_state: int = 16,
|
|
297
|
+
d_conv: int = 4,
|
|
298
|
+
expand: int = 2,
|
|
299
|
+
dropout_rate: float = 0.1,
|
|
300
|
+
**kwargs,
|
|
301
|
+
):
|
|
302
|
+
super().__init__(in_shape, out_size)
|
|
303
|
+
|
|
304
|
+
if len(in_shape) != 1:
|
|
305
|
+
raise ValueError(f"Mamba1D requires 1D input (L,), got {len(in_shape)}D")
|
|
306
|
+
|
|
307
|
+
self.d_model = d_model
|
|
308
|
+
|
|
309
|
+
# Input projection
|
|
310
|
+
self.input_proj = nn.Linear(1, d_model)
|
|
311
|
+
|
|
312
|
+
# Positional encoding
|
|
313
|
+
self.pos_embed = nn.Parameter(torch.zeros(1, in_shape[0], d_model))
|
|
314
|
+
nn.init.trunc_normal_(self.pos_embed, std=0.02)
|
|
315
|
+
|
|
316
|
+
# Mamba blocks
|
|
317
|
+
self.blocks = nn.ModuleList(
|
|
318
|
+
[MambaBlock(d_model, d_state, d_conv, expand) for _ in range(n_layers)]
|
|
319
|
+
)
|
|
320
|
+
|
|
321
|
+
# Final norm
|
|
322
|
+
self.norm = nn.LayerNorm(d_model)
|
|
323
|
+
|
|
324
|
+
# Regression head
|
|
325
|
+
self.head = nn.Sequential(
|
|
326
|
+
nn.Dropout(dropout_rate),
|
|
327
|
+
nn.Linear(d_model, d_model // 2),
|
|
328
|
+
nn.GELU(),
|
|
329
|
+
nn.Dropout(dropout_rate * 0.5),
|
|
330
|
+
nn.Linear(d_model // 2, out_size),
|
|
331
|
+
)
|
|
332
|
+
|
|
333
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
334
|
+
"""
|
|
335
|
+
Args:
|
|
336
|
+
x: (B, 1, L) input signal
|
|
337
|
+
|
|
338
|
+
Returns:
|
|
339
|
+
(B, out_size) regression output
|
|
340
|
+
"""
|
|
341
|
+
_B, _C, L = x.shape
|
|
342
|
+
|
|
343
|
+
# Reshape to sequence
|
|
344
|
+
x = x.transpose(1, 2) # (B, L, 1)
|
|
345
|
+
x = self.input_proj(x) # (B, L, d_model)
|
|
346
|
+
|
|
347
|
+
# Add positional encoding
|
|
348
|
+
x = x + self.pos_embed[:, :L, :]
|
|
349
|
+
|
|
350
|
+
# Mamba blocks
|
|
351
|
+
for block in self.blocks:
|
|
352
|
+
x = block(x)
|
|
353
|
+
|
|
354
|
+
# Global pooling (mean over sequence)
|
|
355
|
+
x = x.mean(dim=1) # (B, d_model)
|
|
356
|
+
|
|
357
|
+
# Final norm and head
|
|
358
|
+
x = self.norm(x)
|
|
359
|
+
return self.head(x)
|
|
360
|
+
|
|
361
|
+
|
|
362
|
+
# =============================================================================
|
|
363
|
+
# VISION MAMBA (For 2D Images)
|
|
364
|
+
# =============================================================================
|
|
365
|
+
|
|
366
|
+
|
|
367
|
+
class VisionMambaBase(BaseModel):
|
|
368
|
+
"""
|
|
369
|
+
Vision Mamba (Vim) for 2D images.
|
|
370
|
+
|
|
371
|
+
Uses bidirectional SSM to capture global context efficiently.
|
|
372
|
+
O(n) complexity instead of O(n²) for transformers.
|
|
373
|
+
"""
|
|
374
|
+
|
|
375
|
+
def __init__(
|
|
376
|
+
self,
|
|
377
|
+
in_shape: tuple[int, int],
|
|
378
|
+
out_size: int,
|
|
379
|
+
patch_size: int = 16,
|
|
380
|
+
d_model: int = 192,
|
|
381
|
+
n_layers: int = 12,
|
|
382
|
+
d_state: int = 16,
|
|
383
|
+
d_conv: int = 4,
|
|
384
|
+
expand: int = 2,
|
|
385
|
+
dropout_rate: float = 0.1,
|
|
386
|
+
**kwargs,
|
|
387
|
+
):
|
|
388
|
+
super().__init__(in_shape, out_size)
|
|
389
|
+
|
|
390
|
+
if len(in_shape) != 2:
|
|
391
|
+
raise ValueError(
|
|
392
|
+
f"VisionMamba requires 2D input (H, W), got {len(in_shape)}D"
|
|
393
|
+
)
|
|
394
|
+
|
|
395
|
+
self.patch_size = patch_size
|
|
396
|
+
self.d_model = d_model
|
|
397
|
+
|
|
398
|
+
H, W = in_shape
|
|
399
|
+
self.num_patches = (H // patch_size) * (W // patch_size)
|
|
400
|
+
self.grid_size = (H // patch_size, W // patch_size)
|
|
401
|
+
|
|
402
|
+
# Patch embedding
|
|
403
|
+
self.patch_embed = nn.Conv2d(
|
|
404
|
+
1, d_model, kernel_size=patch_size, stride=patch_size
|
|
405
|
+
)
|
|
406
|
+
|
|
407
|
+
# CLS token for classification/regression
|
|
408
|
+
self.cls_token = nn.Parameter(torch.zeros(1, 1, d_model))
|
|
409
|
+
nn.init.trunc_normal_(self.cls_token, std=0.02)
|
|
410
|
+
|
|
411
|
+
# Positional embedding
|
|
412
|
+
self.pos_embed = nn.Parameter(torch.zeros(1, self.num_patches + 1, d_model))
|
|
413
|
+
nn.init.trunc_normal_(self.pos_embed, std=0.02)
|
|
414
|
+
|
|
415
|
+
# Bidirectional Mamba blocks
|
|
416
|
+
self.blocks = nn.ModuleList(
|
|
417
|
+
[
|
|
418
|
+
BidirectionalMambaBlock(d_model, d_state, d_conv, expand)
|
|
419
|
+
for _ in range(n_layers)
|
|
420
|
+
]
|
|
421
|
+
)
|
|
422
|
+
|
|
423
|
+
# Final norm
|
|
424
|
+
self.norm = nn.LayerNorm(d_model)
|
|
425
|
+
|
|
426
|
+
# Regression head
|
|
427
|
+
self.head = nn.Sequential(
|
|
428
|
+
nn.Dropout(dropout_rate),
|
|
429
|
+
nn.Linear(d_model, d_model // 2),
|
|
430
|
+
nn.GELU(),
|
|
431
|
+
nn.Dropout(dropout_rate * 0.5),
|
|
432
|
+
nn.Linear(d_model // 2, out_size),
|
|
433
|
+
)
|
|
434
|
+
|
|
435
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
436
|
+
"""
|
|
437
|
+
Args:
|
|
438
|
+
x: (B, 1, H, W) input image
|
|
439
|
+
|
|
440
|
+
Returns:
|
|
441
|
+
(B, out_size) regression output
|
|
442
|
+
"""
|
|
443
|
+
B = x.shape[0]
|
|
444
|
+
|
|
445
|
+
# Patch embedding
|
|
446
|
+
x = self.patch_embed(x) # (B, d_model, H', W')
|
|
447
|
+
x = x.flatten(2).transpose(1, 2) # (B, num_patches, d_model)
|
|
448
|
+
|
|
449
|
+
# Prepend CLS token
|
|
450
|
+
cls_tokens = self.cls_token.expand(B, -1, -1)
|
|
451
|
+
x = torch.cat([cls_tokens, x], dim=1) # (B, 1 + num_patches, d_model)
|
|
452
|
+
|
|
453
|
+
# Add positional embedding
|
|
454
|
+
x = x + self.pos_embed
|
|
455
|
+
|
|
456
|
+
# Bidirectional Mamba blocks
|
|
457
|
+
for block in self.blocks:
|
|
458
|
+
x = block(x)
|
|
459
|
+
|
|
460
|
+
# Extract CLS token
|
|
461
|
+
cls_output = x[:, 0] # (B, d_model)
|
|
462
|
+
|
|
463
|
+
# Final norm and head
|
|
464
|
+
cls_output = self.norm(cls_output)
|
|
465
|
+
return self.head(cls_output)
|
|
466
|
+
|
|
467
|
+
|
|
468
|
+
# =============================================================================
|
|
469
|
+
# REGISTERED VARIANTS
|
|
470
|
+
# =============================================================================
|
|
471
|
+
|
|
472
|
+
|
|
473
|
+
@register_model("mamba_1d")
|
|
474
|
+
class Mamba1D(Mamba1DBase):
|
|
475
|
+
"""
|
|
476
|
+
Mamba 1D: ~3.4M backbone parameters (for time-series regression).
|
|
477
|
+
|
|
478
|
+
8 layers, 256 dim. Alternative to TCN for time-series.
|
|
479
|
+
Pure PyTorch implementation.
|
|
480
|
+
|
|
481
|
+
Example:
|
|
482
|
+
>>> model = Mamba1D(in_shape=(4096,), out_size=3)
|
|
483
|
+
>>> x = torch.randn(4, 1, 4096)
|
|
484
|
+
>>> out = model(x) # (4, 3)
|
|
485
|
+
"""
|
|
486
|
+
|
|
487
|
+
def __init__(self, in_shape: tuple[int], out_size: int, **kwargs):
|
|
488
|
+
kwargs.setdefault("d_model", 256)
|
|
489
|
+
kwargs.setdefault("n_layers", 8)
|
|
490
|
+
super().__init__(in_shape=in_shape, out_size=out_size, **kwargs)
|
|
491
|
+
|
|
492
|
+
def __repr__(self) -> str:
|
|
493
|
+
return f"Mamba1D(in_shape={self.in_shape}, out_size={self.out_size})"
|
|
494
|
+
|
|
495
|
+
|
|
496
|
+
@register_model("vim_tiny")
|
|
497
|
+
class VimTiny(VisionMambaBase):
|
|
498
|
+
"""
|
|
499
|
+
Vision Mamba Tiny: ~6.6M backbone parameters.
|
|
500
|
+
|
|
501
|
+
12 layers, 192 dim. For 2D images.
|
|
502
|
+
Pure PyTorch implementation with O(n) complexity.
|
|
503
|
+
|
|
504
|
+
Example:
|
|
505
|
+
>>> model = VimTiny(in_shape=(224, 224), out_size=3)
|
|
506
|
+
>>> x = torch.randn(4, 1, 224, 224)
|
|
507
|
+
>>> out = model(x) # (4, 3)
|
|
508
|
+
"""
|
|
509
|
+
|
|
510
|
+
def __init__(self, in_shape: tuple[int, int], out_size: int, **kwargs):
|
|
511
|
+
kwargs.setdefault("patch_size", 16)
|
|
512
|
+
kwargs.setdefault("d_model", 192)
|
|
513
|
+
kwargs.setdefault("n_layers", 12)
|
|
514
|
+
super().__init__(in_shape=in_shape, out_size=out_size, **kwargs)
|
|
515
|
+
|
|
516
|
+
def __repr__(self) -> str:
|
|
517
|
+
return f"VisionMamba_Tiny(in_shape={self.in_shape}, out_size={self.out_size})"
|
|
518
|
+
|
|
519
|
+
|
|
520
|
+
@register_model("vim_small")
|
|
521
|
+
class VimSmall(VisionMambaBase):
|
|
522
|
+
"""
|
|
523
|
+
Vision Mamba Small: ~51.1M backbone parameters.
|
|
524
|
+
|
|
525
|
+
24 layers, 384 dim. For 2D images.
|
|
526
|
+
Pure PyTorch implementation with O(n) complexity.
|
|
527
|
+
"""
|
|
528
|
+
|
|
529
|
+
def __init__(self, in_shape: tuple[int, int], out_size: int, **kwargs):
|
|
530
|
+
kwargs.setdefault("patch_size", 16)
|
|
531
|
+
kwargs.setdefault("d_model", 384)
|
|
532
|
+
kwargs.setdefault("n_layers", 24)
|
|
533
|
+
super().__init__(in_shape=in_shape, out_size=out_size, **kwargs)
|
|
534
|
+
|
|
535
|
+
def __repr__(self) -> str:
|
|
536
|
+
return f"VisionMamba_Small(in_shape={self.in_shape}, out_size={self.out_size})"
|
|
537
|
+
|
|
538
|
+
|
|
539
|
+
@register_model("vim_base")
|
|
540
|
+
class VimBase(VisionMambaBase):
|
|
541
|
+
"""
|
|
542
|
+
Vision Mamba Base: ~201.4M backbone parameters.
|
|
543
|
+
|
|
544
|
+
24 layers, 768 dim. For 2D images.
|
|
545
|
+
Pure PyTorch implementation with O(n) complexity.
|
|
546
|
+
"""
|
|
547
|
+
|
|
548
|
+
def __init__(self, in_shape: tuple[int, int], out_size: int, **kwargs):
|
|
549
|
+
kwargs.setdefault("patch_size", 16)
|
|
550
|
+
kwargs.setdefault("d_model", 768)
|
|
551
|
+
kwargs.setdefault("n_layers", 24)
|
|
552
|
+
super().__init__(in_shape=in_shape, out_size=out_size, **kwargs)
|
|
553
|
+
|
|
554
|
+
def __repr__(self) -> str:
|
|
555
|
+
return f"VisionMamba_Base(in_shape={self.in_shape}, out_size={self.out_size})"
|