wavedl 1.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,991 @@
1
+ Metadata-Version: 2.2
2
+ Name: wavedl
3
+ Version: 1.2.0
4
+ Summary: A Scalable Deep Learning Framework for Wave-Based Inverse Problems
5
+ Author: Ductho Le
6
+ License: MIT
7
+ Project-URL: Homepage, https://github.com/ductho-le/WaveDL
8
+ Project-URL: Repository, https://github.com/ductho-le/WaveDL
9
+ Project-URL: Documentation, https://github.com/ductho-le/WaveDL#readme
10
+ Project-URL: Issues, https://github.com/ductho-le/WaveDL/issues
11
+ Keywords: deep-learning,inverse-problems,wave-propagation,ultrasonic,guided-waves,non-destructive-testing,machine-learning,pytorch,regression
12
+ Classifier: Development Status :: 4 - Beta
13
+ Classifier: Intended Audience :: Science/Research
14
+ Classifier: License :: OSI Approved :: MIT License
15
+ Classifier: Programming Language :: Python :: 3
16
+ Classifier: Programming Language :: Python :: 3.11
17
+ Classifier: Programming Language :: Python :: 3.12
18
+ Classifier: Programming Language :: Python :: 3.13
19
+ Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
20
+ Classifier: Topic :: Scientific/Engineering :: Physics
21
+ Requires-Python: >=3.11
22
+ Description-Content-Type: text/markdown
23
+ License-File: LICENSE
24
+ Requires-Dist: torch>=2.0.0
25
+ Requires-Dist: torchvision>=0.15.0
26
+ Requires-Dist: accelerate>=0.20.0
27
+ Requires-Dist: numpy>=1.24.0
28
+ Requires-Dist: scipy>=1.10.0
29
+ Requires-Dist: scikit-learn>=1.2.0
30
+ Requires-Dist: pandas>=2.0.0
31
+ Requires-Dist: matplotlib>=3.7.0
32
+ Requires-Dist: tqdm>=4.65.0
33
+ Requires-Dist: wandb>=0.15.0
34
+ Requires-Dist: pyyaml>=6.0.0
35
+ Requires-Dist: h5py>=3.8.0
36
+ Requires-Dist: safetensors>=0.3.0
37
+ Provides-Extra: dev
38
+ Requires-Dist: pytest>=7.0.0; extra == "dev"
39
+ Requires-Dist: pytest-xdist>=3.5.0; extra == "dev"
40
+ Requires-Dist: ruff>=0.8.0; extra == "dev"
41
+ Requires-Dist: pre-commit>=3.5.0; extra == "dev"
42
+ Provides-Extra: onnx
43
+ Requires-Dist: onnx>=1.14.0; extra == "onnx"
44
+ Requires-Dist: onnxruntime>=1.15.0; extra == "onnx"
45
+ Provides-Extra: compile
46
+ Requires-Dist: triton; extra == "compile"
47
+ Provides-Extra: hpo
48
+ Requires-Dist: optuna>=3.0.0; extra == "hpo"
49
+ Provides-Extra: all
50
+ Requires-Dist: pytest>=7.0.0; extra == "all"
51
+ Requires-Dist: pytest-xdist>=3.5.0; extra == "all"
52
+ Requires-Dist: ruff>=0.8.0; extra == "all"
53
+ Requires-Dist: pre-commit>=3.5.0; extra == "all"
54
+ Requires-Dist: onnx>=1.14.0; extra == "all"
55
+ Requires-Dist: onnxruntime>=1.15.0; extra == "all"
56
+ Requires-Dist: triton; extra == "all"
57
+ Requires-Dist: optuna>=3.0.0; extra == "all"
58
+
59
+ <div align="center">
60
+
61
+ <img src="logos/wavedl_logo.png" alt="WaveDL Logo" width="500">
62
+
63
+ ### A Scalable Deep Learning Framework for Wave-Based Inverse Problems
64
+
65
+ [![Python 3.11+](https://img.shields.io/badge/python-3.11+-blue.svg?style=plastic&logo=python&logoColor=white)](https://www.python.org/downloads/)
66
+ [![PyTorch 2.x](https://img.shields.io/badge/PyTorch-2.x-ee4c2c.svg?style=plastic&logo=pytorch&logoColor=white)](https://pytorch.org/)
67
+ [![Accelerate](https://img.shields.io/badge/Accelerate-Enabled-yellow.svg?style=plastic&logo=huggingface&logoColor=white)](https://huggingface.co/docs/accelerate/)
68
+ <br>
69
+ [![Tests](https://img.shields.io/github/actions/workflow/status/ductho-le/WaveDL/test.yml?branch=main&style=plastic&logo=githubactions&logoColor=white&label=Tests)](https://github.com/ductho-le/WaveDL/actions/workflows/test.yml)
70
+ [![Lint](https://img.shields.io/github/actions/workflow/status/ductho-le/WaveDL/lint.yml?branch=main&style=plastic&logo=ruff&logoColor=white&label=Lint)](https://github.com/ductho-le/WaveDL/actions/workflows/lint.yml)
71
+ [![Try it on Colab](https://img.shields.io/badge/Try_it_on_Colab-8E44AD?style=plastic&logo=googlecolab&logoColor=white)](https://colab.research.google.com/github/ductho-le/WaveDL/blob/main/notebooks/demo.ipynb)
72
+ <br>
73
+ [![License: MIT](https://img.shields.io/badge/License-MIT-orange.svg?style=plastic)](LICENSE)
74
+ [![DOI](https://img.shields.io/badge/DOI-10.5281/zenodo.18012338-008080.svg?style=plastic)](https://doi.org/10.5281/zenodo.18012338)
75
+
76
+ **Production-ready • Multi-GPU DDP • Memory-Efficient • Plug-and-Play**
77
+
78
+ [Getting Started](#-getting-started) •
79
+ [Documentation](#-documentation) •
80
+ [Examples](#-examples) •
81
+ [Discussions](https://github.com/ductho-le/WaveDL/discussions) •
82
+ [Citation](#-citation)
83
+
84
+ ---
85
+
86
+ **Plug in your model, load your data, and let WaveDL do the heavy lifting 💪**
87
+
88
+ </div>
89
+
90
+ ---
91
+
92
+ ## 💡 What is WaveDL?
93
+
94
+ WaveDL is a **deep learning framework** built for **wave-based inverse problems** — from ultrasonic NDE and geophysics to biomedical tissue characterization. It provides a robust, scalable training pipeline for mapping multi-dimensional data (1D/2D/3D) to physical quantities.
95
+
96
+ ```
97
+ Input: Waveforms, spectrograms, B-scans, dispersion curves, ...
98
+
99
+ Output: Material properties, defect dimensions, damage locations, ...
100
+ ```
101
+
102
+ The framework handles the engineering challenges of large-scale deep learning — big datasets, distributed training, and HPC deployment — so you can focus on the science, not the infrastructure.
103
+
104
+ **Built for researchers who need:**
105
+ - 📊 Multi-target regression with reproducibility and fair benchmarking
106
+ - 🚀 Seamless multi-GPU training on HPC clusters
107
+ - 💾 Memory-efficient handling of large-scale datasets
108
+ - 🔧 Easy integration of custom model architectures
109
+
110
+ ---
111
+
112
+ ## ✨ Features
113
+
114
+ <table width="100%">
115
+ <tr>
116
+ <td width="50%" valign="top">
117
+
118
+ **⚡ Load All Data — No More Bottleneck**
119
+
120
+ Train on datasets larger than RAM:
121
+ - Memory-mapped, zero-copy streaming
122
+ - Full random shuffling at GPU speed
123
+ - Your GPU stays fed — always
124
+
125
+ </td>
126
+ <td width="50%" valign="top">
127
+
128
+ **🧠 One-Line Model Registration**
129
+
130
+ Plug in any architecture:
131
+ ```python
132
+ @register_model("my_net")
133
+ class MyNet(BaseModel): ...
134
+ ```
135
+ Design your model. Register with one line.
136
+
137
+ </td>
138
+ </tr>
139
+ <tr>
140
+ <td width="50%" valign="top">
141
+
142
+ **🛡️ DDP That Actually Works**
143
+
144
+ Multi-GPU training without the pain:
145
+ - Synchronized early stopping
146
+ - Deadlock-free checkpointing
147
+ - Correct metric aggregation
148
+
149
+ </td>
150
+ <td width="50%" valign="top">
151
+
152
+ **📊 Publish-Ready Output**
153
+
154
+ Results go straight to your paper:
155
+ - 11 diagnostic plots with LaTeX styling
156
+ - Multi-format export (PNG, PDF, SVG, ...)
157
+ - MAE in physical units per parameter
158
+
159
+ </td>
160
+ </tr>
161
+ <tr>
162
+ <td width="50%" valign="top">
163
+
164
+ **🖥️ HPC-Native Design**
165
+
166
+ Built for high-performance clusters:
167
+ - Automatic GPU detection
168
+ - WandB experiment tracking
169
+ - BF16/FP16 mixed precision
170
+
171
+ </td>
172
+ <td width="50%" valign="top">
173
+
174
+ **🔄 Crash-Proof Training**
175
+
176
+ Never lose your progress:
177
+ - Full state checkpoints
178
+ - Resume from any point
179
+ - Emergency saves on interrupt
180
+
181
+ </td>
182
+ </tr>
183
+ <tr>
184
+ <td width="50%" valign="top">
185
+
186
+ **🎛️ Flexible & Reproducible Training**
187
+
188
+ Fully configurable via CLI flags or YAML:
189
+ - Loss functions, optimizers, schedulers
190
+ - K-fold cross-validation
191
+ - See [Configuration](#️-configuration) for details
192
+
193
+ </td>
194
+ <td width="50%" valign="top">
195
+
196
+ **📦 ONNX Export**
197
+
198
+ Deploy models anywhere:
199
+ - One-command export to ONNX
200
+ - LabVIEW, MATLAB, C++ compatible
201
+ - Validated PyTorch↔ONNX outputs
202
+
203
+ </td>
204
+ </tr>
205
+ </table>
206
+
207
+ ---
208
+
209
+ ## 🚀 Getting Started
210
+
211
+ ### Installation
212
+
213
+ ```bash
214
+ git clone https://github.com/ductho-le/WaveDL.git
215
+ cd WaveDL
216
+
217
+ # Basic install (training + inference)
218
+ pip install -e .
219
+
220
+ # Full install (adds ONNX export, torch.compile, HPO, dev tools)
221
+ pip install -e ".[all]"
222
+ ```
223
+
224
+ > [!NOTE]
225
+ > Dependencies are managed in `pyproject.toml`. Python 3.11+ required.
226
+ >
227
+ > For development setup (running tests, contributing), see [CONTRIBUTING.md](.github/CONTRIBUTING.md).
228
+
229
+ ### Quick Start
230
+
231
+ > [!TIP]
232
+ > In all examples below, replace `<...>` placeholders with your values. See [Configuration](#️-configuration) for defaults and options.
233
+
234
+ #### Option 1: Using the Helper Script (Recommended for HPC)
235
+
236
+ The `run_training.sh` wrapper automatically configures the environment for HPC systems:
237
+
238
+ ```bash
239
+ # Make executable (first time only)
240
+ chmod +x run_training.sh
241
+
242
+ # Basic training (auto-detects available GPUs)
243
+ ./run_training.sh --model <model_name> --data_path <train_data> --batch_size <number> --output_dir <output_folder>
244
+
245
+ # Detailed configuration
246
+ ./run_training.sh --model <model_name> --data_path <train_data> --batch_size <number> \
247
+ --lr <number> --epochs <number> --patience <number> --compile --output_dir <output_folder>
248
+ ```
249
+
250
+ #### Option 2: Direct Accelerate Launch
251
+
252
+ ```bash
253
+ # Local - auto-detects GPUs
254
+ accelerate launch -m wavedl.train --model <model_name> --data_path <train_data> --batch_size <number> --output_dir <output_folder>
255
+
256
+ # Resume training (automatic - just re-run with same output_dir)
257
+ # Manual resume from specific checkpoint:
258
+ accelerate launch -m wavedl.train --model <model_name> --data_path <train_data> --resume <checkpoint_folder> --output_dir <output_folder>
259
+
260
+ # Force fresh start (ignores existing checkpoints)
261
+ accelerate launch -m wavedl.train --model <model_name> --data_path <train_data> --output_dir <output_folder> --fresh
262
+
263
+ # List available models
264
+ python -m wavedl.train --list_models
265
+ ```
266
+
267
+ > [!TIP]
268
+ > **Auto-Resume**: If training crashes or is interrupted, simply re-run with the same `--output_dir`. The framework automatically detects incomplete training and resumes from the last checkpoint. Use `--fresh` to force a fresh start.
269
+ >
270
+ > **GPU Auto-Detection**: By default, `run_training.sh` automatically detects available GPUs using `nvidia-smi`. Set `NUM_GPUS` to override this behavior.
271
+
272
+ ### Testing & Inference
273
+
274
+ After training, use `wavedl.test` to evaluate your model on test data:
275
+
276
+ ```bash
277
+ # Basic inference
278
+ python -m wavedl.test --checkpoint <checkpoint_folder> --data_path <test_data>
279
+
280
+ # With visualization, CSV export, and multiple file formats
281
+ python -m wavedl.test --checkpoint <checkpoint_folder> --data_path <test_data> \
282
+ --plot --plot_format png pdf --save_predictions --output_dir <output_folder>
283
+
284
+ # With custom parameter names
285
+ python -m wavedl.test --checkpoint <checkpoint_folder> --data_path <test_data> \
286
+ --param_names '$p_1$' '$p_2$' '$p_3$' --plot
287
+
288
+ # Export model to ONNX for deployment (LabVIEW, MATLAB, C++, etc.)
289
+ python -m wavedl.test --checkpoint <checkpoint_folder> --data_path <test_data> \
290
+ --export onnx --export_path <output_file.onnx>
291
+ ```
292
+
293
+ **Output:**
294
+ - **Console**: R², Pearson correlation, MAE per parameter
295
+ - **CSV** (with `--save_predictions`): True, predicted, error, and absolute error for all parameters
296
+ - **Plots** (with `--plot`): 10 publication-quality plots (scatter, histogram, residuals, Bland-Altman, Q-Q, correlation, relative error, CDF, index plot, box plot)
297
+ - **Format** (with `--plot_format`): Supported formats: `png` (default), `pdf` (vector), `svg` (vector), `eps` (LaTeX), `tiff`, `jpg`, `ps`
298
+
299
+ > [!NOTE]
300
+ > `wavedl.test` auto-detects the model architecture from checkpoint metadata. If unavailable, it falls back to folder name parsing. Use `--model` to override if needed.
301
+
302
+ ---
303
+
304
+ ## 📁 Project Structure
305
+
306
+ ```
307
+ WaveDL/
308
+ ├── src/
309
+ │ └── wavedl/ # Main package (namespaced)
310
+ │ ├── __init__.py # Package init with __version__
311
+ │ ├── train.py # Training entry point
312
+ │ ├── test.py # Testing & inference script
313
+ │ ├── hpo.py # Hyperparameter optimization
314
+ │ │
315
+ │ ├── models/ # Model architectures
316
+ │ │ ├── registry.py # Model factory (@register_model)
317
+ │ │ ├── base.py # Abstract base class
318
+ │ │ ├── cnn.py # Baseline CNN
319
+ │ │ ├── resnet.py # ResNet-18/34/50 (1D/2D/3D)
320
+ │ │ ├── efficientnet.py# EfficientNet-B0/B1/B2
321
+ │ │ ├── vit.py # Vision Transformer (1D/2D)
322
+ │ │ ├── convnext.py # ConvNeXt (1D/2D/3D)
323
+ │ │ ├── densenet.py # DenseNet-121/169 (1D/2D/3D)
324
+ │ │ └── unet.py # U-Net / U-Net Regression
325
+ │ │
326
+ │ └── utils/ # Utilities
327
+ │ ├── data.py # Memory-mapped data pipeline
328
+ │ ├── metrics.py # R², Pearson, visualization
329
+ │ ├── distributed.py # DDP synchronization
330
+ │ ├── losses.py # Loss function factory
331
+ │ ├── optimizers.py # Optimizer factory
332
+ │ ├── schedulers.py # LR scheduler factory
333
+ │ └── config.py # YAML configuration support
334
+
335
+ ├── run_training.sh # HPC helper script
336
+ ├── configs/ # YAML config templates
337
+ ├── examples/ # Ready-to-run examples
338
+ ├── notebooks/ # Jupyter notebooks
339
+ ├── unit_tests/ # Pytest test suite (422 tests)
340
+
341
+ ├── pyproject.toml # Package config, dependencies
342
+ ├── CHANGELOG.md # Version history
343
+ └── CITATION.cff # Citation metadata
344
+ ```
345
+ ---
346
+
347
+ ## ⚙️ Configuration
348
+
349
+ > [!NOTE]
350
+ > All configuration options below work with **both** `run_training.sh` and direct `accelerate launch`. The wrapper script passes all arguments directly to `train.py`.
351
+ >
352
+ > **Examples:**
353
+ > ```bash
354
+ > # Using run_training.sh
355
+ > ./run_training.sh --model cnn --batch_size 256 --lr 5e-4 --compile
356
+ >
357
+ > # Using accelerate launch directly
358
+ > accelerate launch -m wavedl.train --model cnn --batch_size 256 --lr 5e-4 --compile
359
+ > ```
360
+
361
+ <details>
362
+ <summary><b>Available Models</b> — 21 pre-built architectures</summary>
363
+
364
+ | Model | Best For | Params (2D) | Dimensionality |
365
+ |-------|----------|-------------|----------------|
366
+ | `cnn` | Baseline, lightweight | 1.7M | 1D/2D/3D |
367
+ | `resnet18` | Fast training, smaller datasets | 11.4M | 1D/2D/3D |
368
+ | `resnet34` | Balanced performance | 21.5M | 1D/2D/3D |
369
+ | `resnet50` | High capacity, complex patterns | 24.6M | 1D/2D/3D |
370
+ | `resnet18_pretrained` | **Transfer learning** ⭐ | 11.4M | 2D only |
371
+ | `resnet50_pretrained` | **Transfer learning** ⭐ | 24.6M | 2D only |
372
+ | `efficientnet_b0` | Efficient, **pretrained** ⭐ | 4.7M | 2D only |
373
+ | `efficientnet_b1` | Efficient, **pretrained** ⭐ | 7.2M | 2D only |
374
+ | `efficientnet_b2` | Efficient, **pretrained** ⭐ | 8.4M | 2D only |
375
+ | `vit_tiny` | Transformer, small datasets | 5.4M | 1D/2D |
376
+ | `vit_small` | Transformer, balanced | 21.5M | 1D/2D |
377
+ | `vit_base` | Transformer, high capacity | 85.5M | 1D/2D |
378
+ | `convnext_tiny` | Modern CNN, transformer-inspired | 28.2M | 1D/2D/3D |
379
+ | `convnext_tiny_pretrained` | **Transfer learning** ⭐ | 28.2M | 2D only |
380
+ | `convnext_small` | Modern CNN, balanced | 49.8M | 1D/2D/3D |
381
+ | `convnext_base` | Modern CNN, high capacity | 88.1M | 1D/2D/3D |
382
+ | `densenet121` | Feature reuse, small data | 7.5M | 1D/2D/3D |
383
+ | `densenet121_pretrained` | **Transfer learning** ⭐ | 7.5M | 2D only |
384
+ | `densenet169` | Deeper DenseNet | 13.3M | 1D/2D/3D |
385
+ | `unet` | Spatial output (velocity fields) | 31.0M | 1D/2D/3D |
386
+ | `unet_regression` | Multi-scale features for regression | 31.1M | 1D/2D/3D |
387
+
388
+ > ⭐ **Pretrained models** use ImageNet weights for transfer learning.
389
+
390
+ </details>
391
+
392
+ <details>
393
+ <summary><b>Training Parameters</b></summary>
394
+
395
+ | Argument | Default | Description |
396
+ |----------|---------|-------------|
397
+ | `--model` | `cnn` | Model architecture |
398
+ | `--batch_size` | `128` | Per-GPU batch size |
399
+ | `--lr` | `1e-3` | Learning rate |
400
+ | `--epochs` | `1000` | Maximum epochs |
401
+ | `--patience` | `20` | Early stopping patience |
402
+ | `--weight_decay` | `1e-4` | AdamW regularization |
403
+ | `--grad_clip` | `1.0` | Gradient clipping |
404
+
405
+ </details>
406
+
407
+ <details>
408
+ <summary><b>Data & I/O</b></summary>
409
+
410
+ | Argument | Default | Description |
411
+ |----------|---------|-------------|
412
+ | `--data_path` | `train_data.npz` | Dataset path |
413
+ | `--workers` | `-1` | DataLoader workers per GPU (-1=auto-detect) |
414
+ | `--seed` | `2025` | Random seed |
415
+ | `--output_dir` | `.` | Output directory for checkpoints |
416
+ | `--resume` | `None` | Checkpoint to resume (auto-detected if not set) |
417
+ | `--save_every` | `50` | Checkpoint frequency |
418
+ | `--fresh` | `False` | Force fresh training, ignore existing checkpoints |
419
+ | `--single_channel` | `False` | Confirm data is single-channel (for shallow 3D volumes like `(8, 128, 128)`) |
420
+
421
+ </details>
422
+
423
+ <details>
424
+ <summary><b>Performance</b></summary>
425
+
426
+ | Argument | Default | Description |
427
+ |----------|---------|-------------|
428
+ | `--compile` | `False` | Enable `torch.compile` |
429
+ | `--precision` | `bf16` | Mixed precision mode (`bf16`, `fp16`, `no`) |
430
+ | `--wandb` | `False` | Enable W&B logging |
431
+ | `--project_name` | `DL-Training` | W&B project name |
432
+ | `--run_name` | `None` | W&B run name (auto-generated if not set) |
433
+
434
+ </details>
435
+
436
+ <details>
437
+ <summary><b>Environment Variables (run_training.sh)</b></summary>
438
+
439
+ | Variable | Default | Description |
440
+ |----------|---------|-------------|
441
+ | `NUM_GPUS` | **Auto-detected** | Number of GPUs to use. By default, automatically detected via `nvidia-smi`. Set explicitly to override (e.g., `NUM_GPUS=2`) |
442
+ | `NUM_MACHINES` | `1` | Number of machines in distributed setup |
443
+ | `MIXED_PRECISION` | `bf16` | Precision mode: `bf16`, `fp16`, or `no` |
444
+ | `DYNAMO_BACKEND` | `no` | PyTorch Dynamo backend |
445
+ | `WANDB_MODE` | `offline` | WandB mode: `offline` or `online` |
446
+
447
+ </details>
448
+
449
+ <details>
450
+ <summary><b>Loss Functions</b></summary>
451
+
452
+ | Loss | Flag | Best For | Notes |
453
+ |------|------|----------|-------|
454
+ | `mse` | `--loss mse` | Default, smooth gradients | Standard Mean Squared Error |
455
+ | `mae` | `--loss mae` | Outlier-robust, linear penalty | Mean Absolute Error (L1) |
456
+ | `huber` | `--loss huber --huber_delta 1.0` | Best of MSE + MAE | Robust, smooth transition |
457
+ | `smooth_l1` | `--loss smooth_l1` | Similar to Huber | PyTorch native implementation |
458
+ | `log_cosh` | `--loss log_cosh` | Smooth approximation to MAE | Differentiable everywhere |
459
+ | `weighted_mse` | `--loss weighted_mse --loss_weights "2.0,1.0,1.0"` | Prioritize specific targets | Per-target weighting |
460
+
461
+ **Example:**
462
+ ```bash
463
+ # Use Huber loss for noisy NDE data
464
+ accelerate launch -m wavedl.train --model cnn --loss huber --huber_delta 0.5
465
+
466
+ # Weighted MSE: prioritize thickness (first target)
467
+ accelerate launch -m wavedl.train --model cnn --loss weighted_mse --loss_weights "2.0,1.0,1.0"
468
+ ```
469
+
470
+ </details>
471
+
472
+ <details>
473
+ <summary><b>Optimizers</b></summary>
474
+
475
+ | Optimizer | Flag | Best For | Key Parameters |
476
+ |-----------|------|----------|----------------|
477
+ | `adamw` | `--optimizer adamw` | Default, most cases | `--betas "0.9,0.999"` |
478
+ | `adam` | `--optimizer adam` | Legacy compatibility | `--betas "0.9,0.999"` |
479
+ | `sgd` | `--optimizer sgd` | Better generalization | `--momentum 0.9 --nesterov` |
480
+ | `nadam` | `--optimizer nadam` | Adam + Nesterov | Faster convergence |
481
+ | `radam` | `--optimizer radam` | Variance-adaptive | More stable training |
482
+ | `rmsprop` | `--optimizer rmsprop` | RNN/LSTM models | `--momentum 0.9` |
483
+
484
+ **Example:**
485
+ ```bash
486
+ # SGD with Nesterov momentum (often better generalization)
487
+ accelerate launch -m wavedl.train --model cnn --optimizer sgd --lr 0.01 --momentum 0.9 --nesterov
488
+
489
+ # RAdam for more stable training
490
+ accelerate launch -m wavedl.train --model cnn --optimizer radam --lr 1e-3
491
+ ```
492
+
493
+ </details>
494
+
495
+ <details>
496
+ <summary><b>Learning Rate Schedulers</b></summary>
497
+
498
+ | Scheduler | Flag | Best For | Key Parameters |
499
+ |-----------|------|----------|----------------|
500
+ | `plateau` | `--scheduler plateau` | Default, adaptive | `--scheduler_patience 10 --scheduler_factor 0.5` |
501
+ | `cosine` | `--scheduler cosine` | Long training, smooth decay | `--min_lr 1e-6` |
502
+ | `cosine_restarts` | `--scheduler cosine_restarts` | Escape local minima | Warm restarts |
503
+ | `onecycle` | `--scheduler onecycle` | Fast convergence | Super-convergence |
504
+ | `step` | `--scheduler step` | Simple decay | `--step_size 30 --scheduler_factor 0.1` |
505
+ | `multistep` | `--scheduler multistep` | Custom milestones | `--milestones "30,60,90"` |
506
+ | `exponential` | `--scheduler exponential` | Continuous decay | `--scheduler_factor 0.95` |
507
+ | `linear_warmup` | `--scheduler linear_warmup` | Warmup phase | `--warmup_epochs 5` |
508
+
509
+ **Example:**
510
+ ```bash
511
+ # Cosine annealing for 1000 epochs
512
+ accelerate launch -m wavedl.train --model cnn --scheduler cosine --epochs 1000 --min_lr 1e-7
513
+
514
+ # OneCycleLR for super-convergence
515
+ accelerate launch -m wavedl.train --model cnn --scheduler onecycle --lr 1e-2 --epochs 50
516
+
517
+ # MultiStep with custom milestones
518
+ accelerate launch -m wavedl.train --model cnn --scheduler multistep --milestones "100,200,300"
519
+ ```
520
+
521
+ </details>
522
+
523
+ <details>
524
+ <summary><b>Cross-Validation</b></summary>
525
+
526
+ For robust model evaluation, simply add the `--cv` flag:
527
+
528
+ ```bash
529
+ # 5-fold cross-validation (works with both methods!)
530
+ ./run_training.sh --model cnn --cv 5 --data_path train_data.npz
531
+ # OR
532
+ accelerate launch -m wavedl.train --model cnn --cv 5 --data_path train_data.npz
533
+
534
+ # Stratified CV (recommended for unbalanced data)
535
+ ./run_training.sh --model cnn --cv 5 --cv_stratify --loss huber --epochs 100
536
+
537
+ # Full configuration
538
+ ./run_training.sh --model cnn --cv 5 --cv_stratify \
539
+ --loss huber --optimizer adamw --scheduler cosine \
540
+ --output_dir ./cv_results
541
+ ```
542
+
543
+ | Argument | Default | Description |
544
+ |----------|---------|-------------|
545
+ | `--cv` | `0` | Number of CV folds (0=disabled, normal training) |
546
+ | `--cv_stratify` | `False` | Use stratified splitting (bins targets) |
547
+ | `--cv_bins` | `10` | Number of bins for stratified CV |
548
+
549
+ **Output:**
550
+ - `cv_summary.json`: Aggregated metrics (mean ± std)
551
+ - `cv_results.csv`: Per-fold detailed results
552
+ - `fold_*/`: Individual fold models and scalers
553
+
554
+ </details>
555
+
556
+ <details>
557
+ <summary><b>Configuration Files (YAML)</b></summary>
558
+
559
+ Use YAML files for reproducible experiments. CLI arguments can override any config value.
560
+
561
+ ```bash
562
+ # Use a config file
563
+ accelerate launch -m wavedl.train --config configs/config.yaml --data_path train.npz
564
+
565
+ # Override specific values from config
566
+ accelerate launch -m wavedl.train --config configs/config.yaml --lr 5e-4 --epochs 500
567
+ ```
568
+
569
+ **Example config (`configs/config.yaml`):**
570
+ ```yaml
571
+ # Model & Training
572
+ model: cnn
573
+ batch_size: 128
574
+ lr: 0.001
575
+ epochs: 1000
576
+ patience: 20
577
+
578
+ # Loss, Optimizer, Scheduler
579
+ loss: mse
580
+ optimizer: adamw
581
+ scheduler: plateau
582
+
583
+ # Cross-Validation (0 = disabled)
584
+ cv: 0
585
+
586
+ # Performance
587
+ precision: bf16
588
+ compile: false
589
+ seed: 2025
590
+ ```
591
+
592
+ > [!TIP]
593
+ > See [`configs/config.yaml`](configs/config.yaml) for the complete template with all available options documented.
594
+
595
+ </details>
596
+
597
+ <details>
598
+ <summary><b>Hyperparameter Search (HPO)</b></summary>
599
+
600
+ Automatically find the best training configuration using [Optuna](https://optuna.org/).
601
+
602
+ **Step 1: Install**
603
+ ```bash
604
+ pip install -e ".[hpo]"
605
+ ```
606
+
607
+ **Step 2: Run HPO**
608
+
609
+ You specify which models to search and how many trials to run:
610
+ ```bash
611
+ # Search 3 models with 100 trials
612
+ python -m wavedl.hpo --data_path train.npz --models cnn resnet18 efficientnet_b0 --n_trials 100
613
+
614
+ # Search 1 model (faster)
615
+ python -m wavedl.hpo --data_path train.npz --models cnn --n_trials 50
616
+
617
+ # Search all your candidate models
618
+ python -m wavedl.hpo --data_path train.npz --models cnn resnet18 resnet50 vit_small densenet121 --n_trials 200
619
+ ```
620
+
621
+ **Step 3: Train with best parameters**
622
+
623
+ After HPO completes, it prints the optimal command:
624
+ ```bash
625
+ accelerate launch -m wavedl.train --data_path train.npz --model cnn --lr 3.2e-4 --batch_size 128 ...
626
+ ```
627
+
628
+ ---
629
+
630
+ **What Gets Searched:**
631
+
632
+ | Parameter | Default | You Can Override With |
633
+ |-----------|---------|----------------------|
634
+ | Models | cnn, resnet18, resnet34 | `--models X Y Z` |
635
+ | Optimizers | [all 6](#optimizers) | `--optimizers X Y` |
636
+ | Schedulers | [all 8](#learning-rate-schedulers) | `--schedulers X Y` |
637
+ | Losses | [all 6](#loss-functions) | `--losses X Y` |
638
+ | Learning rate | 1e-5 → 1e-2 | (always searched) |
639
+ | Batch size | 64, 128, 256, 512 | (always searched) |
640
+
641
+ **Quick Mode** (`--quick`):
642
+ - Uses minimal defaults: cnn + adamw + plateau + mse
643
+ - Faster for testing your setup before running full search
644
+ - You can still override any option with the flags above
645
+
646
+ ---
647
+
648
+ **All Arguments:**
649
+
650
+ | Argument | Default | Description |
651
+ |----------|---------|-------------|
652
+ | `--data_path` | (required) | Training data file |
653
+ | `--models` | 3 defaults | Models to search (specify any number) |
654
+ | `--n_trials` | `50` | Number of trials to run |
655
+ | `--quick` | `False` | Use minimal defaults (faster) |
656
+ | `--optimizers` | all 6 | Optimizers to search |
657
+ | `--schedulers` | all 8 | Schedulers to search |
658
+ | `--losses` | all 6 | Losses to search |
659
+ | `--n_jobs` | `1` | Parallel trials (multi-GPU) |
660
+ | `--max_epochs` | `50` | Max epochs per trial |
661
+ | `--output` | `hpo_results.json` | Output file |
662
+
663
+ > [!TIP]
664
+ > See [Available Models](#available-models) for all 21 architectures you can search.
665
+
666
+ </details>
667
+
668
+ ---
669
+
670
+ ## 📈 Data Preparation
671
+
672
+ WaveDL supports multiple data formats for training and inference:
673
+
674
+ | Format | Extension | Key Advantages |
675
+ |--------|-----------|----------------|
676
+ | **NPZ** | `.npz` | Native NumPy, fast loading, recommended |
677
+ | **HDF5** | `.h5`, `.hdf5` | Large datasets, hierarchical, cross-platform |
678
+ | **MAT** | `.mat` | MATLAB compatibility (**v7.3+ only**, saved with `-v7.3` flag) |
679
+
680
+ **The framework automatically detects file format and data dimensionality** (1D, 2D, or 3D) — you only need to provide the appropriate model architecture.
681
+
682
+ | Key | Shape | Type | Description |
683
+ |-----|-------|------|-------------|
684
+ | `input_train` / `input_test` | `(N, L)`, `(N, H, W)`, or `(N, D, H, W)` | `float32` | N samples of 1D/2D/3D representations |
685
+ | `output_train` / `output_test` | `(N, T)` | `float32` | N samples with T regression targets |
686
+
687
+ > [!TIP]
688
+ > - **Flexible Key Names**: WaveDL auto-detects common key pairs:
689
+ > - `input_train`/`output_train`, `input_test`/`output_test` (WaveDL standard)
690
+ > - `X`/`Y`, `x`/`y` (ML convention)
691
+ > - `data`/`labels`, `inputs`/`outputs`, `features`/`targets`
692
+ > - **Automatic Dimension Detection**: Channel dimension is added automatically. No manual reshaping required!
693
+ > - **Sparse Matrix Support**: NPZ and MAT v7.3 files with scipy/MATLAB sparse matrices are automatically converted to dense arrays.
694
+ > - **Auto-Normalization**: Target values are automatically standardized during training. MAE is reported in original physical units.
695
+
696
+ > [!IMPORTANT]
697
+ > **MATLAB Users**: MAT files must be saved with the `-v7.3` flag for memory-efficient loading:
698
+ > ```matlab
699
+ > save('data.mat', 'input_train', 'output_train', '-v7.3')
700
+ > ```
701
+ > Older MAT formats (v5/v7) are not supported. Convert to NPZ for best compatibility.
702
+
703
+ <details>
704
+ <summary><b>Example: Basic Preparation</b></summary>
705
+
706
+ ```python
707
+ import numpy as np
708
+
709
+ X = np.array(images, dtype=np.float32) # (N, H, W)
710
+ y = np.array(labels, dtype=np.float32) # (N, T)
711
+
712
+ np.savez('train_data.npz', input_train=X, output_train=y)
713
+ ```
714
+
715
+ </details>
716
+
717
+ <details>
718
+ <summary><b>Example: From Image Files + CSV</b></summary>
719
+
720
+ ```python
721
+ import numpy as np
722
+ from PIL import Image
723
+ from pathlib import Path
724
+ import pandas as pd
725
+
726
+ # Load images
727
+ images = [np.array(Image.open(f).convert('L'), dtype=np.float32)
728
+ for f in sorted(Path("images/").glob("*.png"))]
729
+ X = np.stack(images)
730
+
731
+ # Load labels
732
+ y = pd.read_csv("labels.csv").values.astype(np.float32)
733
+
734
+ np.savez('train_data.npz', input_train=X, output_train=y)
735
+ ```
736
+
737
+ </details>
738
+
739
+ <details>
740
+ <summary><b>Example: From MATLAB (.mat)</b></summary>
741
+
742
+ ```python
743
+ import numpy as np
744
+ from scipy.io import loadmat
745
+
746
+ data = loadmat('simulation_data.mat')
747
+ X = data['spectrograms'].astype(np.float32) # Adjust key
748
+ y = data['parameters'].astype(np.float32)
749
+
750
+ # Transpose if needed: (H, W, N) → (N, H, W)
751
+ if X.ndim == 3 and X.shape[2] < X.shape[0]:
752
+ X = np.transpose(X, (2, 0, 1))
753
+
754
+ np.savez('train_data.npz', input_train=X, output_train=y)
755
+ ```
756
+
757
+ </details>
758
+
759
+ <details>
760
+ <summary><b>Example: Synthetic Test Data</b></summary>
761
+
762
+ ```python
763
+ import numpy as np
764
+
765
+ X = np.random.randn(1000, 256, 256).astype(np.float32)
766
+ y = np.random.randn(1000, 5).astype(np.float32)
767
+
768
+ np.savez('test_data.npz', input_train=X, output_train=y)
769
+ ```
770
+
771
+ </details>
772
+
773
+ <details>
774
+ <summary><b>Validation Script</b></summary>
775
+
776
+ ```python
777
+ import numpy as np
778
+
779
+ data = np.load('train_data.npz')
780
+ assert data['input_train'].ndim == 3, "Input must be 3D: (N, H, W)"
781
+ assert data['output_train'].ndim == 2, "Output must be 2D: (N, T)"
782
+ assert len(data['input_train']) == len(data['output_train']), "Sample mismatch"
783
+
784
+ print(f"✓ Input: {data['input_train'].shape} {data['input_train'].dtype}")
785
+ print(f"✓ Output: {data['output_train'].shape} {data['output_train'].dtype}")
786
+ ```
787
+
788
+ </details>
789
+
790
+
791
+ ---
792
+
793
+ ## 📦 Examples [![Try it on Colab](https://img.shields.io/badge/Try_it_on_Colab-8E44AD?style=plastic&logo=googlecolab&logoColor=white)](https://colab.research.google.com/github/ductho-le/WaveDL/blob/main/notebooks/demo.ipynb)
794
+
795
+ The `examples/` folder contains a **complete, ready-to-run example** for **material characterization of isotropic plates**. The pre-trained CNN predicts three physical parameters from Lamb wave dispersion curves:
796
+
797
+ | Parameter | Unit | Description |
798
+ |-----------|------|-------------|
799
+ | *h* | mm | Plate thickness |
800
+ | √(*E*/ρ) | km/s | Square root of Young's modulus over density |
801
+ | *ν* | — | Poisson's ratio |
802
+
803
+ > [!NOTE]
804
+ > This example is based on our paper at **SPIE Smart Structures + NDE 2026**: [*"Deep learning-based ultrasonic assessment of plate thickness and elasticity"*](https://spie.org/spie-smart-structures-and-materials-nondestructive-evaluation/presentation/Deep-learningbased-ultrasonic-assessment-of-plate-thickness-and-elasticity/13951-4) (Paper 13951-4, to appear).
805
+
806
+ **Try it yourself:**
807
+
808
+ ```bash
809
+ # Run inference on the example data
810
+ python -m wavedl.test --checkpoint ./examples/elastic_cnn_example/best_checkpoint \
811
+ --data_path ./examples/elastic_cnn_example/Test_data_100.mat \
812
+ --plot --save_predictions --output_dir ./examples/elastic_cnn_example/test_results
813
+
814
+ # Export to ONNX (already included as model.onnx)
815
+ python -m wavedl.test --checkpoint ./examples/elastic_cnn_example/best_checkpoint \
816
+ --data_path ./examples/elastic_cnn_example/Test_data_100.mat \
817
+ --export onnx --export_path ./examples/elastic_cnn_example/model.onnx
818
+ ```
819
+
820
+ **What's Included:**
821
+
822
+ | File | Description |
823
+ |------|-------------|
824
+ | `best_checkpoint/` | Pre-trained CNN checkpoint |
825
+ | `Test_data_100.mat` | 100 sample test set (500×500 dispersion curves → *h*, √(*E*/ρ), *ν*) |
826
+ | `model.onnx` | ONNX export with embedded de-normalization |
827
+ | `training_history.csv` | Epoch-by-epoch training metrics (loss, R², LR, etc.) |
828
+ | `training_curves.png` | Training/validation loss and learning rate plot |
829
+ | `test_results/` | Example predictions and diagnostic plots |
830
+ | `WaveDL_ONNX_Inference.m` | MATLAB script for ONNX inference |
831
+
832
+ **Training Progress:**
833
+
834
+ <p align="center">
835
+ <img src="examples/elastic_cnn_example/training_curves.png" alt="Training curves" width="600"><br>
836
+ <em>Training and validation loss over 162 epochs with learning rate schedule</em>
837
+ </p>
838
+
839
+ **Inference Results:**
840
+
841
+ <p align="center">
842
+ <img src="examples/elastic_cnn_example/test_results/scatter_all.png" alt="Scatter plot" width="700"><br>
843
+ <em>Figure 1: Predictions vs ground truth for all three elastic parameters</em>
844
+ </p>
845
+
846
+ <p align="center">
847
+ <img src="examples/elastic_cnn_example/test_results/error_histogram.png" alt="Error histogram" width="700"><br>
848
+ <em>Figure 2: Distribution of prediction errors showing near-zero mean bias</em>
849
+ </p>
850
+
851
+ <p align="center">
852
+ <img src="examples/elastic_cnn_example/test_results/residuals.png" alt="Residual plot" width="700"><br>
853
+ <em>Figure 3: Residuals vs predicted values (no heteroscedasticity detected)</em>
854
+ </p>
855
+
856
+ <p align="center">
857
+ <img src="examples/elastic_cnn_example/test_results/bland_altman.png" alt="Bland-Altman plot" width="700"><br>
858
+ <em>Figure 4: Bland-Altman analysis with ±1.96 SD limits of agreement</em>
859
+ </p>
860
+
861
+ <p align="center">
862
+ <img src="examples/elastic_cnn_example/test_results/qq_plot.png" alt="Q-Q plot" width="700"><br>
863
+ <em>Figure 5: Q-Q plots confirming normally distributed prediction errors</em>
864
+ </p>
865
+
866
+ <p align="center">
867
+ <img src="examples/elastic_cnn_example/test_results/error_correlation.png" alt="Error correlation" width="300"><br>
868
+ <em>Figure 6: Error correlation matrix between parameters</em>
869
+ </p>
870
+
871
+ <p align="center">
872
+ <img src="examples/elastic_cnn_example/test_results/relative_error.png" alt="Relative error" width="700"><br>
873
+ <em>Figure 7: Relative error (%) vs true value for each parameter</em>
874
+ </p>
875
+
876
+ <p align="center">
877
+ <img src="examples/elastic_cnn_example/test_results/error_cdf.png" alt="Error CDF" width="500"><br>
878
+ <em>Figure 8: Cumulative error distribution — 95% of predictions within indicated bounds</em>
879
+ </p>
880
+
881
+ <p align="center">
882
+ <img src="examples/elastic_cnn_example/test_results/prediction_vs_index.png" alt="Prediction vs index" width="700"><br>
883
+ <em>Figure 9: True vs predicted values by sample index</em>
884
+ </p>
885
+
886
+ <p align="center">
887
+ <img src="examples/elastic_cnn_example/test_results/error_boxplot.png" alt="Error box plot" width="400"><br>
888
+ <em>Figure 10: Error distribution summary (median, quartiles, outliers)</em>
889
+ </p>
890
+
891
+ ---
892
+
893
+ ## 🔬 Broader Applications
894
+
895
+ Beyond the material characterization example above, the WaveDL pipeline can be adapted for a wide range of **wave-based inverse problems** across multiple domains:
896
+
897
+ ### 🏗️ Non-Destructive Evaluation & Structural Health Monitoring
898
+
899
+ | Application | Input | Output |
900
+ |-------------|-------|--------|
901
+ | Defect Sizing | A-scans, phased array images, FMC/TFM, ... | Crack length, depth, ... |
902
+ | Corrosion Estimation | Thickness maps, resonance spectra, ... | Wall thickness, corrosion rate, ... |
903
+ | Weld Quality Assessment | Phased array images, TOFD, ... | Porosity %, penetration depth, ... |
904
+ | RUL Prediction | Acoustic emission (AE), vibration spectra, ... | Cycles to failure, ... |
905
+ | Damage Localization | Wavefield images, DAS/DVS data, ... | Damage coordinates (x, y, z) |
906
+
907
+ ### 🌍 Geophysics & Seismology
908
+
909
+ | Application | Input | Output |
910
+ |-------------|-------|--------|
911
+ | Seismic Inversion | Shot gathers, seismograms, ... | Velocity models, density profiles, ... |
912
+ | Subsurface Characterization | Surface wave dispersion, receiver functions, ... | Layer thickness, shear modulus, ... |
913
+ | Earthquake Source Parameters | Waveforms, spectrograms, ... | Magnitude, depth, focal mechanism, ... |
914
+ | Reservoir Characterization | Reflection seismic, AVO attributes, ... | Porosity, fluid saturation, ... |
915
+
916
+ ### 🩺 Biomedical Ultrasound & Elastography
917
+
918
+ | Application | Input | Output |
919
+ |-------------|-------|--------|
920
+ | Tissue Elastography | Shear wave data, strain images, ... | Shear modulus, Young's modulus, ... |
921
+ | Liver Fibrosis Staging | Elastography images, US RF data, ... | Stiffness (kPa), fibrosis score, ... |
922
+ | Tumor Characterization | B-mode + elastography, ARFI data, ... | Lesion stiffness, size, ... |
923
+ | Bone QUS | Axial-transmission signals, ... | Porosity, cortical thickness, elastic modulus ... |
924
+
925
+ > [!NOTE]
926
+ > Adapting WaveDL to these applications requires preparing your own dataset and choosing a suitable model architecture to match your input dimensionality.
927
+
928
+ ---
929
+
930
+ ## 📚 Documentation
931
+
932
+ | Resource | Description |
933
+ |----------|-------------|
934
+ | Technical Paper | In-depth framework description *(coming soon)* |
935
+ | [`_template.py`](models/_template.py) | Template for new architectures |
936
+
937
+ ---
938
+
939
+ ## 📜 Citation
940
+
941
+ If you use WaveDL in your research, please cite:
942
+
943
+ ```bibtex
944
+ @software{le2025wavedl,
945
+ author = {Le, Ductho},
946
+ title = {{WaveDL}: A Scalable Deep Learning Framework for Wave-Based Inverse Problems},
947
+ year = {2025},
948
+ publisher = {Zenodo},
949
+ doi = {10.5281/zenodo.18012338},
950
+ url = {https://doi.org/10.5281/zenodo.18012338}
951
+ }
952
+ ```
953
+
954
+ Or in APA format:
955
+ > Le, D. (2025). *WaveDL: A Scalable Deep Learning Framework for Wave-Based Inverse Problems*. Zenodo. https://doi.org/10.5281/zenodo.18012338
956
+
957
+ ---
958
+
959
+ ## 🙏 Acknowledgments
960
+
961
+ Ductho Le would like to acknowledge [NSERC](https://www.nserc-crsng.gc.ca/) and [Alberta Innovates](https://albertainnovates.ca/) for supporting his study and research by means of a research assistantship and a graduate doctoral fellowship.
962
+
963
+ This research was enabled in part by support provided by [Compute Ontario](https://www.computeontario.ca/), [Calcul Québec](https://www.calculquebec.ca/), and the [Digital Research Alliance of Canada](https://alliancecan.ca/).
964
+
965
+ <br>
966
+
967
+ <p align="center">
968
+ <a href="https://www.ualberta.ca/"><img src="logos/ualberta.png" alt="University of Alberta" height="60"></a>
969
+ &emsp;&emsp;
970
+ <a href="https://albertainnovates.ca/"><img src="logos/alberta_innovates.png" alt="Alberta Innovates" height="60"></a>
971
+ &emsp;&emsp;
972
+ <a href="https://www.nserc-crsng.gc.ca/"><img src="logos/nserc.png" alt="NSERC" height="60"></a>
973
+ </p>
974
+
975
+ <p align="center">
976
+ <a href="https://alliancecan.ca/"><img src="logos/drac.png" alt="Digital Research Alliance of Canada" height="50"></a>
977
+ </p>
978
+
979
+ ---
980
+
981
+ <div align="center">
982
+
983
+ **[Ductho Le](mailto:ductho.le@outlook.com)** · University of Alberta
984
+
985
+ [![ORCID](https://img.shields.io/badge/ORCID-0000--0002--3073--1416-a6ce39?style=plastic&logo=orcid&logoColor=white)](https://orcid.org/0000-0002-3073-1416)
986
+ [![Google Scholar](https://img.shields.io/badge/Google_Scholar-4285F4?style=plastic&logo=google-scholar&logoColor=white)](https://scholar.google.ca/citations?user=OlwMr9AAAAAJ)
987
+ [![ResearchGate](https://img.shields.io/badge/ResearchGate-00CCBB?style=plastic&logo=researchgate&logoColor=white)](https://www.researchgate.net/profile/Ductho-Le)
988
+
989
+ <sub>Released under the MIT License</sub>
990
+
991
+ </div>