wavedl 1.5.7__py3-none-any.whl → 1.6.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
wavedl/models/mamba.py ADDED
@@ -0,0 +1,535 @@
1
+ """
2
+ Vision Mamba: Efficient Visual Representation Learning with State Space Models
3
+ ===============================================================================
4
+
5
+ Vision Mamba (Vim) adapts the Mamba selective state space model for vision tasks.
6
+ Provides O(n) linear complexity vs O(n²) for transformers, making it efficient
7
+ for long sequences and high-resolution images.
8
+
9
+ **Key Features**:
10
+ - Bidirectional SSM for image understanding
11
+ - O(n) linear complexity
12
+ - 2.8x faster than ViT, 86.8% less GPU memory
13
+ - Works for 1D (time-series) and 2D (images)
14
+
15
+ **Variants**:
16
+ - mamba_1d: For 1D time-series (alternative to TCN)
17
+ - vim_tiny: 7M params for 2D images
18
+ - vim_small: 26M params for 2D images
19
+ - vim_base: 98M params for 2D images
20
+
21
+ **Dependencies**:
22
+ - Optional: mamba-ssm (for optimized CUDA kernels)
23
+ - Fallback: Pure PyTorch implementation
24
+
25
+ Reference:
26
+ Zhu, L., et al. (2024). Vision Mamba: Efficient Visual Representation
27
+ Learning with Bidirectional State Space Model. ICML 2024.
28
+ https://arxiv.org/abs/2401.09417
29
+
30
+ Author: Ductho Le (ductho.le@outlook.com)
31
+ """
32
+
33
+ import torch
34
+ import torch.nn as nn
35
+ import torch.nn.functional as F
36
+
37
+ from wavedl.models.base import BaseModel
38
+ from wavedl.models.registry import register_model
39
+
40
+
41
+ # Type aliases
42
+ SpatialShape = tuple[int] | tuple[int, int]
43
+
44
+ __all__ = [
45
+ "Mamba1D",
46
+ "Mamba1DBase",
47
+ "MambaBlock",
48
+ "VimBase",
49
+ "VimSmall",
50
+ "VimTiny",
51
+ "VisionMambaBase",
52
+ ]
53
+
54
+
55
+ # =============================================================================
56
+ # SELECTIVE SSM CORE (Pure PyTorch Implementation)
57
+ # =============================================================================
58
+
59
+
60
+ class SelectiveSSM(nn.Module):
61
+ """
62
+ Selective State Space Model (S6) - Core of Mamba.
63
+
64
+ The key innovation is making the SSM parameters (B, C, Δ) input-dependent,
65
+ allowing the model to selectively focus on or ignore inputs.
66
+
67
+ This is a simplified pure-PyTorch implementation. For production use,
68
+ consider the optimized mamba-ssm package.
69
+ """
70
+
71
+ def __init__(
72
+ self,
73
+ d_model: int,
74
+ d_state: int = 16,
75
+ d_conv: int = 4,
76
+ expand: int = 2,
77
+ ):
78
+ super().__init__()
79
+
80
+ self.d_model = d_model
81
+ self.d_state = d_state
82
+ self.d_inner = d_model * expand
83
+
84
+ # Input projection
85
+ self.in_proj = nn.Linear(d_model, self.d_inner * 2, bias=False)
86
+
87
+ # Conv for local context
88
+ self.conv1d = nn.Conv1d(
89
+ self.d_inner,
90
+ self.d_inner,
91
+ kernel_size=d_conv,
92
+ padding=d_conv - 1,
93
+ groups=self.d_inner,
94
+ )
95
+
96
+ # SSM parameters (input-dependent)
97
+ self.x_proj = nn.Linear(self.d_inner, d_state * 2 + 1, bias=False)
98
+
99
+ # Learnable SSM matrices
100
+ self.dt_proj = nn.Linear(1, self.d_inner, bias=True)
101
+ self.A_log = nn.Parameter(
102
+ torch.log(torch.arange(1, d_state + 1, dtype=torch.float32))
103
+ )
104
+ self.D = nn.Parameter(torch.ones(self.d_inner))
105
+
106
+ # Output projection
107
+ self.out_proj = nn.Linear(self.d_inner, d_model, bias=False)
108
+
109
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
110
+ """
111
+ Args:
112
+ x: (B, L, D) input sequence
113
+
114
+ Returns:
115
+ y: (B, L, D) output sequence
116
+ """
117
+ _B, L, _D = x.shape
118
+
119
+ # Input projection and split
120
+ xz = self.in_proj(x) # (B, L, 2*d_inner)
121
+ x, z = xz.chunk(2, dim=-1) # Each: (B, L, d_inner)
122
+
123
+ # Conv for local context
124
+ x = x.transpose(1, 2) # (B, d_inner, L)
125
+ x = self.conv1d(x)[:, :, :L] # Causal
126
+ x = x.transpose(1, 2) # (B, L, d_inner)
127
+ x = F.silu(x)
128
+
129
+ # SSM parameters from input
130
+ x_proj = self.x_proj(x) # (B, L, d_state*2 + 1)
131
+ delta = F.softplus(self.dt_proj(x_proj[:, :, :1])) # (B, L, d_inner)
132
+ B_param = x_proj[:, :, 1 : self.d_state + 1] # (B, L, d_state)
133
+ C_param = x_proj[:, :, self.d_state + 1 :] # (B, L, d_state)
134
+
135
+ # Discretize A
136
+ A = -torch.exp(self.A_log) # (d_state,)
137
+
138
+ # Selective scan (simplified, not optimized)
139
+ y = self._selective_scan(x, delta, A, B_param, C_param, self.D)
140
+
141
+ # Gating
142
+ y = y * F.silu(z)
143
+
144
+ # Output projection
145
+ return self.out_proj(y)
146
+
147
+ def _selective_scan(
148
+ self,
149
+ x: torch.Tensor,
150
+ delta: torch.Tensor,
151
+ A: torch.Tensor,
152
+ B: torch.Tensor,
153
+ C: torch.Tensor,
154
+ D: torch.Tensor,
155
+ ) -> torch.Tensor:
156
+ """
157
+ Simplified selective scan.
158
+
159
+ For real applications, use the CUDA-optimized version from mamba-ssm.
160
+ This implementation is for understanding and testing only.
161
+ """
162
+ B_batch, L, d_inner = x.shape
163
+ d_state = A.shape[0]
164
+
165
+ # Initialize state
166
+ h = torch.zeros(B_batch, d_inner, d_state, device=x.device, dtype=x.dtype)
167
+
168
+ outputs = []
169
+ for t in range(L):
170
+ x_t = x[:, t, :] # (B, d_inner)
171
+ delta_t = delta[:, t, :] # (B, d_inner)
172
+ B_t = B[:, t, :] # (B, d_state)
173
+ C_t = C[:, t, :] # (B, d_state)
174
+
175
+ # Discretize: A_bar = exp(delta * A)
176
+ A_bar = torch.exp(delta_t.unsqueeze(-1) * A) # (B, d_inner, d_state)
177
+
178
+ # Update state: h = A_bar * h + delta * B * x
179
+ h = A_bar * h + delta_t.unsqueeze(-1) * B_t.unsqueeze(1) * x_t.unsqueeze(-1)
180
+
181
+ # Output: y = C * h + D * x
182
+ y_t = (C_t.unsqueeze(1) * h).sum(-1) + D * x_t # (B, d_inner)
183
+ outputs.append(y_t)
184
+
185
+ return torch.stack(outputs, dim=1) # (B, L, d_inner)
186
+
187
+
188
+ # =============================================================================
189
+ # MAMBA BLOCK
190
+ # =============================================================================
191
+
192
+
193
+ class MambaBlock(nn.Module):
194
+ """
195
+ Mamba Block with residual connection.
196
+
197
+ Architecture:
198
+ Input → Norm → SelectiveSSM → Residual
199
+ """
200
+
201
+ def __init__(
202
+ self,
203
+ d_model: int,
204
+ d_state: int = 16,
205
+ d_conv: int = 4,
206
+ expand: int = 2,
207
+ ):
208
+ super().__init__()
209
+ self.norm = nn.LayerNorm(d_model)
210
+ self.ssm = SelectiveSSM(d_model, d_state, d_conv, expand)
211
+
212
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
213
+ return x + self.ssm(self.norm(x))
214
+
215
+
216
+ # =============================================================================
217
+ # BIDIRECTIONAL MAMBA (For Vision)
218
+ # =============================================================================
219
+
220
+
221
+ class BidirectionalMambaBlock(nn.Module):
222
+ """
223
+ Bidirectional Mamba Block for vision tasks.
224
+
225
+ Processes sequence in both forward and backward directions
226
+ to capture global context in images.
227
+ """
228
+
229
+ def __init__(
230
+ self,
231
+ d_model: int,
232
+ d_state: int = 16,
233
+ d_conv: int = 4,
234
+ expand: int = 2,
235
+ ):
236
+ super().__init__()
237
+ self.norm = nn.LayerNorm(d_model)
238
+ self.ssm_forward = SelectiveSSM(d_model, d_state, d_conv, expand)
239
+ self.ssm_backward = SelectiveSSM(d_model, d_state, d_conv, expand)
240
+ self.merge = nn.Linear(d_model * 2, d_model)
241
+
242
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
243
+ x_norm = self.norm(x)
244
+
245
+ # Forward pass
246
+ y_forward = self.ssm_forward(x_norm)
247
+
248
+ # Backward pass (flip, process, flip back)
249
+ y_backward = self.ssm_backward(x_norm.flip(dims=[1])).flip(dims=[1])
250
+
251
+ # Merge
252
+ y = self.merge(torch.cat([y_forward, y_backward], dim=-1))
253
+
254
+ return x + y
255
+
256
+
257
+ # =============================================================================
258
+ # MAMBA 1D (For Time-Series)
259
+ # =============================================================================
260
+
261
+
262
+ class Mamba1DBase(BaseModel):
263
+ """
264
+ Mamba for 1D time-series data.
265
+
266
+ Alternative to TCN with theoretically infinite receptive field
267
+ and linear complexity.
268
+ """
269
+
270
+ def __init__(
271
+ self,
272
+ in_shape: tuple[int],
273
+ out_size: int,
274
+ d_model: int = 256,
275
+ n_layers: int = 8,
276
+ d_state: int = 16,
277
+ d_conv: int = 4,
278
+ expand: int = 2,
279
+ dropout_rate: float = 0.1,
280
+ **kwargs,
281
+ ):
282
+ super().__init__(in_shape, out_size)
283
+
284
+ if len(in_shape) != 1:
285
+ raise ValueError(f"Mamba1D requires 1D input (L,), got {len(in_shape)}D")
286
+
287
+ self.d_model = d_model
288
+
289
+ # Input projection
290
+ self.input_proj = nn.Linear(1, d_model)
291
+
292
+ # Positional encoding
293
+ self.pos_embed = nn.Parameter(torch.zeros(1, in_shape[0], d_model))
294
+ nn.init.trunc_normal_(self.pos_embed, std=0.02)
295
+
296
+ # Mamba blocks
297
+ self.blocks = nn.ModuleList(
298
+ [MambaBlock(d_model, d_state, d_conv, expand) for _ in range(n_layers)]
299
+ )
300
+
301
+ # Final norm
302
+ self.norm = nn.LayerNorm(d_model)
303
+
304
+ # Regression head
305
+ self.head = nn.Sequential(
306
+ nn.Dropout(dropout_rate),
307
+ nn.Linear(d_model, d_model // 2),
308
+ nn.GELU(),
309
+ nn.Dropout(dropout_rate * 0.5),
310
+ nn.Linear(d_model // 2, out_size),
311
+ )
312
+
313
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
314
+ """
315
+ Args:
316
+ x: (B, 1, L) input signal
317
+
318
+ Returns:
319
+ (B, out_size) regression output
320
+ """
321
+ _B, _C, L = x.shape
322
+
323
+ # Reshape to sequence
324
+ x = x.transpose(1, 2) # (B, L, 1)
325
+ x = self.input_proj(x) # (B, L, d_model)
326
+
327
+ # Add positional encoding
328
+ x = x + self.pos_embed[:, :L, :]
329
+
330
+ # Mamba blocks
331
+ for block in self.blocks:
332
+ x = block(x)
333
+
334
+ # Global pooling (mean over sequence)
335
+ x = x.mean(dim=1) # (B, d_model)
336
+
337
+ # Final norm and head
338
+ x = self.norm(x)
339
+ return self.head(x)
340
+
341
+
342
+ # =============================================================================
343
+ # VISION MAMBA (For 2D Images)
344
+ # =============================================================================
345
+
346
+
347
+ class VisionMambaBase(BaseModel):
348
+ """
349
+ Vision Mamba (Vim) for 2D images.
350
+
351
+ Uses bidirectional SSM to capture global context efficiently.
352
+ O(n) complexity instead of O(n²) for transformers.
353
+ """
354
+
355
+ def __init__(
356
+ self,
357
+ in_shape: tuple[int, int],
358
+ out_size: int,
359
+ patch_size: int = 16,
360
+ d_model: int = 192,
361
+ n_layers: int = 12,
362
+ d_state: int = 16,
363
+ d_conv: int = 4,
364
+ expand: int = 2,
365
+ dropout_rate: float = 0.1,
366
+ **kwargs,
367
+ ):
368
+ super().__init__(in_shape, out_size)
369
+
370
+ if len(in_shape) != 2:
371
+ raise ValueError(
372
+ f"VisionMamba requires 2D input (H, W), got {len(in_shape)}D"
373
+ )
374
+
375
+ self.patch_size = patch_size
376
+ self.d_model = d_model
377
+
378
+ H, W = in_shape
379
+ self.num_patches = (H // patch_size) * (W // patch_size)
380
+ self.grid_size = (H // patch_size, W // patch_size)
381
+
382
+ # Patch embedding
383
+ self.patch_embed = nn.Conv2d(
384
+ 1, d_model, kernel_size=patch_size, stride=patch_size
385
+ )
386
+
387
+ # CLS token for classification/regression
388
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, d_model))
389
+ nn.init.trunc_normal_(self.cls_token, std=0.02)
390
+
391
+ # Positional embedding
392
+ self.pos_embed = nn.Parameter(torch.zeros(1, self.num_patches + 1, d_model))
393
+ nn.init.trunc_normal_(self.pos_embed, std=0.02)
394
+
395
+ # Bidirectional Mamba blocks
396
+ self.blocks = nn.ModuleList(
397
+ [
398
+ BidirectionalMambaBlock(d_model, d_state, d_conv, expand)
399
+ for _ in range(n_layers)
400
+ ]
401
+ )
402
+
403
+ # Final norm
404
+ self.norm = nn.LayerNorm(d_model)
405
+
406
+ # Regression head
407
+ self.head = nn.Sequential(
408
+ nn.Dropout(dropout_rate),
409
+ nn.Linear(d_model, d_model // 2),
410
+ nn.GELU(),
411
+ nn.Dropout(dropout_rate * 0.5),
412
+ nn.Linear(d_model // 2, out_size),
413
+ )
414
+
415
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
416
+ """
417
+ Args:
418
+ x: (B, 1, H, W) input image
419
+
420
+ Returns:
421
+ (B, out_size) regression output
422
+ """
423
+ B = x.shape[0]
424
+
425
+ # Patch embedding
426
+ x = self.patch_embed(x) # (B, d_model, H', W')
427
+ x = x.flatten(2).transpose(1, 2) # (B, num_patches, d_model)
428
+
429
+ # Prepend CLS token
430
+ cls_tokens = self.cls_token.expand(B, -1, -1)
431
+ x = torch.cat([cls_tokens, x], dim=1) # (B, 1 + num_patches, d_model)
432
+
433
+ # Add positional embedding
434
+ x = x + self.pos_embed
435
+
436
+ # Bidirectional Mamba blocks
437
+ for block in self.blocks:
438
+ x = block(x)
439
+
440
+ # Extract CLS token
441
+ cls_output = x[:, 0] # (B, d_model)
442
+
443
+ # Final norm and head
444
+ cls_output = self.norm(cls_output)
445
+ return self.head(cls_output)
446
+
447
+
448
+ # =============================================================================
449
+ # REGISTERED VARIANTS
450
+ # =============================================================================
451
+
452
+
453
+ @register_model("mamba_1d")
454
+ class Mamba1D(Mamba1DBase):
455
+ """
456
+ Mamba 1D: ~3.4M backbone parameters (for time-series regression).
457
+
458
+ 8 layers, 256 dim. Alternative to TCN for time-series.
459
+ Pure PyTorch implementation.
460
+
461
+ Example:
462
+ >>> model = Mamba1D(in_shape=(4096,), out_size=3)
463
+ >>> x = torch.randn(4, 1, 4096)
464
+ >>> out = model(x) # (4, 3)
465
+ """
466
+
467
+ def __init__(self, in_shape: tuple[int], out_size: int, **kwargs):
468
+ kwargs.setdefault("d_model", 256)
469
+ kwargs.setdefault("n_layers", 8)
470
+ super().__init__(in_shape=in_shape, out_size=out_size, **kwargs)
471
+
472
+ def __repr__(self) -> str:
473
+ return f"Mamba1D(in_shape={self.in_shape}, out_size={self.out_size})"
474
+
475
+
476
+ @register_model("vim_tiny")
477
+ class VimTiny(VisionMambaBase):
478
+ """
479
+ Vision Mamba Tiny: ~6.6M backbone parameters.
480
+
481
+ 12 layers, 192 dim. For 2D images.
482
+ Pure PyTorch implementation with O(n) complexity.
483
+
484
+ Example:
485
+ >>> model = VimTiny(in_shape=(224, 224), out_size=3)
486
+ >>> x = torch.randn(4, 1, 224, 224)
487
+ >>> out = model(x) # (4, 3)
488
+ """
489
+
490
+ def __init__(self, in_shape: tuple[int, int], out_size: int, **kwargs):
491
+ kwargs.setdefault("patch_size", 16)
492
+ kwargs.setdefault("d_model", 192)
493
+ kwargs.setdefault("n_layers", 12)
494
+ super().__init__(in_shape=in_shape, out_size=out_size, **kwargs)
495
+
496
+ def __repr__(self) -> str:
497
+ return f"VisionMamba_Tiny(in_shape={self.in_shape}, out_size={self.out_size})"
498
+
499
+
500
+ @register_model("vim_small")
501
+ class VimSmall(VisionMambaBase):
502
+ """
503
+ Vision Mamba Small: ~51.1M backbone parameters.
504
+
505
+ 24 layers, 384 dim. For 2D images.
506
+ Pure PyTorch implementation with O(n) complexity.
507
+ """
508
+
509
+ def __init__(self, in_shape: tuple[int, int], out_size: int, **kwargs):
510
+ kwargs.setdefault("patch_size", 16)
511
+ kwargs.setdefault("d_model", 384)
512
+ kwargs.setdefault("n_layers", 24)
513
+ super().__init__(in_shape=in_shape, out_size=out_size, **kwargs)
514
+
515
+ def __repr__(self) -> str:
516
+ return f"VisionMamba_Small(in_shape={self.in_shape}, out_size={self.out_size})"
517
+
518
+
519
+ @register_model("vim_base")
520
+ class VimBase(VisionMambaBase):
521
+ """
522
+ Vision Mamba Base: ~201.4M backbone parameters.
523
+
524
+ 24 layers, 768 dim. For 2D images.
525
+ Pure PyTorch implementation with O(n) complexity.
526
+ """
527
+
528
+ def __init__(self, in_shape: tuple[int, int], out_size: int, **kwargs):
529
+ kwargs.setdefault("patch_size", 16)
530
+ kwargs.setdefault("d_model", 768)
531
+ kwargs.setdefault("n_layers", 24)
532
+ super().__init__(in_shape=in_shape, out_size=out_size, **kwargs)
533
+
534
+ def __repr__(self) -> str:
535
+ return f"VisionMamba_Base(in_shape={self.in_shape}, out_size={self.out_size})"