wavedl 1.5.4__tar.gz → 1.5.6__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (43) hide show
  1. {wavedl-1.5.4/src/wavedl.egg-info → wavedl-1.5.6}/PKG-INFO +24 -23
  2. {wavedl-1.5.4 → wavedl-1.5.6}/README.md +23 -22
  3. {wavedl-1.5.4 → wavedl-1.5.6}/src/wavedl/__init__.py +1 -1
  4. {wavedl-1.5.4 → wavedl-1.5.6}/src/wavedl/hpo.py +2 -1
  5. {wavedl-1.5.4 → wavedl-1.5.6}/src/wavedl/models/vit.py +85 -25
  6. {wavedl-1.5.4 → wavedl-1.5.6}/src/wavedl/train.py +46 -14
  7. {wavedl-1.5.4 → wavedl-1.5.6}/src/wavedl/utils/data.py +135 -49
  8. {wavedl-1.5.4 → wavedl-1.5.6}/src/wavedl/utils/metrics.py +22 -1
  9. {wavedl-1.5.4 → wavedl-1.5.6/src/wavedl.egg-info}/PKG-INFO +24 -23
  10. {wavedl-1.5.4 → wavedl-1.5.6}/LICENSE +0 -0
  11. {wavedl-1.5.4 → wavedl-1.5.6}/pyproject.toml +0 -0
  12. {wavedl-1.5.4 → wavedl-1.5.6}/setup.cfg +0 -0
  13. {wavedl-1.5.4 → wavedl-1.5.6}/src/wavedl/hpc.py +0 -0
  14. {wavedl-1.5.4 → wavedl-1.5.6}/src/wavedl/models/__init__.py +0 -0
  15. {wavedl-1.5.4 → wavedl-1.5.6}/src/wavedl/models/_template.py +0 -0
  16. {wavedl-1.5.4 → wavedl-1.5.6}/src/wavedl/models/base.py +0 -0
  17. {wavedl-1.5.4 → wavedl-1.5.6}/src/wavedl/models/cnn.py +0 -0
  18. {wavedl-1.5.4 → wavedl-1.5.6}/src/wavedl/models/convnext.py +0 -0
  19. {wavedl-1.5.4 → wavedl-1.5.6}/src/wavedl/models/densenet.py +0 -0
  20. {wavedl-1.5.4 → wavedl-1.5.6}/src/wavedl/models/efficientnet.py +0 -0
  21. {wavedl-1.5.4 → wavedl-1.5.6}/src/wavedl/models/efficientnetv2.py +0 -0
  22. {wavedl-1.5.4 → wavedl-1.5.6}/src/wavedl/models/mobilenetv3.py +0 -0
  23. {wavedl-1.5.4 → wavedl-1.5.6}/src/wavedl/models/registry.py +0 -0
  24. {wavedl-1.5.4 → wavedl-1.5.6}/src/wavedl/models/regnet.py +0 -0
  25. {wavedl-1.5.4 → wavedl-1.5.6}/src/wavedl/models/resnet.py +0 -0
  26. {wavedl-1.5.4 → wavedl-1.5.6}/src/wavedl/models/resnet3d.py +0 -0
  27. {wavedl-1.5.4 → wavedl-1.5.6}/src/wavedl/models/swin.py +0 -0
  28. {wavedl-1.5.4 → wavedl-1.5.6}/src/wavedl/models/tcn.py +0 -0
  29. {wavedl-1.5.4 → wavedl-1.5.6}/src/wavedl/models/unet.py +0 -0
  30. {wavedl-1.5.4 → wavedl-1.5.6}/src/wavedl/test.py +0 -0
  31. {wavedl-1.5.4 → wavedl-1.5.6}/src/wavedl/utils/__init__.py +0 -0
  32. {wavedl-1.5.4 → wavedl-1.5.6}/src/wavedl/utils/config.py +0 -0
  33. {wavedl-1.5.4 → wavedl-1.5.6}/src/wavedl/utils/constraints.py +0 -0
  34. {wavedl-1.5.4 → wavedl-1.5.6}/src/wavedl/utils/cross_validation.py +0 -0
  35. {wavedl-1.5.4 → wavedl-1.5.6}/src/wavedl/utils/distributed.py +0 -0
  36. {wavedl-1.5.4 → wavedl-1.5.6}/src/wavedl/utils/losses.py +0 -0
  37. {wavedl-1.5.4 → wavedl-1.5.6}/src/wavedl/utils/optimizers.py +0 -0
  38. {wavedl-1.5.4 → wavedl-1.5.6}/src/wavedl/utils/schedulers.py +0 -0
  39. {wavedl-1.5.4 → wavedl-1.5.6}/src/wavedl.egg-info/SOURCES.txt +0 -0
  40. {wavedl-1.5.4 → wavedl-1.5.6}/src/wavedl.egg-info/dependency_links.txt +0 -0
  41. {wavedl-1.5.4 → wavedl-1.5.6}/src/wavedl.egg-info/entry_points.txt +0 -0
  42. {wavedl-1.5.4 → wavedl-1.5.6}/src/wavedl.egg-info/requires.txt +0 -0
  43. {wavedl-1.5.4 → wavedl-1.5.6}/src/wavedl.egg-info/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.2
2
2
  Name: wavedl
3
- Version: 1.5.4
3
+ Version: 1.5.6
4
4
  Summary: A Scalable Deep Learning Framework for Wave-Based Inverse Problems
5
5
  Author: Ductho Le
6
6
  License: MIT
@@ -388,7 +388,7 @@ WaveDL/
388
388
  ├── configs/ # YAML config templates
389
389
  ├── examples/ # Ready-to-run examples
390
390
  ├── notebooks/ # Jupyter notebooks
391
- ├── unit_tests/ # Pytest test suite (725 tests)
391
+ ├── unit_tests/ # Pytest test suite (731 tests)
392
392
 
393
393
  ├── pyproject.toml # Package config, dependencies
394
394
  ├── CHANGELOG.md # Version history
@@ -470,6 +470,7 @@ WaveDL/
470
470
  ⭐ = **Pretrained on ImageNet** (recommended for smaller datasets). Weights are downloaded automatically on first use.
471
471
  - **Cache location**: `~/.cache/torch/hub/checkpoints/` (or `./.torch_cache/` on HPC if home is not writable)
472
472
  - **Size**: ~20–350 MB per model depending on architecture
473
+ - **Train from scratch**: Use `--no_pretrained` to disable pretrained weights
473
474
 
474
475
  **💡 HPC Users**: If compute nodes block internet, pre-download weights on the login node:
475
476
 
@@ -1030,7 +1031,7 @@ print(f"✓ Output: {data['output_train'].shape} {data['output_train'].dtype}")
1030
1031
 
1031
1032
  ## 📦 Examples [![Try it on Colab](https://img.shields.io/badge/Try_it_on_Colab-8E44AD?style=plastic&logo=googlecolab&logoColor=white)](https://colab.research.google.com/github/ductho-le/WaveDL/blob/main/notebooks/demo.ipynb)
1032
1033
 
1033
- The `examples/` folder contains a **complete, ready-to-run example** for **material characterization of isotropic plates**. The pre-trained CNN predicts three physical parameters from Lamb wave dispersion curves:
1034
+ The `examples/` folder contains a **complete, ready-to-run example** for **material characterization of isotropic plates**. The pre-trained MobileNetV3 predicts three physical parameters from Lamb wave dispersion curves:
1034
1035
 
1035
1036
  | Parameter | Unit | Description |
1036
1037
  |-----------|------|-------------|
@@ -1045,22 +1046,22 @@ The `examples/` folder contains a **complete, ready-to-run example** for **mater
1045
1046
 
