wavedl 1.4.0__tar.gz → 1.4.2__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {wavedl-1.4.0/src/wavedl.egg-info → wavedl-1.4.2}/PKG-INFO +26 -6
- {wavedl-1.4.0 → wavedl-1.4.2}/README.md +25 -5
- {wavedl-1.4.0 → wavedl-1.4.2}/src/wavedl/__init__.py +1 -1
- {wavedl-1.4.0 → wavedl-1.4.2}/src/wavedl/hpc.py +22 -4
- {wavedl-1.4.0 → wavedl-1.4.2}/src/wavedl/hpo.py +46 -19
- {wavedl-1.4.0 → wavedl-1.4.2}/src/wavedl/models/_template.py +28 -40
- {wavedl-1.4.0 → wavedl-1.4.2}/src/wavedl/models/base.py +49 -1
- {wavedl-1.4.0 → wavedl-1.4.2}/src/wavedl/train.py +65 -27
- {wavedl-1.4.0 → wavedl-1.4.2}/src/wavedl/utils/config.py +88 -2
- {wavedl-1.4.0 → wavedl-1.4.2}/src/wavedl/utils/data.py +79 -2
- {wavedl-1.4.0 → wavedl-1.4.2/src/wavedl.egg-info}/PKG-INFO +26 -6
- {wavedl-1.4.0 → wavedl-1.4.2}/LICENSE +0 -0
- {wavedl-1.4.0 → wavedl-1.4.2}/pyproject.toml +0 -0
- {wavedl-1.4.0 → wavedl-1.4.2}/setup.cfg +0 -0
- {wavedl-1.4.0 → wavedl-1.4.2}/src/wavedl/models/__init__.py +0 -0
- {wavedl-1.4.0 → wavedl-1.4.2}/src/wavedl/models/cnn.py +0 -0
- {wavedl-1.4.0 → wavedl-1.4.2}/src/wavedl/models/convnext.py +0 -0
- {wavedl-1.4.0 → wavedl-1.4.2}/src/wavedl/models/densenet.py +0 -0
- {wavedl-1.4.0 → wavedl-1.4.2}/src/wavedl/models/efficientnet.py +0 -0
- {wavedl-1.4.0 → wavedl-1.4.2}/src/wavedl/models/efficientnetv2.py +0 -0
- {wavedl-1.4.0 → wavedl-1.4.2}/src/wavedl/models/mobilenetv3.py +0 -0
- {wavedl-1.4.0 → wavedl-1.4.2}/src/wavedl/models/registry.py +0 -0
- {wavedl-1.4.0 → wavedl-1.4.2}/src/wavedl/models/regnet.py +0 -0
- {wavedl-1.4.0 → wavedl-1.4.2}/src/wavedl/models/resnet.py +0 -0
- {wavedl-1.4.0 → wavedl-1.4.2}/src/wavedl/models/resnet3d.py +0 -0
- {wavedl-1.4.0 → wavedl-1.4.2}/src/wavedl/models/swin.py +0 -0
- {wavedl-1.4.0 → wavedl-1.4.2}/src/wavedl/models/tcn.py +0 -0
- {wavedl-1.4.0 → wavedl-1.4.2}/src/wavedl/models/unet.py +0 -0
- {wavedl-1.4.0 → wavedl-1.4.2}/src/wavedl/models/vit.py +0 -0
- {wavedl-1.4.0 → wavedl-1.4.2}/src/wavedl/test.py +0 -0
- {wavedl-1.4.0 → wavedl-1.4.2}/src/wavedl/utils/__init__.py +0 -0
- {wavedl-1.4.0 → wavedl-1.4.2}/src/wavedl/utils/cross_validation.py +0 -0
- {wavedl-1.4.0 → wavedl-1.4.2}/src/wavedl/utils/distributed.py +0 -0
- {wavedl-1.4.0 → wavedl-1.4.2}/src/wavedl/utils/losses.py +0 -0
- {wavedl-1.4.0 → wavedl-1.4.2}/src/wavedl/utils/metrics.py +0 -0
- {wavedl-1.4.0 → wavedl-1.4.2}/src/wavedl/utils/optimizers.py +0 -0
- {wavedl-1.4.0 → wavedl-1.4.2}/src/wavedl/utils/schedulers.py +0 -0
- {wavedl-1.4.0 → wavedl-1.4.2}/src/wavedl.egg-info/SOURCES.txt +0 -0
- {wavedl-1.4.0 → wavedl-1.4.2}/src/wavedl.egg-info/dependency_links.txt +0 -0
- {wavedl-1.4.0 → wavedl-1.4.2}/src/wavedl.egg-info/entry_points.txt +0 -0
- {wavedl-1.4.0 → wavedl-1.4.2}/src/wavedl.egg-info/requires.txt +0 -0
- {wavedl-1.4.0 → wavedl-1.4.2}/src/wavedl.egg-info/top_level.txt +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.2
|
|
2
2
|
Name: wavedl
|
|
3
|
-
Version: 1.4.
|
|
3
|
+
Version: 1.4.2
|
|
4
4
|
Summary: A Scalable Deep Learning Framework for Wave-Based Inverse Problems
|
|
5
5
|
Author: Ductho Le
|
|
6
6
|
License: MIT
|
|
@@ -502,12 +502,32 @@ WaveDL/
|
|
|
502
502
|
|
|
503
503
|
| Argument | Default | Description |
|
|
504
504
|
|----------|---------|-------------|
|
|
505
|
-
| `--compile` | `False` | Enable `torch.compile` |
|
|
505
|
+
| `--compile` | `False` | Enable `torch.compile` (recommended for long runs) |
|
|
506
506
|
| `--precision` | `bf16` | Mixed precision mode (`bf16`, `fp16`, `no`) |
|
|
507
|
+
| `--workers` | `-1` | DataLoader workers per GPU (-1=auto, up to 16) |
|
|
507
508
|
| `--wandb` | `False` | Enable W&B logging |
|
|
509
|
+
| `--wandb_watch` | `False` | Enable W&B gradient watching (adds overhead) |
|
|
508
510
|
| `--project_name` | `DL-Training` | W&B project name |
|
|
509
511
|
| `--run_name` | `None` | W&B run name (auto-generated if not set) |
|
|
510
512
|
|
|
513
|
+
**Automatic GPU Optimizations:**
|
|
514
|
+
|
|
515
|
+
WaveDL automatically enables performance optimizations for modern GPUs:
|
|
516
|
+
|
|
517
|
+
| Optimization | Effect | GPU Support |
|
|
518
|
+
|--------------|--------|-------------|
|
|
519
|
+
| **TF32 precision** | ~2x speedup for float32 matmul | A100, H100 (Ampere+) |
|
|
520
|
+
| **cuDNN benchmark** | Auto-tuned convolutions | All NVIDIA GPUs |
|
|
521
|
+
| **Worker scaling** | Up to 16 workers per GPU | All systems |
|
|
522
|
+
|
|
523
|
+
> [!NOTE]
|
|
524
|
+
> These optimizations are **backward compatible** — they have no effect on older GPUs (V100, T4, GTX) or CPU-only systems. No configuration needed.
|
|
525
|
+
|
|
526
|
+
**HPC Best Practices:**
|
|
527
|
+
- Stage data to `$SLURM_TMPDIR` (local NVMe) for maximum I/O throughput
|
|
528
|
+
- Use `--compile` for training runs > 50 epochs
|
|
529
|
+
- Increase `--workers` manually if auto-detection is suboptimal
|
|
530
|
+
|
|
511
531
|
</details>
|
|
512
532
|
|
|
513
533
|
<details>
|
|
@@ -690,7 +710,7 @@ python -m wavedl.hpo --data_path train.npz --models cnn --n_trials 50
|
|
|
690
710
|
python -m wavedl.hpo --data_path train.npz --models cnn resnet18 resnet50 vit_small densenet121 --n_trials 200
|
|
691
711
|
```
|
|
692
712
|
|
|
693
|
-
**
|
|
713
|
+
**Train with best parameters**
|
|
694
714
|
|
|
695
715
|
After HPO completes, it prints the optimal command:
|
|
696
716
|
```bash
|
|
@@ -837,7 +857,7 @@ import numpy as np
|
|
|
837
857
|
X = np.random.randn(1000, 256, 256).astype(np.float32)
|
|
838
858
|
y = np.random.randn(1000, 5).astype(np.float32)
|
|
839
859
|
|
|
840
|
-
np.savez('test_data.npz',
|
|
860
|
+
np.savez('test_data.npz', input_test=X, output_test=y)
|
|
841
861
|
```
|
|
842
862
|
|
|
843
863
|
</details>
|
|
@@ -849,7 +869,7 @@ np.savez('test_data.npz', input_train=X, output_train=y)
|
|
|
849
869
|
import numpy as np
|
|
850
870
|
|
|
851
871
|
data = np.load('train_data.npz')
|
|
852
|
-
assert data['input_train'].ndim
|
|
872
|
+
assert data['input_train'].ndim >= 2, "Input must be at least 2D: (N, ...) "
|
|
853
873
|
assert data['output_train'].ndim == 2, "Output must be 2D: (N, T)"
|
|
854
874
|
assert len(data['input_train']) == len(data['output_train']), "Sample mismatch"
|
|
855
875
|
|
|
@@ -1004,7 +1024,7 @@ Beyond the material characterization example above, the WaveDL pipeline can be a
|
|
|
1004
1024
|
| Resource | Description |
|
|
1005
1025
|
|----------|-------------|
|
|
1006
1026
|
| Technical Paper | In-depth framework description *(coming soon)* |
|
|
1007
|
-
| [`_template.py`](models/_template.py) | Template for
|
|
1027
|
+
| [`_template.py`](src/wavedl/models/_template.py) | Template for custom architectures |
|
|
1008
1028
|
|
|
1009
1029
|
---
|
|
1010
1030
|
|
|
@@ -457,12 +457,32 @@ WaveDL/
|
|
|
457
457
|
|
|
458
458
|
| Argument | Default | Description |
|
|
459
459
|
|----------|---------|-------------|
|
|
460
|
-
| `--compile` | `False` | Enable `torch.compile` |
|
|
460
|
+
| `--compile` | `False` | Enable `torch.compile` (recommended for long runs) |
|
|
461
461
|
| `--precision` | `bf16` | Mixed precision mode (`bf16`, `fp16`, `no`) |
|
|
462
|
+
| `--workers` | `-1` | DataLoader workers per GPU (-1=auto, up to 16) |
|
|
462
463
|
| `--wandb` | `False` | Enable W&B logging |
|
|
464
|
+
| `--wandb_watch` | `False` | Enable W&B gradient watching (adds overhead) |
|
|
463
465
|
| `--project_name` | `DL-Training` | W&B project name |
|
|
464
466
|
| `--run_name` | `None` | W&B run name (auto-generated if not set) |
|
|
465
467
|
|
|
468
|
+
**Automatic GPU Optimizations:**
|
|
469
|
+
|
|
470
|
+
WaveDL automatically enables performance optimizations for modern GPUs:
|
|
471
|
+
|
|
472
|
+
| Optimization | Effect | GPU Support |
|
|
473
|
+
|--------------|--------|-------------|
|
|
474
|
+
| **TF32 precision** | ~2x speedup for float32 matmul | A100, H100 (Ampere+) |
|
|
475
|
+
| **cuDNN benchmark** | Auto-tuned convolutions | All NVIDIA GPUs |
|
|
476
|
+
| **Worker scaling** | Up to 16 workers per GPU | All systems |
|
|
477
|
+
|
|
478
|
+
> [!NOTE]
|
|
479
|
+
> These optimizations are **backward compatible** — they have no effect on older GPUs (V100, T4, GTX) or CPU-only systems. No configuration needed.
|
|
480
|
+
|
|
481
|
+
**HPC Best Practices:**
|
|
482
|
+
- Stage data to `$SLURM_TMPDIR` (local NVMe) for maximum I/O throughput
|
|
483
|
+
- Use `--compile` for training runs > 50 epochs
|
|
484
|
+
- Increase `--workers` manually if auto-detection is suboptimal
|
|
485
|
+
|
|
466
486
|
</details>
|
|
467
487
|
|
|
468
488
|
<details>
|
|
@@ -645,7 +665,7 @@ python -m wavedl.hpo --data_path train.npz --models cnn --n_trials 50
|
|
|
645
665
|
python -m wavedl.hpo --data_path train.npz --models cnn resnet18 resnet50 vit_small densenet121 --n_trials 200
|
|
646
666
|
```
|
|
647
667
|
|
|
648
|
-
**
|
|
668
|
+
**Train with best parameters**
|
|
649
669
|
|
|
650
670
|
After HPO completes, it prints the optimal command:
|
|
651
671
|
```bash
|
|
@@ -792,7 +812,7 @@ import numpy as np
|
|
|
792
812
|
X = np.random.randn(1000, 256, 256).astype(np.float32)
|
|
793
813
|
y = np.random.randn(1000, 5).astype(np.float32)
|
|
794
814
|
|
|
795
|
-
np.savez('test_data.npz',
|
|
815
|
+
np.savez('test_data.npz', input_test=X, output_test=y)
|
|
796
816
|
```
|
|
797
817
|
|
|
798
818
|
</details>
|
|
@@ -804,7 +824,7 @@ np.savez('test_data.npz', input_train=X, output_train=y)
|
|
|
804
824
|
import numpy as np
|
|
805
825
|
|
|
806
826
|
data = np.load('train_data.npz')
|
|
807
|
-
assert data['input_train'].ndim
|
|
827
|
+
assert data['input_train'].ndim >= 2, "Input must be at least 2D: (N, ...) "
|
|
808
828
|
assert data['output_train'].ndim == 2, "Output must be 2D: (N, T)"
|
|
809
829
|
assert len(data['input_train']) == len(data['output_train']), "Sample mismatch"
|
|
810
830
|
|
|
@@ -959,7 +979,7 @@ Beyond the material characterization example above, the WaveDL pipeline can be a
|
|
|
959
979
|
| Resource | Description |
|
|
960
980
|
|----------|-------------|
|
|
961
981
|
| Technical Paper | In-depth framework description *(coming soon)* |
|
|
962
|
-
| [`_template.py`](models/_template.py) | Template for
|
|
982
|
+
| [`_template.py`](src/wavedl/models/_template.py) | Template for custom architectures |
|
|
963
983
|
|
|
964
984
|
---
|
|
965
985
|
|
|
@@ -130,6 +130,18 @@ Environment Variables:
|
|
|
130
130
|
default=0,
|
|
131
131
|
help="Rank of this machine in multi-node setup (default: 0)",
|
|
132
132
|
)
|
|
133
|
+
parser.add_argument(
|
|
134
|
+
"--main_process_ip",
|
|
135
|
+
type=str,
|
|
136
|
+
default=None,
|
|
137
|
+
help="IP address of the main process for multi-node training",
|
|
138
|
+
)
|
|
139
|
+
parser.add_argument(
|
|
140
|
+
"--main_process_port",
|
|
141
|
+
type=int,
|
|
142
|
+
default=None,
|
|
143
|
+
help="Port for multi-node communication (default: accelerate auto-selects)",
|
|
144
|
+
)
|
|
133
145
|
parser.add_argument(
|
|
134
146
|
"--mixed_precision",
|
|
135
147
|
type=str,
|
|
@@ -207,12 +219,18 @@ def main() -> int:
|
|
|
207
219
|
"launch",
|
|
208
220
|
f"--num_processes={num_gpus}",
|
|
209
221
|
f"--num_machines={args.num_machines}",
|
|
210
|
-
"--machine_rank=
|
|
222
|
+
f"--machine_rank={args.machine_rank}",
|
|
211
223
|
f"--mixed_precision={args.mixed_precision}",
|
|
212
224
|
f"--dynamo_backend={args.dynamo_backend}",
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
|
|
225
|
+
]
|
|
226
|
+
|
|
227
|
+
# Add multi-node networking args if specified (required for some clusters)
|
|
228
|
+
if args.main_process_ip:
|
|
229
|
+
cmd.append(f"--main_process_ip={args.main_process_ip}")
|
|
230
|
+
if args.main_process_port:
|
|
231
|
+
cmd.append(f"--main_process_port={args.main_process_port}")
|
|
232
|
+
|
|
233
|
+
cmd += ["-m", "wavedl.train"] + train_args
|
|
216
234
|
|
|
217
235
|
# Create output directory if specified
|
|
218
236
|
for i, arg in enumerate(train_args):
|
|
@@ -145,6 +145,7 @@ def create_objective(args):
|
|
|
145
145
|
# Use temporary directory for trial output
|
|
146
146
|
with tempfile.TemporaryDirectory() as tmpdir:
|
|
147
147
|
cmd.extend(["--output_dir", tmpdir])
|
|
148
|
+
history_file = Path(tmpdir) / "training_history.csv"
|
|
148
149
|
|
|
149
150
|
# Run training
|
|
150
151
|
try:
|
|
@@ -156,29 +157,55 @@ def create_objective(args):
|
|
|
156
157
|
cwd=Path(__file__).parent,
|
|
157
158
|
)
|
|
158
159
|
|
|
159
|
-
#
|
|
160
|
-
# Look for "Best val_loss: X.XXXX" in stdout
|
|
160
|
+
# Read best val_loss from training_history.csv (reliable machine-readable)
|
|
161
161
|
val_loss = None
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
162
|
+
if history_file.exists():
|
|
163
|
+
try:
|
|
164
|
+
import csv
|
|
165
|
+
|
|
166
|
+
with open(history_file) as f:
|
|
167
|
+
reader = csv.DictReader(f)
|
|
168
|
+
val_losses = []
|
|
169
|
+
for row in reader:
|
|
170
|
+
if "val_loss" in row:
|
|
171
|
+
try:
|
|
172
|
+
val_losses.append(float(row["val_loss"]))
|
|
173
|
+
except (ValueError, TypeError):
|
|
174
|
+
pass
|
|
175
|
+
if val_losses:
|
|
176
|
+
val_loss = min(val_losses) # Best (minimum) val_loss
|
|
177
|
+
except Exception as e:
|
|
178
|
+
print(f"Trial {trial.number}: Error reading history: {e}")
|
|
179
|
+
|
|
180
|
+
if val_loss is None:
|
|
181
|
+
# Fallback: parse stdout for training log format
|
|
182
|
+
# Pattern: "epoch | train_loss | val_loss | ..."
|
|
183
|
+
# Use regex to avoid false positives from unrelated lines
|
|
184
|
+
import re
|
|
185
|
+
|
|
186
|
+
# Match lines like: " 42 | 0.0123 | 0.0156 | ..."
|
|
187
|
+
log_pattern = re.compile(
|
|
188
|
+
r"^\s*\d+\s*\|\s*[\d.]+\s*\|\s*([\d.]+)\s*\|"
|
|
189
|
+
)
|
|
190
|
+
val_losses_stdout = []
|
|
191
|
+
for line in result.stdout.split("\n"):
|
|
192
|
+
match = log_pattern.match(line)
|
|
193
|
+
if match:
|
|
194
|
+
try:
|
|
195
|
+
val_losses_stdout.append(float(match.group(1)))
|
|
196
|
+
except ValueError:
|
|
197
|
+
continue
|
|
198
|
+
if val_losses_stdout:
|
|
199
|
+
val_loss = min(val_losses_stdout)
|
|
178
200
|
|
|
179
201
|
if val_loss is None:
|
|
180
202
|
# Training failed or no loss found
|
|
181
|
-
print(f"Trial {trial.number}: Training failed")
|
|
203
|
+
print(f"Trial {trial.number}: Training failed (no val_loss found)")
|
|
204
|
+
if result.returncode != 0:
|
|
205
|
+
# Show last few lines of stderr for debugging
|
|
206
|
+
stderr_lines = result.stderr.strip().split("\n")[-3:]
|
|
207
|
+
for line in stderr_lines:
|
|
208
|
+
print(f" stderr: {line}")
|
|
182
209
|
return float("inf")
|
|
183
210
|
|
|
184
211
|
print(f"Trial {trial.number}: val_loss={val_loss:.6f}")
|
|
@@ -1,23 +1,26 @@
|
|
|
1
1
|
"""
|
|
2
|
-
Model Template for
|
|
3
|
-
|
|
2
|
+
Model Template for Custom Architectures
|
|
3
|
+
========================================
|
|
4
4
|
|
|
5
|
-
Copy this file and modify to add
|
|
5
|
+
Copy this file and modify to add custom model architectures to WaveDL.
|
|
6
6
|
The model will be automatically registered and available via --model flag.
|
|
7
7
|
|
|
8
|
-
|
|
9
|
-
1. Copy this file to
|
|
10
|
-
2. Rename the class and update @register_model("
|
|
11
|
-
3. Implement
|
|
12
|
-
4.
|
|
13
|
-
|
|
14
|
-
|
|
8
|
+
Quick Start:
|
|
9
|
+
1. Copy this file to your project: cp _template.py my_model.py
|
|
10
|
+
2. Rename the class and update @register_model("my_model")
|
|
11
|
+
3. Implement your architecture in __init__ and forward
|
|
12
|
+
4. Train: wavedl-train --import my_model --model my_model --data_path data.npz
|
|
13
|
+
|
|
14
|
+
Requirements (your model MUST):
|
|
15
|
+
1. Inherit from BaseModel
|
|
16
|
+
2. Accept (in_shape, out_size, **kwargs) in __init__
|
|
17
|
+
3. Return tensor of shape (batch, out_size) from forward()
|
|
18
|
+
|
|
19
|
+
See README.md "Adding Custom Models" section for more details.
|
|
15
20
|
|
|
16
21
|
Author: Ductho Le (ductho.le@outlook.com)
|
|
17
22
|
"""
|
|
18
23
|
|
|
19
|
-
from typing import Any
|
|
20
|
-
|
|
21
24
|
import torch
|
|
22
25
|
import torch.nn as nn
|
|
23
26
|
|
|
@@ -25,7 +28,7 @@ from wavedl.models.base import BaseModel
|
|
|
25
28
|
|
|
26
29
|
|
|
27
30
|
# Uncomment the decorator to register this model
|
|
28
|
-
# @register_model("
|
|
31
|
+
# @register_model("my_model")
|
|
29
32
|
class TemplateModel(BaseModel):
|
|
30
33
|
"""
|
|
31
34
|
Template Model Architecture.
|
|
@@ -34,14 +37,16 @@ class TemplateModel(BaseModel):
|
|
|
34
37
|
The first line will appear in --list_models output.
|
|
35
38
|
|
|
36
39
|
Args:
|
|
37
|
-
in_shape: Input spatial dimensions (
|
|
38
|
-
|
|
40
|
+
in_shape: Input spatial dimensions (auto-detected from data)
|
|
41
|
+
- 1D: (L,) for signals
|
|
42
|
+
- 2D: (H, W) for images
|
|
43
|
+
- 3D: (D, H, W) for volumes
|
|
44
|
+
out_size: Number of regression targets (auto-detected from data)
|
|
39
45
|
hidden_dim: Size of hidden layers (default: 256)
|
|
40
|
-
num_layers: Number of convolutional layers (default: 4)
|
|
41
46
|
dropout: Dropout rate (default: 0.1)
|
|
42
47
|
|
|
43
48
|
Input Shape:
|
|
44
|
-
(B, 1,
|
|
49
|
+
(B, 1, *in_shape) - e.g., (B, 1, 64, 64) for 2D
|
|
45
50
|
|
|
46
51
|
Output Shape:
|
|
47
52
|
(B, out_size) - Regression predictions
|
|
@@ -49,10 +54,9 @@ class TemplateModel(BaseModel):
|
|
|
49
54
|
|
|
50
55
|
def __init__(
|
|
51
56
|
self,
|
|
52
|
-
in_shape: tuple
|
|
57
|
+
in_shape: tuple,
|
|
53
58
|
out_size: int,
|
|
54
59
|
hidden_dim: int = 256,
|
|
55
|
-
num_layers: int = 4,
|
|
56
60
|
dropout: float = 0.1,
|
|
57
61
|
**kwargs, # Accept extra kwargs for flexibility
|
|
58
62
|
):
|
|
@@ -61,14 +65,13 @@ class TemplateModel(BaseModel):
|
|
|
61
65
|
|
|
62
66
|
# Store hyperparameters as attributes (optional but recommended)
|
|
63
67
|
self.hidden_dim = hidden_dim
|
|
64
|
-
self.num_layers = num_layers
|
|
65
68
|
self.dropout_rate = dropout
|
|
66
69
|
|
|
67
70
|
# =================================================================
|
|
68
71
|
# BUILD YOUR ARCHITECTURE HERE
|
|
69
72
|
# =================================================================
|
|
70
73
|
|
|
71
|
-
# Example: Simple CNN encoder
|
|
74
|
+
# Example: Simple CNN encoder (assumes 2D input with 1 channel)
|
|
72
75
|
self.encoder = nn.Sequential(
|
|
73
76
|
# Layer 1
|
|
74
77
|
nn.Conv2d(1, 32, kernel_size=3, padding=1),
|
|
@@ -106,10 +109,10 @@ class TemplateModel(BaseModel):
|
|
|
106
109
|
"""
|
|
107
110
|
Forward pass of the model.
|
|
108
111
|
|
|
109
|
-
REQUIRED: Must accept (B, C,
|
|
112
|
+
REQUIRED: Must accept (B, C, *spatial) and return (B, out_size)
|
|
110
113
|
|
|
111
114
|
Args:
|
|
112
|
-
x: Input tensor of shape (B, 1,
|
|
115
|
+
x: Input tensor of shape (B, 1, *in_shape)
|
|
113
116
|
|
|
114
117
|
Returns:
|
|
115
118
|
Output tensor of shape (B, out_size)
|
|
@@ -122,35 +125,20 @@ class TemplateModel(BaseModel):
|
|
|
122
125
|
|
|
123
126
|
return output
|
|
124
127
|
|
|
125
|
-
@classmethod
|
|
126
|
-
def get_default_config(cls) -> dict[str, Any]:
|
|
127
|
-
"""
|
|
128
|
-
Return default hyperparameters for this model.
|
|
129
|
-
|
|
130
|
-
OPTIONAL: Override to provide model-specific defaults.
|
|
131
|
-
These can be used for documentation or config files.
|
|
132
|
-
"""
|
|
133
|
-
return {
|
|
134
|
-
"hidden_dim": 256,
|
|
135
|
-
"num_layers": 4,
|
|
136
|
-
"dropout": 0.1,
|
|
137
|
-
}
|
|
138
|
-
|
|
139
128
|
|
|
140
129
|
# =============================================================================
|
|
141
130
|
# USAGE EXAMPLE
|
|
142
131
|
# =============================================================================
|
|
143
132
|
if __name__ == "__main__":
|
|
144
133
|
# Quick test of the model
|
|
145
|
-
model = TemplateModel(in_shape=(
|
|
134
|
+
model = TemplateModel(in_shape=(64, 64), out_size=5)
|
|
146
135
|
|
|
147
136
|
# Print model summary
|
|
148
137
|
print(f"Model: {model.__class__.__name__}")
|
|
149
138
|
print(f"Parameters: {model.count_parameters():,}")
|
|
150
|
-
print(f"Default config: {model.get_default_config()}")
|
|
151
139
|
|
|
152
140
|
# Test forward pass
|
|
153
|
-
dummy_input = torch.randn(2, 1,
|
|
141
|
+
dummy_input = torch.randn(2, 1, 64, 64)
|
|
154
142
|
output = model(dummy_input)
|
|
155
143
|
print(f"Input shape: {dummy_input.shape}")
|
|
156
144
|
print(f"Output shape: {output.shape}")
|
|
@@ -75,13 +75,61 @@ class BaseModel(nn.Module, ABC):
|
|
|
75
75
|
Forward pass of the model.
|
|
76
76
|
|
|
77
77
|
Args:
|
|
78
|
-
x: Input tensor of shape (B, C,
|
|
78
|
+
x: Input tensor of shape (B, C, *spatial_dims)
|
|
79
|
+
- 1D: (B, C, L)
|
|
80
|
+
- 2D: (B, C, H, W)
|
|
81
|
+
- 3D: (B, C, D, H, W)
|
|
79
82
|
|
|
80
83
|
Returns:
|
|
81
84
|
Output tensor of shape (B, out_size)
|
|
82
85
|
"""
|
|
83
86
|
pass
|
|
84
87
|
|
|
88
|
+
def validate_input_shape(self, x: torch.Tensor) -> None:
|
|
89
|
+
"""
|
|
90
|
+
Validate input tensor shape against model's expected shape.
|
|
91
|
+
|
|
92
|
+
Call this at the start of forward() for explicit shape contract enforcement.
|
|
93
|
+
Provides clear, actionable error messages instead of cryptic Conv layer errors.
|
|
94
|
+
|
|
95
|
+
Args:
|
|
96
|
+
x: Input tensor to validate
|
|
97
|
+
|
|
98
|
+
Raises:
|
|
99
|
+
ValueError: If shape doesn't match expected dimensions
|
|
100
|
+
|
|
101
|
+
Example:
|
|
102
|
+
def forward(self, x):
|
|
103
|
+
self.validate_input_shape(x) # Optional but recommended
|
|
104
|
+
return self.model(x)
|
|
105
|
+
"""
|
|
106
|
+
expected_ndim = len(self.in_shape) + 2 # +2 for (batch, channel)
|
|
107
|
+
|
|
108
|
+
if x.ndim != expected_ndim:
|
|
109
|
+
dim_names = {
|
|
110
|
+
3: "1D (B, C, L)",
|
|
111
|
+
4: "2D (B, C, H, W)",
|
|
112
|
+
5: "3D (B, C, D, H, W)",
|
|
113
|
+
}
|
|
114
|
+
expected_name = dim_names.get(expected_ndim, f"{expected_ndim}D")
|
|
115
|
+
actual_name = dim_names.get(x.ndim, f"{x.ndim}D")
|
|
116
|
+
raise ValueError(
|
|
117
|
+
f"Input shape mismatch: model expects {expected_name} input, "
|
|
118
|
+
f"got {actual_name} with shape {tuple(x.shape)}.\n"
|
|
119
|
+
f"Expected in_shape: {self.in_shape} -> input should be (B, C, {', '.join(map(str, self.in_shape))})\n"
|
|
120
|
+
f"Hint: Check your data preprocessing - you may need to add/remove dimensions."
|
|
121
|
+
)
|
|
122
|
+
|
|
123
|
+
# Validate spatial dimensions match
|
|
124
|
+
spatial_dims = tuple(x.shape[2:]) # Skip batch and channel
|
|
125
|
+
if spatial_dims != tuple(self.in_shape):
|
|
126
|
+
raise ValueError(
|
|
127
|
+
f"Spatial dimension mismatch: model expects {self.in_shape}, "
|
|
128
|
+
f"got {spatial_dims}.\n"
|
|
129
|
+
f"Full input shape: {tuple(x.shape)} (B={x.shape[0]}, C={x.shape[1]})\n"
|
|
130
|
+
f"Hint: Ensure your data dimensions match the model's in_shape."
|
|
131
|
+
)
|
|
132
|
+
|
|
85
133
|
def count_parameters(self, trainable_only: bool = True) -> int:
|
|
86
134
|
"""
|
|
87
135
|
Count the number of parameters in the model.
|
|
@@ -12,11 +12,11 @@ A modular training framework for wave-based inverse problems and regression:
|
|
|
12
12
|
6. Deep Observability: WandB integration with scatter analysis
|
|
13
13
|
|
|
14
14
|
Usage:
|
|
15
|
-
# Recommended: Using the HPC launcher
|
|
16
|
-
wavedl-hpc --model cnn --batch_size 128 --wandb
|
|
15
|
+
# Recommended: Using the HPC launcher (handles accelerate configuration)
|
|
16
|
+
wavedl-hpc --model cnn --batch_size 128 --mixed_precision bf16 --wandb
|
|
17
17
|
|
|
18
|
-
# Or
|
|
19
|
-
accelerate launch -m wavedl.train --model cnn --batch_size 128 --
|
|
18
|
+
# Or direct training module (use --precision, not --mixed_precision)
|
|
19
|
+
accelerate launch -m wavedl.train --model cnn --batch_size 128 --precision bf16
|
|
20
20
|
|
|
21
21
|
# Multi-GPU with explicit config
|
|
22
22
|
wavedl-hpc --num_gpus 4 --mixed_precision bf16 --model cnn --wandb
|
|
@@ -28,9 +28,9 @@ Usage:
|
|
|
28
28
|
wavedl-train --list_models
|
|
29
29
|
|
|
30
30
|
Note:
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
31
|
+
- wavedl-hpc: Uses --mixed_precision (passed to accelerate launch)
|
|
32
|
+
- wavedl.train: Uses --precision (internal module flag)
|
|
33
|
+
Both control the same behavior; use the appropriate flag for your entry point.
|
|
34
34
|
|
|
35
35
|
Author: Ductho Le (ductho.le@outlook.com)
|
|
36
36
|
"""
|
|
@@ -97,6 +97,18 @@ warnings.filterwarnings("ignore", category=DeprecationWarning)
|
|
|
97
97
|
warnings.filterwarnings("ignore", module="pydantic")
|
|
98
98
|
warnings.filterwarnings("ignore", message=".*UnsupportedFieldAttributeWarning.*")
|
|
99
99
|
|
|
100
|
+
# ==============================================================================
|
|
101
|
+
# GPU PERFORMANCE OPTIMIZATIONS (Ampere/Hopper: A100, H100)
|
|
102
|
+
# ==============================================================================
|
|
103
|
+
# Enable TF32 for faster matmul (safe precision for training, ~2x speedup)
|
|
104
|
+
torch.backends.cuda.matmul.allow_tf32 = True
|
|
105
|
+
torch.backends.cudnn.allow_tf32 = True
|
|
106
|
+
torch.set_float32_matmul_precision("high") # Use TF32 for float32 ops
|
|
107
|
+
|
|
108
|
+
# Enable cuDNN autotuning for fixed-size inputs (CNN-like models benefit most)
|
|
109
|
+
# Note: First few batches may be slower due to benchmarking
|
|
110
|
+
torch.backends.cudnn.benchmark = True
|
|
111
|
+
|
|
100
112
|
|
|
101
113
|
# ==============================================================================
|
|
102
114
|
# ARGUMENT PARSING
|
|
@@ -298,11 +310,24 @@ def parse_args() -> argparse.Namespace:
|
|
|
298
310
|
choices=["bf16", "fp16", "no"],
|
|
299
311
|
help="Mixed precision mode",
|
|
300
312
|
)
|
|
313
|
+
# Alias for consistency with wavedl-hpc (--mixed_precision)
|
|
314
|
+
parser.add_argument(
|
|
315
|
+
"--mixed_precision",
|
|
316
|
+
dest="precision",
|
|
317
|
+
type=str,
|
|
318
|
+
choices=["bf16", "fp16", "no"],
|
|
319
|
+
help=argparse.SUPPRESS, # Hidden: use --precision instead
|
|
320
|
+
)
|
|
301
321
|
|
|
302
322
|
# Logging
|
|
303
323
|
parser.add_argument(
|
|
304
324
|
"--wandb", action="store_true", help="Enable Weights & Biases logging"
|
|
305
325
|
)
|
|
326
|
+
parser.add_argument(
|
|
327
|
+
"--wandb_watch",
|
|
328
|
+
action="store_true",
|
|
329
|
+
help="Enable WandB gradient watching (adds overhead, useful for debugging)",
|
|
330
|
+
)
|
|
306
331
|
parser.add_argument(
|
|
307
332
|
"--project_name", type=str, default="DL-Training", help="WandB project name"
|
|
308
333
|
)
|
|
@@ -467,8 +492,8 @@ def main():
|
|
|
467
492
|
if _cv_handle is not None and hasattr(_cv_handle, "close"):
|
|
468
493
|
try:
|
|
469
494
|
_cv_handle.close()
|
|
470
|
-
except Exception:
|
|
471
|
-
|
|
495
|
+
except Exception as e:
|
|
496
|
+
logging.debug(f"Failed to close CV data handle: {e}")
|
|
472
497
|
return
|
|
473
498
|
|
|
474
499
|
# ==========================================================================
|
|
@@ -496,9 +521,9 @@ def main():
|
|
|
496
521
|
if args.workers < 0:
|
|
497
522
|
cpu_count = os.cpu_count() or 4
|
|
498
523
|
num_gpus = accelerator.num_processes
|
|
499
|
-
# Heuristic: 4-
|
|
500
|
-
#
|
|
501
|
-
args.workers = min(
|
|
524
|
+
# Heuristic: 4-16 workers per GPU, bounded by available CPU cores
|
|
525
|
+
# Increased cap from 8 to 16 for high-throughput GPUs (H100, A100)
|
|
526
|
+
args.workers = min(16, max(2, (cpu_count - 2) // num_gpus))
|
|
502
527
|
if accelerator.is_main_process:
|
|
503
528
|
logger.info(
|
|
504
529
|
f"⚙️ Auto-detected workers: {args.workers} per GPU "
|
|
@@ -544,9 +569,15 @@ def main():
|
|
|
544
569
|
)
|
|
545
570
|
logger.info(f" Model Size: {param_info['total_mb']:.2f} MB")
|
|
546
571
|
|
|
547
|
-
# Optional WandB model watching
|
|
548
|
-
if
|
|
572
|
+
# Optional WandB model watching (opt-in due to overhead on large models)
|
|
573
|
+
if (
|
|
574
|
+
args.wandb
|
|
575
|
+
and args.wandb_watch
|
|
576
|
+
and WANDB_AVAILABLE
|
|
577
|
+
and accelerator.is_main_process
|
|
578
|
+
):
|
|
549
579
|
wandb.watch(model, log="gradients", log_freq=100)
|
|
580
|
+
logger.info(" 📊 WandB gradient watching enabled")
|
|
550
581
|
|
|
551
582
|
# Torch 2.0 compilation (requires compatible Triton on GPU)
|
|
552
583
|
if args.compile:
|
|
@@ -820,7 +851,7 @@ def main():
|
|
|
820
851
|
val_mae_sum = torch.zeros(out_dim, device=accelerator.device)
|
|
821
852
|
val_samples = 0
|
|
822
853
|
|
|
823
|
-
# Accumulate predictions locally
|
|
854
|
+
# Accumulate predictions locally ON CPU to prevent GPU OOM
|
|
824
855
|
local_preds = []
|
|
825
856
|
local_targets = []
|
|
826
857
|
|
|
@@ -836,17 +867,19 @@ def main():
|
|
|
836
867
|
mae_batch = torch.abs((pred - y) * phys_scale).sum(dim=0)
|
|
837
868
|
val_mae_sum += mae_batch
|
|
838
869
|
|
|
839
|
-
# Store
|
|
840
|
-
local_preds.append(pred)
|
|
841
|
-
local_targets.append(y)
|
|
870
|
+
# Store on CPU (critical for large val sets)
|
|
871
|
+
local_preds.append(pred.detach().cpu())
|
|
872
|
+
local_targets.append(y.detach().cpu())
|
|
873
|
+
|
|
874
|
+
# Concatenate locally on CPU (no GPU memory spike)
|
|
875
|
+
cpu_preds = torch.cat(local_preds)
|
|
876
|
+
cpu_targets = torch.cat(local_targets)
|
|
842
877
|
|
|
843
|
-
#
|
|
844
|
-
|
|
845
|
-
|
|
846
|
-
all_preds = accelerator.gather_for_metrics(all_local_preds)
|
|
847
|
-
all_targets = accelerator.gather_for_metrics(all_local_targets)
|
|
878
|
+
# Gather to rank 0 only via gather_object (avoids all-gather to every rank)
|
|
879
|
+
# gather_object returns list of objects from each rank: [(preds0, targs0), (preds1, targs1), ...]
|
|
880
|
+
gathered = accelerator.gather_object((cpu_preds, cpu_targets))
|
|
848
881
|
|
|
849
|
-
# Synchronize validation metrics
|
|
882
|
+
# Synchronize validation metrics (scalars only - efficient)
|
|
850
883
|
val_loss_scalar = val_loss_sum.item()
|
|
851
884
|
val_metrics = torch.cat(
|
|
852
885
|
[
|
|
@@ -869,9 +902,14 @@ def main():
|
|
|
869
902
|
|
|
870
903
|
# ==================== LOGGING & CHECKPOINTING ====================
|
|
871
904
|
if accelerator.is_main_process:
|
|
872
|
-
#
|
|
873
|
-
|
|
874
|
-
|
|
905
|
+
# Concatenate gathered tensors from all ranks (only on rank 0)
|
|
906
|
+
# gathered is list of tuples: [(preds_rank0, targs_rank0), (preds_rank1, targs_rank1), ...]
|
|
907
|
+
all_preds = torch.cat([item[0] for item in gathered])
|
|
908
|
+
all_targets = torch.cat([item[1] for item in gathered])
|
|
909
|
+
|
|
910
|
+
# Scientific metrics - cast to float32 before numpy
|
|
911
|
+
y_pred = all_preds.float().numpy()
|
|
912
|
+
y_true = all_targets.float().numpy()
|
|
875
913
|
|
|
876
914
|
# Trim DDP padding
|
|
877
915
|
real_len = len(val_dl.dataset)
|
|
@@ -116,7 +116,13 @@ def merge_config_with_args(
|
|
|
116
116
|
"""
|
|
117
117
|
# Get parser defaults to detect which args were explicitly set by user
|
|
118
118
|
if parser is not None:
|
|
119
|
-
|
|
119
|
+
# Safe extraction: iterate actions instead of parse_args([])
|
|
120
|
+
# This avoids failures if required arguments are added later
|
|
121
|
+
defaults = {
|
|
122
|
+
action.dest: action.default
|
|
123
|
+
for action in parser._actions
|
|
124
|
+
if action.dest != "help"
|
|
125
|
+
}
|
|
120
126
|
else:
|
|
121
127
|
# Fallback: reconstruct defaults from known patterns
|
|
122
128
|
# This works because argparse stores actual values, and we compare
|
|
@@ -141,6 +147,9 @@ def merge_config_with_args(
|
|
|
141
147
|
setattr(args, key, value)
|
|
142
148
|
elif not ignore_unknown:
|
|
143
149
|
logging.warning(f"Unknown config key: {key}")
|
|
150
|
+
else:
|
|
151
|
+
# Even in ignore_unknown mode, log for discoverability
|
|
152
|
+
logging.debug(f"Config key '{key}' ignored: not a valid argument")
|
|
144
153
|
|
|
145
154
|
return args
|
|
146
155
|
|
|
@@ -188,12 +197,15 @@ def save_config(
|
|
|
188
197
|
return str(output_path)
|
|
189
198
|
|
|
190
199
|
|
|
191
|
-
def validate_config(
|
|
200
|
+
def validate_config(
|
|
201
|
+
config: dict[str, Any], known_keys: list[str] | None = None
|
|
202
|
+
) -> list[str]:
|
|
192
203
|
"""
|
|
193
204
|
Validate configuration values against known options.
|
|
194
205
|
|
|
195
206
|
Args:
|
|
196
207
|
config: Configuration dictionary
|
|
208
|
+
known_keys: Optional list of valid keys (if None, uses defaults from parser args)
|
|
197
209
|
|
|
198
210
|
Returns:
|
|
199
211
|
List of warning messages (empty if valid)
|
|
@@ -229,9 +241,83 @@ def validate_config(config: dict[str, Any]) -> list[str]:
|
|
|
229
241
|
for key, (min_val, max_val, msg) in numeric_checks.items():
|
|
230
242
|
if key in config:
|
|
231
243
|
val = config[key]
|
|
244
|
+
# Type check: ensure value is numeric before comparison
|
|
245
|
+
if not isinstance(val, (int, float)):
|
|
246
|
+
warnings.append(
|
|
247
|
+
f"Invalid type for '{key}': expected number, got {type(val).__name__} ({val!r})"
|
|
248
|
+
)
|
|
249
|
+
continue
|
|
232
250
|
if not (min_val <= val <= max_val):
|
|
233
251
|
warnings.append(f"{msg}: got {val}")
|
|
234
252
|
|
|
253
|
+
# Check for unknown/unrecognized keys (helps catch typos)
|
|
254
|
+
# Default known keys based on common training arguments
|
|
255
|
+
default_known_keys = {
|
|
256
|
+
# Model
|
|
257
|
+
"model",
|
|
258
|
+
"import_modules",
|
|
259
|
+
# Hyperparameters
|
|
260
|
+
"batch_size",
|
|
261
|
+
"lr",
|
|
262
|
+
"epochs",
|
|
263
|
+
"patience",
|
|
264
|
+
"weight_decay",
|
|
265
|
+
"grad_clip",
|
|
266
|
+
# Loss
|
|
267
|
+
"loss",
|
|
268
|
+
"huber_delta",
|
|
269
|
+
"loss_weights",
|
|
270
|
+
# Optimizer
|
|
271
|
+
"optimizer",
|
|
272
|
+
"momentum",
|
|
273
|
+
"nesterov",
|
|
274
|
+
"betas",
|
|
275
|
+
# Scheduler
|
|
276
|
+
"scheduler",
|
|
277
|
+
"scheduler_patience",
|
|
278
|
+
"min_lr",
|
|
279
|
+
"scheduler_factor",
|
|
280
|
+
"warmup_epochs",
|
|
281
|
+
"step_size",
|
|
282
|
+
"milestones",
|
|
283
|
+
# Data
|
|
284
|
+
"data_path",
|
|
285
|
+
"workers",
|
|
286
|
+
"seed",
|
|
287
|
+
"single_channel",
|
|
288
|
+
# Cross-validation
|
|
289
|
+
"cv",
|
|
290
|
+
"cv_stratify",
|
|
291
|
+
"cv_bins",
|
|
292
|
+
# Checkpointing
|
|
293
|
+
"resume",
|
|
294
|
+
"save_every",
|
|
295
|
+
"output_dir",
|
|
296
|
+
"fresh",
|
|
297
|
+
# Performance
|
|
298
|
+
"compile",
|
|
299
|
+
"precision",
|
|
300
|
+
"mixed_precision",
|
|
301
|
+
# Logging
|
|
302
|
+
"wandb",
|
|
303
|
+
"wandb_watch",
|
|
304
|
+
"project_name",
|
|
305
|
+
"run_name",
|
|
306
|
+
# Config
|
|
307
|
+
"config",
|
|
308
|
+
"list_models",
|
|
309
|
+
# Metadata (internal)
|
|
310
|
+
"_metadata",
|
|
311
|
+
}
|
|
312
|
+
|
|
313
|
+
check_keys = set(known_keys) if known_keys else default_known_keys
|
|
314
|
+
|
|
315
|
+
for key in config:
|
|
316
|
+
if key not in check_keys:
|
|
317
|
+
warnings.append(
|
|
318
|
+
f"Unknown config key: '{key}' - check for typos or see wavedl-train --help"
|
|
319
|
+
)
|
|
320
|
+
|
|
235
321
|
return warnings
|
|
236
322
|
|
|
237
323
|
|
|
@@ -735,7 +735,7 @@ def load_test_data(
|
|
|
735
735
|
try:
|
|
736
736
|
inp, outp = source.load(path)
|
|
737
737
|
except KeyError:
|
|
738
|
-
# Try with just inputs if outputs not found
|
|
738
|
+
# Try with just inputs if outputs not found (inference-only mode)
|
|
739
739
|
if format == "npz":
|
|
740
740
|
data = np.load(path, allow_pickle=True)
|
|
741
741
|
keys = list(data.keys())
|
|
@@ -751,6 +751,54 @@ def load_test_data(
|
|
|
751
751
|
)
|
|
752
752
|
out_key = DataSource._find_key(keys, custom_output_keys)
|
|
753
753
|
outp = data[out_key] if out_key else None
|
|
754
|
+
elif format == "hdf5":
|
|
755
|
+
# HDF5: input-only loading for inference
|
|
756
|
+
with h5py.File(path, "r") as f:
|
|
757
|
+
keys = list(f.keys())
|
|
758
|
+
inp_key = DataSource._find_key(keys, custom_input_keys)
|
|
759
|
+
if inp_key is None:
|
|
760
|
+
raise KeyError(
|
|
761
|
+
f"Input key not found. Tried: {custom_input_keys}. Found: {keys}"
|
|
762
|
+
)
|
|
763
|
+
# Check size - load_test_data is eager, large files should use DataLoader
|
|
764
|
+
n_samples = f[inp_key].shape[0]
|
|
765
|
+
if n_samples > 100000:
|
|
766
|
+
raise ValueError(
|
|
767
|
+
f"Dataset has {n_samples:,} samples. load_test_data() loads "
|
|
768
|
+
f"everything into RAM which may cause OOM. For large inference "
|
|
769
|
+
f"sets, use a DataLoader with HDF5Source.load_mmap() instead."
|
|
770
|
+
)
|
|
771
|
+
inp = f[inp_key][:]
|
|
772
|
+
out_key = DataSource._find_key(keys, custom_output_keys)
|
|
773
|
+
outp = f[out_key][:] if out_key else None
|
|
774
|
+
elif format == "mat":
|
|
775
|
+
# MAT v7.3: input-only loading with proper sparse handling
|
|
776
|
+
mat_source = MATSource()
|
|
777
|
+
with h5py.File(path, "r") as f:
|
|
778
|
+
keys = list(f.keys())
|
|
779
|
+
inp_key = DataSource._find_key(keys, custom_input_keys)
|
|
780
|
+
if inp_key is None:
|
|
781
|
+
raise KeyError(
|
|
782
|
+
f"Input key not found. Tried: {custom_input_keys}. Found: {keys}"
|
|
783
|
+
)
|
|
784
|
+
# Check size - load_test_data is eager, large files should use DataLoader
|
|
785
|
+
n_samples = f[inp_key].shape[-1] # MAT is transposed
|
|
786
|
+
if n_samples > 100000:
|
|
787
|
+
raise ValueError(
|
|
788
|
+
f"Dataset has {n_samples:,} samples. load_test_data() loads "
|
|
789
|
+
f"everything into RAM which may cause OOM. For large inference "
|
|
790
|
+
f"sets, use a DataLoader with MATSource.load_mmap() instead."
|
|
791
|
+
)
|
|
792
|
+
# Use _load_dataset for sparse support and proper transpose
|
|
793
|
+
inp = mat_source._load_dataset(f, inp_key)
|
|
794
|
+
out_key = DataSource._find_key(keys, custom_output_keys)
|
|
795
|
+
if out_key:
|
|
796
|
+
outp = mat_source._load_dataset(f, out_key)
|
|
797
|
+
# Handle 1D outputs that become (1, N) after transpose
|
|
798
|
+
if outp.ndim == 2 and outp.shape[0] == 1:
|
|
799
|
+
outp = outp.T
|
|
800
|
+
else:
|
|
801
|
+
outp = None
|
|
754
802
|
else:
|
|
755
803
|
raise
|
|
756
804
|
|
|
@@ -949,6 +997,15 @@ def prepare_data(
|
|
|
949
997
|
with open(META_FILE, "rb") as f:
|
|
950
998
|
meta = pickle.load(f)
|
|
951
999
|
cached_data_path = meta.get("data_path", None)
|
|
1000
|
+
cached_file_size = meta.get("file_size", None)
|
|
1001
|
+
cached_file_mtime = meta.get("file_mtime", None)
|
|
1002
|
+
|
|
1003
|
+
# Get current file stats
|
|
1004
|
+
current_stats = os.stat(args.data_path)
|
|
1005
|
+
current_size = current_stats.st_size
|
|
1006
|
+
current_mtime = current_stats.st_mtime
|
|
1007
|
+
|
|
1008
|
+
# Check if data path changed
|
|
952
1009
|
if cached_data_path != os.path.abspath(args.data_path):
|
|
953
1010
|
if accelerator.is_main_process:
|
|
954
1011
|
logger.warning(
|
|
@@ -958,6 +1015,23 @@ def prepare_data(
|
|
|
958
1015
|
f" Invalidating cache and regenerating..."
|
|
959
1016
|
)
|
|
960
1017
|
cache_exists = False
|
|
1018
|
+
# Check if file was modified (size or mtime changed)
|
|
1019
|
+
elif cached_file_size is not None and cached_file_size != current_size:
|
|
1020
|
+
if accelerator.is_main_process:
|
|
1021
|
+
logger.warning(
|
|
1022
|
+
f"⚠️ Data file size changed!\n"
|
|
1023
|
+
f" Cached size: {cached_file_size:,} bytes\n"
|
|
1024
|
+
f" Current size: {current_size:,} bytes\n"
|
|
1025
|
+
f" Invalidating cache and regenerating..."
|
|
1026
|
+
)
|
|
1027
|
+
cache_exists = False
|
|
1028
|
+
elif cached_file_mtime is not None and cached_file_mtime != current_mtime:
|
|
1029
|
+
if accelerator.is_main_process:
|
|
1030
|
+
logger.warning(
|
|
1031
|
+
"⚠️ Data file was modified!\n"
|
|
1032
|
+
" Cache may be stale, regenerating..."
|
|
1033
|
+
)
|
|
1034
|
+
cache_exists = False
|
|
961
1035
|
except Exception:
|
|
962
1036
|
cache_exists = False
|
|
963
1037
|
|
|
@@ -1053,13 +1127,16 @@ def prepare_data(
|
|
|
1053
1127
|
f" Shape Detected: {full_shape} [{dim_type}] | Output Dim: {out_dim}"
|
|
1054
1128
|
)
|
|
1055
1129
|
|
|
1056
|
-
# Save metadata (including data path for cache validation)
|
|
1130
|
+
# Save metadata (including data path, size, mtime for cache validation)
|
|
1131
|
+
file_stats = os.stat(args.data_path)
|
|
1057
1132
|
with open(META_FILE, "wb") as f:
|
|
1058
1133
|
pickle.dump(
|
|
1059
1134
|
{
|
|
1060
1135
|
"shape": full_shape,
|
|
1061
1136
|
"out_dim": out_dim,
|
|
1062
1137
|
"data_path": os.path.abspath(args.data_path),
|
|
1138
|
+
"file_size": file_stats.st_size,
|
|
1139
|
+
"file_mtime": file_stats.st_mtime,
|
|
1063
1140
|
},
|
|
1064
1141
|
f,
|
|
1065
1142
|
)
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.2
|
|
2
2
|
Name: wavedl
|
|
3
|
-
Version: 1.4.
|
|
3
|
+
Version: 1.4.2
|
|
4
4
|
Summary: A Scalable Deep Learning Framework for Wave-Based Inverse Problems
|
|
5
5
|
Author: Ductho Le
|
|
6
6
|
License: MIT
|
|
@@ -502,12 +502,32 @@ WaveDL/
|
|
|
502
502
|
|
|
503
503
|
| Argument | Default | Description |
|
|
504
504
|
|----------|---------|-------------|
|
|
505
|
-
| `--compile` | `False` | Enable `torch.compile` |
|
|
505
|
+
| `--compile` | `False` | Enable `torch.compile` (recommended for long runs) |
|
|
506
506
|
| `--precision` | `bf16` | Mixed precision mode (`bf16`, `fp16`, `no`) |
|
|
507
|
+
| `--workers` | `-1` | DataLoader workers per GPU (-1=auto, up to 16) |
|
|
507
508
|
| `--wandb` | `False` | Enable W&B logging |
|
|
509
|
+
| `--wandb_watch` | `False` | Enable W&B gradient watching (adds overhead) |
|
|
508
510
|
| `--project_name` | `DL-Training` | W&B project name |
|
|
509
511
|
| `--run_name` | `None` | W&B run name (auto-generated if not set) |
|
|
510
512
|
|
|
513
|
+
**Automatic GPU Optimizations:**
|
|
514
|
+
|
|
515
|
+
WaveDL automatically enables performance optimizations for modern GPUs:
|
|
516
|
+
|
|
517
|
+
| Optimization | Effect | GPU Support |
|
|
518
|
+
|--------------|--------|-------------|
|
|
519
|
+
| **TF32 precision** | ~2x speedup for float32 matmul | A100, H100 (Ampere+) |
|
|
520
|
+
| **cuDNN benchmark** | Auto-tuned convolutions | All NVIDIA GPUs |
|
|
521
|
+
| **Worker scaling** | Up to 16 workers per GPU | All systems |
|
|
522
|
+
|
|
523
|
+
> [!NOTE]
|
|
524
|
+
> These optimizations are **backward compatible** — they have no effect on older GPUs (V100, T4, GTX) or CPU-only systems. No configuration needed.
|
|
525
|
+
|
|
526
|
+
**HPC Best Practices:**
|
|
527
|
+
- Stage data to `$SLURM_TMPDIR` (local NVMe) for maximum I/O throughput
|
|
528
|
+
- Use `--compile` for training runs > 50 epochs
|
|
529
|
+
- Increase `--workers` manually if auto-detection is suboptimal
|
|
530
|
+
|
|
511
531
|
</details>
|
|
512
532
|
|
|
513
533
|
<details>
|
|
@@ -690,7 +710,7 @@ python -m wavedl.hpo --data_path train.npz --models cnn --n_trials 50
|
|
|
690
710
|
python -m wavedl.hpo --data_path train.npz --models cnn resnet18 resnet50 vit_small densenet121 --n_trials 200
|
|
691
711
|
```
|
|
692
712
|
|
|
693
|
-
**
|
|
713
|
+
**Train with best parameters**
|
|
694
714
|
|
|
695
715
|
After HPO completes, it prints the optimal command:
|
|
696
716
|
```bash
|
|
@@ -837,7 +857,7 @@ import numpy as np
|
|
|
837
857
|
X = np.random.randn(1000, 256, 256).astype(np.float32)
|
|
838
858
|
y = np.random.randn(1000, 5).astype(np.float32)
|
|
839
859
|
|
|
840
|
-
np.savez('test_data.npz',
|
|
860
|
+
np.savez('test_data.npz', input_test=X, output_test=y)
|
|
841
861
|
```
|
|
842
862
|
|
|
843
863
|
</details>
|
|
@@ -849,7 +869,7 @@ np.savez('test_data.npz', input_train=X, output_train=y)
|
|
|
849
869
|
import numpy as np
|
|
850
870
|
|
|
851
871
|
data = np.load('train_data.npz')
|
|
852
|
-
assert data['input_train'].ndim
|
|
872
|
+
assert data['input_train'].ndim >= 2, "Input must be at least 2D: (N, ...) "
|
|
853
873
|
assert data['output_train'].ndim == 2, "Output must be 2D: (N, T)"
|
|
854
874
|
assert len(data['input_train']) == len(data['output_train']), "Sample mismatch"
|
|
855
875
|
|
|
@@ -1004,7 +1024,7 @@ Beyond the material characterization example above, the WaveDL pipeline can be a
|
|
|
1004
1024
|
| Resource | Description |
|
|
1005
1025
|
|----------|-------------|
|
|
1006
1026
|
| Technical Paper | In-depth framework description *(coming soon)* |
|
|
1007
|
-
| [`_template.py`](models/_template.py) | Template for
|
|
1027
|
+
| [`_template.py`](src/wavedl/models/_template.py) | Template for custom architectures |
|
|
1008
1028
|
|
|
1009
1029
|
---
|
|
1010
1030
|
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|