wavedl 1.4.3__py3-none-any.whl → 1.4.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- wavedl/__init__.py +1 -1
- wavedl/hpc.py +31 -18
- wavedl/test.py +41 -18
- wavedl/train.py +27 -28
- {wavedl-1.4.3.dist-info → wavedl-1.4.4.dist-info}/METADATA +1 -1
- {wavedl-1.4.3.dist-info → wavedl-1.4.4.dist-info}/RECORD +10 -10
- {wavedl-1.4.3.dist-info → wavedl-1.4.4.dist-info}/LICENSE +0 -0
- {wavedl-1.4.3.dist-info → wavedl-1.4.4.dist-info}/WHEEL +0 -0
- {wavedl-1.4.3.dist-info → wavedl-1.4.4.dist-info}/entry_points.txt +0 -0
- {wavedl-1.4.3.dist-info → wavedl-1.4.4.dist-info}/top_level.txt +0 -0
wavedl/__init__.py
CHANGED
wavedl/hpc.py
CHANGED
|
@@ -26,7 +26,6 @@ import os
|
|
|
26
26
|
import shutil
|
|
27
27
|
import subprocess
|
|
28
28
|
import sys
|
|
29
|
-
import tempfile
|
|
30
29
|
from pathlib import Path
|
|
31
30
|
|
|
32
31
|
|
|
@@ -60,26 +59,40 @@ def setup_hpc_environment() -> None:
|
|
|
60
59
|
Handles restricted home directories (e.g., Compute Canada) and
|
|
61
60
|
offline logging configurations.
|
|
62
61
|
"""
|
|
63
|
-
#
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
#
|
|
76
|
-
|
|
62
|
+
# Check if home is writable
|
|
63
|
+
home = os.path.expanduser("~")
|
|
64
|
+
home_writable = os.access(home, os.W_OK)
|
|
65
|
+
|
|
66
|
+
# Use SLURM_TMPDIR if available, otherwise CWD for HPC, or system temp
|
|
67
|
+
if home_writable:
|
|
68
|
+
# Local machine - let libraries use defaults
|
|
69
|
+
cache_base = None
|
|
70
|
+
else:
|
|
71
|
+
# HPC with restricted home - use CWD for persistent caches
|
|
72
|
+
cache_base = os.getcwd()
|
|
73
|
+
|
|
74
|
+
# Only set environment variables if home is not writable
|
|
75
|
+
if cache_base:
|
|
76
|
+
os.environ.setdefault("TORCH_HOME", f"{cache_base}/.torch_cache")
|
|
77
|
+
os.environ.setdefault("MPLCONFIGDIR", f"{cache_base}/.matplotlib")
|
|
78
|
+
os.environ.setdefault("FONTCONFIG_CACHE", f"{cache_base}/.fontconfig")
|
|
79
|
+
os.environ.setdefault("XDG_CACHE_HOME", f"{cache_base}/.cache")
|
|
80
|
+
|
|
81
|
+
# Ensure directories exist
|
|
82
|
+
for env_var in [
|
|
83
|
+
"TORCH_HOME",
|
|
84
|
+
"MPLCONFIGDIR",
|
|
85
|
+
"FONTCONFIG_CACHE",
|
|
86
|
+
"XDG_CACHE_HOME",
|
|
87
|
+
]:
|
|
88
|
+
Path(os.environ[env_var]).mkdir(parents=True, exist_ok=True)
|
|
77
89
|
|
|
78
90
|
# WandB configuration (offline by default for HPC)
|
|
79
91
|
os.environ.setdefault("WANDB_MODE", "offline")
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
92
|
+
if cache_base:
|
|
93
|
+
os.environ.setdefault("WANDB_DIR", f"{cache_base}/.wandb")
|
|
94
|
+
os.environ.setdefault("WANDB_CACHE_DIR", f"{cache_base}/.wandb_cache")
|
|
95
|
+
os.environ.setdefault("WANDB_CONFIG_DIR", f"{cache_base}/.wandb_config")
|
|
83
96
|
|
|
84
97
|
# Suppress non-critical warnings
|
|
85
98
|
os.environ.setdefault(
|
wavedl/test.py
CHANGED
|
@@ -29,29 +29,52 @@ Author: Ductho Le (ductho.le@outlook.com)
|
|
|
29
29
|
# ==============================================================================
|
|
30
30
|
# ENVIRONMENT CONFIGURATION (must be before matplotlib import)
|
|
31
31
|
# ==============================================================================
|
|
32
|
+
# Auto-configure writable cache directories when home is not writable.
|
|
33
|
+
# Uses current working directory as fallback - works on HPC and local machines.
|
|
32
34
|
import os
|
|
33
35
|
|
|
34
36
|
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
37
|
+
def _setup_cache_dir(env_var: str, subdir: str) -> None:
|
|
38
|
+
"""Set cache directory to CWD if home is not writable."""
|
|
39
|
+
if env_var in os.environ:
|
|
40
|
+
return # User already set, respect their choice
|
|
41
|
+
|
|
42
|
+
# Check if home is writable
|
|
43
|
+
home = os.path.expanduser("~")
|
|
44
|
+
if os.access(home, os.W_OK):
|
|
45
|
+
return # Home is writable, let library use defaults
|
|
46
|
+
|
|
47
|
+
# Home not writable - use current working directory
|
|
48
|
+
cache_path = os.path.join(os.getcwd(), f".{subdir}")
|
|
49
|
+
os.makedirs(cache_path, exist_ok=True)
|
|
50
|
+
os.environ[env_var] = cache_path
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
# Configure cache directories (before any library imports)
|
|
54
|
+
_setup_cache_dir("TORCH_HOME", "torch_cache")
|
|
55
|
+
_setup_cache_dir("MPLCONFIGDIR", "matplotlib")
|
|
56
|
+
_setup_cache_dir("FONTCONFIG_CACHE", "fontconfig")
|
|
57
|
+
_setup_cache_dir("XDG_DATA_HOME", "local/share")
|
|
58
|
+
_setup_cache_dir("XDG_STATE_HOME", "local/state")
|
|
59
|
+
_setup_cache_dir("XDG_CACHE_HOME", "cache")
|
|
60
|
+
|
|
61
|
+
import argparse # noqa: E402
|
|
62
|
+
import logging # noqa: E402
|
|
63
|
+
import pickle # noqa: E402
|
|
64
|
+
from pathlib import Path # noqa: E402
|
|
65
|
+
|
|
66
|
+
import matplotlib.pyplot as plt # noqa: E402
|
|
67
|
+
import numpy as np # noqa: E402
|
|
68
|
+
import pandas as pd # noqa: E402
|
|
69
|
+
import torch # noqa: E402
|
|
70
|
+
import torch.nn as nn # noqa: E402
|
|
71
|
+
from sklearn.metrics import mean_absolute_error, r2_score # noqa: E402
|
|
72
|
+
from torch.utils.data import DataLoader, TensorDataset # noqa: E402
|
|
73
|
+
from tqdm.auto import tqdm # noqa: E402
|
|
51
74
|
|
|
52
75
|
# Local imports
|
|
53
|
-
from wavedl.models import build_model, list_models
|
|
54
|
-
from wavedl.utils import (
|
|
76
|
+
from wavedl.models import build_model, list_models # noqa: E402
|
|
77
|
+
from wavedl.utils import ( # noqa: E402
|
|
55
78
|
FIGURE_DPI,
|
|
56
79
|
calc_pearson,
|
|
57
80
|
load_test_data,
|
wavedl/train.py
CHANGED
|
@@ -40,45 +40,34 @@ from __future__ import annotations
|
|
|
40
40
|
# =============================================================================
|
|
41
41
|
# HPC Environment Setup (MUST be before any library imports)
|
|
42
42
|
# =============================================================================
|
|
43
|
-
#
|
|
44
|
-
#
|
|
43
|
+
# Auto-configure writable cache directories when home is not writable.
|
|
44
|
+
# Uses current working directory as fallback - works on HPC and local machines.
|
|
45
45
|
import os
|
|
46
|
-
import tempfile
|
|
47
46
|
|
|
48
47
|
|
|
49
|
-
def _setup_cache_dir(env_var: str,
|
|
50
|
-
"""Set cache directory
|
|
48
|
+
def _setup_cache_dir(env_var: str, subdir: str) -> None:
|
|
49
|
+
"""Set cache directory to CWD if home is not writable."""
|
|
51
50
|
if env_var in os.environ:
|
|
52
51
|
return # User already set, respect their choice
|
|
53
52
|
|
|
54
|
-
# Check if
|
|
53
|
+
# Check if home is writable
|
|
55
54
|
home = os.path.expanduser("~")
|
|
56
|
-
|
|
57
|
-
|
|
55
|
+
if os.access(home, os.W_OK):
|
|
56
|
+
return # Home is writable, let library use defaults
|
|
58
57
|
|
|
59
|
-
#
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
or os.access(os.path.join(home, ".config"), os.W_OK)
|
|
64
|
-
):
|
|
65
|
-
return
|
|
66
|
-
|
|
67
|
-
# Default not writable - find alternative location
|
|
68
|
-
for cache_base in [
|
|
69
|
-
os.environ.get("SCRATCH"),
|
|
70
|
-
os.environ.get("SLURM_TMPDIR"),
|
|
71
|
-
tempfile.gettempdir(),
|
|
72
|
-
]:
|
|
73
|
-
if cache_base and os.access(cache_base, os.W_OK):
|
|
74
|
-
cache_path = os.path.join(cache_base, f".{default_subpath}")
|
|
75
|
-
os.makedirs(cache_path, exist_ok=True)
|
|
76
|
-
os.environ[env_var] = cache_path
|
|
77
|
-
return
|
|
58
|
+
# Home not writable - use current working directory
|
|
59
|
+
cache_path = os.path.join(os.getcwd(), f".{subdir}")
|
|
60
|
+
os.makedirs(cache_path, exist_ok=True)
|
|
61
|
+
os.environ[env_var] = cache_path
|
|
78
62
|
|
|
79
63
|
|
|
64
|
+
# Configure cache directories (before any library imports)
|
|
65
|
+
_setup_cache_dir("TORCH_HOME", "torch_cache")
|
|
80
66
|
_setup_cache_dir("MPLCONFIGDIR", "matplotlib")
|
|
81
67
|
_setup_cache_dir("FONTCONFIG_CACHE", "fontconfig")
|
|
68
|
+
_setup_cache_dir("XDG_DATA_HOME", "local/share")
|
|
69
|
+
_setup_cache_dir("XDG_STATE_HOME", "local/state")
|
|
70
|
+
_setup_cache_dir("XDG_CACHE_HOME", "cache")
|
|
82
71
|
|
|
83
72
|
# =============================================================================
|
|
84
73
|
# Standard imports (after environment setup)
|
|
@@ -1065,7 +1054,17 @@ def main():
|
|
|
1065
1054
|
f,
|
|
1066
1055
|
)
|
|
1067
1056
|
|
|
1068
|
-
|
|
1057
|
+
# Unwrap model for saving (handle torch.compile compatibility)
|
|
1058
|
+
try:
|
|
1059
|
+
unwrapped = accelerator.unwrap_model(model)
|
|
1060
|
+
except KeyError:
|
|
1061
|
+
# torch.compile model may not have _orig_mod in expected location
|
|
1062
|
+
# Fall back to getting the module directly
|
|
1063
|
+
unwrapped = model.module if hasattr(model, "module") else model
|
|
1064
|
+
# If still compiled, try to get the underlying model
|
|
1065
|
+
if hasattr(unwrapped, "_orig_mod"):
|
|
1066
|
+
unwrapped = unwrapped._orig_mod
|
|
1067
|
+
|
|
1069
1068
|
torch.save(
|
|
1070
1069
|
unwrapped.state_dict(),
|
|
1071
1070
|
os.path.join(args.output_dir, "best_model_weights.pth"),
|
|
@@ -1,8 +1,8 @@
|
|
|
1
|
-
wavedl/__init__.py,sha256=
|
|
2
|
-
wavedl/hpc.py,sha256=
|
|
1
|
+
wavedl/__init__.py,sha256=n0XNSrp0aGEE6HQpzbF1wiKK-jORKgcc0Q2Op4MtQGk,1177
|
|
2
|
+
wavedl/hpc.py,sha256=0h8IZzOT0EzmEv3fU9cKyRVE9V1ivtBzbjuBCaxYadc,8445
|
|
3
3
|
wavedl/hpo.py,sha256=YJXsnSGEBSVUqp_2ah7zu3_VClAUqZrdkuzDaSqQUjU,12952
|
|
4
|
-
wavedl/test.py,sha256=
|
|
5
|
-
wavedl/train.py,sha256=
|
|
4
|
+
wavedl/test.py,sha256=Wajcze8gFEyJ9VyN_Bq-YadZ_VZtVaX_HicvUmW6MXM,37365
|
|
5
|
+
wavedl/train.py,sha256=_pW7prvlNqfUGrGweHO2QelS87UiAYKvyJwqMAIj6yI,49292
|
|
6
6
|
wavedl/models/__init__.py,sha256=lfSohEnAUztO14nuwayMJhPjpgySzRN3jGiyAUuBmAU,3206
|
|
7
7
|
wavedl/models/_template.py,sha256=J_D8taSPmV8lBaucN_vU-WiG98iFr7CJrZVNNX_Tdts,4600
|
|
8
8
|
wavedl/models/base.py,sha256=T9iDF9IQM2MYucG_ggQd31rieUkB2fob-nkHyNIl2ak,7337
|
|
@@ -29,9 +29,9 @@ wavedl/utils/losses.py,sha256=5762M-TBC_hz6uyj1NPbU1vZeFOJQq7fR3-j7OygJRo,7254
|
|
|
29
29
|
wavedl/utils/metrics.py,sha256=mkCpqZwl_XUpNvA5Ekjf7y-HqApafR7eR6EuA8cBdM8,37287
|
|
30
30
|
wavedl/utils/optimizers.py,sha256=PyIkJ_hRhFi_Fio81Gy5YQNhcME0JUUEl8OTSyu-0RA,6323
|
|
31
31
|
wavedl/utils/schedulers.py,sha256=e6Sf0yj8VOqkdwkUHLMyUfGfHKTX4NMr-zfgxWqCTYI,7659
|
|
32
|
-
wavedl-1.4.
|
|
33
|
-
wavedl-1.4.
|
|
34
|
-
wavedl-1.4.
|
|
35
|
-
wavedl-1.4.
|
|
36
|
-
wavedl-1.4.
|
|
37
|
-
wavedl-1.4.
|
|
32
|
+
wavedl-1.4.4.dist-info/LICENSE,sha256=cEUCvcvH-9BT9Y-CNGY__PwWONCKu9zsoIqWA-NeHJ4,1066
|
|
33
|
+
wavedl-1.4.4.dist-info/METADATA,sha256=4tgSvkwzJmZP3PLCrqx-FYV_w6VT6Mi4XIsB0Dvb6_0,40386
|
|
34
|
+
wavedl-1.4.4.dist-info/WHEEL,sha256=beeZ86-EfXScwlR_HKu4SllMC9wUEj_8Z_4FJ3egI2w,91
|
|
35
|
+
wavedl-1.4.4.dist-info/entry_points.txt,sha256=f1RNDkXFZwBzrBzTMFocJ6xhfTvTmaEDTi5YyDEUaF8,140
|
|
36
|
+
wavedl-1.4.4.dist-info/top_level.txt,sha256=ccneUt3D5Qzbh3bsBSSrq9bqrhGiogcWKY24ZC4Q6Xw,7
|
|
37
|
+
wavedl-1.4.4.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|