wavedl 1.5.7__tar.gz → 1.6.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (51) hide show
  1. {wavedl-1.5.7/src/wavedl.egg-info → wavedl-1.6.1}/PKG-INFO +150 -82
  2. {wavedl-1.5.7 → wavedl-1.6.1}/README.md +147 -81
  3. {wavedl-1.5.7 → wavedl-1.6.1}/pyproject.toml +2 -0
  4. {wavedl-1.5.7 → wavedl-1.6.1}/src/wavedl/__init__.py +1 -1
  5. {wavedl-1.5.7 → wavedl-1.6.1}/src/wavedl/hpo.py +451 -451
  6. {wavedl-1.5.7 → wavedl-1.6.1}/src/wavedl/models/__init__.py +80 -4
  7. wavedl-1.6.1/src/wavedl/models/_pretrained_utils.py +366 -0
  8. {wavedl-1.5.7 → wavedl-1.6.1}/src/wavedl/models/base.py +48 -0
  9. wavedl-1.6.1/src/wavedl/models/caformer.py +270 -0
  10. {wavedl-1.5.7 → wavedl-1.6.1}/src/wavedl/models/cnn.py +2 -27
  11. {wavedl-1.5.7 → wavedl-1.6.1}/src/wavedl/models/convnext.py +113 -51
  12. wavedl-1.6.1/src/wavedl/models/convnext_v2.py +488 -0
  13. {wavedl-1.5.7 → wavedl-1.6.1}/src/wavedl/models/densenet.py +10 -23
  14. {wavedl-1.5.7 → wavedl-1.6.1}/src/wavedl/models/efficientnet.py +6 -6
  15. {wavedl-1.5.7 → wavedl-1.6.1}/src/wavedl/models/efficientnetv2.py +315 -315
  16. wavedl-1.6.1/src/wavedl/models/efficientvit.py +398 -0
  17. wavedl-1.6.1/src/wavedl/models/fastvit.py +252 -0
  18. wavedl-1.6.1/src/wavedl/models/mamba.py +555 -0
  19. wavedl-1.6.1/src/wavedl/models/maxvit.py +254 -0
  20. {wavedl-1.5.7 → wavedl-1.6.1}/src/wavedl/models/mobilenetv3.py +295 -295
  21. {wavedl-1.5.7 → wavedl-1.6.1}/src/wavedl/models/regnet.py +406 -406
  22. {wavedl-1.5.7 → wavedl-1.6.1}/src/wavedl/models/resnet.py +19 -61
  23. {wavedl-1.5.7 → wavedl-1.6.1}/src/wavedl/models/resnet3d.py +258 -258
  24. {wavedl-1.5.7 → wavedl-1.6.1}/src/wavedl/models/swin.py +443 -443
  25. {wavedl-1.5.7 → wavedl-1.6.1}/src/wavedl/models/tcn.py +393 -409
  26. {wavedl-1.5.7 → wavedl-1.6.1}/src/wavedl/models/unet.py +2 -6
  27. wavedl-1.6.1/src/wavedl/models/unireplknet.py +491 -0
  28. {wavedl-1.5.7 → wavedl-1.6.1}/src/wavedl/models/vit.py +9 -9
  29. {wavedl-1.5.7 → wavedl-1.6.1}/src/wavedl/train.py +1430 -1425
  30. {wavedl-1.5.7 → wavedl-1.6.1}/src/wavedl/utils/config.py +367 -367
  31. {wavedl-1.5.7 → wavedl-1.6.1}/src/wavedl/utils/cross_validation.py +530 -530
  32. {wavedl-1.5.7 → wavedl-1.6.1}/src/wavedl/utils/data.py +39 -6
  33. {wavedl-1.5.7 → wavedl-1.6.1}/src/wavedl/utils/losses.py +216 -216
  34. {wavedl-1.5.7 → wavedl-1.6.1}/src/wavedl/utils/optimizers.py +216 -216
  35. {wavedl-1.5.7 → wavedl-1.6.1}/src/wavedl/utils/schedulers.py +251 -251
  36. {wavedl-1.5.7 → wavedl-1.6.1/src/wavedl.egg-info}/PKG-INFO +150 -82
  37. {wavedl-1.5.7 → wavedl-1.6.1}/src/wavedl.egg-info/SOURCES.txt +8 -0
  38. {wavedl-1.5.7 → wavedl-1.6.1}/src/wavedl.egg-info/requires.txt +2 -0
  39. {wavedl-1.5.7 → wavedl-1.6.1}/LICENSE +0 -0
  40. {wavedl-1.5.7 → wavedl-1.6.1}/setup.cfg +0 -0
  41. {wavedl-1.5.7 → wavedl-1.6.1}/src/wavedl/hpc.py +0 -0
  42. {wavedl-1.5.7 → wavedl-1.6.1}/src/wavedl/models/_template.py +0 -0
  43. {wavedl-1.5.7 → wavedl-1.6.1}/src/wavedl/models/registry.py +0 -0
  44. {wavedl-1.5.7 → wavedl-1.6.1}/src/wavedl/test.py +0 -0
  45. {wavedl-1.5.7 → wavedl-1.6.1}/src/wavedl/utils/__init__.py +0 -0
  46. {wavedl-1.5.7 → wavedl-1.6.1}/src/wavedl/utils/constraints.py +0 -0
  47. {wavedl-1.5.7 → wavedl-1.6.1}/src/wavedl/utils/distributed.py +0 -0
  48. {wavedl-1.5.7 → wavedl-1.6.1}/src/wavedl/utils/metrics.py +0 -0
  49. {wavedl-1.5.7 → wavedl-1.6.1}/src/wavedl.egg-info/dependency_links.txt +0 -0
  50. {wavedl-1.5.7 → wavedl-1.6.1}/src/wavedl.egg-info/entry_points.txt +0 -0
  51. {wavedl-1.5.7 → wavedl-1.6.1}/src/wavedl.egg-info/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.2
2
2
  Name: wavedl
3
- Version: 1.5.7
3
+ Version: 1.6.1
4
4
  Summary: A Scalable Deep Learning Framework for Wave-Based Inverse Problems
5
5
  Author: Ductho Le
6
6
  License: MIT
