wavedl 1.4.2__py3-none-any.whl → 1.4.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- wavedl/__init__.py +1 -1
- wavedl/train.py +60 -7
- {wavedl-1.4.2.dist-info → wavedl-1.4.3.dist-info}/METADATA +2 -1
- {wavedl-1.4.2.dist-info → wavedl-1.4.3.dist-info}/RECORD +8 -8
- {wavedl-1.4.2.dist-info → wavedl-1.4.3.dist-info}/LICENSE +0 -0
- {wavedl-1.4.2.dist-info → wavedl-1.4.3.dist-info}/WHEEL +0 -0
- {wavedl-1.4.2.dist-info → wavedl-1.4.3.dist-info}/entry_points.txt +0 -0
- {wavedl-1.4.2.dist-info → wavedl-1.4.3.dist-info}/top_level.txt +0 -0
wavedl/__init__.py
CHANGED
wavedl/train.py
CHANGED
|
@@ -37,9 +37,54 @@ Author: Ductho Le (ductho.le@outlook.com)
|
|
|
37
37
|
|
|
38
38
|
from __future__ import annotations
|
|
39
39
|
|
|
40
|
+
# =============================================================================
|
|
41
|
+
# HPC Environment Setup (MUST be before any library imports)
|
|
42
|
+
# =============================================================================
|
|
43
|
+
# Set writable cache directories for matplotlib and fontconfig ONLY when
|
|
44
|
+
# the default paths are not writable (common on HPC clusters).
|
|
45
|
+
import os
|
|
46
|
+
import tempfile
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
def _setup_cache_dir(env_var: str, default_subpath: str) -> None:
|
|
50
|
+
"""Set cache directory only if default path is not writable."""
|
|
51
|
+
if env_var in os.environ:
|
|
52
|
+
return # User already set, respect their choice
|
|
53
|
+
|
|
54
|
+
# Check if default home config path is writable
|
|
55
|
+
home = os.path.expanduser("~")
|
|
56
|
+
default_path = os.path.join(home, ".config", default_subpath)
|
|
57
|
+
default_parent = os.path.dirname(default_path)
|
|
58
|
+
|
|
59
|
+
# If default path or its parent is writable, let the library use defaults
|
|
60
|
+
if (
|
|
61
|
+
os.access(default_path, os.W_OK)
|
|
62
|
+
or (os.path.exists(default_parent) and os.access(default_parent, os.W_OK))
|
|
63
|
+
or os.access(os.path.join(home, ".config"), os.W_OK)
|
|
64
|
+
):
|
|
65
|
+
return
|
|
66
|
+
|
|
67
|
+
# Default not writable - find alternative location
|
|
68
|
+
for cache_base in [
|
|
69
|
+
os.environ.get("SCRATCH"),
|
|
70
|
+
os.environ.get("SLURM_TMPDIR"),
|
|
71
|
+
tempfile.gettempdir(),
|
|
72
|
+
]:
|
|
73
|
+
if cache_base and os.access(cache_base, os.W_OK):
|
|
74
|
+
cache_path = os.path.join(cache_base, f".{default_subpath}")
|
|
75
|
+
os.makedirs(cache_path, exist_ok=True)
|
|
76
|
+
os.environ[env_var] = cache_path
|
|
77
|
+
return
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
_setup_cache_dir("MPLCONFIGDIR", "matplotlib")
|
|
81
|
+
_setup_cache_dir("FONTCONFIG_CACHE", "fontconfig")
|
|
82
|
+
|
|
83
|
+
# =============================================================================
|
|
84
|
+
# Standard imports (after environment setup)
|
|
85
|
+
# =============================================================================
|
|
40
86
|
import argparse
|
|
41
87
|
import logging
|
|
42
|
-
import os
|
|
43
88
|
import pickle
|
|
44
89
|
import shutil
|
|
45
90
|
import sys
|
|
@@ -47,6 +92,10 @@ import time
|
|
|
47
92
|
import warnings
|
|
48
93
|
from typing import Any
|
|
49
94
|
|
|
95
|
+
|
|
96
|
+
# Suppress Pydantic warnings from accelerate's internal Field() usage
|
|
97
|
+
warnings.filterwarnings("ignore", category=UserWarning, module="pydantic")
|
|
98
|
+
|
|
50
99
|
import matplotlib.pyplot as plt
|
|
51
100
|
import numpy as np
|
|
52
101
|
import pandas as pd
|
|
@@ -582,9 +631,9 @@ def main():
|
|
|
582
631
|
# Torch 2.0 compilation (requires compatible Triton on GPU)
|
|
583
632
|
if args.compile:
|
|
584
633
|
try:
|
|
585
|
-
# Test if Triton is available
|
|
586
|
-
#
|
|
587
|
-
|
|
634
|
+
# Test if Triton is available - just import the package
|
|
635
|
+
# Different Triton versions have different internal APIs, so just check base import
|
|
636
|
+
import triton
|
|
588
637
|
|
|
589
638
|
model = torch.compile(model)
|
|
590
639
|
if accelerator.is_main_process:
|
|
@@ -875,9 +924,13 @@ def main():
|
|
|
875
924
|
cpu_preds = torch.cat(local_preds)
|
|
876
925
|
cpu_targets = torch.cat(local_targets)
|
|
877
926
|
|
|
878
|
-
# Gather
|
|
879
|
-
#
|
|
880
|
-
|
|
927
|
+
# Gather predictions and targets across all ranks
|
|
928
|
+
# Use accelerator.gather (works with all accelerate versions)
|
|
929
|
+
gpu_preds = cpu_preds.to(accelerator.device)
|
|
930
|
+
gpu_targets = cpu_targets.to(accelerator.device)
|
|
931
|
+
all_preds_gathered = accelerator.gather(gpu_preds).cpu()
|
|
932
|
+
all_targets_gathered = accelerator.gather(gpu_targets).cpu()
|
|
933
|
+
gathered = [(all_preds_gathered, all_targets_gathered)]
|
|
881
934
|
|
|
882
935
|
# Synchronize validation metrics (scalars only - efficient)
|
|
883
936
|
val_loss_scalar = val_loss_sum.item()
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.2
|
|
2
2
|
Name: wavedl
|
|
3
|
-
Version: 1.4.
|
|
3
|
+
Version: 1.4.3
|
|
4
4
|
Summary: A Scalable Deep Learning Framework for Wave-Based Inverse Problems
|
|
5
5
|
Author: Ductho Le
|
|
6
6
|
License: MIT
|
|
@@ -57,6 +57,7 @@ Requires-Dist: triton>=2.0.0; sys_platform == "linux"
|
|
|
57
57
|
[](https://github.com/ductho-le/WaveDL/actions/workflows/lint.yml)
|
|
58
58
|
[](https://colab.research.google.com/github/ductho-le/WaveDL/blob/main/notebooks/demo.ipynb)
|
|
59
59
|
<br>
|
|
60
|
+
[](https://pepy.tech/project/wavedl)
|
|
60
61
|
[](LICENSE)
|
|
61
62
|
[](https://doi.org/10.5281/zenodo.18012338)
|
|
62
63
|
|
|
@@ -1,8 +1,8 @@
|
|
|
1
|
-
wavedl/__init__.py,sha256=
|
|
1
|
+
wavedl/__init__.py,sha256=7uvgr9r21qLGu6RIZDUfQBhg1vXNAfmIWl_P6BMj_KQ,1177
|
|
2
2
|
wavedl/hpc.py,sha256=de_GKERX8GS10sXRX9yXiGzMnk1jjq8JPzRw7QDs6d4,7967
|
|
3
3
|
wavedl/hpo.py,sha256=YJXsnSGEBSVUqp_2ah7zu3_VClAUqZrdkuzDaSqQUjU,12952
|
|
4
4
|
wavedl/test.py,sha256=jZmRJaivYYTMMTaccCi0yQjHOfp0a9YWR1wAPeKFH-k,36246
|
|
5
|
-
wavedl/train.py,sha256=
|
|
5
|
+
wavedl/train.py,sha256=TjLABBPCqu9r7FEWlxJlKsT7uAMo6hiDxRfii2SKXe4,49052
|
|
6
6
|
wavedl/models/__init__.py,sha256=lfSohEnAUztO14nuwayMJhPjpgySzRN3jGiyAUuBmAU,3206
|
|
7
7
|
wavedl/models/_template.py,sha256=J_D8taSPmV8lBaucN_vU-WiG98iFr7CJrZVNNX_Tdts,4600
|
|
8
8
|
wavedl/models/base.py,sha256=T9iDF9IQM2MYucG_ggQd31rieUkB2fob-nkHyNIl2ak,7337
|
|
@@ -29,9 +29,9 @@ wavedl/utils/losses.py,sha256=5762M-TBC_hz6uyj1NPbU1vZeFOJQq7fR3-j7OygJRo,7254
|
|
|
29
29
|
wavedl/utils/metrics.py,sha256=mkCpqZwl_XUpNvA5Ekjf7y-HqApafR7eR6EuA8cBdM8,37287
|
|
30
30
|
wavedl/utils/optimizers.py,sha256=PyIkJ_hRhFi_Fio81Gy5YQNhcME0JUUEl8OTSyu-0RA,6323
|
|
31
31
|
wavedl/utils/schedulers.py,sha256=e6Sf0yj8VOqkdwkUHLMyUfGfHKTX4NMr-zfgxWqCTYI,7659
|
|
32
|
-
wavedl-1.4.
|
|
33
|
-
wavedl-1.4.
|
|
34
|
-
wavedl-1.4.
|
|
35
|
-
wavedl-1.4.
|
|
36
|
-
wavedl-1.4.
|
|
37
|
-
wavedl-1.4.
|
|
32
|
+
wavedl-1.4.3.dist-info/LICENSE,sha256=cEUCvcvH-9BT9Y-CNGY__PwWONCKu9zsoIqWA-NeHJ4,1066
|
|
33
|
+
wavedl-1.4.3.dist-info/METADATA,sha256=vs3nt8R2O5lD7q-si9M5ChyrriIKm0fFzDwi4HVIYxw,40386
|
|
34
|
+
wavedl-1.4.3.dist-info/WHEEL,sha256=beeZ86-EfXScwlR_HKu4SllMC9wUEj_8Z_4FJ3egI2w,91
|
|
35
|
+
wavedl-1.4.3.dist-info/entry_points.txt,sha256=f1RNDkXFZwBzrBzTMFocJ6xhfTvTmaEDTi5YyDEUaF8,140
|
|
36
|
+
wavedl-1.4.3.dist-info/top_level.txt,sha256=ccneUt3D5Qzbh3bsBSSrq9bqrhGiogcWKY24ZC4Q6Xw,7
|
|
37
|
+
wavedl-1.4.3.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|