wavedl 1.4.2__py3-none-any.whl → 1.4.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
wavedl/__init__.py CHANGED
@@ -18,7 +18,7 @@ For inference:
18
18
  # or: python -m wavedl.test --checkpoint best_checkpoint --data_path test.npz
19
19
  """
20
20
 
21
- __version__ = "1.4.2"
21
+ __version__ = "1.4.4"
22
22
  __author__ = "Ductho Le"
23
23
  __email__ = "ductho.le@outlook.com"
24
24
 
wavedl/hpc.py CHANGED
@@ -26,7 +26,6 @@ import os
26
26
  import shutil
27
27
  import subprocess
28
28
  import sys
29
- import tempfile
30
29
  from pathlib import Path
31
30
 
32
31
 
@@ -60,26 +59,40 @@ def setup_hpc_environment() -> None:
60
59
  Handles restricted home directories (e.g., Compute Canada) and
61
60
  offline logging configurations.
62
61
  """
63
- # Use SLURM_TMPDIR if available, otherwise system temp
64
- tmpdir = os.environ.get(
65
- "SLURM_TMPDIR", os.environ.get("TMPDIR", tempfile.gettempdir())
66
- )
67
-
68
- # Configure directories for systems with restricted home directories
69
- os.environ.setdefault("MPLCONFIGDIR", f"{tmpdir}/matplotlib")
70
- os.environ.setdefault(
71
- "FONTCONFIG_PATH", os.environ.get("FONTCONFIG_PATH", "/etc/fonts")
72
- )
73
- os.environ.setdefault("XDG_CACHE_HOME", f"{tmpdir}/.cache")
74
-
75
- # Ensure matplotlib config dir exists
76
- Path(os.environ["MPLCONFIGDIR"]).mkdir(parents=True, exist_ok=True)
62
+ # Check if home is writable
63
+ home = os.path.expanduser("~")
64
+ home_writable = os.access(home, os.W_OK)
65
+
66
+ # Use SLURM_TMPDIR if available, otherwise CWD for HPC, or system temp
67
+ if home_writable:
68
+ # Local machine - let libraries use defaults
69
+ cache_base = None
70
+ else:
71
+ # HPC with restricted home - use CWD for persistent caches
72
+ cache_base = os.getcwd()
73
+
74
+ # Only set environment variables if home is not writable
75
+ if cache_base:
76
+ os.environ.setdefault("TORCH_HOME", f"{cache_base}/.torch_cache")
77
+ os.environ.setdefault("MPLCONFIGDIR", f"{cache_base}/.matplotlib")
78
+ os.environ.setdefault("FONTCONFIG_CACHE", f"{cache_base}/.fontconfig")
79
+ os.environ.setdefault("XDG_CACHE_HOME", f"{cache_base}/.cache")
80
+
81
+ # Ensure directories exist
82
+ for env_var in [
83
+ "TORCH_HOME",
84
+ "MPLCONFIGDIR",
85
+ "FONTCONFIG_CACHE",
86
+ "XDG_CACHE_HOME",
87
+ ]:
88
+ Path(os.environ[env_var]).mkdir(parents=True, exist_ok=True)
77
89
 
78
90
  # WandB configuration (offline by default for HPC)
79
91
  os.environ.setdefault("WANDB_MODE", "offline")
80
- os.environ.setdefault("WANDB_DIR", f"{tmpdir}/wandb")
81
- os.environ.setdefault("WANDB_CACHE_DIR", f"{tmpdir}/wandb_cache")
82
- os.environ.setdefault("WANDB_CONFIG_DIR", f"{tmpdir}/wandb_config")
92
+ if cache_base:
93
+ os.environ.setdefault("WANDB_DIR", f"{cache_base}/.wandb")
94
+ os.environ.setdefault("WANDB_CACHE_DIR", f"{cache_base}/.wandb_cache")
95
+ os.environ.setdefault("WANDB_CONFIG_DIR", f"{cache_base}/.wandb_config")
83
96
 
84
97
  # Suppress non-critical warnings
