wavedl 1.5.7__py3-none-any.whl → 1.6.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (39) hide show
  1. wavedl/__init__.py +1 -1
  2. wavedl/hpo.py +451 -451
  3. wavedl/models/__init__.py +80 -4
  4. wavedl/models/_pretrained_utils.py +366 -0
  5. wavedl/models/base.py +48 -0
  6. wavedl/models/caformer.py +270 -0
  7. wavedl/models/cnn.py +2 -27
  8. wavedl/models/convnext.py +113 -51
  9. wavedl/models/convnext_v2.py +488 -0
  10. wavedl/models/densenet.py +10 -23
  11. wavedl/models/efficientnet.py +6 -6
  12. wavedl/models/efficientnetv2.py +315 -315
  13. wavedl/models/efficientvit.py +398 -0
  14. wavedl/models/fastvit.py +252 -0
  15. wavedl/models/mamba.py +555 -0
  16. wavedl/models/maxvit.py +254 -0
  17. wavedl/models/mobilenetv3.py +295 -295
  18. wavedl/models/regnet.py +406 -406
  19. wavedl/models/resnet.py +19 -61
  20. wavedl/models/resnet3d.py +258 -258
  21. wavedl/models/swin.py +443 -443
  22. wavedl/models/tcn.py +393 -409
  23. wavedl/models/unet.py +2 -6
  24. wavedl/models/unireplknet.py +491 -0
  25. wavedl/models/vit.py +9 -9
  26. wavedl/train.py +1430 -1425
  27. wavedl/utils/config.py +367 -367
  28. wavedl/utils/cross_validation.py +530 -530
  29. wavedl/utils/data.py +39 -6
  30. wavedl/utils/losses.py +216 -216
  31. wavedl/utils/optimizers.py +216 -216
  32. wavedl/utils/schedulers.py +251 -251
  33. {wavedl-1.5.7.dist-info → wavedl-1.6.1.dist-info}/METADATA +150 -82
  34. wavedl-1.6.1.dist-info/RECORD +46 -0
  35. wavedl-1.5.7.dist-info/RECORD +0 -38
  36. {wavedl-1.5.7.dist-info → wavedl-1.6.1.dist-info}/LICENSE +0 -0
  37. {wavedl-1.5.7.dist-info → wavedl-1.6.1.dist-info}/WHEEL +0 -0
  38. {wavedl-1.5.7.dist-info → wavedl-1.6.1.dist-info}/entry_points.txt +0 -0
  39. {wavedl-1.5.7.dist-info → wavedl-1.6.1.dist-info}/top_level.txt +0 -0
