wavedl 1.5.6__tar.gz → 1.6.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {wavedl-1.5.6/src/wavedl.egg-info → wavedl-1.6.0}/PKG-INFO +104 -67
- {wavedl-1.5.6 → wavedl-1.6.0}/README.md +102 -66
- {wavedl-1.5.6 → wavedl-1.6.0}/pyproject.toml +1 -0
- {wavedl-1.5.6 → wavedl-1.6.0}/src/wavedl/__init__.py +1 -1
- {wavedl-1.5.6 → wavedl-1.6.0}/src/wavedl/models/__init__.py +52 -4
- wavedl-1.6.0/src/wavedl/models/_timm_utils.py +238 -0
- wavedl-1.6.0/src/wavedl/models/caformer.py +270 -0
- {wavedl-1.5.6 → wavedl-1.6.0}/src/wavedl/models/convnext.py +108 -33
- wavedl-1.6.0/src/wavedl/models/convnext_v2.py +504 -0
- {wavedl-1.5.6 → wavedl-1.6.0}/src/wavedl/models/densenet.py +5 -5
- {wavedl-1.5.6 → wavedl-1.6.0}/src/wavedl/models/efficientnet.py +30 -13
- {wavedl-1.5.6 → wavedl-1.6.0}/src/wavedl/models/efficientnetv2.py +32 -9
- wavedl-1.6.0/src/wavedl/models/fastvit.py +285 -0
- wavedl-1.6.0/src/wavedl/models/mamba.py +535 -0
- wavedl-1.6.0/src/wavedl/models/maxvit.py +251 -0
- {wavedl-1.5.6 → wavedl-1.6.0}/src/wavedl/models/mobilenetv3.py +35 -12
- {wavedl-1.5.6 → wavedl-1.6.0}/src/wavedl/models/regnet.py +39 -16
- {wavedl-1.5.6 → wavedl-1.6.0}/src/wavedl/models/resnet.py +5 -5
- {wavedl-1.5.6 → wavedl-1.6.0}/src/wavedl/models/resnet3d.py +2 -2
- {wavedl-1.5.6 → wavedl-1.6.0}/src/wavedl/models/swin.py +41 -9
- {wavedl-1.5.6 → wavedl-1.6.0}/src/wavedl/models/tcn.py +25 -5
- {wavedl-1.5.6 → wavedl-1.6.0}/src/wavedl/models/unet.py +1 -1
- {wavedl-1.5.6 → wavedl-1.6.0}/src/wavedl/models/vit.py +6 -6
- {wavedl-1.5.6 → wavedl-1.6.0}/src/wavedl/test.py +7 -3
- {wavedl-1.5.6 → wavedl-1.6.0}/src/wavedl/train.py +57 -23
- {wavedl-1.5.6 → wavedl-1.6.0}/src/wavedl/utils/constraints.py +11 -5
- {wavedl-1.5.6 → wavedl-1.6.0}/src/wavedl/utils/data.py +120 -18
- {wavedl-1.5.6 → wavedl-1.6.0}/src/wavedl/utils/metrics.py +287 -326
- {wavedl-1.5.6 → wavedl-1.6.0/src/wavedl.egg-info}/PKG-INFO +104 -67
- {wavedl-1.5.6 → wavedl-1.6.0}/src/wavedl.egg-info/SOURCES.txt +6 -0
- {wavedl-1.5.6 → wavedl-1.6.0}/src/wavedl.egg-info/requires.txt +1 -0
- {wavedl-1.5.6 → wavedl-1.6.0}/LICENSE +0 -0
- {wavedl-1.5.6 → wavedl-1.6.0}/setup.cfg +0 -0
- {wavedl-1.5.6 → wavedl-1.6.0}/src/wavedl/hpc.py +0 -0
- {wavedl-1.5.6 → wavedl-1.6.0}/src/wavedl/hpo.py +0 -0
- {wavedl-1.5.6 → wavedl-1.6.0}/src/wavedl/models/_template.py +0 -0
- {wavedl-1.5.6 → wavedl-1.6.0}/src/wavedl/models/base.py +0 -0
- {wavedl-1.5.6 → wavedl-1.6.0}/src/wavedl/models/cnn.py +0 -0
- {wavedl-1.5.6 → wavedl-1.6.0}/src/wavedl/models/registry.py +0 -0
- {wavedl-1.5.6 → wavedl-1.6.0}/src/wavedl/utils/__init__.py +0 -0
- {wavedl-1.5.6 → wavedl-1.6.0}/src/wavedl/utils/config.py +0 -0
- {wavedl-1.5.6 → wavedl-1.6.0}/src/wavedl/utils/cross_validation.py +0 -0
- {wavedl-1.5.6 → wavedl-1.6.0}/src/wavedl/utils/distributed.py +0 -0
- {wavedl-1.5.6 → wavedl-1.6.0}/src/wavedl/utils/losses.py +0 -0
- {wavedl-1.5.6 → wavedl-1.6.0}/src/wavedl/utils/optimizers.py +0 -0
- {wavedl-1.5.6 → wavedl-1.6.0}/src/wavedl/utils/schedulers.py +0 -0
- {wavedl-1.5.6 → wavedl-1.6.0}/src/wavedl.egg-info/dependency_links.txt +0 -0
- {wavedl-1.5.6 → wavedl-1.6.0}/src/wavedl.egg-info/entry_points.txt +0 -0
- {wavedl-1.5.6 → wavedl-1.6.0}/src/wavedl.egg-info/top_level.txt +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.2
|
|
2
2
|
Name: wavedl
|
|
3
|
-
Version: 1.
|
|
3
|
+
Version: 1.6.0
|
|
4
4
|
Summary: A Scalable Deep Learning Framework for Wave-Based Inverse Problems
|
|
5
5
|
Author: Ductho Le
|
|
6
6
|
License: MIT
|
|
@@ -23,6 +23,7 @@ Description-Content-Type: text/markdown
|
|
|
23
23
|
License-File: LICENSE
|
|
24
24
|
Requires-Dist: torch>=2.0.0
|
|
25
25
|
Requires-Dist: torchvision>=0.15.0
|
|
26
|
+
Requires-Dist: timm>=0.9.0
|
|
26
27
|
Requires-Dist: accelerate>=0.20.0
|
|
27
28
|
Requires-Dist: numpy>=1.24.0
|
|
28
29
|
Requires-Dist: scipy>=1.10.0
|
|
@@ -117,7 +118,7 @@ Train on datasets larger than RAM:
|
|
|
117
118
|
|
|
118
119
|
**🧠 Models? We've Got Options**
|
|
119
120
|
|
|
120
|
-
|
|
121
|
+
57 architectures, ready to go:
|
|
121
122
|
- CNNs, ResNets, ViTs, EfficientNets...
|
|
122
123
|
- All adapted for regression
|
|
123
124
|
- [Add your own](#adding-custom-models) in one line
|
|
@@ -202,7 +203,7 @@ Deploy models anywhere:
|
|
|
202
203
|
#### From PyPI (recommended for all users)
|
|
203
204
|
|
|
204
205
|
```bash
|
|
205
|
-
pip install wavedl
|
|
206
|
+
pip install --upgrade wavedl
|
|
206
207
|
```
|
|
207
208
|
|
|
208
209
|
This installs everything you need: training, inference, HPO, ONNX export.
|
|
@@ -358,22 +359,10 @@ WaveDL/
|
|
|
358
359
|
│ ├── hpo.py # Hyperparameter optimization
|
|
359
360
|
│ ├── hpc.py # HPC distributed training launcher
|
|
360
361
|
│ │
|
|
361
|
-
│ ├── models/ # Model
|
|
362
|
+
│ ├── models/ # Model Zoo (57 architectures)
|
|
362
363
|
│ │ ├── registry.py # Model factory (@register_model)
|
|
363
364
|
│ │ ├── base.py # Abstract base class
|
|
364
|
-
│ │
|
|
365
|
-
│ │ ├── resnet.py # ResNet-18/34/50 (1D/2D/3D)
|
|
366
|
-
│ │ ├── resnet3d.py # ResNet3D-18, MC3-18 (3D only)
|
|
367
|
-
│ │ ├── tcn.py # TCN (1D only)
|
|
368
|
-
│ │ ├── efficientnet.py # EfficientNet-B0/B1/B2 (2D)
|
|
369
|
-
│ │ ├── efficientnetv2.py # EfficientNetV2-S/M/L (2D)
|
|
370
|
-
│ │ ├── mobilenetv3.py # MobileNetV3-Small/Large (2D)
|
|
371
|
-
│ │ ├── regnet.py # RegNetY variants (2D)
|
|
372
|
-
│ │ ├── swin.py # Swin Transformer (2D)
|
|
373
|
-
│ │ ├── vit.py # Vision Transformer (1D/2D)
|
|
374
|
-
│ │ ├── convnext.py # ConvNeXt (1D/2D/3D)
|
|
375
|
-
│ │ ├── densenet.py # DenseNet-121/169 (1D/2D/3D)
|
|
376
|
-
│ │ └── unet.py # U-Net Regression
|
|
365
|
+
│ │ └── ... # See "Available Models" section
|
|
377
366
|
│ │
|
|
378
367
|
│ └── utils/ # Utilities
|
|
379
368
|
│ ├── data.py # Memory-mapped data pipeline
|
|
@@ -388,7 +377,7 @@ WaveDL/
|
|
|
388
377
|
├── configs/ # YAML config templates
|
|
389
378
|
├── examples/ # Ready-to-run examples
|
|
390
379
|
├── notebooks/ # Jupyter notebooks
|
|
391
|
-
├── unit_tests/ # Pytest test suite
|
|
380
|
+
├── unit_tests/ # Pytest test suite
|
|
392
381
|
│
|
|
393
382
|
├── pyproject.toml # Package config, dependencies
|
|
394
383
|
├── CHANGELOG.md # Version history
|
|
@@ -411,71 +400,96 @@ WaveDL/
|
|
|
411
400
|
> ```
|
|
412
401
|
|
|
413
402
|
<details>
|
|
414
|
-
<summary><b>Available Models</b> —
|
|
403
|
+
<summary><b>Available Models</b> — 57 architectures</summary>
|
|
415
404
|
|
|
416
|
-
| Model | Params | Dim |
|
|
417
|
-
|
|
405
|
+
| Model | Backbone Params | Dim |
|
|
406
|
+
|-------|-----------------|-----|
|
|
418
407
|
| **CNN** — Convolutional Neural Network |||
|
|
419
|
-
| `cnn` | 1.
|
|
408
|
+
| `cnn` | 1.6M | 1D/2D/3D |
|
|
420
409
|
| **ResNet** — Residual Network |||
|
|
421
|
-
| `resnet18` | 11.
|
|
422
|
-
| `resnet34` | 21.
|
|
423
|
-
| `resnet50` |
|
|
424
|
-
| `resnet18_pretrained` ⭐ | 11.
|
|
425
|
-
| `resnet50_pretrained` ⭐ |
|
|
410
|
+
| `resnet18` | 11.2M | 1D/2D/3D |
|
|
411
|
+
| `resnet34` | 21.3M | 1D/2D/3D |
|
|
412
|
+
| `resnet50` | 23.5M | 1D/2D/3D |
|
|
413
|
+
| `resnet18_pretrained` ⭐ | 11.2M | 2D |
|
|
414
|
+
| `resnet50_pretrained` ⭐ | 23.5M | 2D |
|
|
426
415
|
| **ResNet3D** — 3D Residual Network |||
|
|
427
|
-
| `resnet3d_18` | 33.
|
|
428
|
-
| `mc3_18` — Mixed Convolution 3D | 11.
|
|
416
|
+
| `resnet3d_18` | 33.2M | 3D |
|
|
417
|
+
| `mc3_18` — Mixed Convolution 3D | 11.5M | 3D |
|
|
429
418
|
| **TCN** — Temporal Convolutional Network |||
|
|
430
|
-
| `tcn_small` |
|
|
431
|
-
| `tcn` |
|
|
432
|
-
| `tcn_large` | 10.
|
|
419
|
+
| `tcn_small` | 0.9M | 1D |
|
|
420
|
+
| `tcn` | 6.9M | 1D |
|
|
421
|
+
| `tcn_large` | 10.0M | 1D |
|
|
433
422
|
| **EfficientNet** — Efficient Neural Network |||
|
|
434
|
-
| `efficientnet_b0` ⭐ | 4.
|
|
435
|
-
| `efficientnet_b1` ⭐ |
|
|
436
|
-
| `efficientnet_b2` ⭐ |
|
|
423
|
+
| `efficientnet_b0` ⭐ | 4.0M | 2D |
|
|
424
|
+
| `efficientnet_b1` ⭐ | 6.5M | 2D |
|
|
425
|
+
| `efficientnet_b2` ⭐ | 7.7M | 2D |
|
|
437
426
|
| **EfficientNetV2** — Efficient Neural Network V2 |||
|
|
438
|
-
| `efficientnet_v2_s` ⭐ |
|
|
439
|
-
| `efficientnet_v2_m` ⭐ |
|
|
440
|
-
| `efficientnet_v2_l` ⭐ |
|
|
427
|
+
| `efficientnet_v2_s` ⭐ | 20.2M | 2D |
|
|
428
|
+
| `efficientnet_v2_m` ⭐ | 52.9M | 2D |
|
|
429
|
+
| `efficientnet_v2_l` ⭐ | 117.2M | 2D |
|
|
441
430
|
| **MobileNetV3** — Mobile Neural Network V3 |||
|
|
442
|
-
| `mobilenet_v3_small` ⭐ |
|
|
443
|
-
| `mobilenet_v3_large` ⭐ | 3.
|
|
431
|
+
| `mobilenet_v3_small` ⭐ | 0.9M | 2D |
|
|
432
|
+
| `mobilenet_v3_large` ⭐ | 3.0M | 2D |
|
|
444
433
|
| **RegNet** — Regularized Network |||
|
|
445
|
-
| `regnet_y_400mf` ⭐ |
|
|
446
|
-
| `regnet_y_800mf` ⭐ | 5.
|
|
447
|
-
| `regnet_y_1_6gf` ⭐ | 10.
|
|
448
|
-
| `regnet_y_3_2gf` ⭐ |
|
|
449
|
-
| `regnet_y_8gf` ⭐ | 37.
|
|
434
|
+
| `regnet_y_400mf` ⭐ | 3.9M | 2D |
|
|
435
|
+
| `regnet_y_800mf` ⭐ | 5.7M | 2D |
|
|
436
|
+
| `regnet_y_1_6gf` ⭐ | 10.3M | 2D |
|
|
437
|
+
| `regnet_y_3_2gf` ⭐ | 17.9M | 2D |
|
|
438
|
+
| `regnet_y_8gf` ⭐ | 37.4M | 2D |
|
|
450
439
|
| **Swin** — Shifted Window Transformer |||
|
|
451
|
-
| `swin_t` ⭐ |
|
|
452
|
-
| `swin_s` ⭐ |
|
|
453
|
-
| `swin_b` ⭐ |
|
|
440
|
+
| `swin_t` ⭐ | 27.5M | 2D |
|
|
441
|
+
| `swin_s` ⭐ | 48.8M | 2D |
|
|
442
|
+
| `swin_b` ⭐ | 86.7M | 2D |
|
|
454
443
|
| **ConvNeXt** — Convolutional Next |||
|
|
455
|
-
| `convnext_tiny` |
|
|
456
|
-
| `convnext_small` | 49.
|
|
457
|
-
| `convnext_base` |
|
|
458
|
-
| `convnext_tiny_pretrained` ⭐ |
|
|
444
|
+
| `convnext_tiny` | 27.8M | 1D/2D/3D |
|
|
445
|
+
| `convnext_small` | 49.5M | 1D/2D/3D |
|
|
446
|
+
| `convnext_base` | 87.6M | 1D/2D/3D |
|
|
447
|
+
| `convnext_tiny_pretrained` ⭐ | 27.8M | 2D |
|
|
459
448
|
| **DenseNet** — Densely Connected Network |||
|
|
460
|
-
| `densenet121` | 7.
|
|
461
|
-
| `densenet169` |
|
|
462
|
-
| `densenet121_pretrained` ⭐ | 7.
|
|
449
|
+
| `densenet121` | 7.0M | 1D/2D/3D |
|
|
450
|
+
| `densenet169` | 12.5M | 1D/2D/3D |
|
|
451
|
+
| `densenet121_pretrained` ⭐ | 7.0M | 2D |
|
|
463
452
|
| **ViT** — Vision Transformer |||
|
|
464
|
-
| `vit_tiny` | 5.
|
|
465
|
-
| `vit_small` | 21.
|
|
466
|
-
| `vit_base` | 85.
|
|
453
|
+
| `vit_tiny` | 5.4M | 1D/2D |
|
|
454
|
+
| `vit_small` | 21.4M | 1D/2D |
|
|
455
|
+
| `vit_base` | 85.3M | 1D/2D |
|
|
456
|
+
| **ConvNeXt V2** — ConvNeXt with GRN |||
|
|
457
|
+
| `convnext_v2_tiny` | 27.9M | 1D/2D/3D |
|
|
458
|
+
| `convnext_v2_small` | 49.6M | 1D/2D/3D |
|
|
459
|
+
| `convnext_v2_base` | 87.7M | 1D/2D/3D |
|
|
460
|
+
| `convnext_v2_tiny_pretrained` ⭐ | 27.9M | 2D |
|
|
461
|
+
| **Mamba** — State Space Model |||
|
|
462
|
+
| `mamba_1d` | 3.4M | 1D |
|
|
463
|
+
| **Vision Mamba (ViM)** — 2D Mamba |||
|
|
464
|
+
| `vim_tiny` | 6.6M | 2D |
|
|
465
|
+
| `vim_small` | 51.1M | 2D |
|
|
466
|
+
| `vim_base` | 201.4M | 2D |
|
|
467
|
+
| **MaxViT** — Multi-Axis ViT |||
|
|
468
|
+
| `maxvit_tiny` ⭐ | 30.1M | 2D |
|
|
469
|
+
| `maxvit_small` ⭐ | 67.6M | 2D |
|
|
470
|
+
| `maxvit_base` ⭐ | 119.1M | 2D |
|
|
471
|
+
| **FastViT** — Fast Hybrid CNN-ViT |||
|
|
472
|
+
| `fastvit_t8` ⭐ | 4.0M | 2D |
|
|
473
|
+
| `fastvit_t12` ⭐ | 6.8M | 2D |
|
|
474
|
+
| `fastvit_s12` ⭐ | 8.8M | 2D |
|
|
475
|
+
| `fastvit_sa12` ⭐ | 10.9M | 2D |
|
|
476
|
+
| **CAFormer** — MetaFormer with Attention |||
|
|
477
|
+
| `caformer_s18` ⭐ | 26.3M | 2D |
|
|
478
|
+
| `caformer_s36` ⭐ | 39.2M | 2D |
|
|
479
|
+
| `caformer_m36` ⭐ | 56.9M | 2D |
|
|
480
|
+
| `poolformer_s12` ⭐ | 11.9M | 2D |
|
|
467
481
|
| **U-Net** — U-shaped Network |||
|
|
468
|
-
| `unet_regression` | 31.
|
|
482
|
+
| `unet_regression` | 31.0M | 1D/2D/3D |
|
|
483
|
+
|
|
469
484
|
|
|
470
485
|
⭐ = **Pretrained on ImageNet** (recommended for smaller datasets). Weights are downloaded automatically on first use.
|
|
471
486
|
- **Cache location**: `~/.cache/torch/hub/checkpoints/` (or `./.torch_cache/` on HPC if home is not writable)
|
|
472
|
-
- **Size**: ~20–350 MB per model depending on architecture
|
|
473
487
|
- **Train from scratch**: Use `--no_pretrained` to disable pretrained weights
|
|
474
488
|
|
|
475
489
|
**💡 HPC Users**: If compute nodes block internet, pre-download weights on the login node:
|
|
476
490
|
|
|
477
491
|
```bash
|
|
478
|
-
# Run once on login node (with internet) — downloads ALL pretrained weights
|
|
492
|
+
# Run once on login node (with internet) — downloads ALL pretrained weights
|
|
479
493
|
python -c "
|
|
480
494
|
import os
|
|
481
495
|
os.environ['TORCH_HOME'] = '.torch_cache' # Match WaveDL's HPC cache location
|
|
@@ -483,7 +497,7 @@ os.environ['TORCH_HOME'] = '.torch_cache' # Match WaveDL's HPC cache location
|
|
|
483
497
|
from torchvision import models as m
|
|
484
498
|
from torchvision.models import video as v
|
|
485
499
|
|
|
486
|
-
#
|
|
500
|
+
# === TorchVision Models ===
|
|
487
501
|
weights = {
|
|
488
502
|
'resnet18': m.ResNet18_Weights, 'resnet50': m.ResNet50_Weights,
|
|
489
503
|
'efficientnet_b0': m.EfficientNet_B0_Weights, 'efficientnet_b1': m.EfficientNet_B1_Weights,
|
|
@@ -501,6 +515,20 @@ for name, w in weights.items():
|
|
|
501
515
|
# 3D video models
|
|
502
516
|
v.r3d_18(weights=v.R3D_18_Weights.DEFAULT); print('✓ r3d_18')
|
|
503
517
|
v.mc3_18(weights=v.MC3_18_Weights.DEFAULT); print('✓ mc3_18')
|
|
518
|
+
|
|
519
|
+
# === Timm Models (MaxViT, FastViT, CAFormer, ConvNeXt V2) ===
|
|
520
|
+
import timm
|
|
521
|
+
|
|
522
|
+
timm_models = [
|
|
523
|
+
'maxvit_tiny_tf_224.in1k', 'maxvit_small_tf_224.in1k', 'maxvit_base_tf_224.in1k',
|
|
524
|
+
'fastvit_t8.apple_in1k', 'fastvit_t12.apple_in1k', 'fastvit_s12.apple_in1k', 'fastvit_sa12.apple_in1k',
|
|
525
|
+
'caformer_s18.sail_in1k', 'caformer_s36.sail_in22k_ft_in1k', 'caformer_m36.sail_in22k_ft_in1k',
|
|
526
|
+
'poolformer_s12.sail_in1k',
|
|
527
|
+
'convnextv2_tiny.fcmae_ft_in1k',
|
|
528
|
+
]
|
|
529
|
+
for name in timm_models:
|
|
530
|
+
timm.create_model(name, pretrained=True); print(f'✓ {name}')
|
|
531
|
+
|
|
504
532
|
print('\\n✓ All pretrained weights cached!')
|
|
505
533
|
"
|
|
506
534
|
```
|
|
@@ -1035,12 +1063,20 @@ The `examples/` folder contains a **complete, ready-to-run example** for **mater
|
|
|
1035
1063
|
|
|
1036
1064
|
| Parameter | Unit | Description |
|
|
1037
1065
|
|-----------|------|-------------|
|
|
1038
|
-
|
|
|
1039
|
-
|
|
|
1040
|
-
|
|
|
1066
|
+
| $h$ | mm | Plate thickness |
|
|
1067
|
+
| $\sqrt{E/\rho}$ | km/s | Square root of Young's modulus over density |
|
|
1068
|
+
| $\nu$ | — | Poisson's ratio |
|
|
1041
1069
|
|
|
1042
1070
|
> [!NOTE]
|
|
1043
|
-
> This example is based on our paper at **SPIE Smart Structures + NDE 2026**: [*"
|
|
1071
|
+
> This example is based on our paper at **SPIE Smart Structures + NDE 2026**: [*"A lightweight deep learning model for ultrasonic assessment of plate thickness and elasticity
|
|
1072
|
+
"*](https://spie.org/spie-smart-structures-and-materials-nondestructive-evaluation/presentation/A-lightweight-deep-learning-model-for-ultrasonic-assessment-of-plate/13951-4) (Paper 13951-4, to appear).
|
|
1073
|
+
|
|
1074
|
+
**Sample Dispersion Data:**
|
|
1075
|
+
|
|
1076
|
+
<p align="center">
|
|
1077
|
+
<img src="examples/elasticity_prediction/dispersion_samples.png" alt="Dispersion curve samples" width="700"><br>
|
|
1078
|
+
<em>Test samples showing the wavenumber-frequency relationship for different plate properties</em>
|
|
1079
|
+
</p>
|
|
1044
1080
|
|
|
1045
1081
|
**Try it yourself:**
|
|
1046
1082
|
|
|
@@ -1061,7 +1097,8 @@ python -m wavedl.test --checkpoint ./examples/elasticity_prediction/best_checkpo
|
|
|
1061
1097
|
| File | Description |
|
|
1062
1098
|
|------|-------------|
|
|
1063
1099
|
| `best_checkpoint/` | Pre-trained MobileNetV3 checkpoint |
|
|
1064
|
-
| `Test_data_100.mat` | 100 sample test set (500×500 dispersion curves →
|
|
1100
|
+
| `Test_data_100.mat` | 100 sample test set (500×500 dispersion curves → $h$, $\sqrt{E/\rho}$, $\nu$) |
|
|
1101
|
+
| `dispersion_samples.png` | Visualization of sample dispersion curves with material parameters |
|
|
1065
1102
|
| `model.onnx` | ONNX export with embedded de-normalization |
|
|
1066
1103
|
| `training_history.csv` | Epoch-by-epoch training metrics (loss, R², LR, etc.) |
|
|
1067
1104
|
| `training_curves.png` | Training/validation loss and learning rate plot |
|
|
@@ -71,7 +71,7 @@ Train on datasets larger than RAM:
|
|
|
71
71
|
|
|
72
72
|
**🧠 Models? We've Got Options**
|
|
73
73
|
|
|
74
|
-
|
|
74
|
+
57 architectures, ready to go:
|
|
75
75
|
- CNNs, ResNets, ViTs, EfficientNets...
|
|
76
76
|
- All adapted for regression
|
|
77
77
|
- [Add your own](#adding-custom-models) in one line
|
|
@@ -156,7 +156,7 @@ Deploy models anywhere:
|
|
|
156
156
|
#### From PyPI (recommended for all users)
|
|
157
157
|
|
|
158
158
|
```bash
|
|
159
|
-
pip install wavedl
|
|
159
|
+
pip install --upgrade wavedl
|
|
160
160
|
```
|
|
161
161
|
|
|
162
162
|
This installs everything you need: training, inference, HPO, ONNX export.
|
|
@@ -312,22 +312,10 @@ WaveDL/
|
|
|
312
312
|
│ ├── hpo.py # Hyperparameter optimization
|
|
313
313
|
│ ├── hpc.py # HPC distributed training launcher
|
|
314
314
|
│ │
|
|
315
|
-
│ ├── models/ # Model
|
|
315
|
+
│ ├── models/ # Model Zoo (57 architectures)
|
|
316
316
|
│ │ ├── registry.py # Model factory (@register_model)
|
|
317
317
|
│ │ ├── base.py # Abstract base class
|
|
318
|
-
│ │
|
|
319
|
-
│ │ ├── resnet.py # ResNet-18/34/50 (1D/2D/3D)
|
|
320
|
-
│ │ ├── resnet3d.py # ResNet3D-18, MC3-18 (3D only)
|
|
321
|
-
│ │ ├── tcn.py # TCN (1D only)
|
|
322
|
-
│ │ ├── efficientnet.py # EfficientNet-B0/B1/B2 (2D)
|
|
323
|
-
│ │ ├── efficientnetv2.py # EfficientNetV2-S/M/L (2D)
|
|
324
|
-
│ │ ├── mobilenetv3.py # MobileNetV3-Small/Large (2D)
|
|
325
|
-
│ │ ├── regnet.py # RegNetY variants (2D)
|
|
326
|
-
│ │ ├── swin.py # Swin Transformer (2D)
|
|
327
|
-
│ │ ├── vit.py # Vision Transformer (1D/2D)
|
|
328
|
-
│ │ ├── convnext.py # ConvNeXt (1D/2D/3D)
|
|
329
|
-
│ │ ├── densenet.py # DenseNet-121/169 (1D/2D/3D)
|
|
330
|
-
│ │ └── unet.py # U-Net Regression
|
|
318
|
+
│ │ └── ... # See "Available Models" section
|
|
331
319
|
│ │
|
|
332
320
|
│ └── utils/ # Utilities
|
|
333
321
|
│ ├── data.py # Memory-mapped data pipeline
|
|
@@ -342,7 +330,7 @@ WaveDL/
|
|
|
342
330
|
├── configs/ # YAML config templates
|
|
343
331
|
├── examples/ # Ready-to-run examples
|
|
344
332
|
├── notebooks/ # Jupyter notebooks
|
|
345
|
-
├── unit_tests/ # Pytest test suite
|
|
333
|
+
├── unit_tests/ # Pytest test suite
|
|
346
334
|
│
|
|
347
335
|
├── pyproject.toml # Package config, dependencies
|
|
348
336
|
├── CHANGELOG.md # Version history
|
|
@@ -365,71 +353,96 @@ WaveDL/
|
|
|
365
353
|
> ```
|
|
366
354
|
|
|
367
355
|
<details>
|
|
368
|
-
<summary><b>Available Models</b> —
|
|
356
|
+
<summary><b>Available Models</b> — 57 architectures</summary>
|
|
369
357
|
|
|
370
|
-
| Model | Params | Dim |
|
|
371
|
-
|
|
358
|
+
| Model | Backbone Params | Dim |
|
|
359
|
+
|-------|-----------------|-----|
|
|
372
360
|
| **CNN** — Convolutional Neural Network |||
|
|
373
|
-
| `cnn` | 1.
|
|
361
|
+
| `cnn` | 1.6M | 1D/2D/3D |
|
|
374
362
|
| **ResNet** — Residual Network |||
|
|
375
|
-
| `resnet18` | 11.
|
|
376
|
-
| `resnet34` | 21.
|
|
377
|
-
| `resnet50` |
|
|
378
|
-
| `resnet18_pretrained` ⭐ | 11.
|
|
379
|
-
| `resnet50_pretrained` ⭐ |
|
|
363
|
+
| `resnet18` | 11.2M | 1D/2D/3D |
|
|
364
|
+
| `resnet34` | 21.3M | 1D/2D/3D |
|
|
365
|
+
| `resnet50` | 23.5M | 1D/2D/3D |
|
|
366
|
+
| `resnet18_pretrained` ⭐ | 11.2M | 2D |
|
|
367
|
+
| `resnet50_pretrained` ⭐ | 23.5M | 2D |
|
|
380
368
|
| **ResNet3D** — 3D Residual Network |||
|
|
381
|
-
| `resnet3d_18` | 33.
|
|
382
|
-
| `mc3_18` — Mixed Convolution 3D | 11.
|
|
369
|
+
| `resnet3d_18` | 33.2M | 3D |
|
|
370
|
+
| `mc3_18` — Mixed Convolution 3D | 11.5M | 3D |
|
|
383
371
|
| **TCN** — Temporal Convolutional Network |||
|
|
384
|
-
| `tcn_small` |
|
|
385
|
-
| `tcn` |
|
|
386
|
-
| `tcn_large` | 10.
|
|
372
|
+
| `tcn_small` | 0.9M | 1D |
|
|
373
|
+
| `tcn` | 6.9M | 1D |
|
|
374
|
+
| `tcn_large` | 10.0M | 1D |
|
|
387
375
|
| **EfficientNet** — Efficient Neural Network |||
|
|
388
|
-
| `efficientnet_b0` ⭐ | 4.
|
|
389
|
-
| `efficientnet_b1` ⭐ |
|
|
390
|
-
| `efficientnet_b2` ⭐ |
|
|
376
|
+
| `efficientnet_b0` ⭐ | 4.0M | 2D |
|
|
377
|
+
| `efficientnet_b1` ⭐ | 6.5M | 2D |
|
|
378
|
+
| `efficientnet_b2` ⭐ | 7.7M | 2D |
|
|
391
379
|
| **EfficientNetV2** — Efficient Neural Network V2 |||
|
|
392
|
-
| `efficientnet_v2_s` ⭐ |
|
|
393
|
-
| `efficientnet_v2_m` ⭐ |
|
|
394
|
-
| `efficientnet_v2_l` ⭐ |
|
|
380
|
+
| `efficientnet_v2_s` ⭐ | 20.2M | 2D |
|
|
381
|
+
| `efficientnet_v2_m` ⭐ | 52.9M | 2D |
|
|
382
|
+
| `efficientnet_v2_l` ⭐ | 117.2M | 2D |
|
|
395
383
|
| **MobileNetV3** — Mobile Neural Network V3 |||
|
|
396
|
-
| `mobilenet_v3_small` ⭐ |
|
|
397
|
-
| `mobilenet_v3_large` ⭐ | 3.
|
|
384
|
+
| `mobilenet_v3_small` ⭐ | 0.9M | 2D |
|
|
385
|
+
| `mobilenet_v3_large` ⭐ | 3.0M | 2D |
|
|
398
386
|
| **RegNet** — Regularized Network |||
|
|
399
|
-
| `regnet_y_400mf` ⭐ |
|
|
400
|
-
| `regnet_y_800mf` ⭐ | 5.
|
|
401
|
-
| `regnet_y_1_6gf` ⭐ | 10.
|
|
402
|
-
| `regnet_y_3_2gf` ⭐ |
|
|
403
|
-
| `regnet_y_8gf` ⭐ | 37.
|
|
387
|
+
| `regnet_y_400mf` ⭐ | 3.9M | 2D |
|
|
388
|
+
| `regnet_y_800mf` ⭐ | 5.7M | 2D |
|
|
389
|
+
| `regnet_y_1_6gf` ⭐ | 10.3M | 2D |
|
|
390
|
+
| `regnet_y_3_2gf` ⭐ | 17.9M | 2D |
|
|
391
|
+
| `regnet_y_8gf` ⭐ | 37.4M | 2D |
|
|
404
392
|
| **Swin** — Shifted Window Transformer |||
|
|
405
|
-
| `swin_t` ⭐ |
|
|
406
|
-
| `swin_s` ⭐ |
|
|
407
|
-
| `swin_b` ⭐ |
|
|
393
|
+
| `swin_t` ⭐ | 27.5M | 2D |
|
|
394
|
+
| `swin_s` ⭐ | 48.8M | 2D |
|
|
395
|
+
| `swin_b` ⭐ | 86.7M | 2D |
|
|
408
396
|
| **ConvNeXt** — Convolutional Next |||
|
|
409
|
-
| `convnext_tiny` |
|
|
410
|
-
| `convnext_small` | 49.
|
|
411
|
-
| `convnext_base` |
|
|
412
|
-
| `convnext_tiny_pretrained` ⭐ |
|
|
397
|
+
| `convnext_tiny` | 27.8M | 1D/2D/3D |
|
|
398
|
+
| `convnext_small` | 49.5M | 1D/2D/3D |
|
|
399
|
+
| `convnext_base` | 87.6M | 1D/2D/3D |
|
|
400
|
+
| `convnext_tiny_pretrained` ⭐ | 27.8M | 2D |
|
|
413
401
|
| **DenseNet** — Densely Connected Network |||
|
|
414
|
-
| `densenet121` | 7.
|
|
415
|
-
| `densenet169` |
|
|
416
|
-
| `densenet121_pretrained` ⭐ | 7.
|
|
402
|
+
| `densenet121` | 7.0M | 1D/2D/3D |
|
|
403
|
+
| `densenet169` | 12.5M | 1D/2D/3D |
|
|
404
|
+
| `densenet121_pretrained` ⭐ | 7.0M | 2D |
|
|
417
405
|
| **ViT** — Vision Transformer |||
|
|
418
|
-
| `vit_tiny` | 5.
|
|
419
|
-
| `vit_small` | 21.
|
|
420
|
-
| `vit_base` | 85.
|
|
406
|
+
| `vit_tiny` | 5.4M | 1D/2D |
|
|
407
|
+
| `vit_small` | 21.4M | 1D/2D |
|
|
408
|
+
| `vit_base` | 85.3M | 1D/2D |
|
|
409
|
+
| **ConvNeXt V2** — ConvNeXt with GRN |||
|
|
410
|
+
| `convnext_v2_tiny` | 27.9M | 1D/2D/3D |
|
|
411
|
+
| `convnext_v2_small` | 49.6M | 1D/2D/3D |
|
|
412
|
+
| `convnext_v2_base` | 87.7M | 1D/2D/3D |
|
|
413
|
+
| `convnext_v2_tiny_pretrained` ⭐ | 27.9M | 2D |
|
|
414
|
+
| **Mamba** — State Space Model |||
|
|
415
|
+
| `mamba_1d` | 3.4M | 1D |
|
|
416
|
+
| **Vision Mamba (ViM)** — 2D Mamba |||
|
|
417
|
+
| `vim_tiny` | 6.6M | 2D |
|
|
418
|
+
| `vim_small` | 51.1M | 2D |
|
|
419
|
+
| `vim_base` | 201.4M | 2D |
|
|
420
|
+
| **MaxViT** — Multi-Axis ViT |||
|
|
421
|
+
| `maxvit_tiny` ⭐ | 30.1M | 2D |
|
|
422
|
+
| `maxvit_small` ⭐ | 67.6M | 2D |
|
|
423
|
+
| `maxvit_base` ⭐ | 119.1M | 2D |
|
|
424
|
+
| **FastViT** — Fast Hybrid CNN-ViT |||
|
|
425
|
+
| `fastvit_t8` ⭐ | 4.0M | 2D |
|
|
426
|
+
| `fastvit_t12` ⭐ | 6.8M | 2D |
|
|
427
|
+
| `fastvit_s12` ⭐ | 8.8M | 2D |
|
|
428
|
+
| `fastvit_sa12` ⭐ | 10.9M | 2D |
|
|
429
|
+
| **CAFormer** — MetaFormer with Attention |||
|
|
430
|
+
| `caformer_s18` ⭐ | 26.3M | 2D |
|
|
431
|
+
| `caformer_s36` ⭐ | 39.2M | 2D |
|
|
432
|
+
| `caformer_m36` ⭐ | 56.9M | 2D |
|
|
433
|
+
| `poolformer_s12` ⭐ | 11.9M | 2D |
|
|
421
434
|
| **U-Net** — U-shaped Network |||
|
|
422
|
-
| `unet_regression` | 31.
|
|
435
|
+
| `unet_regression` | 31.0M | 1D/2D/3D |
|
|
436
|
+
|
|
423
437
|
|
|
424
438
|
⭐ = **Pretrained on ImageNet** (recommended for smaller datasets). Weights are downloaded automatically on first use.
|
|
425
439
|
- **Cache location**: `~/.cache/torch/hub/checkpoints/` (or `./.torch_cache/` on HPC if home is not writable)
|
|
426
|
-
- **Size**: ~20–350 MB per model depending on architecture
|
|
427
440
|
- **Train from scratch**: Use `--no_pretrained` to disable pretrained weights
|
|
428
441
|
|
|
429
442
|
**💡 HPC Users**: If compute nodes block internet, pre-download weights on the login node:
|
|
430
443
|
|
|
431
444
|
```bash
|
|
432
|
-
# Run once on login node (with internet) — downloads ALL pretrained weights
|
|
445
|
+
# Run once on login node (with internet) — downloads ALL pretrained weights
|
|
433
446
|
python -c "
|
|
434
447
|
import os
|
|
435
448
|
os.environ['TORCH_HOME'] = '.torch_cache' # Match WaveDL's HPC cache location
|
|
@@ -437,7 +450,7 @@ os.environ['TORCH_HOME'] = '.torch_cache' # Match WaveDL's HPC cache location
|
|
|
437
450
|
from torchvision import models as m
|
|
438
451
|
from torchvision.models import video as v
|
|
439
452
|
|
|
440
|
-
#
|
|
453
|
+
# === TorchVision Models ===
|
|
441
454
|
weights = {
|
|
442
455
|
'resnet18': m.ResNet18_Weights, 'resnet50': m.ResNet50_Weights,
|
|
443
456
|
'efficientnet_b0': m.EfficientNet_B0_Weights, 'efficientnet_b1': m.EfficientNet_B1_Weights,
|
|
@@ -455,6 +468,20 @@ for name, w in weights.items():
|
|
|
455
468
|
# 3D video models
|
|
456
469
|
v.r3d_18(weights=v.R3D_18_Weights.DEFAULT); print('✓ r3d_18')
|
|
457
470
|
v.mc3_18(weights=v.MC3_18_Weights.DEFAULT); print('✓ mc3_18')
|
|
471
|
+
|
|
472
|
+
# === Timm Models (MaxViT, FastViT, CAFormer, ConvNeXt V2) ===
|
|
473
|
+
import timm
|
|
474
|
+
|
|
475
|
+
timm_models = [
|
|
476
|
+
'maxvit_tiny_tf_224.in1k', 'maxvit_small_tf_224.in1k', 'maxvit_base_tf_224.in1k',
|
|
477
|
+
'fastvit_t8.apple_in1k', 'fastvit_t12.apple_in1k', 'fastvit_s12.apple_in1k', 'fastvit_sa12.apple_in1k',
|
|
478
|
+
'caformer_s18.sail_in1k', 'caformer_s36.sail_in22k_ft_in1k', 'caformer_m36.sail_in22k_ft_in1k',
|
|
479
|
+
'poolformer_s12.sail_in1k',
|
|
480
|
+
'convnextv2_tiny.fcmae_ft_in1k',
|
|
481
|
+
]
|
|
482
|
+
for name in timm_models:
|
|
483
|
+
timm.create_model(name, pretrained=True); print(f'✓ {name}')
|
|
484
|
+
|
|
458
485
|
print('\\n✓ All pretrained weights cached!')
|
|
459
486
|
"
|
|
460
487
|
```
|
|
@@ -989,12 +1016,20 @@ The `examples/` folder contains a **complete, ready-to-run example** for **mater
|
|
|
989
1016
|
|
|
990
1017
|
| Parameter | Unit | Description |
|
|
991
1018
|
|-----------|------|-------------|
|
|
992
|
-
|
|
|
993
|
-
|
|
|
994
|
-
|
|
|
1019
|
+
| $h$ | mm | Plate thickness |
|
|
1020
|
+
| $\sqrt{E/\rho}$ | km/s | Square root of Young's modulus over density |
|
|
1021
|
+
| $\nu$ | — | Poisson's ratio |
|
|
995
1022
|
|
|
996
1023
|
> [!NOTE]
|
|
997
|
-
> This example is based on our paper at **SPIE Smart Structures + NDE 2026**: [*"
|
|
1024
|
+
> This example is based on our paper at **SPIE Smart Structures + NDE 2026**: [*"A lightweight deep learning model for ultrasonic assessment of plate thickness and elasticity
|
|
1025
|
+
"*](https://spie.org/spie-smart-structures-and-materials-nondestructive-evaluation/presentation/A-lightweight-deep-learning-model-for-ultrasonic-assessment-of-plate/13951-4) (Paper 13951-4, to appear).
|
|
1026
|
+
|
|
1027
|
+
**Sample Dispersion Data:**
|
|
1028
|
+
|
|
1029
|
+
<p align="center">
|
|
1030
|
+
<img src="examples/elasticity_prediction/dispersion_samples.png" alt="Dispersion curve samples" width="700"><br>
|
|
1031
|
+
<em>Test samples showing the wavenumber-frequency relationship for different plate properties</em>
|
|
1032
|
+
</p>
|
|
998
1033
|
|
|
999
1034
|
**Try it yourself:**
|
|
1000
1035
|
|
|
@@ -1015,7 +1050,8 @@ python -m wavedl.test --checkpoint ./examples/elasticity_prediction/best_checkpo
|
|
|
1015
1050
|
| File | Description |
|
|
1016
1051
|
|------|-------------|
|
|
1017
1052
|
| `best_checkpoint/` | Pre-trained MobileNetV3 checkpoint |
|
|
1018
|
-
| `Test_data_100.mat` | 100 sample test set (500×500 dispersion curves →
|
|
1053
|
+
| `Test_data_100.mat` | 100 sample test set (500×500 dispersion curves → $h$, $\sqrt{E/\rho}$, $\nu$) |
|
|
1054
|
+
| `dispersion_samples.png` | Visualization of sample dispersion curves with material parameters |
|
|
1019
1055
|
| `model.onnx` | ONNX export with embedded de-normalization |
|
|
1020
1056
|
| `training_history.csv` | Epoch-by-epoch training metrics (loss, R², LR, etc.) |
|
|
1021
1057
|
| `training_curves.png` | Training/validation loss and learning rate plot |
|
|
@@ -6,10 +6,11 @@ This module provides a centralized registry for neural network architectures,
|
|
|
6
6
|
enabling dynamic model selection via command-line arguments.
|
|
7
7
|
|
|
8
8
|
**Dimensionality Coverage**:
|
|
9
|
-
- 1D (waveforms): TCN, CNN, ResNet, ConvNeXt, DenseNet, ViT
|
|
10
|
-
- 2D (images): CNN, ResNet, ConvNeXt, DenseNet, ViT, UNet,
|
|
11
|
-
EfficientNet, MobileNetV3, RegNet, Swin
|
|
12
|
-
|
|
9
|
+
- 1D (waveforms): TCN, CNN, ResNet, ConvNeXt, ConvNeXt V2, DenseNet, ViT, Mamba
|
|
10
|
+
- 2D (images): CNN, ResNet, ConvNeXt, ConvNeXt V2, DenseNet, ViT, UNet,
|
|
11
|
+
EfficientNet, MobileNetV3, RegNet, Swin, MaxViT, FastViT,
|
|
12
|
+
CAFormer, PoolFormer, Vision Mamba
|
|
13
|
+
- 3D (volumes): ResNet3D, CNN, ResNet, ConvNeXt, ConvNeXt V2, DenseNet
|
|
13
14
|
|
|
14
15
|
Usage:
|
|
15
16
|
from wavedl.models import get_model, list_models, MODEL_REGISTRY
|
|
@@ -46,9 +47,19 @@ from .base import BaseModel
|
|
|
46
47
|
# Import model implementations (triggers registration via decorators)
|
|
47
48
|
from .cnn import CNN
|
|
48
49
|
from .convnext import ConvNeXtBase_, ConvNeXtSmall, ConvNeXtTiny
|
|
50
|
+
|
|
51
|
+
# New models (v1.6+)
|
|
52
|
+
from .convnext_v2 import (
|
|
53
|
+
ConvNeXtV2Base,
|
|
54
|
+
ConvNeXtV2BaseLarge,
|
|
55
|
+
ConvNeXtV2Small,
|
|
56
|
+
ConvNeXtV2Tiny,
|
|
57
|
+
ConvNeXtV2TinyPretrained,
|
|
58
|
+
)
|
|
49
59
|
from .densenet import DenseNet121, DenseNet169
|
|
50
60
|
from .efficientnet import EfficientNetB0, EfficientNetB1, EfficientNetB2
|
|
51
61
|
from .efficientnetv2 import EfficientNetV2L, EfficientNetV2M, EfficientNetV2S
|
|
62
|
+
from .mamba import Mamba1D, VimBase, VimSmall, VimTiny
|
|
52
63
|
from .mobilenetv3 import MobileNetV3Large, MobileNetV3Small
|
|
53
64
|
from .registry import (
|
|
54
65
|
MODEL_REGISTRY,
|
|
@@ -66,6 +77,17 @@ from .unet import UNetRegression
|
|
|
66
77
|
from .vit import ViTBase_, ViTSmall, ViTTiny
|
|
67
78
|
|
|
68
79
|
|
|
80
|
+
# Optional timm-based models (imported conditionally)
|
|
81
|
+
try:
|
|
82
|
+
from .caformer import CaFormerS18, CaFormerS36, PoolFormerS12
|
|
83
|
+
from .fastvit import FastViTS12, FastViTSA12, FastViTT8, FastViTT12
|
|
84
|
+
from .maxvit import MaxViTBaseLarge, MaxViTSmall, MaxViTTiny
|
|
85
|
+
|
|
86
|
+
_HAS_TIMM_MODELS = True
|
|
87
|
+
except ImportError:
|
|
88
|
+
_HAS_TIMM_MODELS = False
|
|
89
|
+
|
|
90
|
+
|
|
69
91
|
# Export public API (sorted alphabetically per RUF022)
|
|
70
92
|
# See module docstring for dimensionality support details
|
|
71
93
|
__all__ = [
|
|
@@ -77,6 +99,11 @@ __all__ = [
|
|
|
77
99
|
"ConvNeXtBase_",
|
|
78
100
|
"ConvNeXtSmall",
|
|
79
101
|
"ConvNeXtTiny",
|
|
102
|
+
"ConvNeXtV2Base",
|
|
103
|
+
"ConvNeXtV2BaseLarge",
|
|
104
|
+
"ConvNeXtV2Small",
|
|
105
|
+
"ConvNeXtV2Tiny",
|
|
106
|
+
"ConvNeXtV2TinyPretrained",
|
|
80
107
|
"DenseNet121",
|
|
81
108
|
"DenseNet169",
|
|
82
109
|
"EfficientNetB0",
|
|
@@ -85,6 +112,7 @@ __all__ = [
|
|
|
85
112
|
"EfficientNetV2L",
|
|
86
113
|
"EfficientNetV2M",
|
|
87
114
|
"EfficientNetV2S",
|
|
115
|
+
"Mamba1D",
|
|
88
116
|
"MobileNetV3Large",
|
|
89
117
|
"MobileNetV3Small",
|
|
90
118
|
"RegNetY1_6GF",
|
|
@@ -105,8 +133,28 @@ __all__ = [
|
|
|
105
133
|
"ViTBase_",
|
|
106
134
|
"ViTSmall",
|
|
107
135
|
"ViTTiny",
|
|
136
|
+
"VimBase",
|
|
137
|
+
"VimSmall",
|
|
138
|
+
"VimTiny",
|
|
108
139
|
"build_model",
|
|
109
140
|
"get_model",
|
|
110
141
|
"list_models",
|
|
111
142
|
"register_model",
|
|
112
143
|
]
|
|
144
|
+
|
|
145
|
+
# Add timm-based models to __all__ if available
|
|
146
|
+
if _HAS_TIMM_MODELS:
|
|
147
|
+
__all__.extend(
|
|
148
|
+
[
|
|
149
|
+
"CaFormerS18",
|
|
150
|
+
"CaFormerS36",
|
|
151
|
+
"FastViTS12",
|
|
152
|
+
"FastViTSA12",
|
|
153
|
+
"FastViTT8",
|
|
154
|
+
"FastViTT12",
|
|
155
|
+
"MaxViTBaseLarge",
|
|
156
|
+
"MaxViTSmall",
|
|
157
|
+
"MaxViTTiny",
|
|
158
|
+
"PoolFormerS12",
|
|
159
|
+
]
|
|
160
|
+
)
|