wavedl 1.4.6__tar.gz → 1.5.1__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {wavedl-1.4.6/src/wavedl.egg-info → wavedl-1.5.1}/PKG-INFO +122 -19
- {wavedl-1.4.6 → wavedl-1.5.1}/README.md +121 -18
- {wavedl-1.4.6 → wavedl-1.5.1}/src/wavedl/__init__.py +1 -1
- {wavedl-1.4.6 → wavedl-1.5.1}/src/wavedl/hpo.py +9 -1
- {wavedl-1.4.6 → wavedl-1.5.1}/src/wavedl/models/vit.py +21 -0
- {wavedl-1.4.6 → wavedl-1.5.1}/src/wavedl/test.py +28 -5
- {wavedl-1.4.6 → wavedl-1.5.1}/src/wavedl/train.py +122 -15
- {wavedl-1.4.6 → wavedl-1.5.1}/src/wavedl/utils/__init__.py +11 -0
- {wavedl-1.4.6 → wavedl-1.5.1}/src/wavedl/utils/config.py +10 -0
- wavedl-1.5.1/src/wavedl/utils/constraints.py +470 -0
- {wavedl-1.4.6 → wavedl-1.5.1}/src/wavedl/utils/cross_validation.py +12 -2
- {wavedl-1.4.6 → wavedl-1.5.1}/src/wavedl/utils/data.py +26 -7
- {wavedl-1.4.6 → wavedl-1.5.1}/src/wavedl/utils/metrics.py +49 -2
- {wavedl-1.4.6 → wavedl-1.5.1/src/wavedl.egg-info}/PKG-INFO +122 -19
- {wavedl-1.4.6 → wavedl-1.5.1}/src/wavedl.egg-info/SOURCES.txt +1 -0
- {wavedl-1.4.6 → wavedl-1.5.1}/LICENSE +0 -0
- {wavedl-1.4.6 → wavedl-1.5.1}/pyproject.toml +0 -0
- {wavedl-1.4.6 → wavedl-1.5.1}/setup.cfg +0 -0
- {wavedl-1.4.6 → wavedl-1.5.1}/src/wavedl/hpc.py +0 -0
- {wavedl-1.4.6 → wavedl-1.5.1}/src/wavedl/models/__init__.py +0 -0
- {wavedl-1.4.6 → wavedl-1.5.1}/src/wavedl/models/_template.py +0 -0
- {wavedl-1.4.6 → wavedl-1.5.1}/src/wavedl/models/base.py +0 -0
- {wavedl-1.4.6 → wavedl-1.5.1}/src/wavedl/models/cnn.py +0 -0
- {wavedl-1.4.6 → wavedl-1.5.1}/src/wavedl/models/convnext.py +0 -0
- {wavedl-1.4.6 → wavedl-1.5.1}/src/wavedl/models/densenet.py +0 -0
- {wavedl-1.4.6 → wavedl-1.5.1}/src/wavedl/models/efficientnet.py +0 -0
- {wavedl-1.4.6 → wavedl-1.5.1}/src/wavedl/models/efficientnetv2.py +0 -0
- {wavedl-1.4.6 → wavedl-1.5.1}/src/wavedl/models/mobilenetv3.py +0 -0
- {wavedl-1.4.6 → wavedl-1.5.1}/src/wavedl/models/registry.py +0 -0
- {wavedl-1.4.6 → wavedl-1.5.1}/src/wavedl/models/regnet.py +0 -0
- {wavedl-1.4.6 → wavedl-1.5.1}/src/wavedl/models/resnet.py +0 -0
- {wavedl-1.4.6 → wavedl-1.5.1}/src/wavedl/models/resnet3d.py +0 -0
- {wavedl-1.4.6 → wavedl-1.5.1}/src/wavedl/models/swin.py +0 -0
- {wavedl-1.4.6 → wavedl-1.5.1}/src/wavedl/models/tcn.py +0 -0
- {wavedl-1.4.6 → wavedl-1.5.1}/src/wavedl/models/unet.py +0 -0
- {wavedl-1.4.6 → wavedl-1.5.1}/src/wavedl/utils/distributed.py +0 -0
- {wavedl-1.4.6 → wavedl-1.5.1}/src/wavedl/utils/losses.py +0 -0
- {wavedl-1.4.6 → wavedl-1.5.1}/src/wavedl/utils/optimizers.py +0 -0
- {wavedl-1.4.6 → wavedl-1.5.1}/src/wavedl/utils/schedulers.py +0 -0
- {wavedl-1.4.6 → wavedl-1.5.1}/src/wavedl.egg-info/dependency_links.txt +0 -0
- {wavedl-1.4.6 → wavedl-1.5.1}/src/wavedl.egg-info/entry_points.txt +0 -0
- {wavedl-1.4.6 → wavedl-1.5.1}/src/wavedl.egg-info/requires.txt +0 -0
- {wavedl-1.4.6 → wavedl-1.5.1}/src/wavedl.egg-info/top_level.txt +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.2
|
|
2
2
|
Name: wavedl
|
|
3
|
-
Version: 1.
|
|
3
|
+
Version: 1.5.1
|
|
4
4
|
Summary: A Scalable Deep Learning Framework for Wave-Based Inverse Problems
|
|
5
5
|
Author: Ductho Le
|
|
6
6
|
License: MIT
|
|
@@ -99,6 +99,7 @@ The framework handles the engineering challenges of large-scale deep learning
|
|
|
99
99
|
|
|
100
100
|
## ✨ Features
|
|
101
101
|
|
|
102
|
+
<div align="center">
|
|
102
103
|
<table width="100%">
|
|
103
104
|
<tr>
|
|
104
105
|
<td width="50%" valign="top">
|
|
@@ -113,14 +114,12 @@ Train on datasets larger than RAM:
|
|
|
113
114
|
</td>
|
|
114
115
|
<td width="50%" valign="top">
|
|
115
116
|
|
|
116
|
-
**🧠
|
|
117
|
+
**🧠 Models? We've Got Options**
|
|
117
118
|
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
```
|
|
123
|
-
Design your model. Register with one line.
|
|
119
|
+
38 architectures, ready to go:
|
|
120
|
+
- CNNs, ResNets, ViTs, EfficientNets...
|
|
121
|
+
- All adapted for regression
|
|
122
|
+
- [Add your own](#adding-custom-models) in one line
|
|
124
123
|
|
|
125
124
|
</td>
|
|
126
125
|
</tr>
|
|
@@ -137,12 +136,12 @@ Multi-GPU training without the pain:
|
|
|
137
136
|
</td>
|
|
138
137
|
<td width="50%" valign="top">
|
|
139
138
|
|
|
140
|
-
|
|
139
|
+
**🔬 Physics-Constrained Training**
|
|
141
140
|
|
|
142
|
-
|
|
143
|
-
-
|
|
144
|
-
-
|
|
145
|
-
-
|
|
141
|
+
Make your model respect the laws:
|
|
142
|
+
- Enforce bounds, positivity, equations
|
|
143
|
+
- Simple expression syntax or Python
|
|
144
|
+
- [Custom constraints](#physical-constraints) for various laws
|
|
146
145
|
|
|
147
146
|
</td>
|
|
148
147
|
</tr>
|
|
@@ -191,6 +190,7 @@ Deploy models anywhere:
|
|
|
191
190
|
</td>
|
|
192
191
|
</tr>
|
|
193
192
|
</table>
|
|
193
|
+
</div>
|
|
194
194
|
|
|
195
195
|
---
|
|
196
196
|
|
|
@@ -279,6 +279,10 @@ python -m wavedl.test --checkpoint <checkpoint_folder> --data_path <test_data> \
|
|
|
279
279
|
# Export model to ONNX for deployment (LabVIEW, MATLAB, C++, etc.)
|
|
280
280
|
python -m wavedl.test --checkpoint <checkpoint_folder> --data_path <test_data> \
|
|
281
281
|
--export onnx --export_path <output_file.onnx>
|
|
282
|
+
|
|
283
|
+
# For 3D volumes with small depth (e.g., 8×128×128), override auto-detection
|
|
284
|
+
python -m wavedl.test --checkpoint <checkpoint_folder> --data_path <test_data> \
|
|
285
|
+
--input_channels 1
|
|
282
286
|
```
|
|
283
287
|
|
|
284
288
|
**Output:**
|
|
@@ -374,6 +378,7 @@ WaveDL/
|
|
|
374
378
|
│ └── utils/ # Utilities
|
|
375
379
|
│ ├── data.py # Memory-mapped data pipeline
|
|
376
380
|
│ ├── metrics.py # R², Pearson, visualization
|
|
381
|
+
│ ├── constraints.py # Physical constraints for training
|
|
377
382
|
│ ├── distributed.py # DDP synchronization
|
|
378
383
|
│ ├── losses.py # Loss function factory
|
|
379
384
|
│ ├── optimizers.py # Optimizer factory
|
|
@@ -383,7 +388,7 @@ WaveDL/
|
|
|
383
388
|
├── configs/ # YAML config templates
|
|
384
389
|
├── examples/ # Ready-to-run examples
|
|
385
390
|
├── notebooks/ # Jupyter notebooks
|
|
386
|
-
├── unit_tests/ # Pytest test suite (
|
|
391
|
+
├── unit_tests/ # Pytest test suite (725 tests)
|
|
387
392
|
│
|
|
388
393
|
├── pyproject.toml # Package config, dependencies
|
|
389
394
|
├── CHANGELOG.md # Version history
|
|
@@ -727,6 +732,104 @@ seed: 2025
|
|
|
727
732
|
|
|
728
733
|
</details>
|
|
729
734
|
|
|
735
|
+
<details>
|
|
736
|
+
<summary><b>Physical Constraints</b> — Enforce Physics During Training</summary>
|
|
737
|
+
|
|
738
|
+
Add penalty terms to the loss function to enforce physical laws:
|
|
739
|
+
|
|
740
|
+
```
|
|
741
|
+
Total Loss = Data Loss + weight × penalty(violation)
|
|
742
|
+
```
|
|
743
|
+
|
|
744
|
+
### Expression Constraints
|
|
745
|
+
|
|
746
|
+
```bash
|
|
747
|
+
# Positivity
|
|
748
|
+
--constraint "y0 > 0"
|
|
749
|
+
|
|
750
|
+
# Bounds
|
|
751
|
+
--constraint "y0 >= 0" "y0 <= 1"
|
|
752
|
+
|
|
753
|
+
# Equations (penalize deviations from zero)
|
|
754
|
+
--constraint "y2 - y0 * y1"
|
|
755
|
+
|
|
756
|
+
# Input-dependent constraints
|
|
757
|
+
--constraint "y0 - 2*x[0]"
|
|
758
|
+
|
|
759
|
+
# Multiple constraints with different weights
|
|
760
|
+
--constraint "y0 > 0" "y1 - y2" --constraint_weight 0.1 1.0
|
|
761
|
+
```
|
|
762
|
+
|
|
763
|
+
### Custom Python Constraints
|
|
764
|
+
|
|
765
|
+
For complex physics (matrix operations, implicit equations):
|
|
766
|
+
|
|
767
|
+
```python
|
|
768
|
+
# my_constraint.py
|
|
769
|
+
import torch
|
|
770
|
+
|
|
771
|
+
def constraint(pred, inputs=None):
|
|
772
|
+
"""
|
|
773
|
+
Args:
|
|
774
|
+
pred: (batch, num_outputs)
|
|
775
|
+
inputs: (batch, features) or (batch, C, H, W) or (batch, C, D, H, W)
|
|
776
|
+
Returns:
|
|
777
|
+
(batch,) — violation per sample (0 = satisfied)
|
|
778
|
+
"""
|
|
779
|
+
# Outputs (same for all data types)
|
|
780
|
+
y0, y1, y2 = pred[:, 0], pred[:, 1], pred[:, 2]
|
|
781
|
+
|
|
782
|
+
# Inputs — Tabular: (batch, features)
|
|
783
|
+
# x0 = inputs[:, 0] # Feature 0
|
|
784
|
+
# x_sum = inputs.sum(dim=1) # Sum all features
|
|
785
|
+
|
|
786
|
+
# Inputs — Images: (batch, C, H, W)
|
|
787
|
+
# pixel = inputs[:, 0, 3, 5] # Pixel at (3,5), channel 0
|
|
788
|
+
# img_mean = inputs.mean(dim=(1,2,3)) # Mean over C,H,W
|
|
789
|
+
|
|
790
|
+
# Inputs — 3D Volumes: (batch, C, D, H, W)
|
|
791
|
+
# voxel = inputs[:, 0, 2, 3, 5] # Voxel at (2,3,5), channel 0
|
|
792
|
+
|
|
793
|
+
# Example constraints:
|
|
794
|
+
# return y2 - y0 * y1 # Wave equation
|
|
795
|
+
# return y0 - 2 * inputs[:, 0] # Output = 2×input
|
|
796
|
+
# return inputs[:, 0, 3, 5] * y0 + inputs[:, 0, 6, 7] * y1 # Mixed
|
|
797
|
+
|
|
798
|
+
return y0 - y1 * y2
|
|
799
|
+
```
|
|
800
|
+
|
|
801
|
+
```bash
|
|
802
|
+
--constraint_file my_constraint.py --constraint_weight 1.0
|
|
803
|
+
```
|
|
804
|
+
|
|
805
|
+
---
|
|
806
|
+
|
|
807
|
+
### Reference
|
|
808
|
+
|
|
809
|
+
| Argument | Default | Description |
|
|
810
|
+
|----------|---------|-------------|
|
|
811
|
+
| `--constraint` | — | Expression(s): `"y0 > 0"`, `"y0 - y1*y2"` |
|
|
812
|
+
| `--constraint_file` | — | Python file with `constraint(pred, inputs)` |
|
|
813
|
+
| `--constraint_weight` | `0.1` | Penalty weight(s) |
|
|
814
|
+
| `--constraint_reduction` | `mse` | `mse` (squared) or `mae` (linear) |
|
|
815
|
+
|
|
816
|
+
#### Expression Syntax
|
|
817
|
+
|
|
818
|
+
| Variable | Meaning |
|
|
819
|
+
|----------|---------|
|
|
820
|
+
| `y0`, `y1`, ... | Model outputs |
|
|
821
|
+
| `x[0]`, `x[1]`, ... | Input values (1D tabular) |
|
|
822
|
+
| `x[i,j]`, `x[i,j,k]` | Input values (2D/3D: images, volumes) |
|
|
823
|
+
| `x_mean`, `x_sum`, `x_max`, `x_min`, `x_std` | Input aggregates |
|
|
824
|
+
|
|
825
|
+
**Operators:** `+`, `-`, `*`, `/`, `**`, `>`, `<`, `>=`, `<=`, `==`
|
|
826
|
+
|
|
827
|
+
**Functions:** `sin`, `cos`, `exp`, `log`, `sqrt`, `sigmoid`, `softplus`, `tanh`, `relu`, `abs`
|
|
828
|
+
|
|
829
|
+
</details>
|
|
830
|
+
|
|
831
|
+
|
|
832
|
+
|
|
730
833
|
<details>
|
|
731
834
|
<summary><b>Hyperparameter Search (HPO)</b></summary>
|
|
732
835
|
|
|
@@ -766,7 +869,7 @@ accelerate launch -m wavedl.train --data_path train.npz --model cnn --lr 3.2e-4
|
|
|
766
869
|
| Schedulers | [all 8](#learning-rate-schedulers) | `--schedulers X Y` |
|
|
767
870
|
| Losses | [all 6](#loss-functions) | `--losses X Y` |
|
|
768
871
|
| Learning rate | 1e-5 → 1e-2 | (always searched) |
|
|
769
|
-
| Batch size |
|
|
872
|
+
| Batch size | 16, 32, 64, 128 | (always searched) |
|
|
770
873
|
|
|
771
874
|
**Quick Mode** (`--quick`):
|
|
772
875
|
- Uses minimal defaults: cnn + adamw + plateau + mse
|
|
@@ -938,12 +1041,12 @@ The `examples/` folder contains a **complete, ready-to-run example** for **mater
|
|
|
938
1041
|
```bash
|
|
939
1042
|
# Run inference on the example data
|
|
940
1043
|
python -m wavedl.test --checkpoint ./examples/elastic_cnn_example/best_checkpoint \
|
|
941
|
-
--data_path ./examples/elastic_cnn_example/
|
|
1044
|
+
--data_path ./examples/elastic_cnn_example/Test_data_500.mat \
|
|
942
1045
|
--plot --save_predictions --output_dir ./examples/elastic_cnn_example/test_results
|
|
943
1046
|
|
|
944
1047
|
# Export to ONNX (already included as model.onnx)
|
|
945
1048
|
python -m wavedl.test --checkpoint ./examples/elastic_cnn_example/best_checkpoint \
|
|
946
|
-
--data_path ./examples/elastic_cnn_example/
|
|
1049
|
+
--data_path ./examples/elastic_cnn_example/Test_data_500.mat \
|
|
947
1050
|
--export onnx --export_path ./examples/elastic_cnn_example/model.onnx
|
|
948
1051
|
```
|
|
949
1052
|
|
|
@@ -952,7 +1055,7 @@ python -m wavedl.test --checkpoint ./examples/elastic_cnn_example/best_checkpoin
|
|
|
952
1055
|
| File | Description |
|
|
953
1056
|
|------|-------------|
|
|
954
1057
|
| `best_checkpoint/` | Pre-trained CNN checkpoint |
|
|
955
|
-
| `
|
|
1058
|
+
| `Test_data_500.mat` | 500 sample test set (500×500 dispersion curves → *h*, √(*E*/ρ), *ν*) |
|
|
956
1059
|
| `model.onnx` | ONNX export with embedded de-normalization |
|
|
957
1060
|
| `training_history.csv` | Epoch-by-epoch training metrics (loss, R², LR, etc.) |
|
|
958
1061
|
| `training_curves.png` | Training/validation loss and learning rate plot |
|
|
@@ -963,7 +1066,7 @@ python -m wavedl.test --checkpoint ./examples/elastic_cnn_example/best_checkpoin
|
|
|
963
1066
|
|
|
964
1067
|
<p align="center">
|
|
965
1068
|
<img src="examples/elastic_cnn_example/training_curves.png" alt="Training curves" width="600"><br>
|
|
966
|
-
<em>Training and validation loss over
|
|
1069
|
+
<em>Training and validation loss over 227 epochs with <code>onecycle</code> learning rate schedule</em>
|
|
967
1070
|
</p>
|
|
968
1071
|
|
|
969
1072
|
**Inference Results:**
|
|
@@ -54,6 +54,7 @@ The framework handles the engineering challenges of large-scale deep learning
|
|
|
54
54
|
|
|
55
55
|
## ✨ Features
|
|
56
56
|
|
|
57
|
+
<div align="center">
|
|
57
58
|
<table width="100%">
|
|
58
59
|
<tr>
|
|
59
60
|
<td width="50%" valign="top">
|
|
@@ -68,14 +69,12 @@ Train on datasets larger than RAM:
|
|
|
68
69
|
</td>
|
|
69
70
|
<td width="50%" valign="top">
|
|
70
71
|
|
|
71
|
-
**🧠
|
|
72
|
+
**🧠 Models? We've Got Options**
|
|
72
73
|
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
```
|
|
78
|
-
Design your model. Register with one line.
|
|
74
|
+
38 architectures, ready to go:
|
|
75
|
+
- CNNs, ResNets, ViTs, EfficientNets...
|
|
76
|
+
- All adapted for regression
|
|
77
|
+
- [Add your own](#adding-custom-models) in one line
|
|
79
78
|
|
|
80
79
|
</td>
|
|
81
80
|
</tr>
|
|
@@ -92,12 +91,12 @@ Multi-GPU training without the pain:
|
|
|
92
91
|
</td>
|
|
93
92
|
<td width="50%" valign="top">
|
|
94
93
|
|
|
95
|
-
|
|
94
|
+
**🔬 Physics-Constrained Training**
|
|
96
95
|
|
|
97
|
-
|
|
98
|
-
-
|
|
99
|
-
-
|
|
100
|
-
-
|
|
96
|
+
Make your model respect the laws:
|
|
97
|
+
- Enforce bounds, positivity, equations
|
|
98
|
+
- Simple expression syntax or Python
|
|
99
|
+
- [Custom constraints](#physical-constraints) for various laws
|
|
101
100
|
|
|
102
101
|
</td>
|
|
103
102
|
</tr>
|
|
@@ -146,6 +145,7 @@ Deploy models anywhere:
|
|
|
146
145
|
</td>
|
|
147
146
|
</tr>
|
|
148
147
|
</table>
|
|
148
|
+
</div>
|
|
149
149
|
|
|
150
150
|
---
|
|
151
151
|
|
|
@@ -234,6 +234,10 @@ python -m wavedl.test --checkpoint <checkpoint_folder> --data_path <test_data> \
|
|
|
234
234
|
# Export model to ONNX for deployment (LabVIEW, MATLAB, C++, etc.)
|
|
235
235
|
python -m wavedl.test --checkpoint <checkpoint_folder> --data_path <test_data> \
|
|
236
236
|
--export onnx --export_path <output_file.onnx>
|
|
237
|
+
|
|
238
|
+
# For 3D volumes with small depth (e.g., 8×128×128), override auto-detection
|
|
239
|
+
python -m wavedl.test --checkpoint <checkpoint_folder> --data_path <test_data> \
|
|
240
|
+
--input_channels 1
|
|
237
241
|
```
|
|
238
242
|
|
|
239
243
|
**Output:**
|
|
@@ -329,6 +333,7 @@ WaveDL/
|
|
|
329
333
|
│ └── utils/ # Utilities
|
|
330
334
|
│ ├── data.py # Memory-mapped data pipeline
|
|
331
335
|
│ ├── metrics.py # R², Pearson, visualization
|
|
336
|
+
│ ├── constraints.py # Physical constraints for training
|
|
332
337
|
│ ├── distributed.py # DDP synchronization
|
|
333
338
|
│ ├── losses.py # Loss function factory
|
|
334
339
|
│ ├── optimizers.py # Optimizer factory
|
|
@@ -338,7 +343,7 @@ WaveDL/
|
|
|
338
343
|
├── configs/ # YAML config templates
|
|
339
344
|
├── examples/ # Ready-to-run examples
|
|
340
345
|
├── notebooks/ # Jupyter notebooks
|
|
341
|
-
├── unit_tests/ # Pytest test suite (
|
|
346
|
+
├── unit_tests/ # Pytest test suite (725 tests)
|
|
342
347
|
│
|
|
343
348
|
├── pyproject.toml # Package config, dependencies
|
|
344
349
|
├── CHANGELOG.md # Version history
|
|
@@ -682,6 +687,104 @@ seed: 2025
|
|
|
682
687
|
|
|
683
688
|
</details>
|
|
684
689
|
|
|
690
|
+
<details>
|
|
691
|
+
<summary><b>Physical Constraints</b> — Enforce Physics During Training</summary>
|
|
692
|
+
|
|
693
|
+
Add penalty terms to the loss function to enforce physical laws:
|
|
694
|
+
|
|
695
|
+
```
|
|
696
|
+
Total Loss = Data Loss + weight × penalty(violation)
|
|
697
|
+
```
|
|
698
|
+
|
|
699
|
+
### Expression Constraints
|
|
700
|
+
|
|
701
|
+
```bash
|
|
702
|
+
# Positivity
|
|
703
|
+
--constraint "y0 > 0"
|
|
704
|
+
|
|
705
|
+
# Bounds
|
|
706
|
+
--constraint "y0 >= 0" "y0 <= 1"
|
|
707
|
+
|
|
708
|
+
# Equations (penalize deviations from zero)
|
|
709
|
+
--constraint "y2 - y0 * y1"
|
|
710
|
+
|
|
711
|
+
# Input-dependent constraints
|
|
712
|
+
--constraint "y0 - 2*x[0]"
|
|
713
|
+
|
|
714
|
+
# Multiple constraints with different weights
|
|
715
|
+
--constraint "y0 > 0" "y1 - y2" --constraint_weight 0.1 1.0
|
|
716
|
+
```
|
|
717
|
+
|
|
718
|
+
### Custom Python Constraints
|
|
719
|
+
|
|
720
|
+
For complex physics (matrix operations, implicit equations):
|
|
721
|
+
|
|
722
|
+
```python
|
|
723
|
+
# my_constraint.py
|
|
724
|
+
import torch
|
|
725
|
+
|
|
726
|
+
def constraint(pred, inputs=None):
|
|
727
|
+
"""
|
|
728
|
+
Args:
|
|
729
|
+
pred: (batch, num_outputs)
|
|
730
|
+
inputs: (batch, features) or (batch, C, H, W) or (batch, C, D, H, W)
|
|
731
|
+
Returns:
|
|
732
|
+
(batch,) — violation per sample (0 = satisfied)
|
|
733
|
+
"""
|
|
734
|
+
# Outputs (same for all data types)
|
|
735
|
+
y0, y1, y2 = pred[:, 0], pred[:, 1], pred[:, 2]
|
|
736
|
+
|
|
737
|
+
# Inputs — Tabular: (batch, features)
|
|
738
|
+
# x0 = inputs[:, 0] # Feature 0
|
|
739
|
+
# x_sum = inputs.sum(dim=1) # Sum all features
|
|
740
|
+
|
|
741
|
+
# Inputs — Images: (batch, C, H, W)
|
|
742
|
+
# pixel = inputs[:, 0, 3, 5] # Pixel at (3,5), channel 0
|
|
743
|
+
# img_mean = inputs.mean(dim=(1,2,3)) # Mean over C,H,W
|
|
744
|
+
|
|
745
|
+
# Inputs — 3D Volumes: (batch, C, D, H, W)
|
|
746
|
+
# voxel = inputs[:, 0, 2, 3, 5] # Voxel at (2,3,5), channel 0
|
|
747
|
+
|
|
748
|
+
# Example constraints:
|
|
749
|
+
# return y2 - y0 * y1 # Wave equation
|
|
750
|
+
# return y0 - 2 * inputs[:, 0] # Output = 2×input
|
|
751
|
+
# return inputs[:, 0, 3, 5] * y0 + inputs[:, 0, 6, 7] * y1 # Mixed
|
|
752
|
+
|
|
753
|
+
return y0 - y1 * y2
|
|
754
|
+
```
|
|
755
|
+
|
|
756
|
+
```bash
|
|
757
|
+
--constraint_file my_constraint.py --constraint_weight 1.0
|
|
758
|
+
```
|
|
759
|
+
|
|
760
|
+
---
|
|
761
|
+
|
|
762
|
+
### Reference
|
|
763
|
+
|
|
764
|
+
| Argument | Default | Description |
|
|
765
|
+
|----------|---------|-------------|
|
|
766
|
+
| `--constraint` | — | Expression(s): `"y0 > 0"`, `"y0 - y1*y2"` |
|
|
767
|
+
| `--constraint_file` | — | Python file with `constraint(pred, inputs)` |
|
|
768
|
+
| `--constraint_weight` | `0.1` | Penalty weight(s) |
|
|
769
|
+
| `--constraint_reduction` | `mse` | `mse` (squared) or `mae` (linear) |
|
|
770
|
+
|
|
771
|
+
#### Expression Syntax
|
|
772
|
+
|
|
773
|
+
| Variable | Meaning |
|
|
774
|
+
|----------|---------|
|
|
775
|
+
| `y0`, `y1`, ... | Model outputs |
|
|
776
|
+
| `x[0]`, `x[1]`, ... | Input values (1D tabular) |
|
|
777
|
+
| `x[i,j]`, `x[i,j,k]` | Input values (2D/3D: images, volumes) |
|
|
778
|
+
| `x_mean`, `x_sum`, `x_max`, `x_min`, `x_std` | Input aggregates |
|
|
779
|
+
|
|
780
|
+
**Operators:** `+`, `-`, `*`, `/`, `**`, `>`, `<`, `>=`, `<=`, `==`
|
|
781
|
+
|
|
782
|
+
**Functions:** `sin`, `cos`, `exp`, `log`, `sqrt`, `sigmoid`, `softplus`, `tanh`, `relu`, `abs`
|
|
783
|
+
|
|
784
|
+
</details>
|
|
785
|
+
|
|
786
|
+
|
|
787
|
+
|
|
685
788
|
<details>
|
|
686
789
|
<summary><b>Hyperparameter Search (HPO)</b></summary>
|
|
687
790
|
|
|
@@ -721,7 +824,7 @@ accelerate launch -m wavedl.train --data_path train.npz --model cnn --lr 3.2e-4
|
|
|
721
824
|
| Schedulers | [all 8](#learning-rate-schedulers) | `--schedulers X Y` |
|
|
722
825
|
| Losses | [all 6](#loss-functions) | `--losses X Y` |
|
|
723
826
|
| Learning rate | 1e-5 → 1e-2 | (always searched) |
|
|
724
|
-
| Batch size |
|
|
827
|
+
| Batch size | 16, 32, 64, 128 | (always searched) |
|
|
725
828
|
|
|
726
829
|
**Quick Mode** (`--quick`):
|
|
727
830
|
- Uses minimal defaults: cnn + adamw + plateau + mse
|
|
@@ -893,12 +996,12 @@ The `examples/` folder contains a **complete, ready-to-run example** for **mater
|
|
|
893
996
|
```bash
|
|
894
997
|
# Run inference on the example data
|
|
895
998
|
python -m wavedl.test --checkpoint ./examples/elastic_cnn_example/best_checkpoint \
|
|
896
|
-
--data_path ./examples/elastic_cnn_example/
|
|
999
|
+
--data_path ./examples/elastic_cnn_example/Test_data_500.mat \
|
|
897
1000
|
--plot --save_predictions --output_dir ./examples/elastic_cnn_example/test_results
|
|
898
1001
|
|
|
899
1002
|
# Export to ONNX (already included as model.onnx)
|
|
900
1003
|
python -m wavedl.test --checkpoint ./examples/elastic_cnn_example/best_checkpoint \
|
|
901
|
-
--data_path ./examples/elastic_cnn_example/
|
|
1004
|
+
--data_path ./examples/elastic_cnn_example/Test_data_500.mat \
|
|
902
1005
|
--export onnx --export_path ./examples/elastic_cnn_example/model.onnx
|
|
903
1006
|
```
|
|
904
1007
|
|
|
@@ -907,7 +1010,7 @@ python -m wavedl.test --checkpoint ./examples/elastic_cnn_example/best_checkpoin
|
|
|
907
1010
|
| File | Description |
|
|
908
1011
|
|------|-------------|
|
|
909
1012
|
| `best_checkpoint/` | Pre-trained CNN checkpoint |
|
|
910
|
-
| `
|
|
1013
|
+
| `Test_data_500.mat` | 500 sample test set (500×500 dispersion curves → *h*, √(*E*/ρ), *ν*) |
|
|
911
1014
|
| `model.onnx` | ONNX export with embedded de-normalization |
|
|
912
1015
|
| `training_history.csv` | Epoch-by-epoch training metrics (loss, R², LR, etc.) |
|
|
913
1016
|
| `training_curves.png` | Training/validation loss and learning rate plot |
|
|
@@ -918,7 +1021,7 @@ python -m wavedl.test --checkpoint ./examples/elastic_cnn_example/best_checkpoin
|
|
|
918
1021
|
|
|
919
1022
|
<p align="center">
|
|
920
1023
|
<img src="examples/elastic_cnn_example/training_curves.png" alt="Training curves" width="600"><br>
|
|
921
|
-
<em>Training and validation loss over
|
|
1024
|
+
<em>Training and validation loss over 227 epochs with <code>onecycle</code> learning rate schedule</em>
|
|
922
1025
|
</p>
|
|
923
1026
|
|
|
924
1027
|
**Inference Results:**
|
|
@@ -89,7 +89,8 @@ def create_objective(args):
|
|
|
89
89
|
# Suggest hyperparameters
|
|
90
90
|
model = trial.suggest_categorical("model", models)
|
|
91
91
|
lr = trial.suggest_float("lr", 1e-5, 1e-2, log=True)
|
|
92
|
-
|
|
92
|
+
batch_sizes = args.batch_sizes or [16, 32, 64, 128]
|
|
93
|
+
batch_size = trial.suggest_categorical("batch_size", batch_sizes)
|
|
93
94
|
optimizer = trial.suggest_categorical("optimizer", optimizers)
|
|
94
95
|
scheduler = trial.suggest_categorical("scheduler", schedulers)
|
|
95
96
|
loss = trial.suggest_categorical("loss", losses)
|
|
@@ -317,6 +318,13 @@ Examples:
|
|
|
317
318
|
default=None,
|
|
318
319
|
help=f"Losses to search (default: {DEFAULT_LOSSES})",
|
|
319
320
|
)
|
|
321
|
+
parser.add_argument(
|
|
322
|
+
"--batch_sizes",
|
|
323
|
+
type=int,
|
|
324
|
+
nargs="+",
|
|
325
|
+
default=None,
|
|
326
|
+
help="Batch sizes to search (default: 16 32 64 128)",
|
|
327
|
+
)
|
|
320
328
|
|
|
321
329
|
# Training settings for each trial
|
|
322
330
|
parser.add_argument(
|
|
@@ -54,6 +54,16 @@ class PatchEmbed(nn.Module):
|
|
|
54
54
|
if self.dim == 1:
|
|
55
55
|
# 1D: segment patches
|
|
56
56
|
L = in_shape[0]
|
|
57
|
+
if L % patch_size != 0:
|
|
58
|
+
import warnings
|
|
59
|
+
|
|
60
|
+
warnings.warn(
|
|
61
|
+
f"Input length {L} not divisible by patch_size {patch_size}. "
|
|
62
|
+
f"Last {L % patch_size} elements will be dropped. "
|
|
63
|
+
f"Consider padding input to {((L // patch_size) + 1) * patch_size}.",
|
|
64
|
+
UserWarning,
|
|
65
|
+
stacklevel=2,
|
|
66
|
+
)
|
|
57
67
|
self.num_patches = L // patch_size
|
|
58
68
|
self.proj = nn.Conv1d(
|
|
59
69
|
1, embed_dim, kernel_size=patch_size, stride=patch_size
|
|
@@ -61,6 +71,17 @@ class PatchEmbed(nn.Module):
|
|
|
61
71
|
elif self.dim == 2:
|
|
62
72
|
# 2D: grid patches
|
|
63
73
|
H, W = in_shape
|
|
74
|
+
if H % patch_size != 0 or W % patch_size != 0:
|
|
75
|
+
import warnings
|
|
76
|
+
|
|
77
|
+
warnings.warn(
|
|
78
|
+
f"Input shape ({H}, {W}) not divisible by patch_size {patch_size}. "
|
|
79
|
+
f"Border pixels will be dropped (H: {H % patch_size}, W: {W % patch_size}). "
|
|
80
|
+
f"Consider padding to ({((H // patch_size) + 1) * patch_size}, "
|
|
81
|
+
f"{((W // patch_size) + 1) * patch_size}).",
|
|
82
|
+
UserWarning,
|
|
83
|
+
stacklevel=2,
|
|
84
|
+
)
|
|
64
85
|
self.num_patches = (H // patch_size) * (W // patch_size)
|
|
65
86
|
self.proj = nn.Conv2d(
|
|
66
87
|
1, embed_dim, kernel_size=patch_size, stride=patch_size
|
|
@@ -166,6 +166,13 @@ def parse_args() -> argparse.Namespace:
|
|
|
166
166
|
default=None,
|
|
167
167
|
help="Parameter names for output (e.g., 'h' 'v11' 'v12')",
|
|
168
168
|
)
|
|
169
|
+
parser.add_argument(
|
|
170
|
+
"--input_channels",
|
|
171
|
+
type=int,
|
|
172
|
+
default=None,
|
|
173
|
+
help="Explicit number of input channels. Bypasses auto-detection heuristics "
|
|
174
|
+
"for ambiguous 4D shapes (e.g., 3D volumes with small depth).",
|
|
175
|
+
)
|
|
169
176
|
|
|
170
177
|
# Inference options
|
|
171
178
|
parser.add_argument(
|
|
@@ -235,6 +242,7 @@ def load_data_for_inference(
|
|
|
235
242
|
format: str = "auto",
|
|
236
243
|
input_key: str | None = None,
|
|
237
244
|
output_key: str | None = None,
|
|
245
|
+
input_channels: int | None = None,
|
|
238
246
|
) -> tuple[torch.Tensor, torch.Tensor | None]:
|
|
239
247
|
"""
|
|
240
248
|
Load test data for inference using the unified data loading pipeline.
|
|
@@ -278,7 +286,11 @@ def load_data_for_inference(
|
|
|
278
286
|
|
|
279
287
|
# Use the unified loader from utils.data
|
|
280
288
|
X, y = load_test_data(
|
|
281
|
-
file_path,
|
|
289
|
+
file_path,
|
|
290
|
+
format=format,
|
|
291
|
+
input_key=input_key,
|
|
292
|
+
output_key=output_key,
|
|
293
|
+
input_channels=input_channels,
|
|
282
294
|
)
|
|
283
295
|
|
|
284
296
|
# Log results
|
|
@@ -452,7 +464,12 @@ def run_inference(
|
|
|
452
464
|
predictions: Numpy array (N, out_size) - still in normalized space
|
|
453
465
|
"""
|
|
454
466
|
if device is None:
|
|
455
|
-
|
|
467
|
+
if torch.cuda.is_available():
|
|
468
|
+
device = torch.device("cuda")
|
|
469
|
+
elif torch.backends.mps.is_available():
|
|
470
|
+
device = torch.device("mps")
|
|
471
|
+
else:
|
|
472
|
+
device = torch.device("cpu")
|
|
456
473
|
|
|
457
474
|
model = model.to(device)
|
|
458
475
|
model.eval()
|
|
@@ -463,7 +480,7 @@ def run_inference(
|
|
|
463
480
|
batch_size=batch_size,
|
|
464
481
|
shuffle=False,
|
|
465
482
|
num_workers=num_workers,
|
|
466
|
-
pin_memory=device.type
|
|
483
|
+
pin_memory=device.type in ("cuda", "mps"),
|
|
467
484
|
)
|
|
468
485
|
|
|
469
486
|
predictions = []
|
|
@@ -919,8 +936,13 @@ def main():
|
|
|
919
936
|
)
|
|
920
937
|
logger = logging.getLogger("Tester")
|
|
921
938
|
|
|
922
|
-
# Device
|
|
923
|
-
|
|
939
|
+
# Device (CUDA > MPS > CPU)
|
|
940
|
+
if torch.cuda.is_available():
|
|
941
|
+
device = torch.device("cuda")
|
|
942
|
+
elif torch.backends.mps.is_available():
|
|
943
|
+
device = torch.device("mps")
|
|
944
|
+
else:
|
|
945
|
+
device = torch.device("cpu")
|
|
924
946
|
logger.info(f"Using device: {device}")
|
|
925
947
|
|
|
926
948
|
# Load test data
|
|
@@ -929,6 +951,7 @@ def main():
|
|
|
929
951
|
format=args.format,
|
|
930
952
|
input_key=args.input_key,
|
|
931
953
|
output_key=args.output_key,
|
|
954
|
+
input_channels=args.input_channels,
|
|
932
955
|
)
|
|
933
956
|
in_shape = tuple(X_test.shape[2:])
|
|
934
957
|
|