wavedl/models/mamba.py ADDED
@@ -0,0 +1,555 @@
1
+ """
2
+ Vision Mamba: Efficient Visual Representation Learning with State Space Models
3
+ ===============================================================================
4
+
5
+ Vision Mamba (Vim) adapts the Mamba selective state space model for vision tasks.
6
+ Provides O(n) linear complexity vs O(n²) for transformers, making it efficient
7
+ for long sequences and high-resolution images.
8
+
9
+ **Key Features**:
10
+ - Bidirectional SSM for image understanding
11
+ - O(n) linear complexity
12
+ - 2.8x faster than ViT, 86.8% less GPU memory
13
+ - Works for 1D (time-series) and 2D (images)
14
+
15
+ **Variants**:
16
+ - mamba_1d: For 1D time-series (alternative to TCN)
17
+ - vim_tiny: 7M params for 2D images
18
+ - vim_small: 26M params for 2D images
19
+ - vim_base: 98M params for 2D images
20
+
21
+ **Dependencies**:
22
+ - Optional: mamba-ssm (for optimized CUDA kernels)
23
+ - Fallback: Pure PyTorch implementation
24
+
25
+ Reference:
26
+ Zhu, L., et al. (2024). Vision Mamba: Efficient Visual Representation
27
+ Learning with Bidirectional State Space Model. ICML 2024.
28
+ https://arxiv.org/abs/2401.09417
29
+
30
+ Author: Ductho Le (ductho.le@outlook.com)
31
+ """
32
+
33
+ import torch
34
+ import torch.nn as nn
35
+ import torch.nn.functional as F
36
+
37
+ from wavedl.models.base import BaseModel, SpatialShape1D, SpatialShape2D
38
+ from wavedl.models.registry import register_model
39
+
40
+
41
+ # Type alias for Mamba models (1D and 2D only)
42
+ SpatialShape = SpatialShape1D | SpatialShape2D
43
+
44
+ __all__ = [
45
+ "Mamba1D",
46
+ "Mamba1DBase",
47
+ "MambaBlock",
48
+ "VimBase",
49
+ "VimSmall",
50
+ "VimTiny",
51
+ "VisionMambaBase",
52
+ ]
53
+
54
+
55
+ # =============================================================================
56
+ # SELECTIVE SSM CORE (Pure PyTorch Implementation)
57
+ # =============================================================================
58
+
59
+
60
+ class SelectiveSSM(nn.Module):
61
+ """
62
+ Selective State Space Model (S6) - Core of Mamba.
63
+
64
+ The key innovation is making the SSM parameters (B, C, Δ) input-dependent,
65
+ allowing the model to selectively focus on or ignore inputs.
66
+
67
+ This is a simplified pure-PyTorch implementation. For production use,
68
+ consider the optimized mamba-ssm package.
69
+ """
70
+
71
+ def __init__(
72
+ self,
73
+ d_model: int,
74
+ d_state: int = 16,
75
+ d_conv: int = 4,
76
+ expand: int = 2,
77
+ ):
78
+ super().__init__()
79
+
80
+ self.d_model = d_model
81
+ self.d_state = d_state
82
+ self.d_inner = d_model * expand
83
+
84
+ # Input projection
85
+ self.in_proj = nn.Linear(d_model, self.d_inner * 2, bias=False)
86
+
87
+ # Conv for local context
88
+ self.conv1d = nn.Conv1d(
89
+ self.d_inner,
90
+ self.d_inner,
91
+ kernel_size=d_conv,
92
+ padding=d_conv - 1,
93
+ groups=self.d_inner,
94
+ )
95
+
96
+ # SSM parameters (input-dependent)
97
+ self.x_proj = nn.Linear(self.d_inner, d_state * 2 + 1, bias=False)
98
+
99
+ # Learnable SSM matrices
100
+ self.dt_proj = nn.Linear(1, self.d_inner, bias=True)
101
+ self.A_log = nn.Parameter(
102
+ torch.log(torch.arange(1, d_state + 1, dtype=torch.float32))
103
+ )
104
+ self.D = nn.Parameter(torch.ones(self.d_inner))
105
+
106
+ # Output projection
107
+ self.out_proj = nn.Linear(self.d_inner, d_model, bias=False)
108
+
109
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
110
+ """
111
+ Args:
112
+ x: (B, L, D) input sequence
113
+
114
+ Returns:
115
+ y: (B, L, D) output sequence
116
+ """
117
+ _B, L, _D = x.shape
118
+
119
+ # Input projection and split
120
+ xz = self.in_proj(x) # (B, L, 2*d_inner)
121
+ x, z = xz.chunk(2, dim=-1) # Each: (B, L, d_inner)
122
+
123
+ # Conv for local context
124
+ x = x.transpose(1, 2) # (B, d_inner, L)
125
+ x = self.conv1d(x)[:, :, :L] # Causal
126
+ x = x.transpose(1, 2) # (B, L, d_inner)
127
+ x = F.silu(x)
128
+
129
+ # SSM parameters from input
130
+ x_proj = self.x_proj(x) # (B, L, d_state*2 + 1)
131
+ delta = F.softplus(self.dt_proj(x_proj[:, :, :1])) # (B, L, d_inner)
132
+ B_param = x_proj[:, :, 1 : self.d_state + 1] # (B, L, d_state)
133
+ C_param = x_proj[:, :, self.d_state + 1 :] # (B, L, d_state)
134
+
135
+ # Discretize A
136
+ A = -torch.exp(self.A_log) # (d_state,)
137
+
138
+ # Selective scan (simplified, not optimized)
139
+ y = self._selective_scan(x, delta, A, B_param, C_param, self.D)
140
+
141
+ # Gating
142
+ y = y * F.silu(z)
143
+
144
+ # Output projection
145
+ return self.out_proj(y)
146
+
147
+ def _selective_scan(
148
+ self,
149
+ x: torch.Tensor,
150
+ delta: torch.Tensor,
151
+ A: torch.Tensor,
152
+ B: torch.Tensor,
153
+ C: torch.Tensor,
154
+ D: torch.Tensor,
155
+ ) -> torch.Tensor:
156
+ """
157
+ Vectorized selective scan using parallel associative scan.
158
+
159
+ This implementation avoids the sequential for-loop by computing
160
+ all timesteps in parallel using cumulative products and sums.
161
+ ~100x faster than the naive sequential implementation.
162
+ """
163
+
164
+ # Compute discretized A_bar for all timesteps: (B, L, d_inner, d_state)
165
+ A_bar = torch.exp(delta.unsqueeze(-1) * A) # (B, L, d_inner, d_state)
166
+
167
+ # Compute input contribution: delta * B * x for all timesteps
168
+ # B: (B, L, d_state), x: (B, L, d_inner), delta: (B, L, d_inner)
169
+ # Result: (B, L, d_inner, d_state)
170
+ BX = delta.unsqueeze(-1) * B.unsqueeze(2) * x.unsqueeze(-1)
171
+
172
+ # Parallel scan using log-space cumulative products for numerical stability
173
+ # For SSM: h[t] = A_bar[t] * h[t-1] + BX[t]
174
+ # This is a linear recurrence that can be solved with associative scan
175
+
176
+ # Use chunked approach for memory efficiency with parallel scan
177
+ # Compute cumulative product of A_bar (in log space for stability)
178
+ log_A_bar = torch.log(A_bar.clamp(min=1e-10))
179
+ log_A_cumsum = torch.cumsum(log_A_bar, dim=1) # (B, L, d_inner, d_state)
180
+ A_cumsum = torch.exp(log_A_cumsum)
181
+
182
+ # For each timestep t, we need: sum_{s=0}^{t} (prod_{k=s+1}^{t} A_bar[k]) * BX[s]
183
+ # = sum_{s=0}^{t} (A_cumsum[t] / A_cumsum[s]) * BX[s]
184
+ # = A_cumsum[t] * sum_{s=0}^{t} (BX[s] / A_cumsum[s])
185
+
186
+ # Compute BX / A_cumsum (use A_cumsum shifted by 1 for proper indexing)
187
+ # A_cumsum[s] represents prod_{k=0}^{s} A_bar[k], but we need prod_{k=0}^{s-1}
188
+ # So we shift: use A_cumsum from previous timestep
189
+ A_cumsum_shifted = F.pad(A_cumsum[:, :-1], (0, 0, 0, 0, 1, 0), value=1.0)
190
+
191
+ # Weighted input: BX[s] / A_cumsum[s-1] = BX[s] * exp(-log_A_cumsum[s-1])
192
+ weighted_BX = BX / A_cumsum_shifted.clamp(min=1e-10)
193
+
194
+ # Cumulative sum of weighted inputs
195
+ weighted_BX_cumsum = torch.cumsum(weighted_BX, dim=1)
196
+
197
+ # Final state at each timestep: h[t] = A_cumsum[t] * weighted_BX_cumsum[t]
198
+ # But A_cumsum includes A_bar[0], so adjust
199
+ h = A_cumsum * weighted_BX_cumsum / A_bar.clamp(min=1e-10)
200
+
201
+ # Output: y = C * h + D * x
202
+ # h: (B, L, d_inner, d_state), C: (B, L, d_state)
203
+ y = (C.unsqueeze(2) * h).sum(-1) + D * x # (B, L, d_inner)
204
+
205
+ return y
206
+
207
+
208
+ # =============================================================================
209
+ # MAMBA BLOCK
210
+ # =============================================================================
211
+
212
+
213
+ class MambaBlock(nn.Module):
214
+ """
215
+ Mamba Block with residual connection.
216
+
217
+ Architecture:
218
+ Input → Norm → SelectiveSSM → Residual
219
+ """
220
+
221
+ def __init__(
222
+ self,
223
+ d_model: int,
224
+ d_state: int = 16,
225
+ d_conv: int = 4,
226
+ expand: int = 2,
227
+ ):
228
+ super().__init__()
229
+ self.norm = nn.LayerNorm(d_model)
230
+ self.ssm = SelectiveSSM(d_model, d_state, d_conv, expand)
231
+
232
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
233
+ return x + self.ssm(self.norm(x))
234
+
235
+
236
+ # =============================================================================
237
+ # BIDIRECTIONAL MAMBA (For Vision)
238
+ # =============================================================================
239
+
240
+
241
+ class BidirectionalMambaBlock(nn.Module):
242
+ """
243
+ Bidirectional Mamba Block for vision tasks.
244
+
245
+ Processes sequence in both forward and backward directions
246
+ to capture global context in images.
247
+ """
248
+
249
+ def __init__(
250
+ self,
251
+ d_model: int,
252
+ d_state: int = 16,
253
+ d_conv: int = 4,
254
+ expand: int = 2,
255
+ ):
256
+ super().__init__()
257
+ self.norm = nn.LayerNorm(d_model)
258
+ self.ssm_forward = SelectiveSSM(d_model, d_state, d_conv, expand)
259
+ self.ssm_backward = SelectiveSSM(d_model, d_state, d_conv, expand)
260
+ self.merge = nn.Linear(d_model * 2, d_model)
261
+
262
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
263
+ x_norm = self.norm(x)
264
+
265
+ # Forward pass
266
+ y_forward = self.ssm_forward(x_norm)
267
+
268
+ # Backward pass (flip, process, flip back)
269
+ y_backward = self.ssm_backward(x_norm.flip(dims=[1])).flip(dims=[1])
270
+
271
+ # Merge
272
+ y = self.merge(torch.cat([y_forward, y_backward], dim=-1))
273
+
274
+ return x + y
275
+
276
+
277
+ # =============================================================================
278
+ # MAMBA 1D (For Time-Series)
279
+ # =============================================================================
280
+
281
+
282
+ class Mamba1DBase(BaseModel):
283
+ """
284
+ Mamba for 1D time-series data.
285
+
286
+ Alternative to TCN with theoretically infinite receptive field
287
+ and linear complexity.
288
+ """
289
+
290
+ def __init__(
291
+ self,
292
+ in_shape: tuple[int],
293
+ out_size: int,
294
+ d_model: int = 256,
295
+ n_layers: int = 8,
296
+ d_state: int = 16,
297
+ d_conv: int = 4,
298
+ expand: int = 2,
299
+ dropout_rate: float = 0.1,
300
+ **kwargs,
301
+ ):
302
+ super().__init__(in_shape, out_size)
303
+
304
+ if len(in_shape) != 1:
305
+ raise ValueError(f"Mamba1D requires 1D input (L,), got {len(in_shape)}D")
306
+
307
+ self.d_model = d_model
308
+
309
+ # Input projection
310
+ self.input_proj = nn.Linear(1, d_model)
311
+
312
+ # Positional encoding
313
+ self.pos_embed = nn.Parameter(torch.zeros(1, in_shape[0], d_model))
314
+ nn.init.trunc_normal_(self.pos_embed, std=0.02)
315
+
316
+ # Mamba blocks
317
+ self.blocks = nn.ModuleList(
318
+ [MambaBlock(d_model, d_state, d_conv, expand) for _ in range(n_layers)]
319
+ )
320
+
321
+ # Final norm
322
+ self.norm = nn.LayerNorm(d_model)
323
+
324
+ # Regression head
325
+ self.head = nn.Sequential(
326
+ nn.Dropout(dropout_rate),
327
+ nn.Linear(d_model, d_model // 2),
328
+ nn.GELU(),
329
+ nn.Dropout(dropout_rate * 0.5),
330
+ nn.Linear(d_model // 2, out_size),
331
+ )
332
+
333
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
334
+ """
335
+ Args:
336
+ x: (B, 1, L) input signal
337
+
338
+ Returns:
339
+ (B, out_size) regression output
340
+ """
341
+ _B, _C, L = x.shape
342
+
343
+ # Reshape to sequence
344
+ x = x.transpose(1, 2) # (B, L, 1)
345
+ x = self.input_proj(x) # (B, L, d_model)
346
+
347
+ # Add positional encoding
348
+ x = x + self.pos_embed[:, :L, :]
349
+
350
+ # Mamba blocks
351
+ for block in self.blocks:
352
+ x = block(x)
353
+
354
+ # Global pooling (mean over sequence)
355
+ x = x.mean(dim=1) # (B, d_model)
356
+
357
+ # Final norm and head
358
+ x = self.norm(x)
359
+ return self.head(x)
360
+
361
+
362
+ # =============================================================================
363
+ # VISION MAMBA (For 2D Images)
364
+ # =============================================================================
365
+
366
+
367
+ class VisionMambaBase(BaseModel):
368
+ """
369
+ Vision Mamba (Vim) for 2D images.
370
+
371
+ Uses bidirectional SSM to capture global context efficiently.
372
+ O(n) complexity instead of O(n²) for transformers.
373
+ """
374
+
375
+ def __init__(
376
+ self,
377
+ in_shape: tuple[int, int],
378
+ out_size: int,
379
+ patch_size: int = 16,
380
+ d_model: int = 192,
381
+ n_layers: int = 12,
382
+ d_state: int = 16,
383
+ d_conv: int = 4,
384
+ expand: int = 2,
385
+ dropout_rate: float = 0.1,
386
+ **kwargs,
387
+ ):
388
+ super().__init__(in_shape, out_size)
389
+
390
+ if len(in_shape) != 2:
391
+ raise ValueError(
392
+ f"VisionMamba requires 2D input (H, W), got {len(in_shape)}D"
393
+ )
394
+
395
+ self.patch_size = patch_size
396
+ self.d_model = d_model
397
+
398
+ H, W = in_shape
399
+ self.num_patches = (H // patch_size) * (W // patch_size)
400
+ self.grid_size = (H // patch_size, W // patch_size)
401
+
402
+ # Patch embedding
403
+ self.patch_embed = nn.Conv2d(
404
+ 1, d_model, kernel_size=patch_size, stride=patch_size
405
+ )
406
+
407
+ # CLS token for classification/regression
408
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, d_model))
409
+ nn.init.trunc_normal_(self.cls_token, std=0.02)
410
+
411
+ # Positional embedding
412
+ self.pos_embed = nn.Parameter(torch.zeros(1, self.num_patches + 1, d_model))
413
+ nn.init.trunc_normal_(self.pos_embed, std=0.02)
414
+
415
+ # Bidirectional Mamba blocks
416
+ self.blocks = nn.ModuleList(
417
+ [
418
+ BidirectionalMambaBlock(d_model, d_state, d_conv, expand)
419
+ for _ in range(n_layers)
420
+ ]
421
+ )
422
+
423
+ # Final norm
424
+ self.norm = nn.LayerNorm(d_model)
425
+
426
+ # Regression head
427
+ self.head = nn.Sequential(
428
+ nn.Dropout(dropout_rate),
429
+ nn.Linear(d_model, d_model // 2),
430
+ nn.GELU(),
431
+ nn.Dropout(dropout_rate * 0.5),
432
+ nn.Linear(d_model // 2, out_size),
433
+ )
434
+
435
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
436
+ """
437
+ Args:
438
+ x: (B, 1, H, W) input image
439
+
440
+ Returns:
441
+ (B, out_size) regression output
442
+ """
443
+ B = x.shape[0]
444
+
445
+ # Patch embedding
446
+ x = self.patch_embed(x) # (B, d_model, H', W')
447
+ x = x.flatten(2).transpose(1, 2) # (B, num_patches, d_model)
448
+
449
+ # Prepend CLS token
450
+ cls_tokens = self.cls_token.expand(B, -1, -1)
451
+ x = torch.cat([cls_tokens, x], dim=1) # (B, 1 + num_patches, d_model)
452
+
453
+ # Add positional embedding
454
+ x = x + self.pos_embed
455
+
456
+ # Bidirectional Mamba blocks
457
+ for block in self.blocks:
458
+ x = block(x)
459
+
460
+ # Extract CLS token
461
+ cls_output = x[:, 0] # (B, d_model)
462
+
463
+ # Final norm and head
464
+ cls_output = self.norm(cls_output)
465
+ return self.head(cls_output)
466
+
467
+
468
+ # =============================================================================
469
+ # REGISTERED VARIANTS
470
+ # =============================================================================
471
+
472
+
473
+ @register_model("mamba_1d")
474
+ class Mamba1D(Mamba1DBase):
475
+ """
476
+ Mamba 1D: ~3.4M backbone parameters (for time-series regression).
477
+
478
+ 8 layers, 256 dim. Alternative to TCN for time-series.
479
+ Pure PyTorch implementation.
480
+
481
+ Example:
482
+ >>> model = Mamba1D(in_shape=(4096,), out_size=3)
483
+ >>> x = torch.randn(4, 1, 4096)
484
+ >>> out = model(x) # (4, 3)
485
+ """
486
+
487
+ def __init__(self, in_shape: tuple[int], out_size: int, **kwargs):
488
+ kwargs.setdefault("d_model", 256)
489
+ kwargs.setdefault("n_layers", 8)
490
+ super().__init__(in_shape=in_shape, out_size=out_size, **kwargs)
491
+
492
+ def __repr__(self) -> str:
493
+ return f"Mamba1D(in_shape={self.in_shape}, out_size={self.out_size})"
494
+
495
+
496
+ @register_model("vim_tiny")
497
+ class VimTiny(VisionMambaBase):
498
+ """
499
+ Vision Mamba Tiny: ~6.6M backbone parameters.
500
+
501
+ 12 layers, 192 dim. For 2D images.
502
+ Pure PyTorch implementation with O(n) complexity.
503
+
504
+ Example:
505
+ >>> model = VimTiny(in_shape=(224, 224), out_size=3)
506
+ >>> x = torch.randn(4, 1, 224, 224)
507
+ >>> out = model(x) # (4, 3)
508
+ """
509
+
510
+ def __init__(self, in_shape: tuple[int, int], out_size: int, **kwargs):
511
+ kwargs.setdefault("patch_size", 16)
512
+ kwargs.setdefault("d_model", 192)
513
+ kwargs.setdefault("n_layers", 12)
514
+ super().__init__(in_shape=in_shape, out_size=out_size, **kwargs)
515
+
516
+ def __repr__(self) -> str:
517
+ return f"VisionMamba_Tiny(in_shape={self.in_shape}, out_size={self.out_size})"
518
+
519
+
520
+ @register_model("vim_small")
521
+ class VimSmall(VisionMambaBase):
522
+ """
523
+ Vision Mamba Small: ~51.1M backbone parameters.
524
+
525
+ 24 layers, 384 dim. For 2D images.
526
+ Pure PyTorch implementation with O(n) complexity.
527
+ """
528
+
529
+ def __init__(self, in_shape: tuple[int, int], out_size: int, **kwargs):
530
+ kwargs.setdefault("patch_size", 16)
531
+ kwargs.setdefault("d_model", 384)
532
+ kwargs.setdefault("n_layers", 24)
533
+ super().__init__(in_shape=in_shape, out_size=out_size, **kwargs)
534
+
535
+ def __repr__(self) -> str:
536
+ return f"VisionMamba_Small(in_shape={self.in_shape}, out_size={self.out_size})"
537
+
538
+
539
+ @register_model("vim_base")
540
+ class VimBase(VisionMambaBase):
541
+ """
542
+ Vision Mamba Base: ~201.4M backbone parameters.
543
+
544
+ 24 layers, 768 dim. For 2D images.
545
+ Pure PyTorch implementation with O(n) complexity.
546
+ """
547
+
548
+ def __init__(self, in_shape: tuple[int, int], out_size: int, **kwargs):
549
+ kwargs.setdefault("patch_size", 16)
550
+ kwargs.setdefault("d_model", 768)
551
+ kwargs.setdefault("n_layers", 24)
552
+ super().__init__(in_shape=in_shape, out_size=out_size, **kwargs)
553
+
554
+ def __repr__(self) -> str:
555
+ return f"VisionMamba_Base(in_shape={self.in_shape}, out_size={self.out_size})"