wavedl 1.5.5__tar.gz → 1.5.7__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (43) hide show
  1. {wavedl-1.5.5/src/wavedl.egg-info → wavedl-1.5.7}/PKG-INFO +37 -27
  2. {wavedl-1.5.5 → wavedl-1.5.7}/README.md +36 -26
  3. {wavedl-1.5.5 → wavedl-1.5.7}/src/wavedl/__init__.py +1 -1
  4. {wavedl-1.5.5 → wavedl-1.5.7}/src/wavedl/models/efficientnet.py +24 -7
  5. {wavedl-1.5.5 → wavedl-1.5.7}/src/wavedl/models/efficientnetv2.py +29 -6
  6. {wavedl-1.5.5 → wavedl-1.5.7}/src/wavedl/models/mobilenetv3.py +31 -8
  7. {wavedl-1.5.5 → wavedl-1.5.7}/src/wavedl/models/regnet.py +29 -6
  8. {wavedl-1.5.5 → wavedl-1.5.7}/src/wavedl/models/swin.py +38 -6
  9. {wavedl-1.5.5 → wavedl-1.5.7}/src/wavedl/models/tcn.py +22 -2
  10. {wavedl-1.5.5 → wavedl-1.5.7}/src/wavedl/models/vit.py +85 -25
  11. {wavedl-1.5.5 → wavedl-1.5.7}/src/wavedl/test.py +7 -3
  12. {wavedl-1.5.5 → wavedl-1.5.7}/src/wavedl/train.py +79 -18
  13. {wavedl-1.5.5 → wavedl-1.5.7}/src/wavedl/utils/constraints.py +11 -5
  14. {wavedl-1.5.5 → wavedl-1.5.7}/src/wavedl/utils/data.py +130 -39
  15. {wavedl-1.5.5 → wavedl-1.5.7}/src/wavedl/utils/metrics.py +287 -326
  16. {wavedl-1.5.5 → wavedl-1.5.7/src/wavedl.egg-info}/PKG-INFO +37 -27
  17. {wavedl-1.5.5 → wavedl-1.5.7}/LICENSE +0 -0
  18. {wavedl-1.5.5 → wavedl-1.5.7}/pyproject.toml +0 -0
  19. {wavedl-1.5.5 → wavedl-1.5.7}/setup.cfg +0 -0
  20. {wavedl-1.5.5 → wavedl-1.5.7}/src/wavedl/hpc.py +0 -0
  21. {wavedl-1.5.5 → wavedl-1.5.7}/src/wavedl/hpo.py +0 -0
  22. {wavedl-1.5.5 → wavedl-1.5.7}/src/wavedl/models/__init__.py +0 -0
  23. {wavedl-1.5.5 → wavedl-1.5.7}/src/wavedl/models/_template.py +0 -0
  24. {wavedl-1.5.5 → wavedl-1.5.7}/src/wavedl/models/base.py +0 -0
  25. {wavedl-1.5.5 → wavedl-1.5.7}/src/wavedl/models/cnn.py +0 -0
  26. {wavedl-1.5.5 → wavedl-1.5.7}/src/wavedl/models/convnext.py +0 -0
  27. {wavedl-1.5.5 → wavedl-1.5.7}/src/wavedl/models/densenet.py +0 -0
  28. {wavedl-1.5.5 → wavedl-1.5.7}/src/wavedl/models/registry.py +0 -0
  29. {wavedl-1.5.5 → wavedl-1.5.7}/src/wavedl/models/resnet.py +0 -0
  30. {wavedl-1.5.5 → wavedl-1.5.7}/src/wavedl/models/resnet3d.py +0 -0
  31. {wavedl-1.5.5 → wavedl-1.5.7}/src/wavedl/models/unet.py +0 -0
  32. {wavedl-1.5.5 → wavedl-1.5.7}/src/wavedl/utils/__init__.py +0 -0
  33. {wavedl-1.5.5 → wavedl-1.5.7}/src/wavedl/utils/config.py +0 -0
  34. {wavedl-1.5.5 → wavedl-1.5.7}/src/wavedl/utils/cross_validation.py +0 -0
  35. {wavedl-1.5.5 → wavedl-1.5.7}/src/wavedl/utils/distributed.py +0 -0
  36. {wavedl-1.5.5 → wavedl-1.5.7}/src/wavedl/utils/losses.py +0 -0
  37. {wavedl-1.5.5 → wavedl-1.5.7}/src/wavedl/utils/optimizers.py +0 -0
  38. {wavedl-1.5.5 → wavedl-1.5.7}/src/wavedl/utils/schedulers.py +0 -0
  39. {wavedl-1.5.5 → wavedl-1.5.7}/src/wavedl.egg-info/SOURCES.txt +0 -0
  40. {wavedl-1.5.5 → wavedl-1.5.7}/src/wavedl.egg-info/dependency_links.txt +0 -0
  41. {wavedl-1.5.5 → wavedl-1.5.7}/src/wavedl.egg-info/entry_points.txt +0 -0
  42. {wavedl-1.5.5 → wavedl-1.5.7}/src/wavedl.egg-info/requires.txt +0 -0
  43. {wavedl-1.5.5 → wavedl-1.5.7}/src/wavedl.egg-info/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.2
2
2
  Name: wavedl
3
- Version: 1.5.5
3
+ Version: 1.5.7
4
4
  Summary: A Scalable Deep Learning Framework for Wave-Based Inverse Problems
5
5
  Author: Ductho Le
6
6
  License: MIT
@@ -388,7 +388,7 @@ WaveDL/
388
388
  ├── configs/ # YAML config templates
389
389
  ├── examples/ # Ready-to-run examples
390
390
  ├── notebooks/ # Jupyter notebooks
391
- ├── unit_tests/ # Pytest test suite (731 tests)
391
+ ├── unit_tests/ # Pytest test suite (903 tests)
392
392
 
393
393
  ├── pyproject.toml # Package config, dependencies
394
394
  ├── CHANGELOG.md # Version history
@@ -470,6 +470,7 @@ WaveDL/
470
470
  ⭐ = **Pretrained on ImageNet** (recommended for smaller datasets). Weights are downloaded automatically on first use.
471
471
  - **Cache location**: `~/.cache/torch/hub/checkpoints/` (or `./.torch_cache/` on HPC if home is not writable)
472
472
  - **Size**: ~20–350 MB per model depending on architecture
473
+ - **Train from scratch**: Use `--no_pretrained` to disable pretrained weights
473
474
 
474
475
  **💡 HPC Users**: If compute nodes block internet, pre-download weights on the login node:
475
476
 
@@ -1030,37 +1031,46 @@ print(f"✓ Output: {data['output_train'].shape} {data['output_train'].dtype}")
1030
1031
 
1031
1032
  ## 📦 Examples [![Try it on Colab](https://img.shields.io/badge/Try_it_on_Colab-8E44AD?style=plastic&logo=googlecolab&logoColor=white)](https://colab.research.google.com/github/ductho-le/WaveDL/blob/main/notebooks/demo.ipynb)
1032
1033
 
1033
- The `examples/` folder contains a **complete, ready-to-run example** for **material characterization of isotropic plates**. The pre-trained CNN predicts three physical parameters from Lamb wave dispersion curves:
1034
+ The `examples/` folder contains a **complete, ready-to-run example** for **material characterization of isotropic plates**. The pre-trained MobileNetV3 predicts three physical parameters from Lamb wave dispersion curves:
1034
1035
 
1035
1036
  | Parameter | Unit | Description |
1036
1037
  |-----------|------|-------------|
1037
- | *h* | mm | Plate thickness |
1038
- | √(*E*/ρ) | km/s | Square root of Young's modulus over density |
1039
- | *ν* | — | Poisson's ratio |
1038
+ | $h$ | mm | Plate thickness |
1039
+ | $\sqrt{E/\rho}$ | km/s | Square root of Young's modulus over density |
1040
+ | $\nu$ | — | Poisson's ratio |
1040
1041
 
1041
1042
  > [!NOTE]
1042
- > This example is based on our paper at **SPIE Smart Structures + NDE 2026**: [*"Deep learning-based ultrasonic assessment of plate thickness and elasticity"*](https://spie.org/spie-smart-structures-and-materials-nondestructive-evaluation/presentation/Deep-learningbased-ultrasonic-assessment-of-plate-thickness-and-elasticity/13951-4) (Paper 13951-4, to appear).
1043
+ > This example is based on our paper at **SPIE Smart Structures + NDE 2026**: [*"A lightweight deep learning model for ultrasonic assessment of plate thickness and elasticity
1044
+ "*](https://spie.org/spie-smart-structures-and-materials-nondestructive-evaluation/presentation/A-lightweight-deep-learning-model-for-ultrasonic-assessment-of-plate/13951-4) (Paper 13951-4, to appear).
1045
+
1046
+ **Sample Dispersion Data:**
1047
+
1048
+ <p align="center">
1049
+ <img src="examples/elasticity_prediction/dispersion_samples.png" alt="Dispersion curve samples" width="700"><br>
1050
+ <em>Test samples showing the wavenumber-frequency relationship for different plate properties</em>
1051
+ </p>
1043
1052
 
1044
1053
  **Try it yourself:**
1045
1054
 
1046
1055
  ```bash
1047
1056
  # Run inference on the example data
1048
- python -m wavedl.test --checkpoint ./examples/elastic_cnn_example/best_checkpoint \
1049
- --data_path ./examples/elastic_cnn_example/Test_data_500.mat \
1050
- --plot --save_predictions --output_dir ./examples/elastic_cnn_example/test_results
1057
+ python -m wavedl.test --checkpoint ./examples/elasticity_prediction/best_checkpoint \
1058
+ --data_path ./examples/elasticity_prediction/Test_data_100.mat \
1059
+ --plot --save_predictions --output_dir ./examples/elasticity_prediction/test_results
1051
1060
 
1052
1061
  # Export to ONNX (already included as model.onnx)
1053
- python -m wavedl.test --checkpoint ./examples/elastic_cnn_example/best_checkpoint \
1054
- --data_path ./examples/elastic_cnn_example/Test_data_500.mat \
1055
- --export onnx --export_path ./examples/elastic_cnn_example/model.onnx
1062
+ python -m wavedl.test --checkpoint ./examples/elasticity_prediction/best_checkpoint \
1063
+ --data_path ./examples/elasticity_prediction/Test_data_100.mat \
1064
+ --export onnx --export_path ./examples/elasticity_prediction/model.onnx
1056
1065
  ```
1057
1066
 
1058
1067
  **What's Included:**
1059
1068
 
1060
1069
  | File | Description |
1061
1070
  |------|-------------|
1062
- | `best_checkpoint/` | Pre-trained CNN checkpoint |
1063
- | `Test_data_500.mat` | 500 sample test set (500×500 dispersion curves → *h*, √(*E*/ρ), *ν*) |
1071
+ | `best_checkpoint/` | Pre-trained MobileNetV3 checkpoint |
1072
+ | `Test_data_100.mat` | 100 sample test set (500×500 dispersion curves → $h$, $\sqrt{E/\rho}$, $\nu$) |
1073
+ | `dispersion_samples.png` | Visualization of sample dispersion curves with material parameters |
1064
1074
  | `model.onnx` | ONNX export with embedded de-normalization |
1065
1075
  | `training_history.csv` | Epoch-by-epoch training metrics (loss, R², LR, etc.) |
1066
1076
  | `training_curves.png` | Training/validation loss and learning rate plot |
@@ -1070,59 +1080,59 @@ python -m wavedl.test --checkpoint ./examples/elastic_cnn_example/best_checkpoin
1070
1080
  **Training Progress:**
1071
1081
 
1072
1082
  <p align="center">
1073
- <img src="examples/elastic_cnn_example/training_curves.png" alt="Training curves" width="600"><br>
1074
- <em>Training and validation loss over 227 epochs with <code>onecycle</code> learning rate schedule</em>
1083
+ <img src="examples/elasticity_prediction/training_curves.png" alt="Training curves" width="600"><br>
1084
+ <em>Training and validation loss with <code>plateau</code> learning rate schedule</em>
1075
1085
  </p>
1076
1086
 
1077
1087
  **Inference Results:**
1078
1088
 
1079
1089
  <p align="center">
1080
- <img src="examples/elastic_cnn_example/test_results/scatter_all.png" alt="Scatter plot" width="700"><br>
1090
+ <img src="examples/elasticity_prediction/test_results/scatter_all.png" alt="Scatter plot" width="700"><br>
1081
1091
  <em>Figure 1: Predictions vs ground truth for all three elastic parameters</em>
1082
1092
  </p>
1083
1093
 
1084
1094
  <p align="center">
1085
- <img src="examples/elastic_cnn_example/test_results/error_histogram.png" alt="Error histogram" width="700"><br>
1095
+ <img src="examples/elasticity_prediction/test_results/error_histogram.png" alt="Error histogram" width="700"><br>
1086
1096
  <em>Figure 2: Distribution of prediction errors showing near-zero mean bias</em>
1087
1097
  </p>
1088
1098
 
1089
1099
  <p align="center">
1090
- <img src="examples/elastic_cnn_example/test_results/residuals.png" alt="Residual plot" width="700"><br>
1100
+ <img src="examples/elasticity_prediction/test_results/residuals.png" alt="Residual plot" width="700"><br>
1091
1101
  <em>Figure 3: Residuals vs predicted values (no heteroscedasticity detected)</em>
1092
1102
  </p>
1093
1103
 
1094
1104
  <p align="center">
1095
- <img src="examples/elastic_cnn_example/test_results/bland_altman.png" alt="Bland-Altman plot" width="700"><br>
1105
+ <img src="examples/elasticity_prediction/test_results/bland_altman.png" alt="Bland-Altman plot" width="700"><br>
1096
1106
  <em>Figure 4: Bland-Altman analysis with ±1.96 SD limits of agreement</em>
1097
1107
  </p>
1098
1108
 
1099
1109
  <p align="center">
1100
- <img src="examples/elastic_cnn_example/test_results/qq_plot.png" alt="Q-Q plot" width="700"><br>
1110
+ <img src="examples/elasticity_prediction/test_results/qq_plot.png" alt="Q-Q plot" width="700"><br>
1101
1111
  <em>Figure 5: Q-Q plots confirming normally distributed prediction errors</em>
1102
1112
  </p>
1103
1113
 
1104
1114
  <p align="center">
1105
- <img src="examples/elastic_cnn_example/test_results/error_correlation.png" alt="Error correlation" width="300"><br>
1115
+ <img src="examples/elasticity_prediction/test_results/error_correlation.png" alt="Error correlation" width="300"><br>
1106
1116
  <em>Figure 6: Error correlation matrix between parameters</em>
1107
1117
  </p>
1108
1118
 
1109
1119
  <p align="center">
1110
- <img src="examples/elastic_cnn_example/test_results/relative_error.png" alt="Relative error" width="700"><br>
1120
+ <img src="examples/elasticity_prediction/test_results/relative_error.png" alt="Relative error" width="700"><br>
1111
1121
  <em>Figure 7: Relative error (%) vs true value for each parameter</em>
1112
1122
  </p>
1113
1123
 
1114
1124
  <p align="center">
1115
- <img src="examples/elastic_cnn_example/test_results/error_cdf.png" alt="Error CDF" width="500"><br>
1125
+ <img src="examples/elasticity_prediction/test_results/error_cdf.png" alt="Error CDF" width="500"><br>
1116
1126
  <em>Figure 8: Cumulative error distribution — 95% of predictions within indicated bounds</em>
1117
1127
  </p>
1118
1128
 
1119
1129
  <p align="center">
1120
- <img src="examples/elastic_cnn_example/test_results/prediction_vs_index.png" alt="Prediction vs index" width="700"><br>
1130
+ <img src="examples/elasticity_prediction/test_results/prediction_vs_index.png" alt="Prediction vs index" width="700"><br>
1121
1131
  <em>Figure 9: True vs predicted values by sample index</em>
1122
1132
  </p>
1123
1133
 
1124
1134
  <p align="center">
1125
- <img src="examples/elastic_cnn_example/test_results/error_boxplot.png" alt="Error box plot" width="400"><br>
1135
+ <img src="examples/elasticity_prediction/test_results/error_boxplot.png" alt="Error box plot" width="400"><br>
1126
1136
  <em>Figure 10: Error distribution summary (median, quartiles, outliers)</em>
1127
1137
  </p>
1128
1138
 
@@ -342,7 +342,7 @@ WaveDL/
342
342
  ├── configs/ # YAML config templates
343
343
  ├── examples/ # Ready-to-run examples
344
344
  ├── notebooks/ # Jupyter notebooks
345
- ├── unit_tests/ # Pytest test suite (731 tests)
345
+ ├── unit_tests/ # Pytest test suite (903 tests)
346
346
 
347
347
  ├── pyproject.toml # Package config, dependencies
348
348
  ├── CHANGELOG.md # Version history
@@ -424,6 +424,7 @@ WaveDL/
424
424
  ⭐ = **Pretrained on ImageNet** (recommended for smaller datasets). Weights are downloaded automatically on first use.
425
425
  - **Cache location**: `~/.cache/torch/hub/checkpoints/` (or `./.torch_cache/` on HPC if home is not writable)
426
426
  - **Size**: ~20–350 MB per model depending on architecture
427
+ - **Train from scratch**: Use `--no_pretrained` to disable pretrained weights
427
428
 
428
429
  **💡 HPC Users**: If compute nodes block internet, pre-download weights on the login node:
429
430
 
@@ -984,37 +985,46 @@ print(f"✓ Output: {data['output_train'].shape} {data['output_train'].dtype}")
984
985
 
985
986
  ## 📦 Examples [![Try it on Colab](https://img.shields.io/badge/Try_it_on_Colab-8E44AD?style=plastic&logo=googlecolab&logoColor=white)](https://colab.research.google.com/github/ductho-le/WaveDL/blob/main/notebooks/demo.ipynb)
986
987
 
987
- The `examples/` folder contains a **complete, ready-to-run example** for **material characterization of isotropic plates**. The pre-trained CNN predicts three physical parameters from Lamb wave dispersion curves:
988
+ The `examples/` folder contains a **complete, ready-to-run example** for **material characterization of isotropic plates**. The pre-trained MobileNetV3 predicts three physical parameters from Lamb wave dispersion curves:
988
989
 
989
990
  | Parameter | Unit | Description |
990
991
  |-----------|------|-------------|
991
- | *h* | mm | Plate thickness |
992
- | √(*E*/ρ) | km/s | Square root of Young's modulus over density |
993
- | *ν* | — | Poisson's ratio |
992
+ | $h$ | mm | Plate thickness |
993
+ | $\sqrt{E/\rho}$ | km/s | Square root of Young's modulus over density |
994
+ | $\nu$ | — | Poisson's ratio |
994
995
 
995
996
  > [!NOTE]
996
- > This example is based on our paper at **SPIE Smart Structures + NDE 2026**: [*"Deep learning-based ultrasonic assessment of plate thickness and elasticity"*](https://spie.org/spie-smart-structures-and-materials-nondestructive-evaluation/presentation/Deep-learningbased-ultrasonic-assessment-of-plate-thickness-and-elasticity/13951-4) (Paper 13951-4, to appear).
997
+ > This example is based on our paper at **SPIE Smart Structures + NDE 2026**: [*"A lightweight deep learning model for ultrasonic assessment of plate thickness and elasticity
998
+ "*](https://spie.org/spie-smart-structures-and-materials-nondestructive-evaluation/presentation/A-lightweight-deep-learning-model-for-ultrasonic-assessment-of-plate/13951-4) (Paper 13951-4, to appear).
999
+
1000
+ **Sample Dispersion Data:**
1001
+
1002
+ <p align="center">
1003
+ <img src="examples/elasticity_prediction/dispersion_samples.png" alt="Dispersion curve samples" width="700"><br>
1004
+ <em>Test samples showing the wavenumber-frequency relationship for different plate properties</em>
1005
+ </p>
997
1006
 
998
1007
  **Try it yourself:**
999
1008
 
1000
1009
  ```bash
1001
1010
  # Run inference on the example data
1002
- python -m wavedl.test --checkpoint ./examples/elastic_cnn_example/best_checkpoint \
1003
- --data_path ./examples/elastic_cnn_example/Test_data_500.mat \
1004
- --plot --save_predictions --output_dir ./examples/elastic_cnn_example/test_results
1011
+ python -m wavedl.test --checkpoint ./examples/elasticity_prediction/best_checkpoint \
1012
+ --data_path ./examples/elasticity_prediction/Test_data_100.mat \
1013
+ --plot --save_predictions --output_dir ./examples/elasticity_prediction/test_results
1005
1014
 
1006
1015
  # Export to ONNX (already included as model.onnx)
1007
- python -m wavedl.test --checkpoint ./examples/elastic_cnn_example/best_checkpoint \
1008
- --data_path ./examples/elastic_cnn_example/Test_data_500.mat \
1009
- --export onnx --export_path ./examples/elastic_cnn_example/model.onnx
1016
+ python -m wavedl.test --checkpoint ./examples/elasticity_prediction/best_checkpoint \
1017
+ --data_path ./examples/elasticity_prediction/Test_data_100.mat \
1018
+ --export onnx --export_path ./examples/elasticity_prediction/model.onnx
1010
1019
  ```
1011
1020
 
1012
1021
  **What's Included:**
1013
1022
 
1014
1023
  | File | Description |
1015
1024
  |------|-------------|
1016
- | `best_checkpoint/` | Pre-trained CNN checkpoint |
1017
- | `Test_data_500.mat` | 500 sample test set (500×500 dispersion curves → *h*, √(*E*/ρ), *ν*) |
1025
+ | `best_checkpoint/` | Pre-trained MobileNetV3 checkpoint |
1026
+ | `Test_data_100.mat` | 100 sample test set (500×500 dispersion curves → $h$, $\sqrt{E/\rho}$, $\nu$) |
1027
+ | `dispersion_samples.png` | Visualization of sample dispersion curves with material parameters |
1018
1028
  | `model.onnx` | ONNX export with embedded de-normalization |
1019
1029
  | `training_history.csv` | Epoch-by-epoch training metrics (loss, R², LR, etc.) |
1020
1030
  | `training_curves.png` | Training/validation loss and learning rate plot |
@@ -1024,59 +1034,59 @@ python -m wavedl.test --checkpoint ./examples/elastic_cnn_example/best_checkpoin
1024
1034
  **Training Progress:**
1025
1035
 
1026
1036
  <p align="center">
1027
- <img src="examples/elastic_cnn_example/training_curves.png" alt="Training curves" width="600"><br>
1028
- <em>Training and validation loss over 227 epochs with <code>onecycle</code> learning rate schedule</em>
1037
+ <img src="examples/elasticity_prediction/training_curves.png" alt="Training curves" width="600"><br>
1038
+ <em>Training and validation loss with <code>plateau</code> learning rate schedule</em>
1029
1039
  </p>
1030
1040
 
1031
1041
  **Inference Results:**
1032
1042
 
1033
1043
  <p align="center">
1034
- <img src="examples/elastic_cnn_example/test_results/scatter_all.png" alt="Scatter plot" width="700"><br>
1044
+ <img src="examples/elasticity_prediction/test_results/scatter_all.png" alt="Scatter plot" width="700"><br>
1035
1045
  <em>Figure 1: Predictions vs ground truth for all three elastic parameters</em>
1036
1046
  </p>
1037
1047
 
1038
1048
  <p align="center">
1039
- <img src="examples/elastic_cnn_example/test_results/error_histogram.png" alt="Error histogram" width="700"><br>
1049
+ <img src="examples/elasticity_prediction/test_results/error_histogram.png" alt="Error histogram" width="700"><br>
1040
1050
  <em>Figure 2: Distribution of prediction errors showing near-zero mean bias</em>
1041
1051
  </p>
1042
1052
 
1043
1053
  <p align="center">
1044
- <img src="examples/elastic_cnn_example/test_results/residuals.png" alt="Residual plot" width="700"><br>
1054
+ <img src="examples/elasticity_prediction/test_results/residuals.png" alt="Residual plot" width="700"><br>
1045
1055
  <em>Figure 3: Residuals vs predicted values (no heteroscedasticity detected)</em>
1046
1056
  </p>
1047
1057
 
1048
1058
  <p align="center">
1049
- <img src="examples/elastic_cnn_example/test_results/bland_altman.png" alt="Bland-Altman plot" width="700"><br>
1059
+ <img src="examples/elasticity_prediction/test_results/bland_altman.png" alt="Bland-Altman plot" width="700"><br>
1050
1060
  <em>Figure 4: Bland-Altman analysis with ±1.96 SD limits of agreement</em>
1051
1061
  </p>
1052
1062
 
1053
1063
  <p align="center">
1054
- <img src="examples/elastic_cnn_example/test_results/qq_plot.png" alt="Q-Q plot" width="700"><br>
1064
+ <img src="examples/elasticity_prediction/test_results/qq_plot.png" alt="Q-Q plot" width="700"><br>
1055
1065
  <em>Figure 5: Q-Q plots confirming normally distributed prediction errors</em>
1056
1066
  </p>
1057
1067
 
1058
1068
  <p align="center">
1059
- <img src="examples/elastic_cnn_example/test_results/error_correlation.png" alt="Error correlation" width="300"><br>
1069
+ <img src="examples/elasticity_prediction/test_results/error_correlation.png" alt="Error correlation" width="300"><br>
1060
1070
  <em>Figure 6: Error correlation matrix between parameters</em>
1061
1071
  </p>
1062
1072
 
1063
1073
  <p align="center">
1064
- <img src="examples/elastic_cnn_example/test_results/relative_error.png" alt="Relative error" width="700"><br>
1074
+ <img src="examples/elasticity_prediction/test_results/relative_error.png" alt="Relative error" width="700"><br>
1065
1075
  <em>Figure 7: Relative error (%) vs true value for each parameter</em>
1066
1076
  </p>
1067
1077
 
1068
1078
  <p align="center">
1069
- <img src="examples/elastic_cnn_example/test_results/error_cdf.png" alt="Error CDF" width="500"><br>
1079
+ <img src="examples/elasticity_prediction/test_results/error_cdf.png" alt="Error CDF" width="500"><br>
1070
1080
  <em>Figure 8: Cumulative error distribution — 95% of predictions within indicated bounds</em>
1071
1081
  </p>
1072
1082
 
1073
1083
  <p align="center">
1074
- <img src="examples/elastic_cnn_example/test_results/prediction_vs_index.png" alt="Prediction vs index" width="700"><br>
1084
+ <img src="examples/elasticity_prediction/test_results/prediction_vs_index.png" alt="Prediction vs index" width="700"><br>
1075
1085
  <em>Figure 9: True vs predicted values by sample index</em>
1076
1086
  </p>
1077
1087
 
1078
1088
  <p align="center">
1079
- <img src="examples/elastic_cnn_example/test_results/error_boxplot.png" alt="Error box plot" width="400"><br>
1089
+ <img src="examples/elasticity_prediction/test_results/error_boxplot.png" alt="Error box plot" width="400"><br>
1080
1090
  <em>Figure 10: Error distribution summary (median, quartiles, outliers)</em>
1081
1091
  </p>
1082
1092
 
@@ -18,7 +18,7 @@ For inference:
18
18
  # or: python -m wavedl.test --checkpoint best_checkpoint --data_path test.npz
19
19
  """
20
20
 
21
- __version__ = "1.5.5"
21
+ __version__ = "1.5.7"
22
22
  __author__ = "Ductho Le"
23
23
  __email__ = "ductho.le@outlook.com"
24
24
 
@@ -110,9 +110,30 @@ class EfficientNetBase(BaseModel):
110
110
  self._freeze_backbone()
111
111
 
112
112
  def _adapt_input_channels(self):
113
- """Modify first conv to handle single-channel input by expanding to 3ch."""
114
- # We'll handle this in forward by repeating channels
115
- pass
113
+ """Modify first conv to accept single-channel input.
114
+
115
+ Instead of expanding 1→3 channels in forward (which triples memory),
116
+ we replace the first conv layer with a 1-channel version and initialize
117
+ weights as the mean of the pretrained RGB filters.
118
+ """
119
+ # EfficientNet stem conv is at: features[0][0]
120
+ old_conv = self.backbone.features[0][0]
121
+ new_conv = nn.Conv2d(
122
+ 1, # Single channel input
123
+ old_conv.out_channels,
124
+ kernel_size=old_conv.kernel_size,
125
+ stride=old_conv.stride,
126
+ padding=old_conv.padding,
127
+ dilation=old_conv.dilation,
128
+ groups=old_conv.groups,
129
+ padding_mode=old_conv.padding_mode,
130
+ bias=old_conv.bias is not None,
131
+ )
132
+ if self.pretrained:
133
+ # Initialize with mean of pretrained RGB weights
134
+ with torch.no_grad():
135
+ new_conv.weight.copy_(old_conv.weight.mean(dim=1, keepdim=True))
136
+ self.backbone.features[0][0] = new_conv
116
137
 
117
138
  def _freeze_backbone(self):
118
139
  """Freeze all backbone parameters except the classifier."""
@@ -130,10 +151,6 @@ class EfficientNetBase(BaseModel):
130
151
  Returns:
131
152
  Output tensor of shape (B, out_size)
132
153
  """
133
- # Expand single channel to 3 channels for pretrained weights
134
- if x.size(1) == 1:
135
- x = x.expand(-1, 3, -1, -1)
136
-
137
154
  return self.backbone(x)
138
155
 
139
156
  @classmethod
@@ -129,10 +129,37 @@ class EfficientNetV2Base(BaseModel):
129
129
  nn.Linear(regression_hidden // 2, out_size),
130
130
  )
131
131
 
132
- # Optionally freeze backbone for fine-tuning
132
+ # Adapt first conv for single-channel input (3× memory savings vs expand)
133
+ self._adapt_input_channels()
134
+
135
+ # Optionally freeze backbone for fine-tuning (after adaptation so new conv is frozen too)
133
136
  if freeze_backbone:
134
137
  self._freeze_backbone()
135
138
 
139
+ def _adapt_input_channels(self):
140
+ """Modify first conv to accept single-channel input.
141
+
142
+ Instead of expanding 1→3 channels in forward (which triples memory),
143
+ we replace the first conv layer with a 1-channel version and initialize
144
+ weights as the mean of the pretrained RGB filters.
145
+ """
146
+ old_conv = self.backbone.features[0][0]
147
+ new_conv = nn.Conv2d(
148
+ 1, # Single channel input
149
+ old_conv.out_channels,
150
+ kernel_size=old_conv.kernel_size,
151
+ stride=old_conv.stride,
152
+ padding=old_conv.padding,
153
+ dilation=old_conv.dilation,
154
+ groups=old_conv.groups,
155
+ padding_mode=old_conv.padding_mode,
156
+ bias=old_conv.bias is not None,
157
+ )
158
+ if self.pretrained:
159
+ with torch.no_grad():
160
+ new_conv.weight.copy_(old_conv.weight.mean(dim=1, keepdim=True))
161
+ self.backbone.features[0][0] = new_conv
162
+
136
163
  def _freeze_backbone(self):
137
164
  """Freeze all backbone parameters except the classifier."""
138
165
  for name, param in self.backbone.named_parameters():
@@ -144,15 +171,11 @@ class EfficientNetV2Base(BaseModel):
144
171
  Forward pass.
145
172
 
146
173
  Args:
147
- x: Input tensor of shape (B, C, H, W) where C is 1 or 3
174
+ x: Input tensor of shape (B, 1, H, W)
148
175
 
149
176
  Returns:
150
177
  Output tensor of shape (B, out_size)
151
178
  """
152
- # Expand single channel to 3 channels for pretrained weights compatibility
153
- if x.size(1) == 1:
154
- x = x.expand(-1, 3, -1, -1)
155
-
156
179
  return self.backbone(x)
157
180
 
158
181
  @classmethod
@@ -136,10 +136,37 @@ class MobileNetV3Base(BaseModel):
136
136
  nn.Linear(regression_hidden, out_size),
137
137
  )
138
138
 
139
- # Optionally freeze backbone for fine-tuning
139
+ # Adapt first conv for single-channel input (3× memory savings vs expand)
140
+ self._adapt_input_channels()
141
+
142
+ # Optionally freeze backbone for fine-tuning (after adaptation so new conv is frozen too)
140
143
  if freeze_backbone:
141
144
  self._freeze_backbone()
142
145
 
146
+ def _adapt_input_channels(self):
147
+ """Modify first conv to accept single-channel input.
148
+
149
+ Instead of expanding 1→3 channels in forward (which triples memory),
150
+ we replace the first conv layer with a 1-channel version and initialize
151
+ weights as the mean of the pretrained RGB filters.
152
+ """
153
+ old_conv = self.backbone.features[0][0]
154
+ new_conv = nn.Conv2d(
155
+ 1, # Single channel input
156
+ old_conv.out_channels,
157
+ kernel_size=old_conv.kernel_size,
158
+ stride=old_conv.stride,
159
+ padding=old_conv.padding,
160
+ dilation=old_conv.dilation,
161
+ groups=old_conv.groups,
162
+ padding_mode=old_conv.padding_mode,
163
+ bias=old_conv.bias is not None,
164
+ )
165
+ if self.pretrained:
166
+ with torch.no_grad():
167
+ new_conv.weight.copy_(old_conv.weight.mean(dim=1, keepdim=True))
168
+ self.backbone.features[0][0] = new_conv
169
+
143
170
  def _freeze_backbone(self):
144
171
  """Freeze all backbone parameters except the classifier."""
145
172
  for name, param in self.backbone.named_parameters():
@@ -151,15 +178,11 @@ class MobileNetV3Base(BaseModel):
151
178
  Forward pass.
152
179
 
153
180
  Args:
154
- x: Input tensor of shape (B, C, H, W) where C is 1 or 3
181
+ x: Input tensor of shape (B, 1, H, W)
155
182
 
156
183
  Returns:
157
184
  Output tensor of shape (B, out_size)
158
185
  """
159
- # Expand single channel to 3 channels for pretrained weights compatibility
160
- if x.size(1) == 1:
161
- x = x.expand(-1, 3, -1, -1)
162
-
163
186
  return self.backbone(x)
164
187
 
165
188
  @classmethod
@@ -194,7 +217,7 @@ class MobileNetV3Small(MobileNetV3Base):
194
217
 
195
218
  Performance (approximate):
196
219
  - CPU inference: ~6ms (single core)
197
- - Parameters: 2.5M
220
+ - Parameters: ~1.1M
198
221
  - MAdds: 56M
199
222
 
200
223
  Args:
@@ -241,7 +264,7 @@ class MobileNetV3Large(MobileNetV3Base):
241
264
 
242
265
  Performance (approximate):
243
266
  - CPU inference: ~20ms (single core)
244
- - Parameters: 5.4M
267
+ - Parameters: ~3.2M
245
268
  - MAdds: 219M
246
269
 
247
270
  Args:
@@ -140,10 +140,37 @@ class RegNetBase(BaseModel):
140
140
  nn.Linear(regression_hidden, out_size),
141
141
  )
142
142
 
143
- # Optionally freeze backbone for fine-tuning
143
+ # Adapt first conv for single-channel input (3× memory savings vs expand)
144
+ self._adapt_input_channels()
145
+
146
+ # Optionally freeze backbone for fine-tuning (after adaptation so new conv is frozen too)
144
147
  if freeze_backbone:
145
148
  self._freeze_backbone()
146
149
 
150
+ def _adapt_input_channels(self):
151
+ """Modify first conv to accept single-channel input.
152
+
153
+ Instead of expanding 1→3 channels in forward (which triples memory),
154
+ we replace the first conv layer with a 1-channel version and initialize
155
+ weights as the mean of the pretrained RGB filters.
156
+ """
157
+ old_conv = self.backbone.stem[0]
158
+ new_conv = nn.Conv2d(
159
+ 1, # Single channel input
160
+ old_conv.out_channels,
161
+ kernel_size=old_conv.kernel_size,
162
+ stride=old_conv.stride,
163
+ padding=old_conv.padding,
164
+ dilation=old_conv.dilation,
165
+ groups=old_conv.groups,
166
+ padding_mode=old_conv.padding_mode,
167
+ bias=old_conv.bias is not None,
168
+ )
169
+ if self.pretrained:
170
+ with torch.no_grad():
171
+ new_conv.weight.copy_(old_conv.weight.mean(dim=1, keepdim=True))
172
+ self.backbone.stem[0] = new_conv
173
+
147
174
  def _freeze_backbone(self):
148
175
  """Freeze all backbone parameters except the fc layer."""
149
176
  for name, param in self.backbone.named_parameters():
@@ -155,15 +182,11 @@ class RegNetBase(BaseModel):
155
182
  Forward pass.
156
183
 
157
184
  Args:
158
- x: Input tensor of shape (B, C, H, W) where C is 1 or 3
185
+ x: Input tensor of shape (B, 1, H, W)
159
186
 
160
187
  Returns:
161
188
  Output tensor of shape (B, out_size)
162
189
  """
163
- # Expand single channel to 3 channels for pretrained weights compatibility
164
- if x.size(1) == 1:
165
- x = x.expand(-1, 3, -1, -1)
166
-
167
190
  return self.backbone(x)
168
191
 
169
192
  @classmethod
@@ -141,10 +141,46 @@ class SwinTransformerBase(BaseModel):
141
141
  nn.Linear(regression_hidden // 2, out_size),
142
142
  )
143
143
 
144
- # Optionally freeze backbone for fine-tuning
144
+ # Adapt patch embedding conv for single-channel input (3× memory savings vs expand)
145
+ self._adapt_input_channels()
146
+
147
+ # Optionally freeze backbone for fine-tuning (after adaptation so new conv is frozen too)
145
148
  if freeze_backbone:
146
149
  self._freeze_backbone()
147
150
 
151
+ def _adapt_input_channels(self):
152
+ """Modify patch embedding conv to accept single-channel input.
153
+
154
+ Instead of expanding 1→3 channels in forward (which triples memory),
155
+ we replace the patch embedding conv with a 1-channel version and
156
+ initialize weights as the mean of the pretrained RGB filters.
157
+ """
158
+ # Swin's patch embedding is at features[0][0]
159
+ try:
160
+ old_conv = self.backbone.features[0][0]
161
+ except (IndexError, AttributeError, TypeError) as e:
162
+ raise RuntimeError(
163
+ f"Swin patch embed structure changed in this torchvision version. "
164
+ f"Cannot adapt input channels. Error: {e}"
165
+ ) from e
166
+ new_conv = nn.Conv2d(
167
+ 1, # Single channel input
168
+ old_conv.out_channels,
169
+ kernel_size=old_conv.kernel_size,
170
+ stride=old_conv.stride,
171
+ padding=old_conv.padding,
172
+ dilation=old_conv.dilation,
173
+ groups=old_conv.groups,
174
+ padding_mode=old_conv.padding_mode,
175
+ bias=old_conv.bias is not None,
176
+ )
177
+ if self.pretrained:
178
+ with torch.no_grad():
179
+ new_conv.weight.copy_(old_conv.weight.mean(dim=1, keepdim=True))
180
+ if old_conv.bias is not None:
181
+ new_conv.bias.copy_(old_conv.bias)
182
+ self.backbone.features[0][0] = new_conv
183
+
148
184
  def _freeze_backbone(self):
149
185
  """Freeze all backbone parameters except the head."""
150
186
  for name, param in self.backbone.named_parameters():
@@ -156,15 +192,11 @@ class SwinTransformerBase(BaseModel):
156
192
  Forward pass.
157
193
 
158
194
  Args:
159
- x: Input tensor of shape (B, C, H, W) where C is 1 or 3
195
+ x: Input tensor of shape (B, 1, H, W)
160
196
 
161
197
  Returns:
162
198
  Output tensor of shape (B, out_size)
163
199
  """
164
- # Expand single channel to 3 channels for pretrained weights compatibility
165
- if x.size(1) == 1:
166
- x = x.expand(-1, 3, -1, -1)
167
-
168
200
  return self.backbone(x)
169
201
 
170
202
  @classmethod