wavedl 1.5.7__tar.gz → 1.6.1__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {wavedl-1.5.7/src/wavedl.egg-info → wavedl-1.6.1}/PKG-INFO +150 -82
- {wavedl-1.5.7 → wavedl-1.6.1}/README.md +147 -81
- {wavedl-1.5.7 → wavedl-1.6.1}/pyproject.toml +2 -0
- {wavedl-1.5.7 → wavedl-1.6.1}/src/wavedl/__init__.py +1 -1
- {wavedl-1.5.7 → wavedl-1.6.1}/src/wavedl/hpo.py +451 -451
- {wavedl-1.5.7 → wavedl-1.6.1}/src/wavedl/models/__init__.py +80 -4
- wavedl-1.6.1/src/wavedl/models/_pretrained_utils.py +366 -0
- {wavedl-1.5.7 → wavedl-1.6.1}/src/wavedl/models/base.py +48 -0
- wavedl-1.6.1/src/wavedl/models/caformer.py +270 -0
- {wavedl-1.5.7 → wavedl-1.6.1}/src/wavedl/models/cnn.py +2 -27
- {wavedl-1.5.7 → wavedl-1.6.1}/src/wavedl/models/convnext.py +113 -51
- wavedl-1.6.1/src/wavedl/models/convnext_v2.py +488 -0
- {wavedl-1.5.7 → wavedl-1.6.1}/src/wavedl/models/densenet.py +10 -23
- {wavedl-1.5.7 → wavedl-1.6.1}/src/wavedl/models/efficientnet.py +6 -6
- {wavedl-1.5.7 → wavedl-1.6.1}/src/wavedl/models/efficientnetv2.py +315 -315
- wavedl-1.6.1/src/wavedl/models/efficientvit.py +398 -0
- wavedl-1.6.1/src/wavedl/models/fastvit.py +252 -0
- wavedl-1.6.1/src/wavedl/models/mamba.py +555 -0
- wavedl-1.6.1/src/wavedl/models/maxvit.py +254 -0
- {wavedl-1.5.7 → wavedl-1.6.1}/src/wavedl/models/mobilenetv3.py +295 -295
- {wavedl-1.5.7 → wavedl-1.6.1}/src/wavedl/models/regnet.py +406 -406
- {wavedl-1.5.7 → wavedl-1.6.1}/src/wavedl/models/resnet.py +19 -61
- {wavedl-1.5.7 → wavedl-1.6.1}/src/wavedl/models/resnet3d.py +258 -258
- {wavedl-1.5.7 → wavedl-1.6.1}/src/wavedl/models/swin.py +443 -443
- {wavedl-1.5.7 → wavedl-1.6.1}/src/wavedl/models/tcn.py +393 -409
- {wavedl-1.5.7 → wavedl-1.6.1}/src/wavedl/models/unet.py +2 -6
- wavedl-1.6.1/src/wavedl/models/unireplknet.py +491 -0
- {wavedl-1.5.7 → wavedl-1.6.1}/src/wavedl/models/vit.py +9 -9
- {wavedl-1.5.7 → wavedl-1.6.1}/src/wavedl/train.py +1430 -1425
- {wavedl-1.5.7 → wavedl-1.6.1}/src/wavedl/utils/config.py +367 -367
- {wavedl-1.5.7 → wavedl-1.6.1}/src/wavedl/utils/cross_validation.py +530 -530
- {wavedl-1.5.7 → wavedl-1.6.1}/src/wavedl/utils/data.py +39 -6
- {wavedl-1.5.7 → wavedl-1.6.1}/src/wavedl/utils/losses.py +216 -216
- {wavedl-1.5.7 → wavedl-1.6.1}/src/wavedl/utils/optimizers.py +216 -216
- {wavedl-1.5.7 → wavedl-1.6.1}/src/wavedl/utils/schedulers.py +251 -251
- {wavedl-1.5.7 → wavedl-1.6.1/src/wavedl.egg-info}/PKG-INFO +150 -82
- {wavedl-1.5.7 → wavedl-1.6.1}/src/wavedl.egg-info/SOURCES.txt +8 -0
- {wavedl-1.5.7 → wavedl-1.6.1}/src/wavedl.egg-info/requires.txt +2 -0
- {wavedl-1.5.7 → wavedl-1.6.1}/LICENSE +0 -0
- {wavedl-1.5.7 → wavedl-1.6.1}/setup.cfg +0 -0
- {wavedl-1.5.7 → wavedl-1.6.1}/src/wavedl/hpc.py +0 -0
- {wavedl-1.5.7 → wavedl-1.6.1}/src/wavedl/models/_template.py +0 -0
- {wavedl-1.5.7 → wavedl-1.6.1}/src/wavedl/models/registry.py +0 -0
- {wavedl-1.5.7 → wavedl-1.6.1}/src/wavedl/test.py +0 -0
- {wavedl-1.5.7 → wavedl-1.6.1}/src/wavedl/utils/__init__.py +0 -0
- {wavedl-1.5.7 → wavedl-1.6.1}/src/wavedl/utils/constraints.py +0 -0
- {wavedl-1.5.7 → wavedl-1.6.1}/src/wavedl/utils/distributed.py +0 -0
- {wavedl-1.5.7 → wavedl-1.6.1}/src/wavedl/utils/metrics.py +0 -0
- {wavedl-1.5.7 → wavedl-1.6.1}/src/wavedl.egg-info/dependency_links.txt +0 -0
- {wavedl-1.5.7 → wavedl-1.6.1}/src/wavedl.egg-info/entry_points.txt +0 -0
- {wavedl-1.5.7 → wavedl-1.6.1}/src/wavedl.egg-info/top_level.txt +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.2
|
|
2
2
|
Name: wavedl
|
|
3
|
-
Version: 1.
|
|
3
|
+
Version: 1.6.1
|
|
4
4
|
Summary: A Scalable Deep Learning Framework for Wave-Based Inverse Problems
|
|
5
5
|
Author: Ductho Le
|
|
6
6
|
License: MIT
|
|
@@ -23,6 +23,7 @@ Description-Content-Type: text/markdown
|
|
|
23
23
|
License-File: LICENSE
|
|
24
24
|
Requires-Dist: torch>=2.0.0
|
|
25
25
|
Requires-Dist: torchvision>=0.15.0
|
|
26
|
+
Requires-Dist: timm>=0.9.0
|
|
26
27
|
Requires-Dist: accelerate>=0.20.0
|
|
27
28
|
Requires-Dist: numpy>=1.24.0
|
|
28
29
|
Requires-Dist: scipy>=1.10.0
|
|
@@ -37,6 +38,7 @@ Requires-Dist: wandb>=0.15.0
|
|
|
37
38
|
Requires-Dist: optuna>=3.0.0
|
|
38
39
|
Requires-Dist: onnx>=1.14.0
|
|
39
40
|
Requires-Dist: onnxruntime>=1.15.0
|
|
41
|
+
Requires-Dist: onnxscript>=0.1.0
|
|
40
42
|
Requires-Dist: triton>=2.0.0; sys_platform == "linux"
|
|
41
43
|
Provides-Extra: dev
|
|
42
44
|
Requires-Dist: pytest>=7.0.0; extra == "dev"
|
|
@@ -117,7 +119,7 @@ Train on datasets larger than RAM:
|
|
|
117
119
|
|
|
118
120
|
**🧠 Models? We've Got Options**
|
|
119
121
|
|
|
120
|
-
|
|
122
|
+
69 architectures, ready to go:
|
|
121
123
|
- CNNs, ResNets, ViTs, EfficientNets...
|
|
122
124
|
- All adapted for regression
|
|
123
125
|
- [Add your own](#adding-custom-models) in one line
|
|
@@ -202,7 +204,7 @@ Deploy models anywhere:
|
|
|
202
204
|
#### From PyPI (recommended for all users)
|
|
203
205
|
|
|
204
206
|
```bash
|
|
205
|
-
pip install wavedl
|
|
207
|
+
pip install --upgrade wavedl
|
|
206
208
|
```
|
|
207
209
|
|
|
208
210
|
This installs everything you need: training, inference, HPO, ONNX export.
|
|
@@ -358,22 +360,10 @@ WaveDL/
|
|
|
358
360
|
│ ├── hpo.py # Hyperparameter optimization
|
|
359
361
|
│ ├── hpc.py # HPC distributed training launcher
|
|
360
362
|
│ │
|
|
361
|
-
│ ├── models/ # Model
|
|
363
|
+
│ ├── models/ # Model Zoo (69 architectures)
|
|
362
364
|
│ │ ├── registry.py # Model factory (@register_model)
|
|
363
365
|
│ │ ├── base.py # Abstract base class
|
|
364
|
-
│ │
|
|
365
|
-
│ │ ├── resnet.py # ResNet-18/34/50 (1D/2D/3D)
|
|
366
|
-
│ │ ├── resnet3d.py # ResNet3D-18, MC3-18 (3D only)
|
|
367
|
-
│ │ ├── tcn.py # TCN (1D only)
|
|
368
|
-
│ │ ├── efficientnet.py # EfficientNet-B0/B1/B2 (2D)
|
|
369
|
-
│ │ ├── efficientnetv2.py # EfficientNetV2-S/M/L (2D)
|
|
370
|
-
│ │ ├── mobilenetv3.py # MobileNetV3-Small/Large (2D)
|
|
371
|
-
│ │ ├── regnet.py # RegNetY variants (2D)
|
|
372
|
-
│ │ ├── swin.py # Swin Transformer (2D)
|
|
373
|
-
│ │ ├── vit.py # Vision Transformer (1D/2D)
|
|
374
|
-
│ │ ├── convnext.py # ConvNeXt (1D/2D/3D)
|
|
375
|
-
│ │ ├── densenet.py # DenseNet-121/169 (1D/2D/3D)
|
|
376
|
-
│ │ └── unet.py # U-Net Regression
|
|
366
|
+
│ │ └── ... # See "Available Models" section
|
|
377
367
|
│ │
|
|
378
368
|
│ └── utils/ # Utilities
|
|
379
369
|
│ ├── data.py # Memory-mapped data pipeline
|
|
@@ -388,7 +378,7 @@ WaveDL/
|
|
|
388
378
|
├── configs/ # YAML config templates
|
|
389
379
|
├── examples/ # Ready-to-run examples
|
|
390
380
|
├── notebooks/ # Jupyter notebooks
|
|
391
|
-
├── unit_tests/ # Pytest test suite
|
|
381
|
+
├── unit_tests/ # Pytest test suite
|
|
392
382
|
│
|
|
393
383
|
├── pyproject.toml # Package config, dependencies
|
|
394
384
|
├── CHANGELOG.md # Version history
|
|
@@ -411,71 +401,117 @@ WaveDL/
|
|
|
411
401
|
> ```
|
|
412
402
|
|
|
413
403
|
<details>
|
|
414
|
-
<summary><b>Available Models</b> —
|
|
404
|
+
<summary><b>Available Models</b> — 69 architectures</summary>
|
|
415
405
|
|
|
416
|
-
| Model | Params | Dim |
|
|
417
|
-
|
|
406
|
+
| Model | Backbone Params | Dim |
|
|
407
|
+
|-------|-----------------|-----|
|
|
408
|
+
| **── Classic CNNs ──** |||
|
|
418
409
|
| **CNN** — Convolutional Neural Network |||
|
|
419
|
-
| `cnn` | 1.
|
|
410
|
+
| `cnn` | 1.6M | 1D/2D/3D |
|
|
420
411
|
| **ResNet** — Residual Network |||
|
|
421
|
-
| `resnet18` | 11.
|
|
422
|
-
| `resnet34` | 21.
|
|
423
|
-
| `resnet50` |
|
|
424
|
-
| `resnet18_pretrained` ⭐ | 11.
|
|
425
|
-
| `resnet50_pretrained` ⭐ |
|
|
426
|
-
| **
|
|
427
|
-
| `
|
|
428
|
-
| `
|
|
429
|
-
|
|
|
430
|
-
|
|
|
431
|
-
|
|
|
432
|
-
| `
|
|
412
|
+
| `resnet18` | 11.2M | 1D/2D/3D |
|
|
413
|
+
| `resnet34` | 21.3M | 1D/2D/3D |
|
|
414
|
+
| `resnet50` | 23.5M | 1D/2D/3D |
|
|
415
|
+
| `resnet18_pretrained` ⭐ | 11.2M | 2D |
|
|
416
|
+
| `resnet50_pretrained` ⭐ | 23.5M | 2D |
|
|
417
|
+
| **DenseNet** — Densely Connected Network |||
|
|
418
|
+
| `densenet121` | 7.0M | 1D/2D/3D |
|
|
419
|
+
| `densenet169` | 12.5M | 1D/2D/3D |
|
|
420
|
+
| `densenet121_pretrained` ⭐ | 7.0M | 2D |
|
|
421
|
+
| **── Efficient/Mobile CNNs ──** |||
|
|
422
|
+
| **MobileNetV3** — Mobile Neural Network V3 |||
|
|
423
|
+
| `mobilenet_v3_small` ⭐ | 0.9M | 2D |
|
|
424
|
+
| `mobilenet_v3_large` ⭐ | 3.0M | 2D |
|
|
433
425
|
| **EfficientNet** — Efficient Neural Network |||
|
|
434
|
-
| `efficientnet_b0` ⭐ | 4.
|
|
435
|
-
| `efficientnet_b1` ⭐ |
|
|
436
|
-
| `efficientnet_b2` ⭐ |
|
|
426
|
+
| `efficientnet_b0` ⭐ | 4.0M | 2D |
|
|
427
|
+
| `efficientnet_b1` ⭐ | 6.5M | 2D |
|
|
428
|
+
| `efficientnet_b2` ⭐ | 7.7M | 2D |
|
|
437
429
|
| **EfficientNetV2** — Efficient Neural Network V2 |||
|
|
438
|
-
| `efficientnet_v2_s` ⭐ |
|
|
439
|
-
| `efficientnet_v2_m` ⭐ |
|
|
440
|
-
| `efficientnet_v2_l` ⭐ |
|
|
441
|
-
| **MobileNetV3** — Mobile Neural Network V3 |||
|
|
442
|
-
| `mobilenet_v3_small` ⭐ | 1.1M | 2D |
|
|
443
|
-
| `mobilenet_v3_large` ⭐ | 3.2M | 2D |
|
|
430
|
+
| `efficientnet_v2_s` ⭐ | 20.2M | 2D |
|
|
431
|
+
| `efficientnet_v2_m` ⭐ | 52.9M | 2D |
|
|
432
|
+
| `efficientnet_v2_l` ⭐ | 117.2M | 2D |
|
|
444
433
|
| **RegNet** — Regularized Network |||
|
|
445
|
-
| `regnet_y_400mf` ⭐ |
|
|
446
|
-
| `regnet_y_800mf` ⭐ | 5.
|
|
447
|
-
| `regnet_y_1_6gf` ⭐ | 10.
|
|
448
|
-
| `regnet_y_3_2gf` ⭐ |
|
|
449
|
-
| `regnet_y_8gf` ⭐ | 37.
|
|
450
|
-
|
|
|
451
|
-
| `swin_t` ⭐ | 28.0M | 2D |
|
|
452
|
-
| `swin_s` ⭐ | 49.4M | 2D |
|
|
453
|
-
| `swin_b` ⭐ | 87.4M | 2D |
|
|
434
|
+
| `regnet_y_400mf` ⭐ | 3.9M | 2D |
|
|
435
|
+
| `regnet_y_800mf` ⭐ | 5.7M | 2D |
|
|
436
|
+
| `regnet_y_1_6gf` ⭐ | 10.3M | 2D |
|
|
437
|
+
| `regnet_y_3_2gf` ⭐ | 17.9M | 2D |
|
|
438
|
+
| `regnet_y_8gf` ⭐ | 37.4M | 2D |
|
|
439
|
+
| **── Modern CNNs ──** |||
|
|
454
440
|
| **ConvNeXt** — Convolutional Next |||
|
|
455
|
-
| `convnext_tiny` |
|
|
456
|
-
| `convnext_small` | 49.
|
|
457
|
-
| `convnext_base` |
|
|
458
|
-
| `convnext_tiny_pretrained` ⭐ |
|
|
459
|
-
| **
|
|
460
|
-
| `
|
|
461
|
-
| `
|
|
462
|
-
| `
|
|
441
|
+
| `convnext_tiny` | 27.8M | 1D/2D/3D |
|
|
442
|
+
| `convnext_small` | 49.5M | 1D/2D/3D |
|
|
443
|
+
| `convnext_base` | 87.6M | 1D/2D/3D |
|
|
444
|
+
| `convnext_tiny_pretrained` ⭐ | 27.8M | 2D |
|
|
445
|
+
| **ConvNeXt V2** — ConvNeXt with GRN |||
|
|
446
|
+
| `convnext_v2_tiny` | 27.9M | 1D/2D/3D |
|
|
447
|
+
| `convnext_v2_small` | 49.6M | 1D/2D/3D |
|
|
448
|
+
| `convnext_v2_base` | 87.7M | 1D/2D/3D |
|
|
449
|
+
| `convnext_v2_tiny_pretrained` ⭐ | 27.9M | 2D |
|
|
450
|
+
| **UniRepLKNet** — Large-Kernel ConvNet |||
|
|
451
|
+
| `unireplknet_tiny` | 30.8M | 1D/2D/3D |
|
|
452
|
+
| `unireplknet_small` | 56.0M | 1D/2D/3D |
|
|
453
|
+
| `unireplknet_base` | 97.6M | 1D/2D/3D |
|
|
454
|
+
| **── Vision Transformers ──** |||
|
|
463
455
|
| **ViT** — Vision Transformer |||
|
|
464
|
-
| `vit_tiny` | 5.
|
|
465
|
-
| `vit_small` | 21.
|
|
466
|
-
| `vit_base` | 85.
|
|
456
|
+
| `vit_tiny` | 5.4M | 1D/2D |
|
|
457
|
+
| `vit_small` | 21.4M | 1D/2D |
|
|
458
|
+
| `vit_base` | 85.3M | 1D/2D |
|
|
459
|
+
| **Swin** — Shifted Window Transformer |||
|
|
460
|
+
| `swin_t` ⭐ | 27.5M | 2D |
|
|
461
|
+
| `swin_s` ⭐ | 48.8M | 2D |
|
|
462
|
+
| `swin_b` ⭐ | 86.7M | 2D |
|
|
463
|
+
| **MaxViT** — Multi-Axis ViT |||
|
|
464
|
+
| `maxvit_tiny` ⭐ | 30.1M | 2D |
|
|
465
|
+
| `maxvit_small` ⭐ | 67.6M | 2D |
|
|
466
|
+
| `maxvit_base` ⭐ | 119.1M | 2D |
|
|
467
|
+
| **── Hybrid CNN-Transformer ──** |||
|
|
468
|
+
| **FastViT** — Fast Hybrid CNN-ViT |||
|
|
469
|
+
| `fastvit_t8` ⭐ | 4.0M | 2D |
|
|
470
|
+
| `fastvit_t12` ⭐ | 6.8M | 2D |
|
|
471
|
+
| `fastvit_s12` ⭐ | 8.8M | 2D |
|
|
472
|
+
| `fastvit_sa12` ⭐ | 10.9M | 2D |
|
|
473
|
+
| **CAFormer** — MetaFormer with Attention |||
|
|
474
|
+
| `caformer_s18` ⭐ | 26.3M | 2D |
|
|
475
|
+
| `caformer_s36` ⭐ | 39.2M | 2D |
|
|
476
|
+
| `caformer_m36` ⭐ | 56.9M | 2D |
|
|
477
|
+
| `poolformer_s12` ⭐ | 11.9M | 2D |
|
|
478
|
+
| **EfficientViT** — Memory-Efficient ViT |||
|
|
479
|
+
| `efficientvit_m0` ⭐ | 2.2M | 2D |
|
|
480
|
+
| `efficientvit_m1` ⭐ | 2.6M | 2D |
|
|
481
|
+
| `efficientvit_m2` ⭐ | 3.8M | 2D |
|
|
482
|
+
| `efficientvit_b0` ⭐ | 2.1M | 2D |
|
|
483
|
+
| `efficientvit_b1` ⭐ | 7.5M | 2D |
|
|
484
|
+
| `efficientvit_b2` ⭐ | 21.8M | 2D |
|
|
485
|
+
| `efficientvit_b3` ⭐ | 46.1M | 2D |
|
|
486
|
+
| `efficientvit_l1` ⭐ | 49.5M | 2D |
|
|
487
|
+
| `efficientvit_l2` ⭐ | 60.5M | 2D |
|
|
488
|
+
| **── State Space Models ──** |||
|
|
489
|
+
| **Mamba** — State Space Model |||
|
|
490
|
+
| `mamba_1d` | 3.4M | 1D |
|
|
491
|
+
| **Vision Mamba (ViM)** — 2D Mamba |||
|
|
492
|
+
| `vim_tiny` | 6.6M | 2D |
|
|
493
|
+
| `vim_small` | 51.1M | 2D |
|
|
494
|
+
| `vim_base` | 201.4M | 2D |
|
|
495
|
+
| **── Specialized Architectures ──** |||
|
|
496
|
+
| **TCN** — Temporal Convolutional Network |||
|
|
497
|
+
| `tcn_small` | 0.9M | 1D |
|
|
498
|
+
| `tcn` | 6.9M | 1D |
|
|
499
|
+
| `tcn_large` | 10.0M | 1D |
|
|
500
|
+
| **ResNet3D** — 3D Residual Network |||
|
|
501
|
+
| `resnet3d_18` | 33.2M | 3D |
|
|
502
|
+
| `mc3_18` — Mixed Convolution 3D | 11.5M | 3D |
|
|
467
503
|
| **U-Net** — U-shaped Network |||
|
|
468
|
-
| `unet_regression` | 31.
|
|
504
|
+
| `unet_regression` | 31.0M | 1D/2D/3D |
|
|
505
|
+
|
|
469
506
|
|
|
470
507
|
⭐ = **Pretrained on ImageNet** (recommended for smaller datasets). Weights are downloaded automatically on first use.
|
|
471
508
|
- **Cache location**: `~/.cache/torch/hub/checkpoints/` (or `./.torch_cache/` on HPC if home is not writable)
|
|
472
|
-
- **Size**: ~20–350 MB per model depending on architecture
|
|
473
509
|
- **Train from scratch**: Use `--no_pretrained` to disable pretrained weights
|
|
474
510
|
|
|
475
511
|
**💡 HPC Users**: If compute nodes block internet, pre-download weights on the login node:
|
|
476
512
|
|
|
477
513
|
```bash
|
|
478
|
-
# Run once on login node (with internet) — downloads ALL pretrained weights
|
|
514
|
+
# Run once on login node (with internet) — downloads ALL pretrained weights
|
|
479
515
|
python -c "
|
|
480
516
|
import os
|
|
481
517
|
os.environ['TORCH_HOME'] = '.torch_cache' # Match WaveDL's HPC cache location
|
|
@@ -483,24 +519,56 @@ os.environ['TORCH_HOME'] = '.torch_cache' # Match WaveDL's HPC cache location
|
|
|
483
519
|
from torchvision import models as m
|
|
484
520
|
from torchvision.models import video as v
|
|
485
521
|
|
|
486
|
-
#
|
|
487
|
-
|
|
488
|
-
'resnet18'
|
|
489
|
-
'
|
|
490
|
-
'
|
|
491
|
-
'
|
|
492
|
-
'
|
|
493
|
-
'
|
|
494
|
-
'
|
|
495
|
-
'
|
|
496
|
-
'
|
|
497
|
-
|
|
498
|
-
|
|
499
|
-
|
|
522
|
+
# === TorchVision Models (use IMAGENET1K_V1 to match WaveDL) ===
|
|
523
|
+
models = [
|
|
524
|
+
('resnet18', m.ResNet18_Weights.IMAGENET1K_V1),
|
|
525
|
+
('resnet50', m.ResNet50_Weights.IMAGENET1K_V1),
|
|
526
|
+
('efficientnet_b0', m.EfficientNet_B0_Weights.IMAGENET1K_V1),
|
|
527
|
+
('efficientnet_b1', m.EfficientNet_B1_Weights.IMAGENET1K_V1),
|
|
528
|
+
('efficientnet_b2', m.EfficientNet_B2_Weights.IMAGENET1K_V1),
|
|
529
|
+
('efficientnet_v2_s', m.EfficientNet_V2_S_Weights.IMAGENET1K_V1),
|
|
530
|
+
('efficientnet_v2_m', m.EfficientNet_V2_M_Weights.IMAGENET1K_V1),
|
|
531
|
+
('efficientnet_v2_l', m.EfficientNet_V2_L_Weights.IMAGENET1K_V1),
|
|
532
|
+
('mobilenet_v3_small', m.MobileNet_V3_Small_Weights.IMAGENET1K_V1),
|
|
533
|
+
('mobilenet_v3_large', m.MobileNet_V3_Large_Weights.IMAGENET1K_V1),
|
|
534
|
+
('regnet_y_400mf', m.RegNet_Y_400MF_Weights.IMAGENET1K_V1),
|
|
535
|
+
('regnet_y_800mf', m.RegNet_Y_800MF_Weights.IMAGENET1K_V1),
|
|
536
|
+
('regnet_y_1_6gf', m.RegNet_Y_1_6GF_Weights.IMAGENET1K_V1),
|
|
537
|
+
('regnet_y_3_2gf', m.RegNet_Y_3_2GF_Weights.IMAGENET1K_V1),
|
|
538
|
+
('regnet_y_8gf', m.RegNet_Y_8GF_Weights.IMAGENET1K_V1),
|
|
539
|
+
('swin_t', m.Swin_T_Weights.IMAGENET1K_V1),
|
|
540
|
+
('swin_s', m.Swin_S_Weights.IMAGENET1K_V1),
|
|
541
|
+
('swin_b', m.Swin_B_Weights.IMAGENET1K_V1),
|
|
542
|
+
('convnext_tiny', m.ConvNeXt_Tiny_Weights.IMAGENET1K_V1),
|
|
543
|
+
('densenet121', m.DenseNet121_Weights.IMAGENET1K_V1),
|
|
544
|
+
]
|
|
545
|
+
for name, w in models:
|
|
546
|
+
getattr(m, name)(weights=w); print(f'✓ {name}')
|
|
500
547
|
|
|
501
548
|
# 3D video models
|
|
502
|
-
v.r3d_18(weights=v.R3D_18_Weights.
|
|
503
|
-
v.mc3_18(weights=v.MC3_18_Weights.
|
|
549
|
+
v.r3d_18(weights=v.R3D_18_Weights.KINETICS400_V1); print('✓ r3d_18')
|
|
550
|
+
v.mc3_18(weights=v.MC3_18_Weights.KINETICS400_V1); print('✓ mc3_18')
|
|
551
|
+
|
|
552
|
+
# === Timm Models (MaxViT, FastViT, CAFormer, ConvNeXt V2) ===
|
|
553
|
+
import timm
|
|
554
|
+
|
|
555
|
+
timm_models = [
|
|
556
|
+
# MaxViT (no suffix - timm resolves to default)
|
|
557
|
+
'maxvit_tiny_tf_224', 'maxvit_small_tf_224', 'maxvit_base_tf_224',
|
|
558
|
+
# FastViT (no suffix)
|
|
559
|
+
'fastvit_t8', 'fastvit_t12', 'fastvit_s12', 'fastvit_sa12',
|
|
560
|
+
# CAFormer/PoolFormer (no suffix)
|
|
561
|
+
'caformer_s18', 'caformer_s36', 'caformer_m36', 'poolformer_s12',
|
|
562
|
+
# ConvNeXt V2 (no suffix)
|
|
563
|
+
'convnextv2_tiny',
|
|
564
|
+
# EfficientViT (no suffix)
|
|
565
|
+
'efficientvit_m0', 'efficientvit_m1', 'efficientvit_m2',
|
|
566
|
+
'efficientvit_b0', 'efficientvit_b1', 'efficientvit_b2', 'efficientvit_b3',
|
|
567
|
+
'efficientvit_l1', 'efficientvit_l2',
|
|
568
|
+
]
|
|
569
|
+
for name in timm_models:
|
|
570
|
+
timm.create_model(name, pretrained=True); print(f'✓ {name}')
|
|
571
|
+
|
|
504
572
|
print('\\n✓ All pretrained weights cached!')
|
|
505
573
|
"
|
|
506
574
|
```
|
|
@@ -71,7 +71,7 @@ Train on datasets larger than RAM:
|
|
|
71
71
|
|
|
72
72
|
**🧠 Models? We've Got Options**
|
|
73
73
|
|
|
74
|
-
|
|
74
|
+
69 architectures, ready to go:
|
|
75
75
|
- CNNs, ResNets, ViTs, EfficientNets...
|
|
76
76
|
- All adapted for regression
|
|
77
77
|
- [Add your own](#adding-custom-models) in one line
|
|
@@ -156,7 +156,7 @@ Deploy models anywhere:
|
|
|
156
156
|
#### From PyPI (recommended for all users)
|
|
157
157
|
|
|
158
158
|
```bash
|
|
159
|
-
pip install wavedl
|
|
159
|
+
pip install --upgrade wavedl
|
|
160
160
|
```
|
|
161
161
|
|
|
162
162
|
This installs everything you need: training, inference, HPO, ONNX export.
|
|
@@ -312,22 +312,10 @@ WaveDL/
|
|
|
312
312
|
│ ├── hpo.py # Hyperparameter optimization
|
|
313
313
|
│ ├── hpc.py # HPC distributed training launcher
|
|
314
314
|
│ │
|
|
315
|
-
│ ├── models/ # Model
|
|
315
|
+
│ ├── models/ # Model Zoo (69 architectures)
|
|
316
316
|
│ │ ├── registry.py # Model factory (@register_model)
|
|
317
317
|
│ │ ├── base.py # Abstract base class
|
|
318
|
-
│ │
|
|
319
|
-
│ │ ├── resnet.py # ResNet-18/34/50 (1D/2D/3D)
|
|
320
|
-
│ │ ├── resnet3d.py # ResNet3D-18, MC3-18 (3D only)
|
|
321
|
-
│ │ ├── tcn.py # TCN (1D only)
|
|
322
|
-
│ │ ├── efficientnet.py # EfficientNet-B0/B1/B2 (2D)
|
|
323
|
-
│ │ ├── efficientnetv2.py # EfficientNetV2-S/M/L (2D)
|
|
324
|
-
│ │ ├── mobilenetv3.py # MobileNetV3-Small/Large (2D)
|
|
325
|
-
│ │ ├── regnet.py # RegNetY variants (2D)
|
|
326
|
-
│ │ ├── swin.py # Swin Transformer (2D)
|
|
327
|
-
│ │ ├── vit.py # Vision Transformer (1D/2D)
|
|
328
|
-
│ │ ├── convnext.py # ConvNeXt (1D/2D/3D)
|
|
329
|
-
│ │ ├── densenet.py # DenseNet-121/169 (1D/2D/3D)
|
|
330
|
-
│ │ └── unet.py # U-Net Regression
|
|
318
|
+
│ │ └── ... # See "Available Models" section
|
|
331
319
|
│ │
|
|
332
320
|
│ └── utils/ # Utilities
|
|
333
321
|
│ ├── data.py # Memory-mapped data pipeline
|
|
@@ -342,7 +330,7 @@ WaveDL/
|
|
|
342
330
|
├── configs/ # YAML config templates
|
|
343
331
|
├── examples/ # Ready-to-run examples
|
|
344
332
|
├── notebooks/ # Jupyter notebooks
|
|
345
|
-
├── unit_tests/ # Pytest test suite
|
|
333
|
+
├── unit_tests/ # Pytest test suite
|
|
346
334
|
│
|
|
347
335
|
├── pyproject.toml # Package config, dependencies
|
|
348
336
|
├── CHANGELOG.md # Version history
|
|
@@ -365,71 +353,117 @@ WaveDL/
|
|
|
365
353
|
> ```
|
|
366
354
|
|
|
367
355
|
<details>
|
|
368
|
-
<summary><b>Available Models</b> —
|
|
356
|
+
<summary><b>Available Models</b> — 69 architectures</summary>
|
|
369
357
|
|
|
370
|
-
| Model | Params | Dim |
|
|
371
|
-
|
|
358
|
+
| Model | Backbone Params | Dim |
|
|
359
|
+
|-------|-----------------|-----|
|
|
360
|
+
| **── Classic CNNs ──** |||
|
|
372
361
|
| **CNN** — Convolutional Neural Network |||
|
|
373
|
-
| `cnn` | 1.
|
|
362
|
+
| `cnn` | 1.6M | 1D/2D/3D |
|
|
374
363
|
| **ResNet** — Residual Network |||
|
|
375
|
-
| `resnet18` | 11.
|
|
376
|
-
| `resnet34` | 21.
|
|
377
|
-
| `resnet50` |
|
|
378
|
-
| `resnet18_pretrained` ⭐ | 11.
|
|
379
|
-
| `resnet50_pretrained` ⭐ |
|
|
380
|
-
| **
|
|
381
|
-
| `
|
|
382
|
-
| `
|
|
383
|
-
|
|
|
384
|
-
|
|
|
385
|
-
|
|
|
386
|
-
| `
|
|
364
|
+
| `resnet18` | 11.2M | 1D/2D/3D |
|
|
365
|
+
| `resnet34` | 21.3M | 1D/2D/3D |
|
|
366
|
+
| `resnet50` | 23.5M | 1D/2D/3D |
|
|
367
|
+
| `resnet18_pretrained` ⭐ | 11.2M | 2D |
|
|
368
|
+
| `resnet50_pretrained` ⭐ | 23.5M | 2D |
|
|
369
|
+
| **DenseNet** — Densely Connected Network |||
|
|
370
|
+
| `densenet121` | 7.0M | 1D/2D/3D |
|
|
371
|
+
| `densenet169` | 12.5M | 1D/2D/3D |
|
|
372
|
+
| `densenet121_pretrained` ⭐ | 7.0M | 2D |
|
|
373
|
+
| **── Efficient/Mobile CNNs ──** |||
|
|
374
|
+
| **MobileNetV3** — Mobile Neural Network V3 |||
|
|
375
|
+
| `mobilenet_v3_small` ⭐ | 0.9M | 2D |
|
|
376
|
+
| `mobilenet_v3_large` ⭐ | 3.0M | 2D |
|
|
387
377
|
| **EfficientNet** — Efficient Neural Network |||
|
|
388
|
-
| `efficientnet_b0` ⭐ | 4.
|
|
389
|
-
| `efficientnet_b1` ⭐ |
|
|
390
|
-
| `efficientnet_b2` ⭐ |
|
|
378
|
+
| `efficientnet_b0` ⭐ | 4.0M | 2D |
|
|
379
|
+
| `efficientnet_b1` ⭐ | 6.5M | 2D |
|
|
380
|
+
| `efficientnet_b2` ⭐ | 7.7M | 2D |
|
|
391
381
|
| **EfficientNetV2** — Efficient Neural Network V2 |||
|
|
392
|
-
| `efficientnet_v2_s` ⭐ |
|
|
393
|
-
| `efficientnet_v2_m` ⭐ |
|
|
394
|
-
| `efficientnet_v2_l` ⭐ |
|
|
395
|
-
| **MobileNetV3** — Mobile Neural Network V3 |||
|
|
396
|
-
| `mobilenet_v3_small` ⭐ | 1.1M | 2D |
|
|
397
|
-
| `mobilenet_v3_large` ⭐ | 3.2M | 2D |
|
|
382
|
+
| `efficientnet_v2_s` ⭐ | 20.2M | 2D |
|
|
383
|
+
| `efficientnet_v2_m` ⭐ | 52.9M | 2D |
|
|
384
|
+
| `efficientnet_v2_l` ⭐ | 117.2M | 2D |
|
|
398
385
|
| **RegNet** — Regularized Network |||
|
|
399
|
-
| `regnet_y_400mf` ⭐ |
|
|
400
|
-
| `regnet_y_800mf` ⭐ | 5.
|
|
401
|
-
| `regnet_y_1_6gf` ⭐ | 10.
|
|
402
|
-
| `regnet_y_3_2gf` ⭐ |
|
|
403
|
-
| `regnet_y_8gf` ⭐ | 37.
|
|
404
|
-
|
|
|
405
|
-
| `swin_t` ⭐ | 28.0M | 2D |
|
|
406
|
-
| `swin_s` ⭐ | 49.4M | 2D |
|
|
407
|
-
| `swin_b` ⭐ | 87.4M | 2D |
|
|
386
|
+
| `regnet_y_400mf` ⭐ | 3.9M | 2D |
|
|
387
|
+
| `regnet_y_800mf` ⭐ | 5.7M | 2D |
|
|
388
|
+
| `regnet_y_1_6gf` ⭐ | 10.3M | 2D |
|
|
389
|
+
| `regnet_y_3_2gf` ⭐ | 17.9M | 2D |
|
|
390
|
+
| `regnet_y_8gf` ⭐ | 37.4M | 2D |
|
|
391
|
+
| **── Modern CNNs ──** |||
|
|
408
392
|
| **ConvNeXt** — Convolutional Next |||
|
|
409
|
-
| `convnext_tiny` |
|
|
410
|
-
| `convnext_small` | 49.
|
|
411
|
-
| `convnext_base` |
|
|
412
|
-
| `convnext_tiny_pretrained` ⭐ |
|
|
413
|
-
| **
|
|
414
|
-
| `
|
|
415
|
-
| `
|
|
416
|
-
| `
|
|
393
|
+
| `convnext_tiny` | 27.8M | 1D/2D/3D |
|
|
394
|
+
| `convnext_small` | 49.5M | 1D/2D/3D |
|
|
395
|
+
| `convnext_base` | 87.6M | 1D/2D/3D |
|
|
396
|
+
| `convnext_tiny_pretrained` ⭐ | 27.8M | 2D |
|
|
397
|
+
| **ConvNeXt V2** — ConvNeXt with GRN |||
|
|
398
|
+
| `convnext_v2_tiny` | 27.9M | 1D/2D/3D |
|
|
399
|
+
| `convnext_v2_small` | 49.6M | 1D/2D/3D |
|
|
400
|
+
| `convnext_v2_base` | 87.7M | 1D/2D/3D |
|
|
401
|
+
| `convnext_v2_tiny_pretrained` ⭐ | 27.9M | 2D |
|
|
402
|
+
| **UniRepLKNet** — Large-Kernel ConvNet |||
|
|
403
|
+
| `unireplknet_tiny` | 30.8M | 1D/2D/3D |
|
|
404
|
+
| `unireplknet_small` | 56.0M | 1D/2D/3D |
|
|
405
|
+
| `unireplknet_base` | 97.6M | 1D/2D/3D |
|
|
406
|
+
| **── Vision Transformers ──** |||
|
|
417
407
|
| **ViT** — Vision Transformer |||
|
|
418
|
-
| `vit_tiny` | 5.
|
|
419
|
-
| `vit_small` | 21.
|
|
420
|
-
| `vit_base` | 85.
|
|
408
|
+
| `vit_tiny` | 5.4M | 1D/2D |
|
|
409
|
+
| `vit_small` | 21.4M | 1D/2D |
|
|
410
|
+
| `vit_base` | 85.3M | 1D/2D |
|
|
411
|
+
| **Swin** — Shifted Window Transformer |||
|
|
412
|
+
| `swin_t` ⭐ | 27.5M | 2D |
|
|
413
|
+
| `swin_s` ⭐ | 48.8M | 2D |
|
|
414
|
+
| `swin_b` ⭐ | 86.7M | 2D |
|
|
415
|
+
| **MaxViT** — Multi-Axis ViT |||
|
|
416
|
+
| `maxvit_tiny` ⭐ | 30.1M | 2D |
|
|
417
|
+
| `maxvit_small` ⭐ | 67.6M | 2D |
|
|
418
|
+
| `maxvit_base` ⭐ | 119.1M | 2D |
|
|
419
|
+
| **── Hybrid CNN-Transformer ──** |||
|
|
420
|
+
| **FastViT** — Fast Hybrid CNN-ViT |||
|
|
421
|
+
| `fastvit_t8` ⭐ | 4.0M | 2D |
|
|
422
|
+
| `fastvit_t12` ⭐ | 6.8M | 2D |
|
|
423
|
+
| `fastvit_s12` ⭐ | 8.8M | 2D |
|
|
424
|
+
| `fastvit_sa12` ⭐ | 10.9M | 2D |
|
|
425
|
+
| **CAFormer** — MetaFormer with Attention |||
|
|
426
|
+
| `caformer_s18` ⭐ | 26.3M | 2D |
|
|
427
|
+
| `caformer_s36` ⭐ | 39.2M | 2D |
|
|
428
|
+
| `caformer_m36` ⭐ | 56.9M | 2D |
|
|
429
|
+
| `poolformer_s12` ⭐ | 11.9M | 2D |
|
|
430
|
+
| **EfficientViT** — Memory-Efficient ViT |||
|
|
431
|
+
| `efficientvit_m0` ⭐ | 2.2M | 2D |
|
|
432
|
+
| `efficientvit_m1` ⭐ | 2.6M | 2D |
|
|
433
|
+
| `efficientvit_m2` ⭐ | 3.8M | 2D |
|
|
434
|
+
| `efficientvit_b0` ⭐ | 2.1M | 2D |
|
|
435
|
+
| `efficientvit_b1` ⭐ | 7.5M | 2D |
|
|
436
|
+
| `efficientvit_b2` ⭐ | 21.8M | 2D |
|
|
437
|
+
| `efficientvit_b3` ⭐ | 46.1M | 2D |
|
|
438
|
+
| `efficientvit_l1` ⭐ | 49.5M | 2D |
|
|
439
|
+
| `efficientvit_l2` ⭐ | 60.5M | 2D |
|
|
440
|
+
| **── State Space Models ──** |||
|
|
441
|
+
| **Mamba** — State Space Model |||
|
|
442
|
+
| `mamba_1d` | 3.4M | 1D |
|
|
443
|
+
| **Vision Mamba (ViM)** — 2D Mamba |||
|
|
444
|
+
| `vim_tiny` | 6.6M | 2D |
|
|
445
|
+
| `vim_small` | 51.1M | 2D |
|
|
446
|
+
| `vim_base` | 201.4M | 2D |
|
|
447
|
+
| **── Specialized Architectures ──** |||
|
|
448
|
+
| **TCN** — Temporal Convolutional Network |||
|
|
449
|
+
| `tcn_small` | 0.9M | 1D |
|
|
450
|
+
| `tcn` | 6.9M | 1D |
|
|
451
|
+
| `tcn_large` | 10.0M | 1D |
|
|
452
|
+
| **ResNet3D** — 3D Residual Network |||
|
|
453
|
+
| `resnet3d_18` | 33.2M | 3D |
|
|
454
|
+
| `mc3_18` — Mixed Convolution 3D | 11.5M | 3D |
|
|
421
455
|
| **U-Net** — U-shaped Network |||
|
|
422
|
-
| `unet_regression` | 31.
|
|
456
|
+
| `unet_regression` | 31.0M | 1D/2D/3D |
|
|
457
|
+
|
|
423
458
|
|
|
424
459
|
⭐ = **Pretrained on ImageNet** (recommended for smaller datasets). Weights are downloaded automatically on first use.
|
|
425
460
|
- **Cache location**: `~/.cache/torch/hub/checkpoints/` (or `./.torch_cache/` on HPC if home is not writable)
|
|
426
|
-
- **Size**: ~20–350 MB per model depending on architecture
|
|
427
461
|
- **Train from scratch**: Use `--no_pretrained` to disable pretrained weights
|
|
428
462
|
|
|
429
463
|
**💡 HPC Users**: If compute nodes block internet, pre-download weights on the login node:
|
|
430
464
|
|
|
431
465
|
```bash
|
|
432
|
-
# Run once on login node (with internet) — downloads ALL pretrained weights
|
|
466
|
+
# Run once on login node (with internet) — downloads ALL pretrained weights
|
|
433
467
|
python -c "
|
|
434
468
|
import os
|
|
435
469
|
os.environ['TORCH_HOME'] = '.torch_cache' # Match WaveDL's HPC cache location
|
|
@@ -437,24 +471,56 @@ os.environ['TORCH_HOME'] = '.torch_cache' # Match WaveDL's HPC cache location
|
|
|
437
471
|
from torchvision import models as m
|
|
438
472
|
from torchvision.models import video as v
|
|
439
473
|
|
|
440
|
-
#
|
|
441
|
-
|
|
442
|
-
'resnet18'
|
|
443
|
-
'
|
|
444
|
-
'
|
|
445
|
-
'
|
|
446
|
-
'
|
|
447
|
-
'
|
|
448
|
-
'
|
|
449
|
-
'
|
|
450
|
-
'
|
|
451
|
-
|
|
452
|
-
|
|
453
|
-
|
|
474
|
+
# === TorchVision Models (use IMAGENET1K_V1 to match WaveDL) ===
|
|
475
|
+
models = [
|
|
476
|
+
('resnet18', m.ResNet18_Weights.IMAGENET1K_V1),
|
|
477
|
+
('resnet50', m.ResNet50_Weights.IMAGENET1K_V1),
|
|
478
|
+
('efficientnet_b0', m.EfficientNet_B0_Weights.IMAGENET1K_V1),
|
|
479
|
+
('efficientnet_b1', m.EfficientNet_B1_Weights.IMAGENET1K_V1),
|
|
480
|
+
('efficientnet_b2', m.EfficientNet_B2_Weights.IMAGENET1K_V1),
|
|
481
|
+
('efficientnet_v2_s', m.EfficientNet_V2_S_Weights.IMAGENET1K_V1),
|
|
482
|
+
('efficientnet_v2_m', m.EfficientNet_V2_M_Weights.IMAGENET1K_V1),
|
|
483
|
+
('efficientnet_v2_l', m.EfficientNet_V2_L_Weights.IMAGENET1K_V1),
|
|
484
|
+
('mobilenet_v3_small', m.MobileNet_V3_Small_Weights.IMAGENET1K_V1),
|
|
485
|
+
('mobilenet_v3_large', m.MobileNet_V3_Large_Weights.IMAGENET1K_V1),
|
|
486
|
+
('regnet_y_400mf', m.RegNet_Y_400MF_Weights.IMAGENET1K_V1),
|
|
487
|
+
('regnet_y_800mf', m.RegNet_Y_800MF_Weights.IMAGENET1K_V1),
|
|
488
|
+
('regnet_y_1_6gf', m.RegNet_Y_1_6GF_Weights.IMAGENET1K_V1),
|
|
489
|
+
('regnet_y_3_2gf', m.RegNet_Y_3_2GF_Weights.IMAGENET1K_V1),
|
|
490
|
+
('regnet_y_8gf', m.RegNet_Y_8GF_Weights.IMAGENET1K_V1),
|
|
491
|
+
('swin_t', m.Swin_T_Weights.IMAGENET1K_V1),
|
|
492
|
+
('swin_s', m.Swin_S_Weights.IMAGENET1K_V1),
|
|
493
|
+
('swin_b', m.Swin_B_Weights.IMAGENET1K_V1),
|
|
494
|
+
('convnext_tiny', m.ConvNeXt_Tiny_Weights.IMAGENET1K_V1),
|
|
495
|
+
('densenet121', m.DenseNet121_Weights.IMAGENET1K_V1),
|
|
496
|
+
]
|
|
497
|
+
for name, w in models:
|
|
498
|
+
getattr(m, name)(weights=w); print(f'✓ {name}')
|
|
454
499
|
|
|
455
500
|
# 3D video models
|
|
456
|
-
v.r3d_18(weights=v.R3D_18_Weights.
|
|
457
|
-
v.mc3_18(weights=v.MC3_18_Weights.
|
|
501
|
+
v.r3d_18(weights=v.R3D_18_Weights.KINETICS400_V1); print('✓ r3d_18')
|
|
502
|
+
v.mc3_18(weights=v.MC3_18_Weights.KINETICS400_V1); print('✓ mc3_18')
|
|
503
|
+
|
|
504
|
+
# === Timm Models (MaxViT, FastViT, CAFormer, ConvNeXt V2) ===
|
|
505
|
+
import timm
|
|
506
|
+
|
|
507
|
+
timm_models = [
|
|
508
|
+
# MaxViT (no suffix - timm resolves to default)
|
|
509
|
+
'maxvit_tiny_tf_224', 'maxvit_small_tf_224', 'maxvit_base_tf_224',
|
|
510
|
+
# FastViT (no suffix)
|
|
511
|
+
'fastvit_t8', 'fastvit_t12', 'fastvit_s12', 'fastvit_sa12',
|
|
512
|
+
# CAFormer/PoolFormer (no suffix)
|
|
513
|
+
'caformer_s18', 'caformer_s36', 'caformer_m36', 'poolformer_s12',
|
|
514
|
+
# ConvNeXt V2 (no suffix)
|
|
515
|
+
'convnextv2_tiny',
|
|
516
|
+
# EfficientViT (no suffix)
|
|
517
|
+
'efficientvit_m0', 'efficientvit_m1', 'efficientvit_m2',
|
|
518
|
+
'efficientvit_b0', 'efficientvit_b1', 'efficientvit_b2', 'efficientvit_b3',
|
|
519
|
+
'efficientvit_l1', 'efficientvit_l2',
|
|
520
|
+
]
|
|
521
|
+
for name in timm_models:
|
|
522
|
+
timm.create_model(name, pretrained=True); print(f'✓ {name}')
|
|
523
|
+
|
|
458
524
|
print('\\n✓ All pretrained weights cached!')
|
|
459
525
|
"
|
|
460
526
|
```
|
|
@@ -52,6 +52,7 @@ dependencies = [
|
|
|
52
52
|
# Core ML stack
|
|
53
53
|
"torch>=2.0.0",
|
|
54
54
|
"torchvision>=0.15.0",
|
|
55
|
+
"timm>=0.9.0", # Pretrained models (MaxViT, FastViT, CAFormer)
|
|
55
56
|
"accelerate>=0.20.0",
|
|
56
57
|
"numpy>=1.24.0",
|
|
57
58
|
"scipy>=1.10.0",
|
|
@@ -70,6 +71,7 @@ dependencies = [
|
|
|
70
71
|
# ONNX export
|
|
71
72
|
"onnx>=1.14.0",
|
|
72
73
|
"onnxruntime>=1.15.0",
|
|
74
|
+
"onnxscript>=0.1.0", # Required by torch.onnx.export in PyTorch 2.1+
|
|
73
75
|
# torch.compile backend (Linux only)
|
|
74
76
|
"triton>=2.0.0; sys_platform == 'linux'",
|
|
75
77
|
]
|