wavedl 1.4.4__py3-none-any.whl → 1.4.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- wavedl/__init__.py +1 -1
- wavedl/test.py +12 -4
- {wavedl-1.4.4.dist-info → wavedl-1.4.5.dist-info}/METADATA +39 -4
- {wavedl-1.4.4.dist-info → wavedl-1.4.5.dist-info}/RECORD +8 -8
- {wavedl-1.4.4.dist-info → wavedl-1.4.5.dist-info}/LICENSE +0 -0
- {wavedl-1.4.4.dist-info → wavedl-1.4.5.dist-info}/WHEEL +0 -0
- {wavedl-1.4.4.dist-info → wavedl-1.4.5.dist-info}/entry_points.txt +0 -0
- {wavedl-1.4.4.dist-info → wavedl-1.4.5.dist-info}/top_level.txt +0 -0
wavedl/__init__.py
CHANGED
wavedl/test.py
CHANGED
|
@@ -379,10 +379,18 @@ def load_checkpoint(
|
|
|
379
379
|
else:
|
|
380
380
|
state_dict = torch.load(weight_path, map_location="cpu", weights_only=True)
|
|
381
381
|
|
|
382
|
-
# Remove
|
|
383
|
-
|
|
384
|
-
|
|
385
|
-
}
|
|
382
|
+
# Remove wrapper prefixes from checkpoints:
|
|
383
|
+
# - 'module.' from DDP (DistributedDataParallel)
|
|
384
|
+
# - '_orig_mod.' from torch.compile()
|
|
385
|
+
cleaned_dict = {}
|
|
386
|
+
for k, v in state_dict.items():
|
|
387
|
+
key = k
|
|
388
|
+
if key.startswith("module."):
|
|
389
|
+
key = key[7:] # Remove 'module.' (7 chars)
|
|
390
|
+
if key.startswith("_orig_mod."):
|
|
391
|
+
key = key[10:] # Remove '_orig_mod.' (10 chars)
|
|
392
|
+
cleaned_dict[key] = v
|
|
393
|
+
state_dict = cleaned_dict
|
|
386
394
|
|
|
387
395
|
model.load_state_dict(state_dict)
|
|
388
396
|
model.eval()
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.2
|
|
2
2
|
Name: wavedl
|
|
3
|
-
Version: 1.4.
|
|
3
|
+
Version: 1.4.5
|
|
4
4
|
Summary: A Scalable Deep Learning Framework for Wave-Based Inverse Problems
|
|
5
5
|
Author: Ductho Le
|
|
6
6
|
License: MIT
|
|
@@ -462,7 +462,43 @@ WaveDL/
|
|
|
462
462
|
| **U-Net** — U-shaped Network |||
|
|
463
463
|
| `unet_regression` | 31.1M | 1D/2D/3D |
|
|
464
464
|
|
|
465
|
-
|
|
465
|
+
⭐ = **Pretrained on ImageNet** (recommended for smaller datasets). Weights are downloaded automatically on first use.
|
|
466
|
+
- **Cache location**: `~/.cache/torch/hub/checkpoints/` (or `./.torch_cache/` on HPC if home is not writable)
|
|
467
|
+
- **Size**: ~20–350 MB per model depending on architecture
|
|
468
|
+
|
|
469
|
+
**💡 HPC Users**: If compute nodes block internet, pre-download weights on the login node:
|
|
470
|
+
|
|
471
|
+
```bash
|
|
472
|
+
# Run once on login node (with internet) — downloads ALL pretrained weights (~1.5 GB total)
|
|
473
|
+
python -c "
|
|
474
|
+
import os
|
|
475
|
+
os.environ['TORCH_HOME'] = '.torch_cache' # Match WaveDL's HPC cache location
|
|
476
|
+
|
|
477
|
+
from torchvision import models as m
|
|
478
|
+
from torchvision.models import video as v
|
|
479
|
+
|
|
480
|
+
# Model name -> Weights class mapping
|
|
481
|
+
weights = {
|
|
482
|
+
'resnet18': m.ResNet18_Weights, 'resnet50': m.ResNet50_Weights,
|
|
483
|
+
'efficientnet_b0': m.EfficientNet_B0_Weights, 'efficientnet_b1': m.EfficientNet_B1_Weights,
|
|
484
|
+
'efficientnet_b2': m.EfficientNet_B2_Weights, 'efficientnet_v2_s': m.EfficientNet_V2_S_Weights,
|
|
485
|
+
'efficientnet_v2_m': m.EfficientNet_V2_M_Weights, 'efficientnet_v2_l': m.EfficientNet_V2_L_Weights,
|
|
486
|
+
'mobilenet_v3_small': m.MobileNet_V3_Small_Weights, 'mobilenet_v3_large': m.MobileNet_V3_Large_Weights,
|
|
487
|
+
'regnet_y_400mf': m.RegNet_Y_400MF_Weights, 'regnet_y_800mf': m.RegNet_Y_800MF_Weights,
|
|
488
|
+
'regnet_y_1_6gf': m.RegNet_Y_1_6GF_Weights, 'regnet_y_3_2gf': m.RegNet_Y_3_2GF_Weights,
|
|
489
|
+
'regnet_y_8gf': m.RegNet_Y_8GF_Weights, 'swin_t': m.Swin_T_Weights, 'swin_s': m.Swin_S_Weights,
|
|
490
|
+
'swin_b': m.Swin_B_Weights, 'convnext_tiny': m.ConvNeXt_Tiny_Weights, 'densenet121': m.DenseNet121_Weights,
|
|
491
|
+
}
|
|
492
|
+
for name, w in weights.items():
|
|
493
|
+
getattr(m, name)(weights=w.DEFAULT); print(f'✓ {name}')
|
|
494
|
+
|
|
495
|
+
# 3D video models
|
|
496
|
+
v.r3d_18(weights=v.R3D_18_Weights.DEFAULT); print('✓ r3d_18')
|
|
497
|
+
v.mc3_18(weights=v.MC3_18_Weights.DEFAULT); print('✓ mc3_18')
|
|
498
|
+
print('\\n✓ All pretrained weights cached!')
|
|
499
|
+
"
|
|
500
|
+
```
|
|
501
|
+
|
|
466
502
|
|
|
467
503
|
</details>
|
|
468
504
|
|
|
@@ -687,7 +723,6 @@ compile: false
|
|
|
687
723
|
seed: 2025
|
|
688
724
|
```
|
|
689
725
|
|
|
690
|
-
> [!TIP]
|
|
691
726
|
> See [`configs/config.yaml`](configs/config.yaml) for the complete template with all available options documented.
|
|
692
727
|
|
|
693
728
|
</details>
|
|
@@ -753,7 +788,7 @@ accelerate launch -m wavedl.train --data_path train.npz --model cnn --lr 3.2e-4
|
|
|
753
788
|
| `--max_epochs` | `50` | Max epochs per trial |
|
|
754
789
|
| `--output` | `hpo_results.json` | Output file |
|
|
755
790
|
|
|
756
|
-
|
|
791
|
+
|
|
757
792
|
> See [Available Models](#available-models) for all 38 architectures you can search.
|
|
758
793
|
|
|
759
794
|
</details>
|
|
@@ -1,7 +1,7 @@
|
|
|
1
|
-
wavedl/__init__.py,sha256=
|
|
1
|
+
wavedl/__init__.py,sha256=2ro7SYQ3wCmq-ejiAm5sd6BeXf6sZgixC9U2vS7Ckbs,1177
|
|
2
2
|
wavedl/hpc.py,sha256=0h8IZzOT0EzmEv3fU9cKyRVE9V1ivtBzbjuBCaxYadc,8445
|
|
3
3
|
wavedl/hpo.py,sha256=YJXsnSGEBSVUqp_2ah7zu3_VClAUqZrdkuzDaSqQUjU,12952
|
|
4
|
-
wavedl/test.py,sha256=
|
|
4
|
+
wavedl/test.py,sha256=81al6vQBDAJ3CpSEtxZn6xzR1c4-jo28R7tX_84KROc,37642
|
|
5
5
|
wavedl/train.py,sha256=_pW7prvlNqfUGrGweHO2QelS87UiAYKvyJwqMAIj6yI,49292
|
|
6
6
|
wavedl/models/__init__.py,sha256=lfSohEnAUztO14nuwayMJhPjpgySzRN3jGiyAUuBmAU,3206
|
|
7
7
|
wavedl/models/_template.py,sha256=J_D8taSPmV8lBaucN_vU-WiG98iFr7CJrZVNNX_Tdts,4600
|
|
@@ -29,9 +29,9 @@ wavedl/utils/losses.py,sha256=5762M-TBC_hz6uyj1NPbU1vZeFOJQq7fR3-j7OygJRo,7254
|
|
|
29
29
|
wavedl/utils/metrics.py,sha256=mkCpqZwl_XUpNvA5Ekjf7y-HqApafR7eR6EuA8cBdM8,37287
|
|
30
30
|
wavedl/utils/optimizers.py,sha256=PyIkJ_hRhFi_Fio81Gy5YQNhcME0JUUEl8OTSyu-0RA,6323
|
|
31
31
|
wavedl/utils/schedulers.py,sha256=e6Sf0yj8VOqkdwkUHLMyUfGfHKTX4NMr-zfgxWqCTYI,7659
|
|
32
|
-
wavedl-1.4.
|
|
33
|
-
wavedl-1.4.
|
|
34
|
-
wavedl-1.4.
|
|
35
|
-
wavedl-1.4.
|
|
36
|
-
wavedl-1.4.
|
|
37
|
-
wavedl-1.4.
|
|
32
|
+
wavedl-1.4.5.dist-info/LICENSE,sha256=cEUCvcvH-9BT9Y-CNGY__PwWONCKu9zsoIqWA-NeHJ4,1066
|
|
33
|
+
wavedl-1.4.5.dist-info/METADATA,sha256=4ltxFDaqPqh4XUAW_K8nkFmvqBzPcL2cxmghH11GMWg,42191
|
|
34
|
+
wavedl-1.4.5.dist-info/WHEEL,sha256=beeZ86-EfXScwlR_HKu4SllMC9wUEj_8Z_4FJ3egI2w,91
|
|
35
|
+
wavedl-1.4.5.dist-info/entry_points.txt,sha256=f1RNDkXFZwBzrBzTMFocJ6xhfTvTmaEDTi5YyDEUaF8,140
|
|
36
|
+
wavedl-1.4.5.dist-info/top_level.txt,sha256=ccneUt3D5Qzbh3bsBSSrq9bqrhGiogcWKY24ZC4Q6Xw,7
|
|
37
|
+
wavedl-1.4.5.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|