wavedl 1.6.0__tar.gz → 1.6.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (51) hide show
  1. {wavedl-1.6.0/src/wavedl.egg-info → wavedl-1.6.1}/PKG-INFO +93 -53
  2. {wavedl-1.6.0 → wavedl-1.6.1}/README.md +91 -52
  3. {wavedl-1.6.0 → wavedl-1.6.1}/pyproject.toml +1 -0
  4. {wavedl-1.6.0 → wavedl-1.6.1}/src/wavedl/__init__.py +1 -1
  5. {wavedl-1.6.0 → wavedl-1.6.1}/src/wavedl/hpo.py +451 -451
  6. {wavedl-1.6.0 → wavedl-1.6.1}/src/wavedl/models/__init__.py +28 -0
  7. wavedl-1.6.0/src/wavedl/models/_timm_utils.py → wavedl-1.6.1/src/wavedl/models/_pretrained_utils.py +128 -0
  8. {wavedl-1.6.0 → wavedl-1.6.1}/src/wavedl/models/base.py +48 -0
  9. {wavedl-1.6.0 → wavedl-1.6.1}/src/wavedl/models/caformer.py +1 -1
  10. {wavedl-1.6.0 → wavedl-1.6.1}/src/wavedl/models/cnn.py +2 -27
  11. {wavedl-1.6.0 → wavedl-1.6.1}/src/wavedl/models/convnext.py +5 -18
  12. {wavedl-1.6.0 → wavedl-1.6.1}/src/wavedl/models/convnext_v2.py +6 -22
  13. {wavedl-1.6.0 → wavedl-1.6.1}/src/wavedl/models/densenet.py +5 -18
  14. {wavedl-1.6.0 → wavedl-1.6.1}/src/wavedl/models/efficientnetv2.py +315 -315
  15. wavedl-1.6.1/src/wavedl/models/efficientvit.py +398 -0
  16. {wavedl-1.6.0 → wavedl-1.6.1}/src/wavedl/models/fastvit.py +6 -39
  17. {wavedl-1.6.0 → wavedl-1.6.1}/src/wavedl/models/mamba.py +44 -24
  18. {wavedl-1.6.0 → wavedl-1.6.1}/src/wavedl/models/maxvit.py +51 -48
  19. {wavedl-1.6.0 → wavedl-1.6.1}/src/wavedl/models/mobilenetv3.py +295 -295
  20. {wavedl-1.6.0 → wavedl-1.6.1}/src/wavedl/models/regnet.py +406 -406
  21. {wavedl-1.6.0 → wavedl-1.6.1}/src/wavedl/models/resnet.py +14 -56
  22. {wavedl-1.6.0 → wavedl-1.6.1}/src/wavedl/models/resnet3d.py +258 -258
  23. {wavedl-1.6.0 → wavedl-1.6.1}/src/wavedl/models/swin.py +443 -443
  24. {wavedl-1.6.0 → wavedl-1.6.1}/src/wavedl/models/tcn.py +393 -409
  25. {wavedl-1.6.0 → wavedl-1.6.1}/src/wavedl/models/unet.py +1 -5
  26. wavedl-1.6.1/src/wavedl/models/unireplknet.py +491 -0
  27. {wavedl-1.6.0 → wavedl-1.6.1}/src/wavedl/models/vit.py +3 -3
  28. {wavedl-1.6.0 → wavedl-1.6.1}/src/wavedl/train.py +1430 -1430
  29. {wavedl-1.6.0 → wavedl-1.6.1}/src/wavedl/utils/config.py +367 -367
  30. {wavedl-1.6.0 → wavedl-1.6.1}/src/wavedl/utils/cross_validation.py +530 -530
  31. {wavedl-1.6.0 → wavedl-1.6.1}/src/wavedl/utils/losses.py +216 -216
  32. {wavedl-1.6.0 → wavedl-1.6.1}/src/wavedl/utils/optimizers.py +216 -216
  33. {wavedl-1.6.0 → wavedl-1.6.1}/src/wavedl/utils/schedulers.py +251 -251
  34. {wavedl-1.6.0 → wavedl-1.6.1/src/wavedl.egg-info}/PKG-INFO +93 -53
  35. {wavedl-1.6.0 → wavedl-1.6.1}/src/wavedl.egg-info/SOURCES.txt +3 -1
  36. {wavedl-1.6.0 → wavedl-1.6.1}/src/wavedl.egg-info/requires.txt +1 -0
  37. {wavedl-1.6.0 → wavedl-1.6.1}/LICENSE +0 -0
  38. {wavedl-1.6.0 → wavedl-1.6.1}/setup.cfg +0 -0
  39. {wavedl-1.6.0 → wavedl-1.6.1}/src/wavedl/hpc.py +0 -0
  40. {wavedl-1.6.0 → wavedl-1.6.1}/src/wavedl/models/_template.py +0 -0
  41. {wavedl-1.6.0 → wavedl-1.6.1}/src/wavedl/models/efficientnet.py +0 -0
  42. {wavedl-1.6.0 → wavedl-1.6.1}/src/wavedl/models/registry.py +0 -0
  43. {wavedl-1.6.0 → wavedl-1.6.1}/src/wavedl/test.py +0 -0
  44. {wavedl-1.6.0 → wavedl-1.6.1}/src/wavedl/utils/__init__.py +0 -0
  45. {wavedl-1.6.0 → wavedl-1.6.1}/src/wavedl/utils/constraints.py +0 -0
  46. {wavedl-1.6.0 → wavedl-1.6.1}/src/wavedl/utils/data.py +0 -0
  47. {wavedl-1.6.0 → wavedl-1.6.1}/src/wavedl/utils/distributed.py +0 -0
  48. {wavedl-1.6.0 → wavedl-1.6.1}/src/wavedl/utils/metrics.py +0 -0
  49. {wavedl-1.6.0 → wavedl-1.6.1}/src/wavedl.egg-info/dependency_links.txt +0 -0
  50. {wavedl-1.6.0 → wavedl-1.6.1}/src/wavedl.egg-info/entry_points.txt +0 -0
  51. {wavedl-1.6.0 → wavedl-1.6.1}/src/wavedl.egg-info/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.2
