wavedl 1.4.6__tar.gz → 1.5.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {wavedl-1.4.6/src/wavedl.egg-info → wavedl-1.5.0}/PKG-INFO +115 -19
- {wavedl-1.4.6 → wavedl-1.5.0}/README.md +114 -18
- {wavedl-1.4.6 → wavedl-1.5.0}/src/wavedl/__init__.py +1 -1
- {wavedl-1.4.6 → wavedl-1.5.0}/src/wavedl/hpo.py +9 -1
- {wavedl-1.4.6 → wavedl-1.5.0}/src/wavedl/train.py +73 -6
- {wavedl-1.4.6 → wavedl-1.5.0}/src/wavedl/utils/__init__.py +11 -0
- {wavedl-1.4.6 → wavedl-1.5.0}/src/wavedl/utils/config.py +10 -0
- wavedl-1.5.0/src/wavedl/utils/constraints.py +470 -0
- {wavedl-1.4.6 → wavedl-1.5.0}/src/wavedl/utils/metrics.py +49 -2
- {wavedl-1.4.6 → wavedl-1.5.0/src/wavedl.egg-info}/PKG-INFO +115 -19
- {wavedl-1.4.6 → wavedl-1.5.0}/src/wavedl.egg-info/SOURCES.txt +1 -0
- {wavedl-1.4.6 → wavedl-1.5.0}/LICENSE +0 -0
- {wavedl-1.4.6 → wavedl-1.5.0}/pyproject.toml +0 -0
- {wavedl-1.4.6 → wavedl-1.5.0}/setup.cfg +0 -0
- {wavedl-1.4.6 → wavedl-1.5.0}/src/wavedl/hpc.py +0 -0
- {wavedl-1.4.6 → wavedl-1.5.0}/src/wavedl/models/__init__.py +0 -0
- {wavedl-1.4.6 → wavedl-1.5.0}/src/wavedl/models/_template.py +0 -0
- {wavedl-1.4.6 → wavedl-1.5.0}/src/wavedl/models/base.py +0 -0
- {wavedl-1.4.6 → wavedl-1.5.0}/src/wavedl/models/cnn.py +0 -0
- {wavedl-1.4.6 → wavedl-1.5.0}/src/wavedl/models/convnext.py +0 -0
- {wavedl-1.4.6 → wavedl-1.5.0}/src/wavedl/models/densenet.py +0 -0
- {wavedl-1.4.6 → wavedl-1.5.0}/src/wavedl/models/efficientnet.py +0 -0
- {wavedl-1.4.6 → wavedl-1.5.0}/src/wavedl/models/efficientnetv2.py +0 -0
- {wavedl-1.4.6 → wavedl-1.5.0}/src/wavedl/models/mobilenetv3.py +0 -0
- {wavedl-1.4.6 → wavedl-1.5.0}/src/wavedl/models/registry.py +0 -0
- {wavedl-1.4.6 → wavedl-1.5.0}/src/wavedl/models/regnet.py +0 -0
- {wavedl-1.4.6 → wavedl-1.5.0}/src/wavedl/models/resnet.py +0 -0
- {wavedl-1.4.6 → wavedl-1.5.0}/src/wavedl/models/resnet3d.py +0 -0
- {wavedl-1.4.6 → wavedl-1.5.0}/src/wavedl/models/swin.py +0 -0
- {wavedl-1.4.6 → wavedl-1.5.0}/src/wavedl/models/tcn.py +0 -0
- {wavedl-1.4.6 → wavedl-1.5.0}/src/wavedl/models/unet.py +0 -0
- {wavedl-1.4.6 → wavedl-1.5.0}/src/wavedl/models/vit.py +0 -0
- {wavedl-1.4.6 → wavedl-1.5.0}/src/wavedl/test.py +0 -0
- {wavedl-1.4.6 → wavedl-1.5.0}/src/wavedl/utils/cross_validation.py +0 -0
- {wavedl-1.4.6 → wavedl-1.5.0}/src/wavedl/utils/data.py +0 -0
- {wavedl-1.4.6 → wavedl-1.5.0}/src/wavedl/utils/distributed.py +0 -0
- {wavedl-1.4.6 → wavedl-1.5.0}/src/wavedl/utils/losses.py +0 -0
- {wavedl-1.4.6 → wavedl-1.5.0}/src/wavedl/utils/optimizers.py +0 -0
- {wavedl-1.4.6 → wavedl-1.5.0}/src/wavedl/utils/schedulers.py +0 -0
- {wavedl-1.4.6 → wavedl-1.5.0}/src/wavedl.egg-info/dependency_links.txt +0 -0
- {wavedl-1.4.6 → wavedl-1.5.0}/src/wavedl.egg-info/entry_points.txt +0 -0
- {wavedl-1.4.6 → wavedl-1.5.0}/src/wavedl.egg-info/requires.txt +0 -0
- {wavedl-1.4.6 → wavedl-1.5.0}/src/wavedl.egg-info/top_level.txt +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.2
|
|
2
2
|
Name: wavedl
|
|
3
|
-
Version: 1.
|
|
3
|
+
Version: 1.5.0
|
|
4
4
|
Summary: A Scalable Deep Learning Framework for Wave-Based Inverse Problems
|
|
5
5
|
Author: Ductho Le
|
|
6
6
|
License: MIT
|
|
@@ -113,14 +113,12 @@ Train on datasets larger than RAM:
|
|
|
113
113
|
</td>
|
|
114
114
|
<td width="50%" valign="top">
|
|
115
115
|
|
|
116
|
-
**🧠
|
|
116
|
+
**🧠 Models? We've Got Options**
|
|
117
117
|
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
```
|
|
123
|
-
Design your model. Register with one line.
|
|
118
|
+
38 architectures, ready to go:
|
|
119
|
+
- CNNs, ResNets, ViTs, EfficientNets...
|
|
120
|
+
- All adapted for regression
|
|
121
|
+
- [Add your own](#adding-custom-models) in one line
|
|
124
122
|
|
|
125
123
|
</td>
|
|
126
124
|
</tr>
|
|
@@ -137,12 +135,12 @@ Multi-GPU training without the pain:
|
|
|
137
135
|
</td>
|
|
138
136
|
<td width="50%" valign="top">
|
|
139
137
|
|
|
140
|
-
|
|
138
|
+
**🔬 Physics-Constrained Training**
|
|
141
139
|
|
|
142
|
-
|
|
143
|
-
-
|
|
144
|
-
-
|
|
145
|
-
-
|
|
140
|
+
Make your model respect the laws:
|
|
141
|
+
- Enforce bounds, positivity, equations
|
|
142
|
+
- Simple expression syntax or Python
|
|
143
|
+
- [Custom constraints](#physical-constraints) for various laws
|
|
146
144
|
|
|
147
145
|
</td>
|
|
148
146
|
</tr>
|
|
@@ -383,7 +381,7 @@ WaveDL/
|
|
|
383
381
|
├── configs/ # YAML config templates
|
|
384
382
|
├── examples/ # Ready-to-run examples
|
|
385
383
|
├── notebooks/ # Jupyter notebooks
|
|
386
|
-
├── unit_tests/ # Pytest test suite (
|
|
384
|
+
├── unit_tests/ # Pytest test suite (725 tests)
|
|
387
385
|
│
|
|
388
386
|
├── pyproject.toml # Package config, dependencies
|
|
389
387
|
├── CHANGELOG.md # Version history
|
|
@@ -727,6 +725,104 @@ seed: 2025
|
|
|
727
725
|
|
|
728
726
|
</details>
|
|
729
727
|
|
|
728
|
+
<details>
|
|
729
|
+
<summary><b>Physical Constraints</b> — Enforce Physics During Training</summary>
|
|
730
|
+
|
|
731
|
+
Add penalty terms to the loss function to enforce physical laws:
|
|
732
|
+
|
|
733
|
+
```
|
|
734
|
+
Total Loss = Data Loss + weight × penalty(violation)
|
|
735
|
+
```
|
|
736
|
+
|
|
737
|
+
### Expression Constraints
|
|
738
|
+
|
|
739
|
+
```bash
|
|
740
|
+
# Positivity
|
|
741
|
+
--constraint "y0 > 0"
|
|
742
|
+
|
|
743
|
+
# Bounds
|
|
744
|
+
--constraint "y0 >= 0" "y0 <= 1"
|
|
745
|
+
|
|
746
|
+
# Equations (penalize deviations from zero)
|
|
747
|
+
--constraint "y2 - y0 * y1"
|
|
748
|
+
|
|
749
|
+
# Input-dependent constraints
|
|
750
|
+
--constraint "y0 - 2*x[0]"
|
|
751
|
+
|
|
752
|
+
# Multiple constraints with different weights
|
|
753
|
+
--constraint "y0 > 0" "y1 - y2" --constraint_weight 0.1 1.0
|
|
754
|
+
```
|
|
755
|
+
|
|
756
|
+
### Custom Python Constraints
|
|
757
|
+
|
|
758
|
+
For complex physics (matrix operations, implicit equations):
|
|
759
|
+
|
|
760
|
+
```python
|
|
761
|
+
# my_constraint.py
|
|
762
|
+
import torch
|
|
763
|
+
|
|
764
|
+
def constraint(pred, inputs=None):
|
|
765
|
+
"""
|
|
766
|
+
Args:
|
|
767
|
+
pred: (batch, num_outputs)
|
|
768
|
+
inputs: (batch, features) or (batch, C, H, W) or (batch, C, D, H, W)
|
|
769
|
+
Returns:
|
|
770
|
+
(batch,) — violation per sample (0 = satisfied)
|
|
771
|
+
"""
|
|
772
|
+
# Outputs (same for all data types)
|
|
773
|
+
y0, y1, y2 = pred[:, 0], pred[:, 1], pred[:, 2]
|
|
774
|
+
|
|
775
|
+
# Inputs — Tabular: (batch, features)
|
|
776
|
+
# x0 = inputs[:, 0] # Feature 0
|
|
777
|
+
# x_sum = inputs.sum(dim=1) # Sum all features
|
|
778
|
+
|
|
779
|
+
# Inputs — Images: (batch, C, H, W)
|
|
780
|
+
# pixel = inputs[:, 0, 3, 5] # Pixel at (3,5), channel 0
|
|
781
|
+
# img_mean = inputs.mean(dim=(1,2,3)) # Mean over C,H,W
|
|
782
|
+
|
|
783
|
+
# Inputs — 3D Volumes: (batch, C, D, H, W)
|
|
784
|
+
# voxel = inputs[:, 0, 2, 3, 5] # Voxel at (2,3,5), channel 0
|
|
785
|
+
|
|
786
|
+
# Example constraints:
|
|
787
|
+
# return y2 - y0 * y1 # Wave equation
|
|
788
|
+
# return y0 - 2 * inputs[:, 0] # Output = 2×input
|
|
789
|
+
# return inputs[:, 0, 3, 5] * y0 + inputs[:, 0, 6, 7] * y1 # Mixed
|
|
790
|
+
|
|
791
|
+
return y0 - y1 * y2
|
|
792
|
+
```
|
|
793
|
+
|
|
794
|
+
```bash
|
|
795
|
+
--constraint_file my_constraint.py --constraint_weight 1.0
|
|
796
|
+
```
|
|
797
|
+
|
|
798
|
+
---
|
|
799
|
+
|
|
800
|
+
### Reference
|
|
801
|
+
|
|
802
|
+
| Argument | Default | Description |
|
|
803
|
+
|----------|---------|-------------|
|
|
804
|
+
| `--constraint` | — | Expression(s): `"y0 > 0"`, `"y0 - y1*y2"` |
|
|
805
|
+
| `--constraint_file` | — | Python file with `constraint(pred, inputs)` |
|
|
806
|
+
| `--constraint_weight` | `0.1` | Penalty weight(s) |
|
|
807
|
+
| `--constraint_reduction` | `mse` | `mse` (squared) or `mae` (linear) |
|
|
808
|
+
|
|
809
|
+
#### Expression Syntax
|
|
810
|
+
|
|
811
|
+
| Variable | Meaning |
|
|
812
|
+
|----------|---------|
|
|
813
|
+
| `y0`, `y1`, ... | Model outputs |
|
|
814
|
+
| `x[0]`, `x[1]`, ... | Input values (1D tabular) |
|
|
815
|
+
| `x[i,j]`, `x[i,j,k]` | Input values (2D/3D: images, volumes) |
|
|
816
|
+
| `x_mean`, `x_sum`, `x_max`, `x_min`, `x_std` | Input aggregates |
|
|
817
|
+
|
|
818
|
+
**Operators:** `+`, `-`, `*`, `/`, `**`, `>`, `<`, `>=`, `<=`, `==`
|
|
819
|
+
|
|
820
|
+
**Functions:** `sin`, `cos`, `exp`, `log`, `sqrt`, `sigmoid`, `softplus`, `tanh`, `relu`, `abs`
|
|
821
|
+
|
|
822
|
+
</details>
|
|
823
|
+
|
|
824
|
+
|
|
825
|
+
|
|
730
826
|
<details>
|
|
731
827
|
<summary><b>Hyperparameter Search (HPO)</b></summary>
|
|
732
828
|
|
|
@@ -766,7 +862,7 @@ accelerate launch -m wavedl.train --data_path train.npz --model cnn --lr 3.2e-4
|
|
|
766
862
|
| Schedulers | [all 8](#learning-rate-schedulers) | `--schedulers X Y` |
|
|
767
863
|
| Losses | [all 6](#loss-functions) | `--losses X Y` |
|
|
768
864
|
| Learning rate | 1e-5 → 1e-2 | (always searched) |
|
|
769
|
-
| Batch size |
|
|
865
|
+
| Batch size | 16, 32, 64, 128 | (always searched) |
|
|
770
866
|
|
|
771
867
|
**Quick Mode** (`--quick`):
|
|
772
868
|
- Uses minimal defaults: cnn + adamw + plateau + mse
|
|
@@ -938,12 +1034,12 @@ The `examples/` folder contains a **complete, ready-to-run example** for **mater
|
|
|
938
1034
|
```bash
|
|
939
1035
|
# Run inference on the example data
|
|
940
1036
|
python -m wavedl.test --checkpoint ./examples/elastic_cnn_example/best_checkpoint \
|
|
941
|
-
--data_path ./examples/elastic_cnn_example/
|
|
1037
|
+
--data_path ./examples/elastic_cnn_example/Test_data_500.mat \
|
|
942
1038
|
--plot --save_predictions --output_dir ./examples/elastic_cnn_example/test_results
|
|
943
1039
|
|
|
944
1040
|
# Export to ONNX (already included as model.onnx)
|
|
945
1041
|
python -m wavedl.test --checkpoint ./examples/elastic_cnn_example/best_checkpoint \
|
|
946
|
-
--data_path ./examples/elastic_cnn_example/
|
|
1042
|
+
--data_path ./examples/elastic_cnn_example/Test_data_500.mat \
|
|
947
1043
|
--export onnx --export_path ./examples/elastic_cnn_example/model.onnx
|
|
948
1044
|
```
|
|
949
1045
|
|
|
@@ -952,7 +1048,7 @@ python -m wavedl.test --checkpoint ./examples/elastic_cnn_example/best_checkpoin
|
|
|
952
1048
|
| File | Description |
|
|
953
1049
|
|------|-------------|
|
|
954
1050
|
| `best_checkpoint/` | Pre-trained CNN checkpoint |
|
|
955
|
-
| `
|
|
1051
|
+
| `Test_data_500.mat` | 500 sample test set (500×500 dispersion curves → *h*, √(*E*/ρ), *ν*) |
|
|
956
1052
|
| `model.onnx` | ONNX export with embedded de-normalization |
|
|
957
1053
|
| `training_history.csv` | Epoch-by-epoch training metrics (loss, R², LR, etc.) |
|
|
958
1054
|
| `training_curves.png` | Training/validation loss and learning rate plot |
|
|
@@ -963,7 +1059,7 @@ python -m wavedl.test --checkpoint ./examples/elastic_cnn_example/best_checkpoin
|
|
|
963
1059
|
|
|
964
1060
|
<p align="center">
|
|
965
1061
|
<img src="examples/elastic_cnn_example/training_curves.png" alt="Training curves" width="600"><br>
|
|
966
|
-
<em>Training and validation loss over
|
|
1062
|
+
<em>Training and validation loss over 227 epochs with <code>onecycle</code> learning rate schedule</em>
|
|
967
1063
|
</p>
|
|
968
1064
|
|
|
969
1065
|
**Inference Results:**
|
|
@@ -68,14 +68,12 @@ Train on datasets larger than RAM:
|
|
|
68
68
|
</td>
|
|
69
69
|
<td width="50%" valign="top">
|
|
70
70
|
|
|
71
|
-
**🧠
|
|
71
|
+
**🧠 Models? We've Got Options**
|
|
72
72
|
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
```
|
|
78
|
-
Design your model. Register with one line.
|
|
73
|
+
38 architectures, ready to go:
|
|
74
|
+
- CNNs, ResNets, ViTs, EfficientNets...
|
|
75
|
+
- All adapted for regression
|
|
76
|
+
- [Add your own](#adding-custom-models) in one line
|
|
79
77
|
|
|
80
78
|
</td>
|
|
81
79
|
</tr>
|
|
@@ -92,12 +90,12 @@ Multi-GPU training without the pain:
|
|
|
92
90
|
</td>
|
|
93
91
|
<td width="50%" valign="top">
|
|
94
92
|
|
|
95
|
-
|
|
93
|
+
**🔬 Physics-Constrained Training**
|
|
96
94
|
|
|
97
|
-
|
|
98
|
-
-
|
|
99
|
-
-
|
|
100
|
-
-
|
|
95
|
+
Make your model respect the laws:
|
|
96
|
+
- Enforce bounds, positivity, equations
|
|
97
|
+
- Simple expression syntax or Python
|
|
98
|
+
- [Custom constraints](#physical-constraints) for various laws
|
|
101
99
|
|
|
102
100
|
</td>
|
|
103
101
|
</tr>
|
|
@@ -338,7 +336,7 @@ WaveDL/
|
|
|
338
336
|
├── configs/ # YAML config templates
|
|
339
337
|
├── examples/ # Ready-to-run examples
|
|
340
338
|
├── notebooks/ # Jupyter notebooks
|
|
341
|
-
├── unit_tests/ # Pytest test suite (
|
|
339
|
+
├── unit_tests/ # Pytest test suite (725 tests)
|
|
342
340
|
│
|
|
343
341
|
├── pyproject.toml # Package config, dependencies
|
|
344
342
|
├── CHANGELOG.md # Version history
|
|
@@ -682,6 +680,104 @@ seed: 2025
|
|
|
682
680
|
|
|
683
681
|
</details>
|
|
684
682
|
|
|
683
|
+
<details>
|
|
684
|
+
<summary><b>Physical Constraints</b> — Enforce Physics During Training</summary>
|
|
685
|
+
|
|
686
|
+
Add penalty terms to the loss function to enforce physical laws:
|
|
687
|
+
|
|
688
|
+
```
|
|
689
|
+
Total Loss = Data Loss + weight × penalty(violation)
|
|
690
|
+
```
|
|
691
|
+
|
|
692
|
+
### Expression Constraints
|
|
693
|
+
|
|
694
|
+
```bash
|
|
695
|
+
# Positivity
|
|
696
|
+
--constraint "y0 > 0"
|
|
697
|
+
|
|
698
|
+
# Bounds
|
|
699
|
+
--constraint "y0 >= 0" "y0 <= 1"
|
|
700
|
+
|
|
701
|
+
# Equations (penalize deviations from zero)
|
|
702
|
+
--constraint "y2 - y0 * y1"
|
|
703
|
+
|
|
704
|
+
# Input-dependent constraints
|
|
705
|
+
--constraint "y0 - 2*x[0]"
|
|
706
|
+
|
|
707
|
+
# Multiple constraints with different weights
|
|
708
|
+
--constraint "y0 > 0" "y1 - y2" --constraint_weight 0.1 1.0
|
|
709
|
+
```
|
|
710
|
+
|
|
711
|
+
### Custom Python Constraints
|
|
712
|
+
|
|
713
|
+
For complex physics (matrix operations, implicit equations):
|
|
714
|
+
|
|
715
|
+
```python
|
|
716
|
+
# my_constraint.py
|
|
717
|
+
import torch
|
|
718
|
+
|
|
719
|
+
def constraint(pred, inputs=None):
|
|
720
|
+
"""
|
|
721
|
+
Args:
|
|
722
|
+
pred: (batch, num_outputs)
|
|
723
|
+
inputs: (batch, features) or (batch, C, H, W) or (batch, C, D, H, W)
|
|
724
|
+
Returns:
|
|
725
|
+
(batch,) — violation per sample (0 = satisfied)
|
|
726
|
+
"""
|
|
727
|
+
# Outputs (same for all data types)
|
|
728
|
+
y0, y1, y2 = pred[:, 0], pred[:, 1], pred[:, 2]
|
|
729
|
+
|
|
730
|
+
# Inputs — Tabular: (batch, features)
|
|
731
|
+
# x0 = inputs[:, 0] # Feature 0
|
|
732
|
+
# x_sum = inputs.sum(dim=1) # Sum all features
|
|
733
|
+
|
|
734
|
+
# Inputs — Images: (batch, C, H, W)
|
|
735
|
+
# pixel = inputs[:, 0, 3, 5] # Pixel at (3,5), channel 0
|
|
736
|
+
# img_mean = inputs.mean(dim=(1,2,3)) # Mean over C,H,W
|
|
737
|
+
|
|
738
|
+
# Inputs — 3D Volumes: (batch, C, D, H, W)
|
|
739
|
+
# voxel = inputs[:, 0, 2, 3, 5] # Voxel at (2,3,5), channel 0
|
|
740
|
+
|
|
741
|
+
# Example constraints:
|
|
742
|
+
# return y2 - y0 * y1 # Wave equation
|
|
743
|
+
# return y0 - 2 * inputs[:, 0] # Output = 2×input
|
|
744
|
+
# return inputs[:, 0, 3, 5] * y0 + inputs[:, 0, 6, 7] * y1 # Mixed
|
|
745
|
+
|
|
746
|
+
return y0 - y1 * y2
|
|
747
|
+
```
|
|
748
|
+
|
|
749
|
+
```bash
|
|
750
|
+
--constraint_file my_constraint.py --constraint_weight 1.0
|
|
751
|
+
```
|
|
752
|
+
|
|
753
|
+
---
|
|
754
|
+
|
|
755
|
+
### Reference
|
|
756
|
+
|
|
757
|
+
| Argument | Default | Description |
|
|
758
|
+
|----------|---------|-------------|
|
|
759
|
+
| `--constraint` | — | Expression(s): `"y0 > 0"`, `"y0 - y1*y2"` |
|
|
760
|
+
| `--constraint_file` | — | Python file with `constraint(pred, inputs)` |
|
|
761
|
+
| `--constraint_weight` | `0.1` | Penalty weight(s) |
|
|
762
|
+
| `--constraint_reduction` | `mse` | `mse` (squared) or `mae` (linear) |
|
|
763
|
+
|
|
764
|
+
#### Expression Syntax
|
|
765
|
+
|
|
766
|
+
| Variable | Meaning |
|
|
767
|
+
|----------|---------|
|
|
768
|
+
| `y0`, `y1`, ... | Model outputs |
|
|
769
|
+
| `x[0]`, `x[1]`, ... | Input values (1D tabular) |
|
|
770
|
+
| `x[i,j]`, `x[i,j,k]` | Input values (2D/3D: images, volumes) |
|
|
771
|
+
| `x_mean`, `x_sum`, `x_max`, `x_min`, `x_std` | Input aggregates |
|
|
772
|
+
|
|
773
|
+
**Operators:** `+`, `-`, `*`, `/`, `**`, `>`, `<`, `>=`, `<=`, `==`
|
|
774
|
+
|
|
775
|
+
**Functions:** `sin`, `cos`, `exp`, `log`, `sqrt`, `sigmoid`, `softplus`, `tanh`, `relu`, `abs`
|
|
776
|
+
|
|
777
|
+
</details>
|
|
778
|
+
|
|
779
|
+
|
|
780
|
+
|
|
685
781
|
<details>
|
|
686
782
|
<summary><b>Hyperparameter Search (HPO)</b></summary>
|
|
687
783
|
|
|
@@ -721,7 +817,7 @@ accelerate launch -m wavedl.train --data_path train.npz --model cnn --lr 3.2e-4
|
|
|
721
817
|
| Schedulers | [all 8](#learning-rate-schedulers) | `--schedulers X Y` |
|
|
722
818
|
| Losses | [all 6](#loss-functions) | `--losses X Y` |
|
|
723
819
|
| Learning rate | 1e-5 → 1e-2 | (always searched) |
|
|
724
|
-
| Batch size |
|
|
820
|
+
| Batch size | 16, 32, 64, 128 | (always searched) |
|
|
725
821
|
|
|
726
822
|
**Quick Mode** (`--quick`):
|
|
727
823
|
- Uses minimal defaults: cnn + adamw + plateau + mse
|
|
@@ -893,12 +989,12 @@ The `examples/` folder contains a **complete, ready-to-run example** for **mater
|
|
|
893
989
|
```bash
|
|
894
990
|
# Run inference on the example data
|
|
895
991
|
python -m wavedl.test --checkpoint ./examples/elastic_cnn_example/best_checkpoint \
|
|
896
|
-
--data_path ./examples/elastic_cnn_example/
|
|
992
|
+
--data_path ./examples/elastic_cnn_example/Test_data_500.mat \
|
|
897
993
|
--plot --save_predictions --output_dir ./examples/elastic_cnn_example/test_results
|
|
898
994
|
|
|
899
995
|
# Export to ONNX (already included as model.onnx)
|
|
900
996
|
python -m wavedl.test --checkpoint ./examples/elastic_cnn_example/best_checkpoint \
|
|
901
|
-
--data_path ./examples/elastic_cnn_example/
|
|
997
|
+
--data_path ./examples/elastic_cnn_example/Test_data_500.mat \
|
|
902
998
|
--export onnx --export_path ./examples/elastic_cnn_example/model.onnx
|
|
903
999
|
```
|
|
904
1000
|
|
|
@@ -907,7 +1003,7 @@ python -m wavedl.test --checkpoint ./examples/elastic_cnn_example/best_checkpoin
|
|
|
907
1003
|
| File | Description |
|
|
908
1004
|
|------|-------------|
|
|
909
1005
|
| `best_checkpoint/` | Pre-trained CNN checkpoint |
|
|
910
|
-
| `
|
|
1006
|
+
| `Test_data_500.mat` | 500 sample test set (500×500 dispersion curves → *h*, √(*E*/ρ), *ν*) |
|
|
911
1007
|
| `model.onnx` | ONNX export with embedded de-normalization |
|
|
912
1008
|
| `training_history.csv` | Epoch-by-epoch training metrics (loss, R², LR, etc.) |
|
|
913
1009
|
| `training_curves.png` | Training/validation loss and learning rate plot |
|
|
@@ -918,7 +1014,7 @@ python -m wavedl.test --checkpoint ./examples/elastic_cnn_example/best_checkpoin
|
|
|
918
1014
|
|
|
919
1015
|
<p align="center">
|
|
920
1016
|
<img src="examples/elastic_cnn_example/training_curves.png" alt="Training curves" width="600"><br>
|
|
921
|
-
<em>Training and validation loss over
|
|
1017
|
+
<em>Training and validation loss over 227 epochs with <code>onecycle</code> learning rate schedule</em>
|
|
922
1018
|
</p>
|
|
923
1019
|
|
|
924
1020
|
**Inference Results:**
|
|
@@ -89,7 +89,8 @@ def create_objective(args):
|
|
|
89
89
|
# Suggest hyperparameters
|
|
90
90
|
model = trial.suggest_categorical("model", models)
|
|
91
91
|
lr = trial.suggest_float("lr", 1e-5, 1e-2, log=True)
|
|
92
|
-
|
|
92
|
+
batch_sizes = args.batch_sizes or [16, 32, 64, 128]
|
|
93
|
+
batch_size = trial.suggest_categorical("batch_size", batch_sizes)
|
|
93
94
|
optimizer = trial.suggest_categorical("optimizer", optimizers)
|
|
94
95
|
scheduler = trial.suggest_categorical("scheduler", schedulers)
|
|
95
96
|
loss = trial.suggest_categorical("loss", losses)
|
|
@@ -317,6 +318,13 @@ Examples:
|
|
|
317
318
|
default=None,
|
|
318
319
|
help=f"Losses to search (default: {DEFAULT_LOSSES})",
|
|
319
320
|
)
|
|
321
|
+
parser.add_argument(
|
|
322
|
+
"--batch_sizes",
|
|
323
|
+
type=int,
|
|
324
|
+
nargs="+",
|
|
325
|
+
default=None,
|
|
326
|
+
help="Batch sizes to search (default: 16 32 64 128)",
|
|
327
|
+
)
|
|
320
328
|
|
|
321
329
|
# Training settings for each trial
|
|
322
330
|
parser.add_argument(
|
|
@@ -375,6 +375,36 @@ def parse_args() -> argparse.Namespace:
|
|
|
375
375
|
help=argparse.SUPPRESS, # Hidden: use --precision instead
|
|
376
376
|
)
|
|
377
377
|
|
|
378
|
+
# Physical Constraints
|
|
379
|
+
parser.add_argument(
|
|
380
|
+
"--constraint",
|
|
381
|
+
type=str,
|
|
382
|
+
nargs="+",
|
|
383
|
+
default=[],
|
|
384
|
+
help="Soft constraint expressions: 'y0 - y1*y2' (penalize violations)",
|
|
385
|
+
)
|
|
386
|
+
|
|
387
|
+
parser.add_argument(
|
|
388
|
+
"--constraint_file",
|
|
389
|
+
type=str,
|
|
390
|
+
default=None,
|
|
391
|
+
help="Python file with constraint(pred, inputs) function",
|
|
392
|
+
)
|
|
393
|
+
parser.add_argument(
|
|
394
|
+
"--constraint_weight",
|
|
395
|
+
type=float,
|
|
396
|
+
nargs="+",
|
|
397
|
+
default=[0.1],
|
|
398
|
+
help="Weight(s) for soft constraints (one per constraint, or single shared weight)",
|
|
399
|
+
)
|
|
400
|
+
parser.add_argument(
|
|
401
|
+
"--constraint_reduction",
|
|
402
|
+
type=str,
|
|
403
|
+
default="mse",
|
|
404
|
+
choices=["mse", "mae"],
|
|
405
|
+
help="Reduction mode for constraint penalties",
|
|
406
|
+
)
|
|
407
|
+
|
|
378
408
|
# Logging
|
|
379
409
|
parser.add_argument(
|
|
380
410
|
"--wandb", action="store_true", help="Enable Weights & Biases logging"
|
|
@@ -553,7 +583,7 @@ def main():
|
|
|
553
583
|
return
|
|
554
584
|
|
|
555
585
|
# ==========================================================================
|
|
556
|
-
#
|
|
586
|
+
# SYSTEM INITIALIZATION
|
|
557
587
|
# ==========================================================================
|
|
558
588
|
# Initialize Accelerator for DDP and mixed precision
|
|
559
589
|
accelerator = Accelerator(
|
|
@@ -609,7 +639,7 @@ def main():
|
|
|
609
639
|
)
|
|
610
640
|
|
|
611
641
|
# ==========================================================================
|
|
612
|
-
#
|
|
642
|
+
# DATA & MODEL LOADING
|
|
613
643
|
# ==========================================================================
|
|
614
644
|
train_dl, val_dl, scaler, in_shape, out_dim = prepare_data(
|
|
615
645
|
args, logger, accelerator, cache_dir=args.output_dir
|
|
@@ -663,7 +693,7 @@ def main():
|
|
|
663
693
|
)
|
|
664
694
|
|
|
665
695
|
# ==========================================================================
|
|
666
|
-
#
|
|
696
|
+
# OPTIMIZER, SCHEDULER & LOSS CONFIGURATION
|
|
667
697
|
# ==========================================================================
|
|
668
698
|
# Parse comma-separated arguments with validation
|
|
669
699
|
try:
|
|
@@ -707,6 +737,43 @@ def main():
|
|
|
707
737
|
# Move criterion to device (important for WeightedMSELoss buffer)
|
|
708
738
|
criterion = criterion.to(accelerator.device)
|
|
709
739
|
|
|
740
|
+
# ==========================================================================
|
|
741
|
+
# PHYSICAL CONSTRAINTS INTEGRATION
|
|
742
|
+
# ==========================================================================
|
|
743
|
+
from wavedl.utils.constraints import (
|
|
744
|
+
PhysicsConstrainedLoss,
|
|
745
|
+
build_constraints,
|
|
746
|
+
)
|
|
747
|
+
|
|
748
|
+
# Build soft constraints
|
|
749
|
+
soft_constraints = build_constraints(
|
|
750
|
+
expressions=args.constraint,
|
|
751
|
+
file_path=args.constraint_file,
|
|
752
|
+
reduction=args.constraint_reduction,
|
|
753
|
+
)
|
|
754
|
+
|
|
755
|
+
# Wrap criterion with PhysicsConstrainedLoss if we have soft constraints
|
|
756
|
+
if soft_constraints:
|
|
757
|
+
# Pass output scaler so constraints can be evaluated in physical space
|
|
758
|
+
output_mean = scaler.mean_ if hasattr(scaler, "mean_") else None
|
|
759
|
+
output_std = scaler.scale_ if hasattr(scaler, "scale_") else None
|
|
760
|
+
criterion = PhysicsConstrainedLoss(
|
|
761
|
+
criterion,
|
|
762
|
+
soft_constraints,
|
|
763
|
+
weights=args.constraint_weight,
|
|
764
|
+
output_mean=output_mean,
|
|
765
|
+
output_std=output_std,
|
|
766
|
+
)
|
|
767
|
+
if accelerator.is_main_process:
|
|
768
|
+
logger.info(
|
|
769
|
+
f" 🔬 Physical constraints: {len(soft_constraints)} constraint(s) "
|
|
770
|
+
f"with weight(s) {args.constraint_weight}"
|
|
771
|
+
)
|
|
772
|
+
if output_mean is not None:
|
|
773
|
+
logger.info(
|
|
774
|
+
" 📐 Constraints evaluated in physical space (denormalized)"
|
|
775
|
+
)
|
|
776
|
+
|
|
710
777
|
# Track if scheduler should step per batch (OneCycleLR) or per epoch
|
|
711
778
|
scheduler_step_per_batch = not is_epoch_based(args.scheduler)
|
|
712
779
|
|
|
@@ -762,7 +829,7 @@ def main():
|
|
|
762
829
|
)
|
|
763
830
|
|
|
764
831
|
# ==========================================================================
|
|
765
|
-
#
|
|
832
|
+
# AUTO-RESUME / RESUME FROM CHECKPOINT
|
|
766
833
|
# ==========================================================================
|
|
767
834
|
start_epoch = 0
|
|
768
835
|
best_val_loss = float("inf")
|
|
@@ -818,7 +885,7 @@ def main():
|
|
|
818
885
|
raise FileNotFoundError(f"Checkpoint not found: {args.resume}")
|
|
819
886
|
|
|
820
887
|
# ==========================================================================
|
|
821
|
-
#
|
|
888
|
+
# PHYSICAL METRIC SETUP
|
|
822
889
|
# ==========================================================================
|
|
823
890
|
# Physical MAE = normalized MAE * scaler.scale_
|
|
824
891
|
phys_scale = torch.tensor(
|
|
@@ -826,7 +893,7 @@ def main():
|
|
|
826
893
|
)
|
|
827
894
|
|
|
828
895
|
# ==========================================================================
|
|
829
|
-
#
|
|
896
|
+
# TRAINING LOOP
|
|
830
897
|
# ==========================================================================
|
|
831
898
|
# Dynamic console header
|
|
832
899
|
if accelerator.is_main_process:
|
|
@@ -15,6 +15,12 @@ from .config import (
|
|
|
15
15
|
save_config,
|
|
16
16
|
validate_config,
|
|
17
17
|
)
|
|
18
|
+
from .constraints import (
|
|
19
|
+
ExpressionConstraint,
|
|
20
|
+
FileConstraint,
|
|
21
|
+
PhysicsConstrainedLoss,
|
|
22
|
+
build_constraints,
|
|
23
|
+
)
|
|
18
24
|
from .cross_validation import (
|
|
19
25
|
CVDataset,
|
|
20
26
|
run_cross_validation,
|
|
@@ -91,8 +97,11 @@ __all__ = [
|
|
|
91
97
|
"FIGURE_WIDTH_INCH",
|
|
92
98
|
"FONT_SIZE_TEXT",
|
|
93
99
|
"FONT_SIZE_TICKS",
|
|
100
|
+
# Constraints
|
|
94
101
|
"CVDataset",
|
|
95
102
|
"DataSource",
|
|
103
|
+
"ExpressionConstraint",
|
|
104
|
+
"FileConstraint",
|
|
96
105
|
"HDF5Source",
|
|
97
106
|
"LogCoshLoss",
|
|
98
107
|
"MATSource",
|
|
@@ -101,10 +110,12 @@ __all__ = [
|
|
|
101
110
|
# Metrics
|
|
102
111
|
"MetricTracker",
|
|
103
112
|
"NPZSource",
|
|
113
|
+
"PhysicsConstrainedLoss",
|
|
104
114
|
"WeightedMSELoss",
|
|
105
115
|
# Distributed
|
|
106
116
|
"broadcast_early_stop",
|
|
107
117
|
"broadcast_value",
|
|
118
|
+
"build_constraints",
|
|
108
119
|
"calc_pearson",
|
|
109
120
|
"calc_per_target_r2",
|
|
110
121
|
"configure_matplotlib_style",
|
|
@@ -306,6 +306,16 @@ def validate_config(
|
|
|
306
306
|
# Config
|
|
307
307
|
"config",
|
|
308
308
|
"list_models",
|
|
309
|
+
# Physical Constraints
|
|
310
|
+
"constraint",
|
|
311
|
+
"bounds",
|
|
312
|
+
"constraint_file",
|
|
313
|
+
"constraint_weight",
|
|
314
|
+
"constraint_reduction",
|
|
315
|
+
"positive",
|
|
316
|
+
"output_bounds",
|
|
317
|
+
"output_transform",
|
|
318
|
+
"output_formula",
|
|
309
319
|
# Metadata (internal)
|
|
310
320
|
"_metadata",
|
|
311
321
|
}
|