1046
1047
  ```bash
1047
1048
  # Run inference on the example data
1048
- python -m wavedl.test --checkpoint ./examples/elastic_cnn_example/best_checkpoint \
1049
- --data_path ./examples/elastic_cnn_example/Test_data_500.mat \
1050
- --plot --save_predictions --output_dir ./examples/elastic_cnn_example/test_results
1049
+ python -m wavedl.test --checkpoint ./examples/elasticity_prediction/best_checkpoint \
1050
+ --data_path ./examples/elasticity_prediction/Test_data_100.mat \
1051
+ --plot --save_predictions --output_dir ./examples/elasticity_prediction/test_results
1051
1052
 
1052
1053
  # Export to ONNX (already included as model.onnx)
1053
- python -m wavedl.test --checkpoint ./examples/elastic_cnn_example/best_checkpoint \
1054
- --data_path ./examples/elastic_cnn_example/Test_data_500.mat \
1055
- --export onnx --export_path ./examples/elastic_cnn_example/model.onnx
1054
+ python -m wavedl.test --checkpoint ./examples/elasticity_prediction/best_checkpoint \
1055
+ --data_path ./examples/elasticity_prediction/Test_data_100.mat \
1056
+ --export onnx --export_path ./examples/elasticity_prediction/model.onnx
1056
1057
  ```
1057
1058
 
1058
1059
  **What's Included:**
1059
1060
 
1060
1061
  | File | Description |
1061
1062
  |------|-------------|
1062
- | `best_checkpoint/` | Pre-trained CNN checkpoint |
1063
- | `Test_data_500.mat` | 500 sample test set (500×500 dispersion curves → *h*, √(*E*/ρ), *ν*) |
1063
+ | `best_checkpoint/` | Pre-trained MobileNetV3 checkpoint |
1064
+ | `Test_data_100.mat` | 100 sample test set (500×500 dispersion curves → *h*, √(*E*/ρ), *ν*) |
1064
1065
  | `model.onnx` | ONNX export with embedded de-normalization |
1065
1066
  | `training_history.csv` | Epoch-by-epoch training metrics (loss, R², LR, etc.) |
1066
1067
  | `training_curves.png` | Training/validation loss and learning rate plot |
@@ -1070,59 +1071,59 @@ python -m wavedl.test --checkpoint ./examples/elastic_cnn_example/best_checkpoin
1070
1071
  **Training Progress:**
1071
1072
 
1072
1073
  <p align="center">
1073
- <img src="examples/elastic_cnn_example/training_curves.png" alt="Training curves" width="600"><br>
1074
- <em>Training and validation loss over 227 epochs with <code>onecycle</code> learning rate schedule</em>
1074
+ <img src="examples/elasticity_prediction/training_curves.png" alt="Training curves" width="600"><br>
1075
+ <em>Training and validation loss with <code>plateau</code> learning rate schedule</em>
1075
1076
  </p>
1076
1077
 
1077
1078
  **Inference Results:**
1078
1079
 
1079
1080
  <p align="center">
1080
- <img src="examples/elastic_cnn_example/test_results/scatter_all.png" alt="Scatter plot" width="700"><br>
1081
+ <img src="examples/elasticity_prediction/test_results/scatter_all.png" alt="Scatter plot" width="700"><br>
1081
1082
  <em>Figure 1: Predictions vs ground truth for all three elastic parameters</em>
1082
1083
  </p>
1083
1084
 
1084
1085
  <p align="center">
1085
- <img src="examples/elastic_cnn_example/test_results/error_histogram.png" alt="Error histogram" width="700"><br>
1086
+ <img src="examples/elasticity_prediction/test_results/error_histogram.png" alt="Error histogram" width="700"><br>
1086
1087
  <em>Figure 2: Distribution of prediction errors showing near-zero mean bias</em>
1087
1088
  </p>
1088
1089
 
1089
1090
  <p align="center">
1090
- <img src="examples/elastic_cnn_example/test_results/residuals.png" alt="Residual plot" width="700"><br>
1091
+ <img src="examples/elasticity_prediction/test_results/residuals.png" alt="Residual plot" width="700"><br>
1091
1092
  <em>Figure 3: Residuals vs predicted values (no heteroscedasticity detected)</em>
1092
1093
  </p>
1093
1094
 
1094
1095
  <p align="center">
1095
- <img src="examples/elastic_cnn_example/test_results/bland_altman.png" alt="Bland-Altman plot" width="700"><br>
1096
+ <img src="examples/elasticity_prediction/test_results/bland_altman.png" alt="Bland-Altman plot" width="700"><br>
1096
1097
  <em>Figure 4: Bland-Altman analysis with ±1.96 SD limits of agreement</em>
1097
1098
  </p>
1098
1099
 
1099
1100
  <p align="center">
1100
- <img src="examples/elastic_cnn_example/test_results/qq_plot.png" alt="Q-Q plot" width="700"><br>
1101
+ <img src="examples/elasticity_prediction/test_results/qq_plot.png" alt="Q-Q plot" width="700"><br>
1101
1102
  <em>Figure 5: Q-Q plots confirming normally distributed prediction errors</em>
1102
1103
  </p>
1103
1104
 
1104
1105
  <p align="center">
1105
- <img src="examples/elastic_cnn_example/test_results/error_correlation.png" alt="Error correlation" width="300"><br>
1106
+ <img src="examples/elasticity_prediction/test_results/error_correlation.png" alt="Error correlation" width="300"><br>
1106
1107
  <em>Figure 6: Error correlation matrix between parameters</em>
1107
1108
  </p>
1108
1109
 
1109
1110
  <p align="center">
1110
- <img src="examples/elastic_cnn_example/test_results/relative_error.png" alt="Relative error" width="700"><br>
1111
+ <img src="examples/elasticity_prediction/test_results/relative_error.png" alt="Relative error" width="700"><br>
1111
1112
  <em>Figure 7: Relative error (%) vs true value for each parameter</em>
1112
1113
  </p>
1113
1114
 
1114
1115
  <p align="center">
1115
- <img src="examples/elastic_cnn_example/test_results/error_cdf.png" alt="Error CDF" width="500"><br>
1116
+ <img src="examples/elasticity_prediction/test_results/error_cdf.png" alt="Error CDF" width="500"><br>
1116
1117
  <em>Figure 8: Cumulative error distribution — 95% of predictions within indicated bounds</em>
1117
1118
  </p>
1118
1119
 
1119
1120
  <p align="center">
1120
- <img src="examples/elastic_cnn_example/test_results/prediction_vs_index.png" alt="Prediction vs index" width="700"><br>
1121
+ <img src="examples/elasticity_prediction/test_results/prediction_vs_index.png" alt="Prediction vs index" width="700"><br>
1121
1122
  <em>Figure 9: True vs predicted values by sample index</em>
1122
1123
  </p>
1123
1124
 
1124
1125
  <p align="center">
1125
- <img src="examples/elastic_cnn_example/test_results/error_boxplot.png" alt="Error box plot" width="400"><br>
1126
+ <img src="examples/elasticity_prediction/test_results/error_boxplot.png" alt="Error box plot" width="400"><br>
1126
1127
  <em>Figure 10: Error distribution summary (median, quartiles, outliers)</em>
1127
1128
  </p>
1128
1129
 
@@ -342,7 +342,7 @@ WaveDL/
342
342
  ├── configs/ # YAML config templates
343
343
  ├── examples/ # Ready-to-run examples
344
344
  ├── notebooks/ # Jupyter notebooks
345
- ├── unit_tests/ # Pytest test suite (725 tests)
345
+ ├── unit_tests/ # Pytest test suite (731 tests)
346
346
 
347
347
  ├── pyproject.toml # Package config, dependencies
348
348
  ├── CHANGELOG.md # Version history
@@ -424,6 +424,7 @@ WaveDL/
424
424
  ⭐ = **Pretrained on ImageNet** (recommended for smaller datasets). Weights are downloaded automatically on first use.
425
425
  - **Cache location**: `~/.cache/torch/hub/checkpoints/` (or `./.torch_cache/` on HPC if home is not writable)
426
426
  - **Size**: ~20–350 MB per model depending on architecture
427
+ - **Train from scratch**: Use `--no_pretrained` to disable pretrained weights
427
428
 
428
429
  **💡 HPC Users**: If compute nodes block internet, pre-download weights on the login node:
429
430
 
@@ -984,7 +985,7 @@ print(f"✓ Output: {data['output_train'].shape} {data['output_train'].dtype}")
984
985
 
985
986
  ## 📦 Examples [![Try it on Colab](https://img.shields.io/badge/Try_it_on_Colab-8E44AD?style=plastic&logo=googlecolab&logoColor=white)](https://colab.research.google.com/github/ductho-le/WaveDL/blob/main/notebooks/demo.ipynb)