@@ -23,6 +23,7 @@ Description-Content-Type: text/markdown
23
23
  License-File: LICENSE
24
24
  Requires-Dist: torch>=2.0.0
25
25
  Requires-Dist: torchvision>=0.15.0
26
+ Requires-Dist: timm>=0.9.0
26
27
  Requires-Dist: accelerate>=0.20.0
27
28
  Requires-Dist: numpy>=1.24.0
28
29
  Requires-Dist: scipy>=1.10.0
@@ -37,6 +38,7 @@ Requires-Dist: wandb>=0.15.0
37
38
  Requires-Dist: optuna>=3.0.0
38
39
  Requires-Dist: onnx>=1.14.0
39
40
  Requires-Dist: onnxruntime>=1.15.0
41
+ Requires-Dist: onnxscript>=0.1.0
40
42
  Requires-Dist: triton>=2.0.0; sys_platform == "linux"
41
43
  Provides-Extra: dev
42
44
  Requires-Dist: pytest>=7.0.0; extra == "dev"
@@ -117,7 +119,7 @@ Train on datasets larger than RAM:
117
119
 
118
120
  **🧠 Models? We've Got Options**
119
121
 
120
- 38 architectures, ready to go:
122
+ 69 architectures, ready to go:
121
123
  - CNNs, ResNets, ViTs, EfficientNets...
122
124
  - All adapted for regression
