wavedl 1.4.4__tar.gz → 1.4.5__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (42) hide show
  1. {wavedl-1.4.4/src/wavedl.egg-info → wavedl-1.4.5}/PKG-INFO +39 -4
  2. {wavedl-1.4.4 → wavedl-1.4.5}/README.md +38 -3
  3. {wavedl-1.4.4 → wavedl-1.4.5}/src/wavedl/__init__.py +1 -1
  4. {wavedl-1.4.4 → wavedl-1.4.5}/src/wavedl/test.py +12 -4
  5. {wavedl-1.4.4 → wavedl-1.4.5/src/wavedl.egg-info}/PKG-INFO +39 -4
  6. {wavedl-1.4.4 → wavedl-1.4.5}/LICENSE +0 -0
  7. {wavedl-1.4.4 → wavedl-1.4.5}/pyproject.toml +0 -0
  8. {wavedl-1.4.4 → wavedl-1.4.5}/setup.cfg +0 -0
  9. {wavedl-1.4.4 → wavedl-1.4.5}/src/wavedl/hpc.py +0 -0
  10. {wavedl-1.4.4 → wavedl-1.4.5}/src/wavedl/hpo.py +0 -0
  11. {wavedl-1.4.4 → wavedl-1.4.5}/src/wavedl/models/__init__.py +0 -0
  12. {wavedl-1.4.4 → wavedl-1.4.5}/src/wavedl/models/_template.py +0 -0
  13. {wavedl-1.4.4 → wavedl-1.4.5}/src/wavedl/models/base.py +0 -0
  14. {wavedl-1.4.4 → wavedl-1.4.5}/src/wavedl/models/cnn.py +0 -0
  15. {wavedl-1.4.4 → wavedl-1.4.5}/src/wavedl/models/convnext.py +0 -0
  16. {wavedl-1.4.4 → wavedl-1.4.5}/src/wavedl/models/densenet.py +0 -0
  17. {wavedl-1.4.4 → wavedl-1.4.5}/src/wavedl/models/efficientnet.py +0 -0
  18. {wavedl-1.4.4 → wavedl-1.4.5}/src/wavedl/models/efficientnetv2.py +0 -0
  19. {wavedl-1.4.4 → wavedl-1.4.5}/src/wavedl/models/mobilenetv3.py +0 -0
  20. {wavedl-1.4.4 → wavedl-1.4.5}/src/wavedl/models/registry.py +0 -0
  21. {wavedl-1.4.4 → wavedl-1.4.5}/src/wavedl/models/regnet.py +0 -0
  22. {wavedl-1.4.4 → wavedl-1.4.5}/src/wavedl/models/resnet.py +0 -0
  23. {wavedl-1.4.4 → wavedl-1.4.5}/src/wavedl/models/resnet3d.py +0 -0
  24. {wavedl-1.4.4 → wavedl-1.4.5}/src/wavedl/models/swin.py +0 -0
  25. {wavedl-1.4.4 → wavedl-1.4.5}/src/wavedl/models/tcn.py +0 -0
  26. {wavedl-1.4.4 → wavedl-1.4.5}/src/wavedl/models/unet.py +0 -0
  27. {wavedl-1.4.4 → wavedl-1.4.5}/src/wavedl/models/vit.py +0 -0
  28. {wavedl-1.4.4 → wavedl-1.4.5}/src/wavedl/train.py +0 -0
  29. {wavedl-1.4.4 → wavedl-1.4.5}/src/wavedl/utils/__init__.py +0 -0
  30. {wavedl-1.4.4 → wavedl-1.4.5}/src/wavedl/utils/config.py +0 -0
  31. {wavedl-1.4.4 → wavedl-1.4.5}/src/wavedl/utils/cross_validation.py +0 -0
  32. {wavedl-1.4.4 → wavedl-1.4.5}/src/wavedl/utils/data.py +0 -0
  33. {wavedl-1.4.4 → wavedl-1.4.5}/src/wavedl/utils/distributed.py +0 -0
  34. {wavedl-1.4.4 → wavedl-1.4.5}/src/wavedl/utils/losses.py +0 -0
  35. {wavedl-1.4.4 → wavedl-1.4.5}/src/wavedl/utils/metrics.py +0 -0
  36. {wavedl-1.4.4 → wavedl-1.4.5}/src/wavedl/utils/optimizers.py +0 -0
  37. {wavedl-1.4.4 → wavedl-1.4.5}/src/wavedl/utils/schedulers.py +0 -0
  38. {wavedl-1.4.4 → wavedl-1.4.5}/src/wavedl.egg-info/SOURCES.txt +0 -0
  39. {wavedl-1.4.4 → wavedl-1.4.5}/src/wavedl.egg-info/dependency_links.txt +0 -0
  40. {wavedl-1.4.4 → wavedl-1.4.5}/src/wavedl.egg-info/entry_points.txt +0 -0
  41. {wavedl-1.4.4 → wavedl-1.4.5}/src/wavedl.egg-info/requires.txt +0 -0
  42. {wavedl-1.4.4 → wavedl-1.4.5}/src/wavedl.egg-info/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.2