986
987
 
987
- The `examples/` folder contains a **complete, ready-to-run example** for **material characterization of isotropic plates**. The pre-trained CNN predicts three physical parameters from Lamb wave dispersion curves:
988
+ The `examples/` folder contains a **complete, ready-to-run example** for **material characterization of isotropic plates**. The pre-trained MobileNetV3 predicts three physical parameters from Lamb wave dispersion curves:
988
989
 
989
990
  | Parameter | Unit | Description |
990
991
  |-----------|------|-------------|
@@ -999,22 +1000,22 @@ The `examples/` folder contains a **complete, ready-to-run example** for **mater
999
1000
 
1000
1001
  ```bash
1001
1002
  # Run inference on the example data
1002
- python -m wavedl.test --checkpoint ./examples/elastic_cnn_example/best_checkpoint \
1003
- --data_path ./examples/elastic_cnn_example/Test_data_500.mat \
1004
- --plot --save_predictions --output_dir ./examples/elastic_cnn_example/test_results
1003
+ python -m wavedl.test --checkpoint ./examples/elasticity_prediction/best_checkpoint \
1004
+ --data_path ./examples/elasticity_prediction/Test_data_100.mat \
1005
+ --plot --save_predictions --output_dir ./examples/elasticity_prediction/test_results
1005
1006
 
1006
1007
  # Export to ONNX (already included as model.onnx)
1007
- python -m wavedl.test --checkpoint ./examples/elastic_cnn_example/best_checkpoint \
1008
- --data_path ./examples/elastic_cnn_example/Test_data_500.mat \
1009
- --export onnx --export_path ./examples/elastic_cnn_example/model.onnx
1008
+ python -m wavedl.test --checkpoint ./examples/elasticity_prediction/best_checkpoint \
1009
+ --data_path ./examples/elasticity_prediction/Test_data_100.mat \
1010
+ --export onnx --export_path ./examples/elasticity_prediction/model.onnx
1010
1011
  ```
1011
1012
 
1012
1013
  **What's Included:**
1013
1014
 
1014
1015
  | File | Description |
1015
1016
  |------|-------------|
1016
- | `best_checkpoint/` | Pre-trained CNN checkpoint |
1017
- | `Test_data_500.mat` | 500 sample test set (500×500 dispersion curves → *h*, √(*E*/ρ), *ν*) |
1017
+ | `best_checkpoint/` | Pre-trained MobileNetV3 checkpoint |
1018
+ | `Test_data_100.mat` | 100 sample test set (500×500 dispersion curves → *h*, √(*E*/ρ), *ν*) |
1018
1019
  | `model.onnx` | ONNX export with embedded de-normalization |
1019
1020
  | `training_history.csv` | Epoch-by-epoch training metrics (loss, R², LR, etc.) |
1020
1021
  | `training_curves.png` | Training/validation loss and learning rate plot |
@@ -1024,59 +1025,59 @@ python -m wavedl.test --checkpoint ./examples/elastic_cnn_example/best_checkpoin
1024
1025
  **Training Progress:**
1025
1026
 
1026
1027
  <p align="center">
1027
- <img src="examples/elastic_cnn_example/training_curves.png" alt="Training curves" width="600"><br>
1028
- <em>Training and validation loss over 227 epochs with <code>onecycle</code> learning rate schedule</em>
1028
+ <img src="examples/elasticity_prediction/training_curves.png" alt="Training curves" width="600"><br>
1029
+ <em>Training and validation loss with <code>plateau</code> learning rate schedule</em>
1029
1030
  </p>
1030
1031
 
1031
1032
  **Inference Results:**
1032
1033
 
1033
1034
  <p align="center">
1034
- <img src="examples/elastic_cnn_example/test_results/scatter_all.png" alt="Scatter plot" width="700"><br>
1035
+ <img src="examples/elasticity_prediction/test_results/scatter_all.png" alt="Scatter plot" width="700"><br>
1035
1036
  <em>Figure 1: Predictions vs ground truth for all three elastic parameters</em>
1036
1037
  </p>
1037
1038
 
1038
1039
  <p align="center">
1039
- <img src="examples/elastic_cnn_example/test_results/error_histogram.png" alt="Error histogram" width="700"><br>
1040
+ <img src="examples/elasticity_prediction/test_results/error_histogram.png" alt="Error histogram" width="700"><br>
1040
1041
  <em>Figure 2: Distribution of prediction errors showing near-zero mean bias</em>
1041
1042
  </p>
1042
1043
 
1043
1044
  <p align="center">
1044
- <img src="examples/elastic_cnn_example/test_results/residuals.png" alt="Residual plot" width="700"><br>
1045
+ <img src="examples/elasticity_prediction/test_results/residuals.png" alt="Residual plot" width="700"><br>
1045
1046
  <em>Figure 3: Residuals vs predicted values (no heteroscedasticity detected)</em>
1046
1047
  </p>
1047
1048
 
1048
1049
  <p align="center">
1049
- <img src="examples/elastic_cnn_example/test_results/bland_altman.png" alt="Bland-Altman plot" width="700"><br>
1050
+ <img src="examples/elasticity_prediction/test_results/bland_altman.png" alt="Bland-Altman plot" width="700"><br>
1050
1051
  <em>Figure 4: Bland-Altman analysis with ±1.96 SD limits of agreement</em>
1051
1052
  </p>
1052
1053
 
1053
1054
  <p align="center">
1054
- <img src="examples/elastic_cnn_example/test_results/qq_plot.png" alt="Q-Q plot" width="700"><br>
1055
+ <img src="examples/elasticity_prediction/test_results/qq_plot.png" alt="Q-Q plot" width="700"><br>
1055
1056
  <em>Figure 5: Q-Q plots confirming normally distributed prediction errors</em>
1056
1057
  </p>
1057
1058
 
1058
1059
  <p align="center">
1059
- <img src="examples/elastic_cnn_example/test_results/error_correlation.png" alt="Error correlation" width="300"><br>
1060
+ <img src="examples/elasticity_prediction/test_results/error_correlation.png" alt="Error correlation" width="300"><br>
1060
1061
  <em>Figure 6: Error correlation matrix between parameters</em>
1061
1062
  </p>
1062
1063
 
1063
1064
  <p align="center">
1064
- <img src="examples/elastic_cnn_example/test_results/relative_error.png" alt="Relative error" width="700"><br>
1065
+ <img src="examples/elasticity_prediction/test_results/relative_error.png" alt="Relative error" width="700"><br>
1065
1066
  <em>Figure 7: Relative error (%) vs true value for each parameter</em>
1066
1067
  </p>
1067
1068
 
1068
1069
  <p align="center">
1069
- <img src="examples/elastic_cnn_example/test_results/error_cdf.png" alt="Error CDF" width="500"><br>
1070
+ <img src="examples/elasticity_prediction/test_results/error_cdf.png" alt="Error CDF" width="500"><br>
1070
1071
  <em>Figure 8: Cumulative error distribution — 95% of predictions within indicated bounds</em>
1071
1072
  </p>
1072
1073
 
1073
1074
  <p align="center">
1074
- <img src="examples/elastic_cnn_example/test_results/prediction_vs_index.png" alt="Prediction vs index" width="700"><br>
1075
+ <img src="examples/elasticity_prediction/test_results/prediction_vs_index.png" alt="Prediction vs index" width="700"><br>
1075
1076
  <em>Figure 9: True vs predicted values by sample index</em>
1076
1077
  </p>
1077
1078
 
1078
1079
  <p align="center">
1079
- <img src="examples/elastic_cnn_example/test_results/error_boxplot.png" alt="Error box plot" width="400"><br>
1080
+ <img src="examples/elasticity_prediction/test_results/error_boxplot.png" alt="Error box plot" width="400"><br>
1080
1081
  <em>Figure 10: Error distribution summary (median, quartiles, outliers)</em>
1081
1082
  </p>
1082
1083
 
@@ -18,7 +18,7 @@ For inference:
18
18
  # or: python -m wavedl.test --checkpoint best_checkpoint --data_path test.npz
19
19
  """
20
20
 
21
- __version__ = "1.5.4"
21
+ __version__ = "1.5.6"
22
22
  __author__ = "Ductho Le"
23
23
  __email__ = "ductho.le@outlook.com"
24
24
 
@@ -175,13 +175,14 @@ def create_objective(args):
175
175
  env["CUDA_VISIBLE_DEVICES"] = str(gpu_id)
176
176
 
177
177
  # Run training
178
+ # Note: We inherit the user's cwd instead of setting cwd=Path(__file__).parent
179
+ # because site-packages may be read-only and train.py creates cache directories
178
180
  try:
179
181
  result = subprocess.run(
180
182
  cmd,
181
183
  capture_output=True,
182
184
  text=True,
183
185
  timeout=args.timeout,
184
- cwd=Path(__file__).parent,
185
186
  env=env,
186
187
  )
187
188
 
@@ -42,47 +42,89 @@ class PatchEmbed(nn.Module):
42
42
  Supports 1D and 2D inputs:
43
43
  - 1D: Input (B, 1, L) → (B, num_patches, embed_dim)
44
44
  - 2D: Input (B, 1, H, W) → (B, num_patches, embed_dim)
45
+
46
+ Args:
47
+ in_shape: Spatial shape (L,) for 1D or (H, W) for 2D
48
+ patch_size: Size of each patch
49
+ embed_dim: Embedding dimension
50
+ pad_if_needed: If True, pad input to nearest patch-aligned size instead of
51
+ dropping edge pixels. Important for NDE/QUS applications where edge
52
+ effects matter. Default: False (original behavior with warning).
45
53
  """
46
54
 
47
- def __init__(self, in_shape: SpatialShape, patch_size: int, embed_dim: int):
55
+ def __init__(
56
+ self,
57
+ in_shape: SpatialShape,
58
+ patch_size: int,
59
+ embed_dim: int,
60
+ pad_if_needed: bool = False,
61
+ ):
48
62
  super().__init__()
49
63
 
50
64
  self.dim = len(in_shape)
51
65
  self.patch_size = patch_size
52
66
  self.embed_dim = embed_dim
67
+ self.pad_if_needed = pad_if_needed
68
+ self._padding = None # Will be set if padding is needed
53
69
 
54
70
  if self.dim == 1:
55
71
  # 1D: segment patches
56
72
  L = in_shape[0]
57
- if L % patch_size != 0:
58
- import warnings
59
-
60
- warnings.warn(
61
- f"Input length {L} not divisible by patch_size {patch_size}. "
62
- f"Last {L % patch_size} elements will be dropped. "
63
- f"Consider padding input to {((L // patch_size) + 1) * patch_size}.",
64
- UserWarning,
65
- stacklevel=2,
66
- )
67
- self.num_patches = L // patch_size
73
+ remainder = L % patch_size
74
+ if remainder != 0:
75
+ if pad_if_needed:
76
+ # Pad to next multiple of patch_size
77
+ pad_amount = patch_size - remainder
78
+ self._padding = (0, pad_amount) # (left, right)
79
+ L_padded = L + pad_amount
80
+ self.num_patches = L_padded // patch_size
81
+ else:
82
+ import warnings
83
+
84
+ warnings.warn(
85
+ f"Input length {L} not divisible by patch_size {patch_size}. "
86
+ f"Last {remainder} elements will be dropped. "
87
+ f"Consider using pad_if_needed=True or padding input to "
88
+ f"{((L // patch_size) + 1) * patch_size}.",
89
+ UserWarning,
90
+ stacklevel=2,
91
+ )
92
+ self.num_patches = L // patch_size
93
+ else:
94
+ self.num_patches = L // patch_size
68
95
  self.proj = nn.Conv1d(
69
96
  1, embed_dim, kernel_size=patch_size, stride=patch_size
70
97
  )
71
98
  elif self.dim == 2:
72
99
  # 2D: grid patches
73
100
  H, W = in_shape
