wavedl 1.3.1__tar.gz → 1.4.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (42) hide show
  1. {wavedl-1.3.1/src/wavedl.egg-info → wavedl-1.4.1}/PKG-INFO +136 -98
  2. {wavedl-1.3.1 → wavedl-1.4.1}/README.md +126 -75
  3. {wavedl-1.3.1 → wavedl-1.4.1}/pyproject.toml +15 -12
  4. {wavedl-1.3.1 → wavedl-1.4.1}/src/wavedl/__init__.py +1 -1
  5. {wavedl-1.3.1 → wavedl-1.4.1}/src/wavedl/hpc.py +48 -28
  6. {wavedl-1.3.1 → wavedl-1.4.1}/src/wavedl/models/__init__.py +33 -7
  7. {wavedl-1.3.1 → wavedl-1.4.1}/src/wavedl/models/_template.py +28 -41
  8. {wavedl-1.3.1 → wavedl-1.4.1}/src/wavedl/models/base.py +49 -2
  9. {wavedl-1.3.1 → wavedl-1.4.1}/src/wavedl/models/cnn.py +0 -1
  10. {wavedl-1.3.1 → wavedl-1.4.1}/src/wavedl/models/convnext.py +4 -1
  11. {wavedl-1.3.1 → wavedl-1.4.1}/src/wavedl/models/densenet.py +4 -1
  12. {wavedl-1.3.1 → wavedl-1.4.1}/src/wavedl/models/efficientnet.py +9 -5
  13. wavedl-1.4.1/src/wavedl/models/efficientnetv2.py +292 -0
  14. wavedl-1.4.1/src/wavedl/models/mobilenetv3.py +272 -0
  15. {wavedl-1.3.1 → wavedl-1.4.1}/src/wavedl/models/registry.py +0 -1
  16. wavedl-1.4.1/src/wavedl/models/regnet.py +383 -0
  17. {wavedl-1.3.1 → wavedl-1.4.1}/src/wavedl/models/resnet.py +7 -4
  18. wavedl-1.4.1/src/wavedl/models/resnet3d.py +258 -0
  19. wavedl-1.4.1/src/wavedl/models/swin.py +390 -0
  20. wavedl-1.4.1/src/wavedl/models/tcn.py +389 -0
  21. {wavedl-1.3.1 → wavedl-1.4.1}/src/wavedl/models/unet.py +44 -110
  22. {wavedl-1.3.1 → wavedl-1.4.1}/src/wavedl/models/vit.py +8 -4
  23. {wavedl-1.3.1 → wavedl-1.4.1}/src/wavedl/train.py +1144 -1116
  24. {wavedl-1.3.1 → wavedl-1.4.1}/src/wavedl/utils/config.py +88 -2
  25. {wavedl-1.3.1 → wavedl-1.4.1/src/wavedl.egg-info}/PKG-INFO +136 -98
  26. {wavedl-1.3.1 → wavedl-1.4.1}/src/wavedl.egg-info/SOURCES.txt +6 -0
  27. {wavedl-1.3.1 → wavedl-1.4.1}/src/wavedl.egg-info/requires.txt +4 -24
  28. {wavedl-1.3.1 → wavedl-1.4.1}/LICENSE +0 -0
  29. {wavedl-1.3.1 → wavedl-1.4.1}/setup.cfg +0 -0
  30. {wavedl-1.3.1 → wavedl-1.4.1}/src/wavedl/hpo.py +0 -0
  31. {wavedl-1.3.1 → wavedl-1.4.1}/src/wavedl/test.py +0 -0
  32. {wavedl-1.3.1 → wavedl-1.4.1}/src/wavedl/utils/__init__.py +0 -0
  33. {wavedl-1.3.1 → wavedl-1.4.1}/src/wavedl/utils/cross_validation.py +0 -0
  34. {wavedl-1.3.1 → wavedl-1.4.1}/src/wavedl/utils/data.py +0 -0
  35. {wavedl-1.3.1 → wavedl-1.4.1}/src/wavedl/utils/distributed.py +0 -0
  36. {wavedl-1.3.1 → wavedl-1.4.1}/src/wavedl/utils/losses.py +0 -0
  37. {wavedl-1.3.1 → wavedl-1.4.1}/src/wavedl/utils/metrics.py +0 -0
  38. {wavedl-1.3.1 → wavedl-1.4.1}/src/wavedl/utils/optimizers.py +0 -0
  39. {wavedl-1.3.1 → wavedl-1.4.1}/src/wavedl/utils/schedulers.py +0 -0
  40. {wavedl-1.3.1 → wavedl-1.4.1}/src/wavedl.egg-info/dependency_links.txt +0 -0
  41. {wavedl-1.3.1 → wavedl-1.4.1}/src/wavedl.egg-info/entry_points.txt +0 -0
  42. {wavedl-1.3.1 → wavedl-1.4.1}/src/wavedl.egg-info/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.2
2
2
  Name: wavedl
3
- Version: 1.3.1
3
+ Version: 1.4.1
4
4
  Summary: A Scalable Deep Learning Framework for Wave-Based Inverse Problems
5
5
  Author: Ductho Le
6
6
  License: MIT
