wavedl 1.4.4__tar.gz → 1.4.6__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {wavedl-1.4.4/src/wavedl.egg-info → wavedl-1.4.6}/PKG-INFO +51 -14
- {wavedl-1.4.4 → wavedl-1.4.6}/README.md +50 -13
- {wavedl-1.4.4 → wavedl-1.4.6}/src/wavedl/__init__.py +1 -1
- {wavedl-1.4.4 → wavedl-1.4.6}/src/wavedl/hpc.py +11 -2
- {wavedl-1.4.4 → wavedl-1.4.6}/src/wavedl/hpo.py +51 -2
- {wavedl-1.4.4 → wavedl-1.4.6}/src/wavedl/test.py +25 -11
- {wavedl-1.4.4 → wavedl-1.4.6}/src/wavedl/train.py +27 -3
- {wavedl-1.4.4 → wavedl-1.4.6/src/wavedl.egg-info}/PKG-INFO +51 -14
- {wavedl-1.4.4 → wavedl-1.4.6}/LICENSE +0 -0
- {wavedl-1.4.4 → wavedl-1.4.6}/pyproject.toml +0 -0
- {wavedl-1.4.4 → wavedl-1.4.6}/setup.cfg +0 -0
- {wavedl-1.4.4 → wavedl-1.4.6}/src/wavedl/models/__init__.py +0 -0
- {wavedl-1.4.4 → wavedl-1.4.6}/src/wavedl/models/_template.py +0 -0
- {wavedl-1.4.4 → wavedl-1.4.6}/src/wavedl/models/base.py +0 -0
- {wavedl-1.4.4 → wavedl-1.4.6}/src/wavedl/models/cnn.py +0 -0
- {wavedl-1.4.4 → wavedl-1.4.6}/src/wavedl/models/convnext.py +0 -0
- {wavedl-1.4.4 → wavedl-1.4.6}/src/wavedl/models/densenet.py +0 -0
- {wavedl-1.4.4 → wavedl-1.4.6}/src/wavedl/models/efficientnet.py +0 -0
- {wavedl-1.4.4 → wavedl-1.4.6}/src/wavedl/models/efficientnetv2.py +0 -0
- {wavedl-1.4.4 → wavedl-1.4.6}/src/wavedl/models/mobilenetv3.py +0 -0
- {wavedl-1.4.4 → wavedl-1.4.6}/src/wavedl/models/registry.py +0 -0
- {wavedl-1.4.4 → wavedl-1.4.6}/src/wavedl/models/regnet.py +0 -0
- {wavedl-1.4.4 → wavedl-1.4.6}/src/wavedl/models/resnet.py +0 -0
- {wavedl-1.4.4 → wavedl-1.4.6}/src/wavedl/models/resnet3d.py +0 -0
- {wavedl-1.4.4 → wavedl-1.4.6}/src/wavedl/models/swin.py +0 -0
- {wavedl-1.4.4 → wavedl-1.4.6}/src/wavedl/models/tcn.py +0 -0
- {wavedl-1.4.4 → wavedl-1.4.6}/src/wavedl/models/unet.py +0 -0
- {wavedl-1.4.4 → wavedl-1.4.6}/src/wavedl/models/vit.py +0 -0
- {wavedl-1.4.4 → wavedl-1.4.6}/src/wavedl/utils/__init__.py +0 -0
- {wavedl-1.4.4 → wavedl-1.4.6}/src/wavedl/utils/config.py +0 -0
- {wavedl-1.4.4 → wavedl-1.4.6}/src/wavedl/utils/cross_validation.py +0 -0
- {wavedl-1.4.4 → wavedl-1.4.6}/src/wavedl/utils/data.py +0 -0
- {wavedl-1.4.4 → wavedl-1.4.6}/src/wavedl/utils/distributed.py +0 -0
- {wavedl-1.4.4 → wavedl-1.4.6}/src/wavedl/utils/losses.py +0 -0
- {wavedl-1.4.4 → wavedl-1.4.6}/src/wavedl/utils/metrics.py +0 -0
- {wavedl-1.4.4 → wavedl-1.4.6}/src/wavedl/utils/optimizers.py +0 -0
- {wavedl-1.4.4 → wavedl-1.4.6}/src/wavedl/utils/schedulers.py +0 -0
- {wavedl-1.4.4 → wavedl-1.4.6}/src/wavedl.egg-info/SOURCES.txt +0 -0
- {wavedl-1.4.4 → wavedl-1.4.6}/src/wavedl.egg-info/dependency_links.txt +0 -0
- {wavedl-1.4.4 → wavedl-1.4.6}/src/wavedl.egg-info/entry_points.txt +0 -0
- {wavedl-1.4.4 → wavedl-1.4.6}/src/wavedl.egg-info/requires.txt +0 -0
- {wavedl-1.4.4 → wavedl-1.4.6}/src/wavedl.egg-info/top_level.txt +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.2
|
|
2
2
|
Name: wavedl
|
|
3
|
-
Version: 1.4.
|
|
3
|
+
Version: 1.4.6
|
|
4
4
|
Summary: A Scalable Deep Learning Framework for Wave-Based Inverse Problems
|
|
5
5
|
Author: Ductho Le
|
|
6
6
|
License: MIT
|
|
@@ -49,7 +49,7 @@ Requires-Dist: triton>=2.0.0; sys_platform == "linux"
|
|
|
49
49
|
|
|
50
50
|
### A Scalable Deep Learning Framework for Wave-Based Inverse Problems
|
|
51
51
|
|
|
52
|
-
[](https://www.python.org/downloads/)
|
|
53
53
|
[](https://pytorch.org/)
|
|
54
54
|
[](https://huggingface.co/docs/accelerate/)
|
|
55
55
|
<br>
|
|
@@ -57,7 +57,7 @@ Requires-Dist: triton>=2.0.0; sys_platform == "linux"
|
|
|
57
57
|
[](https://github.com/ductho-le/WaveDL/actions/workflows/lint.yml)
|
|
58
58
|
[](https://colab.research.google.com/github/ductho-le/WaveDL/blob/main/notebooks/demo.ipynb)
|
|
59
59
|
<br>
|
|
60
|
-
[](https://pypistats.org/packages/wavedl)
|
|
61
61
|
[](LICENSE)
|
|
62
62
|
[](https://doi.org/10.5281/zenodo.18012338)
|
|
63
63
|
|
|
@@ -462,7 +462,43 @@ WaveDL/
|
|
|
462
462
|
| **U-Net** — U-shaped Network |||
|
|
463
463
|
| `unet_regression` | 31.1M | 1D/2D/3D |
|
|
464
464
|
|
|
465
|
-
|
|
465
|
+
⭐ = **Pretrained on ImageNet** (recommended for smaller datasets). Weights are downloaded automatically on first use.
|
|
466
|
+
- **Cache location**: `~/.cache/torch/hub/checkpoints/` (or `./.torch_cache/` on HPC if home is not writable)
|
|
467
|
+
- **Size**: ~20–350 MB per model depending on architecture
|
|
468
|
+
|
|
469
|
+
**💡 HPC Users**: If compute nodes block internet, pre-download weights on the login node:
|
|
470
|
+
|
|
471
|
+
```bash
|
|
472
|
+
# Run once on login node (with internet) — downloads ALL pretrained weights (~1.5 GB total)
|
|
473
|
+
python -c "
|
|
474
|
+
import os
|
|
475
|
+
os.environ['TORCH_HOME'] = '.torch_cache' # Match WaveDL's HPC cache location
|
|
476
|
+
|
|
477
|
+
from torchvision import models as m
|
|
478
|
+
from torchvision.models import video as v
|
|
479
|
+
|
|
480
|
+
# Model name -> Weights class mapping
|
|
481
|
+
weights = {
|
|
482
|
+
'resnet18': m.ResNet18_Weights, 'resnet50': m.ResNet50_Weights,
|
|
483
|
+
'efficientnet_b0': m.EfficientNet_B0_Weights, 'efficientnet_b1': m.EfficientNet_B1_Weights,
|
|
484
|
+
'efficientnet_b2': m.EfficientNet_B2_Weights, 'efficientnet_v2_s': m.EfficientNet_V2_S_Weights,
|
|
485
|
+
'efficientnet_v2_m': m.EfficientNet_V2_M_Weights, 'efficientnet_v2_l': m.EfficientNet_V2_L_Weights,
|
|
486
|
+
'mobilenet_v3_small': m.MobileNet_V3_Small_Weights, 'mobilenet_v3_large': m.MobileNet_V3_Large_Weights,
|
|
487
|
+
'regnet_y_400mf': m.RegNet_Y_400MF_Weights, 'regnet_y_800mf': m.RegNet_Y_800MF_Weights,
|
|
488
|
+
'regnet_y_1_6gf': m.RegNet_Y_1_6GF_Weights, 'regnet_y_3_2gf': m.RegNet_Y_3_2GF_Weights,
|
|
489
|
+
'regnet_y_8gf': m.RegNet_Y_8GF_Weights, 'swin_t': m.Swin_T_Weights, 'swin_s': m.Swin_S_Weights,
|
|
490
|
+
'swin_b': m.Swin_B_Weights, 'convnext_tiny': m.ConvNeXt_Tiny_Weights, 'densenet121': m.DenseNet121_Weights,
|
|
491
|
+
}
|
|
492
|
+
for name, w in weights.items():
|
|
493
|
+
getattr(m, name)(weights=w.DEFAULT); print(f'✓ {name}')
|
|
494
|
+
|
|
495
|
+
# 3D video models
|
|
496
|
+
v.r3d_18(weights=v.R3D_18_Weights.DEFAULT); print('✓ r3d_18')
|
|
497
|
+
v.mc3_18(weights=v.MC3_18_Weights.DEFAULT); print('✓ mc3_18')
|
|
498
|
+
print('\\n✓ All pretrained weights cached!')
|
|
499
|
+
"
|
|
500
|
+
```
|
|
501
|
+
|
|
466
502
|
|
|
467
503
|
</details>
|
|
468
504
|
|
|
@@ -687,7 +723,6 @@ compile: false
|
|
|
687
723
|
seed: 2025
|
|
688
724
|
```
|
|
689
725
|
|
|
690
|
-
> [!TIP]
|
|
691
726
|
> See [`configs/config.yaml`](configs/config.yaml) for the complete template with all available options documented.
|
|
692
727
|
|
|
693
728
|
</details>
|
|
@@ -699,18 +734,20 @@ Automatically find the best training configuration using [Optuna](https://optuna
|
|
|
699
734
|
|
|
700
735
|
**Run HPO:**
|
|
701
736
|
|
|
702
|
-
You specify which models to search and how many trials to run:
|
|
703
737
|
```bash
|
|
704
|
-
#
|
|
705
|
-
|
|
738
|
+
# Basic HPO (auto-detects GPUs for parallel trials)
|
|
739
|
+
wavedl-hpo --data_path train.npz --models cnn --n_trials 100
|
|
706
740
|
|
|
707
|
-
# Search
|
|
708
|
-
|
|
741
|
+
# Search multiple models
|
|
742
|
+
wavedl-hpo --data_path train.npz --models cnn resnet18 efficientnet_b0 --n_trials 200
|
|
709
743
|
|
|
710
|
-
#
|
|
711
|
-
|
|
744
|
+
# Quick mode (fewer parameters, faster)
|
|
745
|
+
wavedl-hpo --data_path train.npz --models cnn --n_trials 50 --quick
|
|
712
746
|
```
|
|
713
747
|
|
|
748
|
+
> [!TIP]
|
|
749
|
+
> **Auto GPU Detection**: HPO automatically detects available GPUs and runs one trial per GPU in parallel. On a 4-GPU system, 4 trials run simultaneously. Use `--n_jobs 1` to force serial execution.
|
|
750
|
+
|
|
714
751
|
**Train with best parameters**
|
|
715
752
|
|
|
716
753
|
After HPO completes, it prints the optimal command:
|
|
@@ -749,11 +786,11 @@ accelerate launch -m wavedl.train --data_path train.npz --model cnn --lr 3.2e-4
|
|
|
749
786
|
| `--optimizers` | all 6 | Optimizers to search |
|
|
750
787
|
| `--schedulers` | all 8 | Schedulers to search |
|
|
751
788
|
| `--losses` | all 6 | Losses to search |
|
|
752
|
-
| `--n_jobs` |
|
|
789
|
+
| `--n_jobs` | `-1` | Parallel trials (-1 = auto-detect GPUs) |
|
|
753
790
|
| `--max_epochs` | `50` | Max epochs per trial |
|
|
754
791
|
| `--output` | `hpo_results.json` | Output file |
|
|
755
792
|
|
|
756
|
-
|
|
793
|
+
|
|
757
794
|
> See [Available Models](#available-models) for all 38 architectures you can search.
|
|
758
795
|
|
|
759
796
|
</details>
|
|
@@ -4,7 +4,7 @@
|
|
|
4
4
|
|
|
5
5
|
### A Scalable Deep Learning Framework for Wave-Based Inverse Problems
|
|
6
6
|
|
|
7
|
-
[](https://www.python.org/downloads/)
|
|
8
8
|
[](https://pytorch.org/)
|
|
9
9
|
[](https://huggingface.co/docs/accelerate/)
|
|
10
10
|
<br>
|
|
@@ -12,7 +12,7 @@
|
|
|
12
12
|
[](https://github.com/ductho-le/WaveDL/actions/workflows/lint.yml)
|
|
13
13
|
[](https://colab.research.google.com/github/ductho-le/WaveDL/blob/main/notebooks/demo.ipynb)
|
|
14
14
|
<br>
|
|
15
|
-
[](https://pypistats.org/packages/wavedl)
|
|
16
16
|
[](LICENSE)
|
|
17
17
|
[](https://doi.org/10.5281/zenodo.18012338)
|
|
18
18
|
|
|
@@ -417,7 +417,43 @@ WaveDL/
|
|
|
417
417
|
| **U-Net** — U-shaped Network |||
|
|
418
418
|
| `unet_regression` | 31.1M | 1D/2D/3D |
|
|
419
419
|
|
|
420
|
-
|
|
420
|
+
⭐ = **Pretrained on ImageNet** (recommended for smaller datasets). Weights are downloaded automatically on first use.
|
|
421
|
+
- **Cache location**: `~/.cache/torch/hub/checkpoints/` (or `./.torch_cache/` on HPC if home is not writable)
|
|
422
|
+
- **Size**: ~20–350 MB per model depending on architecture
|
|
423
|
+
|
|
424
|
+
**💡 HPC Users**: If compute nodes block internet, pre-download weights on the login node:
|
|
425
|
+
|
|
426
|
+
```bash
|
|
427
|
+
# Run once on login node (with internet) — downloads ALL pretrained weights (~1.5 GB total)
|
|
428
|
+
python -c "
|
|
429
|
+
import os
|
|
430
|
+
os.environ['TORCH_HOME'] = '.torch_cache' # Match WaveDL's HPC cache location
|
|
431
|
+
|
|
432
|
+
from torchvision import models as m
|
|
433
|
+
from torchvision.models import video as v
|
|
434
|
+
|
|
435
|
+
# Model name -> Weights class mapping
|
|
436
|
+
weights = {
|
|
437
|
+
'resnet18': m.ResNet18_Weights, 'resnet50': m.ResNet50_Weights,
|
|
438
|
+
'efficientnet_b0': m.EfficientNet_B0_Weights, 'efficientnet_b1': m.EfficientNet_B1_Weights,
|
|
439
|
+
'efficientnet_b2': m.EfficientNet_B2_Weights, 'efficientnet_v2_s': m.EfficientNet_V2_S_Weights,
|
|
440
|
+
'efficientnet_v2_m': m.EfficientNet_V2_M_Weights, 'efficientnet_v2_l': m.EfficientNet_V2_L_Weights,
|
|
441
|
+
'mobilenet_v3_small': m.MobileNet_V3_Small_Weights, 'mobilenet_v3_large': m.MobileNet_V3_Large_Weights,
|
|
442
|
+
'regnet_y_400mf': m.RegNet_Y_400MF_Weights, 'regnet_y_800mf': m.RegNet_Y_800MF_Weights,
|
|
443
|
+
'regnet_y_1_6gf': m.RegNet_Y_1_6GF_Weights, 'regnet_y_3_2gf': m.RegNet_Y_3_2GF_Weights,
|
|
444
|
+
'regnet_y_8gf': m.RegNet_Y_8GF_Weights, 'swin_t': m.Swin_T_Weights, 'swin_s': m.Swin_S_Weights,
|
|
445
|
+
'swin_b': m.Swin_B_Weights, 'convnext_tiny': m.ConvNeXt_Tiny_Weights, 'densenet121': m.DenseNet121_Weights,
|
|
446
|
+
}
|
|
447
|
+
for name, w in weights.items():
|
|
448
|
+
getattr(m, name)(weights=w.DEFAULT); print(f'✓ {name}')
|
|
449
|
+
|
|
450
|
+
# 3D video models
|
|
451
|
+
v.r3d_18(weights=v.R3D_18_Weights.DEFAULT); print('✓ r3d_18')
|
|
452
|
+
v.mc3_18(weights=v.MC3_18_Weights.DEFAULT); print('✓ mc3_18')
|
|
453
|
+
print('\\n✓ All pretrained weights cached!')
|
|
454
|
+
"
|
|
455
|
+
```
|
|
456
|
+
|
|
421
457
|
|
|
422
458
|
</details>
|
|
423
459
|
|
|
@@ -642,7 +678,6 @@ compile: false
|
|
|
642
678
|
seed: 2025
|
|
643
679
|
```
|
|
644
680
|
|
|
645
|
-
> [!TIP]
|
|
646
681
|
> See [`configs/config.yaml`](configs/config.yaml) for the complete template with all available options documented.
|
|
647
682
|
|
|
648
683
|
</details>
|
|
@@ -654,18 +689,20 @@ Automatically find the best training configuration using [Optuna](https://optuna
|
|
|
654
689
|
|
|
655
690
|
**Run HPO:**
|
|
656
691
|
|
|
657
|
-
You specify which models to search and how many trials to run:
|
|
658
692
|
```bash
|
|
659
|
-
#
|
|
660
|
-
|
|
693
|
+
# Basic HPO (auto-detects GPUs for parallel trials)
|
|
694
|
+
wavedl-hpo --data_path train.npz --models cnn --n_trials 100
|
|
661
695
|
|
|
662
|
-
# Search
|
|
663
|
-
|
|
696
|
+
# Search multiple models
|
|
697
|
+
wavedl-hpo --data_path train.npz --models cnn resnet18 efficientnet_b0 --n_trials 200
|
|
664
698
|
|
|
665
|
-
#
|
|
666
|
-
|
|
699
|
+
# Quick mode (fewer parameters, faster)
|
|
700
|
+
wavedl-hpo --data_path train.npz --models cnn --n_trials 50 --quick
|
|
667
701
|
```
|
|
668
702
|
|
|
703
|
+
> [!TIP]
|
|
704
|
+
> **Auto GPU Detection**: HPO automatically detects available GPUs and runs one trial per GPU in parallel. On a 4-GPU system, 4 trials run simultaneously. Use `--n_jobs 1` to force serial execution.
|
|
705
|
+
|
|
669
706
|
**Train with best parameters**
|
|
670
707
|
|
|
671
708
|
After HPO completes, it prints the optimal command:
|
|
@@ -704,11 +741,11 @@ accelerate launch -m wavedl.train --data_path train.npz --model cnn --lr 3.2e-4
|
|
|
704
741
|
| `--optimizers` | all 6 | Optimizers to search |
|
|
705
742
|
| `--schedulers` | all 8 | Schedulers to search |
|
|
706
743
|
| `--losses` | all 6 | Losses to search |
|
|
707
|
-
| `--n_jobs` |
|
|
744
|
+
| `--n_jobs` | `-1` | Parallel trials (-1 = auto-detect GPUs) |
|
|
708
745
|
| `--max_epochs` | `50` | Max epochs per trial |
|
|
709
746
|
| `--output` | `hpo_results.json` | Output file |
|
|
710
747
|
|
|
711
|
-
|
|
748
|
+
|
|
712
749
|
> See [Available Models](#available-models) for all 38 architectures you can search.
|
|
713
750
|
|
|
714
751
|
</details>
|
|
@@ -174,7 +174,9 @@ Environment Variables:
|
|
|
174
174
|
return args, remaining
|
|
175
175
|
|
|
176
176
|
|
|
177
|
-
def print_summary(
|
|
177
|
+
def print_summary(
|
|
178
|
+
exit_code: int, wandb_enabled: bool, wandb_mode: str, wandb_dir: str
|
|
179
|
+
) -> None:
|
|
178
180
|
"""Print post-training summary and instructions."""
|
|
179
181
|
print()
|
|
180
182
|
print("=" * 40)
|
|
@@ -183,7 +185,8 @@ def print_summary(exit_code: int, wandb_mode: str, wandb_dir: str) -> None:
|
|
|
183
185
|
print("✅ Training completed successfully!")
|
|
184
186
|
print("=" * 40)
|
|
185
187
|
|
|
186
|
-
if
|
|
188
|
+
# Only show WandB sync instructions if user enabled wandb
|
|
189
|
+
if wandb_enabled and wandb_mode == "offline":
|
|
187
190
|
print()
|
|
188
191
|
print("📊 WandB Sync Instructions:")
|
|
189
192
|
print(" From the login node, run:")
|
|
@@ -237,6 +240,10 @@ def main() -> int:
|
|
|
237
240
|
f"--dynamo_backend={args.dynamo_backend}",
|
|
238
241
|
]
|
|
239
242
|
|
|
243
|
+
# Explicitly set multi_gpu to suppress accelerate auto-detection warning
|
|
244
|
+
if num_gpus > 1:
|
|
245
|
+
cmd.append("--multi_gpu")
|
|
246
|
+
|
|
240
247
|
# Add multi-node networking args if specified (required for some clusters)
|
|
241
248
|
if args.main_process_ip:
|
|
242
249
|
cmd.append(f"--main_process_ip={args.main_process_ip}")
|
|
@@ -263,8 +270,10 @@ def main() -> int:
|
|
|
263
270
|
exit_code = 130
|
|
264
271
|
|
|
265
272
|
# Print summary
|
|
273
|
+
wandb_enabled = "--wandb" in train_args
|
|
266
274
|
print_summary(
|
|
267
275
|
exit_code,
|
|
276
|
+
wandb_enabled,
|
|
268
277
|
os.environ.get("WANDB_MODE", "offline"),
|
|
269
278
|
os.environ.get("WANDB_DIR", "/tmp/wandb"),
|
|
270
279
|
)
|
|
@@ -31,7 +31,7 @@ try:
|
|
|
31
31
|
import optuna
|
|
32
32
|
from optuna.trial import TrialState
|
|
33
33
|
except ImportError:
|
|
34
|
-
print("Error: Optuna not installed. Run: pip install
|
|
34
|
+
print("Error: Optuna not installed. Run: pip install wavedl")
|
|
35
35
|
sys.exit(1)
|
|
36
36
|
|
|
37
37
|
|
|
@@ -147,6 +147,32 @@ def create_objective(args):
|
|
|
147
147
|
cmd.extend(["--output_dir", tmpdir])
|
|
148
148
|
history_file = Path(tmpdir) / "training_history.csv"
|
|
149
149
|
|
|
150
|
+
# GPU isolation for parallel trials: assign each trial to a specific GPU
|
|
151
|
+
# This prevents multiple trials from competing for all GPUs
|
|
152
|
+
env = None
|
|
153
|
+
if args.n_jobs > 1:
|
|
154
|
+
import os
|
|
155
|
+
|
|
156
|
+
# Detect available GPUs
|
|
157
|
+
n_gpus = 1
|
|
158
|
+
try:
|
|
159
|
+
import subprocess as sp
|
|
160
|
+
|
|
161
|
+
result_gpu = sp.run(
|
|
162
|
+
["nvidia-smi", "--list-gpus"],
|
|
163
|
+
capture_output=True,
|
|
164
|
+
text=True,
|
|
165
|
+
)
|
|
166
|
+
if result_gpu.returncode == 0:
|
|
167
|
+
n_gpus = len(result_gpu.stdout.strip().split("\n"))
|
|
168
|
+
except Exception:
|
|
169
|
+
pass
|
|
170
|
+
|
|
171
|
+
# Assign trial to a specific GPU (round-robin)
|
|
172
|
+
gpu_id = trial.number % n_gpus
|
|
173
|
+
env = os.environ.copy()
|
|
174
|
+
env["CUDA_VISIBLE_DEVICES"] = str(gpu_id)
|
|
175
|
+
|
|
150
176
|
# Run training
|
|
151
177
|
try:
|
|
152
178
|
result = subprocess.run(
|
|
@@ -155,6 +181,7 @@ def create_objective(args):
|
|
|
155
181
|
text=True,
|
|
156
182
|
timeout=args.timeout,
|
|
157
183
|
cwd=Path(__file__).parent,
|
|
184
|
+
env=env,
|
|
158
185
|
)
|
|
159
186
|
|
|
160
187
|
# Read best val_loss from training_history.csv (reliable machine-readable)
|
|
@@ -248,7 +275,10 @@ Examples:
|
|
|
248
275
|
"--n_trials", type=int, default=50, help="Number of HPO trials (default: 50)"
|
|
249
276
|
)
|
|
250
277
|
parser.add_argument(
|
|
251
|
-
"--n_jobs",
|
|
278
|
+
"--n_jobs",
|
|
279
|
+
type=int,
|
|
280
|
+
default=-1,
|
|
281
|
+
help="Parallel trials (-1 = auto-detect GPUs, default: -1)",
|
|
252
282
|
)
|
|
253
283
|
parser.add_argument(
|
|
254
284
|
"--quick",
|
|
@@ -315,11 +345,30 @@ Examples:
|
|
|
315
345
|
|
|
316
346
|
args = parser.parse_args()
|
|
317
347
|
|
|
348
|
+
# Convert to absolute path (child processes may run in different cwd)
|
|
349
|
+
args.data_path = str(Path(args.data_path).resolve())
|
|
350
|
+
|
|
318
351
|
# Validate data path
|
|
319
352
|
if not Path(args.data_path).exists():
|
|
320
353
|
print(f"Error: Data file not found: {args.data_path}")
|
|
321
354
|
sys.exit(1)
|
|
322
355
|
|
|
356
|
+
# Auto-detect GPUs for n_jobs if not specified
|
|
357
|
+
if args.n_jobs == -1:
|
|
358
|
+
try:
|
|
359
|
+
result_gpu = subprocess.run(
|
|
360
|
+
["nvidia-smi", "--list-gpus"],
|
|
361
|
+
capture_output=True,
|
|
362
|
+
text=True,
|
|
363
|
+
)
|
|
364
|
+
if result_gpu.returncode == 0:
|
|
365
|
+
args.n_jobs = max(1, len(result_gpu.stdout.strip().split("\n")))
|
|
366
|
+
else:
|
|
367
|
+
args.n_jobs = 1
|
|
368
|
+
except Exception:
|
|
369
|
+
args.n_jobs = 1
|
|
370
|
+
print(f"Auto-detected {args.n_jobs} GPU(s) for parallel trials")
|
|
371
|
+
|
|
323
372
|
# Create study
|
|
324
373
|
print("=" * 60)
|
|
325
374
|
print("WaveDL Hyperparameter Optimization")
|
|
@@ -366,23 +366,37 @@ def load_checkpoint(
|
|
|
366
366
|
logging.info(f" Building model: {model_name}")
|
|
367
367
|
model = build_model(model_name, in_shape=in_shape, out_size=out_size)
|
|
368
368
|
|
|
369
|
-
# Load weights (
|
|
370
|
-
weight_path =
|
|
371
|
-
|
|
372
|
-
|
|
373
|
-
|
|
374
|
-
|
|
375
|
-
|
|
369
|
+
# Load weights (check multiple formats in order of preference)
|
|
370
|
+
weight_path = None
|
|
371
|
+
for fname in ["model.safetensors", "model.bin", "pytorch_model.bin"]:
|
|
372
|
+
candidate = checkpoint_dir / fname
|
|
373
|
+
if candidate.exists():
|
|
374
|
+
weight_path = candidate
|
|
375
|
+
break
|
|
376
|
+
|
|
377
|
+
if weight_path is None:
|
|
378
|
+
raise FileNotFoundError(
|
|
379
|
+
f"No model weights found in {checkpoint_dir}. "
|
|
380
|
+
f"Expected one of: model.safetensors, model.bin, pytorch_model.bin"
|
|
381
|
+
)
|
|
376
382
|
|
|
377
383
|
if HAS_SAFETENSORS and weight_path.suffix == ".safetensors":
|
|
378
384
|
state_dict = load_safetensors(str(weight_path))
|
|
379
385
|
else:
|
|
380
386
|
state_dict = torch.load(weight_path, map_location="cpu", weights_only=True)
|
|
381
387
|
|
|
382
|
-
# Remove
|
|
383
|
-
|
|
384
|
-
|
|
385
|
-
}
|
|
388
|
+
# Remove wrapper prefixes from checkpoints:
|
|
389
|
+
# - 'module.' from DDP (DistributedDataParallel)
|
|
390
|
+
# - '_orig_mod.' from torch.compile()
|
|
391
|
+
cleaned_dict = {}
|
|
392
|
+
for k, v in state_dict.items():
|
|
393
|
+
key = k
|
|
394
|
+
if key.startswith("module."):
|
|
395
|
+
key = key[7:] # Remove 'module.' (7 chars)
|
|
396
|
+
if key.startswith("_orig_mod."):
|
|
397
|
+
key = key[10:] # Remove '_orig_mod.' (10 chars)
|
|
398
|
+
cleaned_dict[key] = v
|
|
399
|
+
state_dict = cleaned_dict
|
|
386
400
|
|
|
387
401
|
model.load_state_dict(state_dict)
|
|
388
402
|
model.eval()
|
|
@@ -148,6 +148,24 @@ torch.set_float32_matmul_precision("high") # Use TF32 for float32 ops
|
|
|
148
148
|
torch.backends.cudnn.benchmark = True
|
|
149
149
|
|
|
150
150
|
|
|
151
|
+
# ==============================================================================
|
|
152
|
+
# LOGGING UTILITIES
|
|
153
|
+
# ==============================================================================
|
|
154
|
+
from contextlib import contextmanager
|
|
155
|
+
|
|
156
|
+
|
|
157
|
+
@contextmanager
|
|
158
|
+
def suppress_accelerate_logging():
|
|
159
|
+
"""Temporarily suppress accelerate's verbose checkpoint save messages."""
|
|
160
|
+
accelerate_logger = logging.getLogger("accelerate.checkpointing")
|
|
161
|
+
original_level = accelerate_logger.level
|
|
162
|
+
accelerate_logger.setLevel(logging.WARNING)
|
|
163
|
+
try:
|
|
164
|
+
yield
|
|
165
|
+
finally:
|
|
166
|
+
accelerate_logger.setLevel(original_level)
|
|
167
|
+
|
|
168
|
+
|
|
151
169
|
# ==============================================================================
|
|
152
170
|
# ARGUMENT PARSING
|
|
153
171
|
# ==============================================================================
|
|
@@ -1033,7 +1051,8 @@ def main():
|
|
|
1033
1051
|
# Step 3: Save checkpoint with all ranks participating
|
|
1034
1052
|
if is_best_epoch:
|
|
1035
1053
|
ckpt_dir = os.path.join(args.output_dir, "best_checkpoint")
|
|
1036
|
-
|
|
1054
|
+
with suppress_accelerate_logging():
|
|
1055
|
+
accelerator.save_state(ckpt_dir, safe_serialization=False)
|
|
1037
1056
|
|
|
1038
1057
|
# Step 4: Rank 0 handles metadata and updates tracking variables
|
|
1039
1058
|
if accelerator.is_main_process:
|
|
@@ -1096,7 +1115,8 @@ def main():
|
|
|
1096
1115
|
if periodic_checkpoint_needed:
|
|
1097
1116
|
ckpt_name = f"epoch_{epoch + 1}_checkpoint"
|
|
1098
1117
|
ckpt_dir = os.path.join(args.output_dir, ckpt_name)
|
|
1099
|
-
|
|
1118
|
+
with suppress_accelerate_logging():
|
|
1119
|
+
accelerator.save_state(ckpt_dir, safe_serialization=False)
|
|
1100
1120
|
|
|
1101
1121
|
if accelerator.is_main_process:
|
|
1102
1122
|
with open(os.path.join(ckpt_dir, "training_meta.pkl"), "wb") as f:
|
|
@@ -1147,7 +1167,11 @@ def main():
|
|
|
1147
1167
|
|
|
1148
1168
|
except KeyboardInterrupt:
|
|
1149
1169
|
logger.warning("Training interrupted. Saving emergency checkpoint...")
|
|
1150
|
-
|
|
1170
|
+
with suppress_accelerate_logging():
|
|
1171
|
+
accelerator.save_state(
|
|
1172
|
+
os.path.join(args.output_dir, "interrupted_checkpoint"),
|
|
1173
|
+
safe_serialization=False,
|
|
1174
|
+
)
|
|
1151
1175
|
|
|
1152
1176
|
except Exception as e:
|
|
1153
1177
|
logger.error(f"Critical error: {e}", exc_info=True)
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.2
|
|
2
2
|
Name: wavedl
|
|
3
|
-
Version: 1.4.
|
|
3
|
+
Version: 1.4.6
|
|
4
4
|
Summary: A Scalable Deep Learning Framework for Wave-Based Inverse Problems
|
|
5
5
|
Author: Ductho Le
|
|
6
6
|
License: MIT
|
|
@@ -49,7 +49,7 @@ Requires-Dist: triton>=2.0.0; sys_platform == "linux"
|
|
|
49
49
|
|
|
50
50
|
### A Scalable Deep Learning Framework for Wave-Based Inverse Problems
|
|
51
51
|
|
|
52
|
-
[](https://www.python.org/downloads/)
|
|
53
53
|
[](https://pytorch.org/)
|
|
54
54
|
[](https://huggingface.co/docs/accelerate/)
|
|
55
55
|
<br>
|
|
@@ -57,7 +57,7 @@ Requires-Dist: triton>=2.0.0; sys_platform == "linux"
|
|
|
57
57
|
[](https://github.com/ductho-le/WaveDL/actions/workflows/lint.yml)
|
|
58
58
|
[](https://colab.research.google.com/github/ductho-le/WaveDL/blob/main/notebooks/demo.ipynb)
|
|
59
59
|
<br>
|
|
60
|
-
[](https://pypistats.org/packages/wavedl)
|
|
61
61
|
[](LICENSE)
|
|
62
62
|
[](https://doi.org/10.5281/zenodo.18012338)
|
|
63
63
|
|
|
@@ -462,7 +462,43 @@ WaveDL/
|
|
|
462
462
|
| **U-Net** — U-shaped Network |||
|
|
463
463
|
| `unet_regression` | 31.1M | 1D/2D/3D |
|
|
464
464
|
|
|
465
|
-
|
|
465
|
+
⭐ = **Pretrained on ImageNet** (recommended for smaller datasets). Weights are downloaded automatically on first use.
|
|
466
|
+
- **Cache location**: `~/.cache/torch/hub/checkpoints/` (or `./.torch_cache/` on HPC if home is not writable)
|
|
467
|
+
- **Size**: ~20–350 MB per model depending on architecture
|
|
468
|
+
|
|
469
|
+
**💡 HPC Users**: If compute nodes block internet, pre-download weights on the login node:
|
|
470
|
+
|
|
471
|
+
```bash
|
|
472
|
+
# Run once on login node (with internet) — downloads ALL pretrained weights (~1.5 GB total)
|
|
473
|
+
python -c "
|
|
474
|
+
import os
|
|
475
|
+
os.environ['TORCH_HOME'] = '.torch_cache' # Match WaveDL's HPC cache location
|
|
476
|
+
|
|
477
|
+
from torchvision import models as m
|
|
478
|
+
from torchvision.models import video as v
|
|
479
|
+
|
|
480
|
+
# Model name -> Weights class mapping
|
|
481
|
+
weights = {
|
|
482
|
+
'resnet18': m.ResNet18_Weights, 'resnet50': m.ResNet50_Weights,
|
|
483
|
+
'efficientnet_b0': m.EfficientNet_B0_Weights, 'efficientnet_b1': m.EfficientNet_B1_Weights,
|
|
484
|
+
'efficientnet_b2': m.EfficientNet_B2_Weights, 'efficientnet_v2_s': m.EfficientNet_V2_S_Weights,
|
|
485
|
+
'efficientnet_v2_m': m.EfficientNet_V2_M_Weights, 'efficientnet_v2_l': m.EfficientNet_V2_L_Weights,
|
|
486
|
+
'mobilenet_v3_small': m.MobileNet_V3_Small_Weights, 'mobilenet_v3_large': m.MobileNet_V3_Large_Weights,
|
|
487
|
+
'regnet_y_400mf': m.RegNet_Y_400MF_Weights, 'regnet_y_800mf': m.RegNet_Y_800MF_Weights,
|
|
488
|
+
'regnet_y_1_6gf': m.RegNet_Y_1_6GF_Weights, 'regnet_y_3_2gf': m.RegNet_Y_3_2GF_Weights,
|
|
489
|
+
'regnet_y_8gf': m.RegNet_Y_8GF_Weights, 'swin_t': m.Swin_T_Weights, 'swin_s': m.Swin_S_Weights,
|
|
490
|
+
'swin_b': m.Swin_B_Weights, 'convnext_tiny': m.ConvNeXt_Tiny_Weights, 'densenet121': m.DenseNet121_Weights,
|
|
491
|
+
}
|
|
492
|
+
for name, w in weights.items():
|
|
493
|
+
getattr(m, name)(weights=w.DEFAULT); print(f'✓ {name}')
|
|
494
|
+
|
|
495
|
+
# 3D video models
|
|
496
|
+
v.r3d_18(weights=v.R3D_18_Weights.DEFAULT); print('✓ r3d_18')
|
|
497
|
+
v.mc3_18(weights=v.MC3_18_Weights.DEFAULT); print('✓ mc3_18')
|
|
498
|
+
print('\\n✓ All pretrained weights cached!')
|
|
499
|
+
"
|
|
500
|
+
```
|
|
501
|
+
|
|
466
502
|
|
|
467
503
|
</details>
|
|
468
504
|
|
|
@@ -687,7 +723,6 @@ compile: false
|
|
|
687
723
|
seed: 2025
|
|
688
724
|
```
|
|
689
725
|
|
|
690
|
-
> [!TIP]
|
|
691
726
|
> See [`configs/config.yaml`](configs/config.yaml) for the complete template with all available options documented.
|
|
692
727
|
|
|
693
728
|
</details>
|
|
@@ -699,18 +734,20 @@ Automatically find the best training configuration using [Optuna](https://optuna
|
|
|
699
734
|
|
|
700
735
|
**Run HPO:**
|
|
701
736
|
|
|
702
|
-
You specify which models to search and how many trials to run:
|
|
703
737
|
```bash
|
|
704
|
-
#
|
|
705
|
-
|
|
738
|
+
# Basic HPO (auto-detects GPUs for parallel trials)
|
|
739
|
+
wavedl-hpo --data_path train.npz --models cnn --n_trials 100
|
|
706
740
|
|
|
707
|
-
# Search
|
|
708
|
-
|
|
741
|
+
# Search multiple models
|
|
742
|
+
wavedl-hpo --data_path train.npz --models cnn resnet18 efficientnet_b0 --n_trials 200
|
|
709
743
|
|
|
710
|
-
#
|
|
711
|
-
|
|
744
|
+
# Quick mode (fewer parameters, faster)
|
|
745
|
+
wavedl-hpo --data_path train.npz --models cnn --n_trials 50 --quick
|
|
712
746
|
```
|
|
713
747
|
|
|
748
|
+
> [!TIP]
|
|
749
|
+
> **Auto GPU Detection**: HPO automatically detects available GPUs and runs one trial per GPU in parallel. On a 4-GPU system, 4 trials run simultaneously. Use `--n_jobs 1` to force serial execution.
|
|
750
|
+
|
|
714
751
|
**Train with best parameters**
|
|
715
752
|
|
|
716
753
|
After HPO completes, it prints the optimal command:
|
|
@@ -749,11 +786,11 @@ accelerate launch -m wavedl.train --data_path train.npz --model cnn --lr 3.2e-4
|
|
|
749
786
|
| `--optimizers` | all 6 | Optimizers to search |
|
|
750
787
|
| `--schedulers` | all 8 | Schedulers to search |
|
|
751
788
|
| `--losses` | all 6 | Losses to search |
|
|
752
|
-
| `--n_jobs` |
|
|
789
|
+
| `--n_jobs` | `-1` | Parallel trials (-1 = auto-detect GPUs) |
|
|
753
790
|
| `--max_epochs` | `50` | Max epochs per trial |
|
|
754
791
|
| `--output` | `hpo_results.json` | Output file |
|
|
755
792
|
|
|
756
|
-
|
|
793
|
+
|
|
757
794
|
> See [Available Models](#available-models) for all 38 architectures you can search.
|
|
758
795
|
|
|
759
796
|
</details>
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|