2
2
  Name: wavedl
3
- Version: 1.4.4
3
+ Version: 1.4.5
4
4
  Summary: A Scalable Deep Learning Framework for Wave-Based Inverse Problems
5
5
  Author: Ductho Le
6
6
  License: MIT
@@ -462,7 +462,43 @@ WaveDL/
462
462
  | **U-Net** — U-shaped Network |||
463
463
  | `unet_regression` | 31.1M | 1D/2D/3D |
464
464
 
465
- > ⭐ = Pretrained on ImageNet. Recommended for smaller datasets.
465
+ ⭐ = **Pretrained on ImageNet** (recommended for smaller datasets). Weights are downloaded automatically on first use.
466
+ - **Cache location**: `~/.cache/torch/hub/checkpoints/` (or `./.torch_cache/` on HPC if home is not writable)
467
+ - **Size**: ~20–350 MB per model depending on architecture
468
+
469
+ **💡 HPC Users**: If compute nodes block internet, pre-download weights on the login node:
470
+
471
+ ```bash
472
+ # Run once on login node (with internet) — downloads ALL pretrained weights (~1.5 GB total)
473
+ python -c "
474
+ import os
475
+ os.environ['TORCH_HOME'] = '.torch_cache' # Match WaveDL's HPC cache location
476
+
477
+ from torchvision import models as m
478
+ from torchvision.models import video as v
479
+
480
+ # Model name -> Weights class mapping
481
+ weights = {
482
+ 'resnet18': m.ResNet18_Weights, 'resnet50': m.ResNet50_Weights,
483
+ 'efficientnet_b0': m.EfficientNet_B0_Weights, 'efficientnet_b1': m.EfficientNet_B1_Weights,
484
+ 'efficientnet_b2': m.EfficientNet_B2_Weights, 'efficientnet_v2_s': m.EfficientNet_V2_S_Weights,
485
+ 'efficientnet_v2_m': m.EfficientNet_V2_M_Weights, 'efficientnet_v2_l': m.EfficientNet_V2_L_Weights,
486
+ 'mobilenet_v3_small': m.MobileNet_V3_Small_Weights, 'mobilenet_v3_large': m.MobileNet_V3_Large_Weights,
487
+ 'regnet_y_400mf': m.RegNet_Y_400MF_Weights, 'regnet_y_800mf': m.RegNet_Y_800MF_Weights,
488
+ 'regnet_y_1_6gf': m.RegNet_Y_1_6GF_Weights, 'regnet_y_3_2gf': m.RegNet_Y_3_2GF_Weights,
489
+ 'regnet_y_8gf': m.RegNet_Y_8GF_Weights, 'swin_t': m.Swin_T_Weights, 'swin_s': m.Swin_S_Weights,
490
+ 'swin_b': m.Swin_B_Weights, 'convnext_tiny': m.ConvNeXt_Tiny_Weights, 'densenet121': m.DenseNet121_Weights,
491
+ }
492
+ for name, w in weights.items():
493
+ getattr(m, name)(weights=w.DEFAULT); print(f'✓ {name}')
494
+
495
+ # 3D video models
496
+ v.r3d_18(weights=v.R3D_18_Weights.DEFAULT); print('✓ r3d_18')
497
+ v.mc3_18(weights=v.MC3_18_Weights.DEFAULT); print('✓ mc3_18')
498
+ print('\\n✓ All pretrained weights cached!')
499
+ "
500
+ ```
501
+
466
502
 
467
503
  </details>
468
504
 
@@ -687,7 +723,6 @@ compile: false
687
723
  seed: 2025
