wavedl 1.6.0__tar.gz → 1.6.1__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {wavedl-1.6.0/src/wavedl.egg-info → wavedl-1.6.1}/PKG-INFO +93 -53
- {wavedl-1.6.0 → wavedl-1.6.1}/README.md +91 -52
- {wavedl-1.6.0 → wavedl-1.6.1}/pyproject.toml +1 -0
- {wavedl-1.6.0 → wavedl-1.6.1}/src/wavedl/__init__.py +1 -1
- {wavedl-1.6.0 → wavedl-1.6.1}/src/wavedl/hpo.py +451 -451
- {wavedl-1.6.0 → wavedl-1.6.1}/src/wavedl/models/__init__.py +28 -0
- wavedl-1.6.0/src/wavedl/models/_timm_utils.py → wavedl-1.6.1/src/wavedl/models/_pretrained_utils.py +128 -0
- {wavedl-1.6.0 → wavedl-1.6.1}/src/wavedl/models/base.py +48 -0
- {wavedl-1.6.0 → wavedl-1.6.1}/src/wavedl/models/caformer.py +1 -1
- {wavedl-1.6.0 → wavedl-1.6.1}/src/wavedl/models/cnn.py +2 -27
- {wavedl-1.6.0 → wavedl-1.6.1}/src/wavedl/models/convnext.py +5 -18
- {wavedl-1.6.0 → wavedl-1.6.1}/src/wavedl/models/convnext_v2.py +6 -22
- {wavedl-1.6.0 → wavedl-1.6.1}/src/wavedl/models/densenet.py +5 -18
- {wavedl-1.6.0 → wavedl-1.6.1}/src/wavedl/models/efficientnetv2.py +315 -315
- wavedl-1.6.1/src/wavedl/models/efficientvit.py +398 -0
- {wavedl-1.6.0 → wavedl-1.6.1}/src/wavedl/models/fastvit.py +6 -39
- {wavedl-1.6.0 → wavedl-1.6.1}/src/wavedl/models/mamba.py +44 -24
- {wavedl-1.6.0 → wavedl-1.6.1}/src/wavedl/models/maxvit.py +51 -48
- {wavedl-1.6.0 → wavedl-1.6.1}/src/wavedl/models/mobilenetv3.py +295 -295
- {wavedl-1.6.0 → wavedl-1.6.1}/src/wavedl/models/regnet.py +406 -406
- {wavedl-1.6.0 → wavedl-1.6.1}/src/wavedl/models/resnet.py +14 -56
- {wavedl-1.6.0 → wavedl-1.6.1}/src/wavedl/models/resnet3d.py +258 -258
- {wavedl-1.6.0 → wavedl-1.6.1}/src/wavedl/models/swin.py +443 -443
- {wavedl-1.6.0 → wavedl-1.6.1}/src/wavedl/models/tcn.py +393 -409
- {wavedl-1.6.0 → wavedl-1.6.1}/src/wavedl/models/unet.py +1 -5
- wavedl-1.6.1/src/wavedl/models/unireplknet.py +491 -0
- {wavedl-1.6.0 → wavedl-1.6.1}/src/wavedl/models/vit.py +3 -3
- {wavedl-1.6.0 → wavedl-1.6.1}/src/wavedl/train.py +1430 -1430
- {wavedl-1.6.0 → wavedl-1.6.1}/src/wavedl/utils/config.py +367 -367
- {wavedl-1.6.0 → wavedl-1.6.1}/src/wavedl/utils/cross_validation.py +530 -530
- {wavedl-1.6.0 → wavedl-1.6.1}/src/wavedl/utils/losses.py +216 -216
- {wavedl-1.6.0 → wavedl-1.6.1}/src/wavedl/utils/optimizers.py +216 -216
- {wavedl-1.6.0 → wavedl-1.6.1}/src/wavedl/utils/schedulers.py +251 -251
- {wavedl-1.6.0 → wavedl-1.6.1/src/wavedl.egg-info}/PKG-INFO +93 -53
- {wavedl-1.6.0 → wavedl-1.6.1}/src/wavedl.egg-info/SOURCES.txt +3 -1
- {wavedl-1.6.0 → wavedl-1.6.1}/src/wavedl.egg-info/requires.txt +1 -0
- {wavedl-1.6.0 → wavedl-1.6.1}/LICENSE +0 -0
- {wavedl-1.6.0 → wavedl-1.6.1}/setup.cfg +0 -0
- {wavedl-1.6.0 → wavedl-1.6.1}/src/wavedl/hpc.py +0 -0
- {wavedl-1.6.0 → wavedl-1.6.1}/src/wavedl/models/_template.py +0 -0
- {wavedl-1.6.0 → wavedl-1.6.1}/src/wavedl/models/efficientnet.py +0 -0
- {wavedl-1.6.0 → wavedl-1.6.1}/src/wavedl/models/registry.py +0 -0
- {wavedl-1.6.0 → wavedl-1.6.1}/src/wavedl/test.py +0 -0
- {wavedl-1.6.0 → wavedl-1.6.1}/src/wavedl/utils/__init__.py +0 -0
- {wavedl-1.6.0 → wavedl-1.6.1}/src/wavedl/utils/constraints.py +0 -0
- {wavedl-1.6.0 → wavedl-1.6.1}/src/wavedl/utils/data.py +0 -0
- {wavedl-1.6.0 → wavedl-1.6.1}/src/wavedl/utils/distributed.py +0 -0
- {wavedl-1.6.0 → wavedl-1.6.1}/src/wavedl/utils/metrics.py +0 -0
- {wavedl-1.6.0 → wavedl-1.6.1}/src/wavedl.egg-info/dependency_links.txt +0 -0
- {wavedl-1.6.0 → wavedl-1.6.1}/src/wavedl.egg-info/entry_points.txt +0 -0
- {wavedl-1.6.0 → wavedl-1.6.1}/src/wavedl.egg-info/top_level.txt +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.2
|
|
2
2
|
Name: wavedl
|
|
3
|
-
Version: 1.6.
|
|
3
|
+
Version: 1.6.1
|
|
4
4
|
Summary: A Scalable Deep Learning Framework for Wave-Based Inverse Problems
|
|
5
5
|
Author: Ductho Le
|
|
6
6
|
License: MIT
|
|
@@ -38,6 +38,7 @@ Requires-Dist: wandb>=0.15.0
|
|
|
38
38
|
Requires-Dist: optuna>=3.0.0
|
|
39
39
|
Requires-Dist: onnx>=1.14.0
|
|
40
40
|
Requires-Dist: onnxruntime>=1.15.0
|
|
41
|
+
Requires-Dist: onnxscript>=0.1.0
|
|
41
42
|
Requires-Dist: triton>=2.0.0; sys_platform == "linux"
|
|
42
43
|
Provides-Extra: dev
|
|
43
44
|
Requires-Dist: pytest>=7.0.0; extra == "dev"
|
|
@@ -118,7 +119,7 @@ Train on datasets larger than RAM:
|
|
|
118
119
|
|
|
119
120
|
**🧠 Models? We've Got Options**
|
|
120
121
|
|
|
121
|
-
|
|
122
|
+
69 architectures, ready to go:
|
|
122
123
|
- CNNs, ResNets, ViTs, EfficientNets...
|
|
123
124
|
- All adapted for regression
|
|
124
125
|
- [Add your own](#adding-custom-models) in one line
|
|
@@ -359,7 +360,7 @@ WaveDL/
|
|
|
359
360
|
│ ├── hpo.py # Hyperparameter optimization
|
|
360
361
|
│ ├── hpc.py # HPC distributed training launcher
|
|
361
362
|
│ │
|
|
362
|
-
│ ├── models/ # Model Zoo (
|
|
363
|
+
│ ├── models/ # Model Zoo (69 architectures)
|
|
363
364
|
│ │ ├── registry.py # Model factory (@register_model)
|
|
364
365
|
│ │ ├── base.py # Abstract base class
|
|
365
366
|
│ │ └── ... # See "Available Models" section
|
|
@@ -400,10 +401,11 @@ WaveDL/
|
|
|
400
401
|
> ```
|
|
401
402
|
|
|
402
403
|
<details>
|
|
403
|
-
<summary><b>Available Models</b> —
|
|
404
|
+
<summary><b>Available Models</b> — 69 architectures</summary>
|
|
404
405
|
|
|
405
406
|
| Model | Backbone Params | Dim |
|
|
406
407
|
|-------|-----------------|-----|
|
|
408
|
+
| **── Classic CNNs ──** |||
|
|
407
409
|
| **CNN** — Convolutional Neural Network |||
|
|
408
410
|
| `cnn` | 1.6M | 1D/2D/3D |
|
|
409
411
|
| **ResNet** — Residual Network |||
|
|
@@ -412,13 +414,14 @@ WaveDL/
|
|
|
412
414
|
| `resnet50` | 23.5M | 1D/2D/3D |
|
|
413
415
|
| `resnet18_pretrained` ⭐ | 11.2M | 2D |
|
|
414
416
|
| `resnet50_pretrained` ⭐ | 23.5M | 2D |
|
|
415
|
-
| **
|
|
416
|
-
| `
|
|
417
|
-
| `
|
|
418
|
-
|
|
|
419
|
-
|
|
|
420
|
-
|
|
|
421
|
-
| `
|
|
417
|
+
| **DenseNet** — Densely Connected Network |||
|
|
418
|
+
| `densenet121` | 7.0M | 1D/2D/3D |
|
|
419
|
+
| `densenet169` | 12.5M | 1D/2D/3D |
|
|
420
|
+
| `densenet121_pretrained` ⭐ | 7.0M | 2D |
|
|
421
|
+
| **── Efficient/Mobile CNNs ──** |||
|
|
422
|
+
| **MobileNetV3** — Mobile Neural Network V3 |||
|
|
423
|
+
| `mobilenet_v3_small` ⭐ | 0.9M | 2D |
|
|
424
|
+
| `mobilenet_v3_large` ⭐ | 3.0M | 2D |
|
|
422
425
|
| **EfficientNet** — Efficient Neural Network |||
|
|
423
426
|
| `efficientnet_b0` ⭐ | 4.0M | 2D |
|
|
424
427
|
| `efficientnet_b1` ⭐ | 6.5M | 2D |
|
|
@@ -427,47 +430,41 @@ WaveDL/
|
|
|
427
430
|
| `efficientnet_v2_s` ⭐ | 20.2M | 2D |
|
|
428
431
|
| `efficientnet_v2_m` ⭐ | 52.9M | 2D |
|
|
429
432
|
| `efficientnet_v2_l` ⭐ | 117.2M | 2D |
|
|
430
|
-
| **MobileNetV3** — Mobile Neural Network V3 |||
|
|
431
|
-
| `mobilenet_v3_small` ⭐ | 0.9M | 2D |
|
|
432
|
-
| `mobilenet_v3_large` ⭐ | 3.0M | 2D |
|
|
433
433
|
| **RegNet** — Regularized Network |||
|
|
434
434
|
| `regnet_y_400mf` ⭐ | 3.9M | 2D |
|
|
435
435
|
| `regnet_y_800mf` ⭐ | 5.7M | 2D |
|
|
436
436
|
| `regnet_y_1_6gf` ⭐ | 10.3M | 2D |
|
|
437
437
|
| `regnet_y_3_2gf` ⭐ | 17.9M | 2D |
|
|
438
438
|
| `regnet_y_8gf` ⭐ | 37.4M | 2D |
|
|
439
|
-
|
|
|
440
|
-
| `swin_t` ⭐ | 27.5M | 2D |
|
|
441
|
-
| `swin_s` ⭐ | 48.8M | 2D |
|
|
442
|
-
| `swin_b` ⭐ | 86.7M | 2D |
|
|
439
|
+
| **── Modern CNNs ──** |||
|
|
443
440
|
| **ConvNeXt** — Convolutional Next |||
|
|
444
441
|
| `convnext_tiny` | 27.8M | 1D/2D/3D |
|
|
445
442
|
| `convnext_small` | 49.5M | 1D/2D/3D |
|
|
446
443
|
| `convnext_base` | 87.6M | 1D/2D/3D |
|
|
447
444
|
| `convnext_tiny_pretrained` ⭐ | 27.8M | 2D |
|
|
448
|
-
| **DenseNet** — Densely Connected Network |||
|
|
449
|
-
| `densenet121` | 7.0M | 1D/2D/3D |
|
|
450
|
-
| `densenet169` | 12.5M | 1D/2D/3D |
|
|
451
|
-
| `densenet121_pretrained` ⭐ | 7.0M | 2D |
|
|
452
|
-
| **ViT** — Vision Transformer |||
|
|
453
|
-
| `vit_tiny` | 5.4M | 1D/2D |
|
|
454
|
-
| `vit_small` | 21.4M | 1D/2D |
|
|
455
|
-
| `vit_base` | 85.3M | 1D/2D |
|
|
456
445
|
| **ConvNeXt V2** — ConvNeXt with GRN |||
|
|
457
446
|
| `convnext_v2_tiny` | 27.9M | 1D/2D/3D |
|
|
458
447
|
| `convnext_v2_small` | 49.6M | 1D/2D/3D |
|
|
459
448
|
| `convnext_v2_base` | 87.7M | 1D/2D/3D |
|
|
460
449
|
| `convnext_v2_tiny_pretrained` ⭐ | 27.9M | 2D |
|
|
461
|
-
| **
|
|
462
|
-
| `
|
|
463
|
-
|
|
|
464
|
-
| `
|
|
465
|
-
|
|
|
466
|
-
|
|
|
450
|
+
| **UniRepLKNet** — Large-Kernel ConvNet |||
|
|
451
|
+
| `unireplknet_tiny` | 30.8M | 1D/2D/3D |
|
|
452
|
+
| `unireplknet_small` | 56.0M | 1D/2D/3D |
|
|
453
|
+
| `unireplknet_base` | 97.6M | 1D/2D/3D |
|
|
454
|
+
| **── Vision Transformers ──** |||
|
|
455
|
+
| **ViT** — Vision Transformer |||
|
|
456
|
+
| `vit_tiny` | 5.4M | 1D/2D |
|
|
457
|
+
| `vit_small` | 21.4M | 1D/2D |
|
|
458
|
+
| `vit_base` | 85.3M | 1D/2D |
|
|
459
|
+
| **Swin** — Shifted Window Transformer |||
|
|
460
|
+
| `swin_t` ⭐ | 27.5M | 2D |
|
|
461
|
+
| `swin_s` ⭐ | 48.8M | 2D |
|
|
462
|
+
| `swin_b` ⭐ | 86.7M | 2D |
|
|
467
463
|
| **MaxViT** — Multi-Axis ViT |||
|
|
468
464
|
| `maxvit_tiny` ⭐ | 30.1M | 2D |
|
|
469
465
|
| `maxvit_small` ⭐ | 67.6M | 2D |
|
|
470
466
|
| `maxvit_base` ⭐ | 119.1M | 2D |
|
|
467
|
+
| **── Hybrid CNN-Transformer ──** |||
|
|
471
468
|
| **FastViT** — Fast Hybrid CNN-ViT |||
|
|
472
469
|
| `fastvit_t8` ⭐ | 4.0M | 2D |
|
|
473
470
|
| `fastvit_t12` ⭐ | 6.8M | 2D |
|
|
@@ -478,6 +475,31 @@ WaveDL/
|
|
|
478
475
|
| `caformer_s36` ⭐ | 39.2M | 2D |
|
|
479
476
|
| `caformer_m36` ⭐ | 56.9M | 2D |
|
|
480
477
|
| `poolformer_s12` ⭐ | 11.9M | 2D |
|
|
478
|
+
| **EfficientViT** — Memory-Efficient ViT |||
|
|
479
|
+
| `efficientvit_m0` ⭐ | 2.2M | 2D |
|
|
480
|
+
| `efficientvit_m1` ⭐ | 2.6M | 2D |
|
|
481
|
+
| `efficientvit_m2` ⭐ | 3.8M | 2D |
|
|
482
|
+
| `efficientvit_b0` ⭐ | 2.1M | 2D |
|
|
483
|
+
| `efficientvit_b1` ⭐ | 7.5M | 2D |
|
|
484
|
+
| `efficientvit_b2` ⭐ | 21.8M | 2D |
|
|
485
|
+
| `efficientvit_b3` ⭐ | 46.1M | 2D |
|
|
486
|
+
| `efficientvit_l1` ⭐ | 49.5M | 2D |
|
|
487
|
+
| `efficientvit_l2` ⭐ | 60.5M | 2D |
|
|
488
|
+
| **── State Space Models ──** |||
|
|
489
|
+
| **Mamba** — State Space Model |||
|
|
490
|
+
| `mamba_1d` | 3.4M | 1D |
|
|
491
|
+
| **Vision Mamba (ViM)** — 2D Mamba |||
|
|
492
|
+
| `vim_tiny` | 6.6M | 2D |
|
|
493
|
+
| `vim_small` | 51.1M | 2D |
|
|
494
|
+
| `vim_base` | 201.4M | 2D |
|
|
495
|
+
| **── Specialized Architectures ──** |||
|
|
496
|
+
| **TCN** — Temporal Convolutional Network |||
|
|
497
|
+
| `tcn_small` | 0.9M | 1D |
|
|
498
|
+
| `tcn` | 6.9M | 1D |
|
|
499
|
+
| `tcn_large` | 10.0M | 1D |
|
|
500
|
+
| **ResNet3D** — 3D Residual Network |||
|
|
501
|
+
| `resnet3d_18` | 33.2M | 3D |
|
|
502
|
+
| `mc3_18` — Mixed Convolution 3D | 11.5M | 3D |
|
|
481
503
|
| **U-Net** — U-shaped Network |||
|
|
482
504
|
| `unet_regression` | 31.0M | 1D/2D/3D |
|
|
483
505
|
|
|
@@ -497,34 +519,52 @@ os.environ['TORCH_HOME'] = '.torch_cache' # Match WaveDL's HPC cache location
|
|
|
497
519
|
from torchvision import models as m
|
|
498
520
|
from torchvision.models import video as v
|
|
499
521
|
|
|
500
|
-
# === TorchVision Models ===
|
|
501
|
-
|
|
502
|
-
'resnet18'
|
|
503
|
-
'
|
|
504
|
-
'
|
|
505
|
-
'
|
|
506
|
-
'
|
|
507
|
-
'
|
|
508
|
-
'
|
|
509
|
-
'
|
|
510
|
-
'
|
|
511
|
-
|
|
512
|
-
|
|
513
|
-
|
|
522
|
+
# === TorchVision Models (use IMAGENET1K_V1 to match WaveDL) ===
|
|
523
|
+
models = [
|
|
524
|
+
('resnet18', m.ResNet18_Weights.IMAGENET1K_V1),
|
|
525
|
+
('resnet50', m.ResNet50_Weights.IMAGENET1K_V1),
|
|
526
|
+
('efficientnet_b0', m.EfficientNet_B0_Weights.IMAGENET1K_V1),
|
|
527
|
+
('efficientnet_b1', m.EfficientNet_B1_Weights.IMAGENET1K_V1),
|
|
528
|
+
('efficientnet_b2', m.EfficientNet_B2_Weights.IMAGENET1K_V1),
|
|
529
|
+
('efficientnet_v2_s', m.EfficientNet_V2_S_Weights.IMAGENET1K_V1),
|
|
530
|
+
('efficientnet_v2_m', m.EfficientNet_V2_M_Weights.IMAGENET1K_V1),
|
|
531
|
+
('efficientnet_v2_l', m.EfficientNet_V2_L_Weights.IMAGENET1K_V1),
|
|
532
|
+
('mobilenet_v3_small', m.MobileNet_V3_Small_Weights.IMAGENET1K_V1),
|
|
533
|
+
('mobilenet_v3_large', m.MobileNet_V3_Large_Weights.IMAGENET1K_V1),
|
|
534
|
+
('regnet_y_400mf', m.RegNet_Y_400MF_Weights.IMAGENET1K_V1),
|
|
535
|
+
('regnet_y_800mf', m.RegNet_Y_800MF_Weights.IMAGENET1K_V1),
|
|
536
|
+
('regnet_y_1_6gf', m.RegNet_Y_1_6GF_Weights.IMAGENET1K_V1),
|
|
537
|
+
('regnet_y_3_2gf', m.RegNet_Y_3_2GF_Weights.IMAGENET1K_V1),
|
|
538
|
+
('regnet_y_8gf', m.RegNet_Y_8GF_Weights.IMAGENET1K_V1),
|
|
539
|
+
('swin_t', m.Swin_T_Weights.IMAGENET1K_V1),
|
|
540
|
+
('swin_s', m.Swin_S_Weights.IMAGENET1K_V1),
|
|
541
|
+
('swin_b', m.Swin_B_Weights.IMAGENET1K_V1),
|
|
542
|
+
('convnext_tiny', m.ConvNeXt_Tiny_Weights.IMAGENET1K_V1),
|
|
543
|
+
('densenet121', m.DenseNet121_Weights.IMAGENET1K_V1),
|
|
544
|
+
]
|
|
545
|
+
for name, w in models:
|
|
546
|
+
getattr(m, name)(weights=w); print(f'✓ {name}')
|
|
514
547
|
|
|
515
548
|
# 3D video models
|
|
516
|
-
v.r3d_18(weights=v.R3D_18_Weights.
|
|
517
|
-
v.mc3_18(weights=v.MC3_18_Weights.
|
|
549
|
+
v.r3d_18(weights=v.R3D_18_Weights.KINETICS400_V1); print('✓ r3d_18')
|
|
550
|
+
v.mc3_18(weights=v.MC3_18_Weights.KINETICS400_V1); print('✓ mc3_18')
|
|
518
551
|
|
|
519
552
|
# === Timm Models (MaxViT, FastViT, CAFormer, ConvNeXt V2) ===
|
|
520
553
|
import timm
|
|
521
554
|
|
|
522
555
|
timm_models = [
|
|
523
|
-
|
|
524
|
-
'
|
|
525
|
-
|
|
526
|
-
'
|
|
527
|
-
|
|
556
|
+
# MaxViT (no suffix - timm resolves to default)
|
|
557
|
+
'maxvit_tiny_tf_224', 'maxvit_small_tf_224', 'maxvit_base_tf_224',
|
|
558
|
+
# FastViT (no suffix)
|
|
559
|
+
'fastvit_t8', 'fastvit_t12', 'fastvit_s12', 'fastvit_sa12',
|
|
560
|
+
# CAFormer/PoolFormer (no suffix)
|
|
561
|
+
'caformer_s18', 'caformer_s36', 'caformer_m36', 'poolformer_s12',
|
|
562
|
+
# ConvNeXt V2 (no suffix)
|
|
563
|
+
'convnextv2_tiny',
|
|
564
|
+
# EfficientViT (no suffix)
|
|
565
|
+
'efficientvit_m0', 'efficientvit_m1', 'efficientvit_m2',
|
|
566
|
+
'efficientvit_b0', 'efficientvit_b1', 'efficientvit_b2', 'efficientvit_b3',
|
|
567
|
+
'efficientvit_l1', 'efficientvit_l2',
|
|
528
568
|
]
|
|
529
569
|
for name in timm_models:
|
|
530
570
|
timm.create_model(name, pretrained=True); print(f'✓ {name}')
|
|
@@ -71,7 +71,7 @@ Train on datasets larger than RAM:
|
|
|
71
71
|
|
|
72
72
|
**🧠 Models? We've Got Options**
|
|
73
73
|
|
|
74
|
-
|
|
74
|
+
69 architectures, ready to go:
|
|
75
75
|
- CNNs, ResNets, ViTs, EfficientNets...
|
|
76
76
|
- All adapted for regression
|
|
77
77
|
- [Add your own](#adding-custom-models) in one line
|
|
@@ -312,7 +312,7 @@ WaveDL/
|
|
|
312
312
|
│ ├── hpo.py # Hyperparameter optimization
|
|
313
313
|
│ ├── hpc.py # HPC distributed training launcher
|
|
314
314
|
│ │
|
|
315
|
-
│ ├── models/ # Model Zoo (
|
|
315
|
+
│ ├── models/ # Model Zoo (69 architectures)
|
|
316
316
|
│ │ ├── registry.py # Model factory (@register_model)
|
|
317
317
|
│ │ ├── base.py # Abstract base class
|
|
318
318
|
│ │ └── ... # See "Available Models" section
|
|
@@ -353,10 +353,11 @@ WaveDL/
|
|
|
353
353
|
> ```
|
|
354
354
|
|
|
355
355
|
<details>
|
|
356
|
-
<summary><b>Available Models</b> —
|
|
356
|
+
<summary><b>Available Models</b> — 69 architectures</summary>
|
|
357
357
|
|
|
358
358
|
| Model | Backbone Params | Dim |
|
|
359
359
|
|-------|-----------------|-----|
|
|
360
|
+
| **── Classic CNNs ──** |||
|
|
360
361
|
| **CNN** — Convolutional Neural Network |||
|
|
361
362
|
| `cnn` | 1.6M | 1D/2D/3D |
|
|
362
363
|
| **ResNet** — Residual Network |||
|
|
@@ -365,13 +366,14 @@ WaveDL/
|
|
|
365
366
|
| `resnet50` | 23.5M | 1D/2D/3D |
|
|
366
367
|
| `resnet18_pretrained` ⭐ | 11.2M | 2D |
|
|
367
368
|
| `resnet50_pretrained` ⭐ | 23.5M | 2D |
|
|
368
|
-
| **
|
|
369
|
-
| `
|
|
370
|
-
| `
|
|
371
|
-
|
|
|
372
|
-
|
|
|
373
|
-
|
|
|
374
|
-
| `
|
|
369
|
+
| **DenseNet** — Densely Connected Network |||
|
|
370
|
+
| `densenet121` | 7.0M | 1D/2D/3D |
|
|
371
|
+
| `densenet169` | 12.5M | 1D/2D/3D |
|
|
372
|
+
| `densenet121_pretrained` ⭐ | 7.0M | 2D |
|
|
373
|
+
| **── Efficient/Mobile CNNs ──** |||
|
|
374
|
+
| **MobileNetV3** — Mobile Neural Network V3 |||
|
|
375
|
+
| `mobilenet_v3_small` ⭐ | 0.9M | 2D |
|
|
376
|
+
| `mobilenet_v3_large` ⭐ | 3.0M | 2D |
|
|
375
377
|
| **EfficientNet** — Efficient Neural Network |||
|
|
376
378
|
| `efficientnet_b0` ⭐ | 4.0M | 2D |
|
|
377
379
|
| `efficientnet_b1` ⭐ | 6.5M | 2D |
|
|
@@ -380,47 +382,41 @@ WaveDL/
|
|
|
380
382
|
| `efficientnet_v2_s` ⭐ | 20.2M | 2D |
|
|
381
383
|
| `efficientnet_v2_m` ⭐ | 52.9M | 2D |
|
|
382
384
|
| `efficientnet_v2_l` ⭐ | 117.2M | 2D |
|
|
383
|
-
| **MobileNetV3** — Mobile Neural Network V3 |||
|
|
384
|
-
| `mobilenet_v3_small` ⭐ | 0.9M | 2D |
|
|
385
|
-
| `mobilenet_v3_large` ⭐ | 3.0M | 2D |
|
|
386
385
|
| **RegNet** — Regularized Network |||
|
|
387
386
|
| `regnet_y_400mf` ⭐ | 3.9M | 2D |
|
|
388
387
|
| `regnet_y_800mf` ⭐ | 5.7M | 2D |
|
|
389
388
|
| `regnet_y_1_6gf` ⭐ | 10.3M | 2D |
|
|
390
389
|
| `regnet_y_3_2gf` ⭐ | 17.9M | 2D |
|
|
391
390
|
| `regnet_y_8gf` ⭐ | 37.4M | 2D |
|
|
392
|
-
|
|
|
393
|
-
| `swin_t` ⭐ | 27.5M | 2D |
|
|
394
|
-
| `swin_s` ⭐ | 48.8M | 2D |
|
|
395
|
-
| `swin_b` ⭐ | 86.7M | 2D |
|
|
391
|
+
| **── Modern CNNs ──** |||
|
|
396
392
|
| **ConvNeXt** — Convolutional Next |||
|
|
397
393
|
| `convnext_tiny` | 27.8M | 1D/2D/3D |
|
|
398
394
|
| `convnext_small` | 49.5M | 1D/2D/3D |
|
|
399
395
|
| `convnext_base` | 87.6M | 1D/2D/3D |
|
|
400
396
|
| `convnext_tiny_pretrained` ⭐ | 27.8M | 2D |
|
|
401
|
-
| **DenseNet** — Densely Connected Network |||
|
|
402
|
-
| `densenet121` | 7.0M | 1D/2D/3D |
|
|
403
|
-
| `densenet169` | 12.5M | 1D/2D/3D |
|
|
404
|
-
| `densenet121_pretrained` ⭐ | 7.0M | 2D |
|
|
405
|
-
| **ViT** — Vision Transformer |||
|
|
406
|
-
| `vit_tiny` | 5.4M | 1D/2D |
|
|
407
|
-
| `vit_small` | 21.4M | 1D/2D |
|
|
408
|
-
| `vit_base` | 85.3M | 1D/2D |
|
|
409
397
|
| **ConvNeXt V2** — ConvNeXt with GRN |||
|
|
410
398
|
| `convnext_v2_tiny` | 27.9M | 1D/2D/3D |
|
|
411
399
|
| `convnext_v2_small` | 49.6M | 1D/2D/3D |
|
|
412
400
|
| `convnext_v2_base` | 87.7M | 1D/2D/3D |
|
|
413
401
|
| `convnext_v2_tiny_pretrained` ⭐ | 27.9M | 2D |
|
|
414
|
-
| **
|
|
415
|
-
| `
|
|
416
|
-
|
|
|
417
|
-
| `
|
|
418
|
-
|
|
|
419
|
-
|
|
|
402
|
+
| **UniRepLKNet** — Large-Kernel ConvNet |||
|
|
403
|
+
| `unireplknet_tiny` | 30.8M | 1D/2D/3D |
|
|
404
|
+
| `unireplknet_small` | 56.0M | 1D/2D/3D |
|
|
405
|
+
| `unireplknet_base` | 97.6M | 1D/2D/3D |
|
|
406
|
+
| **── Vision Transformers ──** |||
|
|
407
|
+
| **ViT** — Vision Transformer |||
|
|
408
|
+
| `vit_tiny` | 5.4M | 1D/2D |
|
|
409
|
+
| `vit_small` | 21.4M | 1D/2D |
|
|
410
|
+
| `vit_base` | 85.3M | 1D/2D |
|
|
411
|
+
| **Swin** — Shifted Window Transformer |||
|
|
412
|
+
| `swin_t` ⭐ | 27.5M | 2D |
|
|
413
|
+
| `swin_s` ⭐ | 48.8M | 2D |
|
|
414
|
+
| `swin_b` ⭐ | 86.7M | 2D |
|
|
420
415
|
| **MaxViT** — Multi-Axis ViT |||
|
|
421
416
|
| `maxvit_tiny` ⭐ | 30.1M | 2D |
|
|
422
417
|
| `maxvit_small` ⭐ | 67.6M | 2D |
|
|
423
418
|
| `maxvit_base` ⭐ | 119.1M | 2D |
|
|
419
|
+
| **── Hybrid CNN-Transformer ──** |||
|
|
424
420
|
| **FastViT** — Fast Hybrid CNN-ViT |||
|
|
425
421
|
| `fastvit_t8` ⭐ | 4.0M | 2D |
|
|
426
422
|
| `fastvit_t12` ⭐ | 6.8M | 2D |
|
|
@@ -431,6 +427,31 @@ WaveDL/
|
|
|
431
427
|
| `caformer_s36` ⭐ | 39.2M | 2D |
|
|
432
428
|
| `caformer_m36` ⭐ | 56.9M | 2D |
|
|
433
429
|
| `poolformer_s12` ⭐ | 11.9M | 2D |
|
|
430
|
+
| **EfficientViT** — Memory-Efficient ViT |||
|
|
431
|
+
| `efficientvit_m0` ⭐ | 2.2M | 2D |
|
|
432
|
+
| `efficientvit_m1` ⭐ | 2.6M | 2D |
|
|
433
|
+
| `efficientvit_m2` ⭐ | 3.8M | 2D |
|
|
434
|
+
| `efficientvit_b0` ⭐ | 2.1M | 2D |
|
|
435
|
+
| `efficientvit_b1` ⭐ | 7.5M | 2D |
|
|
436
|
+
| `efficientvit_b2` ⭐ | 21.8M | 2D |
|
|
437
|
+
| `efficientvit_b3` ⭐ | 46.1M | 2D |
|
|
438
|
+
| `efficientvit_l1` ⭐ | 49.5M | 2D |
|
|
439
|
+
| `efficientvit_l2` ⭐ | 60.5M | 2D |
|
|
440
|
+
| **── State Space Models ──** |||
|
|
441
|
+
| **Mamba** — State Space Model |||
|
|
442
|
+
| `mamba_1d` | 3.4M | 1D |
|
|
443
|
+
| **Vision Mamba (ViM)** — 2D Mamba |||
|
|
444
|
+
| `vim_tiny` | 6.6M | 2D |
|
|
445
|
+
| `vim_small` | 51.1M | 2D |
|
|
446
|
+
| `vim_base` | 201.4M | 2D |
|
|
447
|
+
| **── Specialized Architectures ──** |||
|
|
448
|
+
| **TCN** — Temporal Convolutional Network |||
|
|
449
|
+
| `tcn_small` | 0.9M | 1D |
|
|
450
|
+
| `tcn` | 6.9M | 1D |
|
|
451
|
+
| `tcn_large` | 10.0M | 1D |
|
|
452
|
+
| **ResNet3D** — 3D Residual Network |||
|
|
453
|
+
| `resnet3d_18` | 33.2M | 3D |
|
|
454
|
+
| `mc3_18` — Mixed Convolution 3D | 11.5M | 3D |
|
|
434
455
|
| **U-Net** — U-shaped Network |||
|
|
435
456
|
| `unet_regression` | 31.0M | 1D/2D/3D |
|
|
436
457
|
|
|
@@ -450,34 +471,52 @@ os.environ['TORCH_HOME'] = '.torch_cache' # Match WaveDL's HPC cache location
|
|
|
450
471
|
from torchvision import models as m
|
|
451
472
|
from torchvision.models import video as v
|
|
452
473
|
|
|
453
|
-
# === TorchVision Models ===
|
|
454
|
-
|
|
455
|
-
'resnet18'
|
|
456
|
-
'
|
|
457
|
-
'
|
|
458
|
-
'
|
|
459
|
-
'
|
|
460
|
-
'
|
|
461
|
-
'
|
|
462
|
-
'
|
|
463
|
-
'
|
|
464
|
-
|
|
465
|
-
|
|
466
|
-
|
|
474
|
+
# === TorchVision Models (use IMAGENET1K_V1 to match WaveDL) ===
|
|
475
|
+
models = [
|
|
476
|
+
('resnet18', m.ResNet18_Weights.IMAGENET1K_V1),
|
|
477
|
+
('resnet50', m.ResNet50_Weights.IMAGENET1K_V1),
|
|
478
|
+
('efficientnet_b0', m.EfficientNet_B0_Weights.IMAGENET1K_V1),
|
|
479
|
+
('efficientnet_b1', m.EfficientNet_B1_Weights.IMAGENET1K_V1),
|
|
480
|
+
('efficientnet_b2', m.EfficientNet_B2_Weights.IMAGENET1K_V1),
|
|
481
|
+
('efficientnet_v2_s', m.EfficientNet_V2_S_Weights.IMAGENET1K_V1),
|
|
482
|
+
('efficientnet_v2_m', m.EfficientNet_V2_M_Weights.IMAGENET1K_V1),
|
|
483
|
+
('efficientnet_v2_l', m.EfficientNet_V2_L_Weights.IMAGENET1K_V1),
|
|
484
|
+
('mobilenet_v3_small', m.MobileNet_V3_Small_Weights.IMAGENET1K_V1),
|
|
485
|
+
('mobilenet_v3_large', m.MobileNet_V3_Large_Weights.IMAGENET1K_V1),
|
|
486
|
+
('regnet_y_400mf', m.RegNet_Y_400MF_Weights.IMAGENET1K_V1),
|
|
487
|
+
('regnet_y_800mf', m.RegNet_Y_800MF_Weights.IMAGENET1K_V1),
|
|
488
|
+
('regnet_y_1_6gf', m.RegNet_Y_1_6GF_Weights.IMAGENET1K_V1),
|
|
489
|
+
('regnet_y_3_2gf', m.RegNet_Y_3_2GF_Weights.IMAGENET1K_V1),
|
|
490
|
+
('regnet_y_8gf', m.RegNet_Y_8GF_Weights.IMAGENET1K_V1),
|
|
491
|
+
('swin_t', m.Swin_T_Weights.IMAGENET1K_V1),
|
|
492
|
+
('swin_s', m.Swin_S_Weights.IMAGENET1K_V1),
|
|
493
|
+
('swin_b', m.Swin_B_Weights.IMAGENET1K_V1),
|
|
494
|
+
('convnext_tiny', m.ConvNeXt_Tiny_Weights.IMAGENET1K_V1),
|
|
495
|
+
('densenet121', m.DenseNet121_Weights.IMAGENET1K_V1),
|
|
496
|
+
]
|
|
497
|
+
for name, w in models:
|
|
498
|
+
getattr(m, name)(weights=w); print(f'✓ {name}')
|
|
467
499
|
|
|
468
500
|
# 3D video models
|
|
469
|
-
v.r3d_18(weights=v.R3D_18_Weights.
|
|
470
|
-
v.mc3_18(weights=v.MC3_18_Weights.
|
|
501
|
+
v.r3d_18(weights=v.R3D_18_Weights.KINETICS400_V1); print('✓ r3d_18')
|
|
502
|
+
v.mc3_18(weights=v.MC3_18_Weights.KINETICS400_V1); print('✓ mc3_18')
|
|
471
503
|
|
|
472
504
|
# === Timm Models (MaxViT, FastViT, CAFormer, ConvNeXt V2) ===
|
|
473
505
|
import timm
|
|
474
506
|
|
|
475
507
|
timm_models = [
|
|
476
|
-
|
|
477
|
-
'
|
|
478
|
-
|
|
479
|
-
'
|
|
480
|
-
|
|
508
|
+
# MaxViT (no suffix - timm resolves to default)
|
|
509
|
+
'maxvit_tiny_tf_224', 'maxvit_small_tf_224', 'maxvit_base_tf_224',
|
|
510
|
+
# FastViT (no suffix)
|
|
511
|
+
'fastvit_t8', 'fastvit_t12', 'fastvit_s12', 'fastvit_sa12',
|
|
512
|
+
# CAFormer/PoolFormer (no suffix)
|
|
513
|
+
'caformer_s18', 'caformer_s36', 'caformer_m36', 'poolformer_s12',
|
|
514
|
+
# ConvNeXt V2 (no suffix)
|
|
515
|
+
'convnextv2_tiny',
|
|
516
|
+
# EfficientViT (no suffix)
|
|
517
|
+
'efficientvit_m0', 'efficientvit_m1', 'efficientvit_m2',
|
|
518
|
+
'efficientvit_b0', 'efficientvit_b1', 'efficientvit_b2', 'efficientvit_b3',
|
|
519
|
+
'efficientvit_l1', 'efficientvit_l2',
|
|
481
520
|
]
|
|
482
521
|
for name in timm_models:
|
|
483
522
|
timm.create_model(name, pretrained=True); print(f'✓ {name}')
|