wavedl 1.4.6__tar.gz → 1.5.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (43) hide show
  1. {wavedl-1.4.6/src/wavedl.egg-info → wavedl-1.5.1}/PKG-INFO +122 -19
  2. {wavedl-1.4.6 → wavedl-1.5.1}/README.md +121 -18
  3. {wavedl-1.4.6 → wavedl-1.5.1}/src/wavedl/__init__.py +1 -1
  4. {wavedl-1.4.6 → wavedl-1.5.1}/src/wavedl/hpo.py +9 -1
  5. {wavedl-1.4.6 → wavedl-1.5.1}/src/wavedl/models/vit.py +21 -0
  6. {wavedl-1.4.6 → wavedl-1.5.1}/src/wavedl/test.py +28 -5
  7. {wavedl-1.4.6 → wavedl-1.5.1}/src/wavedl/train.py +122 -15
  8. {wavedl-1.4.6 → wavedl-1.5.1}/src/wavedl/utils/__init__.py +11 -0
  9. {wavedl-1.4.6 → wavedl-1.5.1}/src/wavedl/utils/config.py +10 -0
  10. wavedl-1.5.1/src/wavedl/utils/constraints.py +470 -0
  11. {wavedl-1.4.6 → wavedl-1.5.1}/src/wavedl/utils/cross_validation.py +12 -2
  12. {wavedl-1.4.6 → wavedl-1.5.1}/src/wavedl/utils/data.py +26 -7
  13. {wavedl-1.4.6 → wavedl-1.5.1}/src/wavedl/utils/metrics.py +49 -2
  14. {wavedl-1.4.6 → wavedl-1.5.1/src/wavedl.egg-info}/PKG-INFO +122 -19
  15. {wavedl-1.4.6 → wavedl-1.5.1}/src/wavedl.egg-info/SOURCES.txt +1 -0
  16. {wavedl-1.4.6 → wavedl-1.5.1}/LICENSE +0 -0
  17. {wavedl-1.4.6 → wavedl-1.5.1}/pyproject.toml +0 -0
  18. {wavedl-1.4.6 → wavedl-1.5.1}/setup.cfg +0 -0
  19. {wavedl-1.4.6 → wavedl-1.5.1}/src/wavedl/hpc.py +0 -0
  20. {wavedl-1.4.6 → wavedl-1.5.1}/src/wavedl/models/__init__.py +0 -0
  21. {wavedl-1.4.6 → wavedl-1.5.1}/src/wavedl/models/_template.py +0 -0
  22. {wavedl-1.4.6 → wavedl-1.5.1}/src/wavedl/models/base.py +0 -0
  23. {wavedl-1.4.6 → wavedl-1.5.1}/src/wavedl/models/cnn.py +0 -0
  24. {wavedl-1.4.6 → wavedl-1.5.1}/src/wavedl/models/convnext.py +0 -0
  25. {wavedl-1.4.6 → wavedl-1.5.1}/src/wavedl/models/densenet.py +0 -0
  26. {wavedl-1.4.6 → wavedl-1.5.1}/src/wavedl/models/efficientnet.py +0 -0
  27. {wavedl-1.4.6 → wavedl-1.5.1}/src/wavedl/models/efficientnetv2.py +0 -0
  28. {wavedl-1.4.6 → wavedl-1.5.1}/src/wavedl/models/mobilenetv3.py +0 -0
  29. {wavedl-1.4.6 → wavedl-1.5.1}/src/wavedl/models/registry.py +0 -0
  30. {wavedl-1.4.6 → wavedl-1.5.1}/src/wavedl/models/regnet.py +0 -0
  31. {wavedl-1.4.6 → wavedl-1.5.1}/src/wavedl/models/resnet.py +0 -0
  32. {wavedl-1.4.6 → wavedl-1.5.1}/src/wavedl/models/resnet3d.py +0 -0
  33. {wavedl-1.4.6 → wavedl-1.5.1}/src/wavedl/models/swin.py +0 -0
  34. {wavedl-1.4.6 → wavedl-1.5.1}/src/wavedl/models/tcn.py +0 -0
  35. {wavedl-1.4.6 → wavedl-1.5.1}/src/wavedl/models/unet.py +0 -0
  36. {wavedl-1.4.6 → wavedl-1.5.1}/src/wavedl/utils/distributed.py +0 -0
  37. {wavedl-1.4.6 → wavedl-1.5.1}/src/wavedl/utils/losses.py +0 -0
  38. {wavedl-1.4.6 → wavedl-1.5.1}/src/wavedl/utils/optimizers.py +0 -0
  39. {wavedl-1.4.6 → wavedl-1.5.1}/src/wavedl/utils/schedulers.py +0 -0
  40. {wavedl-1.4.6 → wavedl-1.5.1}/src/wavedl.egg-info/dependency_links.txt +0 -0
  41. {wavedl-1.4.6 → wavedl-1.5.1}/src/wavedl.egg-info/entry_points.txt +0 -0
  42. {wavedl-1.4.6 → wavedl-1.5.1}/src/wavedl.egg-info/requires.txt +0 -0
  43. {wavedl-1.4.6 → wavedl-1.5.1}/src/wavedl.egg-info/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.2
2
2
  Name: wavedl
3
- Version: 1.4.6
3
+ Version: 1.5.1
4
4
  Summary: A Scalable Deep Learning Framework for Wave-Based Inverse Problems
5
5
  Author: Ductho Le
6
6
  License: MIT
@@ -99,6 +99,7 @@ The framework handles the engineering challenges of large-scale deep learning
99
99
 
100
100
  ## ✨ Features
101
101
 
102
+ <div align="center">
102
103
  <table width="100%">
103
104
  <tr>
104
105
  <td width="50%" valign="top">
@@ -113,14 +114,12 @@ Train on datasets larger than RAM:
113
114
  </td>
114
115
  <td width="50%" valign="top">
115
116
 
116
- **🧠 One-Line Model Registration**
117
+ **🧠 Models? We've Got Options**
117
118
 
118
- Plug in any architecture:
119
- ```python
120
- @register_model("my_net")
121
- class MyNet(BaseModel): ...
122
- ```
123
- Design your model. Register with one line.
119
+ 38 architectures, ready to go:
120
+ - CNNs, ResNets, ViTs, EfficientNets...
121
+ - All adapted for regression
122
+ - [Add your own](#adding-custom-models) in one line
124
123
 
125
124
  </td>
126
125
  </tr>
@@ -137,12 +136,12 @@ Multi-GPU training without the pain:
137
136
  </td>
138
137
  <td width="50%" valign="top">
139
138
 
140
- **📊 Publish-Ready Output**
139
+ **🔬 Physics-Constrained Training**
141
140
 
142
- Results go straight to your paper:
143
- - 11 diagnostic plots with LaTeX styling
144
- - Multi-format export (PNG, PDF, SVG, ...)
145
- - MAE in physical units per parameter
141
+ Make your model respect the laws:
142
+ - Enforce bounds, positivity, equations
143
+ - Simple expression syntax or Python
144
+ - [Custom constraints](#physical-constraints) for various laws
146
145
 
147
146
  </td>
148
147
  </tr>
@@ -191,6 +190,7 @@ Deploy models anywhere:
191
190
  </td>
192
191
  </tr>
193
192
  </table>
193
+ </div>
194
194
 
195
195
  ---
196
196
 
@@ -279,6 +279,10 @@ python -m wavedl.test --checkpoint <checkpoint_folder> --data_path <test_data> \
279
279
  # Export model to ONNX for deployment (LabVIEW, MATLAB, C++, etc.)
280
280
  python -m wavedl.test --checkpoint <checkpoint_folder> --data_path <test_data> \
281
281
  --export onnx --export_path <output_file.onnx>
282
+
283
+ # For 3D volumes with small depth (e.g., 8×128×128), override auto-detection
284
+ python -m wavedl.test --checkpoint <checkpoint_folder> --data_path <test_data> \
285
+ --input_channels 1
282
286
  ```
283
287
 
284
288
  **Output:**
@@ -374,6 +378,7 @@ WaveDL/
374
378
  │ └── utils/ # Utilities
375
379
  │ ├── data.py # Memory-mapped data pipeline
376
380
  │ ├── metrics.py # R², Pearson, visualization
381
+ │ ├── constraints.py # Physical constraints for training
377
382
  │ ├── distributed.py # DDP synchronization
378
383
  │ ├── losses.py # Loss function factory
379
384
  │ ├── optimizers.py # Optimizer factory
@@ -383,7 +388,7 @@ WaveDL/
383
388
  ├── configs/ # YAML config templates
384
389
  ├── examples/ # Ready-to-run examples
385
390
  ├── notebooks/ # Jupyter notebooks
386
- ├── unit_tests/ # Pytest test suite (704 tests)
391
+ ├── unit_tests/ # Pytest test suite (725 tests)
387
392
 
388
393
  ├── pyproject.toml # Package config, dependencies
389
394
  ├── CHANGELOG.md # Version history
@@ -727,6 +732,104 @@ seed: 2025
727
732
 
728
733
  </details>
729
734
 
735
+ <details>
736
+ <summary><b>Physical Constraints</b> — Enforce Physics During Training</summary>
737
+
738
+ Add penalty terms to the loss function to enforce physical laws:
739
+
740
+ ```
741
+ Total Loss = Data Loss + weight × penalty(violation)
742
+ ```
743
+
744
+ ### Expression Constraints
745
+
746
+ ```bash
747
+ # Positivity
748
+ --constraint "y0 > 0"
749
+
750
+ # Bounds
751
+ --constraint "y0 >= 0" "y0 <= 1"
752
+
753
+ # Equations (penalize deviations from zero)
754
+ --constraint "y2 - y0 * y1"
755
+
756
+ # Input-dependent constraints
757
+ --constraint "y0 - 2*x[0]"
758
+
759
+ # Multiple constraints with different weights
760
+ --constraint "y0 > 0" "y1 - y2" --constraint_weight 0.1 1.0
761
+ ```
762
+
763
+ ### Custom Python Constraints
764
+
765
+ For complex physics (matrix operations, implicit equations):
766
+
767
+ ```python
768
+ # my_constraint.py
769
+ import torch
770
+
771
+ def constraint(pred, inputs=None):
772
+ """
773
+ Args:
774
+ pred: (batch, num_outputs)
775
+ inputs: (batch, features) or (batch, C, H, W) or (batch, C, D, H, W)
776
+ Returns:
777
+ (batch,) — violation per sample (0 = satisfied)
778
+ """
779
+ # Outputs (same for all data types)
780
+ y0, y1, y2 = pred[:, 0], pred[:, 1], pred[:, 2]
781
+
782
+ # Inputs — Tabular: (batch, features)
783
+ # x0 = inputs[:, 0] # Feature 0
784
+ # x_sum = inputs.sum(dim=1) # Sum all features
785
+
786
+ # Inputs — Images: (batch, C, H, W)
787
+ # pixel = inputs[:, 0, 3, 5] # Pixel at (3,5), channel 0
788
+ # img_mean = inputs.mean(dim=(1,2,3)) # Mean over C,H,W
789
+
790
+ # Inputs — 3D Volumes: (batch, C, D, H, W)
791
+ # voxel = inputs[:, 0, 2, 3, 5] # Voxel at (2,3,5), channel 0
792
+
793
+ # Example constraints:
794
+ # return y2 - y0 * y1 # Wave equation
795
+ # return y0 - 2 * inputs[:, 0] # Output = 2×input
796
+ # return inputs[:, 0, 3, 5] * y0 + inputs[:, 0, 6, 7] * y1 # Mixed
797
+
798
+ return y0 - y1 * y2
799
+ ```
800
+
801
+ ```bash
802
+ --constraint_file my_constraint.py --constraint_weight 1.0
803
+ ```
804
+
805
+ ---
806
+
807
+ ### Reference
808
+
809
+ | Argument | Default | Description |
810
+ |----------|---------|-------------|
811
+ | `--constraint` | — | Expression(s): `"y0 > 0"`, `"y0 - y1*y2"` |
812
+ | `--constraint_file` | — | Python file with `constraint(pred, inputs)` |
813
+ | `--constraint_weight` | `0.1` | Penalty weight(s) |
814
+ | `--constraint_reduction` | `mse` | `mse` (squared) or `mae` (linear) |
815
+
816
+ #### Expression Syntax
817
+
818
+ | Variable | Meaning |
819
+ |----------|---------|
820
+ | `y0`, `y1`, ... | Model outputs |
821
+ | `x[0]`, `x[1]`, ... | Input values (1D tabular) |
822
+ | `x[i,j]`, `x[i,j,k]` | Input values (2D/3D: images, volumes) |
823
+ | `x_mean`, `x_sum`, `x_max`, `x_min`, `x_std` | Input aggregates |
824
+
825
+ **Operators:** `+`, `-`, `*`, `/`, `**`, `>`, `<`, `>=`, `<=`, `==`
826
+
827
+ **Functions:** `sin`, `cos`, `exp`, `log`, `sqrt`, `sigmoid`, `softplus`, `tanh`, `relu`, `abs`
828
+
829
+ </details>
830
+
831
+
832
+
730
833
  <details>
731
834
  <summary><b>Hyperparameter Search (HPO)</b></summary>
732
835
 
@@ -766,7 +869,7 @@ accelerate launch -m wavedl.train --data_path train.npz --model cnn --lr 3.2e-4
766
869
  | Schedulers | [all 8](#learning-rate-schedulers) | `--schedulers X Y` |
767
870
  | Losses | [all 6](#loss-functions) | `--losses X Y` |
768
871
  | Learning rate | 1e-5 → 1e-2 | (always searched) |
769
- | Batch size | 64, 128, 256, 512 | (always searched) |
872
+ | Batch size | 16, 32, 64, 128 | (always searched) |
770
873
 
771
874
  **Quick Mode** (`--quick`):
772
875
  - Uses minimal defaults: cnn + adamw + plateau + mse
@@ -938,12 +1041,12 @@ The `examples/` folder contains a **complete, ready-to-run example** for **mater
938
1041
  ```bash
939
1042
  # Run inference on the example data
940
1043
  python -m wavedl.test --checkpoint ./examples/elastic_cnn_example/best_checkpoint \
941
- --data_path ./examples/elastic_cnn_example/Test_data_100.mat \
1044
+ --data_path ./examples/elastic_cnn_example/Test_data_500.mat \
942
1045
  --plot --save_predictions --output_dir ./examples/elastic_cnn_example/test_results
943
1046
 
944
1047
  # Export to ONNX (already included as model.onnx)
945
1048
  python -m wavedl.test --checkpoint ./examples/elastic_cnn_example/best_checkpoint \
946
- --data_path ./examples/elastic_cnn_example/Test_data_100.mat \
1049
+ --data_path ./examples/elastic_cnn_example/Test_data_500.mat \
947
1050
  --export onnx --export_path ./examples/elastic_cnn_example/model.onnx
948
1051
  ```
949
1052
 
@@ -952,7 +1055,7 @@ python -m wavedl.test --checkpoint ./examples/elastic_cnn_example/best_checkpoin
952
1055
  | File | Description |
953
1056
  |------|-------------|
954
1057
  | `best_checkpoint/` | Pre-trained CNN checkpoint |
955
- | `Test_data_100.mat` | 100 sample test set (500×500 dispersion curves → *h*, √(*E*/ρ), *ν*) |
1058
+ | `Test_data_500.mat` | 500 sample test set (500×500 dispersion curves → *h*, √(*E*/ρ), *ν*) |
956
1059
  | `model.onnx` | ONNX export with embedded de-normalization |
957
1060
  | `training_history.csv` | Epoch-by-epoch training metrics (loss, R², LR, etc.) |
958
1061
  | `training_curves.png` | Training/validation loss and learning rate plot |
@@ -963,7 +1066,7 @@ python -m wavedl.test --checkpoint ./examples/elastic_cnn_example/best_checkpoin
963
1066
 
964
1067
  <p align="center">
965
1068
  <img src="examples/elastic_cnn_example/training_curves.png" alt="Training curves" width="600"><br>
966
- <em>Training and validation loss over 162 epochs with learning rate schedule</em>
1069
+ <em>Training and validation loss over 227 epochs with <code>onecycle</code> learning rate schedule</em>
967
1070
  </p>
968
1071
 
969
1072
  **Inference Results:**
@@ -54,6 +54,7 @@ The framework handles the engineering challenges of large-scale deep learning
54
54
 
55
55
  ## ✨ Features
56
56
 
57
+ <div align="center">
57
58
  <table width="100%">
58
59
  <tr>
59
60
  <td width="50%" valign="top">
@@ -68,14 +69,12 @@ Train on datasets larger than RAM:
68
69
  </td>
69
70
  <td width="50%" valign="top">
70
71
 
71
- **🧠 One-Line Model Registration**
72
+ **🧠 Models? We've Got Options**
72
73
 
73
- Plug in any architecture:
74
- ```python
75
- @register_model("my_net")
76
- class MyNet(BaseModel): ...
77
- ```
78
- Design your model. Register with one line.
74
+ 38 architectures, ready to go:
75
+ - CNNs, ResNets, ViTs, EfficientNets...
76
+ - All adapted for regression
77
+ - [Add your own](#adding-custom-models) in one line
79
78
 
80
79
  </td>
81
80
  </tr>
@@ -92,12 +91,12 @@ Multi-GPU training without the pain:
92
91
  </td>
93
92
  <td width="50%" valign="top">
94
93
 
95
- **📊 Publish-Ready Output**
94
+ **🔬 Physics-Constrained Training**
96
95
 
97
- Results go straight to your paper:
98
- - 11 diagnostic plots with LaTeX styling
99
- - Multi-format export (PNG, PDF, SVG, ...)
100
- - MAE in physical units per parameter
96
+ Make your model respect the laws:
97
+ - Enforce bounds, positivity, equations
98
+ - Simple expression syntax or Python
99
+ - [Custom constraints](#physical-constraints) for various laws
101
100
 
102
101
  </td>
103
102
  </tr>
@@ -146,6 +145,7 @@ Deploy models anywhere:
146
145
  </td>
147
146
  </tr>
148
147
  </table>
148
+ </div>
149
149
 
150
150
  ---
151
151
 
@@ -234,6 +234,10 @@ python -m wavedl.test --checkpoint <checkpoint_folder> --data_path <test_data> \
234
234
  # Export model to ONNX for deployment (LabVIEW, MATLAB, C++, etc.)
235
235
  python -m wavedl.test --checkpoint <checkpoint_folder> --data_path <test_data> \
236
236
  --export onnx --export_path <output_file.onnx>
237
+
238
+ # For 3D volumes with small depth (e.g., 8×128×128), override auto-detection
239
+ python -m wavedl.test --checkpoint <checkpoint_folder> --data_path <test_data> \
240
+ --input_channels 1
237
241
  ```
238
242
 
239
243
  **Output:**
@@ -329,6 +333,7 @@ WaveDL/
329
333
  │ └── utils/ # Utilities
330
334
  │ ├── data.py # Memory-mapped data pipeline
331
335
  │ ├── metrics.py # R², Pearson, visualization
336
+ │ ├── constraints.py # Physical constraints for training
332
337
  │ ├── distributed.py # DDP synchronization
333
338
  │ ├── losses.py # Loss function factory
334
339
  │ ├── optimizers.py # Optimizer factory
@@ -338,7 +343,7 @@ WaveDL/
338
343
  ├── configs/ # YAML config templates
339
344
  ├── examples/ # Ready-to-run examples
340
345
  ├── notebooks/ # Jupyter notebooks
341
- ├── unit_tests/ # Pytest test suite (704 tests)
346
+ ├── unit_tests/ # Pytest test suite (725 tests)
342
347
 
343
348
  ├── pyproject.toml # Package config, dependencies
344
349
  ├── CHANGELOG.md # Version history
@@ -682,6 +687,104 @@ seed: 2025
682
687
 
683
688
  </details>
684
689
 
690
+ <details>
691
+ <summary><b>Physical Constraints</b> — Enforce Physics During Training</summary>
692
+
693
+ Add penalty terms to the loss function to enforce physical laws:
694
+
695
+ ```
696
+ Total Loss = Data Loss + weight × penalty(violation)
697
+ ```
698
+
699
+ ### Expression Constraints
700
+
701
+ ```bash
702
+ # Positivity
703
+ --constraint "y0 > 0"
704
+
705
+ # Bounds
706
+ --constraint "y0 >= 0" "y0 <= 1"
707
+
708
+ # Equations (penalize deviations from zero)
709
+ --constraint "y2 - y0 * y1"
710
+
711
+ # Input-dependent constraints
712
+ --constraint "y0 - 2*x[0]"
713
+
714
+ # Multiple constraints with different weights
715
+ --constraint "y0 > 0" "y1 - y2" --constraint_weight 0.1 1.0
716
+ ```
717
+
718
+ ### Custom Python Constraints
719
+
720
+ For complex physics (matrix operations, implicit equations):
721
+
722
+ ```python
723
+ # my_constraint.py
724
+ import torch
725
+
726
+ def constraint(pred, inputs=None):
727
+ """
728
+ Args:
729
+ pred: (batch, num_outputs)
730
+ inputs: (batch, features) or (batch, C, H, W) or (batch, C, D, H, W)
731
+ Returns:
732
+ (batch,) — violation per sample (0 = satisfied)
733
+ """
734
+ # Outputs (same for all data types)
735
+ y0, y1, y2 = pred[:, 0], pred[:, 1], pred[:, 2]
736
+
737
+ # Inputs — Tabular: (batch, features)
738
+ # x0 = inputs[:, 0] # Feature 0
739
+ # x_sum = inputs.sum(dim=1) # Sum all features
740
+
741
+ # Inputs — Images: (batch, C, H, W)
742
+ # pixel = inputs[:, 0, 3, 5] # Pixel at (3,5), channel 0
743
+ # img_mean = inputs.mean(dim=(1,2,3)) # Mean over C,H,W
744
+
745
+ # Inputs — 3D Volumes: (batch, C, D, H, W)
746
+ # voxel = inputs[:, 0, 2, 3, 5] # Voxel at (2,3,5), channel 0
747
+
748
+ # Example constraints:
749
+ # return y2 - y0 * y1 # Wave equation
750
+ # return y0 - 2 * inputs[:, 0] # Output = 2×input
751
+ # return inputs[:, 0, 3, 5] * y0 + inputs[:, 0, 6, 7] * y1 # Mixed
752
+
753
+ return y0 - y1 * y2
754
+ ```
755
+
756
+ ```bash
757
+ --constraint_file my_constraint.py --constraint_weight 1.0
758
+ ```
759
+
760
+ ---
761
+
762
+ ### Reference
763
+
764
+ | Argument | Default | Description |
765
+ |----------|---------|-------------|
766
+ | `--constraint` | — | Expression(s): `"y0 > 0"`, `"y0 - y1*y2"` |
767
+ | `--constraint_file` | — | Python file with `constraint(pred, inputs)` |
768
+ | `--constraint_weight` | `0.1` | Penalty weight(s) |
769
+ | `--constraint_reduction` | `mse` | `mse` (squared) or `mae` (linear) |
770
+
771
+ #### Expression Syntax
772
+
773
+ | Variable | Meaning |
774
+ |----------|---------|
775
+ | `y0`, `y1`, ... | Model outputs |
776
+ | `x[0]`, `x[1]`, ... | Input values (1D tabular) |
777
+ | `x[i,j]`, `x[i,j,k]` | Input values (2D/3D: images, volumes) |
778
+ | `x_mean`, `x_sum`, `x_max`, `x_min`, `x_std` | Input aggregates |
779
+
780
+ **Operators:** `+`, `-`, `*`, `/`, `**`, `>`, `<`, `>=`, `<=`, `==`
781
+
782
+ **Functions:** `sin`, `cos`, `exp`, `log`, `sqrt`, `sigmoid`, `softplus`, `tanh`, `relu`, `abs`
783
+
784
+ </details>
785
+
786
+
787
+
685
788
  <details>
686
789
  <summary><b>Hyperparameter Search (HPO)</b></summary>
687
790
 
@@ -721,7 +824,7 @@ accelerate launch -m wavedl.train --data_path train.npz --model cnn --lr 3.2e-4
721
824
  | Schedulers | [all 8](#learning-rate-schedulers) | `--schedulers X Y` |
722
825
  | Losses | [all 6](#loss-functions) | `--losses X Y` |
723
826
  | Learning rate | 1e-5 → 1e-2 | (always searched) |
724
- | Batch size | 64, 128, 256, 512 | (always searched) |
827
+ | Batch size | 16, 32, 64, 128 | (always searched) |
725
828
 
726
829
  **Quick Mode** (`--quick`):
727
830
  - Uses minimal defaults: cnn + adamw + plateau + mse
@@ -893,12 +996,12 @@ The `examples/` folder contains a **complete, ready-to-run example** for **mater
893
996
  ```bash
894
997
  # Run inference on the example data
895
998
  python -m wavedl.test --checkpoint ./examples/elastic_cnn_example/best_checkpoint \
896
- --data_path ./examples/elastic_cnn_example/Test_data_100.mat \
999
+ --data_path ./examples/elastic_cnn_example/Test_data_500.mat \
897
1000
  --plot --save_predictions --output_dir ./examples/elastic_cnn_example/test_results
898
1001
 
899
1002
  # Export to ONNX (already included as model.onnx)
900
1003
  python -m wavedl.test --checkpoint ./examples/elastic_cnn_example/best_checkpoint \
901
- --data_path ./examples/elastic_cnn_example/Test_data_100.mat \
1004
+ --data_path ./examples/elastic_cnn_example/Test_data_500.mat \
902
1005
  --export onnx --export_path ./examples/elastic_cnn_example/model.onnx
903
1006
  ```
904
1007
 
@@ -907,7 +1010,7 @@ python -m wavedl.test --checkpoint ./examples/elastic_cnn_example/best_checkpoin
907
1010
  | File | Description |
908
1011
  |------|-------------|
909
1012
  | `best_checkpoint/` | Pre-trained CNN checkpoint |
910
- | `Test_data_100.mat` | 100 sample test set (500×500 dispersion curves → *h*, √(*E*/ρ), *ν*) |
1013
+ | `Test_data_500.mat` | 500 sample test set (500×500 dispersion curves → *h*, √(*E*/ρ), *ν*) |
911
1014
  | `model.onnx` | ONNX export with embedded de-normalization |
912
1015
  | `training_history.csv` | Epoch-by-epoch training metrics (loss, R², LR, etc.) |
913
1016
  | `training_curves.png` | Training/validation loss and learning rate plot |
@@ -918,7 +1021,7 @@ python -m wavedl.test --checkpoint ./examples/elastic_cnn_example/best_checkpoin
918
1021
 
919
1022
  <p align="center">
920
1023
  <img src="examples/elastic_cnn_example/training_curves.png" alt="Training curves" width="600"><br>
921
- <em>Training and validation loss over 162 epochs with learning rate schedule</em>
1024
+ <em>Training and validation loss over 227 epochs with <code>onecycle</code> learning rate schedule</em>
922
1025
  </p>
923
1026
 
924
1027
  **Inference Results:**
@@ -18,7 +18,7 @@ For inference:
18
18
  # or: python -m wavedl.test --checkpoint best_checkpoint --data_path test.npz
19
19
  """
20
20
 
21
- __version__ = "1.4.6"
21
+ __version__ = "1.5.1"
22
22
  __author__ = "Ductho Le"
23
23
  __email__ = "ductho.le@outlook.com"
24
24
 
@@ -89,7 +89,8 @@ def create_objective(args):
89
89
  # Suggest hyperparameters
90
90
  model = trial.suggest_categorical("model", models)
91
91
  lr = trial.suggest_float("lr", 1e-5, 1e-2, log=True)
92
- batch_size = trial.suggest_categorical("batch_size", [64, 128, 256, 512])
92
+ batch_sizes = args.batch_sizes or [16, 32, 64, 128]
93
+ batch_size = trial.suggest_categorical("batch_size", batch_sizes)
93
94
  optimizer = trial.suggest_categorical("optimizer", optimizers)
94
95
  scheduler = trial.suggest_categorical("scheduler", schedulers)
95
96
  loss = trial.suggest_categorical("loss", losses)
@@ -317,6 +318,13 @@ Examples:
317
318
  default=None,
318
319
  help=f"Losses to search (default: {DEFAULT_LOSSES})",
319
320
  )
321
+ parser.add_argument(
322
+ "--batch_sizes",
323
+ type=int,
324
+ nargs="+",
325
+ default=None,
326
+ help="Batch sizes to search (default: 16 32 64 128)",
327
+ )
320
328
 
321
329
  # Training settings for each trial
322
330
  parser.add_argument(
@@ -54,6 +54,16 @@ class PatchEmbed(nn.Module):
54
54
  if self.dim == 1:
55
55
  # 1D: segment patches
56
56
  L = in_shape[0]
57
+ if L % patch_size != 0:
58
+ import warnings
59
+
60
+ warnings.warn(
61
+ f"Input length {L} not divisible by patch_size {patch_size}. "
62
+ f"Last {L % patch_size} elements will be dropped. "
63
+ f"Consider padding input to {((L // patch_size) + 1) * patch_size}.",
64
+ UserWarning,
65
+ stacklevel=2,
66
+ )
57
67
  self.num_patches = L // patch_size
58
68
  self.proj = nn.Conv1d(
59
69
  1, embed_dim, kernel_size=patch_size, stride=patch_size
@@ -61,6 +71,17 @@ class PatchEmbed(nn.Module):
61
71
  elif self.dim == 2:
62
72
  # 2D: grid patches
63
73
  H, W = in_shape
74
+ if H % patch_size != 0 or W % patch_size != 0:
75
+ import warnings
76
+
77
+ warnings.warn(
78
+ f"Input shape ({H}, {W}) not divisible by patch_size {patch_size}. "
79
+ f"Border pixels will be dropped (H: {H % patch_size}, W: {W % patch_size}). "
80
+ f"Consider padding to ({((H // patch_size) + 1) * patch_size}, "
81
+ f"{((W // patch_size) + 1) * patch_size}).",
82
+ UserWarning,
83
+ stacklevel=2,
84
+ )
64
85
  self.num_patches = (H // patch_size) * (W // patch_size)
65
86
  self.proj = nn.Conv2d(
66
87
  1, embed_dim, kernel_size=patch_size, stride=patch_size
@@ -166,6 +166,13 @@ def parse_args() -> argparse.Namespace:
166
166
  default=None,
167
167
  help="Parameter names for output (e.g., 'h' 'v11' 'v12')",
168
168
  )
169
+ parser.add_argument(
170
+ "--input_channels",
171
+ type=int,
172
+ default=None,
173
+ help="Explicit number of input channels. Bypasses auto-detection heuristics "
174
+ "for ambiguous 4D shapes (e.g., 3D volumes with small depth).",
175
+ )
169
176
 
170
177
  # Inference options
171
178
  parser.add_argument(
@@ -235,6 +242,7 @@ def load_data_for_inference(
235
242
  format: str = "auto",
236
243
  input_key: str | None = None,
237
244
  output_key: str | None = None,
245
+ input_channels: int | None = None,
238
246
  ) -> tuple[torch.Tensor, torch.Tensor | None]:
239
247
  """
240
248
  Load test data for inference using the unified data loading pipeline.
@@ -278,7 +286,11 @@ def load_data_for_inference(
278
286
 
279
287
  # Use the unified loader from utils.data
280
288
  X, y = load_test_data(
281
- file_path, format=format, input_key=input_key, output_key=output_key
289
+ file_path,
290
+ format=format,
291
+ input_key=input_key,
292
+ output_key=output_key,
293
+ input_channels=input_channels,
282
294
  )
283
295
 
284
296
  # Log results
@@ -452,7 +464,12 @@ def run_inference(
452
464
  predictions: Numpy array (N, out_size) - still in normalized space
453
465
  """
454
466
  if device is None:
455
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
467
+ if torch.cuda.is_available():
468
+ device = torch.device("cuda")
469
+ elif torch.backends.mps.is_available():
470
+ device = torch.device("mps")
471
+ else:
472
+ device = torch.device("cpu")
456
473
 
457
474
  model = model.to(device)
458
475
  model.eval()
@@ -463,7 +480,7 @@ def run_inference(
463
480
  batch_size=batch_size,
464
481
  shuffle=False,
465
482
  num_workers=num_workers,
466
- pin_memory=device.type == "cuda",
483
+ pin_memory=device.type in ("cuda", "mps"),
467
484
  )
468
485
 
469
486
  predictions = []
@@ -919,8 +936,13 @@ def main():
919
936
  )
920
937
  logger = logging.getLogger("Tester")
921
938
 
922
- # Device
923
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
939
+ # Device (CUDA > MPS > CPU)
940
+ if torch.cuda.is_available():
941
+ device = torch.device("cuda")
942
+ elif torch.backends.mps.is_available():
943
+ device = torch.device("mps")
944
+ else:
945
+ device = torch.device("cpu")
924
946
  logger.info(f"Using device: {device}")
925
947
 
926
948
  # Load test data
@@ -929,6 +951,7 @@ def main():
929
951
  format=args.format,
930
952
  input_key=args.input_key,
931
953
  output_key=args.output_key,
954
+ input_channels=args.input_channels,
932
955
  )
933
956
  in_shape = tuple(X_test.shape[2:])
934
957