123
125
  - [Add your own](#adding-custom-models) in one line
@@ -202,7 +204,7 @@ Deploy models anywhere:
202
204
  #### From PyPI (recommended for all users)
203
205
 
204
206
  ```bash
205
- pip install wavedl
207
+ pip install --upgrade wavedl
206
208
  ```
207
209
 
208
210
  This installs everything you need: training, inference, HPO, ONNX export.
@@ -358,22 +360,10 @@ WaveDL/
358
360
  │ ├── hpo.py # Hyperparameter optimization
359
361
  │ ├── hpc.py # HPC distributed training launcher
360
362
  │ │
361
- │ ├── models/ # Model architectures (38 variants)
363
+ │ ├── models/ # Model Zoo (69 architectures)
362
364
  │ │ ├── registry.py # Model factory (@register_model)
363
365
  │ │ ├── base.py # Abstract base class
364
- │ │ ├── cnn.py # Baseline CNN (1D/2D/3D)
365
- │ │ ├── resnet.py # ResNet-18/34/50 (1D/2D/3D)
366
- │ │ ├── resnet3d.py # ResNet3D-18, MC3-18 (3D only)
367
- │ │ ├── tcn.py # TCN (1D only)
368
- │ │ ├── efficientnet.py # EfficientNet-B0/B1/B2 (2D)
369
- │ │ ├── efficientnetv2.py # EfficientNetV2-S/M/L (2D)
370
- │ │ ├── mobilenetv3.py # MobileNetV3-Small/Large (2D)
371
- │ │ ├── regnet.py # RegNetY variants (2D)
372
- │ │ ├── swin.py # Swin Transformer (2D)
373
- │ │ ├── vit.py # Vision Transformer (1D/2D)
374
- │ │ ├── convnext.py # ConvNeXt (1D/2D/3D)
375
- │ │ ├── densenet.py # DenseNet-121/169 (1D/2D/3D)
376
- │ │ └── unet.py # U-Net Regression
366
+ │ │ └── ... # See "Available Models" section
377
367
  │ │
378
368
  │ └── utils/ # Utilities
379
369
  │ ├── data.py # Memory-mapped data pipeline
@@ -388,7 +378,7 @@ WaveDL/
388
378
  ├── configs/ # YAML config templates
389
379
  ├── examples/ # Ready-to-run examples
390
380
  ├── notebooks/ # Jupyter notebooks
391
- ├── unit_tests/ # Pytest test suite (903 tests)
381
+ ├── unit_tests/ # Pytest test suite
392
382
 
393
383
  ├── pyproject.toml # Package config, dependencies
394
384
  ├── CHANGELOG.md # Version history
@@ -411,71 +401,117 @@ WaveDL/
411
401
  > ```
412
402
 
413
403
  <details>
414
- <summary><b>Available Models</b> — 38 architectures</summary>
404
+ <summary><b>Available Models</b> — 69 architectures</summary>
415
405
 
416
- | Model | Params | Dim |
417
- |-------|--------|-----|
406
+ | Model | Backbone Params | Dim |
407
+ |-------|-----------------|-----|
408
+ | **── Classic CNNs ──** |||
418
409
  | **CNN** — Convolutional Neural Network |||
419
- | `cnn` | 1.7M | 1D/2D/3D |
410
+ | `cnn` | 1.6M | 1D/2D/3D |
420
411
  | **ResNet** — Residual Network |||
421
- | `resnet18` | 11.4M | 1D/2D/3D |
422
- | `resnet34` | 21.5M | 1D/2D/3D |
423
- | `resnet50` | 24.6M | 1D/2D/3D |
424
- | `resnet18_pretrained` ⭐ | 11.4M | 2D |
425
- | `resnet50_pretrained` ⭐ | 24.6M | 2D |
426
- | **ResNet3D** — 3D Residual Network |||
427
- | `resnet3d_18` | 33.6M | 3D |
428
- | `mc3_18` — Mixed Convolution 3D | 11.9M | 3D |
429
- | **TCN** Temporal Convolutional Network |||
430
- | `tcn_small` | 1.0M | 1D |
431
- | `tcn` | 7.0M | 1D |
432
- | `tcn_large` | 10.2M | 1D |
412
+ | `resnet18` | 11.2M | 1D/2D/3D |
413
+ | `resnet34` | 21.3M | 1D/2D/3D |
414
+ | `resnet50` | 23.5M | 1D/2D/3D |
415
+ | `resnet18_pretrained` ⭐ | 11.2M | 2D |
416
+ | `resnet50_pretrained` ⭐ | 23.5M | 2D |
417
+ | **DenseNet** — Densely Connected Network |||
418
+ | `densenet121` | 7.0M | 1D/2D/3D |
419
+ | `densenet169` | 12.5M | 1D/2D/3D |
420
+ | `densenet121_pretrained` | 7.0M | 2D |
421
+ | **── Efficient/Mobile CNNs ──** |||
422
+ | **MobileNetV3** Mobile Neural Network V3 |||
423
+ | `mobilenet_v3_small` | 0.9M | 2D |
424
+ | `mobilenet_v3_large` ⭐ | 3.0M | 2D |
433
425
  | **EfficientNet** — Efficient Neural Network |||
434
- | `efficientnet_b0` ⭐ | 4.7M | 2D |
435
- | `efficientnet_b1` ⭐ | 7.2M | 2D |
436
- | `efficientnet_b2` ⭐ | 8.4M | 2D |
426
+ | `efficientnet_b0` ⭐ | 4.0M | 2D |
427
+ | `efficientnet_b1` ⭐ | 6.5M | 2D |
428
+ | `efficientnet_b2` ⭐ | 7.7M | 2D |
437
429
  | **EfficientNetV2** — Efficient Neural Network V2 |||
438
- | `efficientnet_v2_s` ⭐ | 21.0M | 2D |
439
- | `efficientnet_v2_m` ⭐ | 53.6M | 2D |
440
- | `efficientnet_v2_l` ⭐ | 118.0M | 2D |
441
- | **MobileNetV3** — Mobile Neural Network V3 |||
442
- | `mobilenet_v3_small` ⭐ | 1.1M | 2D |
443
- | `mobilenet_v3_large` ⭐ | 3.2M | 2D |
430
+ | `efficientnet_v2_s` ⭐ | 20.2M | 2D |
431
+ | `efficientnet_v2_m` ⭐ | 52.9M | 2D |
432
+ | `efficientnet_v2_l` ⭐ | 117.2M | 2D |
444
433
  | **RegNet** — Regularized Network |||
445
- | `regnet_y_400mf` ⭐ | 4.0M | 2D |
446
- | `regnet_y_800mf` ⭐ | 5.8M | 2D |
447
- | `regnet_y_1_6gf` ⭐ | 10.5M | 2D |
448
- | `regnet_y_3_2gf` ⭐ | 18.3M | 2D |
449
- | `regnet_y_8gf` ⭐ | 37.9M | 2D |
450
- | **Swin** Shifted Window Transformer |||
451
- | `swin_t` ⭐ | 28.0M | 2D |
452
- | `swin_s` ⭐ | 49.4M | 2D |
453
- | `swin_b` ⭐ | 87.4M | 2D |
434
+ | `regnet_y_400mf` ⭐ | 3.9M | 2D |
435
+ | `regnet_y_800mf` ⭐ | 5.7M | 2D |
436
+ | `regnet_y_1_6gf` ⭐ | 10.3M | 2D |
437
+ | `regnet_y_3_2gf` ⭐ | 17.9M | 2D |
438
+ | `regnet_y_8gf` ⭐ | 37.4M | 2D |
439
+ | **── Modern CNNs ──** |||
454
440
  | **ConvNeXt** — Convolutional Next |||
455
- | `convnext_tiny` | 28.2M | 1D/2D/3D |
456
- | `convnext_small` | 49.8M | 1D/2D/3D |
457
- | `convnext_base` | 88.1M | 1D/2D/3D |
458
- | `convnext_tiny_pretrained` ⭐ | 28.2M | 2D |
459
- | **DenseNet** — Densely Connected Network |||
460
- | `densenet121` | 7.5M | 1D/2D/3D |
461
- | `densenet169` | 13.3M | 1D/2D/3D |
462
- | `densenet121_pretrained` | 7.5M | 2D |
441
+ | `convnext_tiny` | 27.8M | 1D/2D/3D |
442
+ | `convnext_small` | 49.5M | 1D/2D/3D |
443
+ | `convnext_base` | 87.6M | 1D/2D/3D |
444
+ | `convnext_tiny_pretrained` ⭐ | 27.8M | 2D |
445
+ | **ConvNeXt V2** — ConvNeXt with GRN |||
446
+ | `convnext_v2_tiny` | 27.9M | 1D/2D/3D |
447
+ | `convnext_v2_small` | 49.6M | 1D/2D/3D |
448
+ | `convnext_v2_base` | 87.7M | 1D/2D/3D |
449
+ | `convnext_v2_tiny_pretrained` ⭐ | 27.9M | 2D |
450
+ | **UniRepLKNet** — Large-Kernel ConvNet |||
451
+ | `unireplknet_tiny` | 30.8M | 1D/2D/3D |
452
+ | `unireplknet_small` | 56.0M | 1D/2D/3D |
453
+ | `unireplknet_base` | 97.6M | 1D/2D/3D |
454
+ | **── Vision Transformers ──** |||
463
455
  | **ViT** — Vision Transformer |||
464
- | `vit_tiny` | 5.5M | 1D/2D |
465
- | `vit_small` | 21.6M | 1D/2D |
466
- | `vit_base` | 85.6M | 1D/2D |
456
+ | `vit_tiny` | 5.4M | 1D/2D |
457
+ | `vit_small` | 21.4M | 1D/2D |
458
+ | `vit_base` | 85.3M | 1D/2D |
459
+ | **Swin** — Shifted Window Transformer |||
460
+ | `swin_t` ⭐ | 27.5M | 2D |
461
+ | `swin_s` ⭐ | 48.8M | 2D |
462
+ | `swin_b` ⭐ | 86.7M | 2D |
463
+ | **MaxViT** — Multi-Axis ViT |||
464
+ | `maxvit_tiny` ⭐ | 30.1M | 2D |
465
+ | `maxvit_small` ⭐ | 67.6M | 2D |
466
+ | `maxvit_base` ⭐ | 119.1M | 2D |
467
+ | **── Hybrid CNN-Transformer ──** |||
468
+ | **FastViT** — Fast Hybrid CNN-ViT |||
469
+ | `fastvit_t8` ⭐ | 4.0M | 2D |
470
+ | `fastvit_t12` ⭐ | 6.8M | 2D |
471
+ | `fastvit_s12` ⭐ | 8.8M | 2D |
472
+ | `fastvit_sa12` ⭐ | 10.9M | 2D |
473
+ | **CAFormer** — MetaFormer with Attention |||
474
+ | `caformer_s18` ⭐ | 26.3M | 2D |
475
+ | `caformer_s36` ⭐ | 39.2M | 2D |
476
+ | `caformer_m36` ⭐ | 56.9M | 2D |
477
+ | `poolformer_s12` ⭐ | 11.9M | 2D |
478
+ | **EfficientViT** — Memory-Efficient ViT |||
479
+ | `efficientvit_m0` ⭐ | 2.2M | 2D |
480
+ | `efficientvit_m1` ⭐ | 2.6M | 2D |
481
+ | `efficientvit_m2` ⭐ | 3.8M | 2D |
482
+ | `efficientvit_b0` ⭐ | 2.1M | 2D |
483
+ | `efficientvit_b1` ⭐ | 7.5M | 2D |
484
+ | `efficientvit_b2` ⭐ | 21.8M | 2D |
485
+ | `efficientvit_b3` ⭐ | 46.1M | 2D |
486
+ | `efficientvit_l1` ⭐ | 49.5M | 2D |
487
+ | `efficientvit_l2` ⭐ | 60.5M | 2D |
488
+ | **── State Space Models ──** |||
489
+ | **Mamba** — State Space Model |||
490
+ | `mamba_1d` | 3.4M | 1D |
491
+ | **Vision Mamba (ViM)** — 2D Mamba |||
492
+ | `vim_tiny` | 6.6M | 2D |
493
+ | `vim_small` | 51.1M | 2D |
494
+ | `vim_base` | 201.4M | 2D |
495
+ | **── Specialized Architectures ──** |||
496
+ | **TCN** — Temporal Convolutional Network |||
497
+ | `tcn_small` | 0.9M | 1D |
498
+ | `tcn` | 6.9M | 1D |
499
+ | `tcn_large` | 10.0M | 1D |
500
+ | **ResNet3D** — 3D Residual Network |||
501
+ | `resnet3d_18` | 33.2M | 3D |
502
+ | `mc3_18` — Mixed Convolution 3D | 11.5M | 3D |
467
503
  | **U-Net** — U-shaped Network |||
468
- | `unet_regression` | 31.1M | 1D/2D/3D |
504
+ | `unet_regression` | 31.0M | 1D/2D/3D |
505
+
469
506
 
470
507
  ⭐ = **Pretrained on ImageNet** (recommended for smaller datasets). Weights are downloaded automatically on first use.
471
508
  - **Cache location**: `~/.cache/torch/hub/checkpoints/` (or `./.torch_cache/` on HPC if home is not writable)
472
- - **Size**: ~20–350 MB per model depending on architecture
473
509
  - **Train from scratch**: Use `--no_pretrained` to disable pretrained weights
474
510
 
475
511
  **💡 HPC Users**: If compute nodes block internet, pre-download weights on the login node:
476
512
 
477
513
  ```bash
478
- # Run once on login node (with internet) — downloads ALL pretrained weights (~1.5 GB total)
514
+ # Run once on login node (with internet) — downloads ALL pretrained weights
479
515
  python -c "
480
516
  import os
481
517
  os.environ['TORCH_HOME'] = '.torch_cache' # Match WaveDL's HPC cache location
@@ -483,24 +519,56 @@ os.environ['TORCH_HOME'] = '.torch_cache' # Match WaveDL's HPC cache location
483
519
  from torchvision import models as m
484
520
  from torchvision.models import video as v
485
521
 
486
- # Model name -> Weights class mapping
487
- weights = {
488
- 'resnet18': m.ResNet18_Weights, 'resnet50': m.ResNet50_Weights,
489
- 'efficientnet_b0': m.EfficientNet_B0_Weights, 'efficientnet_b1': m.EfficientNet_B1_Weights,
490
- 'efficientnet_b2': m.EfficientNet_B2_Weights, 'efficientnet_v2_s': m.EfficientNet_V2_S_Weights,
491
- 'efficientnet_v2_m': m.EfficientNet_V2_M_Weights, 'efficientnet_v2_l': m.EfficientNet_V2_L_Weights,
492
- 'mobilenet_v3_small': m.MobileNet_V3_Small_Weights, 'mobilenet_v3_large': m.MobileNet_V3_Large_Weights,
493
- 'regnet_y_400mf': m.RegNet_Y_400MF_Weights, 'regnet_y_800mf': m.RegNet_Y_800MF_Weights,
494
- 'regnet_y_1_6gf': m.RegNet_Y_1_6GF_Weights, 'regnet_y_3_2gf': m.RegNet_Y_3_2GF_Weights,
495
- 'regnet_y_8gf': m.RegNet_Y_8GF_Weights, 'swin_t': m.Swin_T_Weights, 'swin_s': m.Swin_S_Weights,
496
- 'swin_b': m.Swin_B_Weights, 'convnext_tiny': m.ConvNeXt_Tiny_Weights, 'densenet121': m.DenseNet121_Weights,
497
- }
498
- for name, w in weights.items():
499
- getattr(m, name)(weights=w.DEFAULT); print(f'✓ {name}')
522
+ # === TorchVision Models (use IMAGENET1K_V1 to match WaveDL) ===
523
+ models = [
524
+ ('resnet18', m.ResNet18_Weights.IMAGENET1K_V1),
525
+ ('resnet50', m.ResNet50_Weights.IMAGENET1K_V1),
526
+ ('efficientnet_b0', m.EfficientNet_B0_Weights.IMAGENET1K_V1),
527
+ ('efficientnet_b1', m.EfficientNet_B1_Weights.IMAGENET1K_V1),
528
+ ('efficientnet_b2', m.EfficientNet_B2_Weights.IMAGENET1K_V1),
529
+ ('efficientnet_v2_s', m.EfficientNet_V2_S_Weights.IMAGENET1K_V1),
530
+ ('efficientnet_v2_m', m.EfficientNet_V2_M_Weights.IMAGENET1K_V1),
531
+ ('efficientnet_v2_l', m.EfficientNet_V2_L_Weights.IMAGENET1K_V1),
532
+ ('mobilenet_v3_small', m.MobileNet_V3_Small_Weights.IMAGENET1K_V1),
533
+ ('mobilenet_v3_large', m.MobileNet_V3_Large_Weights.IMAGENET1K_V1),
534
+ ('regnet_y_400mf', m.RegNet_Y_400MF_Weights.IMAGENET1K_V1),
535
+ ('regnet_y_800mf', m.RegNet_Y_800MF_Weights.IMAGENET1K_V1),
536
+ ('regnet_y_1_6gf', m.RegNet_Y_1_6GF_Weights.IMAGENET1K_V1),
537
+ ('regnet_y_3_2gf', m.RegNet_Y_3_2GF_Weights.IMAGENET1K_V1),
538
+ ('regnet_y_8gf', m.RegNet_Y_8GF_Weights.IMAGENET1K_V1),
539
+ ('swin_t', m.Swin_T_Weights.IMAGENET1K_V1),
540
+ ('swin_s', m.Swin_S_Weights.IMAGENET1K_V1),
541
+ ('swin_b', m.Swin_B_Weights.IMAGENET1K_V1),
542
+ ('convnext_tiny', m.ConvNeXt_Tiny_Weights.IMAGENET1K_V1),
543
+ ('densenet121', m.DenseNet121_Weights.IMAGENET1K_V1),
544
+ ]
545
+ for name, w in models:
546
+ getattr(m, name)(weights=w); print(f'✓ {name}')
500
547
 
501
548
  # 3D video models
502
- v.r3d_18(weights=v.R3D_18_Weights.DEFAULT); print('✓ r3d_18')
503
- v.mc3_18(weights=v.MC3_18_Weights.DEFAULT); print('✓ mc3_18')
549
+ v.r3d_18(weights=v.R3D_18_Weights.KINETICS400_V1); print('✓ r3d_18')
550
+ v.mc3_18(weights=v.MC3_18_Weights.KINETICS400_V1); print('✓ mc3_18')
551
+
552
+ # === Timm Models (MaxViT, FastViT, CAFormer, ConvNeXt V2) ===
553
+ import timm
554
+
555
+ timm_models = [
556
+ # MaxViT (no suffix - timm resolves to default)
557
+ 'maxvit_tiny_tf_224', 'maxvit_small_tf_224', 'maxvit_base_tf_224',
558
+ # FastViT (no suffix)
559
+ 'fastvit_t8', 'fastvit_t12', 'fastvit_s12', 'fastvit_sa12',
560
+ # CAFormer/PoolFormer (no suffix)
561
+ 'caformer_s18', 'caformer_s36', 'caformer_m36', 'poolformer_s12',
562
+ # ConvNeXt V2 (no suffix)
563
+ 'convnextv2_tiny',
564
+ # EfficientViT (no suffix)
565
+ 'efficientvit_m0', 'efficientvit_m1', 'efficientvit_m2',
566
+ 'efficientvit_b0', 'efficientvit_b1', 'efficientvit_b2', 'efficientvit_b3',
567
+ 'efficientvit_l1', 'efficientvit_l2',
568
+ ]
569
+ for name in timm_models:
570
+ timm.create_model(name, pretrained=True); print(f'✓ {name}')
571
+
504
572
  print('\\n✓ All pretrained weights cached!')
505
573
  "
506
574
  ```
@@ -71,7 +71,7 @@ Train on datasets larger than RAM:
71
71
 
72
72
  **🧠 Models? We've Got Options**
73
73
 
74
- 38 architectures, ready to go:
74
+ 69 architectures, ready to go:
75
75
  - CNNs, ResNets, ViTs, EfficientNets...
76
76
  - All adapted for regression
77
77
  - [Add your own](#adding-custom-models) in one line
@@ -156,7 +156,7 @@ Deploy models anywhere:
156
156
  #### From PyPI (recommended for all users)
157
157
 
158
158
  ```bash
159
- pip install wavedl
159
+ pip install --upgrade wavedl
160
160
  ```
161
161
 
162
162
  This installs everything you need: training, inference, HPO, ONNX export.
@@ -312,22 +312,10 @@ WaveDL/
312
312
  │ ├── hpo.py # Hyperparameter optimization
313
313
  │ ├── hpc.py # HPC distributed training launcher
314
314
  │ │
315
- │ ├── models/ # Model architectures (38 variants)
315
+ │ ├── models/ # Model Zoo (69 architectures)
316
316
  │ │ ├── registry.py # Model factory (@register_model)
317
317
  │ │ ├── base.py # Abstract base class
318
- │ │ ├── cnn.py # Baseline CNN (1D/2D/3D)
319
- │ │ ├── resnet.py # ResNet-18/34/50 (1D/2D/3D)
320
- │ │ ├── resnet3d.py # ResNet3D-18, MC3-18 (3D only)
321
- │ │ ├── tcn.py # TCN (1D only)
322
- │ │ ├── efficientnet.py # EfficientNet-B0/B1/B2 (2D)
323
- │ │ ├── efficientnetv2.py # EfficientNetV2-S/M/L (2D)
324
- │ │ ├── mobilenetv3.py # MobileNetV3-Small/Large (2D)
325
- │ │ ├── regnet.py # RegNetY variants (2D)
326
- │ │ ├── swin.py # Swin Transformer (2D)
327
- │ │ ├── vit.py # Vision Transformer (1D/2D)
328
- │ │ ├── convnext.py # ConvNeXt (1D/2D/3D)
329
- │ │ ├── densenet.py # DenseNet-121/169 (1D/2D/3D)
330
- │ │ └── unet.py # U-Net Regression
318
+ │ │ └── ... # See "Available Models" section
331
319
  │ │
332
320
  │ └── utils/ # Utilities
333
321
  │ ├── data.py # Memory-mapped data pipeline
@@ -342,7 +330,7 @@ WaveDL/
342
330
  ├── configs/ # YAML config templates
343
331
  ├── examples/ # Ready-to-run examples
344
332
  ├── notebooks/ # Jupyter notebooks
345
- ├── unit_tests/ # Pytest test suite (903 tests)
333
+ ├── unit_tests/ # Pytest test suite
346
334
 
347
335
  ├── pyproject.toml # Package config, dependencies
348
336
  ├── CHANGELOG.md # Version history
@@ -365,71 +353,117 @@ WaveDL/
365
353
  > ```
366
354
 
367
355
  <details>
368
- <summary><b>Available Models</b> — 38 architectures</summary>
356
+ <summary><b>Available Models</b> — 69 architectures</summary>
369
357
 
370
- | Model | Params | Dim |
371
- |-------|--------|-----|
358
+ | Model | Backbone Params | Dim |
359
+ |-------|-----------------|-----|
360
+ | **── Classic CNNs ──** |||
372
361
  | **CNN** — Convolutional Neural Network |||
373
- | `cnn` | 1.7M | 1D/2D/3D |
362
+ | `cnn` | 1.6M | 1D/2D/3D |
374
363
  | **ResNet** — Residual Network |||
375
- | `resnet18` | 11.4M | 1D/2D/3D |
376
- | `resnet34` | 21.5M | 1D/2D/3D |
377
- | `resnet50` | 24.6M | 1D/2D/3D |
378
- | `resnet18_pretrained` ⭐ | 11.4M | 2D |
379
- | `resnet50_pretrained` ⭐ | 24.6M | 2D |
380
- | **ResNet3D** — 3D Residual Network |||
381
- | `resnet3d_18` | 33.6M | 3D |
382
- | `mc3_18` — Mixed Convolution 3D | 11.9M | 3D |
383
- | **TCN** Temporal Convolutional Network |||
384
- | `tcn_small` | 1.0M | 1D |
385
- | `tcn` | 7.0M | 1D |
386
- | `tcn_large` | 10.2M | 1D |
364
+ | `resnet18` | 11.2M | 1D/2D/3D |
365
+ | `resnet34` | 21.3M | 1D/2D/3D |
366
+ | `resnet50` | 23.5M | 1D/2D/3D |
367
+ | `resnet18_pretrained` ⭐ | 11.2M | 2D |
368
+ | `resnet50_pretrained` ⭐ | 23.5M | 2D |
369
+ | **DenseNet** — Densely Connected Network |||
370
+ | `densenet121` | 7.0M | 1D/2D/3D |
371
+ | `densenet169` | 12.5M | 1D/2D/3D |
372
+ | `densenet121_pretrained` | 7.0M | 2D |
373
+ | **── Efficient/Mobile CNNs ──** |||
374
+ | **MobileNetV3** Mobile Neural Network V3 |||
375
+ | `mobilenet_v3_small` | 0.9M | 2D |
376
+ | `mobilenet_v3_large` ⭐ | 3.0M | 2D |
387
377
  | **EfficientNet** — Efficient Neural Network |||
388
- | `efficientnet_b0` ⭐ | 4.7M | 2D |
389
- | `efficientnet_b1` ⭐ | 7.2M | 2D |
390
- | `efficientnet_b2` ⭐ | 8.4M | 2D |
378
+ | `efficientnet_b0` ⭐ | 4.0M | 2D |
379
+ | `efficientnet_b1` ⭐ | 6.5M | 2D |
380
+ | `efficientnet_b2` ⭐ | 7.7M | 2D |
391
381
  | **EfficientNetV2** — Efficient Neural Network V2 |||
392
- | `efficientnet_v2_s` ⭐ | 21.0M | 2D |
393
- | `efficientnet_v2_m` ⭐ | 53.6M | 2D |
394
- | `efficientnet_v2_l` ⭐ | 118.0M | 2D |
395
- | **MobileNetV3** — Mobile Neural Network V3 |||
396
- | `mobilenet_v3_small` ⭐ | 1.1M | 2D |
397
- | `mobilenet_v3_large` ⭐ | 3.2M | 2D |
382
+ | `efficientnet_v2_s` ⭐ | 20.2M | 2D |
383
+ | `efficientnet_v2_m` ⭐ | 52.9M | 2D |
384
+ | `efficientnet_v2_l` ⭐ | 117.2M | 2D |
398
385
  | **RegNet** — Regularized Network |||
399
- | `regnet_y_400mf` ⭐ | 4.0M | 2D |
400
- | `regnet_y_800mf` ⭐ | 5.8M | 2D |
401
- | `regnet_y_1_6gf` ⭐ | 10.5M | 2D |
402
- | `regnet_y_3_2gf` ⭐ | 18.3M | 2D |
403
- | `regnet_y_8gf` ⭐ | 37.9M | 2D |
404
- | **Swin** Shifted Window Transformer |||
405
- | `swin_t` ⭐ | 28.0M | 2D |
406
- | `swin_s` ⭐ | 49.4M | 2D |
407
- | `swin_b` ⭐ | 87.4M | 2D |
386
+ | `regnet_y_400mf` ⭐ | 3.9M | 2D |
387
+ | `regnet_y_800mf` ⭐ | 5.7M | 2D |
388
+ | `regnet_y_1_6gf` ⭐ | 10.3M | 2D |
389
+ | `regnet_y_3_2gf` ⭐ | 17.9M | 2D |
390
+ | `regnet_y_8gf` ⭐ | 37.4M | 2D |
391
+ | **── Modern CNNs ──** |||
408
392
  | **ConvNeXt** — Convolutional Next |||
409
- | `convnext_tiny` | 28.2M | 1D/2D/3D |
410
- | `convnext_small` | 49.8M | 1D/2D/3D |
411
- | `convnext_base` | 88.1M | 1D/2D/3D |
412
- | `convnext_tiny_pretrained` ⭐ | 28.2M | 2D |
413
- | **DenseNet** — Densely Connected Network |||
414
- | `densenet121` | 7.5M | 1D/2D/3D |
415
- | `densenet169` | 13.3M | 1D/2D/3D |
416
- | `densenet121_pretrained` | 7.5M | 2D |
393
+ | `convnext_tiny` | 27.8M | 1D/2D/3D |
394
+ | `convnext_small` | 49.5M | 1D/2D/3D |
395
+ | `convnext_base` | 87.6M | 1D/2D/3D |
396
+ | `convnext_tiny_pretrained` ⭐ | 27.8M | 2D |
397
+ | **ConvNeXt V2** — ConvNeXt with GRN |||
398
+ | `convnext_v2_tiny` | 27.9M | 1D/2D/3D |
399
+ | `convnext_v2_small` | 49.6M | 1D/2D/3D |
400
+ | `convnext_v2_base` | 87.7M | 1D/2D/3D |
401
+ | `convnext_v2_tiny_pretrained` ⭐ | 27.9M | 2D |
402
+ | **UniRepLKNet** — Large-Kernel ConvNet |||
403
+ | `unireplknet_tiny` | 30.8M | 1D/2D/3D |
404
+ | `unireplknet_small` | 56.0M | 1D/2D/3D |
405
+ | `unireplknet_base` | 97.6M | 1D/2D/3D |
406
+ | **── Vision Transformers ──** |||
417
407
  | **ViT** — Vision Transformer |||
418
- | `vit_tiny` | 5.5M | 1D/2D |
419
- | `vit_small` | 21.6M | 1D/2D |
420
- | `vit_base` | 85.6M | 1D/2D |
408
+ | `vit_tiny` | 5.4M | 1D/2D |
409
+ | `vit_small` | 21.4M | 1D/2D |
410
+ | `vit_base` | 85.3M | 1D/2D |
411
+ | **Swin** — Shifted Window Transformer |||
412
+ | `swin_t` ⭐ | 27.5M | 2D |
413
+ | `swin_s` ⭐ | 48.8M | 2D |
414
+ | `swin_b` ⭐ | 86.7M | 2D |
415
+ | **MaxViT** — Multi-Axis ViT |||
416
+ | `maxvit_tiny` ⭐ | 30.1M | 2D |
417
+ | `maxvit_small` ⭐ | 67.6M | 2D |
418
+ | `maxvit_base` ⭐ | 119.1M | 2D |
419
+ | **── Hybrid CNN-Transformer ──** |||
420
+ | **FastViT** — Fast Hybrid CNN-ViT |||
421
+ | `fastvit_t8` ⭐ | 4.0M | 2D |
422
+ | `fastvit_t12` ⭐ | 6.8M | 2D |
423
+ | `fastvit_s12` ⭐ | 8.8M | 2D |
424
+ | `fastvit_sa12` ⭐ | 10.9M | 2D |
425
+ | **CAFormer** — MetaFormer with Attention |||
426
+ | `caformer_s18` ⭐ | 26.3M | 2D |
427
+ | `caformer_s36` ⭐ | 39.2M | 2D |
428
+ | `caformer_m36` ⭐ | 56.9M | 2D |
429
+ | `poolformer_s12` ⭐ | 11.9M | 2D |
430
+ | **EfficientViT** — Memory-Efficient ViT |||
431
+ | `efficientvit_m0` ⭐ | 2.2M | 2D |
432
+ | `efficientvit_m1` ⭐ | 2.6M | 2D |
433
+ | `efficientvit_m2` ⭐ | 3.8M | 2D |
434
+ | `efficientvit_b0` ⭐ | 2.1M | 2D |
435
+ | `efficientvit_b1` ⭐ | 7.5M | 2D |
436
+ | `efficientvit_b2` ⭐ | 21.8M | 2D |
437
+ | `efficientvit_b3` ⭐ | 46.1M | 2D |
438
+ | `efficientvit_l1` ⭐ | 49.5M | 2D |
439
+ | `efficientvit_l2` ⭐ | 60.5M | 2D |
440
+ | **── State Space Models ──** |||
441
+ | **Mamba** — State Space Model |||
442
+ | `mamba_1d` | 3.4M | 1D |
443
+ | **Vision Mamba (ViM)** — 2D Mamba |||
444
+ | `vim_tiny` | 6.6M | 2D |
445
+ | `vim_small` | 51.1M | 2D |
446
+ | `vim_base` | 201.4M | 2D |
447
+ | **── Specialized Architectures ──** |||
448
+ | **TCN** — Temporal Convolutional Network |||
449
+ | `tcn_small` | 0.9M | 1D |
450
+ | `tcn` | 6.9M | 1D |
451
+ | `tcn_large` | 10.0M | 1D |
452
+ | **ResNet3D** — 3D Residual Network |||
453
+ | `resnet3d_18` | 33.2M | 3D |
454
+ | `mc3_18` — Mixed Convolution 3D | 11.5M | 3D |
421
455
  | **U-Net** — U-shaped Network |||
422
- | `unet_regression` | 31.1M | 1D/2D/3D |
456
+ | `unet_regression` | 31.0M | 1D/2D/3D |
457
+
423
458
 
424
459
  ⭐ = **Pretrained on ImageNet** (recommended for smaller datasets). Weights are downloaded automatically on first use.
425
460
  - **Cache location**: `~/.cache/torch/hub/checkpoints/` (or `./.torch_cache/` on HPC if home is not writable)
426
- - **Size**: ~20–350 MB per model depending on architecture
427
461
  - **Train from scratch**: Use `--no_pretrained` to disable pretrained weights
428
462
 
429
463
  **💡 HPC Users**: If compute nodes block internet, pre-download weights on the login node:
430
464
 
431
465
  ```bash
432
- # Run once on login node (with internet) — downloads ALL pretrained weights (~1.5 GB total)
466
+ # Run once on login node (with internet) — downloads ALL pretrained weights
433
467
  python -c "
434
468
  import os
435
469
  os.environ['TORCH_HOME'] = '.torch_cache' # Match WaveDL's HPC cache location
@@ -437,24 +471,56 @@ os.environ['TORCH_HOME'] = '.torch_cache' # Match WaveDL's HPC cache location
437
471
  from torchvision import models as m
438
472
  from torchvision.models import video as v
439
473
 
440
- # Model name -> Weights class mapping
441
- weights = {
442
- 'resnet18': m.ResNet18_Weights, 'resnet50': m.ResNet50_Weights,
443
- 'efficientnet_b0': m.EfficientNet_B0_Weights, 'efficientnet_b1': m.EfficientNet_B1_Weights,
444
- 'efficientnet_b2': m.EfficientNet_B2_Weights, 'efficientnet_v2_s': m.EfficientNet_V2_S_Weights,
445
- 'efficientnet_v2_m': m.EfficientNet_V2_M_Weights, 'efficientnet_v2_l': m.EfficientNet_V2_L_Weights,
446
- 'mobilenet_v3_small': m.MobileNet_V3_Small_Weights, 'mobilenet_v3_large': m.MobileNet_V3_Large_Weights,
447
- 'regnet_y_400mf': m.RegNet_Y_400MF_Weights, 'regnet_y_800mf': m.RegNet_Y_800MF_Weights,
448
- 'regnet_y_1_6gf': m.RegNet_Y_1_6GF_Weights, 'regnet_y_3_2gf': m.RegNet_Y_3_2GF_Weights,
449
- 'regnet_y_8gf': m.RegNet_Y_8GF_Weights, 'swin_t': m.Swin_T_Weights, 'swin_s': m.Swin_S_Weights,
450
- 'swin_b': m.Swin_B_Weights, 'convnext_tiny': m.ConvNeXt_Tiny_Weights, 'densenet121': m.DenseNet121_Weights,
451
- }
452
- for name, w in weights.items():
453
- getattr(m, name)(weights=w.DEFAULT); print(f'✓ {name}')
474
+ # === TorchVision Models (use IMAGENET1K_V1 to match WaveDL) ===
475
+ models = [
476
+ ('resnet18', m.ResNet18_Weights.IMAGENET1K_V1),
477
+ ('resnet50', m.ResNet50_Weights.IMAGENET1K_V1),
478
+ ('efficientnet_b0', m.EfficientNet_B0_Weights.IMAGENET1K_V1),
479
+ ('efficientnet_b1', m.EfficientNet_B1_Weights.IMAGENET1K_V1),
480
+ ('efficientnet_b2', m.EfficientNet_B2_Weights.IMAGENET1K_V1),
481
+ ('efficientnet_v2_s', m.EfficientNet_V2_S_Weights.IMAGENET1K_V1),
482
+ ('efficientnet_v2_m', m.EfficientNet_V2_M_Weights.IMAGENET1K_V1),
483
+ ('efficientnet_v2_l', m.EfficientNet_V2_L_Weights.IMAGENET1K_V1),
484
+ ('mobilenet_v3_small', m.MobileNet_V3_Small_Weights.IMAGENET1K_V1),
485
+ ('mobilenet_v3_large', m.MobileNet_V3_Large_Weights.IMAGENET1K_V1),
486
+ ('regnet_y_400mf', m.RegNet_Y_400MF_Weights.IMAGENET1K_V1),
487
+ ('regnet_y_800mf', m.RegNet_Y_800MF_Weights.IMAGENET1K_V1),
488
+ ('regnet_y_1_6gf', m.RegNet_Y_1_6GF_Weights.IMAGENET1K_V1),
489
+ ('regnet_y_3_2gf', m.RegNet_Y_3_2GF_Weights.IMAGENET1K_V1),
490
+ ('regnet_y_8gf', m.RegNet_Y_8GF_Weights.IMAGENET1K_V1),
491
+ ('swin_t', m.Swin_T_Weights.IMAGENET1K_V1),
492
+ ('swin_s', m.Swin_S_Weights.IMAGENET1K_V1),
493
+ ('swin_b', m.Swin_B_Weights.IMAGENET1K_V1),
494
+ ('convnext_tiny', m.ConvNeXt_Tiny_Weights.IMAGENET1K_V1),
495
+ ('densenet121', m.DenseNet121_Weights.IMAGENET1K_V1),
496
+ ]
497
+ for name, w in models:
498
+ getattr(m, name)(weights=w); print(f'✓ {name}')
454
499
 
455
500
  # 3D video models
456
- v.r3d_18(weights=v.R3D_18_Weights.DEFAULT); print('✓ r3d_18')
457
- v.mc3_18(weights=v.MC3_18_Weights.DEFAULT); print('✓ mc3_18')
501
+ v.r3d_18(weights=v.R3D_18_Weights.KINETICS400_V1); print('✓ r3d_18')
502
+ v.mc3_18(weights=v.MC3_18_Weights.KINETICS400_V1); print('✓ mc3_18')
503
+
504
+ # === Timm Models (MaxViT, FastViT, CAFormer, ConvNeXt V2) ===
505
+ import timm
506
+
507
+ timm_models = [
508
+ # MaxViT (no suffix - timm resolves to default)
509
+ 'maxvit_tiny_tf_224', 'maxvit_small_tf_224', 'maxvit_base_tf_224',
510
+ # FastViT (no suffix)
511
+ 'fastvit_t8', 'fastvit_t12', 'fastvit_s12', 'fastvit_sa12',
512
+ # CAFormer/PoolFormer (no suffix)
513
+ 'caformer_s18', 'caformer_s36', 'caformer_m36', 'poolformer_s12',
514
+ # ConvNeXt V2 (no suffix)
515
+ 'convnextv2_tiny',
516
+ # EfficientViT (no suffix)
517
+ 'efficientvit_m0', 'efficientvit_m1', 'efficientvit_m2',
518
+ 'efficientvit_b0', 'efficientvit_b1', 'efficientvit_b2', 'efficientvit_b3',
519
+ 'efficientvit_l1', 'efficientvit_l2',
520
+ ]
521
+ for name in timm_models:
522
+ timm.create_model(name, pretrained=True); print(f'✓ {name}')
523
+
458
524
  print('\\n✓ All pretrained weights cached!')
459
525
  "
460
526
  ```
@@ -52,6 +52,7 @@ dependencies = [
52
52
  # Core ML stack
53
53
  "torch>=2.0.0",
54
54
  "torchvision>=0.15.0",
55
+ "timm>=0.9.0", # Pretrained models (MaxViT, FastViT, CAFormer)
55
56
  "accelerate>=0.20.0",
56
57
  "numpy>=1.24.0",
57
58
  "scipy>=1.10.0",
@@ -70,6 +71,7 @@ dependencies = [
70
71
  # ONNX export
71
72
  "onnx>=1.14.0",
72
73
  "onnxruntime>=1.15.0",
74
+ "onnxscript>=0.1.0", # Required by torch.onnx.export in PyTorch 2.1+
73
75
  # torch.compile backend (Linux only)
74
76
  "triton>=2.0.0; sys_platform == 'linux'",
75
77
  ]
@@ -18,7 +18,7 @@ For inference:
18
18
  # or: python -m wavedl.test --checkpoint best_checkpoint --data_path test.npz
19
19
  """
20
20
 
21
- __version__ = "1.5.7"
21
+ __version__ = "1.6.1"
22
22
  __author__ = "Ductho Le"
23
23
  __email__ = "ductho.le@outlook.com"
24
24