2
2
  Name: wavedl
3
- Version: 1.6.0
3
+ Version: 1.6.1
4
4
  Summary: A Scalable Deep Learning Framework for Wave-Based Inverse Problems
5
5
  Author: Ductho Le
6
6
  License: MIT
@@ -38,6 +38,7 @@ Requires-Dist: wandb>=0.15.0
38
38
  Requires-Dist: optuna>=3.0.0
39
39
  Requires-Dist: onnx>=1.14.0
40
40
  Requires-Dist: onnxruntime>=1.15.0
41
+ Requires-Dist: onnxscript>=0.1.0
41
42
  Requires-Dist: triton>=2.0.0; sys_platform == "linux"
42
43
  Provides-Extra: dev
43
44
  Requires-Dist: pytest>=7.0.0; extra == "dev"
@@ -118,7 +119,7 @@ Train on datasets larger than RAM:
118
119
 
119
120
  **🧠 Models? We've Got Options**
120
121
 
121
- 57 architectures, ready to go:
122
+ 69 architectures, ready to go:
122
123
  - CNNs, ResNets, ViTs, EfficientNets...
123
124
  - All adapted for regression
124
125
  - [Add your own](#adding-custom-models) in one line
@@ -359,7 +360,7 @@ WaveDL/
359
360
  │ ├── hpo.py # Hyperparameter optimization
360
361
  │ ├── hpc.py # HPC distributed training launcher
361
362
  │ │
362
- │ ├── models/ # Model Zoo (57 architectures)
363
+ │ ├── models/ # Model Zoo (69 architectures)
363
364
  │ │ ├── registry.py # Model factory (@register_model)
364
365
  │ │ ├── base.py # Abstract base class
365
366
  │ │ └── ... # See "Available Models" section
@@ -400,10 +401,11 @@ WaveDL/
400
401
  > ```
401
402
 
402
403
  <details>
403
- <summary><b>Available Models</b> — 57 architectures</summary>
404
+ <summary><b>Available Models</b> — 69 architectures</summary>
404
405
 
405
406
  | Model | Backbone Params | Dim |
406
407
  |-------|-----------------|-----|
408
+ | **── Classic CNNs ──** |||
407
409
  | **CNN** — Convolutional Neural Network |||
408
410
  | `cnn` | 1.6M | 1D/2D/3D |
409
411
  | **ResNet** — Residual Network |||
@@ -412,13 +414,14 @@ WaveDL/
412
414
  | `resnet50` | 23.5M | 1D/2D/3D |
413
415
  | `resnet18_pretrained` ⭐ | 11.2M | 2D |
414
416
  | `resnet50_pretrained` ⭐ | 23.5M | 2D |
415
- | **ResNet3D** — 3D Residual Network |||
416
- | `resnet3d_18` | 33.2M | 3D |
417
- | `mc3_18` — Mixed Convolution 3D | 11.5M | 3D |
418
- | **TCN** Temporal Convolutional Network |||
419
- | `tcn_small` | 0.9M | 1D |
420
- | `tcn` | 6.9M | 1D |
421
- | `tcn_large` | 10.0M | 1D |
417
+ | **DenseNet** — Densely Connected Network |||
418
+ | `densenet121` | 7.0M | 1D/2D/3D |
419
+ | `densenet169` | 12.5M | 1D/2D/3D |
420
+ | `densenet121_pretrained` | 7.0M | 2D |
421
+ | **── Efficient/Mobile CNNs ──** |||
422
+ | **MobileNetV3** Mobile Neural Network V3 |||
423
+ | `mobilenet_v3_small` | 0.9M | 2D |
424
+ | `mobilenet_v3_large` ⭐ | 3.0M | 2D |
422
425
  | **EfficientNet** — Efficient Neural Network |||
423
426
  | `efficientnet_b0` ⭐ | 4.0M | 2D |
424
427
  | `efficientnet_b1` ⭐ | 6.5M | 2D |
@@ -427,47 +430,41 @@ WaveDL/
427
430
  | `efficientnet_v2_s` ⭐ | 20.2M | 2D |
428
431
  | `efficientnet_v2_m` ⭐ | 52.9M | 2D |
429
432
  | `efficientnet_v2_l` ⭐ | 117.2M | 2D |
430
- | **MobileNetV3** — Mobile Neural Network V3 |||
431
- | `mobilenet_v3_small` ⭐ | 0.9M | 2D |
432
- | `mobilenet_v3_large` ⭐ | 3.0M | 2D |
433
433
  | **RegNet** — Regularized Network |||
434
434
  | `regnet_y_400mf` ⭐ | 3.9M | 2D |
435
435
  | `regnet_y_800mf` ⭐ | 5.7M | 2D |
436
436
  | `regnet_y_1_6gf` ⭐ | 10.3M | 2D |
437
437
  | `regnet_y_3_2gf` ⭐ | 17.9M | 2D |
438
438
  | `regnet_y_8gf` ⭐ | 37.4M | 2D |
439
- | **Swin** Shifted Window Transformer |||
440
- | `swin_t` ⭐ | 27.5M | 2D |
441
- | `swin_s` ⭐ | 48.8M | 2D |
442
- | `swin_b` ⭐ | 86.7M | 2D |
439
+ | **── Modern CNNs ──** |||
443
440
  | **ConvNeXt** — Convolutional Next |||
444
441
  | `convnext_tiny` | 27.8M | 1D/2D/3D |
445
442
  | `convnext_small` | 49.5M | 1D/2D/3D |
446
443
  | `convnext_base` | 87.6M | 1D/2D/3D |
447
444
  | `convnext_tiny_pretrained` ⭐ | 27.8M | 2D |
448
- | **DenseNet** — Densely Connected Network |||
449
- | `densenet121` | 7.0M | 1D/2D/3D |
450
- | `densenet169` | 12.5M | 1D/2D/3D |
451
- | `densenet121_pretrained` ⭐ | 7.0M | 2D |
452
- | **ViT** — Vision Transformer |||
453
- | `vit_tiny` | 5.4M | 1D/2D |
454
- | `vit_small` | 21.4M | 1D/2D |
455
- | `vit_base` | 85.3M | 1D/2D |
456
445
  | **ConvNeXt V2** — ConvNeXt with GRN |||
457
446
  | `convnext_v2_tiny` | 27.9M | 1D/2D/3D |
458
447
  | `convnext_v2_small` | 49.6M | 1D/2D/3D |
459
448
  | `convnext_v2_base` | 87.7M | 1D/2D/3D |
460
449
  | `convnext_v2_tiny_pretrained` ⭐ | 27.9M | 2D |
461
- | **Mamba** — State Space Model |||
462
- | `mamba_1d` | 3.4M | 1D |
463
- | **Vision Mamba (ViM)** 2D Mamba |||
464
- | `vim_tiny` | 6.6M | 2D |
465
- | `vim_small` | 51.1M | 2D |
466
- | `vim_base` | 201.4M | 2D |
450
+ | **UniRepLKNet** — Large-Kernel ConvNet |||
451
+ | `unireplknet_tiny` | 30.8M | 1D/2D/3D |
452
+ | `unireplknet_small` | 56.0M | 1D/2D/3D |
453
+ | `unireplknet_base` | 97.6M | 1D/2D/3D |
454
+ | **── Vision Transformers ──** |||
455
+ | **ViT** Vision Transformer |||
456
+ | `vit_tiny` | 5.4M | 1D/2D |
457
+ | `vit_small` | 21.4M | 1D/2D |
458
+ | `vit_base` | 85.3M | 1D/2D |
459
+ | **Swin** — Shifted Window Transformer |||
460
+ | `swin_t` ⭐ | 27.5M | 2D |
461
+ | `swin_s` ⭐ | 48.8M | 2D |
462
+ | `swin_b` ⭐ | 86.7M | 2D |
467
463
  | **MaxViT** — Multi-Axis ViT |||
468
464
  | `maxvit_tiny` ⭐ | 30.1M | 2D |
469
465
  | `maxvit_small` ⭐ | 67.6M | 2D |
470
466
  | `maxvit_base` ⭐ | 119.1M | 2D |
467
+ | **── Hybrid CNN-Transformer ──** |||
471
468
  | **FastViT** — Fast Hybrid CNN-ViT |||
472
469
  | `fastvit_t8` ⭐ | 4.0M | 2D |
473
470
  | `fastvit_t12` ⭐ | 6.8M | 2D |
@@ -478,6 +475,31 @@ WaveDL/
478
475
  | `caformer_s36` ⭐ | 39.2M | 2D |
479
476
  | `caformer_m36` ⭐ | 56.9M | 2D |
480
477
  | `poolformer_s12` ⭐ | 11.9M | 2D |
478
+ | **EfficientViT** — Memory-Efficient ViT |||
479
+ | `efficientvit_m0` ⭐ | 2.2M | 2D |
480
+ | `efficientvit_m1` ⭐ | 2.6M | 2D |
481
+ | `efficientvit_m2` ⭐ | 3.8M | 2D |
482
+ | `efficientvit_b0` ⭐ | 2.1M | 2D |
483
+ | `efficientvit_b1` ⭐ | 7.5M | 2D |
484
+ | `efficientvit_b2` ⭐ | 21.8M | 2D |
485
+ | `efficientvit_b3` ⭐ | 46.1M | 2D |
486
+ | `efficientvit_l1` ⭐ | 49.5M | 2D |
487
+ | `efficientvit_l2` ⭐ | 60.5M | 2D |
488
+ | **── State Space Models ──** |||
489
+ | **Mamba** — State Space Model |||
490
+ | `mamba_1d` | 3.4M | 1D |
491
+ | **Vision Mamba (ViM)** — 2D Mamba |||
492
+ | `vim_tiny` | 6.6M | 2D |
493
+ | `vim_small` | 51.1M | 2D |
494
+ | `vim_base` | 201.4M | 2D |
495
+ | **── Specialized Architectures ──** |||
496
+ | **TCN** — Temporal Convolutional Network |||
497
+ | `tcn_small` | 0.9M | 1D |
498
+ | `tcn` | 6.9M | 1D |
499
+ | `tcn_large` | 10.0M | 1D |
500
+ | **ResNet3D** — 3D Residual Network |||
501
+ | `resnet3d_18` | 33.2M | 3D |
502
+ | `mc3_18` — Mixed Convolution 3D | 11.5M | 3D |
481
503
  | **U-Net** — U-shaped Network |||
482
504
  | `unet_regression` | 31.0M | 1D/2D/3D |
483
505
 
@@ -497,34 +519,52 @@ os.environ['TORCH_HOME'] = '.torch_cache' # Match WaveDL's HPC cache location
497
519
  from torchvision import models as m
498
520
  from torchvision.models import video as v
499
521
 
500
- # === TorchVision Models ===
501
- weights = {
502
- 'resnet18': m.ResNet18_Weights, 'resnet50': m.ResNet50_Weights,
503
- 'efficientnet_b0': m.EfficientNet_B0_Weights, 'efficientnet_b1': m.EfficientNet_B1_Weights,
504
- 'efficientnet_b2': m.EfficientNet_B2_Weights, 'efficientnet_v2_s': m.EfficientNet_V2_S_Weights,
505
- 'efficientnet_v2_m': m.EfficientNet_V2_M_Weights, 'efficientnet_v2_l': m.EfficientNet_V2_L_Weights,
506
- 'mobilenet_v3_small': m.MobileNet_V3_Small_Weights, 'mobilenet_v3_large': m.MobileNet_V3_Large_Weights,
507
- 'regnet_y_400mf': m.RegNet_Y_400MF_Weights, 'regnet_y_800mf': m.RegNet_Y_800MF_Weights,
508
- 'regnet_y_1_6gf': m.RegNet_Y_1_6GF_Weights, 'regnet_y_3_2gf': m.RegNet_Y_3_2GF_Weights,
509
- 'regnet_y_8gf': m.RegNet_Y_8GF_Weights, 'swin_t': m.Swin_T_Weights, 'swin_s': m.Swin_S_Weights,
510
- 'swin_b': m.Swin_B_Weights, 'convnext_tiny': m.ConvNeXt_Tiny_Weights, 'densenet121': m.DenseNet121_Weights,
511
- }
512
- for name, w in weights.items():
513
- getattr(m, name)(weights=w.DEFAULT); print(f'✓ {name}')
522
+ # === TorchVision Models (use IMAGENET1K_V1 to match WaveDL) ===
523
+ models = [
524
+ ('resnet18', m.ResNet18_Weights.IMAGENET1K_V1),
525
+ ('resnet50', m.ResNet50_Weights.IMAGENET1K_V1),
526
+ ('efficientnet_b0', m.EfficientNet_B0_Weights.IMAGENET1K_V1),
527
+ ('efficientnet_b1', m.EfficientNet_B1_Weights.IMAGENET1K_V1),
528
+ ('efficientnet_b2', m.EfficientNet_B2_Weights.IMAGENET1K_V1),
529
+ ('efficientnet_v2_s', m.EfficientNet_V2_S_Weights.IMAGENET1K_V1),
530
+ ('efficientnet_v2_m', m.EfficientNet_V2_M_Weights.IMAGENET1K_V1),
531
+ ('efficientnet_v2_l', m.EfficientNet_V2_L_Weights.IMAGENET1K_V1),
532
+ ('mobilenet_v3_small', m.MobileNet_V3_Small_Weights.IMAGENET1K_V1),
533
+ ('mobilenet_v3_large', m.MobileNet_V3_Large_Weights.IMAGENET1K_V1),
534
+ ('regnet_y_400mf', m.RegNet_Y_400MF_Weights.IMAGENET1K_V1),
535
+ ('regnet_y_800mf', m.RegNet_Y_800MF_Weights.IMAGENET1K_V1),
536
+ ('regnet_y_1_6gf', m.RegNet_Y_1_6GF_Weights.IMAGENET1K_V1),
537
+ ('regnet_y_3_2gf', m.RegNet_Y_3_2GF_Weights.IMAGENET1K_V1),
538
+ ('regnet_y_8gf', m.RegNet_Y_8GF_Weights.IMAGENET1K_V1),
539
+ ('swin_t', m.Swin_T_Weights.IMAGENET1K_V1),
540
+ ('swin_s', m.Swin_S_Weights.IMAGENET1K_V1),
541
+ ('swin_b', m.Swin_B_Weights.IMAGENET1K_V1),
542
+ ('convnext_tiny', m.ConvNeXt_Tiny_Weights.IMAGENET1K_V1),
543
+ ('densenet121', m.DenseNet121_Weights.IMAGENET1K_V1),
544
+ ]
545
+ for name, w in models:
546
+ getattr(m, name)(weights=w); print(f'✓ {name}')
514
547
 
515
548
  # 3D video models
516
- v.r3d_18(weights=v.R3D_18_Weights.DEFAULT); print('✓ r3d_18')
517
- v.mc3_18(weights=v.MC3_18_Weights.DEFAULT); print('✓ mc3_18')
549
+ v.r3d_18(weights=v.R3D_18_Weights.KINETICS400_V1); print('✓ r3d_18')
550
+ v.mc3_18(weights=v.MC3_18_Weights.KINETICS400_V1); print('✓ mc3_18')
518
551
 
519
552
  # === Timm Models (MaxViT, FastViT, CAFormer, ConvNeXt V2) ===
520
553
  import timm
521
554
 
522
555
  timm_models = [
523
- 'maxvit_tiny_tf_224.in1k', 'maxvit_small_tf_224.in1k', 'maxvit_base_tf_224.in1k',
524
- 'fastvit_t8.apple_in1k', 'fastvit_t12.apple_in1k', 'fastvit_s12.apple_in1k', 'fastvit_sa12.apple_in1k',
525
- 'caformer_s18.sail_in1k', 'caformer_s36.sail_in22k_ft_in1k', 'caformer_m36.sail_in22k_ft_in1k',
526
- 'poolformer_s12.sail_in1k',
527
- 'convnextv2_tiny.fcmae_ft_in1k',
556
+ # MaxViT (no suffix - timm resolves to default)
557
+ 'maxvit_tiny_tf_224', 'maxvit_small_tf_224', 'maxvit_base_tf_224',
558
+ # FastViT (no suffix)
559
+ 'fastvit_t8', 'fastvit_t12', 'fastvit_s12', 'fastvit_sa12',
560
+ # CAFormer/PoolFormer (no suffix)
561
+ 'caformer_s18', 'caformer_s36', 'caformer_m36', 'poolformer_s12',
562
+ # ConvNeXt V2 (no suffix)
563
+ 'convnextv2_tiny',
564
+ # EfficientViT (no suffix)
565
+ 'efficientvit_m0', 'efficientvit_m1', 'efficientvit_m2',
566
+ 'efficientvit_b0', 'efficientvit_b1', 'efficientvit_b2', 'efficientvit_b3',
567
+ 'efficientvit_l1', 'efficientvit_l2',
528
568
  ]
529
569
  for name in timm_models:
530
570
  timm.create_model(name, pretrained=True); print(f'✓ {name}')
@@ -71,7 +71,7 @@ Train on datasets larger than RAM:
71
71
 
72
72
  **🧠 Models? We've Got Options**
73
73
 
74
- 57 architectures, ready to go:
74
+ 69 architectures, ready to go:
75
75
  - CNNs, ResNets, ViTs, EfficientNets...
76
76
  - All adapted for regression
77
77
  - [Add your own](#adding-custom-models) in one line
@@ -312,7 +312,7 @@ WaveDL/
312
312
  │ ├── hpo.py # Hyperparameter optimization
313
313
  │ ├── hpc.py # HPC distributed training launcher
314
314
  │ │
315
- │ ├── models/ # Model Zoo (57 architectures)
315
+ │ ├── models/ # Model Zoo (69 architectures)
316
316
  │ │ ├── registry.py # Model factory (@register_model)
317
317
  │ │ ├── base.py # Abstract base class
318
318
  │ │ └── ... # See "Available Models" section
@@ -353,10 +353,11 @@ WaveDL/
353
353
  > ```
354
354
 
355
355
  <details>
356
- <summary><b>Available Models</b> — 57 architectures</summary>
356
+ <summary><b>Available Models</b> — 69 architectures</summary>
357
357
 
358
358
  | Model | Backbone Params | Dim |
359
359
  |-------|-----------------|-----|
360
+ | **── Classic CNNs ──** |||
360
361
  | **CNN** — Convolutional Neural Network |||
361
362
  | `cnn` | 1.6M | 1D/2D/3D |
362
363
  | **ResNet** — Residual Network |||
@@ -365,13 +366,14 @@ WaveDL/
365
366
  | `resnet50` | 23.5M | 1D/2D/3D |
366
367
  | `resnet18_pretrained` ⭐ | 11.2M | 2D |
367
368
  | `resnet50_pretrained` ⭐ | 23.5M | 2D |
368
- | **ResNet3D** — 3D Residual Network |||
369
- | `resnet3d_18` | 33.2M | 3D |
370
- | `mc3_18` — Mixed Convolution 3D | 11.5M | 3D |
371
- | **TCN** Temporal Convolutional Network |||
372
- | `tcn_small` | 0.9M | 1D |
373
- | `tcn` | 6.9M | 1D |
374
- | `tcn_large` | 10.0M | 1D |
369
+ | **DenseNet** — Densely Connected Network |||
370
+ | `densenet121` | 7.0M | 1D/2D/3D |
371
+ | `densenet169` | 12.5M | 1D/2D/3D |
372
+ | `densenet121_pretrained` | 7.0M | 2D |
373
+ | **── Efficient/Mobile CNNs ──** |||
374
+ | **MobileNetV3** Mobile Neural Network V3 |||
375
+ | `mobilenet_v3_small` | 0.9M | 2D |
376
+ | `mobilenet_v3_large` ⭐ | 3.0M | 2D |
375
377
  | **EfficientNet** — Efficient Neural Network |||
376
378
  | `efficientnet_b0` ⭐ | 4.0M | 2D |
377
379
  | `efficientnet_b1` ⭐ | 6.5M | 2D |
@@ -380,47 +382,41 @@ WaveDL/
380
382
  | `efficientnet_v2_s` ⭐ | 20.2M | 2D |
381
383
  | `efficientnet_v2_m` ⭐ | 52.9M | 2D |
382
384
  | `efficientnet_v2_l` ⭐ | 117.2M | 2D |
383
- | **MobileNetV3** — Mobile Neural Network V3 |||
384
- | `mobilenet_v3_small` ⭐ | 0.9M | 2D |
385
- | `mobilenet_v3_large` ⭐ | 3.0M | 2D |
386
385
  | **RegNet** — Regularized Network |||
387
386
  | `regnet_y_400mf` ⭐ | 3.9M | 2D |
388
387
  | `regnet_y_800mf` ⭐ | 5.7M | 2D |
389
388
  | `regnet_y_1_6gf` ⭐ | 10.3M | 2D |
390
389
  | `regnet_y_3_2gf` ⭐ | 17.9M | 2D |
391
390
  | `regnet_y_8gf` ⭐ | 37.4M | 2D |
392
- | **Swin** Shifted Window Transformer |||
393
- | `swin_t` ⭐ | 27.5M | 2D |
394
- | `swin_s` ⭐ | 48.8M | 2D |
395
- | `swin_b` ⭐ | 86.7M | 2D |
391
+ | **── Modern CNNs ──** |||
396
392
  | **ConvNeXt** — Convolutional Next |||
397
393
  | `convnext_tiny` | 27.8M | 1D/2D/3D |
398
394
  | `convnext_small` | 49.5M | 1D/2D/3D |
399
395
  | `convnext_base` | 87.6M | 1D/2D/3D |
400
396
  | `convnext_tiny_pretrained` ⭐ | 27.8M | 2D |
401
- | **DenseNet** — Densely Connected Network |||
402
- | `densenet121` | 7.0M | 1D/2D/3D |
403
- | `densenet169` | 12.5M | 1D/2D/3D |
404
- | `densenet121_pretrained` ⭐ | 7.0M | 2D |
405
- | **ViT** — Vision Transformer |||
406
- | `vit_tiny` | 5.4M | 1D/2D |
407
- | `vit_small` | 21.4M | 1D/2D |
408
- | `vit_base` | 85.3M | 1D/2D |
409
397
  | **ConvNeXt V2** — ConvNeXt with GRN |||
410
398
  | `convnext_v2_tiny` | 27.9M | 1D/2D/3D |
411
399
  | `convnext_v2_small` | 49.6M | 1D/2D/3D |
412
400
  | `convnext_v2_base` | 87.7M | 1D/2D/3D |
413
401
  | `convnext_v2_tiny_pretrained` ⭐ | 27.9M | 2D |
414
- | **Mamba** — State Space Model |||
415
- | `mamba_1d` | 3.4M | 1D |
416
- | **Vision Mamba (ViM)** 2D Mamba |||
417
- | `vim_tiny` | 6.6M | 2D |
418
- | `vim_small` | 51.1M | 2D |
419
- | `vim_base` | 201.4M | 2D |
402
+ | **UniRepLKNet** — Large-Kernel ConvNet |||
403
+ | `unireplknet_tiny` | 30.8M | 1D/2D/3D |
404
+ | `unireplknet_small` | 56.0M | 1D/2D/3D |
405
+ | `unireplknet_base` | 97.6M | 1D/2D/3D |
406
+ | **── Vision Transformers ──** |||
407
+ | **ViT** Vision Transformer |||
408
+ | `vit_tiny` | 5.4M | 1D/2D |
409
+ | `vit_small` | 21.4M | 1D/2D |
410
+ | `vit_base` | 85.3M | 1D/2D |
411
+ | **Swin** — Shifted Window Transformer |||
412
+ | `swin_t` ⭐ | 27.5M | 2D |
413
+ | `swin_s` ⭐ | 48.8M | 2D |
414
+ | `swin_b` ⭐ | 86.7M | 2D |
420
415
  | **MaxViT** — Multi-Axis ViT |||
421
416
  | `maxvit_tiny` ⭐ | 30.1M | 2D |
422
417
  | `maxvit_small` ⭐ | 67.6M | 2D |
423
418
  | `maxvit_base` ⭐ | 119.1M | 2D |
419
+ | **── Hybrid CNN-Transformer ──** |||
424
420
  | **FastViT** — Fast Hybrid CNN-ViT |||
425
421
  | `fastvit_t8` ⭐ | 4.0M | 2D |
426
422
  | `fastvit_t12` ⭐ | 6.8M | 2D |
@@ -431,6 +427,31 @@ WaveDL/
431
427
  | `caformer_s36` ⭐ | 39.2M | 2D |
432
428
  | `caformer_m36` ⭐ | 56.9M | 2D |
433
429
  | `poolformer_s12` ⭐ | 11.9M | 2D |
430
+ | **EfficientViT** — Memory-Efficient ViT |||
431
+ | `efficientvit_m0` ⭐ | 2.2M | 2D |
432
+ | `efficientvit_m1` ⭐ | 2.6M | 2D |
433
+ | `efficientvit_m2` ⭐ | 3.8M | 2D |
434
+ | `efficientvit_b0` ⭐ | 2.1M | 2D |
435
+ | `efficientvit_b1` ⭐ | 7.5M | 2D |
436
+ | `efficientvit_b2` ⭐ | 21.8M | 2D |
437
+ | `efficientvit_b3` ⭐ | 46.1M | 2D |
438
+ | `efficientvit_l1` ⭐ | 49.5M | 2D |
439
+ | `efficientvit_l2` ⭐ | 60.5M | 2D |
440
+ | **── State Space Models ──** |||
441
+ | **Mamba** — State Space Model |||
442
+ | `mamba_1d` | 3.4M | 1D |
443
+ | **Vision Mamba (ViM)** — 2D Mamba |||
444
+ | `vim_tiny` | 6.6M | 2D |
445
+ | `vim_small` | 51.1M | 2D |
446
+ | `vim_base` | 201.4M | 2D |
447
+ | **── Specialized Architectures ──** |||
448
+ | **TCN** — Temporal Convolutional Network |||
449
+ | `tcn_small` | 0.9M | 1D |
450
+ | `tcn` | 6.9M | 1D |
451
+ | `tcn_large` | 10.0M | 1D |
452
+ | **ResNet3D** — 3D Residual Network |||
453
+ | `resnet3d_18` | 33.2M | 3D |
454
+ | `mc3_18` — Mixed Convolution 3D | 11.5M | 3D |
434
455
  | **U-Net** — U-shaped Network |||
435
456
  | `unet_regression` | 31.0M | 1D/2D/3D |
436
457
 
@@ -450,34 +471,52 @@ os.environ['TORCH_HOME'] = '.torch_cache' # Match WaveDL's HPC cache location
450
471
  from torchvision import models as m
451
472
  from torchvision.models import video as v
452
473
 
453
- # === TorchVision Models ===
454
- weights = {
455
- 'resnet18': m.ResNet18_Weights, 'resnet50': m.ResNet50_Weights,
456
- 'efficientnet_b0': m.EfficientNet_B0_Weights, 'efficientnet_b1': m.EfficientNet_B1_Weights,
457
- 'efficientnet_b2': m.EfficientNet_B2_Weights, 'efficientnet_v2_s': m.EfficientNet_V2_S_Weights,
458
- 'efficientnet_v2_m': m.EfficientNet_V2_M_Weights, 'efficientnet_v2_l': m.EfficientNet_V2_L_Weights,
459
- 'mobilenet_v3_small': m.MobileNet_V3_Small_Weights, 'mobilenet_v3_large': m.MobileNet_V3_Large_Weights,
460
- 'regnet_y_400mf': m.RegNet_Y_400MF_Weights, 'regnet_y_800mf': m.RegNet_Y_800MF_Weights,
461
- 'regnet_y_1_6gf': m.RegNet_Y_1_6GF_Weights, 'regnet_y_3_2gf': m.RegNet_Y_3_2GF_Weights,
462
- 'regnet_y_8gf': m.RegNet_Y_8GF_Weights, 'swin_t': m.Swin_T_Weights, 'swin_s': m.Swin_S_Weights,
463
- 'swin_b': m.Swin_B_Weights, 'convnext_tiny': m.ConvNeXt_Tiny_Weights, 'densenet121': m.DenseNet121_Weights,
464
- }
465
- for name, w in weights.items():
466
- getattr(m, name)(weights=w.DEFAULT); print(f'✓ {name}')
474
+ # === TorchVision Models (use IMAGENET1K_V1 to match WaveDL) ===
475
+ models = [
476
+ ('resnet18', m.ResNet18_Weights.IMAGENET1K_V1),
477
+ ('resnet50', m.ResNet50_Weights.IMAGENET1K_V1),
478
+ ('efficientnet_b0', m.EfficientNet_B0_Weights.IMAGENET1K_V1),
479
+ ('efficientnet_b1', m.EfficientNet_B1_Weights.IMAGENET1K_V1),
480
+ ('efficientnet_b2', m.EfficientNet_B2_Weights.IMAGENET1K_V1),
481
+ ('efficientnet_v2_s', m.EfficientNet_V2_S_Weights.IMAGENET1K_V1),
482
+ ('efficientnet_v2_m', m.EfficientNet_V2_M_Weights.IMAGENET1K_V1),
483
+ ('efficientnet_v2_l', m.EfficientNet_V2_L_Weights.IMAGENET1K_V1),
484
+ ('mobilenet_v3_small', m.MobileNet_V3_Small_Weights.IMAGENET1K_V1),
485
+ ('mobilenet_v3_large', m.MobileNet_V3_Large_Weights.IMAGENET1K_V1),
486
+ ('regnet_y_400mf', m.RegNet_Y_400MF_Weights.IMAGENET1K_V1),
487
+ ('regnet_y_800mf', m.RegNet_Y_800MF_Weights.IMAGENET1K_V1),
488
+ ('regnet_y_1_6gf', m.RegNet_Y_1_6GF_Weights.IMAGENET1K_V1),
489
+ ('regnet_y_3_2gf', m.RegNet_Y_3_2GF_Weights.IMAGENET1K_V1),
490
+ ('regnet_y_8gf', m.RegNet_Y_8GF_Weights.IMAGENET1K_V1),
491
+ ('swin_t', m.Swin_T_Weights.IMAGENET1K_V1),
492
+ ('swin_s', m.Swin_S_Weights.IMAGENET1K_V1),
493
+ ('swin_b', m.Swin_B_Weights.IMAGENET1K_V1),
494
+ ('convnext_tiny', m.ConvNeXt_Tiny_Weights.IMAGENET1K_V1),
495
+ ('densenet121', m.DenseNet121_Weights.IMAGENET1K_V1),
496
+ ]
497
+ for name, w in models:
498
+ getattr(m, name)(weights=w); print(f'✓ {name}')
467
499
 
468
500
  # 3D video models
469
- v.r3d_18(weights=v.R3D_18_Weights.DEFAULT); print('✓ r3d_18')
470
- v.mc3_18(weights=v.MC3_18_Weights.DEFAULT); print('✓ mc3_18')
501
+ v.r3d_18(weights=v.R3D_18_Weights.KINETICS400_V1); print('✓ r3d_18')
502
+ v.mc3_18(weights=v.MC3_18_Weights.KINETICS400_V1); print('✓ mc3_18')
471
503
 
472
504
  # === Timm Models (MaxViT, FastViT, CAFormer, ConvNeXt V2) ===
473
505
  import timm
474
506
 
475
507
  timm_models = [
476
- 'maxvit_tiny_tf_224.in1k', 'maxvit_small_tf_224.in1k', 'maxvit_base_tf_224.in1k',
477
- 'fastvit_t8.apple_in1k', 'fastvit_t12.apple_in1k', 'fastvit_s12.apple_in1k', 'fastvit_sa12.apple_in1k',
478
- 'caformer_s18.sail_in1k', 'caformer_s36.sail_in22k_ft_in1k', 'caformer_m36.sail_in22k_ft_in1k',
479
- 'poolformer_s12.sail_in1k',
480
- 'convnextv2_tiny.fcmae_ft_in1k',
508
+ # MaxViT (no suffix - timm resolves to default)
509
+ 'maxvit_tiny_tf_224', 'maxvit_small_tf_224', 'maxvit_base_tf_224',
510
+ # FastViT (no suffix)
511
+ 'fastvit_t8', 'fastvit_t12', 'fastvit_s12', 'fastvit_sa12',
512
+ # CAFormer/PoolFormer (no suffix)
513
+ 'caformer_s18', 'caformer_s36', 'caformer_m36', 'poolformer_s12',
514
+ # ConvNeXt V2 (no suffix)
515
+ 'convnextv2_tiny',
516
+ # EfficientViT (no suffix)
517
+ 'efficientvit_m0', 'efficientvit_m1', 'efficientvit_m2',
518
+ 'efficientvit_b0', 'efficientvit_b1', 'efficientvit_b2', 'efficientvit_b3',
519
+ 'efficientvit_l1', 'efficientvit_l2',
481
520
  ]
482
521
  for name in timm_models:
483
522
  timm.create_model(name, pretrained=True); print(f'✓ {name}')
@@ -71,6 +71,7 @@ dependencies = [
71
71
  # ONNX export
72
72
  "onnx>=1.14.0",
73
73
  "onnxruntime>=1.15.0",
74
+ "onnxscript>=0.1.0", # Required by torch.onnx.export in PyTorch 2.1+
74
75
  # torch.compile backend (Linux only)
75
76
  "triton>=2.0.0; sys_platform == 'linux'",
76
77
  ]
@@ -18,7 +18,7 @@ For inference:
18
18
  # or: python -m wavedl.test --checkpoint best_checkpoint --data_path test.npz
19
19
  """
20
20
 
21
- __version__ = "1.6.0"
21
+ __version__ = "1.6.1"
22
22
  __author__ = "Ductho Le"
23
23
  __email__ = "ductho.le@outlook.com"
24
24