688
724
  ```
689
725
 
690
- > [!TIP]
691
726
  > See [`configs/config.yaml`](configs/config.yaml) for the complete template with all available options documented.
692
727
 
693
728
  </details>
@@ -753,7 +788,7 @@ accelerate launch -m wavedl.train --data_path train.npz --model cnn --lr 3.2e-4
753
788
  | `--max_epochs` | `50` | Max epochs per trial |
754
789
  | `--output` | `hpo_results.json` | Output file |
755
790
 
756
- > [!TIP]
791
+
757
792
  > See [Available Models](#available-models) for all 38 architectures you can search.
758
793
 
759
794
  </details>
@@ -417,7 +417,43 @@ WaveDL/
417
417
  | **U-Net** — U-shaped Network |||
418
418
  | `unet_regression` | 31.1M | 1D/2D/3D |
419
419
 
420
- > ⭐ = Pretrained on ImageNet. Recommended for smaller datasets.
420
+ ⭐ = **Pretrained on ImageNet** (recommended for smaller datasets). Weights are downloaded automatically on first use.
421
+ - **Cache location**: `~/.cache/torch/hub/checkpoints/` (or `./.torch_cache/` on HPC if home is not writable)
422
+ - **Size**: ~20–350 MB per model depending on architecture
423
+
424
+ **💡 HPC Users**: If compute nodes block internet, pre-download weights on the login node:
425
+
426
+ ```bash
427
+ # Run once on login node (with internet) — downloads ALL pretrained weights (~1.5 GB total)
428
+ python -c "
429
+ import os
430
+ os.environ['TORCH_HOME'] = '.torch_cache' # Match WaveDL's HPC cache location
431
+
432
+ from torchvision import models as m
433
+ from torchvision.models import video as v
434
+
435
+ # Model name -> Weights class mapping
436
+ weights = {
437
+ 'resnet18': m.ResNet18_Weights, 'resnet50': m.ResNet50_Weights,
438
+ 'efficientnet_b0': m.EfficientNet_B0_Weights, 'efficientnet_b1': m.EfficientNet_B1_Weights,
439
+ 'efficientnet_b2': m.EfficientNet_B2_Weights, 'efficientnet_v2_s': m.EfficientNet_V2_S_Weights,
440
+ 'efficientnet_v2_m': m.EfficientNet_V2_M_Weights, 'efficientnet_v2_l': m.EfficientNet_V2_L_Weights,
441
+ 'mobilenet_v3_small': m.MobileNet_V3_Small_Weights, 'mobilenet_v3_large': m.MobileNet_V3_Large_Weights,
442
+ 'regnet_y_400mf': m.RegNet_Y_400MF_Weights, 'regnet_y_800mf': m.RegNet_Y_800MF_Weights,
443
+ 'regnet_y_1_6gf': m.RegNet_Y_1_6GF_Weights, 'regnet_y_3_2gf': m.RegNet_Y_3_2GF_Weights,
444
+ 'regnet_y_8gf': m.RegNet_Y_8GF_Weights, 'swin_t': m.Swin_T_Weights, 'swin_s': m.Swin_S_Weights,
445
+ 'swin_b': m.Swin_B_Weights, 'convnext_tiny': m.ConvNeXt_Tiny_Weights, 'densenet121': m.DenseNet121_Weights,
446
+ }
447
+ for name, w in weights.items():
448
+ getattr(m, name)(weights=w.DEFAULT); print(f'✓ {name}')
449
+
450
+ # 3D video models
451
+ v.r3d_18(weights=v.R3D_18_Weights.DEFAULT); print('✓ r3d_18')
452
+ v.mc3_18(weights=v.MC3_18_Weights.DEFAULT); print('✓ mc3_18')
453
+ print('\\n✓ All pretrained weights cached!')
454
+ "
455
+ ```
456
+
421
457
 
422
458
  </details>
423
459
 
@@ -642,7 +678,6 @@ compile: false
642
678
  seed: 2025
643
679
  ```
644
680
 
645
- > [!TIP]
646
681
  > See [`configs/config.yaml`](configs/config.yaml) for the complete template with all available options documented.
647
682
 
648
683
  </details>
@@ -708,7 +743,7 @@ accelerate launch -m wavedl.train --data_path train.npz --model cnn --lr 3.2e-4
708
743
  | `--max_epochs` | `50` | Max epochs per trial |
709
744
  | `--output` | `hpo_results.json` | Output file |
710
745
 
