wavedl 1.4.4__py3-none-any.whl → 1.4.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
wavedl/__init__.py CHANGED
@@ -18,7 +18,7 @@ For inference:
18
18
  # or: python -m wavedl.test --checkpoint best_checkpoint --data_path test.npz
19
19
  """
20
20
 
21
- __version__ = "1.4.4"
21
+ __version__ = "1.4.5"
22
22
  __author__ = "Ductho Le"
23
23
  __email__ = "ductho.le@outlook.com"
24
24
 
wavedl/test.py CHANGED
@@ -379,10 +379,18 @@ def load_checkpoint(
379
379
  else:
380
380
  state_dict = torch.load(weight_path, map_location="cpu", weights_only=True)
381
381
 
382
- # Remove 'module.' prefix from DDP checkpoints (leading only, not all occurrences)
383
- state_dict = {
384
- (k[7:] if k.startswith("module.") else k): v for k, v in state_dict.items()
385
- }
382
+ # Remove wrapper prefixes from checkpoints:
383
+ # - 'module.' from DDP (DistributedDataParallel)
384
+ # - '_orig_mod.' from torch.compile()
385
+ cleaned_dict = {}
386
+ for k, v in state_dict.items():
387
+ key = k
388
+ if key.startswith("module."):
389
+ key = key[7:] # Remove 'module.' (7 chars)
390
+ if key.startswith("_orig_mod."):
391
+ key = key[10:] # Remove '_orig_mod.' (10 chars)
392
+ cleaned_dict[key] = v
393
+ state_dict = cleaned_dict
386
394
 
387
395
  model.load_state_dict(state_dict)
388
396
  model.eval()
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.2
2
2
  Name: wavedl
3
- Version: 1.4.4
3
+ Version: 1.4.5
4
4
  Summary: A Scalable Deep Learning Framework for Wave-Based Inverse Problems
5
5
  Author: Ductho Le
6
6
  License: MIT
@@ -462,7 +462,43 @@ WaveDL/
462
462
  | **U-Net** — U-shaped Network |||
463
463
  | `unet_regression` | 31.1M | 1D/2D/3D |
464
464
 
465
- > ⭐ = Pretrained on ImageNet. Recommended for smaller datasets.
465
+ ⭐ = **Pretrained on ImageNet** (recommended for smaller datasets). Weights are downloaded automatically on first use.
466
+ - **Cache location**: `~/.cache/torch/hub/checkpoints/` (or `./.torch_cache/` on HPC if home is not writable)
467
+ - **Size**: ~20–350 MB per model depending on architecture
468
+
469
+ **💡 HPC Users**: If compute nodes block internet, pre-download weights on the login node:
470
+
471
+ ```bash
472
+ # Run once on login node (with internet) — downloads ALL pretrained weights (~1.5 GB total)
473
+ python -c "
474
+ import os
475
+ os.environ['TORCH_HOME'] = '.torch_cache' # Match WaveDL's HPC cache location
476
+
477
+ from torchvision import models as m
478
+ from torchvision.models import video as v
479
+
480
+ # Model name -> Weights class mapping
481
+ weights = {
482
+ 'resnet18': m.ResNet18_Weights, 'resnet50': m.ResNet50_Weights,
483
+ 'efficientnet_b0': m.EfficientNet_B0_Weights, 'efficientnet_b1': m.EfficientNet_B1_Weights,
484
+ 'efficientnet_b2': m.EfficientNet_B2_Weights, 'efficientnet_v2_s': m.EfficientNet_V2_S_Weights,
485
+ 'efficientnet_v2_m': m.EfficientNet_V2_M_Weights, 'efficientnet_v2_l': m.EfficientNet_V2_L_Weights,
486
+ 'mobilenet_v3_small': m.MobileNet_V3_Small_Weights, 'mobilenet_v3_large': m.MobileNet_V3_Large_Weights,
487
+ 'regnet_y_400mf': m.RegNet_Y_400MF_Weights, 'regnet_y_800mf': m.RegNet_Y_800MF_Weights,
488
+ 'regnet_y_1_6gf': m.RegNet_Y_1_6GF_Weights, 'regnet_y_3_2gf': m.RegNet_Y_3_2GF_Weights,
489
+ 'regnet_y_8gf': m.RegNet_Y_8GF_Weights, 'swin_t': m.Swin_T_Weights, 'swin_s': m.Swin_S_Weights,
490
+ 'swin_b': m.Swin_B_Weights, 'convnext_tiny': m.ConvNeXt_Tiny_Weights, 'densenet121': m.DenseNet121_Weights,
491
+ }
492
+ for name, w in weights.items():
493
+ getattr(m, name)(weights=w.DEFAULT); print(f'✓ {name}')
494
+
495
+ # 3D video models
496
+ v.r3d_18(weights=v.R3D_18_Weights.DEFAULT); print('✓ r3d_18')
497
+ v.mc3_18(weights=v.MC3_18_Weights.DEFAULT); print('✓ mc3_18')
498
+ print('\\n✓ All pretrained weights cached!')
499
+ "
500
+ ```
501
+
466
502
 
467
503
  </details>
468
504
 
@@ -687,7 +723,6 @@ compile: false
687
723
  seed: 2025
688
724
  ```
689
725
 
690
- > [!TIP]
691
726
  > See [`configs/config.yaml`](configs/config.yaml) for the complete template with all available options documented.
692
727
 
693
728
  </details>
@@ -753,7 +788,7 @@ accelerate launch -m wavedl.train --data_path train.npz --model cnn --lr 3.2e-4
753
788
  | `--max_epochs` | `50` | Max epochs per trial |
754
789
  | `--output` | `hpo_results.json` | Output file |
755
790
 
756
- > [!TIP]
791
+
757
792
  > See [Available Models](#available-models) for all 38 architectures you can search.
758
793
 
759
794
  </details>
@@ -1,7 +1,7 @@
1
- wavedl/__init__.py,sha256=n0XNSrp0aGEE6HQpzbF1wiKK-jORKgcc0Q2Op4MtQGk,1177
1
+ wavedl/__init__.py,sha256=2ro7SYQ3wCmq-ejiAm5sd6BeXf6sZgixC9U2vS7Ckbs,1177
2
2
  wavedl/hpc.py,sha256=0h8IZzOT0EzmEv3fU9cKyRVE9V1ivtBzbjuBCaxYadc,8445
3
3
  wavedl/hpo.py,sha256=YJXsnSGEBSVUqp_2ah7zu3_VClAUqZrdkuzDaSqQUjU,12952
4
- wavedl/test.py,sha256=Wajcze8gFEyJ9VyN_Bq-YadZ_VZtVaX_HicvUmW6MXM,37365
4
+ wavedl/test.py,sha256=81al6vQBDAJ3CpSEtxZn6xzR1c4-jo28R7tX_84KROc,37642
5
5
  wavedl/train.py,sha256=_pW7prvlNqfUGrGweHO2QelS87UiAYKvyJwqMAIj6yI,49292
6
6
  wavedl/models/__init__.py,sha256=lfSohEnAUztO14nuwayMJhPjpgySzRN3jGiyAUuBmAU,3206
7
7
  wavedl/models/_template.py,sha256=J_D8taSPmV8lBaucN_vU-WiG98iFr7CJrZVNNX_Tdts,4600
@@ -29,9 +29,9 @@ wavedl/utils/losses.py,sha256=5762M-TBC_hz6uyj1NPbU1vZeFOJQq7fR3-j7OygJRo,7254
29
29
  wavedl/utils/metrics.py,sha256=mkCpqZwl_XUpNvA5Ekjf7y-HqApafR7eR6EuA8cBdM8,37287
30
30
  wavedl/utils/optimizers.py,sha256=PyIkJ_hRhFi_Fio81Gy5YQNhcME0JUUEl8OTSyu-0RA,6323
31
31
  wavedl/utils/schedulers.py,sha256=e6Sf0yj8VOqkdwkUHLMyUfGfHKTX4NMr-zfgxWqCTYI,7659
32
- wavedl-1.4.4.dist-info/LICENSE,sha256=cEUCvcvH-9BT9Y-CNGY__PwWONCKu9zsoIqWA-NeHJ4,1066
33
- wavedl-1.4.4.dist-info/METADATA,sha256=4tgSvkwzJmZP3PLCrqx-FYV_w6VT6Mi4XIsB0Dvb6_0,40386
34
- wavedl-1.4.4.dist-info/WHEEL,sha256=beeZ86-EfXScwlR_HKu4SllMC9wUEj_8Z_4FJ3egI2w,91
35
- wavedl-1.4.4.dist-info/entry_points.txt,sha256=f1RNDkXFZwBzrBzTMFocJ6xhfTvTmaEDTi5YyDEUaF8,140
36
- wavedl-1.4.4.dist-info/top_level.txt,sha256=ccneUt3D5Qzbh3bsBSSrq9bqrhGiogcWKY24ZC4Q6Xw,7
37
- wavedl-1.4.4.dist-info/RECORD,,
32
+ wavedl-1.4.5.dist-info/LICENSE,sha256=cEUCvcvH-9BT9Y-CNGY__PwWONCKu9zsoIqWA-NeHJ4,1066
33
+ wavedl-1.4.5.dist-info/METADATA,sha256=4ltxFDaqPqh4XUAW_K8nkFmvqBzPcL2cxmghH11GMWg,42191
34
+ wavedl-1.4.5.dist-info/WHEEL,sha256=beeZ86-EfXScwlR_HKu4SllMC9wUEj_8Z_4FJ3egI2w,91
35
+ wavedl-1.4.5.dist-info/entry_points.txt,sha256=f1RNDkXFZwBzrBzTMFocJ6xhfTvTmaEDTi5YyDEUaF8,140
36
+ wavedl-1.4.5.dist-info/top_level.txt,sha256=ccneUt3D5Qzbh3bsBSSrq9bqrhGiogcWKY24ZC4Q6Xw,7
37
+ wavedl-1.4.5.dist-info/RECORD,,
File without changes