wavedl 1.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- wavedl/__init__.py +43 -0
- wavedl/hpo.py +366 -0
- wavedl/models/__init__.py +86 -0
- wavedl/models/_template.py +157 -0
- wavedl/models/base.py +173 -0
- wavedl/models/cnn.py +249 -0
- wavedl/models/convnext.py +425 -0
- wavedl/models/densenet.py +406 -0
- wavedl/models/efficientnet.py +236 -0
- wavedl/models/registry.py +104 -0
- wavedl/models/resnet.py +555 -0
- wavedl/models/unet.py +304 -0
- wavedl/models/vit.py +372 -0
- wavedl/test.py +1069 -0
- wavedl/train.py +1079 -0
- wavedl/utils/__init__.py +151 -0
- wavedl/utils/config.py +269 -0
- wavedl/utils/cross_validation.py +509 -0
- wavedl/utils/data.py +1220 -0
- wavedl/utils/distributed.py +138 -0
- wavedl/utils/losses.py +216 -0
- wavedl/utils/metrics.py +1236 -0
- wavedl/utils/optimizers.py +216 -0
- wavedl/utils/schedulers.py +251 -0
- wavedl-1.2.0.dist-info/LICENSE +21 -0
- wavedl-1.2.0.dist-info/METADATA +991 -0
- wavedl-1.2.0.dist-info/RECORD +30 -0
- wavedl-1.2.0.dist-info/WHEEL +5 -0
- wavedl-1.2.0.dist-info/entry_points.txt +4 -0
- wavedl-1.2.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,991 @@
|
|
|
1
|
+
Metadata-Version: 2.2
|
|
2
|
+
Name: wavedl
|
|
3
|
+
Version: 1.2.0
|
|
4
|
+
Summary: A Scalable Deep Learning Framework for Wave-Based Inverse Problems
|
|
5
|
+
Author: Ductho Le
|
|
6
|
+
License: MIT
|
|
7
|
+
Project-URL: Homepage, https://github.com/ductho-le/WaveDL
|
|
8
|
+
Project-URL: Repository, https://github.com/ductho-le/WaveDL
|
|
9
|
+
Project-URL: Documentation, https://github.com/ductho-le/WaveDL#readme
|
|
10
|
+
Project-URL: Issues, https://github.com/ductho-le/WaveDL/issues
|
|
11
|
+
Keywords: deep-learning,inverse-problems,wave-propagation,ultrasonic,guided-waves,non-destructive-testing,machine-learning,pytorch,regression
|
|
12
|
+
Classifier: Development Status :: 4 - Beta
|
|
13
|
+
Classifier: Intended Audience :: Science/Research
|
|
14
|
+
Classifier: License :: OSI Approved :: MIT License
|
|
15
|
+
Classifier: Programming Language :: Python :: 3
|
|
16
|
+
Classifier: Programming Language :: Python :: 3.11
|
|
17
|
+
Classifier: Programming Language :: Python :: 3.12
|
|
18
|
+
Classifier: Programming Language :: Python :: 3.13
|
|
19
|
+
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
|
20
|
+
Classifier: Topic :: Scientific/Engineering :: Physics
|
|
21
|
+
Requires-Python: >=3.11
|
|
22
|
+
Description-Content-Type: text/markdown
|
|
23
|
+
License-File: LICENSE
|
|
24
|
+
Requires-Dist: torch>=2.0.0
|
|
25
|
+
Requires-Dist: torchvision>=0.15.0
|
|
26
|
+
Requires-Dist: accelerate>=0.20.0
|
|
27
|
+
Requires-Dist: numpy>=1.24.0
|
|
28
|
+
Requires-Dist: scipy>=1.10.0
|
|
29
|
+
Requires-Dist: scikit-learn>=1.2.0
|
|
30
|
+
Requires-Dist: pandas>=2.0.0
|
|
31
|
+
Requires-Dist: matplotlib>=3.7.0
|
|
32
|
+
Requires-Dist: tqdm>=4.65.0
|
|
33
|
+
Requires-Dist: wandb>=0.15.0
|
|
34
|
+
Requires-Dist: pyyaml>=6.0.0
|
|
35
|
+
Requires-Dist: h5py>=3.8.0
|
|
36
|
+
Requires-Dist: safetensors>=0.3.0
|
|
37
|
+
Provides-Extra: dev
|
|
38
|
+
Requires-Dist: pytest>=7.0.0; extra == "dev"
|
|
39
|
+
Requires-Dist: pytest-xdist>=3.5.0; extra == "dev"
|
|
40
|
+
Requires-Dist: ruff>=0.8.0; extra == "dev"
|
|
41
|
+
Requires-Dist: pre-commit>=3.5.0; extra == "dev"
|
|
42
|
+
Provides-Extra: onnx
|
|
43
|
+
Requires-Dist: onnx>=1.14.0; extra == "onnx"
|
|
44
|
+
Requires-Dist: onnxruntime>=1.15.0; extra == "onnx"
|
|
45
|
+
Provides-Extra: compile
|
|
46
|
+
Requires-Dist: triton; extra == "compile"
|
|
47
|
+
Provides-Extra: hpo
|
|
48
|
+
Requires-Dist: optuna>=3.0.0; extra == "hpo"
|
|
49
|
+
Provides-Extra: all
|
|
50
|
+
Requires-Dist: pytest>=7.0.0; extra == "all"
|
|
51
|
+
Requires-Dist: pytest-xdist>=3.5.0; extra == "all"
|
|
52
|
+
Requires-Dist: ruff>=0.8.0; extra == "all"
|
|
53
|
+
Requires-Dist: pre-commit>=3.5.0; extra == "all"
|
|
54
|
+
Requires-Dist: onnx>=1.14.0; extra == "all"
|
|
55
|
+
Requires-Dist: onnxruntime>=1.15.0; extra == "all"
|
|
56
|
+
Requires-Dist: triton; extra == "all"
|
|
57
|
+
Requires-Dist: optuna>=3.0.0; extra == "all"
|
|
58
|
+
|
|
59
|
+
<div align="center">
|
|
60
|
+
|
|
61
|
+
<img src="logos/wavedl_logo.png" alt="WaveDL Logo" width="500">
|
|
62
|
+
|
|
63
|
+
### A Scalable Deep Learning Framework for Wave-Based Inverse Problems
|
|
64
|
+
|
|
65
|
+
[](https://www.python.org/downloads/)
|
|
66
|
+
[](https://pytorch.org/)
|
|
67
|
+
[](https://huggingface.co/docs/accelerate/)
|
|
68
|
+
<br>
|
|
69
|
+
[](https://github.com/ductho-le/WaveDL/actions/workflows/test.yml)
|
|
70
|
+
[](https://github.com/ductho-le/WaveDL/actions/workflows/lint.yml)
|
|
71
|
+
[](https://colab.research.google.com/github/ductho-le/WaveDL/blob/main/notebooks/demo.ipynb)
|
|
72
|
+
<br>
|
|
73
|
+
[](LICENSE)
|
|
74
|
+
[](https://doi.org/10.5281/zenodo.18012338)
|
|
75
|
+
|
|
76
|
+
**Production-ready • Multi-GPU DDP • Memory-Efficient • Plug-and-Play**
|
|
77
|
+
|
|
78
|
+
[Getting Started](#-getting-started) •
|
|
79
|
+
[Documentation](#-documentation) •
|
|
80
|
+
[Examples](#-examples) •
|
|
81
|
+
[Discussions](https://github.com/ductho-le/WaveDL/discussions) •
|
|
82
|
+
[Citation](#-citation)
|
|
83
|
+
|
|
84
|
+
---
|
|
85
|
+
|
|
86
|
+
**Plug in your model, load your data, and let WaveDL do the heavy lifting 💪**
|
|
87
|
+
|
|
88
|
+
</div>
|
|
89
|
+
|
|
90
|
+
---
|
|
91
|
+
|
|
92
|
+
## 💡 What is WaveDL?
|
|
93
|
+
|
|
94
|
+
WaveDL is a **deep learning framework** built for **wave-based inverse problems** — from ultrasonic NDE and geophysics to biomedical tissue characterization. It provides a robust, scalable training pipeline for mapping multi-dimensional data (1D/2D/3D) to physical quantities.
|
|
95
|
+
|
|
96
|
+
```
|
|
97
|
+
Input: Waveforms, spectrograms, B-scans, dispersion curves, ...
|
|
98
|
+
↓
|
|
99
|
+
Output: Material properties, defect dimensions, damage locations, ...
|
|
100
|
+
```
|
|
101
|
+
|
|
102
|
+
The framework handles the engineering challenges of large-scale deep learning — big datasets, distributed training, and HPC deployment — so you can focus on the science, not the infrastructure.
|
|
103
|
+
|
|
104
|
+
**Built for researchers who need:**
|
|
105
|
+
- 📊 Multi-target regression with reproducibility and fair benchmarking
|
|
106
|
+
- 🚀 Seamless multi-GPU training on HPC clusters
|
|
107
|
+
- 💾 Memory-efficient handling of large-scale datasets
|
|
108
|
+
- 🔧 Easy integration of custom model architectures
|
|
109
|
+
|
|
110
|
+
---
|
|
111
|
+
|
|
112
|
+
## ✨ Features
|
|
113
|
+
|
|
114
|
+
<table width="100%">
|
|
115
|
+
<tr>
|
|
116
|
+
<td width="50%" valign="top">
|
|
117
|
+
|
|
118
|
+
**⚡ Load All Data — No More Bottleneck**
|
|
119
|
+
|
|
120
|
+
Train on datasets larger than RAM:
|
|
121
|
+
- Memory-mapped, zero-copy streaming
|
|
122
|
+
- Full random shuffling at GPU speed
|
|
123
|
+
- Your GPU stays fed — always
|
|
124
|
+
|
|
125
|
+
</td>
|
|
126
|
+
<td width="50%" valign="top">
|
|
127
|
+
|
|
128
|
+
**🧠 One-Line Model Registration**
|
|
129
|
+
|
|
130
|
+
Plug in any architecture:
|
|
131
|
+
```python
|
|
132
|
+
@register_model("my_net")
|
|
133
|
+
class MyNet(BaseModel): ...
|
|
134
|
+
```
|
|
135
|
+
Design your model. Register with one line.
|
|
136
|
+
|
|
137
|
+
</td>
|
|
138
|
+
</tr>
|
|
139
|
+
<tr>
|
|
140
|
+
<td width="50%" valign="top">
|
|
141
|
+
|
|
142
|
+
**🛡️ DDP That Actually Works**
|
|
143
|
+
|
|
144
|
+
Multi-GPU training without the pain:
|
|
145
|
+
- Synchronized early stopping
|
|
146
|
+
- Deadlock-free checkpointing
|
|
147
|
+
- Correct metric aggregation
|
|
148
|
+
|
|
149
|
+
</td>
|
|
150
|
+
<td width="50%" valign="top">
|
|
151
|
+
|
|
152
|
+
**📊 Publish-Ready Output**
|
|
153
|
+
|
|
154
|
+
Results go straight to your paper:
|
|
155
|
+
- 11 diagnostic plots with LaTeX styling
|
|
156
|
+
- Multi-format export (PNG, PDF, SVG, ...)
|
|
157
|
+
- MAE in physical units per parameter
|
|
158
|
+
|
|
159
|
+
</td>
|
|
160
|
+
</tr>
|
|
161
|
+
<tr>
|
|
162
|
+
<td width="50%" valign="top">
|
|
163
|
+
|
|
164
|
+
**🖥️ HPC-Native Design**
|
|
165
|
+
|
|
166
|
+
Built for high-performance clusters:
|
|
167
|
+
- Automatic GPU detection
|
|
168
|
+
- WandB experiment tracking
|
|
169
|
+
- BF16/FP16 mixed precision
|
|
170
|
+
|
|
171
|
+
</td>
|
|
172
|
+
<td width="50%" valign="top">
|
|
173
|
+
|
|
174
|
+
**🔄 Crash-Proof Training**
|
|
175
|
+
|
|
176
|
+
Never lose your progress:
|
|
177
|
+
- Full state checkpoints
|
|
178
|
+
- Resume from any point
|
|
179
|
+
- Emergency saves on interrupt
|
|
180
|
+
|
|
181
|
+
</td>
|
|
182
|
+
</tr>
|
|
183
|
+
<tr>
|
|
184
|
+
<td width="50%" valign="top">
|
|
185
|
+
|
|
186
|
+
**🎛️ Flexible & Reproducible Training**
|
|
187
|
+
|
|
188
|
+
Fully configurable via CLI flags or YAML:
|
|
189
|
+
- Loss functions, optimizers, schedulers
|
|
190
|
+
- K-fold cross-validation
|
|
191
|
+
- See [Configuration](#️-configuration) for details
|
|
192
|
+
|
|
193
|
+
</td>
|
|
194
|
+
<td width="50%" valign="top">
|
|
195
|
+
|
|
196
|
+
**📦 ONNX Export**
|
|
197
|
+
|
|
198
|
+
Deploy models anywhere:
|
|
199
|
+
- One-command export to ONNX
|
|
200
|
+
- LabVIEW, MATLAB, C++ compatible
|
|
201
|
+
- Validated PyTorch↔ONNX outputs
|
|
202
|
+
|
|
203
|
+
</td>
|
|
204
|
+
</tr>
|
|
205
|
+
</table>
|
|
206
|
+
|
|
207
|
+
---
|
|
208
|
+
|
|
209
|
+
## 🚀 Getting Started
|
|
210
|
+
|
|
211
|
+
### Installation
|
|
212
|
+
|
|
213
|
+
```bash
|
|
214
|
+
git clone https://github.com/ductho-le/WaveDL.git
|
|
215
|
+
cd WaveDL
|
|
216
|
+
|
|
217
|
+
# Basic install (training + inference)
|
|
218
|
+
pip install -e .
|
|
219
|
+
|
|
220
|
+
# Full install (adds ONNX export, torch.compile, HPO, dev tools)
|
|
221
|
+
pip install -e ".[all]"
|
|
222
|
+
```
|
|
223
|
+
|
|
224
|
+
> [!NOTE]
|
|
225
|
+
> Dependencies are managed in `pyproject.toml`. Python 3.11+ required.
|
|
226
|
+
>
|
|
227
|
+
> For development setup (running tests, contributing), see [CONTRIBUTING.md](.github/CONTRIBUTING.md).
|
|
228
|
+
|
|
229
|
+
### Quick Start
|
|
230
|
+
|
|
231
|
+
> [!TIP]
|
|
232
|
+
> In all examples below, replace `<...>` placeholders with your values. See [Configuration](#️-configuration) for defaults and options.
|
|
233
|
+
|
|
234
|
+
#### Option 1: Using the Helper Script (Recommended for HPC)
|
|
235
|
+
|
|
236
|
+
The `run_training.sh` wrapper automatically configures the environment for HPC systems:
|
|
237
|
+
|
|
238
|
+
```bash
|
|
239
|
+
# Make executable (first time only)
|
|
240
|
+
chmod +x run_training.sh
|
|
241
|
+
|
|
242
|
+
# Basic training (auto-detects available GPUs)
|
|
243
|
+
./run_training.sh --model <model_name> --data_path <train_data> --batch_size <number> --output_dir <output_folder>
|
|
244
|
+
|
|
245
|
+
# Detailed configuration
|
|
246
|
+
./run_training.sh --model <model_name> --data_path <train_data> --batch_size <number> \
|
|
247
|
+
--lr <number> --epochs <number> --patience <number> --compile --output_dir <output_folder>
|
|
248
|
+
```
|
|
249
|
+
|
|
250
|
+
#### Option 2: Direct Accelerate Launch
|
|
251
|
+
|
|
252
|
+
```bash
|
|
253
|
+
# Local - auto-detects GPUs
|
|
254
|
+
accelerate launch -m wavedl.train --model <model_name> --data_path <train_data> --batch_size <number> --output_dir <output_folder>
|
|
255
|
+
|
|
256
|
+
# Resume training (automatic - just re-run with same output_dir)
|
|
257
|
+
# Manual resume from specific checkpoint:
|
|
258
|
+
accelerate launch -m wavedl.train --model <model_name> --data_path <train_data> --resume <checkpoint_folder> --output_dir <output_folder>
|
|
259
|
+
|
|
260
|
+
# Force fresh start (ignores existing checkpoints)
|
|
261
|
+
accelerate launch -m wavedl.train --model <model_name> --data_path <train_data> --output_dir <output_folder> --fresh
|
|
262
|
+
|
|
263
|
+
# List available models
|
|
264
|
+
python -m wavedl.train --list_models
|
|
265
|
+
```
|
|
266
|
+
|
|
267
|
+
> [!TIP]
|
|
268
|
+
> **Auto-Resume**: If training crashes or is interrupted, simply re-run with the same `--output_dir`. The framework automatically detects incomplete training and resumes from the last checkpoint. Use `--fresh` to force a fresh start.
|
|
269
|
+
>
|
|
270
|
+
> **GPU Auto-Detection**: By default, `run_training.sh` automatically detects available GPUs using `nvidia-smi`. Set `NUM_GPUS` to override this behavior.
|
|
271
|
+
|
|
272
|
+
### Testing & Inference
|
|
273
|
+
|
|
274
|
+
After training, use `wavedl.test` to evaluate your model on test data:
|
|
275
|
+
|
|
276
|
+
```bash
|
|
277
|
+
# Basic inference
|
|
278
|
+
python -m wavedl.test --checkpoint <checkpoint_folder> --data_path <test_data>
|
|
279
|
+
|
|
280
|
+
# With visualization, CSV export, and multiple file formats
|
|
281
|
+
python -m wavedl.test --checkpoint <checkpoint_folder> --data_path <test_data> \
|
|
282
|
+
--plot --plot_format png pdf --save_predictions --output_dir <output_folder>
|
|
283
|
+
|
|
284
|
+
# With custom parameter names
|
|
285
|
+
python -m wavedl.test --checkpoint <checkpoint_folder> --data_path <test_data> \
|
|
286
|
+
--param_names '$p_1$' '$p_2$' '$p_3$' --plot
|
|
287
|
+
|
|
288
|
+
# Export model to ONNX for deployment (LabVIEW, MATLAB, C++, etc.)
|
|
289
|
+
python -m wavedl.test --checkpoint <checkpoint_folder> --data_path <test_data> \
|
|
290
|
+
--export onnx --export_path <output_file.onnx>
|
|
291
|
+
```
|
|
292
|
+
|
|
293
|
+
**Output:**
|
|
294
|
+
- **Console**: R², Pearson correlation, MAE per parameter
|
|
295
|
+
- **CSV** (with `--save_predictions`): True, predicted, error, and absolute error for all parameters
|
|
296
|
+
- **Plots** (with `--plot`): 10 publication-quality plots (scatter, histogram, residuals, Bland-Altman, Q-Q, correlation, relative error, CDF, index plot, box plot)
|
|
297
|
+
- **Format** (with `--plot_format`): Supported formats: `png` (default), `pdf` (vector), `svg` (vector), `eps` (LaTeX), `tiff`, `jpg`, `ps`
|
|
298
|
+
|
|
299
|
+
> [!NOTE]
|
|
300
|
+
> `wavedl.test` auto-detects the model architecture from checkpoint metadata. If unavailable, it falls back to folder name parsing. Use `--model` to override if needed.
|
|
301
|
+
|
|
302
|
+
---
|
|
303
|
+
|
|
304
|
+
## 📁 Project Structure
|
|
305
|
+
|
|
306
|
+
```
|
|
307
|
+
WaveDL/
|
|
308
|
+
├── src/
|
|
309
|
+
│ └── wavedl/ # Main package (namespaced)
|
|
310
|
+
│ ├── __init__.py # Package init with __version__
|
|
311
|
+
│ ├── train.py # Training entry point
|
|
312
|
+
│ ├── test.py # Testing & inference script
|
|
313
|
+
│ ├── hpo.py # Hyperparameter optimization
|
|
314
|
+
│ │
|
|
315
|
+
│ ├── models/ # Model architectures
|
|
316
|
+
│ │ ├── registry.py # Model factory (@register_model)
|
|
317
|
+
│ │ ├── base.py # Abstract base class
|
|
318
|
+
│ │ ├── cnn.py # Baseline CNN
|
|
319
|
+
│ │ ├── resnet.py # ResNet-18/34/50 (1D/2D/3D)
|
|
320
|
+
│ │ ├── efficientnet.py# EfficientNet-B0/B1/B2
|
|
321
|
+
│ │ ├── vit.py # Vision Transformer (1D/2D)
|
|
322
|
+
│ │ ├── convnext.py # ConvNeXt (1D/2D/3D)
|
|
323
|
+
│ │ ├── densenet.py # DenseNet-121/169 (1D/2D/3D)
|
|
324
|
+
│ │ └── unet.py # U-Net / U-Net Regression
|
|
325
|
+
│ │
|
|
326
|
+
│ └── utils/ # Utilities
|
|
327
|
+
│ ├── data.py # Memory-mapped data pipeline
|
|
328
|
+
│ ├── metrics.py # R², Pearson, visualization
|
|
329
|
+
│ ├── distributed.py # DDP synchronization
|
|
330
|
+
│ ├── losses.py # Loss function factory
|
|
331
|
+
│ ├── optimizers.py # Optimizer factory
|
|
332
|
+
│ ├── schedulers.py # LR scheduler factory
|
|
333
|
+
│ └── config.py # YAML configuration support
|
|
334
|
+
│
|
|
335
|
+
├── run_training.sh # HPC helper script
|
|
336
|
+
├── configs/ # YAML config templates
|
|
337
|
+
├── examples/ # Ready-to-run examples
|
|
338
|
+
├── notebooks/ # Jupyter notebooks
|
|
339
|
+
├── unit_tests/ # Pytest test suite (422 tests)
|
|
340
|
+
│
|
|
341
|
+
├── pyproject.toml # Package config, dependencies
|
|
342
|
+
├── CHANGELOG.md # Version history
|
|
343
|
+
└── CITATION.cff # Citation metadata
|
|
344
|
+
```
|
|
345
|
+
---
|
|
346
|
+
|
|
347
|
+
## ⚙️ Configuration
|
|
348
|
+
|
|
349
|
+
> [!NOTE]
|
|
350
|
+
> All configuration options below work with **both** `run_training.sh` and direct `accelerate launch`. The wrapper script passes all arguments directly to `train.py`.
|
|
351
|
+
>
|
|
352
|
+
> **Examples:**
|
|
353
|
+
> ```bash
|
|
354
|
+
> # Using run_training.sh
|
|
355
|
+
> ./run_training.sh --model cnn --batch_size 256 --lr 5e-4 --compile
|
|
356
|
+
>
|
|
357
|
+
> # Using accelerate launch directly
|
|
358
|
+
> accelerate launch -m wavedl.train --model cnn --batch_size 256 --lr 5e-4 --compile
|
|
359
|
+
> ```
|
|
360
|
+
|
|
361
|
+
<details>
|
|
362
|
+
<summary><b>Available Models</b> — 21 pre-built architectures</summary>
|
|
363
|
+
|
|
364
|
+
| Model | Best For | Params (2D) | Dimensionality |
|
|
365
|
+
|-------|----------|-------------|----------------|
|
|
366
|
+
| `cnn` | Baseline, lightweight | 1.7M | 1D/2D/3D |
|
|
367
|
+
| `resnet18` | Fast training, smaller datasets | 11.4M | 1D/2D/3D |
|
|
368
|
+
| `resnet34` | Balanced performance | 21.5M | 1D/2D/3D |
|
|
369
|
+
| `resnet50` | High capacity, complex patterns | 24.6M | 1D/2D/3D |
|
|
370
|
+
| `resnet18_pretrained` | **Transfer learning** ⭐ | 11.4M | 2D only |
|
|
371
|
+
| `resnet50_pretrained` | **Transfer learning** ⭐ | 24.6M | 2D only |
|
|
372
|
+
| `efficientnet_b0` | Efficient, **pretrained** ⭐ | 4.7M | 2D only |
|
|
373
|
+
| `efficientnet_b1` | Efficient, **pretrained** ⭐ | 7.2M | 2D only |
|
|
374
|
+
| `efficientnet_b2` | Efficient, **pretrained** ⭐ | 8.4M | 2D only |
|
|
375
|
+
| `vit_tiny` | Transformer, small datasets | 5.4M | 1D/2D |
|
|
376
|
+
| `vit_small` | Transformer, balanced | 21.5M | 1D/2D |
|
|
377
|
+
| `vit_base` | Transformer, high capacity | 85.5M | 1D/2D |
|
|
378
|
+
| `convnext_tiny` | Modern CNN, transformer-inspired | 28.2M | 1D/2D/3D |
|
|
379
|
+
| `convnext_tiny_pretrained` | **Transfer learning** ⭐ | 28.2M | 2D only |
|
|
380
|
+
| `convnext_small` | Modern CNN, balanced | 49.8M | 1D/2D/3D |
|
|
381
|
+
| `convnext_base` | Modern CNN, high capacity | 88.1M | 1D/2D/3D |
|
|
382
|
+
| `densenet121` | Feature reuse, small data | 7.5M | 1D/2D/3D |
|
|
383
|
+
| `densenet121_pretrained` | **Transfer learning** ⭐ | 7.5M | 2D only |
|
|
384
|
+
| `densenet169` | Deeper DenseNet | 13.3M | 1D/2D/3D |
|
|
385
|
+
| `unet` | Spatial output (velocity fields) | 31.0M | 1D/2D/3D |
|
|
386
|
+
| `unet_regression` | Multi-scale features for regression | 31.1M | 1D/2D/3D |
|
|
387
|
+
|
|
388
|
+
> ⭐ **Pretrained models** use ImageNet weights for transfer learning.
|
|
389
|
+
|
|
390
|
+
</details>
|
|
391
|
+
|
|
392
|
+
<details>
|
|
393
|
+
<summary><b>Training Parameters</b></summary>
|
|
394
|
+
|
|
395
|
+
| Argument | Default | Description |
|
|
396
|
+
|----------|---------|-------------|
|
|
397
|
+
| `--model` | `cnn` | Model architecture |
|
|
398
|
+
| `--batch_size` | `128` | Per-GPU batch size |
|
|
399
|
+
| `--lr` | `1e-3` | Learning rate |
|
|
400
|
+
| `--epochs` | `1000` | Maximum epochs |
|
|
401
|
+
| `--patience` | `20` | Early stopping patience |
|
|
402
|
+
| `--weight_decay` | `1e-4` | AdamW regularization |
|
|
403
|
+
| `--grad_clip` | `1.0` | Gradient clipping |
|
|
404
|
+
|
|
405
|
+
</details>
|
|
406
|
+
|
|
407
|
+
<details>
|
|
408
|
+
<summary><b>Data & I/O</b></summary>
|
|
409
|
+
|
|
410
|
+
| Argument | Default | Description |
|
|
411
|
+
|----------|---------|-------------|
|
|
412
|
+
| `--data_path` | `train_data.npz` | Dataset path |
|
|
413
|
+
| `--workers` | `-1` | DataLoader workers per GPU (-1=auto-detect) |
|
|
414
|
+
| `--seed` | `2025` | Random seed |
|
|
415
|
+
| `--output_dir` | `.` | Output directory for checkpoints |
|
|
416
|
+
| `--resume` | `None` | Checkpoint to resume (auto-detected if not set) |
|
|
417
|
+
| `--save_every` | `50` | Checkpoint frequency |
|
|
418
|
+
| `--fresh` | `False` | Force fresh training, ignore existing checkpoints |
|
|
419
|
+
| `--single_channel` | `False` | Confirm data is single-channel (for shallow 3D volumes like `(8, 128, 128)`) |
|
|
420
|
+
|
|
421
|
+
</details>
|
|
422
|
+
|
|
423
|
+
<details>
|
|
424
|
+
<summary><b>Performance</b></summary>
|
|
425
|
+
|
|
426
|
+
| Argument | Default | Description |
|
|
427
|
+
|----------|---------|-------------|
|
|
428
|
+
| `--compile` | `False` | Enable `torch.compile` |
|
|
429
|
+
| `--precision` | `bf16` | Mixed precision mode (`bf16`, `fp16`, `no`) |
|
|
430
|
+
| `--wandb` | `False` | Enable W&B logging |
|
|
431
|
+
| `--project_name` | `DL-Training` | W&B project name |
|
|
432
|
+
| `--run_name` | `None` | W&B run name (auto-generated if not set) |
|
|
433
|
+
|
|
434
|
+
</details>
|
|
435
|
+
|
|
436
|
+
<details>
|
|
437
|
+
<summary><b>Environment Variables (run_training.sh)</b></summary>
|
|
438
|
+
|
|
439
|
+
| Variable | Default | Description |
|
|
440
|
+
|----------|---------|-------------|
|
|
441
|
+
| `NUM_GPUS` | **Auto-detected** | Number of GPUs to use. By default, automatically detected via `nvidia-smi`. Set explicitly to override (e.g., `NUM_GPUS=2`) |
|
|
442
|
+
| `NUM_MACHINES` | `1` | Number of machines in distributed setup |
|
|
443
|
+
| `MIXED_PRECISION` | `bf16` | Precision mode: `bf16`, `fp16`, or `no` |
|
|
444
|
+
| `DYNAMO_BACKEND` | `no` | PyTorch Dynamo backend |
|
|
445
|
+
| `WANDB_MODE` | `offline` | WandB mode: `offline` or `online` |
|
|
446
|
+
|
|
447
|
+
</details>
|
|
448
|
+
|
|
449
|
+
<details>
|
|
450
|
+
<summary><b>Loss Functions</b></summary>
|
|
451
|
+
|
|
452
|
+
| Loss | Flag | Best For | Notes |
|
|
453
|
+
|------|------|----------|-------|
|
|
454
|
+
| `mse` | `--loss mse` | Default, smooth gradients | Standard Mean Squared Error |
|
|
455
|
+
| `mae` | `--loss mae` | Outlier-robust, linear penalty | Mean Absolute Error (L1) |
|
|
456
|
+
| `huber` | `--loss huber --huber_delta 1.0` | Best of MSE + MAE | Robust, smooth transition |
|
|
457
|
+
| `smooth_l1` | `--loss smooth_l1` | Similar to Huber | PyTorch native implementation |
|
|
458
|
+
| `log_cosh` | `--loss log_cosh` | Smooth approximation to MAE | Differentiable everywhere |
|
|
459
|
+
| `weighted_mse` | `--loss weighted_mse --loss_weights "2.0,1.0,1.0"` | Prioritize specific targets | Per-target weighting |
|
|
460
|
+
|
|
461
|
+
**Example:**
|
|
462
|
+
```bash
|
|
463
|
+
# Use Huber loss for noisy NDE data
|
|
464
|
+
accelerate launch -m wavedl.train --model cnn --loss huber --huber_delta 0.5
|
|
465
|
+
|
|
466
|
+
# Weighted MSE: prioritize thickness (first target)
|
|
467
|
+
accelerate launch -m wavedl.train --model cnn --loss weighted_mse --loss_weights "2.0,1.0,1.0"
|
|
468
|
+
```
|
|
469
|
+
|
|
470
|
+
</details>
|
|
471
|
+
|
|
472
|
+
<details>
|
|
473
|
+
<summary><b>Optimizers</b></summary>
|
|
474
|
+
|
|
475
|
+
| Optimizer | Flag | Best For | Key Parameters |
|
|
476
|
+
|-----------|------|----------|----------------|
|
|
477
|
+
| `adamw` | `--optimizer adamw` | Default, most cases | `--betas "0.9,0.999"` |
|
|
478
|
+
| `adam` | `--optimizer adam` | Legacy compatibility | `--betas "0.9,0.999"` |
|
|
479
|
+
| `sgd` | `--optimizer sgd` | Better generalization | `--momentum 0.9 --nesterov` |
|
|
480
|
+
| `nadam` | `--optimizer nadam` | Adam + Nesterov | Faster convergence |
|
|
481
|
+
| `radam` | `--optimizer radam` | Variance-adaptive | More stable training |
|
|
482
|
+
| `rmsprop` | `--optimizer rmsprop` | RNN/LSTM models | `--momentum 0.9` |
|
|
483
|
+
|
|
484
|
+
**Example:**
|
|
485
|
+
```bash
|
|
486
|
+
# SGD with Nesterov momentum (often better generalization)
|
|
487
|
+
accelerate launch -m wavedl.train --model cnn --optimizer sgd --lr 0.01 --momentum 0.9 --nesterov
|
|
488
|
+
|
|
489
|
+
# RAdam for more stable training
|
|
490
|
+
accelerate launch -m wavedl.train --model cnn --optimizer radam --lr 1e-3
|
|
491
|
+
```
|
|
492
|
+
|
|
493
|
+
</details>
|
|
494
|
+
|
|
495
|
+
<details>
|
|
496
|
+
<summary><b>Learning Rate Schedulers</b></summary>
|
|
497
|
+
|
|
498
|
+
| Scheduler | Flag | Best For | Key Parameters |
|
|
499
|
+
|-----------|------|----------|----------------|
|
|
500
|
+
| `plateau` | `--scheduler plateau` | Default, adaptive | `--scheduler_patience 10 --scheduler_factor 0.5` |
|
|
501
|
+
| `cosine` | `--scheduler cosine` | Long training, smooth decay | `--min_lr 1e-6` |
|
|
502
|
+
| `cosine_restarts` | `--scheduler cosine_restarts` | Escape local minima | Warm restarts |
|
|
503
|
+
| `onecycle` | `--scheduler onecycle` | Fast convergence | Super-convergence |
|
|
504
|
+
| `step` | `--scheduler step` | Simple decay | `--step_size 30 --scheduler_factor 0.1` |
|
|
505
|
+
| `multistep` | `--scheduler multistep` | Custom milestones | `--milestones "30,60,90"` |
|
|
506
|
+
| `exponential` | `--scheduler exponential` | Continuous decay | `--scheduler_factor 0.95` |
|
|
507
|
+
| `linear_warmup` | `--scheduler linear_warmup` | Warmup phase | `--warmup_epochs 5` |
|
|
508
|
+
|
|
509
|
+
**Example:**
|
|
510
|
+
```bash
|
|
511
|
+
# Cosine annealing for 1000 epochs
|
|
512
|
+
accelerate launch -m wavedl.train --model cnn --scheduler cosine --epochs 1000 --min_lr 1e-7
|
|
513
|
+
|
|
514
|
+
# OneCycleLR for super-convergence
|
|
515
|
+
accelerate launch -m wavedl.train --model cnn --scheduler onecycle --lr 1e-2 --epochs 50
|
|
516
|
+
|
|
517
|
+
# MultiStep with custom milestones
|
|
518
|
+
accelerate launch -m wavedl.train --model cnn --scheduler multistep --milestones "100,200,300"
|
|
519
|
+
```
|
|
520
|
+
|
|
521
|
+
</details>
|
|
522
|
+
|
|
523
|
+
<details>
|
|
524
|
+
<summary><b>Cross-Validation</b></summary>
|
|
525
|
+
|
|
526
|
+
For robust model evaluation, simply add the `--cv` flag:
|
|
527
|
+
|
|
528
|
+
```bash
|
|
529
|
+
# 5-fold cross-validation (works with both methods!)
|
|
530
|
+
./run_training.sh --model cnn --cv 5 --data_path train_data.npz
|
|
531
|
+
# OR
|
|
532
|
+
accelerate launch -m wavedl.train --model cnn --cv 5 --data_path train_data.npz
|
|
533
|
+
|
|
534
|
+
# Stratified CV (recommended for unbalanced data)
|
|
535
|
+
./run_training.sh --model cnn --cv 5 --cv_stratify --loss huber --epochs 100
|
|
536
|
+
|
|
537
|
+
# Full configuration
|
|
538
|
+
./run_training.sh --model cnn --cv 5 --cv_stratify \
|
|
539
|
+
--loss huber --optimizer adamw --scheduler cosine \
|
|
540
|
+
--output_dir ./cv_results
|
|
541
|
+
```
|
|
542
|
+
|
|
543
|
+
| Argument | Default | Description |
|
|
544
|
+
|----------|---------|-------------|
|
|
545
|
+
| `--cv` | `0` | Number of CV folds (0=disabled, normal training) |
|
|
546
|
+
| `--cv_stratify` | `False` | Use stratified splitting (bins targets) |
|
|
547
|
+
| `--cv_bins` | `10` | Number of bins for stratified CV |
|
|
548
|
+
|
|
549
|
+
**Output:**
|
|
550
|
+
- `cv_summary.json`: Aggregated metrics (mean ± std)
|
|
551
|
+
- `cv_results.csv`: Per-fold detailed results
|
|
552
|
+
- `fold_*/`: Individual fold models and scalers
|
|
553
|
+
|
|
554
|
+
</details>
|
|
555
|
+
|
|
556
|
+
<details>
|
|
557
|
+
<summary><b>Configuration Files (YAML)</b></summary>
|
|
558
|
+
|
|
559
|
+
Use YAML files for reproducible experiments. CLI arguments can override any config value.
|
|
560
|
+
|
|
561
|
+
```bash
|
|
562
|
+
# Use a config file
|
|
563
|
+
accelerate launch -m wavedl.train --config configs/config.yaml --data_path train.npz
|
|
564
|
+
|
|
565
|
+
# Override specific values from config
|
|
566
|
+
accelerate launch -m wavedl.train --config configs/config.yaml --lr 5e-4 --epochs 500
|
|
567
|
+
```
|
|
568
|
+
|
|
569
|
+
**Example config (`configs/config.yaml`):**
|
|
570
|
+
```yaml
|
|
571
|
+
# Model & Training
|
|
572
|
+
model: cnn
|
|
573
|
+
batch_size: 128
|
|
574
|
+
lr: 0.001
|
|
575
|
+
epochs: 1000
|
|
576
|
+
patience: 20
|
|
577
|
+
|
|
578
|
+
# Loss, Optimizer, Scheduler
|
|
579
|
+
loss: mse
|
|
580
|
+
optimizer: adamw
|
|
581
|
+
scheduler: plateau
|
|
582
|
+
|
|
583
|
+
# Cross-Validation (0 = disabled)
|
|
584
|
+
cv: 0
|
|
585
|
+
|
|
586
|
+
# Performance
|
|
587
|
+
precision: bf16
|
|
588
|
+
compile: false
|
|
589
|
+
seed: 2025
|
|
590
|
+
```
|
|
591
|
+
|
|
592
|
+
> [!TIP]
|
|
593
|
+
> See [`configs/config.yaml`](configs/config.yaml) for the complete template with all available options documented.
|
|
594
|
+
|
|
595
|
+
</details>
|
|
596
|
+
|
|
597
|
+
<details>
|
|
598
|
+
<summary><b>Hyperparameter Search (HPO)</b></summary>
|
|
599
|
+
|
|
600
|
+
Automatically find the best training configuration using [Optuna](https://optuna.org/).
|
|
601
|
+
|
|
602
|
+
**Step 1: Install**
|
|
603
|
+
```bash
|
|
604
|
+
pip install -e ".[hpo]"
|
|
605
|
+
```
|
|
606
|
+
|
|
607
|
+
**Step 2: Run HPO**
|
|
608
|
+
|
|
609
|
+
You specify which models to search and how many trials to run:
|
|
610
|
+
```bash
|
|
611
|
+
# Search 3 models with 100 trials
|
|
612
|
+
python -m wavedl.hpo --data_path train.npz --models cnn resnet18 efficientnet_b0 --n_trials 100
|
|
613
|
+
|
|
614
|
+
# Search 1 model (faster)
|
|
615
|
+
python -m wavedl.hpo --data_path train.npz --models cnn --n_trials 50
|
|
616
|
+
|
|
617
|
+
# Search all your candidate models
|
|
618
|
+
python -m wavedl.hpo --data_path train.npz --models cnn resnet18 resnet50 vit_small densenet121 --n_trials 200
|
|
619
|
+
```
|
|
620
|
+
|
|
621
|
+
**Step 3: Train with best parameters**
|
|
622
|
+
|
|
623
|
+
After HPO completes, it prints the optimal command:
|
|
624
|
+
```bash
|
|
625
|
+
accelerate launch -m wavedl.train --data_path train.npz --model cnn --lr 3.2e-4 --batch_size 128 ...
|
|
626
|
+
```
|
|
627
|
+
|
|
628
|
+
---
|
|
629
|
+
|
|
630
|
+
**What Gets Searched:**
|
|
631
|
+
|
|
632
|
+
| Parameter | Default | You Can Override With |
|
|
633
|
+
|-----------|---------|----------------------|
|
|
634
|
+
| Models | cnn, resnet18, resnet34 | `--models X Y Z` |
|
|
635
|
+
| Optimizers | [all 6](#optimizers) | `--optimizers X Y` |
|
|
636
|
+
| Schedulers | [all 8](#learning-rate-schedulers) | `--schedulers X Y` |
|
|
637
|
+
| Losses | [all 6](#loss-functions) | `--losses X Y` |
|
|
638
|
+
| Learning rate | 1e-5 → 1e-2 | (always searched) |
|
|
639
|
+
| Batch size | 64, 128, 256, 512 | (always searched) |
|
|
640
|
+
|
|
641
|
+
**Quick Mode** (`--quick`):
|
|
642
|
+
- Uses minimal defaults: cnn + adamw + plateau + mse
|
|
643
|
+
- Faster for testing your setup before running full search
|
|
644
|
+
- You can still override any option with the flags above
|
|
645
|
+
|
|
646
|
+
---
|
|
647
|
+
|
|
648
|
+
**All Arguments:**
|
|
649
|
+
|
|
650
|
+
| Argument | Default | Description |
|
|
651
|
+
|----------|---------|-------------|
|
|
652
|
+
| `--data_path` | (required) | Training data file |
|
|
653
|
+
| `--models` | 3 defaults | Models to search (specify any number) |
|
|
654
|
+
| `--n_trials` | `50` | Number of trials to run |
|
|
655
|
+
| `--quick` | `False` | Use minimal defaults (faster) |
|
|
656
|
+
| `--optimizers` | all 6 | Optimizers to search |
|
|
657
|
+
| `--schedulers` | all 8 | Schedulers to search |
|
|
658
|
+
| `--losses` | all 6 | Losses to search |
|
|
659
|
+
| `--n_jobs` | `1` | Parallel trials (multi-GPU) |
|
|
660
|
+
| `--max_epochs` | `50` | Max epochs per trial |
|
|
661
|
+
| `--output` | `hpo_results.json` | Output file |
|
|
662
|
+
|
|
663
|
+
> [!TIP]
|
|
664
|
+
> See [Available Models](#available-models) for all 21 architectures you can search.
|
|
665
|
+
|
|
666
|
+
</details>
|
|
667
|
+
|
|
668
|
+
---
|
|
669
|
+
|
|
670
|
+
## 📈 Data Preparation
|
|
671
|
+
|
|
672
|
+
WaveDL supports multiple data formats for training and inference:
|
|
673
|
+
|
|
674
|
+
| Format | Extension | Key Advantages |
|
|
675
|
+
|--------|-----------|----------------|
|
|
676
|
+
| **NPZ** | `.npz` | Native NumPy, fast loading, recommended |
|
|
677
|
+
| **HDF5** | `.h5`, `.hdf5` | Large datasets, hierarchical, cross-platform |
|
|
678
|
+
| **MAT** | `.mat` | MATLAB compatibility (**v7.3+ only**, saved with `-v7.3` flag) |
|
|
679
|
+
|
|
680
|
+
**The framework automatically detects file format and data dimensionality** (1D, 2D, or 3D) — you only need to provide the appropriate model architecture.
|
|
681
|
+
|
|
682
|
+
| Key | Shape | Type | Description |
|
|
683
|
+
|-----|-------|------|-------------|
|
|
684
|
+
| `input_train` / `input_test` | `(N, L)`, `(N, H, W)`, or `(N, D, H, W)` | `float32` | N samples of 1D/2D/3D representations |
|
|
685
|
+
| `output_train` / `output_test` | `(N, T)` | `float32` | N samples with T regression targets |
|
|
686
|
+
|
|
687
|
+
> [!TIP]
|
|
688
|
+
> - **Flexible Key Names**: WaveDL auto-detects common key pairs:
|
|
689
|
+
> - `input_train`/`output_train`, `input_test`/`output_test` (WaveDL standard)
|
|
690
|
+
> - `X`/`Y`, `x`/`y` (ML convention)
|
|
691
|
+
> - `data`/`labels`, `inputs`/`outputs`, `features`/`targets`
|
|
692
|
+
> - **Automatic Dimension Detection**: Channel dimension is added automatically. No manual reshaping required!
|
|
693
|
+
> - **Sparse Matrix Support**: NPZ and MAT v7.3 files with scipy/MATLAB sparse matrices are automatically converted to dense arrays.
|
|
694
|
+
> - **Auto-Normalization**: Target values are automatically standardized during training. MAE is reported in original physical units.
|
|
695
|
+
|
|
696
|
+
> [!IMPORTANT]
|
|
697
|
+
> **MATLAB Users**: MAT files must be saved with the `-v7.3` flag for memory-efficient loading:
|
|
698
|
+
> ```matlab
|
|
699
|
+
> save('data.mat', 'input_train', 'output_train', '-v7.3')
|
|
700
|
+
> ```
|
|
701
|
+
> Older MAT formats (v5/v7) are not supported. Convert to NPZ for best compatibility.
|
|
702
|
+
|
|
703
|
+
<details>
|
|
704
|
+
<summary><b>Example: Basic Preparation</b></summary>
|
|
705
|
+
|
|
706
|
+
```python
|
|
707
|
+
import numpy as np
|
|
708
|
+
|
|
709
|
+
X = np.array(images, dtype=np.float32) # (N, H, W)
|
|
710
|
+
y = np.array(labels, dtype=np.float32) # (N, T)
|
|
711
|
+
|
|
712
|
+
np.savez('train_data.npz', input_train=X, output_train=y)
|
|
713
|
+
```
|
|
714
|
+
|
|
715
|
+
</details>
|
|
716
|
+
|
|
717
|
+
<details>
|
|
718
|
+
<summary><b>Example: From Image Files + CSV</b></summary>
|
|
719
|
+
|
|
720
|
+
```python
|
|
721
|
+
import numpy as np
|
|
722
|
+
from PIL import Image
|
|
723
|
+
from pathlib import Path
|
|
724
|
+
import pandas as pd
|
|
725
|
+
|
|
726
|
+
# Load images
|
|
727
|
+
images = [np.array(Image.open(f).convert('L'), dtype=np.float32)
|
|
728
|
+
for f in sorted(Path("images/").glob("*.png"))]
|
|
729
|
+
X = np.stack(images)
|
|
730
|
+
|
|
731
|
+
# Load labels
|
|
732
|
+
y = pd.read_csv("labels.csv").values.astype(np.float32)
|
|
733
|
+
|
|
734
|
+
np.savez('train_data.npz', input_train=X, output_train=y)
|
|
735
|
+
```
|
|
736
|
+
|
|
737
|
+
</details>
|
|
738
|
+
|
|
739
|
+
<details>
|
|
740
|
+
<summary><b>Example: From MATLAB (.mat)</b></summary>
|
|
741
|
+
|
|
742
|
+
```python
|
|
743
|
+
import numpy as np
|
|
744
|
+
from scipy.io import loadmat
|
|
745
|
+
|
|
746
|
+
data = loadmat('simulation_data.mat')
|
|
747
|
+
X = data['spectrograms'].astype(np.float32) # Adjust key
|
|
748
|
+
y = data['parameters'].astype(np.float32)
|
|
749
|
+
|
|
750
|
+
# Transpose if needed: (H, W, N) → (N, H, W)
|
|
751
|
+
if X.ndim == 3 and X.shape[2] < X.shape[0]:
|
|
752
|
+
X = np.transpose(X, (2, 0, 1))
|
|
753
|
+
|
|
754
|
+
np.savez('train_data.npz', input_train=X, output_train=y)
|
|
755
|
+
```
|
|
756
|
+
|
|
757
|
+
</details>
|
|
758
|
+
|
|
759
|
+
<details>
|
|
760
|
+
<summary><b>Example: Synthetic Test Data</b></summary>
|
|
761
|
+
|
|
762
|
+
```python
|
|
763
|
+
import numpy as np
|
|
764
|
+
|
|
765
|
+
X = np.random.randn(1000, 256, 256).astype(np.float32)
|
|
766
|
+
y = np.random.randn(1000, 5).astype(np.float32)
|
|
767
|
+
|
|
768
|
+
np.savez('test_data.npz', input_train=X, output_train=y)
|
|
769
|
+
```
|
|
770
|
+
|
|
771
|
+
</details>
|
|
772
|
+
|
|
773
|
+
<details>
|
|
774
|
+
<summary><b>Validation Script</b></summary>
|
|
775
|
+
|
|
776
|
+
```python
|
|
777
|
+
import numpy as np
|
|
778
|
+
|
|
779
|
+
data = np.load('train_data.npz')
|
|
780
|
+
assert data['input_train'].ndim == 3, "Input must be 3D: (N, H, W)"
|
|
781
|
+
assert data['output_train'].ndim == 2, "Output must be 2D: (N, T)"
|
|
782
|
+
assert len(data['input_train']) == len(data['output_train']), "Sample mismatch"
|
|
783
|
+
|
|
784
|
+
print(f"✓ Input: {data['input_train'].shape} {data['input_train'].dtype}")
|
|
785
|
+
print(f"✓ Output: {data['output_train'].shape} {data['output_train'].dtype}")
|
|
786
|
+
```
|
|
787
|
+
|
|
788
|
+
</details>
|
|
789
|
+
|
|
790
|
+
|
|
791
|
+
---
|
|
792
|
+
|
|
793
|
+
## 📦 Examples [](https://colab.research.google.com/github/ductho-le/WaveDL/blob/main/notebooks/demo.ipynb)
|
|
794
|
+
|
|
795
|
+
The `examples/` folder contains a **complete, ready-to-run example** for **material characterization of isotropic plates**. The pre-trained CNN predicts three physical parameters from Lamb wave dispersion curves:
|
|
796
|
+
|
|
797
|
+
| Parameter | Unit | Description |
|
|
798
|
+
|-----------|------|-------------|
|
|
799
|
+
| *h* | mm | Plate thickness |
|
|
800
|
+
| √(*E*/ρ) | km/s | Square root of Young's modulus over density |
|
|
801
|
+
| *ν* | — | Poisson's ratio |
|
|
802
|
+
|
|
803
|
+
> [!NOTE]
|
|
804
|
+
> This example is based on our paper at **SPIE Smart Structures + NDE 2026**: [*"Deep learning-based ultrasonic assessment of plate thickness and elasticity"*](https://spie.org/spie-smart-structures-and-materials-nondestructive-evaluation/presentation/Deep-learningbased-ultrasonic-assessment-of-plate-thickness-and-elasticity/13951-4) (Paper 13951-4, to appear).
|
|
805
|
+
|
|
806
|
+
**Try it yourself:**
|
|
807
|
+
|
|
808
|
+
```bash
|
|
809
|
+
# Run inference on the example data
|
|
810
|
+
python -m wavedl.test --checkpoint ./examples/elastic_cnn_example/best_checkpoint \
|
|
811
|
+
--data_path ./examples/elastic_cnn_example/Test_data_100.mat \
|
|
812
|
+
--plot --save_predictions --output_dir ./examples/elastic_cnn_example/test_results
|
|
813
|
+
|
|
814
|
+
# Export to ONNX (already included as model.onnx)
|
|
815
|
+
python -m wavedl.test --checkpoint ./examples/elastic_cnn_example/best_checkpoint \
|
|
816
|
+
--data_path ./examples/elastic_cnn_example/Test_data_100.mat \
|
|
817
|
+
--export onnx --export_path ./examples/elastic_cnn_example/model.onnx
|
|
818
|
+
```
|
|
819
|
+
|
|
820
|
+
**What's Included:**
|
|
821
|
+
|
|
822
|
+
| File | Description |
|
|
823
|
+
|------|-------------|
|
|
824
|
+
| `best_checkpoint/` | Pre-trained CNN checkpoint |
|
|
825
|
+
| `Test_data_100.mat` | 100 sample test set (500×500 dispersion curves → *h*, √(*E*/ρ), *ν*) |
|
|
826
|
+
| `model.onnx` | ONNX export with embedded de-normalization |
|
|
827
|
+
| `training_history.csv` | Epoch-by-epoch training metrics (loss, R², LR, etc.) |
|
|
828
|
+
| `training_curves.png` | Training/validation loss and learning rate plot |
|
|
829
|
+
| `test_results/` | Example predictions and diagnostic plots |
|
|
830
|
+
| `WaveDL_ONNX_Inference.m` | MATLAB script for ONNX inference |
|
|
831
|
+
|
|
832
|
+
**Training Progress:**
|
|
833
|
+
|
|
834
|
+
<p align="center">
|
|
835
|
+
<img src="examples/elastic_cnn_example/training_curves.png" alt="Training curves" width="600"><br>
|
|
836
|
+
<em>Training and validation loss over 162 epochs with learning rate schedule</em>
|
|
837
|
+
</p>
|
|
838
|
+
|
|
839
|
+
**Inference Results:**
|
|
840
|
+
|
|
841
|
+
<p align="center">
|
|
842
|
+
<img src="examples/elastic_cnn_example/test_results/scatter_all.png" alt="Scatter plot" width="700"><br>
|
|
843
|
+
<em>Figure 1: Predictions vs ground truth for all three elastic parameters</em>
|
|
844
|
+
</p>
|
|
845
|
+
|
|
846
|
+
<p align="center">
|
|
847
|
+
<img src="examples/elastic_cnn_example/test_results/error_histogram.png" alt="Error histogram" width="700"><br>
|
|
848
|
+
<em>Figure 2: Distribution of prediction errors showing near-zero mean bias</em>
|
|
849
|
+
</p>
|
|
850
|
+
|
|
851
|
+
<p align="center">
|
|
852
|
+
<img src="examples/elastic_cnn_example/test_results/residuals.png" alt="Residual plot" width="700"><br>
|
|
853
|
+
<em>Figure 3: Residuals vs predicted values (no heteroscedasticity detected)</em>
|
|
854
|
+
</p>
|
|
855
|
+
|
|
856
|
+
<p align="center">
|
|
857
|
+
<img src="examples/elastic_cnn_example/test_results/bland_altman.png" alt="Bland-Altman plot" width="700"><br>
|
|
858
|
+
<em>Figure 4: Bland-Altman analysis with ±1.96 SD limits of agreement</em>
|
|
859
|
+
</p>
|
|
860
|
+
|
|
861
|
+
<p align="center">
|
|
862
|
+
<img src="examples/elastic_cnn_example/test_results/qq_plot.png" alt="Q-Q plot" width="700"><br>
|
|
863
|
+
<em>Figure 5: Q-Q plots confirming normally distributed prediction errors</em>
|
|
864
|
+
</p>
|
|
865
|
+
|
|
866
|
+
<p align="center">
|
|
867
|
+
<img src="examples/elastic_cnn_example/test_results/error_correlation.png" alt="Error correlation" width="300"><br>
|
|
868
|
+
<em>Figure 6: Error correlation matrix between parameters</em>
|
|
869
|
+
</p>
|
|
870
|
+
|
|
871
|
+
<p align="center">
|
|
872
|
+
<img src="examples/elastic_cnn_example/test_results/relative_error.png" alt="Relative error" width="700"><br>
|
|
873
|
+
<em>Figure 7: Relative error (%) vs true value for each parameter</em>
|
|
874
|
+
</p>
|
|
875
|
+
|
|
876
|
+
<p align="center">
|
|
877
|
+
<img src="examples/elastic_cnn_example/test_results/error_cdf.png" alt="Error CDF" width="500"><br>
|
|
878
|
+
<em>Figure 8: Cumulative error distribution — 95% of predictions within indicated bounds</em>
|
|
879
|
+
</p>
|
|
880
|
+
|
|
881
|
+
<p align="center">
|
|
882
|
+
<img src="examples/elastic_cnn_example/test_results/prediction_vs_index.png" alt="Prediction vs index" width="700"><br>
|
|
883
|
+
<em>Figure 9: True vs predicted values by sample index</em>
|
|
884
|
+
</p>
|
|
885
|
+
|
|
886
|
+
<p align="center">
|
|
887
|
+
<img src="examples/elastic_cnn_example/test_results/error_boxplot.png" alt="Error box plot" width="400"><br>
|
|
888
|
+
<em>Figure 10: Error distribution summary (median, quartiles, outliers)</em>
|
|
889
|
+
</p>
|
|
890
|
+
|
|
891
|
+
---
|
|
892
|
+
|
|
893
|
+
## 🔬 Broader Applications
|
|
894
|
+
|
|
895
|
+
Beyond the material characterization example above, the WaveDL pipeline can be adapted for a wide range of **wave-based inverse problems** across multiple domains:
|
|
896
|
+
|
|
897
|
+
### 🏗️ Non-Destructive Evaluation & Structural Health Monitoring
|
|
898
|
+
|
|
899
|
+
| Application | Input | Output |
|
|
900
|
+
|-------------|-------|--------|
|
|
901
|
+
| Defect Sizing | A-scans, phased array images, FMC/TFM, ... | Crack length, depth, ... |
|
|
902
|
+
| Corrosion Estimation | Thickness maps, resonance spectra, ... | Wall thickness, corrosion rate, ... |
|
|
903
|
+
| Weld Quality Assessment | Phased array images, TOFD, ... | Porosity %, penetration depth, ... |
|
|
904
|
+
| RUL Prediction | Acoustic emission (AE), vibration spectra, ... | Cycles to failure, ... |
|
|
905
|
+
| Damage Localization | Wavefield images, DAS/DVS data, ... | Damage coordinates (x, y, z) |
|
|
906
|
+
|
|
907
|
+
### 🌍 Geophysics & Seismology
|
|
908
|
+
|
|
909
|
+
| Application | Input | Output |
|
|
910
|
+
|-------------|-------|--------|
|
|
911
|
+
| Seismic Inversion | Shot gathers, seismograms, ... | Velocity models, density profiles, ... |
|
|
912
|
+
| Subsurface Characterization | Surface wave dispersion, receiver functions, ... | Layer thickness, shear modulus, ... |
|
|
913
|
+
| Earthquake Source Parameters | Waveforms, spectrograms, ... | Magnitude, depth, focal mechanism, ... |
|
|
914
|
+
| Reservoir Characterization | Reflection seismic, AVO attributes, ... | Porosity, fluid saturation, ... |
|
|
915
|
+
|
|
916
|
+
### 🩺 Biomedical Ultrasound & Elastography
|
|
917
|
+
|
|
918
|
+
| Application | Input | Output |
|
|
919
|
+
|-------------|-------|--------|
|
|
920
|
+
| Tissue Elastography | Shear wave data, strain images, ... | Shear modulus, Young's modulus, ... |
|
|
921
|
+
| Liver Fibrosis Staging | Elastography images, US RF data, ... | Stiffness (kPa), fibrosis score, ... |
|
|
922
|
+
| Tumor Characterization | B-mode + elastography, ARFI data, ... | Lesion stiffness, size, ... |
|
|
923
|
+
| Bone QUS | Axial-transmission signals, ... | Porosity, cortical thickness, elastic modulus ... |
|
|
924
|
+
|
|
925
|
+
> [!NOTE]
|
|
926
|
+
> Adapting WaveDL to these applications requires preparing your own dataset and choosing a suitable model architecture to match your input dimensionality.
|
|
927
|
+
|
|
928
|
+
---
|
|
929
|
+
|
|
930
|
+
## 📚 Documentation
|
|
931
|
+
|
|
932
|
+
| Resource | Description |
|
|
933
|
+
|----------|-------------|
|
|
934
|
+
| Technical Paper | In-depth framework description *(coming soon)* |
|
|
935
|
+
| [`_template.py`](models/_template.py) | Template for new architectures |
|
|
936
|
+
|
|
937
|
+
---
|
|
938
|
+
|
|
939
|
+
## 📜 Citation
|
|
940
|
+
|
|
941
|
+
If you use WaveDL in your research, please cite:
|
|
942
|
+
|
|
943
|
+
```bibtex
|
|
944
|
+
@software{le2025wavedl,
|
|
945
|
+
author = {Le, Ductho},
|
|
946
|
+
title = {{WaveDL}: A Scalable Deep Learning Framework for Wave-Based Inverse Problems},
|
|
947
|
+
year = {2025},
|
|
948
|
+
publisher = {Zenodo},
|
|
949
|
+
doi = {10.5281/zenodo.18012338},
|
|
950
|
+
url = {https://doi.org/10.5281/zenodo.18012338}
|
|
951
|
+
}
|
|
952
|
+
```
|
|
953
|
+
|
|
954
|
+
Or in APA format:
|
|
955
|
+
> Le, D. (2025). *WaveDL: A Scalable Deep Learning Framework for Wave-Based Inverse Problems*. Zenodo. https://doi.org/10.5281/zenodo.18012338
|
|
956
|
+
|
|
957
|
+
---
|
|
958
|
+
|
|
959
|
+
## 🙏 Acknowledgments
|
|
960
|
+
|
|
961
|
+
Ductho Le would like to acknowledge [NSERC](https://www.nserc-crsng.gc.ca/) and [Alberta Innovates](https://albertainnovates.ca/) for supporting his study and research by means of a research assistantship and a graduate doctoral fellowship.
|
|
962
|
+
|
|
963
|
+
This research was enabled in part by support provided by [Compute Ontario](https://www.computeontario.ca/), [Calcul Québec](https://www.calculquebec.ca/), and the [Digital Research Alliance of Canada](https://alliancecan.ca/).
|
|
964
|
+
|
|
965
|
+
<br>
|
|
966
|
+
|
|
967
|
+
<p align="center">
|
|
968
|
+
<a href="https://www.ualberta.ca/"><img src="logos/ualberta.png" alt="University of Alberta" height="60"></a>
|
|
969
|
+
  
|
|
970
|
+
<a href="https://albertainnovates.ca/"><img src="logos/alberta_innovates.png" alt="Alberta Innovates" height="60"></a>
|
|
971
|
+
  
|
|
972
|
+
<a href="https://www.nserc-crsng.gc.ca/"><img src="logos/nserc.png" alt="NSERC" height="60"></a>
|
|
973
|
+
</p>
|
|
974
|
+
|
|
975
|
+
<p align="center">
|
|
976
|
+
<a href="https://alliancecan.ca/"><img src="logos/drac.png" alt="Digital Research Alliance of Canada" height="50"></a>
|
|
977
|
+
</p>
|
|
978
|
+
|
|
979
|
+
---
|
|
980
|
+
|
|
981
|
+
<div align="center">
|
|
982
|
+
|
|
983
|
+
**[Ductho Le](mailto:ductho.le@outlook.com)** · University of Alberta
|
|
984
|
+
|
|
985
|
+
[](https://orcid.org/0000-0002-3073-1416)
|
|
986
|
+
[](https://scholar.google.ca/citations?user=OlwMr9AAAAAJ)
|
|
987
|
+
[](https://www.researchgate.net/profile/Ductho-Le)
|
|
988
|
+
|
|
989
|
+
<sub>Released under the MIT License</sub>
|
|
990
|
+
|
|
991
|
+
</div>
|