85
98
  os.environ.setdefault(
wavedl/test.py CHANGED
@@ -29,29 +29,52 @@ Author: Ductho Le (ductho.le@outlook.com)
29
29
  # ==============================================================================
30
30
  # ENVIRONMENT CONFIGURATION (must be before matplotlib import)
31
31
  # ==============================================================================
32
+ # Auto-configure writable cache directories when home is not writable.
33
+ # Uses current working directory as fallback - works on HPC and local machines.
32
34
  import os
33
35
 
34
36
 
35
- os.environ.setdefault("MPLCONFIGDIR", os.getenv("TMPDIR", "/tmp") + "/matplotlib")
36
- os.environ.setdefault("FONTCONFIG_PATH", "/etc/fonts")
37
-
38
- import argparse
39
- import logging
40
- import pickle
41
- from pathlib import Path
42
-
43
- import matplotlib.pyplot as plt
44
- import numpy as np
45
- import pandas as pd
46
- import torch
47
- import torch.nn as nn
48
- from sklearn.metrics import mean_absolute_error, r2_score
49
- from torch.utils.data import DataLoader, TensorDataset
50
- from tqdm.auto import tqdm
37
+ def _setup_cache_dir(env_var: str, subdir: str) -> None:
38
+ """Set cache directory to CWD if home is not writable."""
39
+ if env_var in os.environ:
40
+ return # User already set, respect their choice
41
+
42
+ # Check if home is writable
43
+ home = os.path.expanduser("~")
44
+ if os.access(home, os.W_OK):
45
+ return # Home is writable, let library use defaults
46
+
47
+ # Home not writable - use current working directory
48
+ cache_path = os.path.join(os.getcwd(), f".{subdir}")
49
+ os.makedirs(cache_path, exist_ok=True)
50
+ os.environ[env_var] = cache_path
51
+
52
+
53
+ # Configure cache directories (before any library imports)
54
+ _setup_cache_dir("TORCH_HOME", "torch_cache")
55
+ _setup_cache_dir("MPLCONFIGDIR", "matplotlib")
56
+ _setup_cache_dir("FONTCONFIG_CACHE", "fontconfig")
57
+ _setup_cache_dir("XDG_DATA_HOME", "local/share")
58
+ _setup_cache_dir("XDG_STATE_HOME", "local/state")
59
+ _setup_cache_dir("XDG_CACHE_HOME", "cache")
60
+
61
+ import argparse # noqa: E402
62
+ import logging # noqa: E402
63
+ import pickle # noqa: E402
64
+ from pathlib import Path # noqa: E402
65
+
66
+ import matplotlib.pyplot as plt # noqa: E402
67
+ import numpy as np # noqa: E402
68
+ import pandas as pd # noqa: E402
69
+ import torch # noqa: E402
70
+ import torch.nn as nn # noqa: E402
71
+ from sklearn.metrics import mean_absolute_error, r2_score # noqa: E402
72
+ from torch.utils.data import DataLoader, TensorDataset # noqa: E402
73
+ from tqdm.auto import tqdm # noqa: E402
51
74
 
52
75
  # Local imports
53
- from wavedl.models import build_model, list_models
54
- from wavedl.utils import (
76
+ from wavedl.models import build_model, list_models # noqa: E402
77
+ from wavedl.utils import ( # noqa: E402
55
78
  FIGURE_DPI,
56
79
  calc_pearson,
57
80
  load_test_data,
wavedl/train.py CHANGED
@@ -37,9 +37,43 @@ Author: Ductho Le (ductho.le@outlook.com)
37
37
 
38
38
  from __future__ import annotations
39
39
 
40
+ # =============================================================================
41
+ # HPC Environment Setup (MUST be before any library imports)
42
+ # =============================================================================
43
+ # Auto-configure writable cache directories when home is not writable.
44
+ # Uses current working directory as fallback - works on HPC and local machines.
45
+ import os
46
+
47
+
48
+ def _setup_cache_dir(env_var: str, subdir: str) -> None:
49
+ """Set cache directory to CWD if home is not writable."""
50
+ if env_var in os.environ:
51
+ return # User already set, respect their choice
52
+
53
+ # Check if home is writable
54
+ home = os.path.expanduser("~")
55
+ if os.access(home, os.W_OK):
56
+ return # Home is writable, let library use defaults
57
+
58
+ # Home not writable - use current working directory
59
+ cache_path = os.path.join(os.getcwd(), f".{subdir}")
60
+ os.makedirs(cache_path, exist_ok=True)
61
+ os.environ[env_var] = cache_path
62
+
63
+
64
+ # Configure cache directories (before any library imports)
65
+ _setup_cache_dir("TORCH_HOME", "torch_cache")
66
+ _setup_cache_dir("MPLCONFIGDIR", "matplotlib")
67
+ _setup_cache_dir("FONTCONFIG_CACHE", "fontconfig")
68
+ _setup_cache_dir("XDG_DATA_HOME", "local/share")
69
+ _setup_cache_dir("XDG_STATE_HOME", "local/state")
70
+ _setup_cache_dir("XDG_CACHE_HOME", "cache")
71
+
72
+ # =============================================================================
73
+ # Standard imports (after environment setup)
74
+ # =============================================================================
40
75
  import argparse
41
76
  import logging
42
- import os
43
77
  import pickle
44
78
  import shutil
45
79
  import sys
@@ -47,6 +81,10 @@ import time
47
81
  import warnings
48
82
  from typing import Any
49
83
 
84
+
85
+ # Suppress Pydantic warnings from accelerate's internal Field() usage
86
+ warnings.filterwarnings("ignore", category=UserWarning, module="pydantic")
87
+
50
88
  import matplotlib.pyplot as plt
51
89
  import numpy as np
52
90
  import pandas as pd
@@ -582,9 +620,9 @@ def main():
582
620
  # Torch 2.0 compilation (requires compatible Triton on GPU)
583
621
  if args.compile:
584
622
  try:
585
- # Test if Triton is available AND compatible with this PyTorch version
586
- # PyTorch needs triton_key from triton.compiler.compiler
587
- from triton.compiler.compiler import triton_key
623
+ # Test if Triton is available - just import the package
624
+ # Different Triton versions have different internal APIs, so just check base import
625
+ import triton
588
626
 
589
627
  model = torch.compile(model)
590
628
  if accelerator.is_main_process:
@@ -875,9 +913,13 @@ def main():
875
913
  cpu_preds = torch.cat(local_preds)
876
914
  cpu_targets = torch.cat(local_targets)
877
915
 
878
- # Gather to rank 0 only via gather_object (avoids all-gather to every rank)
879
- # gather_object returns list of objects from each rank: [(preds0, targs0), (preds1, targs1), ...]
880
- gathered = accelerator.gather_object((cpu_preds, cpu_targets))
916
+ # Gather predictions and targets across all ranks
917
+ # Use accelerator.gather (works with all accelerate versions)
918
+ gpu_preds = cpu_preds.to(accelerator.device)
919
+ gpu_targets = cpu_targets.to(accelerator.device)
920
+ all_preds_gathered = accelerator.gather(gpu_preds).cpu()
921
+ all_targets_gathered = accelerator.gather(gpu_targets).cpu()
922
+ gathered = [(all_preds_gathered, all_targets_gathered)]
881
923
 
882
924
  # Synchronize validation metrics (scalars only - efficient)
883
925
  val_loss_scalar = val_loss_sum.item()
@@ -1012,7 +1054,17 @@ def main():
1012
1054
  f,
1013
1055
  )
1014
1056
 
1015
- unwrapped = accelerator.unwrap_model(model)
1057
+ # Unwrap model for saving (handle torch.compile compatibility)
1058
+ try:
1059
+ unwrapped = accelerator.unwrap_model(model)
1060
+ except KeyError:
1061
+ # torch.compile model may not have _orig_mod in expected location
1062
+ # Fall back to getting the module directly
1063
+ unwrapped = model.module if hasattr(model, "module") else model
1064
+ # If still compiled, try to get the underlying model
1065
+ if hasattr(unwrapped, "_orig_mod"):
1066
+ unwrapped = unwrapped._orig_mod
1067
+
1016
1068
  torch.save(
1017
1069
  unwrapped.state_dict(),
1018
1070
  os.path.join(args.output_dir, "best_model_weights.pth"),
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.2
2
2
  Name: wavedl
3
- Version: 1.4.2
3
+ Version: 1.4.4
4
4
  Summary: A Scalable Deep Learning Framework for Wave-Based Inverse Problems
5
5
  Author: Ductho Le
6
6
  License: MIT
@@ -57,6 +57,7 @@ Requires-Dist: triton>=2.0.0; sys_platform == "linux"
57
57
  [![Lint](https://img.shields.io/github/actions/workflow/status/ductho-le/WaveDL/lint.yml?branch=main&style=plastic&logo=ruff&logoColor=white&label=Lint)](https://github.com/ductho-le/WaveDL/actions/workflows/lint.yml)
58
58
  [![Try it on Colab](https://img.shields.io/badge/Try_it_on_Colab-8E44AD?style=plastic&logo=googlecolab&logoColor=white)](https://colab.research.google.com/github/ductho-le/WaveDL/blob/main/notebooks/demo.ipynb)
59
59
  <br>
60
+ [![Downloads](https://img.shields.io/pepy/dt/wavedl?style=plastic&logo=pypi&logoColor=white&color=9ACD32)](https://pepy.tech/project/wavedl)
60
61
  [![License: MIT](https://img.shields.io/badge/License-MIT-orange.svg?style=plastic)](LICENSE)
61
62
  [![DOI](https://img.shields.io/badge/DOI-10.5281/zenodo.18012338-008080.svg?style=plastic)](https://doi.org/10.5281/zenodo.18012338)
62
63
 
@@ -1,8 +1,8 @@
1
- wavedl/__init__.py,sha256=K52yq0nkj2B3W0ZpR3tb7RUcHMIiANAE2d1WTuGVZLI,1177
2
- wavedl/hpc.py,sha256=de_GKERX8GS10sXRX9yXiGzMnk1jjq8JPzRw7QDs6d4,7967
1
+ wavedl/__init__.py,sha256=n0XNSrp0aGEE6HQpzbF1wiKK-jORKgcc0Q2Op4MtQGk,1177
2
+ wavedl/hpc.py,sha256=0h8IZzOT0EzmEv3fU9cKyRVE9V1ivtBzbjuBCaxYadc,8445
3
3
  wavedl/hpo.py,sha256=YJXsnSGEBSVUqp_2ah7zu3_VClAUqZrdkuzDaSqQUjU,12952
4
- wavedl/test.py,sha256=jZmRJaivYYTMMTaccCi0yQjHOfp0a9YWR1wAPeKFH-k,36246
5
- wavedl/train.py,sha256=Gh02hlfjcote6w1sgUndJzKU_FhkJckTAhlq1aL8FY8,46842
4
+ wavedl/test.py,sha256=Wajcze8gFEyJ9VyN_Bq-YadZ_VZtVaX_HicvUmW6MXM,37365
5
+ wavedl/train.py,sha256=_pW7prvlNqfUGrGweHO2QelS87UiAYKvyJwqMAIj6yI,49292
6
6
  wavedl/models/__init__.py,sha256=lfSohEnAUztO14nuwayMJhPjpgySzRN3jGiyAUuBmAU,3206
7
7
  wavedl/models/_template.py,sha256=J_D8taSPmV8lBaucN_vU-WiG98iFr7CJrZVNNX_Tdts,4600
8
8
  wavedl/models/base.py,sha256=T9iDF9IQM2MYucG_ggQd31rieUkB2fob-nkHyNIl2ak,7337
@@ -29,9 +29,9 @@ wavedl/utils/losses.py,sha256=5762M-TBC_hz6uyj1NPbU1vZeFOJQq7fR3-j7OygJRo,7254
29
29
  wavedl/utils/metrics.py,sha256=mkCpqZwl_XUpNvA5Ekjf7y-HqApafR7eR6EuA8cBdM8,37287
30
30
  wavedl/utils/optimizers.py,sha256=PyIkJ_hRhFi_Fio81Gy5YQNhcME0JUUEl8OTSyu-0RA,6323
31
31
  wavedl/utils/schedulers.py,sha256=e6Sf0yj8VOqkdwkUHLMyUfGfHKTX4NMr-zfgxWqCTYI,7659
32
- wavedl-1.4.2.dist-info/LICENSE,sha256=cEUCvcvH-9BT9Y-CNGY__PwWONCKu9zsoIqWA-NeHJ4,1066
33
- wavedl-1.4.2.dist-info/METADATA,sha256=xHMuRcGdF8Tdju6aBK601mqLYx0uSIdmtjoQrHpPYmI,40245
34
- wavedl-1.4.2.dist-info/WHEEL,sha256=beeZ86-EfXScwlR_HKu4SllMC9wUEj_8Z_4FJ3egI2w,91
35
- wavedl-1.4.2.dist-info/entry_points.txt,sha256=f1RNDkXFZwBzrBzTMFocJ6xhfTvTmaEDTi5YyDEUaF8,140
36
- wavedl-1.4.2.dist-info/top_level.txt,sha256=ccneUt3D5Qzbh3bsBSSrq9bqrhGiogcWKY24ZC4Q6Xw,7
37
- wavedl-1.4.2.dist-info/RECORD,,
32
+ wavedl-1.4.4.dist-info/LICENSE,sha256=cEUCvcvH-9BT9Y-CNGY__PwWONCKu9zsoIqWA-NeHJ4,1066
33
+ wavedl-1.4.4.dist-info/METADATA,sha256=4tgSvkwzJmZP3PLCrqx-FYV_w6VT6Mi4XIsB0Dvb6_0,40386
34
+ wavedl-1.4.4.dist-info/WHEEL,sha256=beeZ86-EfXScwlR_HKu4SllMC9wUEj_8Z_4FJ3egI2w,91
35
+ wavedl-1.4.4.dist-info/entry_points.txt,sha256=f1RNDkXFZwBzrBzTMFocJ6xhfTvTmaEDTi5YyDEUaF8,140
36
+ wavedl-1.4.4.dist-info/top_level.txt,sha256=ccneUt3D5Qzbh3bsBSSrq9bqrhGiogcWKY24ZC4Q6Xw,7
37
+ wavedl-1.4.4.dist-info/RECORD,,
File without changes