wavedl 1.6.1__py3-none-any.whl → 1.6.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- wavedl/__init__.py +1 -1
- wavedl/hpo.py +1 -1
- wavedl/{hpc.py → launcher.py} +135 -61
- wavedl/models/__init__.py +22 -0
- wavedl/test.py +8 -0
- wavedl/train.py +10 -13
- wavedl/utils/data.py +36 -6
- {wavedl-1.6.1.dist-info → wavedl-1.6.3.dist-info}/METADATA +59 -64
- {wavedl-1.6.1.dist-info → wavedl-1.6.3.dist-info}/RECORD +13 -13
- {wavedl-1.6.1.dist-info → wavedl-1.6.3.dist-info}/entry_points.txt +2 -2
- {wavedl-1.6.1.dist-info → wavedl-1.6.3.dist-info}/LICENSE +0 -0
- {wavedl-1.6.1.dist-info → wavedl-1.6.3.dist-info}/WHEEL +0 -0
- {wavedl-1.6.1.dist-info → wavedl-1.6.3.dist-info}/top_level.txt +0 -0
wavedl/__init__.py
CHANGED
wavedl/hpo.py
CHANGED
|
@@ -440,7 +440,7 @@ Examples:
|
|
|
440
440
|
print("\n" + "=" * 60)
|
|
441
441
|
print("TO TRAIN WITH BEST PARAMETERS:")
|
|
442
442
|
print("=" * 60)
|
|
443
|
-
cmd_parts = ["
|
|
443
|
+
cmd_parts = ["wavedl-train"]
|
|
444
444
|
cmd_parts.append(f"--data_path {args.data_path}")
|
|
445
445
|
for key, value in study.best_params.items():
|
|
446
446
|
cmd_parts.append(f"--{key} {value}")
|
wavedl/{hpc.py → launcher.py}
RENAMED
|
@@ -1,12 +1,21 @@
|
|
|
1
1
|
#!/usr/bin/env python
|
|
2
2
|
"""
|
|
3
|
-
WaveDL
|
|
3
|
+
WaveDL Training Launcher.
|
|
4
4
|
|
|
5
|
-
This module provides a
|
|
6
|
-
for distributed training
|
|
5
|
+
This module provides a universal training launcher that wraps accelerate
|
|
6
|
+
for distributed training. It works seamlessly on both:
|
|
7
|
+
- Local machines (uses standard cache locations)
|
|
8
|
+
- HPC clusters (uses local caching, offline WandB)
|
|
9
|
+
|
|
10
|
+
The environment is auto-detected based on scheduler variables (SLURM, PBS, etc.)
|
|
11
|
+
and home directory writability.
|
|
7
12
|
|
|
8
13
|
Usage:
|
|
9
|
-
|
|
14
|
+
# Local machine or HPC - same command!
|
|
15
|
+
wavedl-train --model cnn --data_path train.npz --output_dir results
|
|
16
|
+
|
|
17
|
+
# Multi-GPU is automatic (uses all available GPUs)
|
|
18
|
+
wavedl-train --model resnet18 --data_path train.npz --num_gpus 4
|
|
10
19
|
|
|
11
20
|
Example SLURM script:
|
|
12
21
|
#!/bin/bash
|
|
@@ -14,7 +23,7 @@ Example SLURM script:
|
|
|
14
23
|
#SBATCH --gpus-per-node=4
|
|
15
24
|
#SBATCH --time=12:00:00
|
|
16
25
|
|
|
17
|
-
wavedl-
|
|
26
|
+
wavedl-train --model cnn --data_path /scratch/data.npz --compile
|
|
18
27
|
|
|
19
28
|
Author: Ductho Le (ductho.le@outlook.com)
|
|
20
29
|
"""
|
|
@@ -53,78 +62,138 @@ def detect_gpus() -> int:
|
|
|
53
62
|
return 1
|
|
54
63
|
|
|
55
64
|
|
|
56
|
-
def
|
|
57
|
-
"""
|
|
65
|
+
def is_hpc_environment() -> bool:
|
|
66
|
+
"""Detect if running on an HPC cluster.
|
|
67
|
+
|
|
68
|
+
Checks for:
|
|
69
|
+
1. Common HPC scheduler environment variables (SLURM, PBS, LSF, SGE, Cobalt)
|
|
70
|
+
2. Non-writable home directory (common on HPC systems)
|
|
58
71
|
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
since compute nodes typically lack internet access.
|
|
72
|
+
Returns:
|
|
73
|
+
True if HPC environment detected, False otherwise.
|
|
62
74
|
"""
|
|
63
|
-
#
|
|
64
|
-
|
|
75
|
+
# Check for common HPC scheduler environment variables
|
|
76
|
+
hpc_indicators = [
|
|
77
|
+
"SLURM_JOB_ID", # SLURM
|
|
78
|
+
"PBS_JOBID", # PBS/Torque
|
|
79
|
+
"LSB_JOBID", # LSF
|
|
80
|
+
"SGE_TASK_ID", # Sun Grid Engine
|
|
81
|
+
"COBALT_JOBID", # Cobalt
|
|
82
|
+
]
|
|
83
|
+
if any(var in os.environ for var in hpc_indicators):
|
|
84
|
+
return True
|
|
65
85
|
|
|
66
|
-
#
|
|
67
|
-
os.
|
|
68
|
-
|
|
86
|
+
# Check if home directory is not writable (common on HPC)
|
|
87
|
+
home = os.path.expanduser("~")
|
|
88
|
+
return not os.access(home, os.W_OK)
|
|
69
89
|
|
|
70
|
-
# Triton/Inductor caches - prevents permission errors with --compile
|
|
71
|
-
# These MUST be set before any torch.compile calls
|
|
72
|
-
os.environ.setdefault("TRITON_CACHE_DIR", f"{cache_base}/.triton_cache")
|
|
73
|
-
os.environ.setdefault("TORCHINDUCTOR_CACHE_DIR", f"{cache_base}/.inductor_cache")
|
|
74
|
-
Path(os.environ["TRITON_CACHE_DIR"]).mkdir(parents=True, exist_ok=True)
|
|
75
|
-
Path(os.environ["TORCHINDUCTOR_CACHE_DIR"]).mkdir(parents=True, exist_ok=True)
|
|
76
90
|
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
#
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
]
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
91
|
+
def setup_environment() -> None:
|
|
92
|
+
"""Configure environment for HPC or local machine.
|
|
93
|
+
|
|
94
|
+
Automatically detects the environment and configures accordingly:
|
|
95
|
+
- HPC: Uses CWD-based caching, offline WandB (compute nodes lack internet)
|
|
96
|
+
- Local: Uses standard cache locations (~/.cache), doesn't override WandB
|
|
97
|
+
"""
|
|
98
|
+
is_hpc = is_hpc_environment()
|
|
99
|
+
|
|
100
|
+
if is_hpc:
|
|
101
|
+
# HPC: use CWD-based caching (compute nodes lack internet)
|
|
102
|
+
cache_base = os.getcwd()
|
|
103
|
+
|
|
104
|
+
# TORCH_HOME set to CWD - compute nodes need pre-cached weights
|
|
105
|
+
os.environ.setdefault("TORCH_HOME", f"{cache_base}/.torch_cache")
|
|
106
|
+
Path(os.environ["TORCH_HOME"]).mkdir(parents=True, exist_ok=True)
|
|
107
|
+
|
|
108
|
+
# Triton/Inductor caches - prevents permission errors with --compile
|
|
109
|
+
os.environ.setdefault("TRITON_CACHE_DIR", f"{cache_base}/.triton_cache")
|
|
110
|
+
os.environ.setdefault(
|
|
111
|
+
"TORCHINDUCTOR_CACHE_DIR", f"{cache_base}/.inductor_cache"
|
|
112
|
+
)
|
|
113
|
+
Path(os.environ["TRITON_CACHE_DIR"]).mkdir(parents=True, exist_ok=True)
|
|
114
|
+
Path(os.environ["TORCHINDUCTOR_CACHE_DIR"]).mkdir(parents=True, exist_ok=True)
|
|
115
|
+
|
|
116
|
+
# Check if home is writable for other caches
|
|
117
|
+
home = os.path.expanduser("~")
|
|
118
|
+
home_writable = os.access(home, os.W_OK)
|
|
119
|
+
|
|
120
|
+
# Other caches only if home is not writable
|
|
121
|
+
if not home_writable:
|
|
122
|
+
os.environ.setdefault("MPLCONFIGDIR", f"{cache_base}/.matplotlib")
|
|
123
|
+
os.environ.setdefault("FONTCONFIG_CACHE", f"{cache_base}/.fontconfig")
|
|
124
|
+
os.environ.setdefault("XDG_CACHE_HOME", f"{cache_base}/.cache")
|
|
125
|
+
|
|
126
|
+
for env_var in [
|
|
127
|
+
"MPLCONFIGDIR",
|
|
128
|
+
"FONTCONFIG_CACHE",
|
|
129
|
+
"XDG_CACHE_HOME",
|
|
130
|
+
]:
|
|
131
|
+
Path(os.environ[env_var]).mkdir(parents=True, exist_ok=True)
|
|
132
|
+
|
|
133
|
+
# WandB configuration (offline by default for HPC)
|
|
134
|
+
os.environ.setdefault("WANDB_MODE", "offline")
|
|
135
|
+
os.environ.setdefault("WANDB_DIR", f"{cache_base}/.wandb")
|
|
136
|
+
os.environ.setdefault("WANDB_CACHE_DIR", f"{cache_base}/.wandb_cache")
|
|
137
|
+
os.environ.setdefault("WANDB_CONFIG_DIR", f"{cache_base}/.wandb_config")
|
|
138
|
+
|
|
139
|
+
print("🖥️ HPC environment detected - using local caching")
|
|
140
|
+
else:
|
|
141
|
+
# Local machine: use standard locations, don't override user settings
|
|
142
|
+
# TORCH_HOME defaults to ~/.cache/torch (PyTorch default)
|
|
143
|
+
# WANDB_MODE defaults to online (WandB default)
|
|
144
|
+
print("💻 Local environment detected - using standard cache locations")
|
|
145
|
+
|
|
146
|
+
# Suppress non-critical warnings (both environments)
|
|
102
147
|
os.environ.setdefault(
|
|
103
148
|
"PYTHONWARNINGS",
|
|
104
149
|
"ignore::UserWarning,ignore::FutureWarning,ignore::DeprecationWarning",
|
|
105
150
|
)
|
|
106
151
|
|
|
107
152
|
|
|
153
|
+
def handle_fast_path_args() -> int | None:
|
|
154
|
+
"""Handle utility flags that don't need accelerate launch.
|
|
155
|
+
|
|
156
|
+
Returns:
|
|
157
|
+
Exit code if handled (0 for success), None if should continue to full launch.
|
|
158
|
+
"""
|
|
159
|
+
# --list_models: print models and exit immediately
|
|
160
|
+
if "--list_models" in sys.argv:
|
|
161
|
+
from wavedl.models import list_models
|
|
162
|
+
|
|
163
|
+
print("Available models:")
|
|
164
|
+
for name in list_models():
|
|
165
|
+
print(f" {name}")
|
|
166
|
+
return 0
|
|
167
|
+
|
|
168
|
+
return None # Continue to full launch
|
|
169
|
+
|
|
170
|
+
|
|
108
171
|
def parse_args() -> tuple[argparse.Namespace, list[str]]:
|
|
109
|
-
"""Parse
|
|
172
|
+
"""Parse launcher-specific arguments, pass remaining to wavedl.train."""
|
|
110
173
|
parser = argparse.ArgumentParser(
|
|
111
|
-
description="WaveDL
|
|
174
|
+
description="WaveDL Training Launcher (works on local machines and HPC clusters)",
|
|
112
175
|
formatter_class=argparse.RawDescriptionHelpFormatter,
|
|
113
176
|
epilog="""
|
|
114
177
|
Examples:
|
|
115
|
-
# Basic training
|
|
116
|
-
wavedl-
|
|
178
|
+
# Basic training (auto-detects GPUs and environment)
|
|
179
|
+
wavedl-train --model cnn --data_path train.npz --output_dir results
|
|
117
180
|
|
|
118
|
-
# Specify GPU count
|
|
119
|
-
wavedl-
|
|
181
|
+
# Specify GPU count explicitly
|
|
182
|
+
wavedl-train --model cnn --data_path train.npz --num_gpus 4
|
|
120
183
|
|
|
121
184
|
# Full configuration
|
|
122
|
-
wavedl-
|
|
123
|
-
|
|
185
|
+
wavedl-train --model resnet18 --data_path train.npz --batch_size 256 \\
|
|
186
|
+
--lr 1e-3 --epochs 100 --compile --output_dir ./results
|
|
187
|
+
|
|
188
|
+
# List available models
|
|
189
|
+
wavedl-train --list_models
|
|
124
190
|
|
|
125
|
-
Environment
|
|
126
|
-
|
|
127
|
-
|
|
191
|
+
Environment Detection:
|
|
192
|
+
The launcher automatically detects your environment:
|
|
193
|
+
- HPC (SLURM, PBS, etc.): Uses local caching, offline WandB
|
|
194
|
+
- Local machine: Uses standard cache locations (~/.cache)
|
|
195
|
+
|
|
196
|
+
For full training options, see: python -m wavedl.train --help
|
|
128
197
|
""",
|
|
129
198
|
)
|
|
130
199
|
|
|
@@ -204,7 +273,7 @@ def print_summary(
|
|
|
204
273
|
print("Common issues:")
|
|
205
274
|
print(" - Missing data file (check --data_path)")
|
|
206
275
|
print(" - Insufficient GPU memory (reduce --batch_size)")
|
|
207
|
-
print(" - Invalid model name (run:
|
|
276
|
+
print(" - Invalid model name (run: wavedl-train --list_models)")
|
|
208
277
|
print()
|
|
209
278
|
|
|
210
279
|
print("=" * 40)
|
|
@@ -212,12 +281,17 @@ def print_summary(
|
|
|
212
281
|
|
|
213
282
|
|
|
214
283
|
def main() -> int:
|
|
215
|
-
"""Main entry point for wavedl-
|
|
284
|
+
"""Main entry point for wavedl-train command."""
|
|
285
|
+
# Fast path for utility flags (avoid accelerate launch overhead)
|
|
286
|
+
exit_code = handle_fast_path_args()
|
|
287
|
+
if exit_code is not None:
|
|
288
|
+
return exit_code
|
|
289
|
+
|
|
216
290
|
# Parse arguments
|
|
217
291
|
args, train_args = parse_args()
|
|
218
292
|
|
|
219
|
-
# Setup
|
|
220
|
-
|
|
293
|
+
# Setup environment (smart detection)
|
|
294
|
+
setup_environment()
|
|
221
295
|
|
|
222
296
|
# Check if wavedl package is importable
|
|
223
297
|
try:
|
wavedl/models/__init__.py
CHANGED
|
@@ -77,6 +77,15 @@ from .unet import UNetRegression
|
|
|
77
77
|
from .vit import ViTBase_, ViTSmall, ViTTiny
|
|
78
78
|
|
|
79
79
|
|
|
80
|
+
# Optional RATENet (unpublished, may be gitignored)
|
|
81
|
+
try:
|
|
82
|
+
from .ratenet import RATENet, RATENetLite, RATENetTiny, RATENetV2
|
|
83
|
+
|
|
84
|
+
_HAS_RATENET = True
|
|
85
|
+
except ImportError:
|
|
86
|
+
_HAS_RATENET = False
|
|
87
|
+
|
|
88
|
+
|
|
80
89
|
# Optional timm-based models (imported conditionally)
|
|
81
90
|
try:
|
|
82
91
|
from .caformer import CaFormerS18, CaFormerS36, PoolFormerS12
|
|
@@ -111,6 +120,7 @@ __all__ = [
|
|
|
111
120
|
"MC3_18",
|
|
112
121
|
"MODEL_REGISTRY",
|
|
113
122
|
"TCN",
|
|
123
|
+
# Classes (uppercase first, alphabetically)
|
|
114
124
|
"BaseModel",
|
|
115
125
|
"ConvNeXtBase_",
|
|
116
126
|
"ConvNeXtSmall",
|
|
@@ -152,6 +162,7 @@ __all__ = [
|
|
|
152
162
|
"VimBase",
|
|
153
163
|
"VimSmall",
|
|
154
164
|
"VimTiny",
|
|
165
|
+
# Functions (lowercase, alphabetically)
|
|
155
166
|
"build_model",
|
|
156
167
|
"get_model",
|
|
157
168
|
"list_models",
|
|
@@ -186,3 +197,14 @@ if _HAS_TIMM_MODELS:
|
|
|
186
197
|
"UniRepLKNetTiny",
|
|
187
198
|
]
|
|
188
199
|
)
|
|
200
|
+
|
|
201
|
+
# Add RATENet models to __all__ if available (unpublished)
|
|
202
|
+
if _HAS_RATENET:
|
|
203
|
+
__all__.extend(
|
|
204
|
+
[
|
|
205
|
+
"RATENet",
|
|
206
|
+
"RATENetLite",
|
|
207
|
+
"RATENetTiny",
|
|
208
|
+
"RATENetV2",
|
|
209
|
+
]
|
|
210
|
+
)
|
wavedl/test.py
CHANGED
|
@@ -398,6 +398,14 @@ def load_checkpoint(
|
|
|
398
398
|
|
|
399
399
|
if HAS_SAFETENSORS and weight_path.suffix == ".safetensors":
|
|
400
400
|
state_dict = load_safetensors(str(weight_path))
|
|
401
|
+
elif weight_path.suffix == ".safetensors":
|
|
402
|
+
# Safetensors file exists but library not installed
|
|
403
|
+
raise ImportError(
|
|
404
|
+
f"Checkpoint uses safetensors format ({weight_path.name}) but "
|
|
405
|
+
f"'safetensors' package is not installed. Install it with:\n"
|
|
406
|
+
f" pip install safetensors\n"
|
|
407
|
+
f"Or convert the checkpoint to PyTorch format (model.bin)."
|
|
408
|
+
)
|
|
401
409
|
else:
|
|
402
410
|
state_dict = torch.load(weight_path, map_location="cpu", weights_only=True)
|
|
403
411
|
|
wavedl/train.py
CHANGED
|
@@ -12,25 +12,22 @@ A modular training framework for wave-based inverse problems and regression:
|
|
|
12
12
|
6. Deep Observability: WandB integration with scatter analysis
|
|
13
13
|
|
|
14
14
|
Usage:
|
|
15
|
-
# Recommended:
|
|
16
|
-
wavedl-
|
|
17
|
-
|
|
18
|
-
# Or direct training module (use --precision, not --mixed_precision)
|
|
19
|
-
accelerate launch -m wavedl.train --model cnn --batch_size 128 --precision bf16
|
|
15
|
+
# Recommended: Universal training command (works on local machines and HPC)
|
|
16
|
+
wavedl-train --model cnn --batch_size 128 --compile
|
|
20
17
|
|
|
21
18
|
# Multi-GPU with explicit config
|
|
22
|
-
wavedl-
|
|
19
|
+
wavedl-train --num_gpus 4 --model cnn --output_dir results
|
|
23
20
|
|
|
24
21
|
# Resume from checkpoint
|
|
25
|
-
|
|
22
|
+
wavedl-train --model cnn --output_dir results # auto-resumes if interrupted
|
|
26
23
|
|
|
27
24
|
# List available models
|
|
28
25
|
wavedl-train --list_models
|
|
29
26
|
|
|
30
27
|
Note:
|
|
31
|
-
|
|
32
|
-
-
|
|
33
|
-
|
|
28
|
+
wavedl-train automatically detects your environment:
|
|
29
|
+
- HPC clusters (SLURM, PBS, etc.): Uses local caching, offline WandB
|
|
30
|
+
- Local machines: Uses standard cache locations (~/.cache)
|
|
34
31
|
|
|
35
32
|
Author: Ductho Le (ductho.le@outlook.com)
|
|
36
33
|
"""
|
|
@@ -429,7 +426,7 @@ def parse_args() -> argparse.Namespace:
|
|
|
429
426
|
choices=["bf16", "fp16", "no"],
|
|
430
427
|
help="Mixed precision mode",
|
|
431
428
|
)
|
|
432
|
-
# Alias for consistency with wavedl-
|
|
429
|
+
# Alias for consistency with wavedl-train (--mixed_precision is passed to accelerate)
|
|
433
430
|
parser.add_argument(
|
|
434
431
|
"--mixed_precision",
|
|
435
432
|
dest="precision",
|
|
@@ -1269,10 +1266,10 @@ def main():
|
|
|
1269
1266
|
os.path.join(args.output_dir, "best_model_weights.pth"),
|
|
1270
1267
|
)
|
|
1271
1268
|
|
|
1272
|
-
# Copy scaler to checkpoint for portability
|
|
1269
|
+
# Copy scaler to checkpoint for portability (always overwrite to stay current)
|
|
1273
1270
|
scaler_src = os.path.join(args.output_dir, "scaler.pkl")
|
|
1274
1271
|
scaler_dst = os.path.join(ckpt_dir, "scaler.pkl")
|
|
1275
|
-
if os.path.exists(scaler_src)
|
|
1272
|
+
if os.path.exists(scaler_src):
|
|
1276
1273
|
shutil.copy2(scaler_src, scaler_dst)
|
|
1277
1274
|
|
|
1278
1275
|
logger.info(
|
wavedl/utils/data.py
CHANGED
|
@@ -984,7 +984,15 @@ def load_test_data(
|
|
|
984
984
|
f"Available keys depend on file format. Original error: {e}"
|
|
985
985
|
) from e
|
|
986
986
|
|
|
987
|
-
#
|
|
987
|
+
# Also fail-fast if explicit input_key was provided but not found
|
|
988
|
+
# This prevents silently loading a different tensor when user mistyped key
|
|
989
|
+
if input_key is not None:
|
|
990
|
+
raise KeyError(
|
|
991
|
+
f"Explicit --input_key '{input_key}' not found in file. "
|
|
992
|
+
f"Original error: {e}"
|
|
993
|
+
) from e
|
|
994
|
+
|
|
995
|
+
# Legitimate fallback: no explicit keys, outputs just not present
|
|
988
996
|
if format == "npz":
|
|
989
997
|
# First pass to find keys
|
|
990
998
|
with np.load(path, allow_pickle=False) as probe:
|
|
@@ -1524,21 +1532,43 @@ def prepare_data(
|
|
|
1524
1532
|
|
|
1525
1533
|
logger.info(" ✔ Cache creation complete, synchronizing ranks...")
|
|
1526
1534
|
else:
|
|
1527
|
-
# NON-MAIN RANKS: Wait for cache creation
|
|
1528
|
-
#
|
|
1535
|
+
# NON-MAIN RANKS: Wait for cache creation with timeout
|
|
1536
|
+
# Use monotonic clock (immune to system clock changes)
|
|
1529
1537
|
import time
|
|
1530
1538
|
|
|
1531
|
-
wait_start = time.
|
|
1539
|
+
wait_start = time.monotonic()
|
|
1540
|
+
|
|
1541
|
+
# Robust env parsing with guards for invalid/non-positive values
|
|
1542
|
+
DEFAULT_CACHE_TIMEOUT = 3600 # 1 hour default
|
|
1543
|
+
try:
|
|
1544
|
+
env_timeout = os.environ.get("WAVEDL_CACHE_TIMEOUT", "")
|
|
1545
|
+
CACHE_TIMEOUT = (
|
|
1546
|
+
int(env_timeout) if env_timeout else DEFAULT_CACHE_TIMEOUT
|
|
1547
|
+
)
|
|
1548
|
+
if CACHE_TIMEOUT <= 0:
|
|
1549
|
+
CACHE_TIMEOUT = DEFAULT_CACHE_TIMEOUT
|
|
1550
|
+
except ValueError:
|
|
1551
|
+
CACHE_TIMEOUT = DEFAULT_CACHE_TIMEOUT
|
|
1552
|
+
|
|
1532
1553
|
while not (
|
|
1533
1554
|
os.path.exists(CACHE_FILE)
|
|
1534
1555
|
and os.path.exists(SCALER_FILE)
|
|
1535
1556
|
and os.path.exists(META_FILE)
|
|
1536
1557
|
):
|
|
1537
1558
|
time.sleep(5) # Check every 5 seconds
|
|
1538
|
-
elapsed = time.
|
|
1559
|
+
elapsed = time.monotonic() - wait_start
|
|
1560
|
+
|
|
1561
|
+
if elapsed > CACHE_TIMEOUT:
|
|
1562
|
+
raise RuntimeError(
|
|
1563
|
+
f"[Rank {accelerator.process_index}] Timeout waiting for cache "
|
|
1564
|
+
f"files after {CACHE_TIMEOUT}s. Rank 0 may have failed during "
|
|
1565
|
+
f"cache generation. Check rank 0 logs for errors."
|
|
1566
|
+
)
|
|
1567
|
+
|
|
1539
1568
|
if elapsed > 60 and int(elapsed) % 60 < 5: # Log every ~minute
|
|
1540
1569
|
logger.info(
|
|
1541
|
-
f" [Rank {accelerator.process_index}] Waiting for cache
|
|
1570
|
+
f" [Rank {accelerator.process_index}] Waiting for cache "
|
|
1571
|
+
f"creation... ({int(elapsed)}s / {CACHE_TIMEOUT}s max)"
|
|
1542
1572
|
)
|
|
1543
1573
|
# Small delay to ensure files are fully written
|
|
1544
1574
|
time.sleep(2)
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.2
|
|
2
2
|
Name: wavedl
|
|
3
|
-
Version: 1.6.
|
|
3
|
+
Version: 1.6.3
|
|
4
4
|
Summary: A Scalable Deep Learning Framework for Wave-Based Inverse Problems
|
|
5
5
|
Author: Ductho Le
|
|
6
6
|
License: MIT
|
|
@@ -214,77 +214,83 @@ This installs everything you need: training, inference, HPO, ONNX export.
|
|
|
214
214
|
```bash
|
|
215
215
|
git clone https://github.com/ductho-le/WaveDL.git
|
|
216
216
|
cd WaveDL
|
|
217
|
-
pip install -e .
|
|
217
|
+
pip install -e ".[dev]"
|
|
218
218
|
```
|
|
219
219
|
|
|
220
220
|
> [!NOTE]
|
|
221
|
-
> Python 3.11+ required. For
|
|
221
|
+
> Python 3.11+ required. For contributor setup (pre-commit hooks), see [CONTRIBUTING.md](.github/CONTRIBUTING.md).
|
|
222
222
|
|
|
223
223
|
### Quick Start
|
|
224
224
|
|
|
225
225
|
> [!TIP]
|
|
226
226
|
> In all examples below, replace `<...>` placeholders with your values. See [Configuration](#️-configuration) for defaults and options.
|
|
227
227
|
|
|
228
|
-
|
|
229
|
-
|
|
230
|
-
The `wavedl-hpc` command automatically configures the environment for HPC systems:
|
|
228
|
+
### Training
|
|
231
229
|
|
|
232
230
|
```bash
|
|
233
|
-
# Basic training (auto-detects
|
|
234
|
-
wavedl-
|
|
231
|
+
# Basic training (auto-detects GPUs and environment)
|
|
232
|
+
wavedl-train --model <model_name> --data_path <train_data> --output_dir <output_folder>
|
|
235
233
|
|
|
236
234
|
# Detailed configuration
|
|
237
|
-
wavedl-
|
|
235
|
+
wavedl-train --model <model_name> --data_path <train_data> --batch_size <number> \
|
|
238
236
|
--lr <number> --epochs <number> --patience <number> --compile --output_dir <output_folder>
|
|
239
237
|
|
|
240
|
-
#
|
|
241
|
-
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
#### Option 2: Direct Accelerate Launch
|
|
245
|
-
|
|
246
|
-
```bash
|
|
247
|
-
# Local - auto-detects GPUs
|
|
248
|
-
accelerate launch -m wavedl.train --model <model_name> --data_path <train_data> --batch_size <number> --output_dir <output_folder>
|
|
238
|
+
# Multi-GPU is automatic (uses all available GPUs)
|
|
239
|
+
# Override with --num_gpus if needed
|
|
240
|
+
wavedl-train --model cnn --data_path train.npz --num_gpus 4 --output_dir results
|
|
249
241
|
|
|
250
242
|
# Resume training (automatic - just re-run with same output_dir)
|
|
251
|
-
|
|
252
|
-
accelerate launch -m wavedl.train --model <model_name> --data_path <train_data> --resume <checkpoint_folder> --output_dir <output_folder>
|
|
243
|
+
wavedl-train --model <model_name> --data_path <train_data> --output_dir <output_folder>
|
|
253
244
|
|
|
254
245
|
# Force fresh start (ignores existing checkpoints)
|
|
255
|
-
|
|
246
|
+
wavedl-train --model <model_name> --data_path <train_data> --output_dir <output_folder> --fresh
|
|
256
247
|
|
|
257
248
|
# List available models
|
|
258
249
|
wavedl-train --list_models
|
|
259
250
|
```
|
|
260
251
|
|
|
261
|
-
> [!
|
|
262
|
-
>
|
|
252
|
+
> [!NOTE]
|
|
253
|
+
> `wavedl-train` automatically detects your environment:
|
|
254
|
+
> - **HPC clusters** (SLURM, PBS, etc.): Uses local caching, offline WandB
|
|
255
|
+
> - **Local machines**: Uses standard cache locations (~/.cache)
|
|
263
256
|
>
|
|
264
|
-
> **
|
|
257
|
+
> **Auto-Resume**: If training crashes or is interrupted, simply re-run with the same `--output_dir`. The framework automatically detects incomplete training and resumes from the last checkpoint.
|
|
265
258
|
|
|
266
|
-
|
|
259
|
+
<details>
|
|
260
|
+
<summary><b>Advanced: Direct Accelerate Launch</b></summary>
|
|
261
|
+
|
|
262
|
+
For fine-grained control over distributed training, you can use `accelerate launch` directly:
|
|
267
263
|
|
|
268
|
-
|
|
264
|
+
```bash
|
|
265
|
+
# Custom accelerate configuration
|
|
266
|
+
accelerate launch -m wavedl.train --model <model_name> --data_path <train_data> --output_dir <output_folder>
|
|
267
|
+
|
|
268
|
+
# Multi-node training
|
|
269
|
+
accelerate launch --num_machines 2 --main_process_ip <ip> -m wavedl.train --model cnn --data_path train.npz
|
|
270
|
+
```
|
|
271
|
+
|
|
272
|
+
</details>
|
|
273
|
+
|
|
274
|
+
### Testing & Inference
|
|
269
275
|
|
|
270
276
|
```bash
|
|
271
277
|
# Basic inference
|
|
272
|
-
|
|
278
|
+
wavedl-test --checkpoint <checkpoint_folder> --data_path <test_data>
|
|
273
279
|
|
|
274
280
|
# With visualization, CSV export, and multiple file formats
|
|
275
|
-
|
|
281
|
+
wavedl-test --checkpoint <checkpoint_folder> --data_path <test_data> \
|
|
276
282
|
--plot --plot_format png pdf --save_predictions --output_dir <output_folder>
|
|
277
283
|
|
|
278
284
|
# With custom parameter names
|
|
279
|
-
|
|
285
|
+
wavedl-test --checkpoint <checkpoint_folder> --data_path <test_data> \
|
|
280
286
|
--param_names '$p_1$' '$p_2$' '$p_3$' --plot
|
|
281
287
|
|
|
282
288
|
# Export model to ONNX for deployment (LabVIEW, MATLAB, C++, etc.)
|
|
283
|
-
|
|
289
|
+
wavedl-test --checkpoint <checkpoint_folder> --data_path <test_data> \
|
|
284
290
|
--export onnx --export_path <output_file.onnx>
|
|
285
291
|
|
|
286
292
|
# For 3D volumes with small depth (e.g., 8×128×128), override auto-detection
|
|
287
|
-
|
|
293
|
+
wavedl-test --checkpoint <checkpoint_folder> --data_path <test_data> \
|
|
288
294
|
--input_channels 1
|
|
289
295
|
```
|
|
290
296
|
|
|
@@ -295,7 +301,7 @@ python -m wavedl.test --checkpoint <checkpoint_folder> --data_path <test_data> \
|
|
|
295
301
|
- **Format** (with `--plot_format`): Supported formats: `png` (default), `pdf` (vector), `svg` (vector), `eps` (LaTeX), `tiff`, `jpg`, `ps`
|
|
296
302
|
|
|
297
303
|
> [!NOTE]
|
|
298
|
-
> `wavedl
|
|
304
|
+
> `wavedl-test` auto-detects the model architecture from checkpoint metadata. If unavailable, it falls back to folder name parsing. Use `--model` to override if needed.
|
|
299
305
|
|
|
300
306
|
### Adding Custom Models
|
|
301
307
|
|
|
@@ -339,7 +345,7 @@ class MyModel(BaseModel):
|
|
|
339
345
|
**Step 2: Train**
|
|
340
346
|
|
|
341
347
|
```bash
|
|
342
|
-
wavedl-
|
|
348
|
+
wavedl-train --import my_model.py --model my_model --data_path train.npz
|
|
343
349
|
```
|
|
344
350
|
|
|
345
351
|
WaveDL handles everything else: training loop, logging, checkpoints, multi-GPU, early stopping, etc.
|
|
@@ -355,10 +361,10 @@ WaveDL/
|
|
|
355
361
|
├── src/
|
|
356
362
|
│ └── wavedl/ # Main package (namespaced)
|
|
357
363
|
│ ├── __init__.py # Package init with __version__
|
|
358
|
-
│ ├── train.py # Training
|
|
364
|
+
│ ├── train.py # Training script
|
|
359
365
|
│ ├── test.py # Testing & inference script
|
|
360
366
|
│ ├── hpo.py # Hyperparameter optimization
|
|
361
|
-
│ ├──
|
|
367
|
+
│ ├── launcher.py # Training launcher (wavedl-train)
|
|
362
368
|
│ │
|
|
363
369
|
│ ├── models/ # Model Zoo (69 architectures)
|
|
364
370
|
│ │ ├── registry.py # Model factory (@register_model)
|
|
@@ -389,16 +395,7 @@ WaveDL/
|
|
|
389
395
|
## ⚙️ Configuration
|
|
390
396
|
|
|
391
397
|
> [!NOTE]
|
|
392
|
-
> All configuration options below work with
|
|
393
|
-
>
|
|
394
|
-
> **Examples:**
|
|
395
|
-
> ```bash
|
|
396
|
-
> # Using wavedl-hpc
|
|
397
|
-
> wavedl-hpc --model cnn --batch_size 256 --lr 5e-4 --compile
|
|
398
|
-
>
|
|
399
|
-
> # Using accelerate launch directly
|
|
400
|
-
> accelerate launch -m wavedl.train --model cnn --batch_size 256 --lr 5e-4 --compile
|
|
401
|
-
> ```
|
|
398
|
+
> All configuration options below work with `wavedl-train`. The wrapper script passes all arguments directly to `train.py`.
|
|
402
399
|
|
|
403
400
|
<details>
|
|
404
401
|
<summary><b>Available Models</b> — 69 architectures</summary>
|
|
@@ -642,7 +639,7 @@ WaveDL automatically enables performance optimizations for modern GPUs:
|
|
|
642
639
|
</details>
|
|
643
640
|
|
|
644
641
|
<details>
|
|
645
|
-
<summary><b>
|
|
642
|
+
<summary><b>Distributed Training Arguments</b></summary>
|
|
646
643
|
|
|
647
644
|
| Argument | Default | Description |
|
|
648
645
|
|----------|---------|-------------|
|
|
@@ -674,10 +671,10 @@ WaveDL automatically enables performance optimizations for modern GPUs:
|
|
|
674
671
|
**Example:**
|
|
675
672
|
```bash
|
|
676
673
|
# Use Huber loss for noisy NDE data
|
|
677
|
-
|
|
674
|
+
wavedl-train --model cnn --loss huber --huber_delta 0.5
|
|
678
675
|
|
|
679
676
|
# Weighted MSE: prioritize thickness (first target)
|
|
680
|
-
|
|
677
|
+
wavedl-train --model cnn --loss weighted_mse --loss_weights "2.0,1.0,1.0"
|
|
681
678
|
```
|
|
682
679
|
|
|
683
680
|
</details>
|
|
@@ -697,10 +694,10 @@ accelerate launch -m wavedl.train --model cnn --loss weighted_mse --loss_weights
|
|
|
697
694
|
**Example:**
|
|
698
695
|
```bash
|
|
699
696
|
# SGD with Nesterov momentum (often better generalization)
|
|
700
|
-
|
|
697
|
+
wavedl-train --model cnn --optimizer sgd --lr 0.01 --momentum 0.9 --nesterov
|
|
701
698
|
|
|
702
699
|
# RAdam for more stable training
|
|
703
|
-
|
|
700
|
+
wavedl-train --model cnn --optimizer radam --lr 1e-3
|
|
704
701
|
```
|
|
705
702
|
|
|
706
703
|
</details>
|
|
@@ -722,13 +719,13 @@ accelerate launch -m wavedl.train --model cnn --optimizer radam --lr 1e-3
|
|
|
722
719
|
**Example:**
|
|
723
720
|
```bash
|
|
724
721
|
# Cosine annealing for 1000 epochs
|
|
725
|
-
|
|
722
|
+
wavedl-train --model cnn --scheduler cosine --epochs 1000 --min_lr 1e-7
|
|
726
723
|
|
|
727
724
|
# OneCycleLR for super-convergence
|
|
728
|
-
|
|
725
|
+
wavedl-train --model cnn --scheduler onecycle --lr 1e-2 --epochs 50
|
|
729
726
|
|
|
730
727
|
# MultiStep with custom milestones
|
|
731
|
-
|
|
728
|
+
wavedl-train --model cnn --scheduler multistep --milestones "100,200,300"
|
|
732
729
|
```
|
|
733
730
|
|
|
734
731
|
</details>
|
|
@@ -739,16 +736,14 @@ accelerate launch -m wavedl.train --model cnn --scheduler multistep --milestones
|
|
|
739
736
|
For robust model evaluation, simply add the `--cv` flag:
|
|
740
737
|
|
|
741
738
|
```bash
|
|
742
|
-
# 5-fold cross-validation
|
|
743
|
-
wavedl-
|
|
744
|
-
# OR
|
|
745
|
-
accelerate launch -m wavedl.train --model cnn --cv 5 --data_path train_data.npz
|
|
739
|
+
# 5-fold cross-validation
|
|
740
|
+
wavedl-train --model cnn --cv 5 --data_path train_data.npz
|
|
746
741
|
|
|
747
742
|
# Stratified CV (recommended for unbalanced data)
|
|
748
|
-
wavedl-
|
|
743
|
+
wavedl-train --model cnn --cv 5 --cv_stratify --loss huber --epochs 100
|
|
749
744
|
|
|
750
745
|
# Full configuration
|
|
751
|
-
wavedl-
|
|
746
|
+
wavedl-train --model cnn --cv 5 --cv_stratify \
|
|
752
747
|
--loss huber --optimizer adamw --scheduler cosine \
|
|
753
748
|
--output_dir ./cv_results
|
|
754
749
|
```
|
|
@@ -773,10 +768,10 @@ Use YAML files for reproducible experiments. CLI arguments can override any conf
|
|
|
773
768
|
|
|
774
769
|
```bash
|
|
775
770
|
# Use a config file
|
|
776
|
-
|
|
771
|
+
wavedl-train --config configs/config.yaml --data_path train.npz
|
|
777
772
|
|
|
778
773
|
# Override specific values from config
|
|
779
|
-
|
|
774
|
+
wavedl-train --config configs/config.yaml --lr 5e-4 --epochs 500
|
|
780
775
|
```
|
|
781
776
|
|
|
782
777
|
**Example config (`configs/config.yaml`):**
|
|
@@ -929,7 +924,7 @@ wavedl-hpo --data_path train.npz --models cnn --n_trials 50 --quick
|
|
|
929
924
|
|
|
930
925
|
After HPO completes, it prints the optimal command:
|
|
931
926
|
```bash
|
|
932
|
-
|
|
927
|
+
wavedl-train --data_path train.npz --model cnn --lr 3.2e-4 --batch_size 128 ...
|
|
933
928
|
```
|
|
934
929
|
|
|
935
930
|
---
|
|
@@ -1122,12 +1117,12 @@ The `examples/` folder contains a **complete, ready-to-run example** for **mater
|
|
|
1122
1117
|
|
|
1123
1118
|
```bash
|
|
1124
1119
|
# Run inference on the example data
|
|
1125
|
-
|
|
1120
|
+
wavedl-test --checkpoint ./examples/elasticity_prediction/best_checkpoint \
|
|
1126
1121
|
--data_path ./examples/elasticity_prediction/Test_data_100.mat \
|
|
1127
1122
|
--plot --save_predictions --output_dir ./examples/elasticity_prediction/test_results
|
|
1128
1123
|
|
|
1129
1124
|
# Export to ONNX (already included as model.onnx)
|
|
1130
|
-
|
|
1125
|
+
wavedl-test --checkpoint ./examples/elasticity_prediction/best_checkpoint \
|
|
1131
1126
|
--data_path ./examples/elasticity_prediction/Test_data_100.mat \
|
|
1132
1127
|
--export onnx --export_path ./examples/elasticity_prediction/model.onnx
|
|
1133
1128
|
```
|
|
@@ -1,9 +1,9 @@
|
|
|
1
|
-
wavedl/__init__.py,sha256=
|
|
2
|
-
wavedl/
|
|
3
|
-
wavedl/
|
|
4
|
-
wavedl/test.py,sha256=
|
|
5
|
-
wavedl/train.py,sha256=
|
|
6
|
-
wavedl/models/__init__.py,sha256=
|
|
1
|
+
wavedl/__init__.py,sha256=CdjeIFEZh4ccwPiNsUxpac5l3pxwb9e1SmCNiLiYG1I,1177
|
|
2
|
+
wavedl/hpo.py,sha256=nEiy-2O_5EhxF5hU8X5TviSAiXfVrTQx0-VE6baW7JQ,14633
|
|
3
|
+
wavedl/launcher.py,sha256=_CFlgpKgHrtZebl1yQbJZJEcob06Y9-fqnRYzwW7UJQ,11776
|
|
4
|
+
wavedl/test.py,sha256=5MzBtEH2lWWYG23Fz-VpMFAWR5SfZbFomBbu8ptsZRU,39208
|
|
5
|
+
wavedl/train.py,sha256=DizXhi9BFL8heLmO8ENiNm2QubAMm9mdpDiaBlULeKM,57824
|
|
6
|
+
wavedl/models/__init__.py,sha256=hyR__h_D8PsUQCBSM5tj94yYK00uG8ABjEmj_RR8SGE,5719
|
|
7
7
|
wavedl/models/_pretrained_utils.py,sha256=VPdU1DwJB93ZBf_GFIgb8-6BbAt18Phs4yorwlhLw70,12404
|
|
8
8
|
wavedl/models/_template.py,sha256=J_D8taSPmV8lBaucN_vU-WiG98iFr7CJrZVNNX_Tdts,4600
|
|
9
9
|
wavedl/models/base.py,sha256=bDoHYFli-aR8amcFYXbF98QYaKSCEwZWpvOhN21ODro,9075
|
|
@@ -32,15 +32,15 @@ wavedl/utils/__init__.py,sha256=s5R9bRmJ8GNcJrD3OSAOXzwZJIXZbdYrAkZnus11sVQ,3300
|
|
|
32
32
|
wavedl/utils/config.py,sha256=MXkaVc1_zo8sDro8mjtK1MV65t2z8b1Z6fviwSorNiY,10534
|
|
33
33
|
wavedl/utils/constraints.py,sha256=V9Gyi8-uIMbLUWb2cOaHZD0SliWLxVrHZHFyo4HWK7g,18031
|
|
34
34
|
wavedl/utils/cross_validation.py,sha256=HfInyZ8gUROc_AyihYKzzUE0vnoPt_mFvAI2OPK4P54,17945
|
|
35
|
-
wavedl/utils/data.py,sha256=
|
|
35
|
+
wavedl/utils/data.py,sha256=HXod6i6g76oFAjLz7xepBPQEFHRgQ7E1M-YSKwUya-I,64799
|
|
36
36
|
wavedl/utils/distributed.py,sha256=7wQ3mRjkp_xjPSxDWMnBf5dSkAGUaTzntxbz0BhC5v0,4145
|
|
37
37
|
wavedl/utils/losses.py,sha256=KWpU5S5noFzp3bLbcH9RNpkFPajy6fyTIh5cNjI-BYA,7038
|
|
38
38
|
wavedl/utils/metrics.py,sha256=YoqiXWOsUB9Y4_alj8CmHcTgnV4MFcH5PH4XlIC13HY,40304
|
|
39
39
|
wavedl/utils/optimizers.py,sha256=ZoETDSOK1fWUT2dx69PyYebeM8Vcqf9zOIKUERWk5HY,6107
|
|
40
40
|
wavedl/utils/schedulers.py,sha256=K6YCiyiMM9rb0cCRXTp89noXeXcAyUEiePr27O5Cozs,7408
|
|
41
|
-
wavedl-1.6.
|
|
42
|
-
wavedl-1.6.
|
|
43
|
-
wavedl-1.6.
|
|
44
|
-
wavedl-1.6.
|
|
45
|
-
wavedl-1.6.
|
|
46
|
-
wavedl-1.6.
|
|
41
|
+
wavedl-1.6.3.dist-info/LICENSE,sha256=cEUCvcvH-9BT9Y-CNGY__PwWONCKu9zsoIqWA-NeHJ4,1066
|
|
42
|
+
wavedl-1.6.3.dist-info/METADATA,sha256=BtPoAiMwvE58b-dmR0TfPPLLBqYzaAiJ2GBUd0FBSY0,47613
|
|
43
|
+
wavedl-1.6.3.dist-info/WHEEL,sha256=beeZ86-EfXScwlR_HKu4SllMC9wUEj_8Z_4FJ3egI2w,91
|
|
44
|
+
wavedl-1.6.3.dist-info/entry_points.txt,sha256=NuAvdiG93EYYpqv-_1wf6PN0WqBfABanDKalNKe2GOs,148
|
|
45
|
+
wavedl-1.6.3.dist-info/top_level.txt,sha256=ccneUt3D5Qzbh3bsBSSrq9bqrhGiogcWKY24ZC4Q6Xw,7
|
|
46
|
+
wavedl-1.6.3.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|