wavedl 1.4.4__tar.gz → 1.4.6__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (42) hide show
  1. {wavedl-1.4.4/src/wavedl.egg-info → wavedl-1.4.6}/PKG-INFO +51 -14
  2. {wavedl-1.4.4 → wavedl-1.4.6}/README.md +50 -13
  3. {wavedl-1.4.4 → wavedl-1.4.6}/src/wavedl/__init__.py +1 -1
  4. {wavedl-1.4.4 → wavedl-1.4.6}/src/wavedl/hpc.py +11 -2
  5. {wavedl-1.4.4 → wavedl-1.4.6}/src/wavedl/hpo.py +51 -2
  6. {wavedl-1.4.4 → wavedl-1.4.6}/src/wavedl/test.py +25 -11
  7. {wavedl-1.4.4 → wavedl-1.4.6}/src/wavedl/train.py +27 -3
  8. {wavedl-1.4.4 → wavedl-1.4.6/src/wavedl.egg-info}/PKG-INFO +51 -14
  9. {wavedl-1.4.4 → wavedl-1.4.6}/LICENSE +0 -0
  10. {wavedl-1.4.4 → wavedl-1.4.6}/pyproject.toml +0 -0
  11. {wavedl-1.4.4 → wavedl-1.4.6}/setup.cfg +0 -0
  12. {wavedl-1.4.4 → wavedl-1.4.6}/src/wavedl/models/__init__.py +0 -0
  13. {wavedl-1.4.4 → wavedl-1.4.6}/src/wavedl/models/_template.py +0 -0
  14. {wavedl-1.4.4 → wavedl-1.4.6}/src/wavedl/models/base.py +0 -0
  15. {wavedl-1.4.4 → wavedl-1.4.6}/src/wavedl/models/cnn.py +0 -0
  16. {wavedl-1.4.4 → wavedl-1.4.6}/src/wavedl/models/convnext.py +0 -0
  17. {wavedl-1.4.4 → wavedl-1.4.6}/src/wavedl/models/densenet.py +0 -0
  18. {wavedl-1.4.4 → wavedl-1.4.6}/src/wavedl/models/efficientnet.py +0 -0
  19. {wavedl-1.4.4 → wavedl-1.4.6}/src/wavedl/models/efficientnetv2.py +0 -0
  20. {wavedl-1.4.4 → wavedl-1.4.6}/src/wavedl/models/mobilenetv3.py +0 -0
  21. {wavedl-1.4.4 → wavedl-1.4.6}/src/wavedl/models/registry.py +0 -0
  22. {wavedl-1.4.4 → wavedl-1.4.6}/src/wavedl/models/regnet.py +0 -0
  23. {wavedl-1.4.4 → wavedl-1.4.6}/src/wavedl/models/resnet.py +0 -0
  24. {wavedl-1.4.4 → wavedl-1.4.6}/src/wavedl/models/resnet3d.py +0 -0
  25. {wavedl-1.4.4 → wavedl-1.4.6}/src/wavedl/models/swin.py +0 -0
  26. {wavedl-1.4.4 → wavedl-1.4.6}/src/wavedl/models/tcn.py +0 -0
  27. {wavedl-1.4.4 → wavedl-1.4.6}/src/wavedl/models/unet.py +0 -0
  28. {wavedl-1.4.4 → wavedl-1.4.6}/src/wavedl/models/vit.py +0 -0
  29. {wavedl-1.4.4 → wavedl-1.4.6}/src/wavedl/utils/__init__.py +0 -0
  30. {wavedl-1.4.4 → wavedl-1.4.6}/src/wavedl/utils/config.py +0 -0
  31. {wavedl-1.4.4 → wavedl-1.4.6}/src/wavedl/utils/cross_validation.py +0 -0
  32. {wavedl-1.4.4 → wavedl-1.4.6}/src/wavedl/utils/data.py +0 -0
  33. {wavedl-1.4.4 → wavedl-1.4.6}/src/wavedl/utils/distributed.py +0 -0
  34. {wavedl-1.4.4 → wavedl-1.4.6}/src/wavedl/utils/losses.py +0 -0
  35. {wavedl-1.4.4 → wavedl-1.4.6}/src/wavedl/utils/metrics.py +0 -0
  36. {wavedl-1.4.4 → wavedl-1.4.6}/src/wavedl/utils/optimizers.py +0 -0
  37. {wavedl-1.4.4 → wavedl-1.4.6}/src/wavedl/utils/schedulers.py +0 -0
  38. {wavedl-1.4.4 → wavedl-1.4.6}/src/wavedl.egg-info/SOURCES.txt +0 -0
  39. {wavedl-1.4.4 → wavedl-1.4.6}/src/wavedl.egg-info/dependency_links.txt +0 -0
  40. {wavedl-1.4.4 → wavedl-1.4.6}/src/wavedl.egg-info/entry_points.txt +0 -0
  41. {wavedl-1.4.4 → wavedl-1.4.6}/src/wavedl.egg-info/requires.txt +0 -0
  42. {wavedl-1.4.4 → wavedl-1.4.6}/src/wavedl.egg-info/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.2
2
2
  Name: wavedl
3
- Version: 1.4.4
3
+ Version: 1.4.6
4
4
  Summary: A Scalable Deep Learning Framework for Wave-Based Inverse Problems
5
5
  Author: Ductho Le
6
6
  License: MIT
@@ -49,7 +49,7 @@ Requires-Dist: triton>=2.0.0; sys_platform == "linux"
49
49
 
50
50
  ### A Scalable Deep Learning Framework for Wave-Based Inverse Problems
51
51
 
52
- [![Python 3.11+](https://img.shields.io/badge/python-3.11+-blue.svg?style=plastic&logo=python&logoColor=white)](https://www.python.org/downloads/)
52
+ [![Python 3.11+](https://img.shields.io/badge/Python-3.11+-blue.svg?style=plastic&logo=python&logoColor=white)](https://www.python.org/downloads/)
53
53
  [![PyTorch 2.x](https://img.shields.io/badge/PyTorch-2.x-ee4c2c.svg?style=plastic&logo=pytorch&logoColor=white)](https://pytorch.org/)
54
54
  [![Accelerate](https://img.shields.io/badge/Accelerate-Enabled-yellow.svg?style=plastic&logo=huggingface&logoColor=white)](https://huggingface.co/docs/accelerate/)
55
55
  <br>
@@ -57,7 +57,7 @@ Requires-Dist: triton>=2.0.0; sys_platform == "linux"
57
57
  [![Lint](https://img.shields.io/github/actions/workflow/status/ductho-le/WaveDL/lint.yml?branch=main&style=plastic&logo=ruff&logoColor=white&label=Lint)](https://github.com/ductho-le/WaveDL/actions/workflows/lint.yml)
58
58
  [![Try it on Colab](https://img.shields.io/badge/Try_it_on_Colab-8E44AD?style=plastic&logo=googlecolab&logoColor=white)](https://colab.research.google.com/github/ductho-le/WaveDL/blob/main/notebooks/demo.ipynb)
59
59
  <br>
60
- [![Downloads](https://img.shields.io/pepy/dt/wavedl?style=plastic&logo=pypi&logoColor=white&color=9ACD32)](https://pepy.tech/project/wavedl)
60
+ [![Downloads](https://img.shields.io/badge/dynamic/json?url=https://pypistats.org/api/packages/wavedl/recent?period=month%26mirrors=false&query=data.last_month&style=plastic&logo=pypi&logoColor=white&color=9ACD32&label=Downloads&suffix=/month)](https://pypistats.org/packages/wavedl)
61
61
  [![License: MIT](https://img.shields.io/badge/License-MIT-orange.svg?style=plastic)](LICENSE)
62
62
  [![DOI](https://img.shields.io/badge/DOI-10.5281/zenodo.18012338-008080.svg?style=plastic)](https://doi.org/10.5281/zenodo.18012338)
63
63
 
@@ -462,7 +462,43 @@ WaveDL/
462
462
  | **U-Net** — U-shaped Network |||
463
463
  | `unet_regression` | 31.1M | 1D/2D/3D |
464
464
 
465
- > ⭐ = Pretrained on ImageNet. Recommended for smaller datasets.
465
+ ⭐ = **Pretrained on ImageNet** (recommended for smaller datasets). Weights are downloaded automatically on first use.
466
+ - **Cache location**: `~/.cache/torch/hub/checkpoints/` (or `./.torch_cache/` on HPC if home is not writable)
467
+ - **Size**: ~20–350 MB per model depending on architecture
468
+
469
+ **💡 HPC Users**: If compute nodes block internet, pre-download weights on the login node:
470
+
471
+ ```bash
472
+ # Run once on login node (with internet) — downloads ALL pretrained weights (~1.5 GB total)
473
+ python -c "
474
+ import os
475
+ os.environ['TORCH_HOME'] = '.torch_cache' # Match WaveDL's HPC cache location
476
+
477
+ from torchvision import models as m
478
+ from torchvision.models import video as v
479
+
480
+ # Model name -> Weights class mapping
481
+ weights = {
482
+ 'resnet18': m.ResNet18_Weights, 'resnet50': m.ResNet50_Weights,
483
+ 'efficientnet_b0': m.EfficientNet_B0_Weights, 'efficientnet_b1': m.EfficientNet_B1_Weights,
484
+ 'efficientnet_b2': m.EfficientNet_B2_Weights, 'efficientnet_v2_s': m.EfficientNet_V2_S_Weights,
485
+ 'efficientnet_v2_m': m.EfficientNet_V2_M_Weights, 'efficientnet_v2_l': m.EfficientNet_V2_L_Weights,
486
+ 'mobilenet_v3_small': m.MobileNet_V3_Small_Weights, 'mobilenet_v3_large': m.MobileNet_V3_Large_Weights,
487
+ 'regnet_y_400mf': m.RegNet_Y_400MF_Weights, 'regnet_y_800mf': m.RegNet_Y_800MF_Weights,
488
+ 'regnet_y_1_6gf': m.RegNet_Y_1_6GF_Weights, 'regnet_y_3_2gf': m.RegNet_Y_3_2GF_Weights,
489
+ 'regnet_y_8gf': m.RegNet_Y_8GF_Weights, 'swin_t': m.Swin_T_Weights, 'swin_s': m.Swin_S_Weights,
490
+ 'swin_b': m.Swin_B_Weights, 'convnext_tiny': m.ConvNeXt_Tiny_Weights, 'densenet121': m.DenseNet121_Weights,
491
+ }
492
+ for name, w in weights.items():
493
+ getattr(m, name)(weights=w.DEFAULT); print(f'✓ {name}')
494
+
495
+ # 3D video models
496
+ v.r3d_18(weights=v.R3D_18_Weights.DEFAULT); print('✓ r3d_18')
497
+ v.mc3_18(weights=v.MC3_18_Weights.DEFAULT); print('✓ mc3_18')
498
+ print('\\n✓ All pretrained weights cached!')
499
+ "
500
+ ```
501
+
466
502
 
467
503
  </details>
468
504
 
@@ -687,7 +723,6 @@ compile: false
687
723
  seed: 2025
688
724
  ```
689
725
 
690
- > [!TIP]
691
726
  > See [`configs/config.yaml`](configs/config.yaml) for the complete template with all available options documented.
692
727
 
693
728
  </details>
@@ -699,18 +734,20 @@ Automatically find the best training configuration using [Optuna](https://optuna
699
734
 
700
735
  **Run HPO:**
701
736
 
702
- You specify which models to search and how many trials to run:
703
737
  ```bash
704
- # Search 3 models with 100 trials
705
- python -m wavedl.hpo --data_path train.npz --models cnn resnet18 efficientnet_b0 --n_trials 100
738
+ # Basic HPO (auto-detects GPUs for parallel trials)
739
+ wavedl-hpo --data_path train.npz --models cnn --n_trials 100
706
740
 
707
- # Search 1 model (faster)
708
- python -m wavedl.hpo --data_path train.npz --models cnn --n_trials 50
741
+ # Search multiple models
742
+ wavedl-hpo --data_path train.npz --models cnn resnet18 efficientnet_b0 --n_trials 200
709
743
 
710
- # Search all your candidate models
711
- python -m wavedl.hpo --data_path train.npz --models cnn resnet18 resnet50 vit_small densenet121 --n_trials 200
744
+ # Quick mode (fewer parameters, faster)
745
+ wavedl-hpo --data_path train.npz --models cnn --n_trials 50 --quick
712
746
  ```
713
747
 
748
+ > [!TIP]
749
+ > **Auto GPU Detection**: HPO automatically detects available GPUs and runs one trial per GPU in parallel. On a 4-GPU system, 4 trials run simultaneously. Use `--n_jobs 1` to force serial execution.
750
+
714
751
  **Train with best parameters**
715
752
 
716
753
  After HPO completes, it prints the optimal command:
@@ -749,11 +786,11 @@ accelerate launch -m wavedl.train --data_path train.npz --model cnn --lr 3.2e-4
749
786
  | `--optimizers` | all 6 | Optimizers to search |
750
787
  | `--schedulers` | all 8 | Schedulers to search |
751
788
  | `--losses` | all 6 | Losses to search |
752
- | `--n_jobs` | `1` | Parallel trials (multi-GPU) |
789
+ | `--n_jobs` | `-1` | Parallel trials (-1 = auto-detect GPUs) |
753
790
  | `--max_epochs` | `50` | Max epochs per trial |
754
791
  | `--output` | `hpo_results.json` | Output file |
755
792
 
756
- > [!TIP]
793
+
757
794
  > See [Available Models](#available-models) for all 38 architectures you can search.
758
795
 
759
796
  </details>
@@ -4,7 +4,7 @@
4
4
 
5
5
  ### A Scalable Deep Learning Framework for Wave-Based Inverse Problems
6
6
 
7
- [![Python 3.11+](https://img.shields.io/badge/python-3.11+-blue.svg?style=plastic&logo=python&logoColor=white)](https://www.python.org/downloads/)
7
+ [![Python 3.11+](https://img.shields.io/badge/Python-3.11+-blue.svg?style=plastic&logo=python&logoColor=white)](https://www.python.org/downloads/)
8
8
  [![PyTorch 2.x](https://img.shields.io/badge/PyTorch-2.x-ee4c2c.svg?style=plastic&logo=pytorch&logoColor=white)](https://pytorch.org/)
9
9
  [![Accelerate](https://img.shields.io/badge/Accelerate-Enabled-yellow.svg?style=plastic&logo=huggingface&logoColor=white)](https://huggingface.co/docs/accelerate/)
10
10
  <br>
@@ -12,7 +12,7 @@
12
12
  [![Lint](https://img.shields.io/github/actions/workflow/status/ductho-le/WaveDL/lint.yml?branch=main&style=plastic&logo=ruff&logoColor=white&label=Lint)](https://github.com/ductho-le/WaveDL/actions/workflows/lint.yml)
13
13
  [![Try it on Colab](https://img.shields.io/badge/Try_it_on_Colab-8E44AD?style=plastic&logo=googlecolab&logoColor=white)](https://colab.research.google.com/github/ductho-le/WaveDL/blob/main/notebooks/demo.ipynb)
14
14
  <br>
15
- [![Downloads](https://img.shields.io/pepy/dt/wavedl?style=plastic&logo=pypi&logoColor=white&color=9ACD32)](https://pepy.tech/project/wavedl)
15
+ [![Downloads](https://img.shields.io/badge/dynamic/json?url=https://pypistats.org/api/packages/wavedl/recent?period=month%26mirrors=false&query=data.last_month&style=plastic&logo=pypi&logoColor=white&color=9ACD32&label=Downloads&suffix=/month)](https://pypistats.org/packages/wavedl)
16
16
  [![License: MIT](https://img.shields.io/badge/License-MIT-orange.svg?style=plastic)](LICENSE)
17
17
  [![DOI](https://img.shields.io/badge/DOI-10.5281/zenodo.18012338-008080.svg?style=plastic)](https://doi.org/10.5281/zenodo.18012338)
18
18
 
@@ -417,7 +417,43 @@ WaveDL/
417
417
  | **U-Net** — U-shaped Network |||
418
418
  | `unet_regression` | 31.1M | 1D/2D/3D |
419
419
 
420
- > ⭐ = Pretrained on ImageNet. Recommended for smaller datasets.
420
+ ⭐ = **Pretrained on ImageNet** (recommended for smaller datasets). Weights are downloaded automatically on first use.
421
+ - **Cache location**: `~/.cache/torch/hub/checkpoints/` (or `./.torch_cache/` on HPC if home is not writable)
422
+ - **Size**: ~20–350 MB per model depending on architecture
423
+
424
+ **💡 HPC Users**: If compute nodes block internet, pre-download weights on the login node:
425
+
426
+ ```bash
427
+ # Run once on login node (with internet) — downloads ALL pretrained weights (~1.5 GB total)
428
+ python -c "
429
+ import os
430
+ os.environ['TORCH_HOME'] = '.torch_cache' # Match WaveDL's HPC cache location
431
+
432
+ from torchvision import models as m
433
+ from torchvision.models import video as v
434
+
435
+ # Model name -> Weights class mapping
436
+ weights = {
437
+ 'resnet18': m.ResNet18_Weights, 'resnet50': m.ResNet50_Weights,
438
+ 'efficientnet_b0': m.EfficientNet_B0_Weights, 'efficientnet_b1': m.EfficientNet_B1_Weights,
439
+ 'efficientnet_b2': m.EfficientNet_B2_Weights, 'efficientnet_v2_s': m.EfficientNet_V2_S_Weights,
440
+ 'efficientnet_v2_m': m.EfficientNet_V2_M_Weights, 'efficientnet_v2_l': m.EfficientNet_V2_L_Weights,
441
+ 'mobilenet_v3_small': m.MobileNet_V3_Small_Weights, 'mobilenet_v3_large': m.MobileNet_V3_Large_Weights,
442
+ 'regnet_y_400mf': m.RegNet_Y_400MF_Weights, 'regnet_y_800mf': m.RegNet_Y_800MF_Weights,
443
+ 'regnet_y_1_6gf': m.RegNet_Y_1_6GF_Weights, 'regnet_y_3_2gf': m.RegNet_Y_3_2GF_Weights,
444
+ 'regnet_y_8gf': m.RegNet_Y_8GF_Weights, 'swin_t': m.Swin_T_Weights, 'swin_s': m.Swin_S_Weights,
445
+ 'swin_b': m.Swin_B_Weights, 'convnext_tiny': m.ConvNeXt_Tiny_Weights, 'densenet121': m.DenseNet121_Weights,
446
+ }
447
+ for name, w in weights.items():
448
+ getattr(m, name)(weights=w.DEFAULT); print(f'✓ {name}')
449
+
450
+ # 3D video models
451
+ v.r3d_18(weights=v.R3D_18_Weights.DEFAULT); print('✓ r3d_18')
452
+ v.mc3_18(weights=v.MC3_18_Weights.DEFAULT); print('✓ mc3_18')
453
+ print('\\n✓ All pretrained weights cached!')
454
+ "
455
+ ```
456
+
421
457
 
422
458
  </details>
423
459
 
@@ -642,7 +678,6 @@ compile: false
642
678
  seed: 2025
643
679
  ```
644
680
 
645
- > [!TIP]
646
681
  > See [`configs/config.yaml`](configs/config.yaml) for the complete template with all available options documented.
647
682
 
648
683
  </details>
@@ -654,18 +689,20 @@ Automatically find the best training configuration using [Optuna](https://optuna
654
689
 
655
690
  **Run HPO:**
656
691
 
657
- You specify which models to search and how many trials to run:
658
692
  ```bash
659
- # Search 3 models with 100 trials
660
- python -m wavedl.hpo --data_path train.npz --models cnn resnet18 efficientnet_b0 --n_trials 100
693
+ # Basic HPO (auto-detects GPUs for parallel trials)
694
+ wavedl-hpo --data_path train.npz --models cnn --n_trials 100
661
695
 
662
- # Search 1 model (faster)
663
- python -m wavedl.hpo --data_path train.npz --models cnn --n_trials 50
696
+ # Search multiple models
697
+ wavedl-hpo --data_path train.npz --models cnn resnet18 efficientnet_b0 --n_trials 200
664
698
 
665
- # Search all your candidate models
666
- python -m wavedl.hpo --data_path train.npz --models cnn resnet18 resnet50 vit_small densenet121 --n_trials 200
699
+ # Quick mode (fewer parameters, faster)
700
+ wavedl-hpo --data_path train.npz --models cnn --n_trials 50 --quick
667
701
  ```
668
702
 
703
+ > [!TIP]
704
+ > **Auto GPU Detection**: HPO automatically detects available GPUs and runs one trial per GPU in parallel. On a 4-GPU system, 4 trials run simultaneously. Use `--n_jobs 1` to force serial execution.
705
+
669
706
  **Train with best parameters**
670
707
 
671
708
  After HPO completes, it prints the optimal command:
@@ -704,11 +741,11 @@ accelerate launch -m wavedl.train --data_path train.npz --model cnn --lr 3.2e-4
704
741
  | `--optimizers` | all 6 | Optimizers to search |
705
742
  | `--schedulers` | all 8 | Schedulers to search |
706
743
  | `--losses` | all 6 | Losses to search |
707
- | `--n_jobs` | `1` | Parallel trials (multi-GPU) |
744
+ | `--n_jobs` | `-1` | Parallel trials (-1 = auto-detect GPUs) |
708
745
  | `--max_epochs` | `50` | Max epochs per trial |
709
746
  | `--output` | `hpo_results.json` | Output file |
710
747
 
711
- > [!TIP]
748
+
712
749
  > See [Available Models](#available-models) for all 38 architectures you can search.
713
750
 
714
751
  </details>
@@ -18,7 +18,7 @@ For inference:
18
18
  # or: python -m wavedl.test --checkpoint best_checkpoint --data_path test.npz
19
19
  """
20
20
 
21
- __version__ = "1.4.4"
21
+ __version__ = "1.4.6"
22
22
  __author__ = "Ductho Le"
23
23
  __email__ = "ductho.le@outlook.com"
24
24
 
@@ -174,7 +174,9 @@ Environment Variables:
174
174
  return args, remaining
175
175
 
176
176
 
177
- def print_summary(exit_code: int, wandb_mode: str, wandb_dir: str) -> None:
177
+ def print_summary(
178
+ exit_code: int, wandb_enabled: bool, wandb_mode: str, wandb_dir: str
179
+ ) -> None:
178
180
  """Print post-training summary and instructions."""
179
181
  print()
180
182
  print("=" * 40)
@@ -183,7 +185,8 @@ def print_summary(exit_code: int, wandb_mode: str, wandb_dir: str) -> None:
183
185
  print("✅ Training completed successfully!")
184
186
  print("=" * 40)
185
187
 
186
- if wandb_mode == "offline":
188
+ # Only show WandB sync instructions if user enabled wandb
189
+ if wandb_enabled and wandb_mode == "offline":
187
190
  print()
188
191
  print("📊 WandB Sync Instructions:")
189
192
  print(" From the login node, run:")
@@ -237,6 +240,10 @@ def main() -> int:
237
240
  f"--dynamo_backend={args.dynamo_backend}",
238
241
  ]
239
242
 
243
+ # Explicitly set multi_gpu to suppress accelerate auto-detection warning
244
+ if num_gpus > 1:
245
+ cmd.append("--multi_gpu")
246
+
240
247
  # Add multi-node networking args if specified (required for some clusters)
241
248
  if args.main_process_ip:
242
249
  cmd.append(f"--main_process_ip={args.main_process_ip}")
@@ -263,8 +270,10 @@ def main() -> int:
263
270
  exit_code = 130
264
271
 
265
272
  # Print summary
273
+ wandb_enabled = "--wandb" in train_args
266
274
  print_summary(
267
275
  exit_code,
276
+ wandb_enabled,
268
277
  os.environ.get("WANDB_MODE", "offline"),
269
278
  os.environ.get("WANDB_DIR", "/tmp/wandb"),
270
279
  )
@@ -31,7 +31,7 @@ try:
31
31
  import optuna
32
32
  from optuna.trial import TrialState
33
33
  except ImportError:
34
- print("Error: Optuna not installed. Run: pip install -e '.[hpo]'")
34
+ print("Error: Optuna not installed. Run: pip install wavedl")
35
35
  sys.exit(1)
36
36
 
37
37
 
@@ -147,6 +147,32 @@ def create_objective(args):
147
147
  cmd.extend(["--output_dir", tmpdir])
148
148
  history_file = Path(tmpdir) / "training_history.csv"
149
149
 
150
+ # GPU isolation for parallel trials: assign each trial to a specific GPU
151
+ # This prevents multiple trials from competing for all GPUs
152
+ env = None
153
+ if args.n_jobs > 1:
154
+ import os
155
+
156
+ # Detect available GPUs
157
+ n_gpus = 1
158
+ try:
159
+ import subprocess as sp
160
+
161
+ result_gpu = sp.run(
162
+ ["nvidia-smi", "--list-gpus"],
163
+ capture_output=True,
164
+ text=True,
165
+ )
166
+ if result_gpu.returncode == 0:
167
+ n_gpus = len(result_gpu.stdout.strip().split("\n"))
168
+ except Exception:
169
+ pass
170
+
171
+ # Assign trial to a specific GPU (round-robin)
172
+ gpu_id = trial.number % n_gpus
173
+ env = os.environ.copy()
174
+ env["CUDA_VISIBLE_DEVICES"] = str(gpu_id)
175
+
150
176
  # Run training
151
177
  try:
152
178
  result = subprocess.run(
@@ -155,6 +181,7 @@ def create_objective(args):
155
181
  text=True,
156
182
  timeout=args.timeout,
157
183
  cwd=Path(__file__).parent,
184
+ env=env,
158
185
  )
159
186
 
160
187
  # Read best val_loss from training_history.csv (reliable machine-readable)
@@ -248,7 +275,10 @@ Examples:
248
275
  "--n_trials", type=int, default=50, help="Number of HPO trials (default: 50)"
249
276
  )
250
277
  parser.add_argument(
251
- "--n_jobs", type=int, default=1, help="Parallel trials (default: 1)"
278
+ "--n_jobs",
279
+ type=int,
280
+ default=-1,
281
+ help="Parallel trials (-1 = auto-detect GPUs, default: -1)",
252
282
  )
253
283
  parser.add_argument(
254
284
  "--quick",
@@ -315,11 +345,30 @@ Examples:
315
345
 
316
346
  args = parser.parse_args()
317
347
 
348
+ # Convert to absolute path (child processes may run in different cwd)
349
+ args.data_path = str(Path(args.data_path).resolve())
350
+
318
351
  # Validate data path
319
352
  if not Path(args.data_path).exists():
320
353
  print(f"Error: Data file not found: {args.data_path}")
321
354
  sys.exit(1)
322
355
 
356
+ # Auto-detect GPUs for n_jobs if not specified
357
+ if args.n_jobs == -1:
358
+ try:
359
+ result_gpu = subprocess.run(
360
+ ["nvidia-smi", "--list-gpus"],
361
+ capture_output=True,
362
+ text=True,
363
+ )
364
+ if result_gpu.returncode == 0:
365
+ args.n_jobs = max(1, len(result_gpu.stdout.strip().split("\n")))
366
+ else:
367
+ args.n_jobs = 1
368
+ except Exception:
369
+ args.n_jobs = 1
370
+ print(f"Auto-detected {args.n_jobs} GPU(s) for parallel trials")
371
+
323
372
  # Create study
324
373
  print("=" * 60)
325
374
  print("WaveDL Hyperparameter Optimization")
@@ -366,23 +366,37 @@ def load_checkpoint(
366
366
  logging.info(f" Building model: {model_name}")
367
367
  model = build_model(model_name, in_shape=in_shape, out_size=out_size)
368
368
 
369
- # Load weights (prefer safetensors)
370
- weight_path = checkpoint_dir / "model.safetensors"
371
- if not weight_path.exists():
372
- weight_path = checkpoint_dir / "pytorch_model.bin"
373
-
374
- if not weight_path.exists():
375
- raise FileNotFoundError(f"No model weights found in {checkpoint_dir}")
369
+ # Load weights (check multiple formats in order of preference)
370
+ weight_path = None
371
+ for fname in ["model.safetensors", "model.bin", "pytorch_model.bin"]:
372
+ candidate = checkpoint_dir / fname
373
+ if candidate.exists():
374
+ weight_path = candidate
375
+ break
376
+
377
+ if weight_path is None:
378
+ raise FileNotFoundError(
379
+ f"No model weights found in {checkpoint_dir}. "
380
+ f"Expected one of: model.safetensors, model.bin, pytorch_model.bin"
381
+ )
376
382
 
377
383
  if HAS_SAFETENSORS and weight_path.suffix == ".safetensors":
378
384
  state_dict = load_safetensors(str(weight_path))
379
385
  else:
380
386
  state_dict = torch.load(weight_path, map_location="cpu", weights_only=True)
381
387
 
382
- # Remove 'module.' prefix from DDP checkpoints (leading only, not all occurrences)
383
- state_dict = {
384
- (k[7:] if k.startswith("module.") else k): v for k, v in state_dict.items()
385
- }
388
+ # Remove wrapper prefixes from checkpoints:
389
+ # - 'module.' from DDP (DistributedDataParallel)
390
+ # - '_orig_mod.' from torch.compile()
391
+ cleaned_dict = {}
392
+ for k, v in state_dict.items():
393
+ key = k
394
+ if key.startswith("module."):
395
+ key = key[7:] # Remove 'module.' (7 chars)
396
+ if key.startswith("_orig_mod."):
397
+ key = key[10:] # Remove '_orig_mod.' (10 chars)
398
+ cleaned_dict[key] = v
399
+ state_dict = cleaned_dict
386
400
 
387
401
  model.load_state_dict(state_dict)
388
402
  model.eval()
@@ -148,6 +148,24 @@ torch.set_float32_matmul_precision("high") # Use TF32 for float32 ops
148
148
  torch.backends.cudnn.benchmark = True
149
149
 
150
150
 
151
+ # ==============================================================================
152
+ # LOGGING UTILITIES
153
+ # ==============================================================================
154
+ from contextlib import contextmanager
155
+
156
+
157
+ @contextmanager
158
+ def suppress_accelerate_logging():
159
+ """Temporarily suppress accelerate's verbose checkpoint save messages."""
160
+ accelerate_logger = logging.getLogger("accelerate.checkpointing")
161
+ original_level = accelerate_logger.level
162
+ accelerate_logger.setLevel(logging.WARNING)
163
+ try:
164
+ yield
165
+ finally:
166
+ accelerate_logger.setLevel(original_level)
167
+
168
+
151
169
  # ==============================================================================
152
170
  # ARGUMENT PARSING
153
171
  # ==============================================================================
@@ -1033,7 +1051,8 @@ def main():
1033
1051
  # Step 3: Save checkpoint with all ranks participating
1034
1052
  if is_best_epoch:
1035
1053
  ckpt_dir = os.path.join(args.output_dir, "best_checkpoint")
1036
- accelerator.save_state(ckpt_dir) # All ranks must call this
1054
+ with suppress_accelerate_logging():
1055
+ accelerator.save_state(ckpt_dir, safe_serialization=False)
1037
1056
 
1038
1057
  # Step 4: Rank 0 handles metadata and updates tracking variables
1039
1058
  if accelerator.is_main_process:
@@ -1096,7 +1115,8 @@ def main():
1096
1115
  if periodic_checkpoint_needed:
1097
1116
  ckpt_name = f"epoch_{epoch + 1}_checkpoint"
1098
1117
  ckpt_dir = os.path.join(args.output_dir, ckpt_name)
1099
- accelerator.save_state(ckpt_dir) # All ranks participate
1118
+ with suppress_accelerate_logging():
1119
+ accelerator.save_state(ckpt_dir, safe_serialization=False)
1100
1120
 
1101
1121
  if accelerator.is_main_process:
1102
1122
  with open(os.path.join(ckpt_dir, "training_meta.pkl"), "wb") as f:
@@ -1147,7 +1167,11 @@ def main():
1147
1167
 
1148
1168
  except KeyboardInterrupt:
1149
1169
  logger.warning("Training interrupted. Saving emergency checkpoint...")
1150
- accelerator.save_state(os.path.join(args.output_dir, "interrupted_checkpoint"))
1170
+ with suppress_accelerate_logging():
1171
+ accelerator.save_state(
1172
+ os.path.join(args.output_dir, "interrupted_checkpoint"),
1173
+ safe_serialization=False,
1174
+ )
1151
1175
 
1152
1176
  except Exception as e:
1153
1177
  logger.error(f"Critical error: {e}", exc_info=True)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.2
2
2
  Name: wavedl
3
- Version: 1.4.4
3
+ Version: 1.4.6
4
4
  Summary: A Scalable Deep Learning Framework for Wave-Based Inverse Problems
5
5
  Author: Ductho Le
6
6
  License: MIT
@@ -49,7 +49,7 @@ Requires-Dist: triton>=2.0.0; sys_platform == "linux"
49
49
 
50
50
  ### A Scalable Deep Learning Framework for Wave-Based Inverse Problems
51
51
 
52
- [![Python 3.11+](https://img.shields.io/badge/python-3.11+-blue.svg?style=plastic&logo=python&logoColor=white)](https://www.python.org/downloads/)
52
+ [![Python 3.11+](https://img.shields.io/badge/Python-3.11+-blue.svg?style=plastic&logo=python&logoColor=white)](https://www.python.org/downloads/)
53
53
  [![PyTorch 2.x](https://img.shields.io/badge/PyTorch-2.x-ee4c2c.svg?style=plastic&logo=pytorch&logoColor=white)](https://pytorch.org/)
54
54
  [![Accelerate](https://img.shields.io/badge/Accelerate-Enabled-yellow.svg?style=plastic&logo=huggingface&logoColor=white)](https://huggingface.co/docs/accelerate/)
55
55
  <br>
@@ -57,7 +57,7 @@ Requires-Dist: triton>=2.0.0; sys_platform == "linux"
57
57
  [![Lint](https://img.shields.io/github/actions/workflow/status/ductho-le/WaveDL/lint.yml?branch=main&style=plastic&logo=ruff&logoColor=white&label=Lint)](https://github.com/ductho-le/WaveDL/actions/workflows/lint.yml)
58
58
  [![Try it on Colab](https://img.shields.io/badge/Try_it_on_Colab-8E44AD?style=plastic&logo=googlecolab&logoColor=white)](https://colab.research.google.com/github/ductho-le/WaveDL/blob/main/notebooks/demo.ipynb)
59
59
  <br>
60
- [![Downloads](https://img.shields.io/pepy/dt/wavedl?style=plastic&logo=pypi&logoColor=white&color=9ACD32)](https://pepy.tech/project/wavedl)
60
+ [![Downloads](https://img.shields.io/badge/dynamic/json?url=https://pypistats.org/api/packages/wavedl/recent?period=month%26mirrors=false&query=data.last_month&style=plastic&logo=pypi&logoColor=white&color=9ACD32&label=Downloads&suffix=/month)](https://pypistats.org/packages/wavedl)
61
61
  [![License: MIT](https://img.shields.io/badge/License-MIT-orange.svg?style=plastic)](LICENSE)
62
62
  [![DOI](https://img.shields.io/badge/DOI-10.5281/zenodo.18012338-008080.svg?style=plastic)](https://doi.org/10.5281/zenodo.18012338)
63
63
 
@@ -462,7 +462,43 @@ WaveDL/
462
462
  | **U-Net** — U-shaped Network |||
463
463
  | `unet_regression` | 31.1M | 1D/2D/3D |
464
464
 
465
- > ⭐ = Pretrained on ImageNet. Recommended for smaller datasets.
465
+ ⭐ = **Pretrained on ImageNet** (recommended for smaller datasets). Weights are downloaded automatically on first use.
466
+ - **Cache location**: `~/.cache/torch/hub/checkpoints/` (or `./.torch_cache/` on HPC if home is not writable)
467
+ - **Size**: ~20–350 MB per model depending on architecture
468
+
469
+ **💡 HPC Users**: If compute nodes block internet, pre-download weights on the login node:
470
+
471
+ ```bash
472
+ # Run once on login node (with internet) — downloads ALL pretrained weights (~1.5 GB total)
473
+ python -c "
474
+ import os
475
+ os.environ['TORCH_HOME'] = '.torch_cache' # Match WaveDL's HPC cache location
476
+
477
+ from torchvision import models as m
478
+ from torchvision.models import video as v
479
+
480
+ # Model name -> Weights class mapping
481
+ weights = {
482
+ 'resnet18': m.ResNet18_Weights, 'resnet50': m.ResNet50_Weights,
483
+ 'efficientnet_b0': m.EfficientNet_B0_Weights, 'efficientnet_b1': m.EfficientNet_B1_Weights,
484
+ 'efficientnet_b2': m.EfficientNet_B2_Weights, 'efficientnet_v2_s': m.EfficientNet_V2_S_Weights,
485
+ 'efficientnet_v2_m': m.EfficientNet_V2_M_Weights, 'efficientnet_v2_l': m.EfficientNet_V2_L_Weights,
486
+ 'mobilenet_v3_small': m.MobileNet_V3_Small_Weights, 'mobilenet_v3_large': m.MobileNet_V3_Large_Weights,
487
+ 'regnet_y_400mf': m.RegNet_Y_400MF_Weights, 'regnet_y_800mf': m.RegNet_Y_800MF_Weights,
488
+ 'regnet_y_1_6gf': m.RegNet_Y_1_6GF_Weights, 'regnet_y_3_2gf': m.RegNet_Y_3_2GF_Weights,
489
+ 'regnet_y_8gf': m.RegNet_Y_8GF_Weights, 'swin_t': m.Swin_T_Weights, 'swin_s': m.Swin_S_Weights,
490
+ 'swin_b': m.Swin_B_Weights, 'convnext_tiny': m.ConvNeXt_Tiny_Weights, 'densenet121': m.DenseNet121_Weights,
491
+ }
492
+ for name, w in weights.items():
493
+ getattr(m, name)(weights=w.DEFAULT); print(f'✓ {name}')
494
+
495
+ # 3D video models
496
+ v.r3d_18(weights=v.R3D_18_Weights.DEFAULT); print('✓ r3d_18')
497
+ v.mc3_18(weights=v.MC3_18_Weights.DEFAULT); print('✓ mc3_18')
498
+ print('\\n✓ All pretrained weights cached!')
499
+ "
500
+ ```
501
+
466
502
 
467
503
  </details>
468
504
 
@@ -687,7 +723,6 @@ compile: false
687
723
  seed: 2025
688
724
  ```
689
725
 
690
- > [!TIP]
691
726
  > See [`configs/config.yaml`](configs/config.yaml) for the complete template with all available options documented.
692
727
 
693
728
  </details>
@@ -699,18 +734,20 @@ Automatically find the best training configuration using [Optuna](https://optuna
699
734
 
700
735
  **Run HPO:**
701
736
 
702
- You specify which models to search and how many trials to run:
703
737
  ```bash
704
- # Search 3 models with 100 trials
705
- python -m wavedl.hpo --data_path train.npz --models cnn resnet18 efficientnet_b0 --n_trials 100
738
+ # Basic HPO (auto-detects GPUs for parallel trials)
739
+ wavedl-hpo --data_path train.npz --models cnn --n_trials 100
706
740
 
707
- # Search 1 model (faster)
708
- python -m wavedl.hpo --data_path train.npz --models cnn --n_trials 50
741
+ # Search multiple models
742
+ wavedl-hpo --data_path train.npz --models cnn resnet18 efficientnet_b0 --n_trials 200
709
743
 
710
- # Search all your candidate models
711
- python -m wavedl.hpo --data_path train.npz --models cnn resnet18 resnet50 vit_small densenet121 --n_trials 200
744
+ # Quick mode (fewer parameters, faster)
745
+ wavedl-hpo --data_path train.npz --models cnn --n_trials 50 --quick
712
746
  ```
713
747
 
748
+ > [!TIP]
749
+ > **Auto GPU Detection**: HPO automatically detects available GPUs and runs one trial per GPU in parallel. On a 4-GPU system, 4 trials run simultaneously. Use `--n_jobs 1` to force serial execution.
750
+
714
751
  **Train with best parameters**
715
752
 
716
753
  After HPO completes, it prints the optimal command:
@@ -749,11 +786,11 @@ accelerate launch -m wavedl.train --data_path train.npz --model cnn --lr 3.2e-4
749
786
  | `--optimizers` | all 6 | Optimizers to search |
750
787
  | `--schedulers` | all 8 | Schedulers to search |
751
788
  | `--losses` | all 6 | Losses to search |
752
- | `--n_jobs` | `1` | Parallel trials (multi-GPU) |
789
+ | `--n_jobs` | `-1` | Parallel trials (-1 = auto-detect GPUs) |
753
790
  | `--max_epochs` | `50` | Max epochs per trial |
754
791
  | `--output` | `hpo_results.json` | Output file |
755
792
 
756
- > [!TIP]
793
+
757
794
  > See [Available Models](#available-models) for all 38 architectures you can search.
758
795
 
759
796
  </details>
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes