wavedl 1.5.7__tar.gz → 1.6.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (49) hide show
  1. {wavedl-1.5.7/src/wavedl.egg-info → wavedl-1.6.0}/PKG-INFO +90 -62
  2. {wavedl-1.5.7 → wavedl-1.6.0}/README.md +88 -61
  3. {wavedl-1.5.7 → wavedl-1.6.0}/pyproject.toml +1 -0
  4. {wavedl-1.5.7 → wavedl-1.6.0}/src/wavedl/__init__.py +1 -1
  5. {wavedl-1.5.7 → wavedl-1.6.0}/src/wavedl/models/__init__.py +52 -4
  6. wavedl-1.6.0/src/wavedl/models/_timm_utils.py +238 -0
  7. wavedl-1.6.0/src/wavedl/models/caformer.py +270 -0
  8. {wavedl-1.5.7 → wavedl-1.6.0}/src/wavedl/models/convnext.py +108 -33
  9. wavedl-1.6.0/src/wavedl/models/convnext_v2.py +504 -0
  10. {wavedl-1.5.7 → wavedl-1.6.0}/src/wavedl/models/densenet.py +5 -5
  11. {wavedl-1.5.7 → wavedl-1.6.0}/src/wavedl/models/efficientnet.py +6 -6
  12. {wavedl-1.5.7 → wavedl-1.6.0}/src/wavedl/models/efficientnetv2.py +3 -3
  13. wavedl-1.6.0/src/wavedl/models/fastvit.py +285 -0
  14. wavedl-1.6.0/src/wavedl/models/mamba.py +535 -0
  15. wavedl-1.6.0/src/wavedl/models/maxvit.py +251 -0
  16. {wavedl-1.5.7 → wavedl-1.6.0}/src/wavedl/models/mobilenetv3.py +6 -6
  17. {wavedl-1.5.7 → wavedl-1.6.0}/src/wavedl/models/regnet.py +10 -10
  18. {wavedl-1.5.7 → wavedl-1.6.0}/src/wavedl/models/resnet.py +5 -5
  19. {wavedl-1.5.7 → wavedl-1.6.0}/src/wavedl/models/resnet3d.py +2 -2
  20. {wavedl-1.5.7 → wavedl-1.6.0}/src/wavedl/models/swin.py +3 -3
  21. {wavedl-1.5.7 → wavedl-1.6.0}/src/wavedl/models/tcn.py +3 -3
  22. {wavedl-1.5.7 → wavedl-1.6.0}/src/wavedl/models/unet.py +1 -1
  23. {wavedl-1.5.7 → wavedl-1.6.0}/src/wavedl/models/vit.py +6 -6
  24. {wavedl-1.5.7 → wavedl-1.6.0}/src/wavedl/train.py +21 -16
  25. {wavedl-1.5.7 → wavedl-1.6.0}/src/wavedl/utils/data.py +39 -6
  26. {wavedl-1.5.7 → wavedl-1.6.0/src/wavedl.egg-info}/PKG-INFO +90 -62
  27. {wavedl-1.5.7 → wavedl-1.6.0}/src/wavedl.egg-info/SOURCES.txt +6 -0
  28. {wavedl-1.5.7 → wavedl-1.6.0}/src/wavedl.egg-info/requires.txt +1 -0
  29. {wavedl-1.5.7 → wavedl-1.6.0}/LICENSE +0 -0
  30. {wavedl-1.5.7 → wavedl-1.6.0}/setup.cfg +0 -0
  31. {wavedl-1.5.7 → wavedl-1.6.0}/src/wavedl/hpc.py +0 -0
  32. {wavedl-1.5.7 → wavedl-1.6.0}/src/wavedl/hpo.py +0 -0
  33. {wavedl-1.5.7 → wavedl-1.6.0}/src/wavedl/models/_template.py +0 -0
  34. {wavedl-1.5.7 → wavedl-1.6.0}/src/wavedl/models/base.py +0 -0
  35. {wavedl-1.5.7 → wavedl-1.6.0}/src/wavedl/models/cnn.py +0 -0
  36. {wavedl-1.5.7 → wavedl-1.6.0}/src/wavedl/models/registry.py +0 -0
  37. {wavedl-1.5.7 → wavedl-1.6.0}/src/wavedl/test.py +0 -0
  38. {wavedl-1.5.7 → wavedl-1.6.0}/src/wavedl/utils/__init__.py +0 -0
  39. {wavedl-1.5.7 → wavedl-1.6.0}/src/wavedl/utils/config.py +0 -0
  40. {wavedl-1.5.7 → wavedl-1.6.0}/src/wavedl/utils/constraints.py +0 -0
  41. {wavedl-1.5.7 → wavedl-1.6.0}/src/wavedl/utils/cross_validation.py +0 -0
  42. {wavedl-1.5.7 → wavedl-1.6.0}/src/wavedl/utils/distributed.py +0 -0
  43. {wavedl-1.5.7 → wavedl-1.6.0}/src/wavedl/utils/losses.py +0 -0
  44. {wavedl-1.5.7 → wavedl-1.6.0}/src/wavedl/utils/metrics.py +0 -0
  45. {wavedl-1.5.7 → wavedl-1.6.0}/src/wavedl/utils/optimizers.py +0 -0
  46. {wavedl-1.5.7 → wavedl-1.6.0}/src/wavedl/utils/schedulers.py +0 -0
  47. {wavedl-1.5.7 → wavedl-1.6.0}/src/wavedl.egg-info/dependency_links.txt +0 -0
  48. {wavedl-1.5.7 → wavedl-1.6.0}/src/wavedl.egg-info/entry_points.txt +0 -0
  49. {wavedl-1.5.7 → wavedl-1.6.0}/src/wavedl.egg-info/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.2
2
2
  Name: wavedl
3
- Version: 1.5.7
3
+ Version: 1.6.0
4
4
  Summary: A Scalable Deep Learning Framework for Wave-Based Inverse Problems
5
5
  Author: Ductho Le
6
6
  License: MIT
@@ -23,6 +23,7 @@ Description-Content-Type: text/markdown
23
23
  License-File: LICENSE
24
24
  Requires-Dist: torch>=2.0.0
25
25
  Requires-Dist: torchvision>=0.15.0
26
+ Requires-Dist: timm>=0.9.0
26
27
  Requires-Dist: accelerate>=0.20.0
27
28
  Requires-Dist: numpy>=1.24.0
28
29
  Requires-Dist: scipy>=1.10.0
@@ -117,7 +118,7 @@ Train on datasets larger than RAM:
117
118
 
118
119
  **🧠 Models? We've Got Options**
119
120
 
120
- 38 architectures, ready to go:
121
+ 57 architectures, ready to go:
121
122
  - CNNs, ResNets, ViTs, EfficientNets...
122
123
  - All adapted for regression
123
124
  - [Add your own](#adding-custom-models) in one line
@@ -202,7 +203,7 @@ Deploy models anywhere:
202
203
  #### From PyPI (recommended for all users)
203
204
 
204
205
  ```bash
205
- pip install wavedl
206
+ pip install --upgrade wavedl
206
207
  ```
207
208
 
208
209
  This installs everything you need: training, inference, HPO, ONNX export.
@@ -358,22 +359,10 @@ WaveDL/
358
359
  │ ├── hpo.py # Hyperparameter optimization
359
360
  │ ├── hpc.py # HPC distributed training launcher
360
361
  │ │
361
- │ ├── models/ # Model architectures (38 variants)
362
+ │ ├── models/ # Model Zoo (57 architectures)
362
363
  │ │ ├── registry.py # Model factory (@register_model)
363
364
  │ │ ├── base.py # Abstract base class
364
- │ │ ├── cnn.py # Baseline CNN (1D/2D/3D)
365
- │ │ ├── resnet.py # ResNet-18/34/50 (1D/2D/3D)
366
- │ │ ├── resnet3d.py # ResNet3D-18, MC3-18 (3D only)
367
- │ │ ├── tcn.py # TCN (1D only)
368
- │ │ ├── efficientnet.py # EfficientNet-B0/B1/B2 (2D)
369
- │ │ ├── efficientnetv2.py # EfficientNetV2-S/M/L (2D)
370
- │ │ ├── mobilenetv3.py # MobileNetV3-Small/Large (2D)
371
- │ │ ├── regnet.py # RegNetY variants (2D)
372
- │ │ ├── swin.py # Swin Transformer (2D)
373
- │ │ ├── vit.py # Vision Transformer (1D/2D)
374
- │ │ ├── convnext.py # ConvNeXt (1D/2D/3D)
375
- │ │ ├── densenet.py # DenseNet-121/169 (1D/2D/3D)
376
- │ │ └── unet.py # U-Net Regression
365
+ │ │ └── ... # See "Available Models" section
377
366
  │ │
378
367
  │ └── utils/ # Utilities
379
368
  │ ├── data.py # Memory-mapped data pipeline
@@ -388,7 +377,7 @@ WaveDL/
388
377
  ├── configs/ # YAML config templates
389
378
  ├── examples/ # Ready-to-run examples
390
379
  ├── notebooks/ # Jupyter notebooks
391
- ├── unit_tests/ # Pytest test suite (903 tests)
380
+ ├── unit_tests/ # Pytest test suite
392
381
 
393
382
  ├── pyproject.toml # Package config, dependencies
394
383
  ├── CHANGELOG.md # Version history
@@ -411,71 +400,96 @@ WaveDL/
411
400
  > ```
412
401
 
413
402
  <details>
414
- <summary><b>Available Models</b> — 38 architectures</summary>
403
+ <summary><b>Available Models</b> — 57 architectures</summary>
415
404
 
416
- | Model | Params | Dim |
417
- |-------|--------|-----|
405
+ | Model | Backbone Params | Dim |
406
+ |-------|-----------------|-----|
418
407
  | **CNN** — Convolutional Neural Network |||
419
- | `cnn` | 1.7M | 1D/2D/3D |
408
+ | `cnn` | 1.6M | 1D/2D/3D |
420
409
  | **ResNet** — Residual Network |||
421
- | `resnet18` | 11.4M | 1D/2D/3D |
422
- | `resnet34` | 21.5M | 1D/2D/3D |
423
- | `resnet50` | 24.6M | 1D/2D/3D |
424
- | `resnet18_pretrained` ⭐ | 11.4M | 2D |
425
- | `resnet50_pretrained` ⭐ | 24.6M | 2D |
410
+ | `resnet18` | 11.2M | 1D/2D/3D |
411
+ | `resnet34` | 21.3M | 1D/2D/3D |
412
+ | `resnet50` | 23.5M | 1D/2D/3D |
413
+ | `resnet18_pretrained` ⭐ | 11.2M | 2D |
414
+ | `resnet50_pretrained` ⭐ | 23.5M | 2D |
426
415
  | **ResNet3D** — 3D Residual Network |||
427
- | `resnet3d_18` | 33.6M | 3D |
428
- | `mc3_18` — Mixed Convolution 3D | 11.9M | 3D |
416
+ | `resnet3d_18` | 33.2M | 3D |
417
+ | `mc3_18` — Mixed Convolution 3D | 11.5M | 3D |
429
418
  | **TCN** — Temporal Convolutional Network |||
430
- | `tcn_small` | 1.0M | 1D |
431
- | `tcn` | 7.0M | 1D |
432
- | `tcn_large` | 10.2M | 1D |
419
+ | `tcn_small` | 0.9M | 1D |
420
+ | `tcn` | 6.9M | 1D |
421
+ | `tcn_large` | 10.0M | 1D |
433
422
  | **EfficientNet** — Efficient Neural Network |||
434
- | `efficientnet_b0` ⭐ | 4.7M | 2D |
435
- | `efficientnet_b1` ⭐ | 7.2M | 2D |
436
- | `efficientnet_b2` ⭐ | 8.4M | 2D |
423
+ | `efficientnet_b0` ⭐ | 4.0M | 2D |
424
+ | `efficientnet_b1` ⭐ | 6.5M | 2D |
425
+ | `efficientnet_b2` ⭐ | 7.7M | 2D |
437
426
  | **EfficientNetV2** — Efficient Neural Network V2 |||
438
- | `efficientnet_v2_s` ⭐ | 21.0M | 2D |
439
- | `efficientnet_v2_m` ⭐ | 53.6M | 2D |
440
- | `efficientnet_v2_l` ⭐ | 118.0M | 2D |
427
+ | `efficientnet_v2_s` ⭐ | 20.2M | 2D |
428
+ | `efficientnet_v2_m` ⭐ | 52.9M | 2D |
429
+ | `efficientnet_v2_l` ⭐ | 117.2M | 2D |
441
430
  | **MobileNetV3** — Mobile Neural Network V3 |||
442
- | `mobilenet_v3_small` ⭐ | 1.1M | 2D |
443
- | `mobilenet_v3_large` ⭐ | 3.2M | 2D |
431
+ | `mobilenet_v3_small` ⭐ | 0.9M | 2D |
432
+ | `mobilenet_v3_large` ⭐ | 3.0M | 2D |
444
433
  | **RegNet** — Regularized Network |||
445
- | `regnet_y_400mf` ⭐ | 4.0M | 2D |
446
- | `regnet_y_800mf` ⭐ | 5.8M | 2D |
447
- | `regnet_y_1_6gf` ⭐ | 10.5M | 2D |
448
- | `regnet_y_3_2gf` ⭐ | 18.3M | 2D |
449
- | `regnet_y_8gf` ⭐ | 37.9M | 2D |
434
+ | `regnet_y_400mf` ⭐ | 3.9M | 2D |
435
+ | `regnet_y_800mf` ⭐ | 5.7M | 2D |
436
+ | `regnet_y_1_6gf` ⭐ | 10.3M | 2D |
437
+ | `regnet_y_3_2gf` ⭐ | 17.9M | 2D |
438
+ | `regnet_y_8gf` ⭐ | 37.4M | 2D |
450
439
  | **Swin** — Shifted Window Transformer |||
451
- | `swin_t` ⭐ | 28.0M | 2D |
452
- | `swin_s` ⭐ | 49.4M | 2D |
453
- | `swin_b` ⭐ | 87.4M | 2D |
440
+ | `swin_t` ⭐ | 27.5M | 2D |
441
+ | `swin_s` ⭐ | 48.8M | 2D |
442
+ | `swin_b` ⭐ | 86.7M | 2D |
454
443
  | **ConvNeXt** — Convolutional Next |||
455
- | `convnext_tiny` | 28.2M | 1D/2D/3D |
456
- | `convnext_small` | 49.8M | 1D/2D/3D |
457
- | `convnext_base` | 88.1M | 1D/2D/3D |
458
- | `convnext_tiny_pretrained` ⭐ | 28.2M | 2D |
444
+ | `convnext_tiny` | 27.8M | 1D/2D/3D |
445
+ | `convnext_small` | 49.5M | 1D/2D/3D |
446
+ | `convnext_base` | 87.6M | 1D/2D/3D |
447
+ | `convnext_tiny_pretrained` ⭐ | 27.8M | 2D |
459
448
  | **DenseNet** — Densely Connected Network |||
460
- | `densenet121` | 7.5M | 1D/2D/3D |
461
- | `densenet169` | 13.3M | 1D/2D/3D |
462
- | `densenet121_pretrained` ⭐ | 7.5M | 2D |
449
+ | `densenet121` | 7.0M | 1D/2D/3D |
450
+ | `densenet169` | 12.5M | 1D/2D/3D |
451
+ | `densenet121_pretrained` ⭐ | 7.0M | 2D |
463
452
  | **ViT** — Vision Transformer |||
464
- | `vit_tiny` | 5.5M | 1D/2D |
465
- | `vit_small` | 21.6M | 1D/2D |
466
- | `vit_base` | 85.6M | 1D/2D |
453
+ | `vit_tiny` | 5.4M | 1D/2D |
454
+ | `vit_small` | 21.4M | 1D/2D |
455
+ | `vit_base` | 85.3M | 1D/2D |
456
+ | **ConvNeXt V2** — ConvNeXt with GRN |||
457
+ | `convnext_v2_tiny` | 27.9M | 1D/2D/3D |
458
+ | `convnext_v2_small` | 49.6M | 1D/2D/3D |
459
+ | `convnext_v2_base` | 87.7M | 1D/2D/3D |
460
+ | `convnext_v2_tiny_pretrained` ⭐ | 27.9M | 2D |
461
+ | **Mamba** — State Space Model |||
462
+ | `mamba_1d` | 3.4M | 1D |
463
+ | **Vision Mamba (ViM)** — 2D Mamba |||
464
+ | `vim_tiny` | 6.6M | 2D |
465
+ | `vim_small` | 51.1M | 2D |
466
+ | `vim_base` | 201.4M | 2D |
467
+ | **MaxViT** — Multi-Axis ViT |||
468
+ | `maxvit_tiny` ⭐ | 30.1M | 2D |
469
+ | `maxvit_small` ⭐ | 67.6M | 2D |
470
+ | `maxvit_base` ⭐ | 119.1M | 2D |
471
+ | **FastViT** — Fast Hybrid CNN-ViT |||
472
+ | `fastvit_t8` ⭐ | 4.0M | 2D |
473
+ | `fastvit_t12` ⭐ | 6.8M | 2D |
474
+ | `fastvit_s12` ⭐ | 8.8M | 2D |
475
+ | `fastvit_sa12` ⭐ | 10.9M | 2D |
476
+ | **CAFormer** — MetaFormer with Attention |||
477
+ | `caformer_s18` ⭐ | 26.3M | 2D |
478
+ | `caformer_s36` ⭐ | 39.2M | 2D |
479
+ | `caformer_m36` ⭐ | 56.9M | 2D |
480
+ | `poolformer_s12` ⭐ | 11.9M | 2D |
467
481
  | **U-Net** — U-shaped Network |||
468
- | `unet_regression` | 31.1M | 1D/2D/3D |
482
+ | `unet_regression` | 31.0M | 1D/2D/3D |
483
+
469
484
 
470
485
  ⭐ = **Pretrained on ImageNet** (recommended for smaller datasets). Weights are downloaded automatically on first use.
471
486
  - **Cache location**: `~/.cache/torch/hub/checkpoints/` (or `./.torch_cache/` on HPC if home is not writable)
472
- - **Size**: ~20–350 MB per model depending on architecture
473
487
  - **Train from scratch**: Use `--no_pretrained` to disable pretrained weights
474
488
 
475
489
  **💡 HPC Users**: If compute nodes block internet, pre-download weights on the login node:
476
490
 
477
491
  ```bash
478
- # Run once on login node (with internet) — downloads ALL pretrained weights (~1.5 GB total)
492
+ # Run once on login node (with internet) — downloads ALL pretrained weights
479
493
  python -c "
480
494
  import os
481
495
  os.environ['TORCH_HOME'] = '.torch_cache' # Match WaveDL's HPC cache location
@@ -483,7 +497,7 @@ os.environ['TORCH_HOME'] = '.torch_cache' # Match WaveDL's HPC cache location
483
497
  from torchvision import models as m
484
498
  from torchvision.models import video as v
485
499
 
486
- # Model name -> Weights class mapping
500
+ # === TorchVision Models ===
487
501
  weights = {
488
502
  'resnet18': m.ResNet18_Weights, 'resnet50': m.ResNet50_Weights,
489
503
  'efficientnet_b0': m.EfficientNet_B0_Weights, 'efficientnet_b1': m.EfficientNet_B1_Weights,
@@ -501,6 +515,20 @@ for name, w in weights.items():
501
515
  # 3D video models
502
516
  v.r3d_18(weights=v.R3D_18_Weights.DEFAULT); print('✓ r3d_18')
503
517
  v.mc3_18(weights=v.MC3_18_Weights.DEFAULT); print('✓ mc3_18')
518
+
519
+ # === Timm Models (MaxViT, FastViT, CAFormer, ConvNeXt V2) ===
520
+ import timm
521
+
522
+ timm_models = [
523
+ 'maxvit_tiny_tf_224.in1k', 'maxvit_small_tf_224.in1k', 'maxvit_base_tf_224.in1k',
524
+ 'fastvit_t8.apple_in1k', 'fastvit_t12.apple_in1k', 'fastvit_s12.apple_in1k', 'fastvit_sa12.apple_in1k',
525
+ 'caformer_s18.sail_in1k', 'caformer_s36.sail_in22k_ft_in1k', 'caformer_m36.sail_in22k_ft_in1k',
526
+ 'poolformer_s12.sail_in1k',
527
+ 'convnextv2_tiny.fcmae_ft_in1k',
528
+ ]
529
+ for name in timm_models:
530
+ timm.create_model(name, pretrained=True); print(f'✓ {name}')
531
+
504
532
  print('\\n✓ All pretrained weights cached!')
505
533
  "
506
534
  ```
@@ -71,7 +71,7 @@ Train on datasets larger than RAM:
71
71
 
72
72
  **🧠 Models? We've Got Options**
73
73
 
74
- 38 architectures, ready to go:
74
+ 57 architectures, ready to go:
75
75
  - CNNs, ResNets, ViTs, EfficientNets...
76
76
  - All adapted for regression
77
77
  - [Add your own](#adding-custom-models) in one line
@@ -156,7 +156,7 @@ Deploy models anywhere:
156
156
  #### From PyPI (recommended for all users)
157
157
 
158
158
  ```bash
159
- pip install wavedl
159
+ pip install --upgrade wavedl
160
160
  ```
161
161
 
162
162
  This installs everything you need: training, inference, HPO, ONNX export.
@@ -312,22 +312,10 @@ WaveDL/
312
312
  │ ├── hpo.py # Hyperparameter optimization
313
313
  │ ├── hpc.py # HPC distributed training launcher
314
314
  │ │
315
- │ ├── models/ # Model architectures (38 variants)
315
+ │ ├── models/ # Model Zoo (57 architectures)
316
316
  │ │ ├── registry.py # Model factory (@register_model)
317
317
  │ │ ├── base.py # Abstract base class
318
- │ │ ├── cnn.py # Baseline CNN (1D/2D/3D)
319
- │ │ ├── resnet.py # ResNet-18/34/50 (1D/2D/3D)
320
- │ │ ├── resnet3d.py # ResNet3D-18, MC3-18 (3D only)
321
- │ │ ├── tcn.py # TCN (1D only)
322
- │ │ ├── efficientnet.py # EfficientNet-B0/B1/B2 (2D)
323
- │ │ ├── efficientnetv2.py # EfficientNetV2-S/M/L (2D)
324
- │ │ ├── mobilenetv3.py # MobileNetV3-Small/Large (2D)
325
- │ │ ├── regnet.py # RegNetY variants (2D)
326
- │ │ ├── swin.py # Swin Transformer (2D)
327
- │ │ ├── vit.py # Vision Transformer (1D/2D)
328
- │ │ ├── convnext.py # ConvNeXt (1D/2D/3D)
329
- │ │ ├── densenet.py # DenseNet-121/169 (1D/2D/3D)
330
- │ │ └── unet.py # U-Net Regression
318
+ │ │ └── ... # See "Available Models" section
331
319
  │ │
332
320
  │ └── utils/ # Utilities
333
321
  │ ├── data.py # Memory-mapped data pipeline
@@ -342,7 +330,7 @@ WaveDL/
342
330
  ├── configs/ # YAML config templates
343
331
  ├── examples/ # Ready-to-run examples
344
332
  ├── notebooks/ # Jupyter notebooks
345
- ├── unit_tests/ # Pytest test suite (903 tests)
333
+ ├── unit_tests/ # Pytest test suite
346
334
 
347
335
  ├── pyproject.toml # Package config, dependencies
348
336
  ├── CHANGELOG.md # Version history
@@ -365,71 +353,96 @@ WaveDL/
365
353
  > ```
366
354
 
367
355
  <details>
368
- <summary><b>Available Models</b> — 38 architectures</summary>
356
+ <summary><b>Available Models</b> — 57 architectures</summary>
369
357
 
370
- | Model | Params | Dim |
371
- |-------|--------|-----|
358
+ | Model | Backbone Params | Dim |
359
+ |-------|-----------------|-----|
372
360
  | **CNN** — Convolutional Neural Network |||
373
- | `cnn` | 1.7M | 1D/2D/3D |
361
+ | `cnn` | 1.6M | 1D/2D/3D |
374
362
  | **ResNet** — Residual Network |||
375
- | `resnet18` | 11.4M | 1D/2D/3D |
376
- | `resnet34` | 21.5M | 1D/2D/3D |
377
- | `resnet50` | 24.6M | 1D/2D/3D |
378
- | `resnet18_pretrained` ⭐ | 11.4M | 2D |
379
- | `resnet50_pretrained` ⭐ | 24.6M | 2D |
363
+ | `resnet18` | 11.2M | 1D/2D/3D |
364
+ | `resnet34` | 21.3M | 1D/2D/3D |
365
+ | `resnet50` | 23.5M | 1D/2D/3D |
366
+ | `resnet18_pretrained` ⭐ | 11.2M | 2D |
367
+ | `resnet50_pretrained` ⭐ | 23.5M | 2D |
380
368
  | **ResNet3D** — 3D Residual Network |||
381
- | `resnet3d_18` | 33.6M | 3D |
382
- | `mc3_18` — Mixed Convolution 3D | 11.9M | 3D |
369
+ | `resnet3d_18` | 33.2M | 3D |
370
+ | `mc3_18` — Mixed Convolution 3D | 11.5M | 3D |
383
371
  | **TCN** — Temporal Convolutional Network |||
384
- | `tcn_small` | 1.0M | 1D |
385
- | `tcn` | 7.0M | 1D |
386
- | `tcn_large` | 10.2M | 1D |
372
+ | `tcn_small` | 0.9M | 1D |
373
+ | `tcn` | 6.9M | 1D |
374
+ | `tcn_large` | 10.0M | 1D |
387
375
  | **EfficientNet** — Efficient Neural Network |||
388
- | `efficientnet_b0` ⭐ | 4.7M | 2D |
389
- | `efficientnet_b1` ⭐ | 7.2M | 2D |
390
- | `efficientnet_b2` ⭐ | 8.4M | 2D |
376
+ | `efficientnet_b0` ⭐ | 4.0M | 2D |
377
+ | `efficientnet_b1` ⭐ | 6.5M | 2D |
378
+ | `efficientnet_b2` ⭐ | 7.7M | 2D |
391
379
  | **EfficientNetV2** — Efficient Neural Network V2 |||
392
- | `efficientnet_v2_s` ⭐ | 21.0M | 2D |
393
- | `efficientnet_v2_m` ⭐ | 53.6M | 2D |
394
- | `efficientnet_v2_l` ⭐ | 118.0M | 2D |
380
+ | `efficientnet_v2_s` ⭐ | 20.2M | 2D |
381
+ | `efficientnet_v2_m` ⭐ | 52.9M | 2D |
382
+ | `efficientnet_v2_l` ⭐ | 117.2M | 2D |
395
383
  | **MobileNetV3** — Mobile Neural Network V3 |||
396
- | `mobilenet_v3_small` ⭐ | 1.1M | 2D |
397
- | `mobilenet_v3_large` ⭐ | 3.2M | 2D |
384
+ | `mobilenet_v3_small` ⭐ | 0.9M | 2D |
385
+ | `mobilenet_v3_large` ⭐ | 3.0M | 2D |
398
386
  | **RegNet** — Regularized Network |||
399
- | `regnet_y_400mf` ⭐ | 4.0M | 2D |
400
- | `regnet_y_800mf` ⭐ | 5.8M | 2D |
401
- | `regnet_y_1_6gf` ⭐ | 10.5M | 2D |
402
- | `regnet_y_3_2gf` ⭐ | 18.3M | 2D |
403
- | `regnet_y_8gf` ⭐ | 37.9M | 2D |
387
+ | `regnet_y_400mf` ⭐ | 3.9M | 2D |
388
+ | `regnet_y_800mf` ⭐ | 5.7M | 2D |
389
+ | `regnet_y_1_6gf` ⭐ | 10.3M | 2D |
390
+ | `regnet_y_3_2gf` ⭐ | 17.9M | 2D |
391
+ | `regnet_y_8gf` ⭐ | 37.4M | 2D |
404
392
  | **Swin** — Shifted Window Transformer |||
405
- | `swin_t` ⭐ | 28.0M | 2D |
406
- | `swin_s` ⭐ | 49.4M | 2D |
407
- | `swin_b` ⭐ | 87.4M | 2D |
393
+ | `swin_t` ⭐ | 27.5M | 2D |
394
+ | `swin_s` ⭐ | 48.8M | 2D |
395
+ | `swin_b` ⭐ | 86.7M | 2D |
408
396
  | **ConvNeXt** — Convolutional Next |||
409
- | `convnext_tiny` | 28.2M | 1D/2D/3D |
410
- | `convnext_small` | 49.8M | 1D/2D/3D |
411
- | `convnext_base` | 88.1M | 1D/2D/3D |
412
- | `convnext_tiny_pretrained` ⭐ | 28.2M | 2D |
397
+ | `convnext_tiny` | 27.8M | 1D/2D/3D |
398
+ | `convnext_small` | 49.5M | 1D/2D/3D |
399
+ | `convnext_base` | 87.6M | 1D/2D/3D |
400
+ | `convnext_tiny_pretrained` ⭐ | 27.8M | 2D |
413
401
  | **DenseNet** — Densely Connected Network |||
414
- | `densenet121` | 7.5M | 1D/2D/3D |
415
- | `densenet169` | 13.3M | 1D/2D/3D |
416
- | `densenet121_pretrained` ⭐ | 7.5M | 2D |
402
+ | `densenet121` | 7.0M | 1D/2D/3D |
403
+ | `densenet169` | 12.5M | 1D/2D/3D |
404
+ | `densenet121_pretrained` ⭐ | 7.0M | 2D |
417
405
  | **ViT** — Vision Transformer |||
418
- | `vit_tiny` | 5.5M | 1D/2D |
419
- | `vit_small` | 21.6M | 1D/2D |
420
- | `vit_base` | 85.6M | 1D/2D |
406
+ | `vit_tiny` | 5.4M | 1D/2D |
407
+ | `vit_small` | 21.4M | 1D/2D |
408
+ | `vit_base` | 85.3M | 1D/2D |
409
+ | **ConvNeXt V2** — ConvNeXt with GRN |||
410
+ | `convnext_v2_tiny` | 27.9M | 1D/2D/3D |
411
+ | `convnext_v2_small` | 49.6M | 1D/2D/3D |
412
+ | `convnext_v2_base` | 87.7M | 1D/2D/3D |
413
+ | `convnext_v2_tiny_pretrained` ⭐ | 27.9M | 2D |
414
+ | **Mamba** — State Space Model |||
415
+ | `mamba_1d` | 3.4M | 1D |
416
+ | **Vision Mamba (ViM)** — 2D Mamba |||
417
+ | `vim_tiny` | 6.6M | 2D |
418
+ | `vim_small` | 51.1M | 2D |
419
+ | `vim_base` | 201.4M | 2D |
420
+ | **MaxViT** — Multi-Axis ViT |||
421
+ | `maxvit_tiny` ⭐ | 30.1M | 2D |
422
+ | `maxvit_small` ⭐ | 67.6M | 2D |
423
+ | `maxvit_base` ⭐ | 119.1M | 2D |
424
+ | **FastViT** — Fast Hybrid CNN-ViT |||
425
+ | `fastvit_t8` ⭐ | 4.0M | 2D |
426
+ | `fastvit_t12` ⭐ | 6.8M | 2D |
427
+ | `fastvit_s12` ⭐ | 8.8M | 2D |
428
+ | `fastvit_sa12` ⭐ | 10.9M | 2D |
429
+ | **CAFormer** — MetaFormer with Attention |||
430
+ | `caformer_s18` ⭐ | 26.3M | 2D |
431
+ | `caformer_s36` ⭐ | 39.2M | 2D |
432
+ | `caformer_m36` ⭐ | 56.9M | 2D |
433
+ | `poolformer_s12` ⭐ | 11.9M | 2D |
421
434
  | **U-Net** — U-shaped Network |||
422
- | `unet_regression` | 31.1M | 1D/2D/3D |
435
+ | `unet_regression` | 31.0M | 1D/2D/3D |
436
+
423
437
 
424
438
  ⭐ = **Pretrained on ImageNet** (recommended for smaller datasets). Weights are downloaded automatically on first use.
425
439
  - **Cache location**: `~/.cache/torch/hub/checkpoints/` (or `./.torch_cache/` on HPC if home is not writable)
426
- - **Size**: ~20–350 MB per model depending on architecture
427
440
  - **Train from scratch**: Use `--no_pretrained` to disable pretrained weights
428
441
 
429
442
  **💡 HPC Users**: If compute nodes block internet, pre-download weights on the login node:
430
443
 
431
444
  ```bash
432
- # Run once on login node (with internet) — downloads ALL pretrained weights (~1.5 GB total)
445
+ # Run once on login node (with internet) — downloads ALL pretrained weights
433
446
  python -c "
434
447
  import os
435
448
  os.environ['TORCH_HOME'] = '.torch_cache' # Match WaveDL's HPC cache location
@@ -437,7 +450,7 @@ os.environ['TORCH_HOME'] = '.torch_cache' # Match WaveDL's HPC cache location
437
450
  from torchvision import models as m
438
451
  from torchvision.models import video as v
439
452
 
440
- # Model name -> Weights class mapping
453
+ # === TorchVision Models ===
441
454
  weights = {
442
455
  'resnet18': m.ResNet18_Weights, 'resnet50': m.ResNet50_Weights,
443
456
  'efficientnet_b0': m.EfficientNet_B0_Weights, 'efficientnet_b1': m.EfficientNet_B1_Weights,
@@ -455,6 +468,20 @@ for name, w in weights.items():
455
468
  # 3D video models
456
469
  v.r3d_18(weights=v.R3D_18_Weights.DEFAULT); print('✓ r3d_18')
457
470
  v.mc3_18(weights=v.MC3_18_Weights.DEFAULT); print('✓ mc3_18')
471
+
472
+ # === Timm Models (MaxViT, FastViT, CAFormer, ConvNeXt V2) ===
473
+ import timm
474
+
475
+ timm_models = [
476
+ 'maxvit_tiny_tf_224.in1k', 'maxvit_small_tf_224.in1k', 'maxvit_base_tf_224.in1k',
477
+ 'fastvit_t8.apple_in1k', 'fastvit_t12.apple_in1k', 'fastvit_s12.apple_in1k', 'fastvit_sa12.apple_in1k',
478
+ 'caformer_s18.sail_in1k', 'caformer_s36.sail_in22k_ft_in1k', 'caformer_m36.sail_in22k_ft_in1k',
479
+ 'poolformer_s12.sail_in1k',
480
+ 'convnextv2_tiny.fcmae_ft_in1k',
481
+ ]
482
+ for name in timm_models:
483
+ timm.create_model(name, pretrained=True); print(f'✓ {name}')
484
+
458
485
  print('\\n✓ All pretrained weights cached!')
459
486
  "
460
487
  ```
@@ -52,6 +52,7 @@ dependencies = [
52
52
  # Core ML stack
53
53
  "torch>=2.0.0",
54
54
  "torchvision>=0.15.0",
55
+ "timm>=0.9.0", # Pretrained models (MaxViT, FastViT, CAFormer)
55
56
  "accelerate>=0.20.0",
56
57
  "numpy>=1.24.0",
57
58
  "scipy>=1.10.0",
@@ -18,7 +18,7 @@ For inference:
18
18
  # or: python -m wavedl.test --checkpoint best_checkpoint --data_path test.npz
19
19
  """
20
20
 
21
- __version__ = "1.5.7"
21
+ __version__ = "1.6.0"
22
22
  __author__ = "Ductho Le"
23
23
  __email__ = "ductho.le@outlook.com"
24
24
 
@@ -6,10 +6,11 @@ This module provides a centralized registry for neural network architectures,
6
6
  enabling dynamic model selection via command-line arguments.
7
7
 
8
8
  **Dimensionality Coverage**:
9
- - 1D (waveforms): TCN, CNN, ResNet, ConvNeXt, DenseNet, ViT
10
- - 2D (images): CNN, ResNet, ConvNeXt, DenseNet, ViT, UNet,
11
- EfficientNet, MobileNetV3, RegNet, Swin
12
- - 3D (volumes): ResNet3D, CNN, ResNet, ConvNeXt, DenseNet
9
+ - 1D (waveforms): TCN, CNN, ResNet, ConvNeXt, ConvNeXt V2, DenseNet, ViT, Mamba
10
+ - 2D (images): CNN, ResNet, ConvNeXt, ConvNeXt V2, DenseNet, ViT, UNet,
11
+ EfficientNet, MobileNetV3, RegNet, Swin, MaxViT, FastViT,
12
+ CAFormer, PoolFormer, Vision Mamba
13
+ - 3D (volumes): ResNet3D, CNN, ResNet, ConvNeXt, ConvNeXt V2, DenseNet
13
14
 
14
15
  Usage:
15
16
  from wavedl.models import get_model, list_models, MODEL_REGISTRY
@@ -46,9 +47,19 @@ from .base import BaseModel
46
47
  # Import model implementations (triggers registration via decorators)
47
48
  from .cnn import CNN
48
49
  from .convnext import ConvNeXtBase_, ConvNeXtSmall, ConvNeXtTiny
50
+
51
+ # New models (v1.6+)
52
+ from .convnext_v2 import (
53
+ ConvNeXtV2Base,
54
+ ConvNeXtV2BaseLarge,
55
+ ConvNeXtV2Small,
56
+ ConvNeXtV2Tiny,
57
+ ConvNeXtV2TinyPretrained,
58
+ )
49
59
  from .densenet import DenseNet121, DenseNet169
50
60
  from .efficientnet import EfficientNetB0, EfficientNetB1, EfficientNetB2
51
61
  from .efficientnetv2 import EfficientNetV2L, EfficientNetV2M, EfficientNetV2S
62
+ from .mamba import Mamba1D, VimBase, VimSmall, VimTiny
52
63
  from .mobilenetv3 import MobileNetV3Large, MobileNetV3Small
53
64
  from .registry import (
54
65
  MODEL_REGISTRY,
@@ -66,6 +77,17 @@ from .unet import UNetRegression
66
77
  from .vit import ViTBase_, ViTSmall, ViTTiny
67
78
 
68
79
 
80
+ # Optional timm-based models (imported conditionally)
81
+ try:
82
+ from .caformer import CaFormerS18, CaFormerS36, PoolFormerS12
83
+ from .fastvit import FastViTS12, FastViTSA12, FastViTT8, FastViTT12
84
+ from .maxvit import MaxViTBaseLarge, MaxViTSmall, MaxViTTiny
85
+
86
+ _HAS_TIMM_MODELS = True
87
+ except ImportError:
88
+ _HAS_TIMM_MODELS = False
89
+
90
+
69
91
  # Export public API (sorted alphabetically per RUF022)
70
92
  # See module docstring for dimensionality support details
71
93
  __all__ = [
@@ -77,6 +99,11 @@ __all__ = [
77
99
  "ConvNeXtBase_",
78
100
  "ConvNeXtSmall",
79
101
  "ConvNeXtTiny",
102
+ "ConvNeXtV2Base",
103
+ "ConvNeXtV2BaseLarge",
104
+ "ConvNeXtV2Small",
105
+ "ConvNeXtV2Tiny",
106
+ "ConvNeXtV2TinyPretrained",
80
107
  "DenseNet121",
81
108
  "DenseNet169",
82
109
  "EfficientNetB0",
@@ -85,6 +112,7 @@ __all__ = [
85
112
  "EfficientNetV2L",
86
113
  "EfficientNetV2M",
87
114
  "EfficientNetV2S",
115
+ "Mamba1D",
88
116
  "MobileNetV3Large",
89
117
  "MobileNetV3Small",
90
118
  "RegNetY1_6GF",
@@ -105,8 +133,28 @@ __all__ = [
105
133
  "ViTBase_",
106
134
  "ViTSmall",
107
135
  "ViTTiny",
136
+ "VimBase",
137
+ "VimSmall",
138
+ "VimTiny",
108
139
  "build_model",
109
140
  "get_model",
110
141
  "list_models",
111
142
  "register_model",
112
143
  ]
144
+
145
+ # Add timm-based models to __all__ if available
146
+ if _HAS_TIMM_MODELS:
147
+ __all__.extend(
148
+ [
149
+ "CaFormerS18",
150
+ "CaFormerS36",
151
+ "FastViTS12",
152
+ "FastViTSA12",
153
+ "FastViTT8",
154
+ "FastViTT12",
155
+ "MaxViTBaseLarge",
156
+ "MaxViTSmall",
157
+ "MaxViTTiny",
158
+ "PoolFormerS12",
159
+ ]
160
+ )