wavedl 1.4.6__tar.gz → 1.5.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (43) hide show
  1. {wavedl-1.4.6/src/wavedl.egg-info → wavedl-1.5.0}/PKG-INFO +115 -19
  2. {wavedl-1.4.6 → wavedl-1.5.0}/README.md +114 -18
  3. {wavedl-1.4.6 → wavedl-1.5.0}/src/wavedl/__init__.py +1 -1
  4. {wavedl-1.4.6 → wavedl-1.5.0}/src/wavedl/hpo.py +9 -1
  5. {wavedl-1.4.6 → wavedl-1.5.0}/src/wavedl/train.py +73 -6
  6. {wavedl-1.4.6 → wavedl-1.5.0}/src/wavedl/utils/__init__.py +11 -0
  7. {wavedl-1.4.6 → wavedl-1.5.0}/src/wavedl/utils/config.py +10 -0
  8. wavedl-1.5.0/src/wavedl/utils/constraints.py +470 -0
  9. {wavedl-1.4.6 → wavedl-1.5.0}/src/wavedl/utils/metrics.py +49 -2
  10. {wavedl-1.4.6 → wavedl-1.5.0/src/wavedl.egg-info}/PKG-INFO +115 -19
  11. {wavedl-1.4.6 → wavedl-1.5.0}/src/wavedl.egg-info/SOURCES.txt +1 -0
  12. {wavedl-1.4.6 → wavedl-1.5.0}/LICENSE +0 -0
  13. {wavedl-1.4.6 → wavedl-1.5.0}/pyproject.toml +0 -0
  14. {wavedl-1.4.6 → wavedl-1.5.0}/setup.cfg +0 -0
  15. {wavedl-1.4.6 → wavedl-1.5.0}/src/wavedl/hpc.py +0 -0
  16. {wavedl-1.4.6 → wavedl-1.5.0}/src/wavedl/models/__init__.py +0 -0
  17. {wavedl-1.4.6 → wavedl-1.5.0}/src/wavedl/models/_template.py +0 -0
  18. {wavedl-1.4.6 → wavedl-1.5.0}/src/wavedl/models/base.py +0 -0
  19. {wavedl-1.4.6 → wavedl-1.5.0}/src/wavedl/models/cnn.py +0 -0
  20. {wavedl-1.4.6 → wavedl-1.5.0}/src/wavedl/models/convnext.py +0 -0
  21. {wavedl-1.4.6 → wavedl-1.5.0}/src/wavedl/models/densenet.py +0 -0
  22. {wavedl-1.4.6 → wavedl-1.5.0}/src/wavedl/models/efficientnet.py +0 -0
  23. {wavedl-1.4.6 → wavedl-1.5.0}/src/wavedl/models/efficientnetv2.py +0 -0
  24. {wavedl-1.4.6 → wavedl-1.5.0}/src/wavedl/models/mobilenetv3.py +0 -0
  25. {wavedl-1.4.6 → wavedl-1.5.0}/src/wavedl/models/registry.py +0 -0
  26. {wavedl-1.4.6 → wavedl-1.5.0}/src/wavedl/models/regnet.py +0 -0
  27. {wavedl-1.4.6 → wavedl-1.5.0}/src/wavedl/models/resnet.py +0 -0
  28. {wavedl-1.4.6 → wavedl-1.5.0}/src/wavedl/models/resnet3d.py +0 -0
  29. {wavedl-1.4.6 → wavedl-1.5.0}/src/wavedl/models/swin.py +0 -0
  30. {wavedl-1.4.6 → wavedl-1.5.0}/src/wavedl/models/tcn.py +0 -0
  31. {wavedl-1.4.6 → wavedl-1.5.0}/src/wavedl/models/unet.py +0 -0
  32. {wavedl-1.4.6 → wavedl-1.5.0}/src/wavedl/models/vit.py +0 -0
  33. {wavedl-1.4.6 → wavedl-1.5.0}/src/wavedl/test.py +0 -0
  34. {wavedl-1.4.6 → wavedl-1.5.0}/src/wavedl/utils/cross_validation.py +0 -0
  35. {wavedl-1.4.6 → wavedl-1.5.0}/src/wavedl/utils/data.py +0 -0
  36. {wavedl-1.4.6 → wavedl-1.5.0}/src/wavedl/utils/distributed.py +0 -0
  37. {wavedl-1.4.6 → wavedl-1.5.0}/src/wavedl/utils/losses.py +0 -0
  38. {wavedl-1.4.6 → wavedl-1.5.0}/src/wavedl/utils/optimizers.py +0 -0
  39. {wavedl-1.4.6 → wavedl-1.5.0}/src/wavedl/utils/schedulers.py +0 -0
  40. {wavedl-1.4.6 → wavedl-1.5.0}/src/wavedl.egg-info/dependency_links.txt +0 -0
  41. {wavedl-1.4.6 → wavedl-1.5.0}/src/wavedl.egg-info/entry_points.txt +0 -0
  42. {wavedl-1.4.6 → wavedl-1.5.0}/src/wavedl.egg-info/requires.txt +0 -0
  43. {wavedl-1.4.6 → wavedl-1.5.0}/src/wavedl.egg-info/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.2
