wavedl 1.4.2__py3-none-any.whl → 1.4.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- wavedl/__init__.py +1 -1
- wavedl/hpc.py +31 -18
- wavedl/test.py +41 -18
- wavedl/train.py +60 -8
- {wavedl-1.4.2.dist-info → wavedl-1.4.4.dist-info}/METADATA +2 -1
- {wavedl-1.4.2.dist-info → wavedl-1.4.4.dist-info}/RECORD +10 -10
- {wavedl-1.4.2.dist-info → wavedl-1.4.4.dist-info}/LICENSE +0 -0
- {wavedl-1.4.2.dist-info → wavedl-1.4.4.dist-info}/WHEEL +0 -0
- {wavedl-1.4.2.dist-info → wavedl-1.4.4.dist-info}/entry_points.txt +0 -0
- {wavedl-1.4.2.dist-info → wavedl-1.4.4.dist-info}/top_level.txt +0 -0
wavedl/__init__.py
CHANGED
wavedl/hpc.py
CHANGED
|
@@ -26,7 +26,6 @@ import os
|
|
|
26
26
|
import shutil
|
|
27
27
|
import subprocess
|
|
28
28
|
import sys
|
|
29
|
-
import tempfile
|
|
30
29
|
from pathlib import Path
|
|
31
30
|
|
|
32
31
|
|
|
@@ -60,26 +59,40 @@ def setup_hpc_environment() -> None:
|
|
|
60
59
|
Handles restricted home directories (e.g., Compute Canada) and
|
|
61
60
|
offline logging configurations.
|
|
62
61
|
"""
|
|
63
|
-
#
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
#
|
|
76
|
-
|
|
62
|
+
# Check if home is writable
|
|
63
|
+
home = os.path.expanduser("~")
|
|
64
|
+
home_writable = os.access(home, os.W_OK)
|
|
65
|
+
|
|
66
|
+
# Use SLURM_TMPDIR if available, otherwise CWD for HPC, or system temp
|
|
67
|
+
if home_writable:
|
|
68
|
+
# Local machine - let libraries use defaults
|
|
69
|
+
cache_base = None
|
|
70
|
+
else:
|
|
71
|
+
# HPC with restricted home - use CWD for persistent caches
|
|
72
|
+
cache_base = os.getcwd()
|
|
73
|
+
|
|
74
|
+
# Only set environment variables if home is not writable
|
|
75
|
+
if cache_base:
|
|
76
|
+
os.environ.setdefault("TORCH_HOME", f"{cache_base}/.torch_cache")
|
|
77
|
+
os.environ.setdefault("MPLCONFIGDIR", f"{cache_base}/.matplotlib")
|
|
78
|
+
os.environ.setdefault("FONTCONFIG_CACHE", f"{cache_base}/.fontconfig")
|
|
79
|
+
os.environ.setdefault("XDG_CACHE_HOME", f"{cache_base}/.cache")
|
|
80
|
+
|
|
81
|
+
# Ensure directories exist
|
|
82
|
+
for env_var in [
|
|
83
|
+
"TORCH_HOME",
|
|
84
|
+
"MPLCONFIGDIR",
|
|
85
|
+
"FONTCONFIG_CACHE",
|
|
86
|
+
"XDG_CACHE_HOME",
|
|
87
|
+
]:
|
|
88
|
+
Path(os.environ[env_var]).mkdir(parents=True, exist_ok=True)
|
|
77
89
|
|
|
78
90
|
# WandB configuration (offline by default for HPC)
|
|
79
91
|
os.environ.setdefault("WANDB_MODE", "offline")
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
92
|
+
if cache_base:
|
|
93
|
+
os.environ.setdefault("WANDB_DIR", f"{cache_base}/.wandb")
|
|
94
|
+
os.environ.setdefault("WANDB_CACHE_DIR", f"{cache_base}/.wandb_cache")
|
|
95
|
+
os.environ.setdefault("WANDB_CONFIG_DIR", f"{cache_base}/.wandb_config")
|
|
83
96
|
|
|
84
97
|
# Suppress non-critical warnings
|
|
85
98
|
os.environ.setdefault(
|
wavedl/test.py
CHANGED
|
@@ -29,29 +29,52 @@ Author: Ductho Le (ductho.le@outlook.com)
|
|
|
29
29
|
# ==============================================================================
|
|
30
30
|
# ENVIRONMENT CONFIGURATION (must be before matplotlib import)
|
|
31
31
|
# ==============================================================================
|
|
32
|
+
# Auto-configure writable cache directories when home is not writable.
|
|
33
|
+
# Uses current working directory as fallback - works on HPC and local machines.
|
|
32
34
|
import os
|
|
33
35
|
|
|
34
36
|
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
37
|
+
def _setup_cache_dir(env_var: str, subdir: str) -> None:
|
|
38
|
+
"""Set cache directory to CWD if home is not writable."""
|
|
39
|
+
if env_var in os.environ:
|
|
40
|
+
return # User already set, respect their choice
|
|
41
|
+
|
|
42
|
+
# Check if home is writable
|
|
43
|
+
home = os.path.expanduser("~")
|
|
44
|
+
if os.access(home, os.W_OK):
|
|
45
|
+
return # Home is writable, let library use defaults
|
|
46
|
+
|
|
47
|
+
# Home not writable - use current working directory
|
|
48
|
+
cache_path = os.path.join(os.getcwd(), f".{subdir}")
|
|
49
|
+
os.makedirs(cache_path, exist_ok=True)
|
|
50
|
+
os.environ[env_var] = cache_path
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
# Configure cache directories (before any library imports)
|
|
54
|
+
_setup_cache_dir("TORCH_HOME", "torch_cache")
|
|
55
|
+
_setup_cache_dir("MPLCONFIGDIR", "matplotlib")
|
|
56
|
+
_setup_cache_dir("FONTCONFIG_CACHE", "fontconfig")
|
|
57
|
+
_setup_cache_dir("XDG_DATA_HOME", "local/share")
|
|
58
|
+
_setup_cache_dir("XDG_STATE_HOME", "local/state")
|
|
59
|
+
_setup_cache_dir("XDG_CACHE_HOME", "cache")
|
|
60
|
+
|
|
61
|
+
import argparse # noqa: E402
|
|
62
|
+
import logging # noqa: E402
|
|
63
|
+
import pickle # noqa: E402
|
|
64
|
+
from pathlib import Path # noqa: E402
|
|
65
|
+
|
|
66
|
+
import matplotlib.pyplot as plt # noqa: E402
|
|
67
|
+
import numpy as np # noqa: E402
|
|
68
|
+
import pandas as pd # noqa: E402
|
|
69
|
+
import torch # noqa: E402
|
|
70
|
+
import torch.nn as nn # noqa: E402
|
|
71
|
+
from sklearn.metrics import mean_absolute_error, r2_score # noqa: E402
|
|
72
|
+
from torch.utils.data import DataLoader, TensorDataset # noqa: E402
|
|
73
|
+
from tqdm.auto import tqdm # noqa: E402
|
|
51
74
|
|
|
52
75
|
# Local imports
|
|
53
|
-
from wavedl.models import build_model, list_models
|
|
54
|
-
from wavedl.utils import (
|
|
76
|
+
from wavedl.models import build_model, list_models # noqa: E402
|
|
77
|
+
from wavedl.utils import ( # noqa: E402
|
|
55
78
|
FIGURE_DPI,
|
|
56
79
|
calc_pearson,
|
|
57
80
|
load_test_data,
|
wavedl/train.py
CHANGED
|
@@ -37,9 +37,43 @@ Author: Ductho Le (ductho.le@outlook.com)
|
|
|
37
37
|
|
|
38
38
|
from __future__ import annotations
|
|
39
39
|
|
|
40
|
+
# =============================================================================
|
|
41
|
+
# HPC Environment Setup (MUST be before any library imports)
|
|
42
|
+
# =============================================================================
|
|
43
|
+
# Auto-configure writable cache directories when home is not writable.
|
|
44
|
+
# Uses current working directory as fallback - works on HPC and local machines.
|
|
45
|
+
import os
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
def _setup_cache_dir(env_var: str, subdir: str) -> None:
|
|
49
|
+
"""Set cache directory to CWD if home is not writable."""
|
|
50
|
+
if env_var in os.environ:
|
|
51
|
+
return # User already set, respect their choice
|
|
52
|
+
|
|
53
|
+
# Check if home is writable
|
|
54
|
+
home = os.path.expanduser("~")
|
|
55
|
+
if os.access(home, os.W_OK):
|
|
56
|
+
return # Home is writable, let library use defaults
|
|
57
|
+
|
|
58
|
+
# Home not writable - use current working directory
|
|
59
|
+
cache_path = os.path.join(os.getcwd(), f".{subdir}")
|
|
60
|
+
os.makedirs(cache_path, exist_ok=True)
|
|
61
|
+
os.environ[env_var] = cache_path
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
# Configure cache directories (before any library imports)
|
|
65
|
+
_setup_cache_dir("TORCH_HOME", "torch_cache")
|
|
66
|
+
_setup_cache_dir("MPLCONFIGDIR", "matplotlib")
|
|
67
|
+
_setup_cache_dir("FONTCONFIG_CACHE", "fontconfig")
|
|
68
|
+
_setup_cache_dir("XDG_DATA_HOME", "local/share")
|
|
69
|
+
_setup_cache_dir("XDG_STATE_HOME", "local/state")
|
|
70
|
+
_setup_cache_dir("XDG_CACHE_HOME", "cache")
|
|
71
|
+
|
|
72
|
+
# =============================================================================
|
|
73
|
+
# Standard imports (after environment setup)
|
|
74
|
+
# =============================================================================
|
|
40
75
|
import argparse
|
|
41
76
|
import logging
|
|
42
|
-
import os
|
|
43
77
|
import pickle
|
|
44
78
|
import shutil
|
|
45
79
|
import sys
|
|
@@ -47,6 +81,10 @@ import time
|
|
|
47
81
|
import warnings
|
|
48
82
|
from typing import Any
|
|
49
83
|
|
|
84
|
+
|
|
85
|
+
# Suppress Pydantic warnings from accelerate's internal Field() usage
|
|
86
|
+
warnings.filterwarnings("ignore", category=UserWarning, module="pydantic")
|
|
87
|
+
|
|
50
88
|
import matplotlib.pyplot as plt
|
|
51
89
|
import numpy as np
|
|
52
90
|
import pandas as pd
|
|
@@ -582,9 +620,9 @@ def main():
|
|
|
582
620
|
# Torch 2.0 compilation (requires compatible Triton on GPU)
|
|
583
621
|
if args.compile:
|
|
584
622
|
try:
|
|
585
|
-
# Test if Triton is available
|
|
586
|
-
#
|
|
587
|
-
|
|
623
|
+
# Test if Triton is available - just import the package
|
|
624
|
+
# Different Triton versions have different internal APIs, so just check base import
|
|
625
|
+
import triton
|
|
588
626
|
|
|
589
627
|
model = torch.compile(model)
|
|
590
628
|
if accelerator.is_main_process:
|
|
@@ -875,9 +913,13 @@ def main():
|
|
|
875
913
|
cpu_preds = torch.cat(local_preds)
|
|
876
914
|
cpu_targets = torch.cat(local_targets)
|
|
877
915
|
|
|
878
|
-
# Gather
|
|
879
|
-
#
|
|
880
|
-
|
|
916
|
+
# Gather predictions and targets across all ranks
|
|
917
|
+
# Use accelerator.gather (works with all accelerate versions)
|
|
918
|
+
gpu_preds = cpu_preds.to(accelerator.device)
|
|
919
|
+
gpu_targets = cpu_targets.to(accelerator.device)
|
|
920
|
+
all_preds_gathered = accelerator.gather(gpu_preds).cpu()
|
|
921
|
+
all_targets_gathered = accelerator.gather(gpu_targets).cpu()
|
|
922
|
+
gathered = [(all_preds_gathered, all_targets_gathered)]
|
|
881
923
|
|
|
882
924
|
# Synchronize validation metrics (scalars only - efficient)
|
|
883
925
|
val_loss_scalar = val_loss_sum.item()
|
|
@@ -1012,7 +1054,17 @@ def main():
|
|
|
1012
1054
|
f,
|
|
1013
1055
|
)
|
|
1014
1056
|
|
|
1015
|
-
|
|
1057
|
+
# Unwrap model for saving (handle torch.compile compatibility)
|
|
1058
|
+
try:
|
|
1059
|
+
unwrapped = accelerator.unwrap_model(model)
|
|
1060
|
+
except KeyError:
|
|
1061
|
+
# torch.compile model may not have _orig_mod in expected location
|
|
1062
|
+
# Fall back to getting the module directly
|
|
1063
|
+
unwrapped = model.module if hasattr(model, "module") else model
|
|
1064
|
+
# If still compiled, try to get the underlying model
|
|
1065
|
+
if hasattr(unwrapped, "_orig_mod"):
|
|
1066
|
+
unwrapped = unwrapped._orig_mod
|
|
1067
|
+
|
|
1016
1068
|
torch.save(
|
|
1017
1069
|
unwrapped.state_dict(),
|
|
1018
1070
|
os.path.join(args.output_dir, "best_model_weights.pth"),
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.2
|
|
2
2
|
Name: wavedl
|
|
3
|
-
Version: 1.4.
|
|
3
|
+
Version: 1.4.4
|
|
4
4
|
Summary: A Scalable Deep Learning Framework for Wave-Based Inverse Problems
|
|
5
5
|
Author: Ductho Le
|
|
6
6
|
License: MIT
|
|
@@ -57,6 +57,7 @@ Requires-Dist: triton>=2.0.0; sys_platform == "linux"
|
|
|
57
57
|
[](https://github.com/ductho-le/WaveDL/actions/workflows/lint.yml)
|
|
58
58
|
[](https://colab.research.google.com/github/ductho-le/WaveDL/blob/main/notebooks/demo.ipynb)
|
|
59
59
|
<br>
|
|
60
|
+
[](https://pepy.tech/project/wavedl)
|
|
60
61
|
[](LICENSE)
|
|
61
62
|
[](https://doi.org/10.5281/zenodo.18012338)
|
|
62
63
|
|
|
@@ -1,8 +1,8 @@
|
|
|
1
|
-
wavedl/__init__.py,sha256=
|
|
2
|
-
wavedl/hpc.py,sha256=
|
|
1
|
+
wavedl/__init__.py,sha256=n0XNSrp0aGEE6HQpzbF1wiKK-jORKgcc0Q2Op4MtQGk,1177
|
|
2
|
+
wavedl/hpc.py,sha256=0h8IZzOT0EzmEv3fU9cKyRVE9V1ivtBzbjuBCaxYadc,8445
|
|
3
3
|
wavedl/hpo.py,sha256=YJXsnSGEBSVUqp_2ah7zu3_VClAUqZrdkuzDaSqQUjU,12952
|
|
4
|
-
wavedl/test.py,sha256=
|
|
5
|
-
wavedl/train.py,sha256=
|
|
4
|
+
wavedl/test.py,sha256=Wajcze8gFEyJ9VyN_Bq-YadZ_VZtVaX_HicvUmW6MXM,37365
|
|
5
|
+
wavedl/train.py,sha256=_pW7prvlNqfUGrGweHO2QelS87UiAYKvyJwqMAIj6yI,49292
|
|
6
6
|
wavedl/models/__init__.py,sha256=lfSohEnAUztO14nuwayMJhPjpgySzRN3jGiyAUuBmAU,3206
|
|
7
7
|
wavedl/models/_template.py,sha256=J_D8taSPmV8lBaucN_vU-WiG98iFr7CJrZVNNX_Tdts,4600
|
|
8
8
|
wavedl/models/base.py,sha256=T9iDF9IQM2MYucG_ggQd31rieUkB2fob-nkHyNIl2ak,7337
|
|
@@ -29,9 +29,9 @@ wavedl/utils/losses.py,sha256=5762M-TBC_hz6uyj1NPbU1vZeFOJQq7fR3-j7OygJRo,7254
|
|
|
29
29
|
wavedl/utils/metrics.py,sha256=mkCpqZwl_XUpNvA5Ekjf7y-HqApafR7eR6EuA8cBdM8,37287
|
|
30
30
|
wavedl/utils/optimizers.py,sha256=PyIkJ_hRhFi_Fio81Gy5YQNhcME0JUUEl8OTSyu-0RA,6323
|
|
31
31
|
wavedl/utils/schedulers.py,sha256=e6Sf0yj8VOqkdwkUHLMyUfGfHKTX4NMr-zfgxWqCTYI,7659
|
|
32
|
-
wavedl-1.4.
|
|
33
|
-
wavedl-1.4.
|
|
34
|
-
wavedl-1.4.
|
|
35
|
-
wavedl-1.4.
|
|
36
|
-
wavedl-1.4.
|
|
37
|
-
wavedl-1.4.
|
|
32
|
+
wavedl-1.4.4.dist-info/LICENSE,sha256=cEUCvcvH-9BT9Y-CNGY__PwWONCKu9zsoIqWA-NeHJ4,1066
|
|
33
|
+
wavedl-1.4.4.dist-info/METADATA,sha256=4tgSvkwzJmZP3PLCrqx-FYV_w6VT6Mi4XIsB0Dvb6_0,40386
|
|
34
|
+
wavedl-1.4.4.dist-info/WHEEL,sha256=beeZ86-EfXScwlR_HKu4SllMC9wUEj_8Z_4FJ3egI2w,91
|
|
35
|
+
wavedl-1.4.4.dist-info/entry_points.txt,sha256=f1RNDkXFZwBzrBzTMFocJ6xhfTvTmaEDTi5YyDEUaF8,140
|
|
36
|
+
wavedl-1.4.4.dist-info/top_level.txt,sha256=ccneUt3D5Qzbh3bsBSSrq9bqrhGiogcWKY24ZC4Q6Xw,7
|
|
37
|
+
wavedl-1.4.4.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|