74
- if H % patch_size != 0 or W % patch_size != 0:
75
- import warnings
76
-
77
- warnings.warn(
78
- f"Input shape ({H}, {W}) not divisible by patch_size {patch_size}. "
79
- f"Border pixels will be dropped (H: {H % patch_size}, W: {W % patch_size}). "
80
- f"Consider padding to ({((H // patch_size) + 1) * patch_size}, "
81
- f"{((W // patch_size) + 1) * patch_size}).",
82
- UserWarning,
83
- stacklevel=2,
84
- )
85
- self.num_patches = (H // patch_size) * (W // patch_size)
101
+ h_rem, w_rem = H % patch_size, W % patch_size
102
+ if h_rem != 0 or w_rem != 0:
103
+ if pad_if_needed:
104
+ # Pad to next multiple of patch_size
105
+ h_pad = (patch_size - h_rem) % patch_size
106
+ w_pad = (patch_size - w_rem) % patch_size
107
+ # Padding format: (left, right, top, bottom)
108
+ self._padding = (0, w_pad, 0, h_pad)
109
+ H_padded, W_padded = H + h_pad, W + w_pad
110
+ self.num_patches = (H_padded // patch_size) * (
111
+ W_padded // patch_size
112
+ )
113
+ else:
114
+ import warnings
115
+
116
+ warnings.warn(
117
+ f"Input shape ({H}, {W}) not divisible by patch_size {patch_size}. "
118
+ f"Border pixels will be dropped (H: {h_rem}, W: {w_rem}). "
119
+ f"Consider using pad_if_needed=True or padding to "
120
+ f"({((H // patch_size) + 1) * patch_size}, "
121
+ f"{((W // patch_size) + 1) * patch_size}).",
122
+ UserWarning,
123
+ stacklevel=2,
124
+ )
125
+ self.num_patches = (H // patch_size) * (W // patch_size)
126
+ else:
127
+ self.num_patches = (H // patch_size) * (W // patch_size)
86
128
  self.proj = nn.Conv2d(
87
129
  1, embed_dim, kernel_size=patch_size, stride=patch_size
88
130
  )
@@ -97,6 +139,10 @@ class PatchEmbed(nn.Module):
97
139
  Returns:
98
140
  Patch embeddings (B, num_patches, embed_dim)
99
141
  """
142
+ # Apply padding if configured
143
+ if self._padding is not None:
144
+ x = nn.functional.pad(x, self._padding, mode="constant", value=0)
145
+
100
146
  x = self.proj(x) # (B, embed_dim, ..reduced_spatial..)
101
147
  x = x.flatten(2) # (B, embed_dim, num_patches)
102
148
  x = x.transpose(1, 2) # (B, num_patches, embed_dim)
@@ -185,6 +231,18 @@ class ViTBase(BaseModel):
185
231
  3. Transformer encoder blocks
186
232
  4. Extract CLS token
187
233
  5. Regression head
234
+
235
+ Args:
236
+ in_shape: Spatial shape (L,) for 1D or (H, W) for 2D
237
+ out_size: Number of regression targets
238
+ patch_size: Size of each patch (default: 16)
239
+ embed_dim: Embedding dimension (default: 768)
240
+ depth: Number of transformer blocks (default: 12)
241
+ num_heads: Number of attention heads (default: 12)
242
+ mlp_ratio: MLP hidden dim multiplier (default: 4.0)
243
+ dropout_rate: Dropout rate (default: 0.1)
244
+ pad_if_needed: If True, pad input to nearest patch-aligned size instead
245
+ of dropping edge pixels. Important for NDE/QUS applications.
188
246
  """
189
247
 
190
248
  def __init__(
@@ -197,6 +255,7 @@ class ViTBase(BaseModel):
197
255
  num_heads: int = 12,
198
256
  mlp_ratio: float = 4.0,
199
257
  dropout_rate: float = 0.1,
258
+ pad_if_needed: bool = False,
200
259
  **kwargs,
201
260
  ):
202
261
  super().__init__(in_shape, out_size)
@@ -207,9 +266,10 @@ class ViTBase(BaseModel):
207
266
  self.num_heads = num_heads
208
267
  self.dropout_rate = dropout_rate
209
268
  self.dim = len(in_shape)
269
+ self.pad_if_needed = pad_if_needed
210
270
 
211
271
  # Patch embedding
212
- self.patch_embed = PatchEmbed(in_shape, patch_size, embed_dim)
272
+ self.patch_embed = PatchEmbed(in_shape, patch_size, embed_dim, pad_if_needed)
213
273
  num_patches = self.patch_embed.num_patches
214
274
 
215
275
  # Learnable CLS token and position embeddings
@@ -162,12 +162,22 @@ except ImportError:
162
162
  os.environ.setdefault("MPLCONFIGDIR", os.getenv("TMPDIR", "/tmp") + "/matplotlib")
163
163
  os.environ.setdefault("FONTCONFIG_PATH", "/etc/fonts")
164
164
 
165
- # Suppress non-critical warnings for cleaner training logs
166
- warnings.filterwarnings("ignore", category=UserWarning)
165
+ # Suppress warnings from known-noisy libraries, but preserve legitimate warnings
166
+ # from torch/numpy about NaN, dtype, and numerical issues.
167
167
  warnings.filterwarnings("ignore", category=FutureWarning)
168
168
  warnings.filterwarnings("ignore", category=DeprecationWarning)
169
+ # Pydantic v1/v2 compatibility warnings
169
170
  warnings.filterwarnings("ignore", module="pydantic")
170
171
  warnings.filterwarnings("ignore", message=".*UnsupportedFieldAttributeWarning.*")
172
+ # Transformer library warnings (loading configs, etc.)
173
+ warnings.filterwarnings("ignore", module="transformers")
174
+ # Accelerate verbose messages
175
+ warnings.filterwarnings("ignore", module="accelerate")
176
+ # torch.compile backend selection warnings
177
+ warnings.filterwarnings("ignore", message=".*TorchDynamo.*")
178
+ warnings.filterwarnings("ignore", message=".*Dynamo is not supported.*")
179
+ # Note: UserWarning from torch/numpy core is NOT suppressed to preserve
180
+ # legitimate warnings about NaN values, dtype mismatches, etc.
171
181
 
172
182
  # ==============================================================================
173
183
  # GPU PERFORMANCE OPTIMIZATIONS (Ampere/Hopper: A100, H100)
@@ -228,6 +238,18 @@ def parse_args() -> argparse.Namespace:
228
238
  default=[],
229
239
  help="Python modules to import before training (for custom models)",
230
240
  )
241
+ parser.add_argument(
242
+ "--pretrained",
243
+ action="store_true",
244
+ default=True,
245
+ help="Use pretrained weights (default: True)",
246
+ )
247
+ parser.add_argument(
248
+ "--no_pretrained",
249
+ dest="pretrained",
250
+ action="store_false",
251
+ help="Train from scratch without pretrained weights",
252
+ )
231
253
 
232
254
  # Configuration File
233
255
  parser.add_argument(
@@ -543,15 +565,11 @@ def main():
543
565
  data_format = DataSource.detect_format(args.data_path)
544
566
  source = get_data_source(data_format)
545
567
 
546
- # Use memory-mapped loading when available
568
+ # Use memory-mapped loading when available (now returns LazyDataHandle for all formats)
547
569
  _cv_handle = None
548
570
  if hasattr(source, "load_mmap"):
549
- result = source.load_mmap(args.data_path)
550
- if hasattr(result, "inputs"):
551
- _cv_handle = result
552
- X, y = result.inputs, result.outputs
553
- else:
554
- X, y = result # NPZ returns tuple directly
571
+ _cv_handle = source.load_mmap(args.data_path)
572
+ X, y = _cv_handle.inputs, _cv_handle.outputs
555
573
  else:
556
574
  X, y = source.load(args.data_path)
557
575
 
@@ -684,7 +702,9 @@ def main():
684
702
  )
685
703
 
686
704
  # Build model using registry
687
- model = build_model(args.model, in_shape=in_shape, out_size=out_dim)
705
+ model = build_model(
706
+ args.model, in_shape=in_shape, out_size=out_dim, pretrained=args.pretrained
707
+ )
688
708
 
689
709
  if accelerator.is_main_process:
690
710
  param_info = model.parameter_summary()
@@ -861,10 +881,22 @@ def main():
861
881
  milestones=milestones,
862
882
  warmup_epochs=args.warmup_epochs,
863
883
  )
864
- # Prepare everything together
865
- model, optimizer, train_dl, val_dl, scheduler = accelerator.prepare(
866
- model, optimizer, train_dl, val_dl, scheduler
867
- )
884
+
885
+ # For ReduceLROnPlateau: DON'T include scheduler in accelerator.prepare()
886
+ # because accelerator wraps scheduler.step() to sync across processes,
887
+ # which defeats our rank-0-only stepping for correct patience counting.
888
+ # Other schedulers are safe to prepare (no internal state affected by multi-call).
889
+ if args.scheduler == "plateau":
890
+ model, optimizer, train_dl, val_dl = accelerator.prepare(
891
+ model, optimizer, train_dl, val_dl
892
+ )
893
+ # Scheduler stays unwrapped - we handle sync manually in training loop
894
+ # But register it for checkpointing so state is saved/loaded on resume
895
+ accelerator.register_for_checkpointing(scheduler)
896
+ else:
897
+ model, optimizer, train_dl, val_dl, scheduler = accelerator.prepare(
898
+ model, optimizer, train_dl, val_dl, scheduler
899
+ )
868
900
 
869
901
  # ==========================================================================
870
902
  # AUTO-RESUME / RESUME FROM CHECKPOINT
@@ -13,6 +13,7 @@ Version: 1.0.0
13
13
  """
14
14
 
15
15
  import gc
16
+ import hashlib
16
17
  import logging
17
18
  import os
18
19
  import pickle
@@ -49,6 +50,29 @@ INPUT_KEYS = ["input_train", "input_test", "X", "data", "inputs", "features", "x
49
50
  OUTPUT_KEYS = ["output_train", "output_test", "Y", "labels", "outputs", "targets", "y"]
50
51
 
51
52
 
53
+ def _compute_file_hash(path: str, chunk_size: int = 8 * 1024 * 1024) -> str:
54
+ """
55
+ Compute SHA256 hash of a file for cache validation.
56
+
57
+ Uses chunked reading to handle large files efficiently without loading
58
+ the entire file into memory. This is more reliable than mtime for detecting
59
+ actual content changes, especially with cloud sync services (Dropbox, etc.)
60
+ that may touch files without modifying content.
61
+
62
+ Args:
63
+ path: Path to file to hash
64
+ chunk_size: Read buffer size (default 8MB for fast I/O)
65
+
66
+ Returns:
67
+ Hex string of SHA256 hash
68
+ """
69
+ hasher = hashlib.sha256()
70
+ with open(path, "rb") as f:
71
+ while chunk := f.read(chunk_size):
72
+ hasher.update(chunk)
73
+ return hasher.hexdigest()
74
+
75
+
52
76
  class LazyDataHandle:
53
77
  """
54
78
  Context manager wrapper for memory-mapped data handles.
@@ -207,6 +231,10 @@ class NPZSource(DataSource):
207
231
 
208
232
  The error for object arrays happens at ACCESS time, not load time.
209
233
  So we need to probe the keys to detect if pickle is required.
234
+
235
+ WARNING: When mmap_mode is not None, the returned NpzFile must be kept
236
+ open for arrays to remain valid. Caller is responsible for closing.
237
+ For non-mmap loading, use _load_and_copy() instead to avoid leaks.
210
238
  """
211
239
  data = np.load(path, allow_pickle=False, mmap_mode=mmap_mode)
212
240
  try:
@@ -222,6 +250,26 @@ class NPZSource(DataSource):
222
250
  return np.load(path, allow_pickle=True, mmap_mode=mmap_mode)
223
251
  raise
224
252
 
253
+ @staticmethod
254
+ def _load_and_copy(path: str, keys: list[str]) -> dict[str, np.ndarray]:
255
+ """Load NPZ and copy arrays, ensuring file is properly closed.
256
+
257
+ This prevents file descriptor leaks by copying arrays before closing.
258
+ Use this for eager loading; use _safe_load for memory-mapped access.
259
+ """
260
+ data = NPZSource._safe_load(path, keys, mmap_mode=None)
261
+ try:
262
+ result = {}
263
+ for key in keys:
264
+ if key in data:
265
+ arr = data[key]
266
+ # Copy ensures we don't hold reference to mmap
267
+ result[key] = arr.copy() if hasattr(arr, "copy") else arr
268
+ return result
269
+ finally:
270
+ if hasattr(data, "close"):
271
+ data.close()
272
+
225
273
  def load(self, path: str) -> tuple[np.ndarray, np.ndarray]:
226
274
  """Load NPZ file (pickle enabled only for sparse matrices)."""
227
275
  # First pass to find keys without loading data
@@ -238,7 +286,7 @@ class NPZSource(DataSource):
238
286
  f"Found: {keys}"
239
287
  )
240
288
 
241
- data = self._safe_load(path, [input_key, output_key])
289
+ data = self._load_and_copy(path, [input_key, output_key])
242
290
  inp = data[input_key]
243
291
  outp = data[output_key]
244
292
 
@@ -248,13 +296,21 @@ class NPZSource(DataSource):
248
296
 
249
297
  return inp, outp
250
298
 
251
- def load_mmap(self, path: str) -> tuple[np.ndarray, np.ndarray]:
299
+ def load_mmap(self, path: str) -> LazyDataHandle:
252
300
  """
253
301
  Load data using memory-mapped mode for zero-copy access.
254
302
 
255
303
  This allows processing large datasets without loading them entirely
256
304
  into RAM. Critical for HPC environments with memory constraints.
257
305
 
306
+ Returns a LazyDataHandle for consistent API across all data sources.
307
+ The NpzFile is kept open for lazy access.
308
+
309
+ Usage:
310
+ with source.load_mmap(path) as (inputs, outputs):
311
+ # Use inputs and outputs
312
+ pass # File automatically closed
313
+
258
314
  Note: Returns memory-mapped arrays - do NOT modify them.
259
315
  """
260
316
  # First pass to find keys without loading data
@@ -271,11 +327,13 @@ class NPZSource(DataSource):
271
327
  f"Found: {keys}"
272
328
  )
273
329
 
330
+ # Keep NpzFile open for lazy access (like HDF5/MATSource)
274
331
  data = self._safe_load(path, [input_key, output_key], mmap_mode="r")
275
332
  inp = data[input_key]
276
333
  outp = data[output_key]
277
334
 
278
- return inp, outp
335
+ # Return LazyDataHandle for consistent API with HDF5Source/MATSource
336
+ return LazyDataHandle(inp, outp, file_handle=data)
279
337
 
280
338
  def load_outputs_only(self, path: str) -> np.ndarray:
281
339
  """Load only targets from NPZ (avoids loading large input arrays)."""
@@ -290,7 +348,7 @@ class NPZSource(DataSource):
290
348
  f"Supported keys: {OUTPUT_KEYS}. Found: {keys}"
291
349
  )
292
350
 
293
- data = self._safe_load(path, [output_key])
351
+ data = self._load_and_copy(path, [output_key])
294
352
  return data[output_key]
295
353
 
296
354
 
@@ -527,9 +585,17 @@ class MATSource(DataSource):
527
585
  inp = self._load_dataset(f, input_key)
528
586
  outp = self._load_dataset(f, output_key)
529
587
 
530
- # Handle 1D outputs that become (1, N) after transpose
531
- if outp.ndim == 2 and outp.shape[0] == 1:
532
- outp = outp.T
588
+ # Handle transposed outputs from MATLAB.
589
+ # Case 1: (1, N) - N samples with 1 target → transpose to (N, 1)
590
+ # Case 2: (T, 1) - 1 sample with T targets → transpose to (1, T)
591
+ num_samples = inp.shape[0] # inp is already transposed
592
+ if outp.ndim == 2:
593
+ if outp.shape[0] == 1 and outp.shape[1] == num_samples:
594
+ # 1D vector: (1, N) → (N, 1)
595
+ outp = outp.T
596
+ elif outp.shape[1] == 1 and outp.shape[0] != num_samples:
597
+ # Single sample with multiple targets: (T, 1) → (1, T)
598
+ outp = outp.T
533
599
 
534
600
  except OSError as e:
535
601
  raise ValueError(
@@ -614,7 +680,10 @@ class MATSource(DataSource):
614
680
  # Load with sparse matrix support
615
681
  outp = self._load_dataset(f, output_key)
616
682
 
617
- # Handle 1D outputs
683
+ # Handle 1D outputs that become (1, N) after transpose.
684
+ # Note: This method has no input to compare against, so we can't
685
+ # distinguish single-sample outputs. This is acceptable for training
686
+ # data where single-sample is unlikely. For inference, use load_test_data.
618
687
  if outp.ndim == 2 and outp.shape[0] == 1:
619
688
  outp = outp.T
620
689
 
@@ -775,7 +844,7 @@ def load_test_data(
775
844
  raise KeyError(
776
845
  f"Input key not found. Tried: {custom_input_keys}. Found: {keys}"
777
846
  )
778
- data = NPZSource._safe_load(
847
+ data = NPZSource._load_and_copy(
779
848
  path, [inp_key] + ([out_key] if out_key else [])
780
849
  )
781
850
  inp = data[inp_key]
@@ -824,8 +893,17 @@ def load_test_data(
824
893
  inp = mat_source._load_dataset(f, inp_key)
825
894
  if out_key:
826
895
  outp = mat_source._load_dataset(f, out_key)
827
- if outp.ndim == 2 and outp.shape[0] == 1:
828
- outp = outp.T
896
+ # Handle transposed outputs from MATLAB
897
+ # Case 1: (1, N) - N samples with 1 target → transpose to (N, 1)
898
+ # Case 2: (T, 1) - 1 sample with T targets → transpose to (1, T)
899
+ num_samples = inp.shape[0]
900
+ if outp.ndim == 2:
901
+ if outp.shape[0] == 1 and outp.shape[1] == num_samples:
902
+ # 1D vector: (1, N) → (N, 1)
903
+ outp = outp.T
904
+ elif outp.shape[1] == 1 and outp.shape[0] != num_samples:
905
+ # Single sample with multiple targets: (T, 1) → (1, T)
906
+ outp = outp.T
829
907
  else:
830
908
  outp = None
831
909
  else:
@@ -844,7 +922,7 @@ def load_test_data(
844
922
  )
845
923
  out_key = DataSource._find_key(keys, custom_output_keys)
846
924
  keys_to_probe = [inp_key] + ([out_key] if out_key else [])
847
- data = NPZSource._safe_load(path, keys_to_probe)
925
+ data = NPZSource._load_and_copy(path, keys_to_probe)
848
926
  inp = data[inp_key]
849
927
  if inp.dtype == object:
850
928
  inp = np.array(
@@ -894,9 +972,17 @@ def load_test_data(
894
972
  out_key = DataSource._find_key(keys, custom_output_keys)
895
973
  if out_key:
896
974
  outp = mat_source._load_dataset(f, out_key)
897
- # Handle 1D outputs that become (1, N) after transpose
898
- if outp.ndim == 2 and outp.shape[0] == 1:
899
- outp = outp.T
975
+ # Handle transposed outputs from MATLAB
976
+ # Case 1: (1, N) - N samples with 1 target → transpose to (N, 1)
977
+ # Case 2: (T, 1) - 1 sample with T targets → transpose to (1, T)
978
+ num_samples = inp.shape[0]
979
+ if outp.ndim == 2:
980
+ if outp.shape[0] == 1 and outp.shape[1] == num_samples:
981
+ # 1D vector: (1, N) → (N, 1)
982
+ outp = outp.T
983
+ elif outp.shape[1] == 1 and outp.shape[0] != num_samples:
984
+ # Single sample with multiple targets: (T, 1) → (1, T)
985
+ outp = outp.T
900
986
  else:
901
987
  outp = None
902
988
  else:
@@ -1096,32 +1182,21 @@ def prepare_data(
1096
1182
  and os.path.exists(META_FILE)
1097
1183
  )
1098
1184
 
1099
- # Validate cache matches current data_path (prevents stale cache corruption)
1185
+ # Validate cache using content hash (portable across folders/machines)
1186
+ # File size is a fast pre-check, content hash is definitive validation
1100
1187
  if cache_exists:
1101
1188
  try:
1102
1189
  with open(META_FILE, "rb") as f:
1103
1190
  meta = pickle.load(f)
1104
- cached_data_path = meta.get("data_path", None)
1105
1191
  cached_file_size = meta.get("file_size", None)
1106
- cached_file_mtime = meta.get("file_mtime", None)
1192
+ cached_content_hash = meta.get("content_hash", None)
1107
1193
 
1108
1194
  # Get current file stats
1109
1195
  current_stats = os.stat(args.data_path)
1110
1196
  current_size = current_stats.st_size
1111
- current_mtime = current_stats.st_mtime
1112
1197
 
1113
- # Check if data path changed
1114
- if cached_data_path != os.path.abspath(args.data_path):
1115
- if accelerator.is_main_process:
1116
- logger.warning(
1117
- f"⚠️ Cache was created from different data file!\n"
1118
- f" Cached: {cached_data_path}\n"
1119
- f" Current: {os.path.abspath(args.data_path)}\n"
1120
- f" Invalidating cache and regenerating..."
1121
- )
1122
- cache_exists = False
1123
- # Check if file was modified (size or mtime changed)
1124
- elif cached_file_size is not None and cached_file_size != current_size:
1198
+ # Check if file size changed (fast check before expensive hash)
1199
+ if cached_file_size is not None and cached_file_size != current_size:
1125
1200
  if accelerator.is_main_process:
1126
1201
  logger.warning(
1127
1202
  f"⚠️ Data file size changed!\n"
@@ -1130,13 +1205,16 @@ def prepare_data(
1130
1205
  f" Invalidating cache and regenerating..."
1131
1206
  )
1132
1207
  cache_exists = False
1133
- elif cached_file_mtime is not None and cached_file_mtime != current_mtime:
1134
- if accelerator.is_main_process:
1135
- logger.warning(
1136
- "⚠️ Data file was modified!\n"
1137
- " Cache may be stale, regenerating..."
1138
- )
1139
- cache_exists = False
1208
+ # Content hash check (robust against cloud sync mtime changes)
1209
+ elif cached_content_hash is not None:
1210
+ current_hash = _compute_file_hash(args.data_path)
1211
+ if cached_content_hash != current_hash:
1212
+ if accelerator.is_main_process:
1213
+ logger.warning(
1214
+ "⚠️ Data file content changed!\n"
1215
+ " Cache is stale, regenerating..."
1216
+ )
1217
+ cache_exists = False
1140
1218
  except Exception:
1141
1219
  cache_exists = False
1142
1220
 
@@ -1153,6 +1231,18 @@ def prepare_data(
1153
1231
  logger.warning(
1154
1232
  f" Failed to remove stale cache {stale_file}: {e}"
1155
1233
  )
1234
+
1235
+ # Fail explicitly if stale cache files couldn't be removed
1236
+ # This prevents silent reuse of outdated data
1237
+ remaining_stale = [
1238
+ f for f in [CACHE_FILE, SCALER_FILE] if os.path.exists(f)
1239
+ ]
1240
+ if remaining_stale:
1241
+ raise RuntimeError(
1242
+ f"Cannot regenerate cache: stale files could not be removed. "
1243
+ f"Please manually delete: {remaining_stale}"
1244
+ )
1245
+
1156
1246
  # RANK 0: Create cache (can take a long time for large datasets)
1157
1247
  # Other ranks will wait at the barrier below
1158
1248
 
@@ -1170,16 +1260,11 @@ def prepare_data(
1170
1260
 
1171
1261
  # Load raw data using memory-mapped mode for all formats
1172
1262
  # This avoids loading the entire dataset into RAM at once
1263
+ # All load_mmap() methods now return LazyDataHandle consistently
1264
+ _lazy_handle = None
1173
1265
  try:
1174
- if data_format == "npz":
1175
- source = NPZSource()
1176
- inp, outp = source.load_mmap(args.data_path)
1177
- elif data_format == "hdf5":
1178
- source = HDF5Source()
1179
- _lazy_handle = source.load_mmap(args.data_path)
1180
- inp, outp = _lazy_handle.inputs, _lazy_handle.outputs
1181
- elif data_format == "mat":
1182
- source = MATSource()
1266
+ source = get_data_source(data_format)
1267
+ if hasattr(source, "load_mmap"):
1183
1268
  _lazy_handle = source.load_mmap(args.data_path)
1184
1269
  inp, outp = _lazy_handle.inputs, _lazy_handle.outputs
1185
1270
  else:
@@ -1243,8 +1328,9 @@ def prepare_data(
1243
1328
  f" Shape Detected: {full_shape} [{dim_type}] | Output Dim: {out_dim}"
1244
1329
  )
1245
1330
 
1246
- # Save metadata (including data path, size, mtime for cache validation)
1331
+ # Save metadata (including data path, size, content hash for cache validation)
1247
1332
  file_stats = os.stat(args.data_path)
1333
+ content_hash = _compute_file_hash(args.data_path)
1248
1334
  with open(META_FILE, "wb") as f:
1249
1335
  pickle.dump(
1250
1336
  {
@@ -1252,7 +1338,7 @@ def prepare_data(
1252
1338
  "out_dim": out_dim,
1253
1339
  "data_path": os.path.abspath(args.data_path),
1254
1340
  "file_size": file_stats.st_size,
1255
- "file_mtime": file_stats.st_mtime,
1341
+ "content_hash": content_hash,
1256
1342
  },
1257
1343
  f,
1258
1344
  )
@@ -815,7 +815,28 @@ def plot_qq(
815
815
 
816
816
  # Standardize errors for QQ plot
817
817
  err = errors[:, i]
818
- standardized = (err - np.mean(err)) / np.std(err)
818
+ std_err = np.std(err)
819
+
820
+ # Guard against zero variance (constant errors)
821
+ if std_err < 1e-10:
822
+ title = (
823
+ param_names[i] if param_names and i < len(param_names) else f"Param {i}"
824
+ )
825
+ ax.text(
826
+ 0.5,
827
+ 0.5,
828
+ "Zero variance\n(constant errors)",
829
+ ha="center",
830
+ va="center",
831
+ fontsize=10,
832
+ transform=ax.transAxes,
833
+ )
834
+ ax.set_title(f"{title}\n(zero variance)")
835
+ ax.set_xlabel("Theoretical Quantiles")
836
+ ax.set_ylabel("Sample Quantiles")
837
+ continue
838
+
839
+ standardized = (err - np.mean(err)) / std_err
819
840
 
820
841
  # Calculate theoretical quantiles and sample quantiles
821
842
  (osm, osr), (slope, intercept, r) = stats.probplot(standardized, dist="norm")
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.2
2
2
  Name: wavedl
3
- Version: 1.5.4
3
+ Version: 1.5.6
4
4
  Summary: A Scalable Deep Learning Framework for Wave-Based Inverse Problems
5
5
  Author: Ductho Le
6
6
  License: MIT
@@ -388,7 +388,7 @@ WaveDL/
388
388
  ├── configs/ # YAML config templates
389
389
  ├── examples/ # Ready-to-run examples
390
390
  ├── notebooks/ # Jupyter notebooks
391
- ├── unit_tests/ # Pytest test suite (725 tests)
391
+ ├── unit_tests/ # Pytest test suite (731 tests)
392
392
 
393
393
  ├── pyproject.toml # Package config, dependencies
394
394
  ├── CHANGELOG.md # Version history
@@ -470,6 +470,7 @@ WaveDL/
470
470
  ⭐ = **Pretrained on ImageNet** (recommended for smaller datasets). Weights are downloaded automatically on first use.
471
471
  - **Cache location**: `~/.cache/torch/hub/checkpoints/` (or `./.torch_cache/` on HPC if home is not writable)
472
472
  - **Size**: ~20–350 MB per model depending on architecture
473
+ - **Train from scratch**: Use `--no_pretrained` to disable pretrained weights
473
474
 
474
475
  **💡 HPC Users**: If compute nodes block internet, pre-download weights on the login node:
475
476
 
@@ -1030,7 +1031,7 @@ print(f"✓ Output: {data['output_train'].shape} {data['output_train'].dtype}")
1030
1031
 
1031
1032
  ## 📦 Examples [![Try it on Colab](https://img.shields.io/badge/Try_it_on_Colab-8E44AD?style=plastic&logo=googlecolab&logoColor=white)](https://colab.research.google.com/github/ductho-le/WaveDL/blob/main/notebooks/demo.ipynb)
1032
1033
 
1033
- The `examples/` folder contains a **complete, ready-to-run example** for **material characterization of isotropic plates**. The pre-trained CNN predicts three physical parameters from Lamb wave dispersion curves:
1034
+ The `examples/` folder contains a **complete, ready-to-run example** for **material characterization of isotropic plates**. The pre-trained MobileNetV3 predicts three physical parameters from Lamb wave dispersion curves:
1034
1035
 
1035
1036
  | Parameter | Unit | Description |
1036
1037
  |-----------|------|-------------|
@@ -1045,22 +1046,22 @@ The `examples/` folder contains a **complete, ready-to-run example** for **mater
1045
1046
 
1046
1047
  ```bash
1047
1048
  # Run inference on the example data
1048
- python -m wavedl.test --checkpoint ./examples/elastic_cnn_example/best_checkpoint \
1049
- --data_path ./examples/elastic_cnn_example/Test_data_500.mat \
1050
- --plot --save_predictions --output_dir ./examples/elastic_cnn_example/test_results
1049
+ python -m wavedl.test --checkpoint ./examples/elasticity_prediction/best_checkpoint \
1050
+ --data_path ./examples/elasticity_prediction/Test_data_100.mat \
1051
+ --plot --save_predictions --output_dir ./examples/elasticity_prediction/test_results
1051
1052
 
1052
1053
  # Export to ONNX (already included as model.onnx)
1053
- python -m wavedl.test --checkpoint ./examples/elastic_cnn_example/best_checkpoint \
1054
- --data_path ./examples/elastic_cnn_example/Test_data_500.mat \
1055
- --export onnx --export_path ./examples/elastic_cnn_example/model.onnx
1054
+ python -m wavedl.test --checkpoint ./examples/elasticity_prediction/best_checkpoint \
1055
+ --data_path ./examples/elasticity_prediction/Test_data_100.mat \
1056
+ --export onnx --export_path ./examples/elasticity_prediction/model.onnx
1056
1057
  ```
1057
1058
 
1058
1059
  **What's Included:**
1059
1060
 
1060
1061
  | File | Description |
1061
1062
  |------|-------------|
1062
- | `best_checkpoint/` | Pre-trained CNN checkpoint |
1063
- | `Test_data_500.mat` | 500 sample test set (500×500 dispersion curves → *h*, √(*E*/ρ), *ν*) |
1063
+ | `best_checkpoint/` | Pre-trained MobileNetV3 checkpoint |
1064
+ | `Test_data_100.mat` | 100 sample test set (500×500 dispersion curves → *h*, √(*E*/ρ), *ν*) |
1064
1065
  | `model.onnx` | ONNX export with embedded de-normalization |
1065
1066
  | `training_history.csv` | Epoch-by-epoch training metrics (loss, R², LR, etc.) |
1066
1067
  | `training_curves.png` | Training/validation loss and learning rate plot |
@@ -1070,59 +1071,59 @@ python -m wavedl.test --checkpoint ./examples/elastic_cnn_example/best_checkpoin
1070
1071
  **Training Progress:**
1071
1072
 
1072
1073
  <p align="center">
1073
- <img src="examples/elastic_cnn_example/training_curves.png" alt="Training curves" width="600"><br>
1074
- <em>Training and validation loss over 227 epochs with <code>onecycle</code> learning rate schedule</em>
1074
+ <img src="examples/elasticity_prediction/training_curves.png" alt="Training curves" width="600"><br>
1075
+ <em>Training and validation loss with <code>plateau</code> learning rate schedule</em>
1075
1076
  </p>
1076
1077
 
1077
1078
  **Inference Results:**
1078
1079
 
1079
1080
  <p align="center">
1080
- <img src="examples/elastic_cnn_example/test_results/scatter_all.png" alt="Scatter plot" width="700"><br>
1081
+ <img src="examples/elasticity_prediction/test_results/scatter_all.png" alt="Scatter plot" width="700"><br>
1081
1082
  <em>Figure 1: Predictions vs ground truth for all three elastic parameters</em>
1082
1083
  </p>
1083
1084
 
1084
1085
  <p align="center">
1085
- <img src="examples/elastic_cnn_example/test_results/error_histogram.png" alt="Error histogram" width="700"><br>
1086
+ <img src="examples/elasticity_prediction/test_results/error_histogram.png" alt="Error histogram" width="700"><br>
1086
1087
  <em>Figure 2: Distribution of prediction errors showing near-zero mean bias</em>
1087
1088
  </p>
1088
1089
 
1089
1090
  <p align="center">
1090
- <img src="examples/elastic_cnn_example/test_results/residuals.png" alt="Residual plot" width="700"><br>
1091
+ <img src="examples/elasticity_prediction/test_results/residuals.png" alt="Residual plot" width="700"><br>
1091
1092
  <em>Figure 3: Residuals vs predicted values (no heteroscedasticity detected)</em>
1092
1093
  </p>
1093
1094
 
1094
1095
  <p align="center">
1095
- <img src="examples/elastic_cnn_example/test_results/bland_altman.png" alt="Bland-Altman plot" width="700"><br>
1096
+ <img src="examples/elasticity_prediction/test_results/bland_altman.png" alt="Bland-Altman plot" width="700"><br>
1096
1097
  <em>Figure 4: Bland-Altman analysis with ±1.96 SD limits of agreement</em>
1097
1098
  </p>
1098
1099
 
1099
1100
  <p align="center">
1100
- <img src="examples/elastic_cnn_example/test_results/qq_plot.png" alt="Q-Q plot" width="700"><br>
1101
+ <img src="examples/elasticity_prediction/test_results/qq_plot.png" alt="Q-Q plot" width="700"><br>
1101
1102
  <em>Figure 5: Q-Q plots confirming normally distributed prediction errors</em>
1102
1103
  </p>
1103
1104
 
1104
1105
  <p align="center">
1105
- <img src="examples/elastic_cnn_example/test_results/error_correlation.png" alt="Error correlation" width="300"><br>
1106
+ <img src="examples/elasticity_prediction/test_results/error_correlation.png" alt="Error correlation" width="300"><br>
1106
1107
  <em>Figure 6: Error correlation matrix between parameters</em>
1107
1108
  </p>
1108
1109
 
1109
1110
  <p align="center">
1110
- <img src="examples/elastic_cnn_example/test_results/relative_error.png" alt="Relative error" width="700"><br>
1111
+ <img src="examples/elasticity_prediction/test_results/relative_error.png" alt="Relative error" width="700"><br>
1111
1112
  <em>Figure 7: Relative error (%) vs true value for each parameter</em>
1112
1113
  </p>
1113
1114
 
1114
1115
  <p align="center">
1115
- <img src="examples/elastic_cnn_example/test_results/error_cdf.png" alt="Error CDF" width="500"><br>
1116
+ <img src="examples/elasticity_prediction/test_results/error_cdf.png" alt="Error CDF" width="500"><br>
1116
1117
  <em>Figure 8: Cumulative error distribution — 95% of predictions within indicated bounds</em>
1117
1118
  </p>
1118
1119
 
1119
1120
  <p align="center">
1120
- <img src="examples/elastic_cnn_example/test_results/prediction_vs_index.png" alt="Prediction vs index" width="700"><br>
1121
+ <img src="examples/elasticity_prediction/test_results/prediction_vs_index.png" alt="Prediction vs index" width="700"><br>
1121
1122
  <em>Figure 9: True vs predicted values by sample index</em>
1122
1123
  </p>
1123
1124
 
1124
1125
  <p align="center">
1125
- <img src="examples/elastic_cnn_example/test_results/error_boxplot.png" alt="Error box plot" width="400"><br>
1126
+ <img src="examples/elasticity_prediction/test_results/error_boxplot.png" alt="Error box plot" width="400"><br>
1126
1127
  <em>Figure 10: Error distribution summary (median, quartiles, outliers)</em>
1127
1128
  </p>
1128
1129
 
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes