wavedl 1.3.1__tar.gz → 1.4.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (42) hide show
  1. {wavedl-1.3.1/src/wavedl.egg-info → wavedl-1.4.0}/PKG-INFO +111 -93
  2. {wavedl-1.3.1 → wavedl-1.4.0}/README.md +101 -70
  3. {wavedl-1.3.1 → wavedl-1.4.0}/pyproject.toml +15 -12
  4. {wavedl-1.3.1 → wavedl-1.4.0}/src/wavedl/__init__.py +1 -1
  5. {wavedl-1.3.1 → wavedl-1.4.0}/src/wavedl/hpc.py +28 -26
  6. {wavedl-1.3.1 → wavedl-1.4.0}/src/wavedl/models/__init__.py +33 -7
  7. {wavedl-1.3.1 → wavedl-1.4.0}/src/wavedl/models/_template.py +0 -1
  8. {wavedl-1.3.1 → wavedl-1.4.0}/src/wavedl/models/base.py +0 -1
  9. {wavedl-1.3.1 → wavedl-1.4.0}/src/wavedl/models/cnn.py +0 -1
  10. {wavedl-1.3.1 → wavedl-1.4.0}/src/wavedl/models/convnext.py +4 -1
  11. {wavedl-1.3.1 → wavedl-1.4.0}/src/wavedl/models/densenet.py +4 -1
  12. {wavedl-1.3.1 → wavedl-1.4.0}/src/wavedl/models/efficientnet.py +9 -5
  13. wavedl-1.4.0/src/wavedl/models/efficientnetv2.py +292 -0
  14. wavedl-1.4.0/src/wavedl/models/mobilenetv3.py +272 -0
  15. {wavedl-1.3.1 → wavedl-1.4.0}/src/wavedl/models/registry.py +0 -1
  16. wavedl-1.4.0/src/wavedl/models/regnet.py +383 -0
  17. {wavedl-1.3.1 → wavedl-1.4.0}/src/wavedl/models/resnet.py +7 -4
  18. wavedl-1.4.0/src/wavedl/models/resnet3d.py +258 -0
  19. wavedl-1.4.0/src/wavedl/models/swin.py +390 -0
  20. wavedl-1.4.0/src/wavedl/models/tcn.py +389 -0
  21. {wavedl-1.3.1 → wavedl-1.4.0}/src/wavedl/models/unet.py +44 -110
  22. {wavedl-1.3.1 → wavedl-1.4.0}/src/wavedl/models/vit.py +8 -4
  23. {wavedl-1.3.1 → wavedl-1.4.0}/src/wavedl/train.py +1113 -1116
  24. {wavedl-1.3.1 → wavedl-1.4.0/src/wavedl.egg-info}/PKG-INFO +111 -93
  25. {wavedl-1.3.1 → wavedl-1.4.0}/src/wavedl.egg-info/SOURCES.txt +6 -0
  26. {wavedl-1.3.1 → wavedl-1.4.0}/src/wavedl.egg-info/requires.txt +4 -24
  27. {wavedl-1.3.1 → wavedl-1.4.0}/LICENSE +0 -0
  28. {wavedl-1.3.1 → wavedl-1.4.0}/setup.cfg +0 -0
  29. {wavedl-1.3.1 → wavedl-1.4.0}/src/wavedl/hpo.py +0 -0
  30. {wavedl-1.3.1 → wavedl-1.4.0}/src/wavedl/test.py +0 -0
  31. {wavedl-1.3.1 → wavedl-1.4.0}/src/wavedl/utils/__init__.py +0 -0
  32. {wavedl-1.3.1 → wavedl-1.4.0}/src/wavedl/utils/config.py +0 -0
  33. {wavedl-1.3.1 → wavedl-1.4.0}/src/wavedl/utils/cross_validation.py +0 -0
  34. {wavedl-1.3.1 → wavedl-1.4.0}/src/wavedl/utils/data.py +0 -0
  35. {wavedl-1.3.1 → wavedl-1.4.0}/src/wavedl/utils/distributed.py +0 -0
  36. {wavedl-1.3.1 → wavedl-1.4.0}/src/wavedl/utils/losses.py +0 -0
  37. {wavedl-1.3.1 → wavedl-1.4.0}/src/wavedl/utils/metrics.py +0 -0
  38. {wavedl-1.3.1 → wavedl-1.4.0}/src/wavedl/utils/optimizers.py +0 -0
  39. {wavedl-1.3.1 → wavedl-1.4.0}/src/wavedl/utils/schedulers.py +0 -0
  40. {wavedl-1.3.1 → wavedl-1.4.0}/src/wavedl.egg-info/dependency_links.txt +0 -0
  41. {wavedl-1.3.1 → wavedl-1.4.0}/src/wavedl.egg-info/entry_points.txt +0 -0
  42. {wavedl-1.3.1 → wavedl-1.4.0}/src/wavedl.egg-info/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.2
2
2
  Name: wavedl
3
- Version: 1.3.1
3
+ Version: 1.4.0
4
4
  Summary: A Scalable Deep Learning Framework for Wave-Based Inverse Problems
5
5
  Author: Ductho Le
6
6
  License: MIT
@@ -30,31 +30,18 @@ Requires-Dist: scikit-learn>=1.2.0
30
30
  Requires-Dist: pandas>=2.0.0
31
31
  Requires-Dist: matplotlib>=3.7.0
32
32
  Requires-Dist: tqdm>=4.65.0
33
- Requires-Dist: wandb>=0.15.0
34
33
  Requires-Dist: pyyaml>=6.0.0
35
34
  Requires-Dist: h5py>=3.8.0
36
35
  Requires-Dist: safetensors>=0.3.0
37
- Provides-Extra: dev
38
- Requires-Dist: pytest>=7.0.0; extra == "dev"
39
- Requires-Dist: pytest-xdist>=3.5.0; extra == "dev"
40
- Requires-Dist: ruff>=0.8.0; extra == "dev"
41
- Requires-Dist: pre-commit>=3.5.0; extra == "dev"
42
- Provides-Extra: onnx
43
- Requires-Dist: onnx>=1.14.0; extra == "onnx"
44
- Requires-Dist: onnxruntime>=1.15.0; extra == "onnx"
45
- Provides-Extra: compile
46
- Requires-Dist: triton; sys_platform == "linux" and extra == "compile"
47
- Provides-Extra: hpo
48
- Requires-Dist: optuna>=3.0.0; extra == "hpo"
49
- Provides-Extra: all
50
- Requires-Dist: pytest>=7.0.0; extra == "all"
51
- Requires-Dist: pytest-xdist>=3.5.0; extra == "all"
52
- Requires-Dist: ruff>=0.8.0; extra == "all"
53
- Requires-Dist: pre-commit>=3.5.0; extra == "all"
54
- Requires-Dist: onnx>=1.14.0; extra == "all"
55
- Requires-Dist: onnxruntime>=1.15.0; extra == "all"
56
- Requires-Dist: triton; sys_platform == "linux" and extra == "all"
57
- Requires-Dist: optuna>=3.0.0; extra == "all"
36
+ Requires-Dist: wandb>=0.15.0
37
+ Requires-Dist: optuna>=3.0.0
38
+ Requires-Dist: onnx>=1.14.0
39
+ Requires-Dist: onnxruntime>=1.15.0
40
+ Requires-Dist: pytest>=7.0.0
41
+ Requires-Dist: pytest-xdist>=3.5.0
42
+ Requires-Dist: ruff>=0.8.0
43
+ Requires-Dist: pre-commit>=3.5.0
44
+ Requires-Dist: triton>=2.0.0; sys_platform == "linux"
58
45
 
59
46
  <div align="center">
60
47
 
@@ -210,20 +197,20 @@ Deploy models anywhere:
210
197
 
211
198
  ### Installation
212
199
 
200
+ #### From PyPI (recommended for all users)
201
+
213
202
  ```bash
214
- # Install from PyPI (recommended)
215
203
  pip install wavedl
216
-
217
- # Or install with all extras (ONNX export, HPO, dev tools)
218
- pip install wavedl[all]
219
204
  ```
220
205
 
206
+ This installs everything you need: training, inference, HPO, ONNX export, and dev tools.
207
+
221
208
  #### From Source (for development)
222
209
 
223
210
  ```bash
224
211
  git clone https://github.com/ductho-le/WaveDL.git
225
212
  cd WaveDL
226
- pip install -e ".[dev]"
213
+ pip install -e .
227
214
  ```
228
215
 
229
216
  > [!NOTE]
@@ -359,41 +346,47 @@ WaveDL handles everything else: training loop, logging, checkpoints, multi-GPU,
359
346
  ```
360
347
  WaveDL/
361
348
  ├── src/
362
- │ └── wavedl/ # Main package (namespaced)
363
- │ ├── __init__.py # Package init with __version__
364
- │ ├── train.py # Training entry point
365
- │ ├── test.py # Testing & inference script
366
- │ ├── hpo.py # Hyperparameter optimization
367
- │ ├── hpc.py # HPC distributed training launcher
349
+ │ └── wavedl/ # Main package (namespaced)
350
+ │ ├── __init__.py # Package init with __version__
351
+ │ ├── train.py # Training entry point
352
+ │ ├── test.py # Testing & inference script
353
+ │ ├── hpo.py # Hyperparameter optimization
354
+ │ ├── hpc.py # HPC distributed training launcher
368
355
  │ │
369
- │ ├── models/ # Model architectures
370
- │ │ ├── registry.py # Model factory (@register_model)
371
- │ │ ├── base.py # Abstract base class
372
- │ │ ├── cnn.py # Baseline CNN
373
- │ │ ├── resnet.py # ResNet-18/34/50 (1D/2D/3D)
374
- │ │ ├── efficientnet.py# EfficientNet-B0/B1/B2
375
- │ │ ├── vit.py # Vision Transformer (1D/2D)
376
- │ │ ├── convnext.py # ConvNeXt (1D/2D/3D)
377
- │ │ ├── densenet.py # DenseNet-121/169 (1D/2D/3D)
378
- │ │ └── unet.py # U-Net / U-Net Regression
356
+ │ ├── models/ # Model architectures (38 variants)
357
+ │ │ ├── registry.py # Model factory (@register_model)
358
+ │ │ ├── base.py # Abstract base class
359
+ │ │ ├── cnn.py # Baseline CNN (1D/2D/3D)
360
+ │ │ ├── resnet.py # ResNet-18/34/50 (1D/2D/3D)
361
+ │ │ ├── resnet3d.py # ResNet3D-18, MC3-18 (3D only)
362
+ │ │ ├── tcn.py # TCN (1D only)
363
+ │ │ ├── efficientnet.py # EfficientNet-B0/B1/B2 (2D)
364
+ │ │ ├── efficientnetv2.py # EfficientNetV2-S/M/L (2D)
365
+ │ │ ├── mobilenetv3.py # MobileNetV3-Small/Large (2D)
366
+ │ │ ├── regnet.py # RegNetY variants (2D)
367
+ │ │ ├── swin.py # Swin Transformer (2D)
368
+ │ │ ├── vit.py # Vision Transformer (1D/2D)
369
+ │ │ ├── convnext.py # ConvNeXt (1D/2D/3D)
370
+ │ │ ├── densenet.py # DenseNet-121/169 (1D/2D/3D)
371
+ │ │ └── unet.py # U-Net Regression
379
372
  │ │
380
- │ └── utils/ # Utilities
381
- │ ├── data.py # Memory-mapped data pipeline
382
- │ ├── metrics.py # R², Pearson, visualization
383
- │ ├── distributed.py # DDP synchronization
384
- │ ├── losses.py # Loss function factory
385
- │ ├── optimizers.py # Optimizer factory
386
- │ ├── schedulers.py # LR scheduler factory
387
- │ └── config.py # YAML configuration support
373
+ │ └── utils/ # Utilities
374
+ │ ├── data.py # Memory-mapped data pipeline
375
+ │ ├── metrics.py # R², Pearson, visualization
376
+ │ ├── distributed.py # DDP synchronization
377
+ │ ├── losses.py # Loss function factory
378
+ │ ├── optimizers.py # Optimizer factory
379
+ │ ├── schedulers.py # LR scheduler factory
380
+ │ └── config.py # YAML configuration support
388
381
 
389
- ├── configs/ # YAML config templates
390
- ├── examples/ # Ready-to-run examples
391
- ├── notebooks/ # Jupyter notebooks
392
- ├── unit_tests/ # Pytest test suite (422 tests)
382
+ ├── configs/ # YAML config templates
383
+ ├── examples/ # Ready-to-run examples
384
+ ├── notebooks/ # Jupyter notebooks
385
+ ├── unit_tests/ # Pytest test suite (704 tests)
393
386
 
394
- ├── pyproject.toml # Package config, dependencies
395
- ├── CHANGELOG.md # Version history
396
- └── CITATION.cff # Citation metadata
387
+ ├── pyproject.toml # Package config, dependencies
388
+ ├── CHANGELOG.md # Version history
389
+ └── CITATION.cff # Citation metadata
397
390
  ```
398
391
  ---
399
392
 
@@ -412,33 +405,63 @@ WaveDL/
412
405
  > ```
413
406
 
414
407
  <details>
415
- <summary><b>Available Models</b> — 21 pre-built architectures</summary>
416
-
417
- | Model | Best For | Params (2D) | Dimensionality |
418
- |-------|----------|-------------|----------------|
419
- | `cnn` | Baseline, lightweight | 1.7M | 1D/2D/3D |
420
- | `resnet18` | Fast training, smaller datasets | 11.4M | 1D/2D/3D |
421
- | `resnet34` | Balanced performance | 21.5M | 1D/2D/3D |
422
- | `resnet50` | High capacity, complex patterns | 24.6M | 1D/2D/3D |
423
- | `resnet18_pretrained` | **Transfer learning** ⭐ | 11.4M | 2D only |
424
- | `resnet50_pretrained` | **Transfer learning** ⭐ | 24.6M | 2D only |
425
- | `efficientnet_b0` | Efficient, **pretrained** ⭐ | 4.7M | 2D only |
426
- | `efficientnet_b1` | Efficient, **pretrained** ⭐ | 7.2M | 2D only |
427
- | `efficientnet_b2` | Efficient, **pretrained** | 8.4M | 2D only |
428
- | `vit_tiny` | Transformer, small datasets | 5.4M | 1D/2D |
429
- | `vit_small` | Transformer, balanced | 21.5M | 1D/2D |
430
- | `vit_base` | Transformer, high capacity | 85.5M | 1D/2D |
431
- | `convnext_tiny` | Modern CNN, transformer-inspired | 28.2M | 1D/2D/3D |
432
- | `convnext_tiny_pretrained` | **Transfer learning** ⭐ | 28.2M | 2D only |
433
- | `convnext_small` | Modern CNN, balanced | 49.8M | 1D/2D/3D |
434
- | `convnext_base` | Modern CNN, high capacity | 88.1M | 1D/2D/3D |
435
- | `densenet121` | Feature reuse, small data | 7.5M | 1D/2D/3D |
436
- | `densenet121_pretrained` | **Transfer learning** ⭐ | 7.5M | 2D only |
437
- | `densenet169` | Deeper DenseNet | 13.3M | 1D/2D/3D |
438
- | `unet` | Spatial output (velocity fields) | 31.0M | 1D/2D/3D |
439
- | `unet_regression` | Multi-scale features for regression | 31.1M | 1D/2D/3D |
440
-
441
- >**Pretrained models** use ImageNet weights for transfer learning.
408
+ <summary><b>Available Models</b> — 38 architectures</summary>
409
+
410
+ | Model | Params | Dim |
411
+ |-------|--------|-----|
412
+ | **CNN** Convolutional Neural Network |||
413
+ | `cnn` | 1.7M | 1D/2D/3D |
414
+ | **ResNet** Residual Network |||
415
+ | `resnet18` | 11.4M | 1D/2D/3D |
416
+ | `resnet34` | 21.5M | 1D/2D/3D |
417
+ | `resnet50` | 24.6M | 1D/2D/3D |
418
+ | `resnet18_pretrained` ⭐ | 11.4M | 2D |
419
+ | `resnet50_pretrained` ⭐ | 24.6M | 2D |
420
+ | **ResNet3D** 3D Residual Network |||
421
+ | `resnet3d_18` | 33.6M | 3D |
422
+ | `mc3_18` Mixed Convolution 3D | 11.9M | 3D |
423
+ | **TCN** Temporal Convolutional Network |||
424
+ | `tcn_small` | 1.0M | 1D |
425
+ | `tcn` | 7.0M | 1D |
426
+ | `tcn_large` | 10.2M | 1D |
427
+ | **EfficientNet** Efficient Neural Network |||
428
+ | `efficientnet_b0` | 4.7M | 2D |
429
+ | `efficientnet_b1` ⭐ | 7.2M | 2D |
430
+ | `efficientnet_b2` | 8.4M | 2D |
431
+ | **EfficientNetV2** Efficient Neural Network V2 |||
432
+ | `efficientnet_v2_s` | 21.0M | 2D |
433
+ | `efficientnet_v2_m` ⭐ | 53.6M | 2D |
434
+ | `efficientnet_v2_l` | 118.0M | 2D |
435
+ | **MobileNetV3** — Mobile Neural Network V3 |||
436
+ | `mobilenet_v3_small` ⭐ | 1.1M | 2D |
437
+ | `mobilenet_v3_large` ⭐ | 3.2M | 2D |
438
+ | **RegNet** — Regularized Network |||
439
+ | `regnet_y_400mf` ⭐ | 4.0M | 2D |
440
+ | `regnet_y_800mf` ⭐ | 5.8M | 2D |
441
+ | `regnet_y_1_6gf` ⭐ | 10.5M | 2D |
442
+ | `regnet_y_3_2gf` ⭐ | 18.3M | 2D |
443
+ | `regnet_y_8gf` ⭐ | 37.9M | 2D |
444
+ | **Swin** — Shifted Window Transformer |||
445
+ | `swin_t` ⭐ | 28.0M | 2D |
446
+ | `swin_s` ⭐ | 49.4M | 2D |
447
+ | `swin_b` ⭐ | 87.4M | 2D |
448
+ | **ConvNeXt** — Convolutional Next |||
449
+ | `convnext_tiny` | 28.2M | 1D/2D/3D |
450
+ | `convnext_small` | 49.8M | 1D/2D/3D |
451
+ | `convnext_base` | 88.1M | 1D/2D/3D |
452
+ | `convnext_tiny_pretrained` ⭐ | 28.2M | 2D |
453
+ | **DenseNet** — Densely Connected Network |||
454
+ | `densenet121` | 7.5M | 1D/2D/3D |
455
+ | `densenet169` | 13.3M | 1D/2D/3D |
456
+ | `densenet121_pretrained` ⭐ | 7.5M | 2D |
457
+ | **ViT** — Vision Transformer |||
458
+ | `vit_tiny` | 5.5M | 1D/2D |
459
+ | `vit_small` | 21.6M | 1D/2D |
460
+ | `vit_base` | 85.6M | 1D/2D |
461
+ | **U-Net** — U-shaped Network |||
462
+ | `unet_regression` | 31.1M | 1D/2D/3D |
463
+
464
+ > ⭐ = Pretrained on ImageNet. Recommended for smaller datasets.
442
465
 
443
466
  </details>
444
467
 
@@ -653,12 +676,7 @@ seed: 2025
653
676
 
654
677
  Automatically find the best training configuration using [Optuna](https://optuna.org/).
655
678
 
656
- **Step 1: Install**
657
- ```bash
658
- pip install -e ".[hpo]"
659
- ```
660
-
661
- **Step 2: Run HPO**
679
+ **Run HPO:**
662
680
 
663
681
  You specify which models to search and how many trials to run:
664
682
  ```bash
@@ -715,7 +733,7 @@ accelerate launch -m wavedl.train --data_path train.npz --model cnn --lr 3.2e-4
715
733
  | `--output` | `hpo_results.json` | Output file |
716
734
 
717
735
  > [!TIP]
718
- > See [Available Models](#available-models) for all 21 architectures you can search.
736
+ > See [Available Models](#available-models) for all 38 architectures you can search.
719
737
 
720
738
  </details>
721
739
 
@@ -152,20 +152,20 @@ Deploy models anywhere:
152
152
 
153
153
  ### Installation
154
154
 
155
+ #### From PyPI (recommended for all users)
156
+
155
157
  ```bash
156
- # Install from PyPI (recommended)
157
158
  pip install wavedl
158
-
159
- # Or install with all extras (ONNX export, HPO, dev tools)
160
- pip install wavedl[all]
161
159
  ```
162
160
 
161
+ This installs everything you need: training, inference, HPO, ONNX export, and dev tools.
162
+
163
163
  #### From Source (for development)
164
164
 
165
165
  ```bash
166
166
  git clone https://github.com/ductho-le/WaveDL.git
167
167
  cd WaveDL
168
- pip install -e ".[dev]"
168
+ pip install -e .
169
169
  ```
170
170
 
171
171
  > [!NOTE]
@@ -301,41 +301,47 @@ WaveDL handles everything else: training loop, logging, checkpoints, multi-GPU,
301
301
  ```
302
302
  WaveDL/
303
303
  ├── src/
304
- │ └── wavedl/ # Main package (namespaced)
305
- │ ├── __init__.py # Package init with __version__
306
- │ ├── train.py # Training entry point
307
- │ ├── test.py # Testing & inference script
308
- │ ├── hpo.py # Hyperparameter optimization
309
- │ ├── hpc.py # HPC distributed training launcher
304
+ │ └── wavedl/ # Main package (namespaced)
305
+ │ ├── __init__.py # Package init with __version__
306
+ │ ├── train.py # Training entry point
307
+ │ ├── test.py # Testing & inference script
308
+ │ ├── hpo.py # Hyperparameter optimization
309
+ │ ├── hpc.py # HPC distributed training launcher
310
310
  │ │
311
- │ ├── models/ # Model architectures
312
- │ │ ├── registry.py # Model factory (@register_model)
313
- │ │ ├── base.py # Abstract base class
314
- │ │ ├── cnn.py # Baseline CNN
315
- │ │ ├── resnet.py # ResNet-18/34/50 (1D/2D/3D)
316
- │ │ ├── efficientnet.py# EfficientNet-B0/B1/B2
317
- │ │ ├── vit.py # Vision Transformer (1D/2D)
318
- │ │ ├── convnext.py # ConvNeXt (1D/2D/3D)
319
- │ │ ├── densenet.py # DenseNet-121/169 (1D/2D/3D)
320
- │ │ └── unet.py # U-Net / U-Net Regression
311
+ │ ├── models/ # Model architectures (38 variants)
312
+ │ │ ├── registry.py # Model factory (@register_model)
313
+ │ │ ├── base.py # Abstract base class
314
+ │ │ ├── cnn.py # Baseline CNN (1D/2D/3D)
315
+ │ │ ├── resnet.py # ResNet-18/34/50 (1D/2D/3D)
316
+ │ │ ├── resnet3d.py # ResNet3D-18, MC3-18 (3D only)
317
+ │ │ ├── tcn.py # TCN (1D only)
318
+ │ │ ├── efficientnet.py # EfficientNet-B0/B1/B2 (2D)
319
+ │ │ ├── efficientnetv2.py # EfficientNetV2-S/M/L (2D)
320
+ │ │ ├── mobilenetv3.py # MobileNetV3-Small/Large (2D)
321
+ │ │ ├── regnet.py # RegNetY variants (2D)
322
+ │ │ ├── swin.py # Swin Transformer (2D)
323
+ │ │ ├── vit.py # Vision Transformer (1D/2D)
324
+ │ │ ├── convnext.py # ConvNeXt (1D/2D/3D)
325
+ │ │ ├── densenet.py # DenseNet-121/169 (1D/2D/3D)
326
+ │ │ └── unet.py # U-Net Regression
321
327
  │ │
322
- │ └── utils/ # Utilities
323
- │ ├── data.py # Memory-mapped data pipeline
324
- │ ├── metrics.py # R², Pearson, visualization
325
- │ ├── distributed.py # DDP synchronization
326
- │ ├── losses.py # Loss function factory
327
- │ ├── optimizers.py # Optimizer factory
328
- │ ├── schedulers.py # LR scheduler factory
329
- │ └── config.py # YAML configuration support
328
+ │ └── utils/ # Utilities
329
+ │ ├── data.py # Memory-mapped data pipeline
330
+ │ ├── metrics.py # R², Pearson, visualization
331
+ │ ├── distributed.py # DDP synchronization
332
+ │ ├── losses.py # Loss function factory
333
+ │ ├── optimizers.py # Optimizer factory
334
+ │ ├── schedulers.py # LR scheduler factory
335
+ │ └── config.py # YAML configuration support
330
336
 
331
- ├── configs/ # YAML config templates
332
- ├── examples/ # Ready-to-run examples
333
- ├── notebooks/ # Jupyter notebooks
334
- ├── unit_tests/ # Pytest test suite (422 tests)
337
+ ├── configs/ # YAML config templates
338
+ ├── examples/ # Ready-to-run examples
339
+ ├── notebooks/ # Jupyter notebooks
340
+ ├── unit_tests/ # Pytest test suite (704 tests)
335
341
 
336
- ├── pyproject.toml # Package config, dependencies
337
- ├── CHANGELOG.md # Version history
338
- └── CITATION.cff # Citation metadata
342
+ ├── pyproject.toml # Package config, dependencies
343
+ ├── CHANGELOG.md # Version history
344
+ └── CITATION.cff # Citation metadata
339
345
  ```
340
346
  ---
341
347
 
@@ -354,33 +360,63 @@ WaveDL/
354
360
  > ```
355
361
 
356
362
  <details>
357
- <summary><b>Available Models</b> — 21 pre-built architectures</summary>
358
-
359
- | Model | Best For | Params (2D) | Dimensionality |
360
- |-------|----------|-------------|----------------|
361
- | `cnn` | Baseline, lightweight | 1.7M | 1D/2D/3D |
362
- | `resnet18` | Fast training, smaller datasets | 11.4M | 1D/2D/3D |
363
- | `resnet34` | Balanced performance | 21.5M | 1D/2D/3D |
364
- | `resnet50` | High capacity, complex patterns | 24.6M | 1D/2D/3D |
365
- | `resnet18_pretrained` | **Transfer learning** ⭐ | 11.4M | 2D only |
366
- | `resnet50_pretrained` | **Transfer learning** ⭐ | 24.6M | 2D only |
367
- | `efficientnet_b0` | Efficient, **pretrained** ⭐ | 4.7M | 2D only |
368
- | `efficientnet_b1` | Efficient, **pretrained** ⭐ | 7.2M | 2D only |
369
- | `efficientnet_b2` | Efficient, **pretrained** | 8.4M | 2D only |
370
- | `vit_tiny` | Transformer, small datasets | 5.4M | 1D/2D |
371
- | `vit_small` | Transformer, balanced | 21.5M | 1D/2D |
372
- | `vit_base` | Transformer, high capacity | 85.5M | 1D/2D |
373
- | `convnext_tiny` | Modern CNN, transformer-inspired | 28.2M | 1D/2D/3D |
374
- | `convnext_tiny_pretrained` | **Transfer learning** ⭐ | 28.2M | 2D only |
375
- | `convnext_small` | Modern CNN, balanced | 49.8M | 1D/2D/3D |
376
- | `convnext_base` | Modern CNN, high capacity | 88.1M | 1D/2D/3D |
377
- | `densenet121` | Feature reuse, small data | 7.5M | 1D/2D/3D |
378
- | `densenet121_pretrained` | **Transfer learning** ⭐ | 7.5M | 2D only |
379
- | `densenet169` | Deeper DenseNet | 13.3M | 1D/2D/3D |
380
- | `unet` | Spatial output (velocity fields) | 31.0M | 1D/2D/3D |
381
- | `unet_regression` | Multi-scale features for regression | 31.1M | 1D/2D/3D |
382
-
383
- >**Pretrained models** use ImageNet weights for transfer learning.
363
+ <summary><b>Available Models</b> — 38 architectures</summary>
364
+
365
+ | Model | Params | Dim |
366
+ |-------|--------|-----|
367
+ | **CNN** Convolutional Neural Network |||
368
+ | `cnn` | 1.7M | 1D/2D/3D |
369
+ | **ResNet** Residual Network |||
370
+ | `resnet18` | 11.4M | 1D/2D/3D |
371
+ | `resnet34` | 21.5M | 1D/2D/3D |
372
+ | `resnet50` | 24.6M | 1D/2D/3D |
373
+ | `resnet18_pretrained` ⭐ | 11.4M | 2D |
374
+ | `resnet50_pretrained` ⭐ | 24.6M | 2D |
375
+ | **ResNet3D** 3D Residual Network |||
376
+ | `resnet3d_18` | 33.6M | 3D |
377
+ | `mc3_18` Mixed Convolution 3D | 11.9M | 3D |
378
+ | **TCN** Temporal Convolutional Network |||
379
+ | `tcn_small` | 1.0M | 1D |
380
+ | `tcn` | 7.0M | 1D |
381
+ | `tcn_large` | 10.2M | 1D |
382
+ | **EfficientNet** Efficient Neural Network |||
383
+ | `efficientnet_b0` | 4.7M | 2D |
384
+ | `efficientnet_b1` ⭐ | 7.2M | 2D |
385
+ | `efficientnet_b2` | 8.4M | 2D |
386
+ | **EfficientNetV2** Efficient Neural Network V2 |||
387
+ | `efficientnet_v2_s` | 21.0M | 2D |
388
+ | `efficientnet_v2_m` ⭐ | 53.6M | 2D |
389
+ | `efficientnet_v2_l` | 118.0M | 2D |
390
+ | **MobileNetV3** — Mobile Neural Network V3 |||
391
+ | `mobilenet_v3_small` ⭐ | 1.1M | 2D |
392
+ | `mobilenet_v3_large` ⭐ | 3.2M | 2D |
393
+ | **RegNet** — Regularized Network |||
394
+ | `regnet_y_400mf` ⭐ | 4.0M | 2D |
395
+ | `regnet_y_800mf` ⭐ | 5.8M | 2D |
396
+ | `regnet_y_1_6gf` ⭐ | 10.5M | 2D |
397
+ | `regnet_y_3_2gf` ⭐ | 18.3M | 2D |
398
+ | `regnet_y_8gf` ⭐ | 37.9M | 2D |
399
+ | **Swin** — Shifted Window Transformer |||
400
+ | `swin_t` ⭐ | 28.0M | 2D |
401
+ | `swin_s` ⭐ | 49.4M | 2D |
402
+ | `swin_b` ⭐ | 87.4M | 2D |
403
+ | **ConvNeXt** — Convolutional Next |||
404
+ | `convnext_tiny` | 28.2M | 1D/2D/3D |
405
+ | `convnext_small` | 49.8M | 1D/2D/3D |
406
+ | `convnext_base` | 88.1M | 1D/2D/3D |
407
+ | `convnext_tiny_pretrained` ⭐ | 28.2M | 2D |
408
+ | **DenseNet** — Densely Connected Network |||
409
+ | `densenet121` | 7.5M | 1D/2D/3D |
410
+ | `densenet169` | 13.3M | 1D/2D/3D |
411
+ | `densenet121_pretrained` ⭐ | 7.5M | 2D |
412
+ | **ViT** — Vision Transformer |||
413
+ | `vit_tiny` | 5.5M | 1D/2D |
414
+ | `vit_small` | 21.6M | 1D/2D |
415
+ | `vit_base` | 85.6M | 1D/2D |
416
+ | **U-Net** — U-shaped Network |||
417
+ | `unet_regression` | 31.1M | 1D/2D/3D |
418
+
419
+ > ⭐ = Pretrained on ImageNet. Recommended for smaller datasets.
384
420
 
385
421
  </details>
386
422
 
@@ -595,12 +631,7 @@ seed: 2025
595
631
 
596
632
  Automatically find the best training configuration using [Optuna](https://optuna.org/).
597
633
 
598
- **Step 1: Install**
599
- ```bash
600
- pip install -e ".[hpo]"
601
- ```
602
-
603
- **Step 2: Run HPO**
634
+ **Run HPO:**
604
635
 
605
636
  You specify which models to search and how many trials to run:
606
637
  ```bash
@@ -657,7 +688,7 @@ accelerate launch -m wavedl.train --data_path train.npz --model cnn --lr 3.2e-4
657
688
  | `--output` | `hpo_results.json` | Output file |
658
689
 
659
690
  > [!TIP]
660
- > See [Available Models](#available-models) for all 21 architectures you can search.
691
+ > See [Available Models](#available-models) for all 38 architectures you can search.
661
692
 
662
693
  </details>
663
694
 
@@ -49,6 +49,7 @@ classifiers = [
49
49
  "Topic :: Scientific/Engineering :: Physics",
50
50
  ]
51
51
  dependencies = [
52
+ # Core ML stack
52
53
  "torch>=2.0.0",
53
54
  "torchvision>=0.15.0",
54
55
  "accelerate>=0.20.0",
@@ -58,22 +59,24 @@ dependencies = [
58
59
  "pandas>=2.0.0",
59
60
  "matplotlib>=3.7.0",
60
61
  "tqdm>=4.65.0",
61
- "wandb>=0.15.0",
62
62
  "pyyaml>=6.0.0",
63
+ # Data formats
63
64
  "h5py>=3.8.0",
64
65
  "safetensors>=0.3.0",
65
- ]
66
-
67
- [project.optional-dependencies]
68
- dev = ["pytest>=7.0.0", "pytest-xdist>=3.5.0", "ruff>=0.8.0", "pre-commit>=3.5.0"]
69
- onnx = ["onnx>=1.14.0", "onnxruntime>=1.15.0"]
70
- compile = ["triton; sys_platform == 'linux'"] # Linux-only, enables torch.compile
71
- hpo = ["optuna>=3.0.0"] # Hyperparameter optimization
72
- all = [
73
- "pytest>=7.0.0", "pytest-xdist>=3.5.0", "ruff>=0.8.0", "pre-commit>=3.5.0",
74
- "onnx>=1.14.0", "onnxruntime>=1.15.0",
75
- "triton; sys_platform == 'linux'",
66
+ # Logging
67
+ "wandb>=0.15.0",
68
+ # HPO
76
69
  "optuna>=3.0.0",
70
+ # ONNX export
71
+ "onnx>=1.14.0",
72
+ "onnxruntime>=1.15.0",
73
+ # Development tools
74
+ "pytest>=7.0.0",
75
+ "pytest-xdist>=3.5.0",
76
+ "ruff>=0.8.0",
77
+ "pre-commit>=3.5.0",
78
+ # torch.compile backend (Linux only)
79
+ "triton>=2.0.0; sys_platform == 'linux'",
77
80
  ]
78
81
 
79
82
  [project.scripts]
@@ -18,7 +18,7 @@ For inference:
18
18
  # or: python -m wavedl.test --checkpoint best_checkpoint --data_path test.npz
19
19
  """
20
20
 
21
- __version__ = "1.3.1"
21
+ __version__ = "1.4.0"
22
22
  __author__ = "Ductho Le"
23
23
  __email__ = "ductho.le@outlook.com"
24
24