@@ -30,31 +30,18 @@ Requires-Dist: scikit-learn>=1.2.0
30
30
  Requires-Dist: pandas>=2.0.0
31
31
  Requires-Dist: matplotlib>=3.7.0
32
32
  Requires-Dist: tqdm>=4.65.0
33
- Requires-Dist: wandb>=0.15.0
34
33
  Requires-Dist: pyyaml>=6.0.0
35
34
  Requires-Dist: h5py>=3.8.0
36
35
  Requires-Dist: safetensors>=0.3.0
37
- Provides-Extra: dev
38
- Requires-Dist: pytest>=7.0.0; extra == "dev"
39
- Requires-Dist: pytest-xdist>=3.5.0; extra == "dev"
40
- Requires-Dist: ruff>=0.8.0; extra == "dev"
41
- Requires-Dist: pre-commit>=3.5.0; extra == "dev"
42
- Provides-Extra: onnx
43
- Requires-Dist: onnx>=1.14.0; extra == "onnx"
44
- Requires-Dist: onnxruntime>=1.15.0; extra == "onnx"
45
- Provides-Extra: compile
46
- Requires-Dist: triton; sys_platform == "linux" and extra == "compile"
47
- Provides-Extra: hpo
48
- Requires-Dist: optuna>=3.0.0; extra == "hpo"
49
- Provides-Extra: all
50
- Requires-Dist: pytest>=7.0.0; extra == "all"
51
- Requires-Dist: pytest-xdist>=3.5.0; extra == "all"
52
- Requires-Dist: ruff>=0.8.0; extra == "all"
53
- Requires-Dist: pre-commit>=3.5.0; extra == "all"
54
- Requires-Dist: onnx>=1.14.0; extra == "all"
55
- Requires-Dist: onnxruntime>=1.15.0; extra == "all"
56
- Requires-Dist: triton; sys_platform == "linux" and extra == "all"
57
- Requires-Dist: optuna>=3.0.0; extra == "all"
36
+ Requires-Dist: wandb>=0.15.0
37
+ Requires-Dist: optuna>=3.0.0
38
+ Requires-Dist: onnx>=1.14.0
39
+ Requires-Dist: onnxruntime>=1.15.0
40
+ Requires-Dist: pytest>=7.0.0
41
+ Requires-Dist: pytest-xdist>=3.5.0
42
+ Requires-Dist: ruff>=0.8.0
43
+ Requires-Dist: pre-commit>=3.5.0
44
+ Requires-Dist: triton>=2.0.0; sys_platform == "linux"
58
45
 
59
46
  <div align="center">
60
47
 
@@ -210,20 +197,20 @@ Deploy models anywhere:
210
197
 
211
198
  ### Installation
212
199
 
200
+ #### From PyPI (recommended for all users)
201
+
213
202
  ```bash
214
- # Install from PyPI (recommended)
215
203
  pip install wavedl
216
-
217
- # Or install with all extras (ONNX export, HPO, dev tools)
218
- pip install wavedl[all]
219
204
  ```
220
205
 
206
+ This installs everything you need: training, inference, HPO, ONNX export, and dev tools.
207
+
221
208
  #### From Source (for development)
222
209
 
223
210
  ```bash
224
211
  git clone https://github.com/ductho-le/WaveDL.git
225
212
  cd WaveDL
226
- pip install -e ".[dev]"
213
+ pip install -e .
227
214
  ```
228
215
 
229
216
  > [!NOTE]
@@ -359,41 +346,47 @@ WaveDL handles everything else: training loop, logging, checkpoints, multi-GPU,
359
346
  ```
360
347
  WaveDL/
361
348
  ├── src/
362
- │ └── wavedl/ # Main package (namespaced)
363
- │ ├── __init__.py # Package init with __version__
364
- │ ├── train.py # Training entry point
365
- │ ├── test.py # Testing & inference script
366
- │ ├── hpo.py # Hyperparameter optimization
367
- │ ├── hpc.py # HPC distributed training launcher
349
+ │ └── wavedl/ # Main package (namespaced)
350
+ │ ├── __init__.py # Package init with __version__
351
+ │ ├── train.py # Training entry point
352
+ │ ├── test.py # Testing & inference script
353
+ │ ├── hpo.py # Hyperparameter optimization
354
+ │ ├── hpc.py # HPC distributed training launcher
368
355
  │ │
369
- │ ├── models/ # Model architectures
370
- │ │ ├── registry.py # Model factory (@register_model)
371
- │ │ ├── base.py # Abstract base class
372
- │ │ ├── cnn.py # Baseline CNN
373
- │ │ ├── resnet.py # ResNet-18/34/50 (1D/2D/3D)
374
- │ │ ├── efficientnet.py# EfficientNet-B0/B1/B2
375
- │ │ ├── vit.py # Vision Transformer (1D/2D)
376
- │ │ ├── convnext.py # ConvNeXt (1D/2D/3D)
377
- │ │ ├── densenet.py # DenseNet-121/169 (1D/2D/3D)
378
- │ │ └── unet.py # U-Net / U-Net Regression
356
+ │ ├── models/ # Model architectures (38 variants)
357
+ │ │ ├── registry.py # Model factory (@register_model)
358
+ │ │ ├── base.py # Abstract base class
359
+ │ │ ├── cnn.py # Baseline CNN (1D/2D/3D)
360
+ │ │ ├── resnet.py # ResNet-18/34/50 (1D/2D/3D)
361
+ │ │ ├── resnet3d.py # ResNet3D-18, MC3-18 (3D only)
362
+ │ │ ├── tcn.py # TCN (1D only)
363
+ │ │ ├── efficientnet.py # EfficientNet-B0/B1/B2 (2D)
364
+ │ │ ├── efficientnetv2.py # EfficientNetV2-S/M/L (2D)
365
+ │ │ ├── mobilenetv3.py # MobileNetV3-Small/Large (2D)
366
+ │ │ ├── regnet.py # RegNetY variants (2D)
367
+ │ │ ├── swin.py # Swin Transformer (2D)
368
+ │ │ ├── vit.py # Vision Transformer (1D/2D)
369
+ │ │ ├── convnext.py # ConvNeXt (1D/2D/3D)
370
+ │ │ ├── densenet.py # DenseNet-121/169 (1D/2D/3D)
371
+ │ │ └── unet.py # U-Net Regression
379
372
  │ │
380
- │ └── utils/ # Utilities
381
- │ ├── data.py # Memory-mapped data pipeline
382
- │ ├── metrics.py # R², Pearson, visualization
383
- │ ├── distributed.py # DDP synchronization
384
- │ ├── losses.py # Loss function factory
385
- │ ├── optimizers.py # Optimizer factory
386
- │ ├── schedulers.py # LR scheduler factory
387
- │ └── config.py # YAML configuration support
373
+ │ └── utils/ # Utilities
374
+ │ ├── data.py # Memory-mapped data pipeline
375
+ │ ├── metrics.py # R², Pearson, visualization
376
+ │ ├── distributed.py # DDP synchronization
377
+ │ ├── losses.py # Loss function factory
378
+ │ ├── optimizers.py # Optimizer factory
379
+ │ ├── schedulers.py # LR scheduler factory
380
+ │ └── config.py # YAML configuration support
388
381
 
389
- ├── configs/ # YAML config templates
390
- ├── examples/ # Ready-to-run examples
391
- ├── notebooks/ # Jupyter notebooks
392
- ├── unit_tests/ # Pytest test suite (422 tests)
382
+ ├── configs/ # YAML config templates
383
+ ├── examples/ # Ready-to-run examples
384
+ ├── notebooks/ # Jupyter notebooks
385
+ ├── unit_tests/ # Pytest test suite (704 tests)
393
386
 
394
- ├── pyproject.toml # Package config, dependencies
395
- ├── CHANGELOG.md # Version history
396
- └── CITATION.cff # Citation metadata
387
+ ├── pyproject.toml # Package config, dependencies
388
+ ├── CHANGELOG.md # Version history
389
+ └── CITATION.cff # Citation metadata
397
390
  ```
398
391
  ---
399
392
 
@@ -412,33 +405,63 @@ WaveDL/
412
405
  > ```
413
406
 
414
407
  <details>
415
- <summary><b>Available Models</b> — 21 pre-built architectures</summary>
416
-
417
- | Model | Best For | Params (2D) | Dimensionality |
418
- |-------|----------|-------------|----------------|
419
- | `cnn` | Baseline, lightweight | 1.7M | 1D/2D/3D |
420
- | `resnet18` | Fast training, smaller datasets | 11.4M | 1D/2D/3D |
421
- | `resnet34` | Balanced performance | 21.5M | 1D/2D/3D |
422
- | `resnet50` | High capacity, complex patterns | 24.6M | 1D/2D/3D |
423
- | `resnet18_pretrained` | **Transfer learning** ⭐ | 11.4M | 2D only |
424
- | `resnet50_pretrained` | **Transfer learning** ⭐ | 24.6M | 2D only |
425
- | `efficientnet_b0` | Efficient, **pretrained** ⭐ | 4.7M | 2D only |
426
- | `efficientnet_b1` | Efficient, **pretrained** ⭐ | 7.2M | 2D only |
427
- | `efficientnet_b2` | Efficient, **pretrained** | 8.4M | 2D only |
428
- | `vit_tiny` | Transformer, small datasets | 5.4M | 1D/2D |
429
- | `vit_small` | Transformer, balanced | 21.5M | 1D/2D |
430
- | `vit_base` | Transformer, high capacity | 85.5M | 1D/2D |
431
- | `convnext_tiny` | Modern CNN, transformer-inspired | 28.2M | 1D/2D/3D |
432
- | `convnext_tiny_pretrained` | **Transfer learning** ⭐ | 28.2M | 2D only |
433
- | `convnext_small` | Modern CNN, balanced | 49.8M | 1D/2D/3D |
434
- | `convnext_base` | Modern CNN, high capacity | 88.1M | 1D/2D/3D |
435
- | `densenet121` | Feature reuse, small data | 7.5M | 1D/2D/3D |
436
- | `densenet121_pretrained` | **Transfer learning** ⭐ | 7.5M | 2D only |
437
- | `densenet169` | Deeper DenseNet | 13.3M | 1D/2D/3D |
438
- | `unet` | Spatial output (velocity fields) | 31.0M | 1D/2D/3D |
439
- | `unet_regression` | Multi-scale features for regression | 31.1M | 1D/2D/3D |
440
-
441
- >**Pretrained models** use ImageNet weights for transfer learning.
408
+ <summary><b>Available Models</b> — 38 architectures</summary>
409
+
410
+ | Model | Params | Dim |
411
+ |-------|--------|-----|
412
+ | **CNN** Convolutional Neural Network |||
413
+ | `cnn` | 1.7M | 1D/2D/3D |
414
+ | **ResNet** Residual Network |||
415
+ | `resnet18` | 11.4M | 1D/2D/3D |
416
+ | `resnet34` | 21.5M | 1D/2D/3D |
417
+ | `resnet50` | 24.6M | 1D/2D/3D |
418
+ | `resnet18_pretrained` ⭐ | 11.4M | 2D |
419
+ | `resnet50_pretrained` ⭐ | 24.6M | 2D |
420
+ | **ResNet3D** 3D Residual Network |||
421
+ | `resnet3d_18` | 33.6M | 3D |
422
+ | `mc3_18` Mixed Convolution 3D | 11.9M | 3D |
423
+ | **TCN** Temporal Convolutional Network |||
424
+ | `tcn_small` | 1.0M | 1D |
425
+ | `tcn` | 7.0M | 1D |
426
+ | `tcn_large` | 10.2M | 1D |
427
+ | **EfficientNet** Efficient Neural Network |||
428
+ | `efficientnet_b0` | 4.7M | 2D |
429
+ | `efficientnet_b1` ⭐ | 7.2M | 2D |
430
+ | `efficientnet_b2` | 8.4M | 2D |
431
+ | **EfficientNetV2** Efficient Neural Network V2 |||
432
+ | `efficientnet_v2_s` | 21.0M | 2D |
433
+ | `efficientnet_v2_m` ⭐ | 53.6M | 2D |
434
+ | `efficientnet_v2_l` | 118.0M | 2D |
435
+ | **MobileNetV3** — Mobile Neural Network V3 |||
436
+ | `mobilenet_v3_small` ⭐ | 1.1M | 2D |
437
+ | `mobilenet_v3_large` ⭐ | 3.2M | 2D |
438
+ | **RegNet** — Regularized Network |||
439
+ | `regnet_y_400mf` ⭐ | 4.0M | 2D |
440
+ | `regnet_y_800mf` ⭐ | 5.8M | 2D |
441
+ | `regnet_y_1_6gf` ⭐ | 10.5M | 2D |
442
+ | `regnet_y_3_2gf` ⭐ | 18.3M | 2D |
443
+ | `regnet_y_8gf` ⭐ | 37.9M | 2D |
444
+ | **Swin** — Shifted Window Transformer |||
445
+ | `swin_t` ⭐ | 28.0M | 2D |
446
+ | `swin_s` ⭐ | 49.4M | 2D |
447
+ | `swin_b` ⭐ | 87.4M | 2D |
448
+ | **ConvNeXt** — Convolutional Next |||
449
+ | `convnext_tiny` | 28.2M | 1D/2D/3D |
450
+ | `convnext_small` | 49.8M | 1D/2D/3D |
451
+ | `convnext_base` | 88.1M | 1D/2D/3D |
452
+ | `convnext_tiny_pretrained` ⭐ | 28.2M | 2D |
453
+ | **DenseNet** — Densely Connected Network |||
454
+ | `densenet121` | 7.5M | 1D/2D/3D |
455
+ | `densenet169` | 13.3M | 1D/2D/3D |
456
+ | `densenet121_pretrained` ⭐ | 7.5M | 2D |
457
+ | **ViT** — Vision Transformer |||
458
+ | `vit_tiny` | 5.5M | 1D/2D |
459
+ | `vit_small` | 21.6M | 1D/2D |
460
+ | `vit_base` | 85.6M | 1D/2D |
461
+ | **U-Net** — U-shaped Network |||
462
+ | `unet_regression` | 31.1M | 1D/2D/3D |
463
+
464
+ > ⭐ = Pretrained on ImageNet. Recommended for smaller datasets.
442
465
 
443
466
  </details>
444
467
 
@@ -479,12 +502,32 @@ WaveDL/
479
502
 
480
503
  | Argument | Default | Description |
481
504
  |----------|---------|-------------|
482
- | `--compile` | `False` | Enable `torch.compile` |
505
+ | `--compile` | `False` | Enable `torch.compile` (recommended for long runs) |
483
506
  | `--precision` | `bf16` | Mixed precision mode (`bf16`, `fp16`, `no`) |
507
+ | `--workers` | `-1` | DataLoader workers per GPU (-1=auto, up to 16) |
484
508
  | `--wandb` | `False` | Enable W&B logging |
509
+ | `--wandb_watch` | `False` | Enable W&B gradient watching (adds overhead) |
485
510
  | `--project_name` | `DL-Training` | W&B project name |
486
511
  | `--run_name` | `None` | W&B run name (auto-generated if not set) |
487
512
 
513
+ **Automatic GPU Optimizations:**
514
+
515
+ WaveDL automatically enables performance optimizations for modern GPUs:
516
+
517
+ | Optimization | Effect | GPU Support |
518
+ |--------------|--------|-------------|
519
+ | **TF32 precision** | ~2x speedup for float32 matmul | A100, H100 (Ampere+) |
520
+ | **cuDNN benchmark** | Auto-tuned convolutions | All NVIDIA GPUs |
521
+ | **Worker scaling** | Up to 16 workers per GPU | All systems |
522
+
523
+ > [!NOTE]
524
+ > These optimizations are **backward compatible** — they have no effect on older GPUs (V100, T4, GTX) or CPU-only systems. No configuration needed.
525
+
526
+ **HPC Best Practices:**
527
+ - Stage data to `$SLURM_TMPDIR` (local NVMe) for maximum I/O throughput
528
+ - Use `--compile` for training runs > 50 epochs
529
+ - Increase `--workers` manually if auto-detection is suboptimal
530
+
488
531
  </details>
489
532
 
490
533
  <details>
@@ -653,12 +696,7 @@ seed: 2025
653
696
 
654
697
  Automatically find the best training configuration using [Optuna](https://optuna.org/).
655
698
 
656
- **Step 1: Install**
657
- ```bash
658
- pip install -e ".[hpo]"
659
- ```
660
-
661
- **Step 2: Run HPO**
699
+ **Run HPO:**
662
700
 
663
701
  You specify which models to search and how many trials to run:
664
702
  ```bash
@@ -672,7 +710,7 @@ python -m wavedl.hpo --data_path train.npz --models cnn --n_trials 50
672
710
  python -m wavedl.hpo --data_path train.npz --models cnn resnet18 resnet50 vit_small densenet121 --n_trials 200
673
711
  ```
674
712
 
675
- **Step 3: Train with best parameters**
713
+ **Train with best parameters**
676
714
 
677
715
  After HPO completes, it prints the optimal command:
678
716
  ```bash
@@ -715,7 +753,7 @@ accelerate launch -m wavedl.train --data_path train.npz --model cnn --lr 3.2e-4
715
753
  | `--output` | `hpo_results.json` | Output file |
716
754
 
717
755
  > [!TIP]
718
- > See [Available Models](#available-models) for all 21 architectures you can search.
756
+ > See [Available Models](#available-models) for all 38 architectures you can search.
719
757
 
720
758
  </details>
721
759
 
@@ -819,7 +857,7 @@ import numpy as np
819
857
  X = np.random.randn(1000, 256, 256).astype(np.float32)
820
858
  y = np.random.randn(1000, 5).astype(np.float32)
821
859
 
822
- np.savez('test_data.npz', input_train=X, output_train=y)
860
+ np.savez('test_data.npz', input_test=X, output_test=y)
823
861
  ```
824
862
 
825
863
  </details>
@@ -831,7 +869,7 @@ np.savez('test_data.npz', input_train=X, output_train=y)
831
869
  import numpy as np
832
870
 
833
871
  data = np.load('train_data.npz')
834
- assert data['input_train'].ndim == 3, "Input must be 3D: (N, H, W)"
872
+ assert data['input_train'].ndim >= 2, "Input must be at least 2D: (N, ...) "
835
873
  assert data['output_train'].ndim == 2, "Output must be 2D: (N, T)"
836
874
  assert len(data['input_train']) == len(data['output_train']), "Sample mismatch"
837
875
 
@@ -986,7 +1024,7 @@ Beyond the material characterization example above, the WaveDL pipeline can be a
986
1024
  | Resource | Description |
987
1025
  |----------|-------------|
988
1026
  | Technical Paper | In-depth framework description *(coming soon)* |
989
- | [`_template.py`](models/_template.py) | Template for new architectures |
1027
+ | [`_template.py`](src/wavedl/models/_template.py) | Template for custom architectures |
990
1028
 
991
1029
  ---
992
1030
 
@@ -152,20 +152,20 @@ Deploy models anywhere:
152
152
 
153
153
  ### Installation
154
154
 
155
+ #### From PyPI (recommended for all users)
156
+
155
157
  ```bash
156
- # Install from PyPI (recommended)
157
158
  pip install wavedl
158
-
159
- # Or install with all extras (ONNX export, HPO, dev tools)
160
- pip install wavedl[all]
161
159
  ```
162
160
 
161
+ This installs everything you need: training, inference, HPO, ONNX export, and dev tools.
162
+
163
163
  #### From Source (for development)
164
164
 
165
165
  ```bash
166
166
  git clone https://github.com/ductho-le/WaveDL.git
167
167
  cd WaveDL
168
- pip install -e ".[dev]"
168
+ pip install -e .
169
169
  ```
170
170
 
171
171
  > [!NOTE]
@@ -301,41 +301,47 @@ WaveDL handles everything else: training loop, logging, checkpoints, multi-GPU,
301
301
  ```
302
302
  WaveDL/
303
303
  ├── src/
304
- │ └── wavedl/ # Main package (namespaced)
305
- │ ├── __init__.py # Package init with __version__
306
- │ ├── train.py # Training entry point
307
- │ ├── test.py # Testing & inference script
308
- │ ├── hpo.py # Hyperparameter optimization
309
- │ ├── hpc.py # HPC distributed training launcher
304
+ │ └── wavedl/ # Main package (namespaced)
305
+ │ ├── __init__.py # Package init with __version__
306
+ │ ├── train.py # Training entry point
307
+ │ ├── test.py # Testing & inference script
308
+ │ ├── hpo.py # Hyperparameter optimization
309
+ │ ├── hpc.py # HPC distributed training launcher
310
310
  │ │
311
- │ ├── models/ # Model architectures
312
- │ │ ├── registry.py # Model factory (@register_model)
313
- │ │ ├── base.py # Abstract base class
314
- │ │ ├── cnn.py # Baseline CNN
315
- │ │ ├── resnet.py # ResNet-18/34/50 (1D/2D/3D)
316
- │ │ ├── efficientnet.py# EfficientNet-B0/B1/B2
317
- │ │ ├── vit.py # Vision Transformer (1D/2D)
318
- │ │ ├── convnext.py # ConvNeXt (1D/2D/3D)
319
- │ │ ├── densenet.py # DenseNet-121/169 (1D/2D/3D)
320
- │ │ └── unet.py # U-Net / U-Net Regression
311
+ │ ├── models/ # Model architectures (38 variants)
312
+ │ │ ├── registry.py # Model factory (@register_model)
313
+ │ │ ├── base.py # Abstract base class
314
+ │ │ ├── cnn.py # Baseline CNN (1D/2D/3D)
315
+ │ │ ├── resnet.py # ResNet-18/34/50 (1D/2D/3D)
316
+ │ │ ├── resnet3d.py # ResNet3D-18, MC3-18 (3D only)
317
+ │ │ ├── tcn.py # TCN (1D only)
318
+ │ │ ├── efficientnet.py # EfficientNet-B0/B1/B2 (2D)
319
+ │ │ ├── efficientnetv2.py # EfficientNetV2-S/M/L (2D)
320
+ │ │ ├── mobilenetv3.py # MobileNetV3-Small/Large (2D)
321
+ │ │ ├── regnet.py # RegNetY variants (2D)
322
+ │ │ ├── swin.py # Swin Transformer (2D)
323
+ │ │ ├── vit.py # Vision Transformer (1D/2D)
324
+ │ │ ├── convnext.py # ConvNeXt (1D/2D/3D)
325
+ │ │ ├── densenet.py # DenseNet-121/169 (1D/2D/3D)
326
+ │ │ └── unet.py # U-Net Regression
321
327
  │ │
322
- │ └── utils/ # Utilities
323
- │ ├── data.py # Memory-mapped data pipeline
324
- │ ├── metrics.py # R², Pearson, visualization
325
- │ ├── distributed.py # DDP synchronization
326
- │ ├── losses.py # Loss function factory
327
- │ ├── optimizers.py # Optimizer factory
328
- │ ├── schedulers.py # LR scheduler factory
329
- │ └── config.py # YAML configuration support
328
+ │ └── utils/ # Utilities
329
+ │ ├── data.py # Memory-mapped data pipeline
330
+ │ ├── metrics.py # R², Pearson, visualization
331
+ │ ├── distributed.py # DDP synchronization
332
+ │ ├── losses.py # Loss function factory
333
+ │ ├── optimizers.py # Optimizer factory
334
+ │ ├── schedulers.py # LR scheduler factory
335
+ │ └── config.py # YAML configuration support
330
336
 
331
- ├── configs/ # YAML config templates
332
- ├── examples/ # Ready-to-run examples
333
- ├── notebooks/ # Jupyter notebooks
334
- ├── unit_tests/ # Pytest test suite (422 tests)
337
+ ├── configs/ # YAML config templates
338
+ ├── examples/ # Ready-to-run examples
339
+ ├── notebooks/ # Jupyter notebooks
340
+ ├── unit_tests/ # Pytest test suite (704 tests)
335
341
 
336
- ├── pyproject.toml # Package config, dependencies
337
- ├── CHANGELOG.md # Version history
338
- └── CITATION.cff # Citation metadata
342
+ ├── pyproject.toml # Package config, dependencies
343
+ ├── CHANGELOG.md # Version history
344
+ └── CITATION.cff # Citation metadata
339
345
  ```
340
346
  ---
341
347
 
@@ -354,33 +360,63 @@ WaveDL/
354
360
  > ```
355
361
 
356
362
  <details>
357
- <summary><b>Available Models</b> — 21 pre-built architectures</summary>
358
-
359
- | Model | Best For | Params (2D) | Dimensionality |
360
- |-------|----------|-------------|----------------|
361
- | `cnn` | Baseline, lightweight | 1.7M | 1D/2D/3D |
362
- | `resnet18` | Fast training, smaller datasets | 11.4M | 1D/2D/3D |
363
- | `resnet34` | Balanced performance | 21.5M | 1D/2D/3D |
364
- | `resnet50` | High capacity, complex patterns | 24.6M | 1D/2D/3D |
365
- | `resnet18_pretrained` | **Transfer learning** ⭐ | 11.4M | 2D only |
366
- | `resnet50_pretrained` | **Transfer learning** ⭐ | 24.6M | 2D only |
367
- | `efficientnet_b0` | Efficient, **pretrained** ⭐ | 4.7M | 2D only |
368
- | `efficientnet_b1` | Efficient, **pretrained** ⭐ | 7.2M | 2D only |
369
- | `efficientnet_b2` | Efficient, **pretrained** | 8.4M | 2D only |
370
- | `vit_tiny` | Transformer, small datasets | 5.4M | 1D/2D |
371
- | `vit_small` | Transformer, balanced | 21.5M | 1D/2D |
372
- | `vit_base` | Transformer, high capacity | 85.5M | 1D/2D |
373
- | `convnext_tiny` | Modern CNN, transformer-inspired | 28.2M | 1D/2D/3D |
374
- | `convnext_tiny_pretrained` | **Transfer learning** ⭐ | 28.2M | 2D only |
375
- | `convnext_small` | Modern CNN, balanced | 49.8M | 1D/2D/3D |
376
- | `convnext_base` | Modern CNN, high capacity | 88.1M | 1D/2D/3D |
377
- | `densenet121` | Feature reuse, small data | 7.5M | 1D/2D/3D |
378
- | `densenet121_pretrained` | **Transfer learning** ⭐ | 7.5M | 2D only |
379
- | `densenet169` | Deeper DenseNet | 13.3M | 1D/2D/3D |
380
- | `unet` | Spatial output (velocity fields) | 31.0M | 1D/2D/3D |
381
- | `unet_regression` | Multi-scale features for regression | 31.1M | 1D/2D/3D |
382
-
383
- >**Pretrained models** use ImageNet weights for transfer learning.
363
+ <summary><b>Available Models</b> — 38 architectures</summary>
364
+
365
+ | Model | Params | Dim |
366
+ |-------|--------|-----|
367
+ | **CNN** Convolutional Neural Network |||
368
+ | `cnn` | 1.7M | 1D/2D/3D |
369
+ | **ResNet** Residual Network |||
370
+ | `resnet18` | 11.4M | 1D/2D/3D |
371
+ | `resnet34` | 21.5M | 1D/2D/3D |
372
+ | `resnet50` | 24.6M | 1D/2D/3D |
373
+ | `resnet18_pretrained` ⭐ | 11.4M | 2D |
374
+ | `resnet50_pretrained` ⭐ | 24.6M | 2D |
375
+ | **ResNet3D** 3D Residual Network |||
376
+ | `resnet3d_18` | 33.6M | 3D |
377
+ | `mc3_18` Mixed Convolution 3D | 11.9M | 3D |
378
+ | **TCN** Temporal Convolutional Network |||
379
+ | `tcn_small` | 1.0M | 1D |
380
+ | `tcn` | 7.0M | 1D |
381
+ | `tcn_large` | 10.2M | 1D |
382
+ | **EfficientNet** Efficient Neural Network |||
383
+ | `efficientnet_b0` | 4.7M | 2D |
384
+ | `efficientnet_b1` ⭐ | 7.2M | 2D |
385
+ | `efficientnet_b2` | 8.4M | 2D |
386
+ | **EfficientNetV2** Efficient Neural Network V2 |||
387
+ | `efficientnet_v2_s` | 21.0M | 2D |
388
+ | `efficientnet_v2_m` ⭐ | 53.6M | 2D |
389
+ | `efficientnet_v2_l` | 118.0M | 2D |
390
+ | **MobileNetV3** — Mobile Neural Network V3 |||
391
+ | `mobilenet_v3_small` ⭐ | 1.1M | 2D |
392
+ | `mobilenet_v3_large` ⭐ | 3.2M | 2D |
393
+ | **RegNet** — Regularized Network |||
394
+ | `regnet_y_400mf` ⭐ | 4.0M | 2D |
395
+ | `regnet_y_800mf` ⭐ | 5.8M | 2D |
396
+ | `regnet_y_1_6gf` ⭐ | 10.5M | 2D |
397
+ | `regnet_y_3_2gf` ⭐ | 18.3M | 2D |
398
+ | `regnet_y_8gf` ⭐ | 37.9M | 2D |
399
+ | **Swin** — Shifted Window Transformer |||
400
+ | `swin_t` ⭐ | 28.0M | 2D |
401
+ | `swin_s` ⭐ | 49.4M | 2D |
402
+ | `swin_b` ⭐ | 87.4M | 2D |
403
+ | **ConvNeXt** — Convolutional Next |||
404
+ | `convnext_tiny` | 28.2M | 1D/2D/3D |
405
+ | `convnext_small` | 49.8M | 1D/2D/3D |
406
+ | `convnext_base` | 88.1M | 1D/2D/3D |
407
+ | `convnext_tiny_pretrained` ⭐ | 28.2M | 2D |
408
+ | **DenseNet** — Densely Connected Network |||
409
+ | `densenet121` | 7.5M | 1D/2D/3D |
410
+ | `densenet169` | 13.3M | 1D/2D/3D |
411
+ | `densenet121_pretrained` ⭐ | 7.5M | 2D |
412
+ | **ViT** — Vision Transformer |||
413
+ | `vit_tiny` | 5.5M | 1D/2D |
414
+ | `vit_small` | 21.6M | 1D/2D |
415
+ | `vit_base` | 85.6M | 1D/2D |
416
+ | **U-Net** — U-shaped Network |||
417
+ | `unet_regression` | 31.1M | 1D/2D/3D |
418
+
419
+ > ⭐ = Pretrained on ImageNet. Recommended for smaller datasets.
384
420
 
385
421
  </details>
386
422
 
@@ -421,12 +457,32 @@ WaveDL/
421
457
 
422
458
  | Argument | Default | Description |
423
459
  |----------|---------|-------------|
424
- | `--compile` | `False` | Enable `torch.compile` |
460
+ | `--compile` | `False` | Enable `torch.compile` (recommended for long runs) |
425
461
  | `--precision` | `bf16` | Mixed precision mode (`bf16`, `fp16`, `no`) |
462
+ | `--workers` | `-1` | DataLoader workers per GPU (-1=auto, up to 16) |
426
463
  | `--wandb` | `False` | Enable W&B logging |
464
+ | `--wandb_watch` | `False` | Enable W&B gradient watching (adds overhead) |
427
465
  | `--project_name` | `DL-Training` | W&B project name |
428
466
  | `--run_name` | `None` | W&B run name (auto-generated if not set) |
429
467
 
468
+ **Automatic GPU Optimizations:**
469
+
470
+ WaveDL automatically enables performance optimizations for modern GPUs:
471
+
472
+ | Optimization | Effect | GPU Support |
473
+ |--------------|--------|-------------|
474
+ | **TF32 precision** | ~2x speedup for float32 matmul | A100, H100 (Ampere+) |
475
+ | **cuDNN benchmark** | Auto-tuned convolutions | All NVIDIA GPUs |
476
+ | **Worker scaling** | Up to 16 workers per GPU | All systems |
477
+
478
+ > [!NOTE]
479
+ > These optimizations are **backward compatible** — they have no effect on older GPUs (V100, T4, GTX) or CPU-only systems. No configuration needed.
480
+
481
+ **HPC Best Practices:**
482
+ - Stage data to `$SLURM_TMPDIR` (local NVMe) for maximum I/O throughput
483
+ - Use `--compile` for training runs > 50 epochs
484
+ - Increase `--workers` manually if auto-detection is suboptimal
485
+
430
486
  </details>
431
487
 
432
488
  <details>
@@ -595,12 +651,7 @@ seed: 2025
595
651
 
596
652
  Automatically find the best training configuration using [Optuna](https://optuna.org/).
597
653
 
598
- **Step 1: Install**
599
- ```bash
600
- pip install -e ".[hpo]"
601
- ```
602
-
603
- **Step 2: Run HPO**
654
+ **Run HPO:**
604
655
 
605
656
  You specify which models to search and how many trials to run:
606
657
  ```bash
@@ -614,7 +665,7 @@ python -m wavedl.hpo --data_path train.npz --models cnn --n_trials 50
614
665
  python -m wavedl.hpo --data_path train.npz --models cnn resnet18 resnet50 vit_small densenet121 --n_trials 200
615
666
  ```
616
667
 
617
- **Step 3: Train with best parameters**
668
+ **Train with best parameters**
618
669
 
619
670
  After HPO completes, it prints the optimal command:
620
671
  ```bash
@@ -657,7 +708,7 @@ accelerate launch -m wavedl.train --data_path train.npz --model cnn --lr 3.2e-4
657
708
  | `--output` | `hpo_results.json` | Output file |
658
709
 
659
710
  > [!TIP]
660
- > See [Available Models](#available-models) for all 21 architectures you can search.
711
+ > See [Available Models](#available-models) for all 38 architectures you can search.
661
712
 
662
713
  </details>
663
714
 
@@ -761,7 +812,7 @@ import numpy as np
761
812
  X = np.random.randn(1000, 256, 256).astype(np.float32)
762
813
  y = np.random.randn(1000, 5).astype(np.float32)
763
814
 
764
- np.savez('test_data.npz', input_train=X, output_train=y)
815
+ np.savez('test_data.npz', input_test=X, output_test=y)
765
816
  ```
766
817
 
767
818
  </details>
@@ -773,7 +824,7 @@ np.savez('test_data.npz', input_train=X, output_train=y)
773
824
  import numpy as np
774
825
 
775
826
  data = np.load('train_data.npz')
776
- assert data['input_train'].ndim == 3, "Input must be 3D: (N, H, W)"
827
+ assert data['input_train'].ndim >= 2, "Input must be at least 2D: (N, ...) "
777
828
  assert data['output_train'].ndim == 2, "Output must be 2D: (N, T)"
778
829
  assert len(data['input_train']) == len(data['output_train']), "Sample mismatch"
779
830
 
@@ -928,7 +979,7 @@ Beyond the material characterization example above, the WaveDL pipeline can be a
928
979
  | Resource | Description |
929
980
  |----------|-------------|
930
981
  | Technical Paper | In-depth framework description *(coming soon)* |
931
- | [`_template.py`](models/_template.py) | Template for new architectures |
982
+ | [`_template.py`](src/wavedl/models/_template.py) | Template for custom architectures |
932
983
 
933
984
  ---
934
985