wavedl 1.4.3__py3-none-any.whl → 1.4.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- wavedl/__init__.py +1 -1
- wavedl/hpc.py +31 -18
- wavedl/test.py +53 -22
- wavedl/train.py +27 -28
- {wavedl-1.4.3.dist-info → wavedl-1.4.5.dist-info}/METADATA +39 -4
- {wavedl-1.4.3.dist-info → wavedl-1.4.5.dist-info}/RECORD +10 -10
- {wavedl-1.4.3.dist-info → wavedl-1.4.5.dist-info}/LICENSE +0 -0
- {wavedl-1.4.3.dist-info → wavedl-1.4.5.dist-info}/WHEEL +0 -0
- {wavedl-1.4.3.dist-info → wavedl-1.4.5.dist-info}/entry_points.txt +0 -0
- {wavedl-1.4.3.dist-info → wavedl-1.4.5.dist-info}/top_level.txt +0 -0
wavedl/__init__.py
CHANGED
wavedl/hpc.py
CHANGED
|
@@ -26,7 +26,6 @@ import os
|
|
|
26
26
|
import shutil
|
|
27
27
|
import subprocess
|
|
28
28
|
import sys
|
|
29
|
-
import tempfile
|
|
30
29
|
from pathlib import Path
|
|
31
30
|
|
|
32
31
|
|
|
@@ -60,26 +59,40 @@ def setup_hpc_environment() -> None:
|
|
|
60
59
|
Handles restricted home directories (e.g., Compute Canada) and
|
|
61
60
|
offline logging configurations.
|
|
62
61
|
"""
|
|
63
|
-
#
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
#
|
|
76
|
-
|
|
62
|
+
# Check if home is writable
|
|
63
|
+
home = os.path.expanduser("~")
|
|
64
|
+
home_writable = os.access(home, os.W_OK)
|
|
65
|
+
|
|
66
|
+
# Use SLURM_TMPDIR if available, otherwise CWD for HPC, or system temp
|
|
67
|
+
if home_writable:
|
|
68
|
+
# Local machine - let libraries use defaults
|
|
69
|
+
cache_base = None
|
|
70
|
+
else:
|
|
71
|
+
# HPC with restricted home - use CWD for persistent caches
|
|
72
|
+
cache_base = os.getcwd()
|
|
73
|
+
|
|
74
|
+
# Only set environment variables if home is not writable
|
|
75
|
+
if cache_base:
|
|
76
|
+
os.environ.setdefault("TORCH_HOME", f"{cache_base}/.torch_cache")
|
|
77
|
+
os.environ.setdefault("MPLCONFIGDIR", f"{cache_base}/.matplotlib")
|
|
78
|
+
os.environ.setdefault("FONTCONFIG_CACHE", f"{cache_base}/.fontconfig")
|
|
79
|
+
os.environ.setdefault("XDG_CACHE_HOME", f"{cache_base}/.cache")
|
|
80
|
+
|
|
81
|
+
# Ensure directories exist
|
|
82
|
+
for env_var in [
|
|
83
|
+
"TORCH_HOME",
|
|
84
|
+
"MPLCONFIGDIR",
|
|
85
|
+
"FONTCONFIG_CACHE",
|
|
86
|
+
"XDG_CACHE_HOME",
|
|
87
|
+
]:
|
|
88
|
+
Path(os.environ[env_var]).mkdir(parents=True, exist_ok=True)
|
|
77
89
|
|
|
78
90
|
# WandB configuration (offline by default for HPC)
|
|
79
91
|
os.environ.setdefault("WANDB_MODE", "offline")
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
92
|
+
if cache_base:
|
|
93
|
+
os.environ.setdefault("WANDB_DIR", f"{cache_base}/.wandb")
|
|
94
|
+
os.environ.setdefault("WANDB_CACHE_DIR", f"{cache_base}/.wandb_cache")
|
|
95
|
+
os.environ.setdefault("WANDB_CONFIG_DIR", f"{cache_base}/.wandb_config")
|
|
83
96
|
|
|
84
97
|
# Suppress non-critical warnings
|
|
85
98
|
os.environ.setdefault(
|
wavedl/test.py
CHANGED
|
@@ -29,29 +29,52 @@ Author: Ductho Le (ductho.le@outlook.com)
|
|
|
29
29
|
# ==============================================================================
|
|
30
30
|
# ENVIRONMENT CONFIGURATION (must be before matplotlib import)
|
|
31
31
|
# ==============================================================================
|
|
32
|
+
# Auto-configure writable cache directories when home is not writable.
|
|
33
|
+
# Uses current working directory as fallback - works on HPC and local machines.
|
|
32
34
|
import os
|
|
33
35
|
|
|
34
36
|
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
37
|
+
def _setup_cache_dir(env_var: str, subdir: str) -> None:
|
|
38
|
+
"""Set cache directory to CWD if home is not writable."""
|
|
39
|
+
if env_var in os.environ:
|
|
40
|
+
return # User already set, respect their choice
|
|
41
|
+
|
|
42
|
+
# Check if home is writable
|
|
43
|
+
home = os.path.expanduser("~")
|
|
44
|
+
if os.access(home, os.W_OK):
|
|
45
|
+
return # Home is writable, let library use defaults
|
|
46
|
+
|
|
47
|
+
# Home not writable - use current working directory
|
|
48
|
+
cache_path = os.path.join(os.getcwd(), f".{subdir}")
|
|
49
|
+
os.makedirs(cache_path, exist_ok=True)
|
|
50
|
+
os.environ[env_var] = cache_path
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
# Configure cache directories (before any library imports)
|
|
54
|
+
_setup_cache_dir("TORCH_HOME", "torch_cache")
|
|
55
|
+
_setup_cache_dir("MPLCONFIGDIR", "matplotlib")
|
|
56
|
+
_setup_cache_dir("FONTCONFIG_CACHE", "fontconfig")
|
|
57
|
+
_setup_cache_dir("XDG_DATA_HOME", "local/share")
|
|
58
|
+
_setup_cache_dir("XDG_STATE_HOME", "local/state")
|
|
59
|
+
_setup_cache_dir("XDG_CACHE_HOME", "cache")
|
|
60
|
+
|
|
61
|
+
import argparse # noqa: E402
|
|
62
|
+
import logging # noqa: E402
|
|
63
|
+
import pickle # noqa: E402
|
|
64
|
+
from pathlib import Path # noqa: E402
|
|
65
|
+
|
|
66
|
+
import matplotlib.pyplot as plt # noqa: E402
|
|
67
|
+
import numpy as np # noqa: E402
|
|
68
|
+
import pandas as pd # noqa: E402
|
|
69
|
+
import torch # noqa: E402
|
|
70
|
+
import torch.nn as nn # noqa: E402
|
|
71
|
+
from sklearn.metrics import mean_absolute_error, r2_score # noqa: E402
|
|
72
|
+
from torch.utils.data import DataLoader, TensorDataset # noqa: E402
|
|
73
|
+
from tqdm.auto import tqdm # noqa: E402
|
|
51
74
|
|
|
52
75
|
# Local imports
|
|
53
|
-
from wavedl.models import build_model, list_models
|
|
54
|
-
from wavedl.utils import (
|
|
76
|
+
from wavedl.models import build_model, list_models # noqa: E402
|
|
77
|
+
from wavedl.utils import ( # noqa: E402
|
|
55
78
|
FIGURE_DPI,
|
|
56
79
|
calc_pearson,
|
|
57
80
|
load_test_data,
|
|
@@ -356,10 +379,18 @@ def load_checkpoint(
|
|
|
356
379
|
else:
|
|
357
380
|
state_dict = torch.load(weight_path, map_location="cpu", weights_only=True)
|
|
358
381
|
|
|
359
|
-
# Remove
|
|
360
|
-
|
|
361
|
-
|
|
362
|
-
}
|
|
382
|
+
# Remove wrapper prefixes from checkpoints:
|
|
383
|
+
# - 'module.' from DDP (DistributedDataParallel)
|
|
384
|
+
# - '_orig_mod.' from torch.compile()
|
|
385
|
+
cleaned_dict = {}
|
|
386
|
+
for k, v in state_dict.items():
|
|
387
|
+
key = k
|
|
388
|
+
if key.startswith("module."):
|
|
389
|
+
key = key[7:] # Remove 'module.' (7 chars)
|
|
390
|
+
if key.startswith("_orig_mod."):
|
|
391
|
+
key = key[10:] # Remove '_orig_mod.' (10 chars)
|
|
392
|
+
cleaned_dict[key] = v
|
|
393
|
+
state_dict = cleaned_dict
|
|
363
394
|
|
|
364
395
|
model.load_state_dict(state_dict)
|
|
365
396
|
model.eval()
|
wavedl/train.py
CHANGED
|
@@ -40,45 +40,34 @@ from __future__ import annotations
|
|
|
40
40
|
# =============================================================================
|
|
41
41
|
# HPC Environment Setup (MUST be before any library imports)
|
|
42
42
|
# =============================================================================
|
|
43
|
-
#
|
|
44
|
-
#
|
|
43
|
+
# Auto-configure writable cache directories when home is not writable.
|
|
44
|
+
# Uses current working directory as fallback - works on HPC and local machines.
|
|
45
45
|
import os
|
|
46
|
-
import tempfile
|
|
47
46
|
|
|
48
47
|
|
|
49
|
-
def _setup_cache_dir(env_var: str,
|
|
50
|
-
"""Set cache directory
|
|
48
|
+
def _setup_cache_dir(env_var: str, subdir: str) -> None:
|
|
49
|
+
"""Set cache directory to CWD if home is not writable."""
|
|
51
50
|
if env_var in os.environ:
|
|
52
51
|
return # User already set, respect their choice
|
|
53
52
|
|
|
54
|
-
# Check if
|
|
53
|
+
# Check if home is writable
|
|
55
54
|
home = os.path.expanduser("~")
|
|
56
|
-
|
|
57
|
-
|
|
55
|
+
if os.access(home, os.W_OK):
|
|
56
|
+
return # Home is writable, let library use defaults
|
|
58
57
|
|
|
59
|
-
#
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
or os.access(os.path.join(home, ".config"), os.W_OK)
|
|
64
|
-
):
|
|
65
|
-
return
|
|
66
|
-
|
|
67
|
-
# Default not writable - find alternative location
|
|
68
|
-
for cache_base in [
|
|
69
|
-
os.environ.get("SCRATCH"),
|
|
70
|
-
os.environ.get("SLURM_TMPDIR"),
|
|
71
|
-
tempfile.gettempdir(),
|
|
72
|
-
]:
|
|
73
|
-
if cache_base and os.access(cache_base, os.W_OK):
|
|
74
|
-
cache_path = os.path.join(cache_base, f".{default_subpath}")
|
|
75
|
-
os.makedirs(cache_path, exist_ok=True)
|
|
76
|
-
os.environ[env_var] = cache_path
|
|
77
|
-
return
|
|
58
|
+
# Home not writable - use current working directory
|
|
59
|
+
cache_path = os.path.join(os.getcwd(), f".{subdir}")
|
|
60
|
+
os.makedirs(cache_path, exist_ok=True)
|
|
61
|
+
os.environ[env_var] = cache_path
|
|
78
62
|
|
|
79
63
|
|
|
64
|
+
# Configure cache directories (before any library imports)
|
|
65
|
+
_setup_cache_dir("TORCH_HOME", "torch_cache")
|
|
80
66
|
_setup_cache_dir("MPLCONFIGDIR", "matplotlib")
|
|
81
67
|
_setup_cache_dir("FONTCONFIG_CACHE", "fontconfig")
|
|
68
|
+
_setup_cache_dir("XDG_DATA_HOME", "local/share")
|
|
69
|
+
_setup_cache_dir("XDG_STATE_HOME", "local/state")
|
|
70
|
+
_setup_cache_dir("XDG_CACHE_HOME", "cache")
|
|
82
71
|
|
|
83
72
|
# =============================================================================
|
|
84
73
|
# Standard imports (after environment setup)
|
|
@@ -1065,7 +1054,17 @@ def main():
|
|
|
1065
1054
|
f,
|
|
1066
1055
|
)
|
|
1067
1056
|
|
|
1068
|
-
|
|
1057
|
+
# Unwrap model for saving (handle torch.compile compatibility)
|
|
1058
|
+
try:
|
|
1059
|
+
unwrapped = accelerator.unwrap_model(model)
|
|
1060
|
+
except KeyError:
|
|
1061
|
+
# torch.compile model may not have _orig_mod in expected location
|
|
1062
|
+
# Fall back to getting the module directly
|
|
1063
|
+
unwrapped = model.module if hasattr(model, "module") else model
|
|
1064
|
+
# If still compiled, try to get the underlying model
|
|
1065
|
+
if hasattr(unwrapped, "_orig_mod"):
|
|
1066
|
+
unwrapped = unwrapped._orig_mod
|
|
1067
|
+
|
|
1069
1068
|
torch.save(
|
|
1070
1069
|
unwrapped.state_dict(),
|
|
1071
1070
|
os.path.join(args.output_dir, "best_model_weights.pth"),
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.2
|
|
2
2
|
Name: wavedl
|
|
3
|
-
Version: 1.4.
|
|
3
|
+
Version: 1.4.5
|
|
4
4
|
Summary: A Scalable Deep Learning Framework for Wave-Based Inverse Problems
|
|
5
5
|
Author: Ductho Le
|
|
6
6
|
License: MIT
|
|
@@ -462,7 +462,43 @@ WaveDL/
|
|
|
462
462
|
| **U-Net** — U-shaped Network |||
|
|
463
463
|
| `unet_regression` | 31.1M | 1D/2D/3D |
|
|
464
464
|
|
|
465
|
-
|
|
465
|
+
⭐ = **Pretrained on ImageNet** (recommended for smaller datasets). Weights are downloaded automatically on first use.
|
|
466
|
+
- **Cache location**: `~/.cache/torch/hub/checkpoints/` (or `./.torch_cache/` on HPC if home is not writable)
|
|
467
|
+
- **Size**: ~20–350 MB per model depending on architecture
|
|
468
|
+
|
|
469
|
+
**💡 HPC Users**: If compute nodes block internet, pre-download weights on the login node:
|
|
470
|
+
|
|
471
|
+
```bash
|
|
472
|
+
# Run once on login node (with internet) — downloads ALL pretrained weights (~1.5 GB total)
|
|
473
|
+
python -c "
|
|
474
|
+
import os
|
|
475
|
+
os.environ['TORCH_HOME'] = '.torch_cache' # Match WaveDL's HPC cache location
|
|
476
|
+
|
|
477
|
+
from torchvision import models as m
|
|
478
|
+
from torchvision.models import video as v
|
|
479
|
+
|
|
480
|
+
# Model name -> Weights class mapping
|
|
481
|
+
weights = {
|
|
482
|
+
'resnet18': m.ResNet18_Weights, 'resnet50': m.ResNet50_Weights,
|
|
483
|
+
'efficientnet_b0': m.EfficientNet_B0_Weights, 'efficientnet_b1': m.EfficientNet_B1_Weights,
|
|
484
|
+
'efficientnet_b2': m.EfficientNet_B2_Weights, 'efficientnet_v2_s': m.EfficientNet_V2_S_Weights,
|
|
485
|
+
'efficientnet_v2_m': m.EfficientNet_V2_M_Weights, 'efficientnet_v2_l': m.EfficientNet_V2_L_Weights,
|
|
486
|
+
'mobilenet_v3_small': m.MobileNet_V3_Small_Weights, 'mobilenet_v3_large': m.MobileNet_V3_Large_Weights,
|
|
487
|
+
'regnet_y_400mf': m.RegNet_Y_400MF_Weights, 'regnet_y_800mf': m.RegNet_Y_800MF_Weights,
|
|
488
|
+
'regnet_y_1_6gf': m.RegNet_Y_1_6GF_Weights, 'regnet_y_3_2gf': m.RegNet_Y_3_2GF_Weights,
|
|
489
|
+
'regnet_y_8gf': m.RegNet_Y_8GF_Weights, 'swin_t': m.Swin_T_Weights, 'swin_s': m.Swin_S_Weights,
|
|
490
|
+
'swin_b': m.Swin_B_Weights, 'convnext_tiny': m.ConvNeXt_Tiny_Weights, 'densenet121': m.DenseNet121_Weights,
|
|
491
|
+
}
|
|
492
|
+
for name, w in weights.items():
|
|
493
|
+
getattr(m, name)(weights=w.DEFAULT); print(f'✓ {name}')
|
|
494
|
+
|
|
495
|
+
# 3D video models
|
|
496
|
+
v.r3d_18(weights=v.R3D_18_Weights.DEFAULT); print('✓ r3d_18')
|
|
497
|
+
v.mc3_18(weights=v.MC3_18_Weights.DEFAULT); print('✓ mc3_18')
|
|
498
|
+
print('\\n✓ All pretrained weights cached!')
|
|
499
|
+
"
|
|
500
|
+
```
|
|
501
|
+
|
|
466
502
|
|
|
467
503
|
</details>
|
|
468
504
|
|
|
@@ -687,7 +723,6 @@ compile: false
|
|
|
687
723
|
seed: 2025
|
|
688
724
|
```
|
|
689
725
|
|
|
690
|
-
> [!TIP]
|
|
691
726
|
> See [`configs/config.yaml`](configs/config.yaml) for the complete template with all available options documented.
|
|
692
727
|
|
|
693
728
|
</details>
|
|
@@ -753,7 +788,7 @@ accelerate launch -m wavedl.train --data_path train.npz --model cnn --lr 3.2e-4
|
|
|
753
788
|
| `--max_epochs` | `50` | Max epochs per trial |
|
|
754
789
|
| `--output` | `hpo_results.json` | Output file |
|
|
755
790
|
|
|
756
|
-
|
|
791
|
+
|
|
757
792
|
> See [Available Models](#available-models) for all 38 architectures you can search.
|
|
758
793
|
|
|
759
794
|
</details>
|
|
@@ -1,8 +1,8 @@
|
|
|
1
|
-
wavedl/__init__.py,sha256=
|
|
2
|
-
wavedl/hpc.py,sha256=
|
|
1
|
+
wavedl/__init__.py,sha256=2ro7SYQ3wCmq-ejiAm5sd6BeXf6sZgixC9U2vS7Ckbs,1177
|
|
2
|
+
wavedl/hpc.py,sha256=0h8IZzOT0EzmEv3fU9cKyRVE9V1ivtBzbjuBCaxYadc,8445
|
|
3
3
|
wavedl/hpo.py,sha256=YJXsnSGEBSVUqp_2ah7zu3_VClAUqZrdkuzDaSqQUjU,12952
|
|
4
|
-
wavedl/test.py,sha256=
|
|
5
|
-
wavedl/train.py,sha256=
|
|
4
|
+
wavedl/test.py,sha256=81al6vQBDAJ3CpSEtxZn6xzR1c4-jo28R7tX_84KROc,37642
|
|
5
|
+
wavedl/train.py,sha256=_pW7prvlNqfUGrGweHO2QelS87UiAYKvyJwqMAIj6yI,49292
|
|
6
6
|
wavedl/models/__init__.py,sha256=lfSohEnAUztO14nuwayMJhPjpgySzRN3jGiyAUuBmAU,3206
|
|
7
7
|
wavedl/models/_template.py,sha256=J_D8taSPmV8lBaucN_vU-WiG98iFr7CJrZVNNX_Tdts,4600
|
|
8
8
|
wavedl/models/base.py,sha256=T9iDF9IQM2MYucG_ggQd31rieUkB2fob-nkHyNIl2ak,7337
|
|
@@ -29,9 +29,9 @@ wavedl/utils/losses.py,sha256=5762M-TBC_hz6uyj1NPbU1vZeFOJQq7fR3-j7OygJRo,7254
|
|
|
29
29
|
wavedl/utils/metrics.py,sha256=mkCpqZwl_XUpNvA5Ekjf7y-HqApafR7eR6EuA8cBdM8,37287
|
|
30
30
|
wavedl/utils/optimizers.py,sha256=PyIkJ_hRhFi_Fio81Gy5YQNhcME0JUUEl8OTSyu-0RA,6323
|
|
31
31
|
wavedl/utils/schedulers.py,sha256=e6Sf0yj8VOqkdwkUHLMyUfGfHKTX4NMr-zfgxWqCTYI,7659
|
|
32
|
-
wavedl-1.4.
|
|
33
|
-
wavedl-1.4.
|
|
34
|
-
wavedl-1.4.
|
|
35
|
-
wavedl-1.4.
|
|
36
|
-
wavedl-1.4.
|
|
37
|
-
wavedl-1.4.
|
|
32
|
+
wavedl-1.4.5.dist-info/LICENSE,sha256=cEUCvcvH-9BT9Y-CNGY__PwWONCKu9zsoIqWA-NeHJ4,1066
|
|
33
|
+
wavedl-1.4.5.dist-info/METADATA,sha256=4ltxFDaqPqh4XUAW_K8nkFmvqBzPcL2cxmghH11GMWg,42191
|
|
34
|
+
wavedl-1.4.5.dist-info/WHEEL,sha256=beeZ86-EfXScwlR_HKu4SllMC9wUEj_8Z_4FJ3egI2w,91
|
|
35
|
+
wavedl-1.4.5.dist-info/entry_points.txt,sha256=f1RNDkXFZwBzrBzTMFocJ6xhfTvTmaEDTi5YyDEUaF8,140
|
|
36
|
+
wavedl-1.4.5.dist-info/top_level.txt,sha256=ccneUt3D5Qzbh3bsBSSrq9bqrhGiogcWKY24ZC4Q6Xw,7
|
|
37
|
+
wavedl-1.4.5.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|