2
2
  Name: wavedl
3
- Version: 1.4.6
3
+ Version: 1.5.0
4
4
  Summary: A Scalable Deep Learning Framework for Wave-Based Inverse Problems
5
5
  Author: Ductho Le
6
6
  License: MIT
@@ -113,14 +113,12 @@ Train on datasets larger than RAM:
113
113
  </td>
114
114
  <td width="50%" valign="top">
115
115
 
116
- **🧠 One-Line Model Registration**
116
+ **🧠 Models? We've Got Options**
117
117
 
118
- Plug in any architecture:
119
- ```python
120
- @register_model("my_net")
121
- class MyNet(BaseModel): ...
122
- ```
123
- Design your model. Register with one line.
118
+ 38 architectures, ready to go:
119
+ - CNNs, ResNets, ViTs, EfficientNets...
120
+ - All adapted for regression
121
+ - [Add your own](#adding-custom-models) in one line
124
122
 
125
123
  </td>
126
124
  </tr>
@@ -137,12 +135,12 @@ Multi-GPU training without the pain:
137
135
  </td>
138
136
  <td width="50%" valign="top">
139
137
 
140
- **📊 Publish-Ready Output**
138
+ **🔬 Physics-Constrained Training**
141
139
 
142
- Results go straight to your paper:
143
- - 11 diagnostic plots with LaTeX styling
144
- - Multi-format export (PNG, PDF, SVG, ...)
145
- - MAE in physical units per parameter
140
+ Make your model respect the laws:
141
+ - Enforce bounds, positivity, equations
142
+ - Simple expression syntax or Python
143
+ - [Custom constraints](#physical-constraints) for various laws
146
144
 
147
145
  </td>
148
146
  </tr>
@@ -383,7 +381,7 @@ WaveDL/
383
381
  ├── configs/ # YAML config templates
384
382
  ├── examples/ # Ready-to-run examples
385
383
  ├── notebooks/ # Jupyter notebooks
386
- ├── unit_tests/ # Pytest test suite (704 tests)
384
+ ├── unit_tests/ # Pytest test suite (725 tests)
387
385
 
388
386
  ├── pyproject.toml # Package config, dependencies
389
387
  ├── CHANGELOG.md # Version history
@@ -727,6 +725,104 @@ seed: 2025
727
725
 
728
726
  </details>
729
727
 
728
+ <details>
729
+ <summary><b>Physical Constraints</b> — Enforce Physics During Training</summary>
730
+
731
+ Add penalty terms to the loss function to enforce physical laws:
732
+
733
+ ```
734
+ Total Loss = Data Loss + weight × penalty(violation)
735
+ ```
736
+
737
+ ### Expression Constraints
738
+
739
+ ```bash
740
+ # Positivity
741
+ --constraint "y0 > 0"
742
+
743
+ # Bounds
744
+ --constraint "y0 >= 0" "y0 <= 1"
745
+
746
+ # Equations (penalize deviations from zero)
747
+ --constraint "y2 - y0 * y1"
748
+
749
+ # Input-dependent constraints
750
+ --constraint "y0 - 2*x[0]"
751
+
752
+ # Multiple constraints with different weights
753
+ --constraint "y0 > 0" "y1 - y2" --constraint_weight 0.1 1.0
754
+ ```
755
+
756
+ ### Custom Python Constraints
757
+
758
+ For complex physics (matrix operations, implicit equations):
759
+
760
+ ```python
761
+ # my_constraint.py
762
+ import torch
763
+
764
+ def constraint(pred, inputs=None):
765
+ """
766
+ Args:
767
+ pred: (batch, num_outputs)
768
+ inputs: (batch, features) or (batch, C, H, W) or (batch, C, D, H, W)
769
+ Returns:
770
+ (batch,) — violation per sample (0 = satisfied)
771
+ """
772
+ # Outputs (same for all data types)
773
+ y0, y1, y2 = pred[:, 0], pred[:, 1], pred[:, 2]
774
+
775
+ # Inputs — Tabular: (batch, features)
776
+ # x0 = inputs[:, 0] # Feature 0
777
+ # x_sum = inputs.sum(dim=1) # Sum all features
778
+
779
+ # Inputs — Images: (batch, C, H, W)
780
+ # pixel = inputs[:, 0, 3, 5] # Pixel at (3,5), channel 0
781
+ # img_mean = inputs.mean(dim=(1,2,3)) # Mean over C,H,W
782
+
783
+ # Inputs — 3D Volumes: (batch, C, D, H, W)
784
+ # voxel = inputs[:, 0, 2, 3, 5] # Voxel at (2,3,5), channel 0
785
+
786
+ # Example constraints:
787
+ # return y2 - y0 * y1 # Wave equation
788
+ # return y0 - 2 * inputs[:, 0] # Output = 2×input
789
+ # return inputs[:, 0, 3, 5] * y0 + inputs[:, 0, 6, 7] * y1 # Mixed
790
+
791
+ return y0 - y1 * y2
792
+ ```
793
+
794
+ ```bash
795
+ --constraint_file my_constraint.py --constraint_weight 1.0
796
+ ```
797
+
798
+ ---
799
+
800
+ ### Reference
801
+
802
+ | Argument | Default | Description |
803
+ |----------|---------|-------------|
804
+ | `--constraint` | — | Expression(s): `"y0 > 0"`, `"y0 - y1*y2"` |
805
+ | `--constraint_file` | — | Python file with `constraint(pred, inputs)` |
806
+ | `--constraint_weight` | `0.1` | Penalty weight(s) |
807
+ | `--constraint_reduction` | `mse` | `mse` (squared) or `mae` (linear) |
808
+
809
+ #### Expression Syntax
810
+
811
+ | Variable | Meaning |
812
+ |----------|---------|
813
+ | `y0`, `y1`, ... | Model outputs |
814
+ | `x[0]`, `x[1]`, ... | Input values (1D tabular) |
815
+ | `x[i,j]`, `x[i,j,k]` | Input values (2D/3D: images, volumes) |
816
+ | `x_mean`, `x_sum`, `x_max`, `x_min`, `x_std` | Input aggregates |
817
+
818
+ **Operators:** `+`, `-`, `*`, `/`, `**`, `>`, `<`, `>=`, `<=`, `==`
819
+
820
+ **Functions:** `sin`, `cos`, `exp`, `log`, `sqrt`, `sigmoid`, `softplus`, `tanh`, `relu`, `abs`
821
+
822
+ </details>
823
+
824
+
825
+
730
826
  <details>
731
827
  <summary><b>Hyperparameter Search (HPO)</b></summary>
732
828
 
@@ -766,7 +862,7 @@ accelerate launch -m wavedl.train --data_path train.npz --model cnn --lr 3.2e-4
766
862
  | Schedulers | [all 8](#learning-rate-schedulers) | `--schedulers X Y` |
767
863
  | Losses | [all 6](#loss-functions) | `--losses X Y` |
768
864
  | Learning rate | 1e-5 → 1e-2 | (always searched) |
769
- | Batch size | 64, 128, 256, 512 | (always searched) |
865
+ | Batch size | 16, 32, 64, 128 | (always searched) |
770
866
 
771
867
  **Quick Mode** (`--quick`):
772
868
  - Uses minimal defaults: cnn + adamw + plateau + mse
@@ -938,12 +1034,12 @@ The `examples/` folder contains a **complete, ready-to-run example** for **mater
938
1034
  ```bash
939
1035
  # Run inference on the example data
940
1036
  python -m wavedl.test --checkpoint ./examples/elastic_cnn_example/best_checkpoint \
941
- --data_path ./examples/elastic_cnn_example/Test_data_100.mat \
1037
+ --data_path ./examples/elastic_cnn_example/Test_data_500.mat \
942
1038
  --plot --save_predictions --output_dir ./examples/elastic_cnn_example/test_results
943
1039
 
944
1040
  # Export to ONNX (already included as model.onnx)
945
1041
  python -m wavedl.test --checkpoint ./examples/elastic_cnn_example/best_checkpoint \
946
- --data_path ./examples/elastic_cnn_example/Test_data_100.mat \
1042
+ --data_path ./examples/elastic_cnn_example/Test_data_500.mat \
947
1043
  --export onnx --export_path ./examples/elastic_cnn_example/model.onnx
948
1044
  ```
949
1045
 
@@ -952,7 +1048,7 @@ python -m wavedl.test --checkpoint ./examples/elastic_cnn_example/best_checkpoin
952
1048
  | File | Description |
953
1049
  |------|-------------|
954
1050
  | `best_checkpoint/` | Pre-trained CNN checkpoint |
955
- | `Test_data_100.mat` | 100 sample test set (500×500 dispersion curves → *h*, √(*E*/ρ), *ν*) |
1051
+ | `Test_data_500.mat` | 500 sample test set (500×500 dispersion curves → *h*, √(*E*/ρ), *ν*) |
956
1052
  | `model.onnx` | ONNX export with embedded de-normalization |
957
1053
  | `training_history.csv` | Epoch-by-epoch training metrics (loss, R², LR, etc.) |
958
1054
  | `training_curves.png` | Training/validation loss and learning rate plot |
@@ -963,7 +1059,7 @@ python -m wavedl.test --checkpoint ./examples/elastic_cnn_example/best_checkpoin
963
1059
 
964
1060
  <p align="center">
965
1061
  <img src="examples/elastic_cnn_example/training_curves.png" alt="Training curves" width="600"><br>
966
- <em>Training and validation loss over 162 epochs with learning rate schedule</em>
1062
+ <em>Training and validation loss over 227 epochs with <code>onecycle</code> learning rate schedule</em>
967
1063
  </p>
968
1064
 
969
1065
  **Inference Results:**
@@ -68,14 +68,12 @@ Train on datasets larger than RAM:
68
68
  </td>
69
69
  <td width="50%" valign="top">
70
70
 
71
- **🧠 One-Line Model Registration**
71
+ **🧠 Models? We've Got Options**
72
72
 
73
- Plug in any architecture:
74
- ```python
75
- @register_model("my_net")
76
- class MyNet(BaseModel): ...
77
- ```
78
- Design your model. Register with one line.
73
+ 38 architectures, ready to go:
74
+ - CNNs, ResNets, ViTs, EfficientNets...
75
+ - All adapted for regression
76
+ - [Add your own](#adding-custom-models) in one line
79
77
 
80
78
  </td>
81
79
  </tr>
@@ -92,12 +90,12 @@ Multi-GPU training without the pain:
92
90
  </td>
93
91
  <td width="50%" valign="top">
94
92
 
95
- **📊 Publish-Ready Output**
93
+ **🔬 Physics-Constrained Training**
96
94
 
97
- Results go straight to your paper:
98
- - 11 diagnostic plots with LaTeX styling
99
- - Multi-format export (PNG, PDF, SVG, ...)
100
- - MAE in physical units per parameter
95
+ Make your model respect the laws:
96
+ - Enforce bounds, positivity, equations
97
+ - Simple expression syntax or Python
98
+ - [Custom constraints](#physical-constraints) for various laws
101
99
 
102
100
  </td>
103
101
  </tr>
@@ -338,7 +336,7 @@ WaveDL/
338
336
  ├── configs/ # YAML config templates
339
337
  ├── examples/ # Ready-to-run examples
340
338
  ├── notebooks/ # Jupyter notebooks
341
- ├── unit_tests/ # Pytest test suite (704 tests)
339
+ ├── unit_tests/ # Pytest test suite (725 tests)
342
340
 
343
341
  ├── pyproject.toml # Package config, dependencies
344
342
  ├── CHANGELOG.md # Version history
@@ -682,6 +680,104 @@ seed: 2025
682
680
 
683
681
  </details>
684
682
 
683
+ <details>
684
+ <summary><b>Physical Constraints</b> — Enforce Physics During Training</summary>
685
+
686
+ Add penalty terms to the loss function to enforce physical laws:
687
+
688
+ ```
689
+ Total Loss = Data Loss + weight × penalty(violation)
690
+ ```
691
+
692
+ ### Expression Constraints
693
+
694
+ ```bash
695
+ # Positivity
696
+ --constraint "y0 > 0"
697
+
698
+ # Bounds
699
+ --constraint "y0 >= 0" "y0 <= 1"
700
+
701
+ # Equations (penalize deviations from zero)
702
+ --constraint "y2 - y0 * y1"
703
+
704
+ # Input-dependent constraints
705
+ --constraint "y0 - 2*x[0]"
706
+
707
+ # Multiple constraints with different weights
708
+ --constraint "y0 > 0" "y1 - y2" --constraint_weight 0.1 1.0
709
+ ```
710
+
711
+ ### Custom Python Constraints
712
+
713
+ For complex physics (matrix operations, implicit equations):
714
+
715
+ ```python
716
+ # my_constraint.py
717
+ import torch
718
+
719
+ def constraint(pred, inputs=None):
720
+ """
721
+ Args:
722
+ pred: (batch, num_outputs)
723
+ inputs: (batch, features) or (batch, C, H, W) or (batch, C, D, H, W)
724
+ Returns:
725
+ (batch,) — violation per sample (0 = satisfied)
726
+ """
727
+ # Outputs (same for all data types)
728
+ y0, y1, y2 = pred[:, 0], pred[:, 1], pred[:, 2]
729
+
730
+ # Inputs — Tabular: (batch, features)
731
+ # x0 = inputs[:, 0] # Feature 0
732
+ # x_sum = inputs.sum(dim=1) # Sum all features
733
+
734
+ # Inputs — Images: (batch, C, H, W)
735
+ # pixel = inputs[:, 0, 3, 5] # Pixel at (3,5), channel 0
736
+ # img_mean = inputs.mean(dim=(1,2,3)) # Mean over C,H,W
737
+
738
+ # Inputs — 3D Volumes: (batch, C, D, H, W)
739
+ # voxel = inputs[:, 0, 2, 3, 5] # Voxel at (2,3,5), channel 0
740
+
741
+ # Example constraints:
742
+ # return y2 - y0 * y1 # Wave equation
743
+ # return y0 - 2 * inputs[:, 0] # Output = 2×input
744
+ # return inputs[:, 0, 3, 5] * y0 + inputs[:, 0, 6, 7] * y1 # Mixed
745
+
746
+ return y0 - y1 * y2
747
+ ```
748
+
749
+ ```bash
750
+ --constraint_file my_constraint.py --constraint_weight 1.0
751
+ ```
752
+
753
+ ---
754
+
755
+ ### Reference
756
+
757
+ | Argument | Default | Description |
758
+ |----------|---------|-------------|
759
+ | `--constraint` | — | Expression(s): `"y0 > 0"`, `"y0 - y1*y2"` |
760
+ | `--constraint_file` | — | Python file with `constraint(pred, inputs)` |
761
+ | `--constraint_weight` | `0.1` | Penalty weight(s) |
762
+ | `--constraint_reduction` | `mse` | `mse` (squared) or `mae` (linear) |
763
+
764
+ #### Expression Syntax
765
+
766
+ | Variable | Meaning |
767
+ |----------|---------|
768
+ | `y0`, `y1`, ... | Model outputs |
769
+ | `x[0]`, `x[1]`, ... | Input values (1D tabular) |
770
+ | `x[i,j]`, `x[i,j,k]` | Input values (2D/3D: images, volumes) |
771
+ | `x_mean`, `x_sum`, `x_max`, `x_min`, `x_std` | Input aggregates |
772
+
773
+ **Operators:** `+`, `-`, `*`, `/`, `**`, `>`, `<`, `>=`, `<=`, `==`
774
+
775
+ **Functions:** `sin`, `cos`, `exp`, `log`, `sqrt`, `sigmoid`, `softplus`, `tanh`, `relu`, `abs`
776
+
777
+ </details>
778
+
779
+
780
+
685
781
  <details>
686
782
  <summary><b>Hyperparameter Search (HPO)</b></summary>
687
783
 
@@ -721,7 +817,7 @@ accelerate launch -m wavedl.train --data_path train.npz --model cnn --lr 3.2e-4
721
817
  | Schedulers | [all 8](#learning-rate-schedulers) | `--schedulers X Y` |
722
818
  | Losses | [all 6](#loss-functions) | `--losses X Y` |
723
819
  | Learning rate | 1e-5 → 1e-2 | (always searched) |
724
- | Batch size | 64, 128, 256, 512 | (always searched) |
820
+ | Batch size | 16, 32, 64, 128 | (always searched) |
725
821
 
726
822
  **Quick Mode** (`--quick`):
727
823
  - Uses minimal defaults: cnn + adamw + plateau + mse
@@ -893,12 +989,12 @@ The `examples/` folder contains a **complete, ready-to-run example** for **mater
893
989
  ```bash
894
990
  # Run inference on the example data
895
991
  python -m wavedl.test --checkpoint ./examples/elastic_cnn_example/best_checkpoint \
896
- --data_path ./examples/elastic_cnn_example/Test_data_100.mat \
992
+ --data_path ./examples/elastic_cnn_example/Test_data_500.mat \
897
993
  --plot --save_predictions --output_dir ./examples/elastic_cnn_example/test_results
898
994
 
899
995
  # Export to ONNX (already included as model.onnx)
900
996
  python -m wavedl.test --checkpoint ./examples/elastic_cnn_example/best_checkpoint \
901
- --data_path ./examples/elastic_cnn_example/Test_data_100.mat \
997
+ --data_path ./examples/elastic_cnn_example/Test_data_500.mat \
902
998
  --export onnx --export_path ./examples/elastic_cnn_example/model.onnx
903
999
  ```
904
1000
 
@@ -907,7 +1003,7 @@ python -m wavedl.test --checkpoint ./examples/elastic_cnn_example/best_checkpoin
907
1003
  | File | Description |
908
1004
  |------|-------------|
909
1005
  | `best_checkpoint/` | Pre-trained CNN checkpoint |
910
- | `Test_data_100.mat` | 100 sample test set (500×500 dispersion curves → *h*, √(*E*/ρ), *ν*) |
1006
+ | `Test_data_500.mat` | 500 sample test set (500×500 dispersion curves → *h*, √(*E*/ρ), *ν*) |
911
1007
  | `model.onnx` | ONNX export with embedded de-normalization |
912
1008
  | `training_history.csv` | Epoch-by-epoch training metrics (loss, R², LR, etc.) |
913
1009
  | `training_curves.png` | Training/validation loss and learning rate plot |
@@ -918,7 +1014,7 @@ python -m wavedl.test --checkpoint ./examples/elastic_cnn_example/best_checkpoin
918
1014
 
919
1015
  <p align="center">
920
1016
  <img src="examples/elastic_cnn_example/training_curves.png" alt="Training curves" width="600"><br>
921
- <em>Training and validation loss over 162 epochs with learning rate schedule</em>
1017
+ <em>Training and validation loss over 227 epochs with <code>onecycle</code> learning rate schedule</em>
922
1018
  </p>
923
1019
 
924
1020
  **Inference Results:**
@@ -18,7 +18,7 @@ For inference:
18
18
  # or: python -m wavedl.test --checkpoint best_checkpoint --data_path test.npz
19
19
  """
20
20
 
21
- __version__ = "1.4.6"
21
+ __version__ = "1.5.0"
22
22
  __author__ = "Ductho Le"
23
23
  __email__ = "ductho.le@outlook.com"
24
24
 
@@ -89,7 +89,8 @@ def create_objective(args):
89
89
  # Suggest hyperparameters
90
90
  model = trial.suggest_categorical("model", models)
91
91
  lr = trial.suggest_float("lr", 1e-5, 1e-2, log=True)
92
- batch_size = trial.suggest_categorical("batch_size", [64, 128, 256, 512])
92
+ batch_sizes = args.batch_sizes or [16, 32, 64, 128]
93
+ batch_size = trial.suggest_categorical("batch_size", batch_sizes)
93
94
  optimizer = trial.suggest_categorical("optimizer", optimizers)
94
95
  scheduler = trial.suggest_categorical("scheduler", schedulers)
95
96
  loss = trial.suggest_categorical("loss", losses)
@@ -317,6 +318,13 @@ Examples:
317
318
  default=None,
318
319
  help=f"Losses to search (default: {DEFAULT_LOSSES})",
319
320
  )
321
+ parser.add_argument(
322
+ "--batch_sizes",
323
+ type=int,
324
+ nargs="+",
325
+ default=None,
326
+ help="Batch sizes to search (default: 16 32 64 128)",
327
+ )
320
328
 
321
329
  # Training settings for each trial
322
330
  parser.add_argument(
@@ -375,6 +375,36 @@ def parse_args() -> argparse.Namespace:
375
375
  help=argparse.SUPPRESS, # Hidden: use --precision instead
376
376
  )
377
377
 
378
+ # Physical Constraints
379
+ parser.add_argument(
380
+ "--constraint",
381
+ type=str,
382
+ nargs="+",
383
+ default=[],
384
+ help="Soft constraint expressions: 'y0 - y1*y2' (penalize violations)",
385
+ )
386
+
387
+ parser.add_argument(
388
+ "--constraint_file",
389
+ type=str,
390
+ default=None,
391
+ help="Python file with constraint(pred, inputs) function",
392
+ )
393
+ parser.add_argument(
394
+ "--constraint_weight",
395
+ type=float,
396
+ nargs="+",
397
+ default=[0.1],
398
+ help="Weight(s) for soft constraints (one per constraint, or single shared weight)",
399
+ )
400
+ parser.add_argument(
401
+ "--constraint_reduction",
402
+ type=str,
403
+ default="mse",
404
+ choices=["mse", "mae"],
405
+ help="Reduction mode for constraint penalties",
406
+ )
407
+
378
408
  # Logging
379
409
  parser.add_argument(
380
410
  "--wandb", action="store_true", help="Enable Weights & Biases logging"
@@ -553,7 +583,7 @@ def main():
553
583
  return
554
584
 
555
585
  # ==========================================================================
556
- # 1. SYSTEM INITIALIZATION
586
+ # SYSTEM INITIALIZATION
557
587
  # ==========================================================================
558
588
  # Initialize Accelerator for DDP and mixed precision
559
589
  accelerator = Accelerator(
@@ -609,7 +639,7 @@ def main():
609
639
  )
610
640
 
611
641
  # ==========================================================================
612
- # 2. DATA & MODEL LOADING
642
+ # DATA & MODEL LOADING
613
643
  # ==========================================================================
614
644
  train_dl, val_dl, scaler, in_shape, out_dim = prepare_data(
615
645
  args, logger, accelerator, cache_dir=args.output_dir
@@ -663,7 +693,7 @@ def main():
663
693
  )
664
694
 
665
695
  # ==========================================================================
666
- # 2.5. OPTIMIZER, SCHEDULER & LOSS CONFIGURATION
696
+ # OPTIMIZER, SCHEDULER & LOSS CONFIGURATION
667
697
  # ==========================================================================
668
698
  # Parse comma-separated arguments with validation
669
699
  try:
@@ -707,6 +737,43 @@ def main():
707
737
  # Move criterion to device (important for WeightedMSELoss buffer)
708
738
  criterion = criterion.to(accelerator.device)
709
739
 
740
+ # ==========================================================================
741
+ # PHYSICAL CONSTRAINTS INTEGRATION
742
+ # ==========================================================================
743
+ from wavedl.utils.constraints import (
744
+ PhysicsConstrainedLoss,
745
+ build_constraints,
746
+ )
747
+
748
+ # Build soft constraints
749
+ soft_constraints = build_constraints(
750
+ expressions=args.constraint,
751
+ file_path=args.constraint_file,
752
+ reduction=args.constraint_reduction,
753
+ )
754
+
755
+ # Wrap criterion with PhysicsConstrainedLoss if we have soft constraints
756
+ if soft_constraints:
757
+ # Pass output scaler so constraints can be evaluated in physical space
758
+ output_mean = scaler.mean_ if hasattr(scaler, "mean_") else None
759
+ output_std = scaler.scale_ if hasattr(scaler, "scale_") else None
760
+ criterion = PhysicsConstrainedLoss(
761
+ criterion,
762
+ soft_constraints,
763
+ weights=args.constraint_weight,
764
+ output_mean=output_mean,
765
+ output_std=output_std,
766
+ )
767
+ if accelerator.is_main_process:
768
+ logger.info(
769
+ f" 🔬 Physical constraints: {len(soft_constraints)} constraint(s) "
770
+ f"with weight(s) {args.constraint_weight}"
771
+ )
772
+ if output_mean is not None:
773
+ logger.info(
774
+ " 📐 Constraints evaluated in physical space (denormalized)"
775
+ )
776
+
710
777
  # Track if scheduler should step per batch (OneCycleLR) or per epoch
711
778
  scheduler_step_per_batch = not is_epoch_based(args.scheduler)
712
779
 
@@ -762,7 +829,7 @@ def main():
762
829
  )
763
830
 
764
831
  # ==========================================================================
765
- # 3. AUTO-RESUME / RESUME FROM CHECKPOINT
832
+ # AUTO-RESUME / RESUME FROM CHECKPOINT
766
833
  # ==========================================================================
767
834
  start_epoch = 0
768
835
  best_val_loss = float("inf")
@@ -818,7 +885,7 @@ def main():
818
885
  raise FileNotFoundError(f"Checkpoint not found: {args.resume}")
819
886
 
820
887
  # ==========================================================================
821
- # 4. PHYSICAL METRIC SETUP
888
+ # PHYSICAL METRIC SETUP
822
889
  # ==========================================================================
823
890
  # Physical MAE = normalized MAE * scaler.scale_
824
891
  phys_scale = torch.tensor(
@@ -826,7 +893,7 @@ def main():
826
893
  )
827
894
 
828
895
  # ==========================================================================
829
- # 5. TRAINING LOOP
896
+ # TRAINING LOOP
830
897
  # ==========================================================================
831
898
  # Dynamic console header
832
899
  if accelerator.is_main_process:
@@ -15,6 +15,12 @@ from .config import (
15
15
  save_config,
16
16
  validate_config,
17
17
  )
18
+ from .constraints import (
19
+ ExpressionConstraint,
20
+ FileConstraint,
21
+ PhysicsConstrainedLoss,
22
+ build_constraints,
23
+ )
18
24
  from .cross_validation import (
19
25
  CVDataset,
20
26
  run_cross_validation,
@@ -91,8 +97,11 @@ __all__ = [
91
97
  "FIGURE_WIDTH_INCH",
92
98
  "FONT_SIZE_TEXT",
93
99
  "FONT_SIZE_TICKS",
100
+ # Constraints
94
101
  "CVDataset",
95
102
  "DataSource",
103
+ "ExpressionConstraint",
104
+ "FileConstraint",
96
105
  "HDF5Source",
97
106
  "LogCoshLoss",
98
107
  "MATSource",
@@ -101,10 +110,12 @@ __all__ = [
101
110
  # Metrics
102
111
  "MetricTracker",
103
112
  "NPZSource",
113
+ "PhysicsConstrainedLoss",
104
114
  "WeightedMSELoss",
105
115
  # Distributed
106
116
  "broadcast_early_stop",
107
117
  "broadcast_value",
118
+ "build_constraints",
108
119
  "calc_pearson",
109
120
  "calc_per_target_r2",
110
121
  "configure_matplotlib_style",
@@ -306,6 +306,16 @@ def validate_config(
306
306
  # Config
307
307
  "config",
308
308
  "list_models",
309
+ # Physical Constraints
310
+ "constraint",
311
+ "bounds",
312
+ "constraint_file",
313
+ "constraint_weight",
314
+ "constraint_reduction",
315
+ "positive",
316
+ "output_bounds",
317
+ "output_transform",
318
+ "output_formula",
309
319
  # Metadata (internal)
310
320
  "_metadata",
311
321
  }