711
- > [!TIP]
746
+
712
747
  > See [Available Models](#available-models) for all 38 architectures you can search.
713
748
 
714
749
  </details>
@@ -18,7 +18,7 @@ For inference:
18
18
  # or: python -m wavedl.test --checkpoint best_checkpoint --data_path test.npz
19
19
  """
20
20
 
21
- __version__ = "1.4.4"
21
+ __version__ = "1.4.5"
22
22
  __author__ = "Ductho Le"
23
23
  __email__ = "ductho.le@outlook.com"
24
24
 
@@ -379,10 +379,18 @@ def load_checkpoint(
379
379
  else:
380
380
  state_dict = torch.load(weight_path, map_location="cpu", weights_only=True)
381
381
 
382
- # Remove 'module.' prefix from DDP checkpoints (leading only, not all occurrences)
383
- state_dict = {
384
- (k[7:] if k.startswith("module.") else k): v for k, v in state_dict.items()
385
- }
382
+ # Remove wrapper prefixes from checkpoints:
383
+ # - 'module.' from DDP (DistributedDataParallel)
384
+ # - '_orig_mod.' from torch.compile()
385
+ cleaned_dict = {}
386
+ for k, v in state_dict.items():
387
+ key = k
388
+ if key.startswith("module."):
389
+ key = key[7:] # Remove 'module.' (7 chars)
390
+ if key.startswith("_orig_mod."):
391
+ key = key[10:] # Remove '_orig_mod.' (10 chars)
392
+ cleaned_dict[key] = v
393
+ state_dict = cleaned_dict
386
394
 
387
395
  model.load_state_dict(state_dict)
388
396
  model.eval()
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.2
2
2
  Name: wavedl
3
- Version: 1.4.4
3
+ Version: 1.4.5
4
4
  Summary: A Scalable Deep Learning Framework for Wave-Based Inverse Problems
5
5
  Author: Ductho Le
6
6
  License: MIT
@@ -462,7 +462,43 @@ WaveDL/
462
462
  | **U-Net** — U-shaped Network |||
463
463
  | `unet_regression` | 31.1M | 1D/2D/3D |
464
464
 
465
- > ⭐ = Pretrained on ImageNet. Recommended for smaller datasets.
465
+ ⭐ = **Pretrained on ImageNet** (recommended for smaller datasets). Weights are downloaded automatically on first use.
466
+ - **Cache location**: `~/.cache/torch/hub/checkpoints/` (or `./.torch_cache/` on HPC if home is not writable)
467
+ - **Size**: ~20–350 MB per model depending on architecture
468
+
469
+ **💡 HPC Users**: If compute nodes block internet, pre-download weights on the login node:
470
+
471
+ ```bash
472
+ # Run once on login node (with internet) — downloads ALL pretrained weights (~1.5 GB total)
473
+ python -c "
474
+ import os
475
+ os.environ['TORCH_HOME'] = '.torch_cache' # Match WaveDL's HPC cache location
476
+
477
+ from torchvision import models as m
478
+ from torchvision.models import video as v
479
+
480
+ # Model name -> Weights class mapping
481
+ weights = {
482
+ 'resnet18': m.ResNet18_Weights, 'resnet50': m.ResNet50_Weights,
483
+ 'efficientnet_b0': m.EfficientNet_B0_Weights, 'efficientnet_b1': m.EfficientNet_B1_Weights,
484
+ 'efficientnet_b2': m.EfficientNet_B2_Weights, 'efficientnet_v2_s': m.EfficientNet_V2_S_Weights,
485
+ 'efficientnet_v2_m': m.EfficientNet_V2_M_Weights, 'efficientnet_v2_l': m.EfficientNet_V2_L_Weights,
486
+ 'mobilenet_v3_small': m.MobileNet_V3_Small_Weights, 'mobilenet_v3_large': m.MobileNet_V3_Large_Weights,
487
+ 'regnet_y_400mf': m.RegNet_Y_400MF_Weights, 'regnet_y_800mf': m.RegNet_Y_800MF_Weights,
488
+ 'regnet_y_1_6gf': m.RegNet_Y_1_6GF_Weights, 'regnet_y_3_2gf': m.RegNet_Y_3_2GF_Weights,
489
+ 'regnet_y_8gf': m.RegNet_Y_8GF_Weights, 'swin_t': m.Swin_T_Weights, 'swin_s': m.Swin_S_Weights,
490
+ 'swin_b': m.Swin_B_Weights, 'convnext_tiny': m.ConvNeXt_Tiny_Weights, 'densenet121': m.DenseNet121_Weights,
491
+ }
492
+ for name, w in weights.items():
493
+ getattr(m, name)(weights=w.DEFAULT); print(f'✓ {name}')
494
+
495
+ # 3D video models
496
+ v.r3d_18(weights=v.R3D_18_Weights.DEFAULT); print('✓ r3d_18')
497
+ v.mc3_18(weights=v.MC3_18_Weights.DEFAULT); print('✓ mc3_18')
498
+ print('\\n✓ All pretrained weights cached!')
499
+ "
500
+ ```
501
+
466
502
 
467
503
  </details>
468
504
 
@@ -687,7 +723,6 @@ compile: false
687
723
  seed: 2025
688
724
  ```
689
725
 
690
- > [!TIP]
691
726
  > See [`configs/config.yaml`](configs/config.yaml) for the complete template with all available options documented.
692
727
 
693
728
  </details>
@@ -753,7 +788,7 @@ accelerate launch -m wavedl.train --data_path train.npz --model cnn --lr 3.2e-4
753
788
  | `--max_epochs` | `50` | Max epochs per trial |
754
789
  | `--output` | `hpo_results.json` | Output file |
755
790
 
756
- > [!TIP]
791
+
757
792
  > See [Available Models](#available-models) for all 38 architectures you can search.
758
793
 
759
794
  </details>
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes