wavedl 1.4.5__tar.gz → 1.5.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (43) hide show
  1. {wavedl-1.4.5/src/wavedl.egg-info → wavedl-1.5.0}/PKG-INFO +127 -29
  2. {wavedl-1.4.5 → wavedl-1.5.0}/README.md +126 -28
  3. {wavedl-1.4.5 → wavedl-1.5.0}/src/wavedl/__init__.py +1 -1
  4. {wavedl-1.4.5 → wavedl-1.5.0}/src/wavedl/hpc.py +11 -2
  5. {wavedl-1.4.5 → wavedl-1.5.0}/src/wavedl/hpo.py +60 -3
  6. {wavedl-1.4.5 → wavedl-1.5.0}/src/wavedl/test.py +13 -7
  7. {wavedl-1.4.5 → wavedl-1.5.0}/src/wavedl/train.py +100 -9
  8. {wavedl-1.4.5 → wavedl-1.5.0}/src/wavedl/utils/__init__.py +11 -0
  9. {wavedl-1.4.5 → wavedl-1.5.0}/src/wavedl/utils/config.py +10 -0
  10. wavedl-1.5.0/src/wavedl/utils/constraints.py +470 -0
  11. {wavedl-1.4.5 → wavedl-1.5.0}/src/wavedl/utils/metrics.py +49 -2
  12. {wavedl-1.4.5 → wavedl-1.5.0/src/wavedl.egg-info}/PKG-INFO +127 -29
  13. {wavedl-1.4.5 → wavedl-1.5.0}/src/wavedl.egg-info/SOURCES.txt +1 -0
  14. {wavedl-1.4.5 → wavedl-1.5.0}/LICENSE +0 -0
  15. {wavedl-1.4.5 → wavedl-1.5.0}/pyproject.toml +0 -0
  16. {wavedl-1.4.5 → wavedl-1.5.0}/setup.cfg +0 -0
  17. {wavedl-1.4.5 → wavedl-1.5.0}/src/wavedl/models/__init__.py +0 -0
  18. {wavedl-1.4.5 → wavedl-1.5.0}/src/wavedl/models/_template.py +0 -0
  19. {wavedl-1.4.5 → wavedl-1.5.0}/src/wavedl/models/base.py +0 -0
  20. {wavedl-1.4.5 → wavedl-1.5.0}/src/wavedl/models/cnn.py +0 -0
  21. {wavedl-1.4.5 → wavedl-1.5.0}/src/wavedl/models/convnext.py +0 -0
  22. {wavedl-1.4.5 → wavedl-1.5.0}/src/wavedl/models/densenet.py +0 -0
  23. {wavedl-1.4.5 → wavedl-1.5.0}/src/wavedl/models/efficientnet.py +0 -0
  24. {wavedl-1.4.5 → wavedl-1.5.0}/src/wavedl/models/efficientnetv2.py +0 -0
  25. {wavedl-1.4.5 → wavedl-1.5.0}/src/wavedl/models/mobilenetv3.py +0 -0
  26. {wavedl-1.4.5 → wavedl-1.5.0}/src/wavedl/models/registry.py +0 -0
  27. {wavedl-1.4.5 → wavedl-1.5.0}/src/wavedl/models/regnet.py +0 -0
  28. {wavedl-1.4.5 → wavedl-1.5.0}/src/wavedl/models/resnet.py +0 -0
  29. {wavedl-1.4.5 → wavedl-1.5.0}/src/wavedl/models/resnet3d.py +0 -0
  30. {wavedl-1.4.5 → wavedl-1.5.0}/src/wavedl/models/swin.py +0 -0
  31. {wavedl-1.4.5 → wavedl-1.5.0}/src/wavedl/models/tcn.py +0 -0
  32. {wavedl-1.4.5 → wavedl-1.5.0}/src/wavedl/models/unet.py +0 -0
  33. {wavedl-1.4.5 → wavedl-1.5.0}/src/wavedl/models/vit.py +0 -0
  34. {wavedl-1.4.5 → wavedl-1.5.0}/src/wavedl/utils/cross_validation.py +0 -0
  35. {wavedl-1.4.5 → wavedl-1.5.0}/src/wavedl/utils/data.py +0 -0
  36. {wavedl-1.4.5 → wavedl-1.5.0}/src/wavedl/utils/distributed.py +0 -0
  37. {wavedl-1.4.5 → wavedl-1.5.0}/src/wavedl/utils/losses.py +0 -0
  38. {wavedl-1.4.5 → wavedl-1.5.0}/src/wavedl/utils/optimizers.py +0 -0
  39. {wavedl-1.4.5 → wavedl-1.5.0}/src/wavedl/utils/schedulers.py +0 -0
  40. {wavedl-1.4.5 → wavedl-1.5.0}/src/wavedl.egg-info/dependency_links.txt +0 -0
  41. {wavedl-1.4.5 → wavedl-1.5.0}/src/wavedl.egg-info/entry_points.txt +0 -0
  42. {wavedl-1.4.5 → wavedl-1.5.0}/src/wavedl.egg-info/requires.txt +0 -0
  43. {wavedl-1.4.5 → wavedl-1.5.0}/src/wavedl.egg-info/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.2
2
2
  Name: wavedl
3
- Version: 1.4.5
3
+ Version: 1.5.0
4
4
  Summary: A Scalable Deep Learning Framework for Wave-Based Inverse Problems
5
5
  Author: Ductho Le
6
6
  License: MIT
@@ -49,7 +49,7 @@ Requires-Dist: triton>=2.0.0; sys_platform == "linux"
49
49
 
50
50
  ### A Scalable Deep Learning Framework for Wave-Based Inverse Problems
51
51
 
52
- [![Python 3.11+](https://img.shields.io/badge/python-3.11+-blue.svg?style=plastic&logo=python&logoColor=white)](https://www.python.org/downloads/)
52
+ [![Python 3.11+](https://img.shields.io/badge/Python-3.11+-blue.svg?style=plastic&logo=python&logoColor=white)](https://www.python.org/downloads/)
53
53
  [![PyTorch 2.x](https://img.shields.io/badge/PyTorch-2.x-ee4c2c.svg?style=plastic&logo=pytorch&logoColor=white)](https://pytorch.org/)
54
54
  [![Accelerate](https://img.shields.io/badge/Accelerate-Enabled-yellow.svg?style=plastic&logo=huggingface&logoColor=white)](https://huggingface.co/docs/accelerate/)
55
55
  <br>
@@ -57,7 +57,7 @@ Requires-Dist: triton>=2.0.0; sys_platform == "linux"
57
57
  [![Lint](https://img.shields.io/github/actions/workflow/status/ductho-le/WaveDL/lint.yml?branch=main&style=plastic&logo=ruff&logoColor=white&label=Lint)](https://github.com/ductho-le/WaveDL/actions/workflows/lint.yml)
58
58
  [![Try it on Colab](https://img.shields.io/badge/Try_it_on_Colab-8E44AD?style=plastic&logo=googlecolab&logoColor=white)](https://colab.research.google.com/github/ductho-le/WaveDL/blob/main/notebooks/demo.ipynb)
59
59
  <br>
60
- [![Downloads](https://img.shields.io/pepy/dt/wavedl?style=plastic&logo=pypi&logoColor=white&color=9ACD32)](https://pepy.tech/project/wavedl)
60
+ [![Downloads](https://img.shields.io/badge/dynamic/json?url=https://pypistats.org/api/packages/wavedl/recent?period=month%26mirrors=false&query=data.last_month&style=plastic&logo=pypi&logoColor=white&color=9ACD32&label=Downloads&suffix=/month)](https://pypistats.org/packages/wavedl)
61
61
  [![License: MIT](https://img.shields.io/badge/License-MIT-orange.svg?style=plastic)](LICENSE)
62
62
  [![DOI](https://img.shields.io/badge/DOI-10.5281/zenodo.18012338-008080.svg?style=plastic)](https://doi.org/10.5281/zenodo.18012338)
63
63
 
@@ -113,14 +113,12 @@ Train on datasets larger than RAM:
113
113
  </td>
114
114
  <td width="50%" valign="top">
115
115
 
116
- **🧠 One-Line Model Registration**
116
+ **🧠 Models? We've Got Options**
117
117
 
118
- Plug in any architecture:
119
- ```python
120
- @register_model("my_net")
121
- class MyNet(BaseModel): ...
122
- ```
123
- Design your model. Register with one line.
118
+ 38 architectures, ready to go:
119
+ - CNNs, ResNets, ViTs, EfficientNets...
120
+ - All adapted for regression
121
+ - [Add your own](#adding-custom-models) in one line
124
122
 
125
123
  </td>
126
124
  </tr>
@@ -137,12 +135,12 @@ Multi-GPU training without the pain:
137
135
  </td>
138
136
  <td width="50%" valign="top">
139
137
 
140
- **📊 Publish-Ready Output**
138
+ **🔬 Physics-Constrained Training**
141
139
 
142
- Results go straight to your paper:
143
- - 11 diagnostic plots with LaTeX styling
144
- - Multi-format export (PNG, PDF, SVG, ...)
145
- - MAE in physical units per parameter
140
+ Make your model respect the laws:
141
+ - Enforce bounds, positivity, equations
142
+ - Simple expression syntax or Python
143
+ - [Custom constraints](#physical-constraints) for various laws
146
144
 
147
145
  </td>
148
146
  </tr>
@@ -383,7 +381,7 @@ WaveDL/
383
381
  ├── configs/ # YAML config templates
384
382
  ├── examples/ # Ready-to-run examples
385
383
  ├── notebooks/ # Jupyter notebooks
386
- ├── unit_tests/ # Pytest test suite (704 tests)
384
+ ├── unit_tests/ # Pytest test suite (725 tests)
387
385
 
388
386
  ├── pyproject.toml # Package config, dependencies
389
387
  ├── CHANGELOG.md # Version history
@@ -727,6 +725,104 @@ seed: 2025
727
725
 
728
726
  </details>
729
727
 
728
+ <details>
729
+ <summary><b>Physical Constraints</b> — Enforce Physics During Training</summary>
730
+
731
+ Add penalty terms to the loss function to enforce physical laws:
732
+
733
+ ```
734
+ Total Loss = Data Loss + weight × penalty(violation)
735
+ ```
736
+
737
+ ### Expression Constraints
738
+
739
+ ```bash
740
+ # Positivity
741
+ --constraint "y0 > 0"
742
+
743
+ # Bounds
744
+ --constraint "y0 >= 0" "y0 <= 1"
745
+
746
+ # Equations (penalize deviations from zero)
747
+ --constraint "y2 - y0 * y1"
748
+
749
+ # Input-dependent constraints
750
+ --constraint "y0 - 2*x[0]"
751
+
752
+ # Multiple constraints with different weights
753
+ --constraint "y0 > 0" "y1 - y2" --constraint_weight 0.1 1.0
754
+ ```
755
+
756
+ ### Custom Python Constraints
757
+
758
+ For complex physics (matrix operations, implicit equations):
759
+
760
+ ```python
761
+ # my_constraint.py
762
+ import torch
763
+
764
+ def constraint(pred, inputs=None):
765
+ """
766
+ Args:
767
+ pred: (batch, num_outputs)
768
+ inputs: (batch, features) or (batch, C, H, W) or (batch, C, D, H, W)
769
+ Returns:
770
+ (batch,) — violation per sample (0 = satisfied)
771
+ """
772
+ # Outputs (same for all data types)
773
+ y0, y1, y2 = pred[:, 0], pred[:, 1], pred[:, 2]
774
+
775
+ # Inputs — Tabular: (batch, features)
776
+ # x0 = inputs[:, 0] # Feature 0
777
+ # x_sum = inputs.sum(dim=1) # Sum all features
778
+
779
+ # Inputs — Images: (batch, C, H, W)
780
+ # pixel = inputs[:, 0, 3, 5] # Pixel at (3,5), channel 0
781
+ # img_mean = inputs.mean(dim=(1,2,3)) # Mean over C,H,W
782
+
783
+ # Inputs — 3D Volumes: (batch, C, D, H, W)
784
+ # voxel = inputs[:, 0, 2, 3, 5] # Voxel at (2,3,5), channel 0
785
+
786
+ # Example constraints:
787
+ # return y2 - y0 * y1 # Wave equation
788
+ # return y0 - 2 * inputs[:, 0] # Output = 2×input
789
+ # return inputs[:, 0, 3, 5] * y0 + inputs[:, 0, 6, 7] * y1 # Mixed
790
+
791
+ return y0 - y1 * y2
792
+ ```
793
+
794
+ ```bash
795
+ --constraint_file my_constraint.py --constraint_weight 1.0
796
+ ```
797
+
798
+ ---
799
+
800
+ ### Reference
801
+
802
+ | Argument | Default | Description |
803
+ |----------|---------|-------------|
804
+ | `--constraint` | — | Expression(s): `"y0 > 0"`, `"y0 - y1*y2"` |
805
+ | `--constraint_file` | — | Python file with `constraint(pred, inputs)` |
806
+ | `--constraint_weight` | `0.1` | Penalty weight(s) |
807
+ | `--constraint_reduction` | `mse` | `mse` (squared) or `mae` (linear) |
808
+
809
+ #### Expression Syntax
810
+
811
+ | Variable | Meaning |
812
+ |----------|---------|
813
+ | `y0`, `y1`, ... | Model outputs |
814
+ | `x[0]`, `x[1]`, ... | Input values (1D tabular) |
815
+ | `x[i,j]`, `x[i,j,k]` | Input values (2D/3D: images, volumes) |
816
+ | `x_mean`, `x_sum`, `x_max`, `x_min`, `x_std` | Input aggregates |
817
+
818
+ **Operators:** `+`, `-`, `*`, `/`, `**`, `>`, `<`, `>=`, `<=`, `==`
819
+
820
+ **Functions:** `sin`, `cos`, `exp`, `log`, `sqrt`, `sigmoid`, `softplus`, `tanh`, `relu`, `abs`
821
+
822
+ </details>
823
+
824
+
825
+
730
826
  <details>
731
827
  <summary><b>Hyperparameter Search (HPO)</b></summary>
732
828
 
@@ -734,18 +830,20 @@ Automatically find the best training configuration using [Optuna](https://optuna
734
830
 
735
831
  **Run HPO:**
736
832
 
737
- You specify which models to search and how many trials to run:
738
833
  ```bash
739
- # Search 3 models with 100 trials
740
- python -m wavedl.hpo --data_path train.npz --models cnn resnet18 efficientnet_b0 --n_trials 100
834
+ # Basic HPO (auto-detects GPUs for parallel trials)
835
+ wavedl-hpo --data_path train.npz --models cnn --n_trials 100
741
836
 
742
- # Search 1 model (faster)
743
- python -m wavedl.hpo --data_path train.npz --models cnn --n_trials 50
837
+ # Search multiple models
838
+ wavedl-hpo --data_path train.npz --models cnn resnet18 efficientnet_b0 --n_trials 200
744
839
 
745
- # Search all your candidate models
746
- python -m wavedl.hpo --data_path train.npz --models cnn resnet18 resnet50 vit_small densenet121 --n_trials 200
840
+ # Quick mode (fewer parameters, faster)
841
+ wavedl-hpo --data_path train.npz --models cnn --n_trials 50 --quick
747
842
  ```
748
843
 
844
+ > [!TIP]
845
+ > **Auto GPU Detection**: HPO automatically detects available GPUs and runs one trial per GPU in parallel. On a 4-GPU system, 4 trials run simultaneously. Use `--n_jobs 1` to force serial execution.
846
+
749
847
  **Train with best parameters**
750
848
 
751
849
  After HPO completes, it prints the optimal command:
@@ -764,7 +862,7 @@ accelerate launch -m wavedl.train --data_path train.npz --model cnn --lr 3.2e-4
764
862
  | Schedulers | [all 8](#learning-rate-schedulers) | `--schedulers X Y` |
765
863
  | Losses | [all 6](#loss-functions) | `--losses X Y` |
766
864
  | Learning rate | 1e-5 → 1e-2 | (always searched) |
767
- | Batch size | 64, 128, 256, 512 | (always searched) |
865
+ | Batch size | 16, 32, 64, 128 | (always searched) |
768
866
 
769
867
  **Quick Mode** (`--quick`):
770
868
  - Uses minimal defaults: cnn + adamw + plateau + mse
@@ -784,7 +882,7 @@ accelerate launch -m wavedl.train --data_path train.npz --model cnn --lr 3.2e-4
784
882
  | `--optimizers` | all 6 | Optimizers to search |
785
883
  | `--schedulers` | all 8 | Schedulers to search |
786
884
  | `--losses` | all 6 | Losses to search |
787
- | `--n_jobs` | `1` | Parallel trials (multi-GPU) |
885
+ | `--n_jobs` | `-1` | Parallel trials (-1 = auto-detect GPUs) |
788
886
  | `--max_epochs` | `50` | Max epochs per trial |
789
887
  | `--output` | `hpo_results.json` | Output file |
790
888
 
@@ -936,12 +1034,12 @@ The `examples/` folder contains a **complete, ready-to-run example** for **mater
936
1034
  ```bash
937
1035
  # Run inference on the example data
938
1036
  python -m wavedl.test --checkpoint ./examples/elastic_cnn_example/best_checkpoint \
939
- --data_path ./examples/elastic_cnn_example/Test_data_100.mat \
1037
+ --data_path ./examples/elastic_cnn_example/Test_data_500.mat \
940
1038
  --plot --save_predictions --output_dir ./examples/elastic_cnn_example/test_results
941
1039
 
942
1040
  # Export to ONNX (already included as model.onnx)
943
1041
  python -m wavedl.test --checkpoint ./examples/elastic_cnn_example/best_checkpoint \
944
- --data_path ./examples/elastic_cnn_example/Test_data_100.mat \
1042
+ --data_path ./examples/elastic_cnn_example/Test_data_500.mat \
945
1043
  --export onnx --export_path ./examples/elastic_cnn_example/model.onnx
946
1044
  ```
947
1045
 
@@ -950,7 +1048,7 @@ python -m wavedl.test --checkpoint ./examples/elastic_cnn_example/best_checkpoin
950
1048
  | File | Description |
951
1049
  |------|-------------|
952
1050
  | `best_checkpoint/` | Pre-trained CNN checkpoint |
953
- | `Test_data_100.mat` | 100 sample test set (500×500 dispersion curves → *h*, √(*E*/ρ), *ν*) |
1051
+ | `Test_data_500.mat` | 500 sample test set (500×500 dispersion curves → *h*, √(*E*/ρ), *ν*) |
954
1052
  | `model.onnx` | ONNX export with embedded de-normalization |
955
1053
  | `training_history.csv` | Epoch-by-epoch training metrics (loss, R², LR, etc.) |
956
1054
  | `training_curves.png` | Training/validation loss and learning rate plot |
@@ -961,7 +1059,7 @@ python -m wavedl.test --checkpoint ./examples/elastic_cnn_example/best_checkpoin
961
1059
 
962
1060
  <p align="center">
963
1061
  <img src="examples/elastic_cnn_example/training_curves.png" alt="Training curves" width="600"><br>
964
- <em>Training and validation loss over 162 epochs with learning rate schedule</em>
1062
+ <em>Training and validation loss over 227 epochs with <code>onecycle</code> learning rate schedule</em>
965
1063
  </p>
966
1064
 
967
1065
  **Inference Results:**
@@ -4,7 +4,7 @@
4
4
 
5
5
  ### A Scalable Deep Learning Framework for Wave-Based Inverse Problems
6
6
 
7
- [![Python 3.11+](https://img.shields.io/badge/python-3.11+-blue.svg?style=plastic&logo=python&logoColor=white)](https://www.python.org/downloads/)
7
+ [![Python 3.11+](https://img.shields.io/badge/Python-3.11+-blue.svg?style=plastic&logo=python&logoColor=white)](https://www.python.org/downloads/)
8
8
  [![PyTorch 2.x](https://img.shields.io/badge/PyTorch-2.x-ee4c2c.svg?style=plastic&logo=pytorch&logoColor=white)](https://pytorch.org/)
9
9
  [![Accelerate](https://img.shields.io/badge/Accelerate-Enabled-yellow.svg?style=plastic&logo=huggingface&logoColor=white)](https://huggingface.co/docs/accelerate/)
10
10
  <br>
@@ -12,7 +12,7 @@
12
12
  [![Lint](https://img.shields.io/github/actions/workflow/status/ductho-le/WaveDL/lint.yml?branch=main&style=plastic&logo=ruff&logoColor=white&label=Lint)](https://github.com/ductho-le/WaveDL/actions/workflows/lint.yml)
13
13
  [![Try it on Colab](https://img.shields.io/badge/Try_it_on_Colab-8E44AD?style=plastic&logo=googlecolab&logoColor=white)](https://colab.research.google.com/github/ductho-le/WaveDL/blob/main/notebooks/demo.ipynb)
14
14
  <br>
15
- [![Downloads](https://img.shields.io/pepy/dt/wavedl?style=plastic&logo=pypi&logoColor=white&color=9ACD32)](https://pepy.tech/project/wavedl)
15
+ [![Downloads](https://img.shields.io/badge/dynamic/json?url=https://pypistats.org/api/packages/wavedl/recent?period=month%26mirrors=false&query=data.last_month&style=plastic&logo=pypi&logoColor=white&color=9ACD32&label=Downloads&suffix=/month)](https://pypistats.org/packages/wavedl)
16
16
  [![License: MIT](https://img.shields.io/badge/License-MIT-orange.svg?style=plastic)](LICENSE)
17
17
  [![DOI](https://img.shields.io/badge/DOI-10.5281/zenodo.18012338-008080.svg?style=plastic)](https://doi.org/10.5281/zenodo.18012338)
18
18
 
@@ -68,14 +68,12 @@ Train on datasets larger than RAM:
68
68
  </td>
69
69
  <td width="50%" valign="top">
70
70
 
71
- **🧠 One-Line Model Registration**
71
+ **🧠 Models? We've Got Options**
72
72
 
73
- Plug in any architecture:
74
- ```python
75
- @register_model("my_net")
76
- class MyNet(BaseModel): ...
77
- ```
78
- Design your model. Register with one line.
73
+ 38 architectures, ready to go:
74
+ - CNNs, ResNets, ViTs, EfficientNets...
75
+ - All adapted for regression
76
+ - [Add your own](#adding-custom-models) in one line
79
77
 
80
78
  </td>
81
79
  </tr>
@@ -92,12 +90,12 @@ Multi-GPU training without the pain:
92
90
  </td>
93
91
  <td width="50%" valign="top">
94
92
 
95
- **📊 Publish-Ready Output**
93
+ **🔬 Physics-Constrained Training**
96
94
 
97
- Results go straight to your paper:
98
- - 11 diagnostic plots with LaTeX styling
99
- - Multi-format export (PNG, PDF, SVG, ...)
100
- - MAE in physical units per parameter
95
+ Make your model respect the laws:
96
+ - Enforce bounds, positivity, equations
97
+ - Simple expression syntax or Python
98
+ - [Custom constraints](#physical-constraints) for various laws
101
99
 
102
100
  </td>
103
101
  </tr>
@@ -338,7 +336,7 @@ WaveDL/
338
336
  ├── configs/ # YAML config templates
339
337
  ├── examples/ # Ready-to-run examples
340
338
  ├── notebooks/ # Jupyter notebooks
341
- ├── unit_tests/ # Pytest test suite (704 tests)
339
+ ├── unit_tests/ # Pytest test suite (725 tests)
342
340
 
343
341
  ├── pyproject.toml # Package config, dependencies
344
342
  ├── CHANGELOG.md # Version history
@@ -682,6 +680,104 @@ seed: 2025
682
680
 
683
681
  </details>
684
682
 
683
+ <details>
684
+ <summary><b>Physical Constraints</b> — Enforce Physics During Training</summary>
685
+
686
+ Add penalty terms to the loss function to enforce physical laws:
687
+
688
+ ```
689
+ Total Loss = Data Loss + weight × penalty(violation)
690
+ ```
691
+
692
+ ### Expression Constraints
693
+
694
+ ```bash
695
+ # Positivity
696
+ --constraint "y0 > 0"
697
+
698
+ # Bounds
699
+ --constraint "y0 >= 0" "y0 <= 1"
700
+
701
+ # Equations (penalize deviations from zero)
702
+ --constraint "y2 - y0 * y1"
703
+
704
+ # Input-dependent constraints
705
+ --constraint "y0 - 2*x[0]"
706
+
707
+ # Multiple constraints with different weights
708
+ --constraint "y0 > 0" "y1 - y2" --constraint_weight 0.1 1.0
709
+ ```
710
+
711
+ ### Custom Python Constraints
712
+
713
+ For complex physics (matrix operations, implicit equations):
714
+
715
+ ```python
716
+ # my_constraint.py
717
+ import torch
718
+
719
+ def constraint(pred, inputs=None):
720
+ """
721
+ Args:
722
+ pred: (batch, num_outputs)
723
+ inputs: (batch, features) or (batch, C, H, W) or (batch, C, D, H, W)
724
+ Returns:
725
+ (batch,) — violation per sample (0 = satisfied)
726
+ """
727
+ # Outputs (same for all data types)
728
+ y0, y1, y2 = pred[:, 0], pred[:, 1], pred[:, 2]
729
+
730
+ # Inputs — Tabular: (batch, features)
731
+ # x0 = inputs[:, 0] # Feature 0
732
+ # x_sum = inputs.sum(dim=1) # Sum all features
733
+
734
+ # Inputs — Images: (batch, C, H, W)
735
+ # pixel = inputs[:, 0, 3, 5] # Pixel at (3,5), channel 0
736
+ # img_mean = inputs.mean(dim=(1,2,3)) # Mean over C,H,W
737
+
738
+ # Inputs — 3D Volumes: (batch, C, D, H, W)
739
+ # voxel = inputs[:, 0, 2, 3, 5] # Voxel at (2,3,5), channel 0
740
+
741
+ # Example constraints:
742
+ # return y2 - y0 * y1 # Wave equation
743
+ # return y0 - 2 * inputs[:, 0] # Output = 2×input
744
+ # return inputs[:, 0, 3, 5] * y0 + inputs[:, 0, 6, 7] * y1 # Mixed
745
+
746
+ return y0 - y1 * y2
747
+ ```
748
+
749
+ ```bash
750
+ --constraint_file my_constraint.py --constraint_weight 1.0
751
+ ```
752
+
753
+ ---
754
+
755
+ ### Reference
756
+
757
+ | Argument | Default | Description |
758
+ |----------|---------|-------------|
759
+ | `--constraint` | — | Expression(s): `"y0 > 0"`, `"y0 - y1*y2"` |
760
+ | `--constraint_file` | — | Python file with `constraint(pred, inputs)` |
761
+ | `--constraint_weight` | `0.1` | Penalty weight(s) |
762
+ | `--constraint_reduction` | `mse` | `mse` (squared) or `mae` (linear) |
763
+
764
+ #### Expression Syntax
765
+
766
+ | Variable | Meaning |
767
+ |----------|---------|
768
+ | `y0`, `y1`, ... | Model outputs |
769
+ | `x[0]`, `x[1]`, ... | Input values (1D tabular) |
770
+ | `x[i,j]`, `x[i,j,k]` | Input values (2D/3D: images, volumes) |
771
+ | `x_mean`, `x_sum`, `x_max`, `x_min`, `x_std` | Input aggregates |
772
+
773
+ **Operators:** `+`, `-`, `*`, `/`, `**`, `>`, `<`, `>=`, `<=`, `==`
774
+
775
+ **Functions:** `sin`, `cos`, `exp`, `log`, `sqrt`, `sigmoid`, `softplus`, `tanh`, `relu`, `abs`
776
+
777
+ </details>
778
+
779
+
780
+
685
781
  <details>
686
782
  <summary><b>Hyperparameter Search (HPO)</b></summary>
687
783
 
@@ -689,18 +785,20 @@ Automatically find the best training configuration using [Optuna](https://optuna
689
785
 
690
786
  **Run HPO:**
691
787
 
692
- You specify which models to search and how many trials to run:
693
788
  ```bash
694
- # Search 3 models with 100 trials
695
- python -m wavedl.hpo --data_path train.npz --models cnn resnet18 efficientnet_b0 --n_trials 100
789
+ # Basic HPO (auto-detects GPUs for parallel trials)
790
+ wavedl-hpo --data_path train.npz --models cnn --n_trials 100
696
791
 
697
- # Search 1 model (faster)
698
- python -m wavedl.hpo --data_path train.npz --models cnn --n_trials 50
792
+ # Search multiple models
793
+ wavedl-hpo --data_path train.npz --models cnn resnet18 efficientnet_b0 --n_trials 200
699
794
 
700
- # Search all your candidate models
701
- python -m wavedl.hpo --data_path train.npz --models cnn resnet18 resnet50 vit_small densenet121 --n_trials 200
795
+ # Quick mode (fewer parameters, faster)
796
+ wavedl-hpo --data_path train.npz --models cnn --n_trials 50 --quick
702
797
  ```
703
798
 
799
+ > [!TIP]
800
+ > **Auto GPU Detection**: HPO automatically detects available GPUs and runs one trial per GPU in parallel. On a 4-GPU system, 4 trials run simultaneously. Use `--n_jobs 1` to force serial execution.
801
+
704
802
  **Train with best parameters**
705
803
 
706
804
  After HPO completes, it prints the optimal command:
@@ -719,7 +817,7 @@ accelerate launch -m wavedl.train --data_path train.npz --model cnn --lr 3.2e-4
719
817
  | Schedulers | [all 8](#learning-rate-schedulers) | `--schedulers X Y` |
720
818
  | Losses | [all 6](#loss-functions) | `--losses X Y` |
721
819
  | Learning rate | 1e-5 → 1e-2 | (always searched) |
722
- | Batch size | 64, 128, 256, 512 | (always searched) |
820
+ | Batch size | 16, 32, 64, 128 | (always searched) |
723
821
 
724
822
  **Quick Mode** (`--quick`):
725
823
  - Uses minimal defaults: cnn + adamw + plateau + mse
@@ -739,7 +837,7 @@ accelerate launch -m wavedl.train --data_path train.npz --model cnn --lr 3.2e-4
739
837
  | `--optimizers` | all 6 | Optimizers to search |
740
838
  | `--schedulers` | all 8 | Schedulers to search |
741
839
  | `--losses` | all 6 | Losses to search |
742
- | `--n_jobs` | `1` | Parallel trials (multi-GPU) |
840
+ | `--n_jobs` | `-1` | Parallel trials (-1 = auto-detect GPUs) |
743
841
  | `--max_epochs` | `50` | Max epochs per trial |
744
842
  | `--output` | `hpo_results.json` | Output file |
745
843
 
@@ -891,12 +989,12 @@ The `examples/` folder contains a **complete, ready-to-run example** for **mater
891
989
  ```bash
892
990
  # Run inference on the example data
893
991
  python -m wavedl.test --checkpoint ./examples/elastic_cnn_example/best_checkpoint \
894
- --data_path ./examples/elastic_cnn_example/Test_data_100.mat \
992
+ --data_path ./examples/elastic_cnn_example/Test_data_500.mat \
895
993
  --plot --save_predictions --output_dir ./examples/elastic_cnn_example/test_results
896
994
 
897
995
  # Export to ONNX (already included as model.onnx)
898
996
  python -m wavedl.test --checkpoint ./examples/elastic_cnn_example/best_checkpoint \
899
- --data_path ./examples/elastic_cnn_example/Test_data_100.mat \
997
+ --data_path ./examples/elastic_cnn_example/Test_data_500.mat \
900
998
  --export onnx --export_path ./examples/elastic_cnn_example/model.onnx
901
999
  ```
902
1000
 
@@ -905,7 +1003,7 @@ python -m wavedl.test --checkpoint ./examples/elastic_cnn_example/best_checkpoin
905
1003
  | File | Description |
906
1004
  |------|-------------|
907
1005
  | `best_checkpoint/` | Pre-trained CNN checkpoint |
908
- | `Test_data_100.mat` | 100 sample test set (500×500 dispersion curves → *h*, √(*E*/ρ), *ν*) |
1006
+ | `Test_data_500.mat` | 500 sample test set (500×500 dispersion curves → *h*, √(*E*/ρ), *ν*) |
909
1007
  | `model.onnx` | ONNX export with embedded de-normalization |
910
1008
  | `training_history.csv` | Epoch-by-epoch training metrics (loss, R², LR, etc.) |
911
1009
  | `training_curves.png` | Training/validation loss and learning rate plot |
@@ -916,7 +1014,7 @@ python -m wavedl.test --checkpoint ./examples/elastic_cnn_example/best_checkpoin
916
1014
 
917
1015
  <p align="center">
918
1016
  <img src="examples/elastic_cnn_example/training_curves.png" alt="Training curves" width="600"><br>
919
- <em>Training and validation loss over 162 epochs with learning rate schedule</em>
1017
+ <em>Training and validation loss over 227 epochs with <code>onecycle</code> learning rate schedule</em>
920
1018
  </p>
921
1019
 
922
1020
  **Inference Results:**
@@ -18,7 +18,7 @@ For inference:
18
18
  # or: python -m wavedl.test --checkpoint best_checkpoint --data_path test.npz
19
19
  """
20
20
 
21
- __version__ = "1.4.5"
21
+ __version__ = "1.5.0"
22
22
  __author__ = "Ductho Le"
23
23
  __email__ = "ductho.le@outlook.com"
24
24
 
@@ -174,7 +174,9 @@ Environment Variables:
174
174
  return args, remaining
175
175
 
176
176
 
177
- def print_summary(exit_code: int, wandb_mode: str, wandb_dir: str) -> None:
177
+ def print_summary(
178
+ exit_code: int, wandb_enabled: bool, wandb_mode: str, wandb_dir: str
179
+ ) -> None:
178
180
  """Print post-training summary and instructions."""
179
181
  print()
180
182
  print("=" * 40)
@@ -183,7 +185,8 @@ def print_summary(exit_code: int, wandb_mode: str, wandb_dir: str) -> None:
183
185
  print("✅ Training completed successfully!")
184
186
  print("=" * 40)
185
187
 
186
- if wandb_mode == "offline":
188
+ # Only show WandB sync instructions if user enabled wandb
189
+ if wandb_enabled and wandb_mode == "offline":
187
190
  print()
188
191
  print("📊 WandB Sync Instructions:")
189
192
  print(" From the login node, run:")
@@ -237,6 +240,10 @@ def main() -> int:
237
240
  f"--dynamo_backend={args.dynamo_backend}",
238
241
  ]
239
242
 
243
+ # Explicitly set multi_gpu to suppress accelerate auto-detection warning
244
+ if num_gpus > 1:
245
+ cmd.append("--multi_gpu")
246
+
240
247
  # Add multi-node networking args if specified (required for some clusters)
241
248
  if args.main_process_ip:
242
249
  cmd.append(f"--main_process_ip={args.main_process_ip}")
@@ -263,8 +270,10 @@ def main() -> int:
263
270
  exit_code = 130
264
271
 
265
272
  # Print summary
273
+ wandb_enabled = "--wandb" in train_args
266
274
  print_summary(
267
275
  exit_code,
276
+ wandb_enabled,
268
277
  os.environ.get("WANDB_MODE", "offline"),
269
278
  os.environ.get("WANDB_DIR", "/tmp/wandb"),
270
279
  )
@@ -31,7 +31,7 @@ try:
31
31
  import optuna
32
32
  from optuna.trial import TrialState
33
33
  except ImportError:
34
- print("Error: Optuna not installed. Run: pip install -e '.[hpo]'")
34
+ print("Error: Optuna not installed. Run: pip install wavedl")
35
35
  sys.exit(1)
36
36
 
37
37
 
@@ -89,7 +89,8 @@ def create_objective(args):
89
89
  # Suggest hyperparameters
90
90
  model = trial.suggest_categorical("model", models)
91
91
  lr = trial.suggest_float("lr", 1e-5, 1e-2, log=True)
92
- batch_size = trial.suggest_categorical("batch_size", [64, 128, 256, 512])
92
+ batch_sizes = args.batch_sizes or [16, 32, 64, 128]
93
+ batch_size = trial.suggest_categorical("batch_size", batch_sizes)
93
94
  optimizer = trial.suggest_categorical("optimizer", optimizers)
94
95
  scheduler = trial.suggest_categorical("scheduler", schedulers)
95
96
  loss = trial.suggest_categorical("loss", losses)
@@ -147,6 +148,32 @@ def create_objective(args):
147
148
  cmd.extend(["--output_dir", tmpdir])
148
149
  history_file = Path(tmpdir) / "training_history.csv"
149
150
 
151
+ # GPU isolation for parallel trials: assign each trial to a specific GPU
152
+ # This prevents multiple trials from competing for all GPUs
153
+ env = None
154
+ if args.n_jobs > 1:
155
+ import os
156
+
157
+ # Detect available GPUs
158
+ n_gpus = 1
159
+ try:
160
+ import subprocess as sp
161
+
162
+ result_gpu = sp.run(
163
+ ["nvidia-smi", "--list-gpus"],
164
+ capture_output=True,
165
+ text=True,
166
+ )
167
+ if result_gpu.returncode == 0:
168
+ n_gpus = len(result_gpu.stdout.strip().split("\n"))
169
+ except Exception:
170
+ pass
171
+
172
+ # Assign trial to a specific GPU (round-robin)
173
+ gpu_id = trial.number % n_gpus
174
+ env = os.environ.copy()
175
+ env["CUDA_VISIBLE_DEVICES"] = str(gpu_id)
176
+
150
177
  # Run training
151
178
  try:
152
179
  result = subprocess.run(
@@ -155,6 +182,7 @@ def create_objective(args):
155
182
  text=True,
156
183
  timeout=args.timeout,
157
184
  cwd=Path(__file__).parent,
185
+ env=env,
158
186
  )
159
187
 
160
188
  # Read best val_loss from training_history.csv (reliable machine-readable)
@@ -248,7 +276,10 @@ Examples:
248
276
  "--n_trials", type=int, default=50, help="Number of HPO trials (default: 50)"
249
277
  )
250
278
  parser.add_argument(
251
- "--n_jobs", type=int, default=1, help="Parallel trials (default: 1)"
279
+ "--n_jobs",
280
+ type=int,
281
+ default=-1,
282
+ help="Parallel trials (-1 = auto-detect GPUs, default: -1)",
252
283
  )
253
284
  parser.add_argument(
254
285
  "--quick",
@@ -287,6 +318,13 @@ Examples:
287
318
  default=None,
288
319
  help=f"Losses to search (default: {DEFAULT_LOSSES})",
289
320
  )
321
+ parser.add_argument(
322
+ "--batch_sizes",
323
+ type=int,
324
+ nargs="+",
325
+ default=None,
326
+ help="Batch sizes to search (default: 16 32 64 128)",
327
+ )
290
328
 
291
329
  # Training settings for each trial
292
330
  parser.add_argument(
@@ -315,11 +353,30 @@ Examples:
315
353
 
316
354
  args = parser.parse_args()
317
355
 
356
+ # Convert to absolute path (child processes may run in different cwd)
357
+ args.data_path = str(Path(args.data_path).resolve())
358
+
318
359
  # Validate data path
319
360
  if not Path(args.data_path).exists():
320
361
  print(f"Error: Data file not found: {args.data_path}")
321
362
  sys.exit(1)
322
363
 
364
+ # Auto-detect GPUs for n_jobs if not specified
365
+ if args.n_jobs == -1:
366
+ try:
367
+ result_gpu = subprocess.run(
368
+ ["nvidia-smi", "--list-gpus"],
369
+ capture_output=True,
370
+ text=True,
371
+ )
372
+ if result_gpu.returncode == 0:
373
+ args.n_jobs = max(1, len(result_gpu.stdout.strip().split("\n")))
374
+ else:
375
+ args.n_jobs = 1
376
+ except Exception:
377
+ args.n_jobs = 1
378
+ print(f"Auto-detected {args.n_jobs} GPU(s) for parallel trials")
379
+
323
380
  # Create study
324
381
  print("=" * 60)
325
382
  print("WaveDL Hyperparameter Optimization")