wavedl 1.6.1__py3-none-any.whl → 1.6.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- wavedl/__init__.py +1 -1
- wavedl/hpo.py +1 -1
- wavedl/{hpc.py → launcher.py} +135 -61
- wavedl/train.py +8 -11
- {wavedl-1.6.1.dist-info → wavedl-1.6.2.dist-info}/METADATA +58 -61
- {wavedl-1.6.1.dist-info → wavedl-1.6.2.dist-info}/RECORD +10 -10
- {wavedl-1.6.1.dist-info → wavedl-1.6.2.dist-info}/entry_points.txt +2 -2
- {wavedl-1.6.1.dist-info → wavedl-1.6.2.dist-info}/LICENSE +0 -0
- {wavedl-1.6.1.dist-info → wavedl-1.6.2.dist-info}/WHEEL +0 -0
- {wavedl-1.6.1.dist-info → wavedl-1.6.2.dist-info}/top_level.txt +0 -0
wavedl/__init__.py
CHANGED
wavedl/hpo.py
CHANGED
|
@@ -440,7 +440,7 @@ Examples:
|
|
|
440
440
|
print("\n" + "=" * 60)
|
|
441
441
|
print("TO TRAIN WITH BEST PARAMETERS:")
|
|
442
442
|
print("=" * 60)
|
|
443
|
-
cmd_parts = ["
|
|
443
|
+
cmd_parts = ["wavedl-train"]
|
|
444
444
|
cmd_parts.append(f"--data_path {args.data_path}")
|
|
445
445
|
for key, value in study.best_params.items():
|
|
446
446
|
cmd_parts.append(f"--{key} {value}")
|
wavedl/{hpc.py → launcher.py}
RENAMED
|
@@ -1,12 +1,21 @@
|
|
|
1
1
|
#!/usr/bin/env python
|
|
2
2
|
"""
|
|
3
|
-
WaveDL
|
|
3
|
+
WaveDL Training Launcher.
|
|
4
4
|
|
|
5
|
-
This module provides a
|
|
6
|
-
for distributed training
|
|
5
|
+
This module provides a universal training launcher that wraps accelerate
|
|
6
|
+
for distributed training. It works seamlessly on both:
|
|
7
|
+
- Local machines (uses standard cache locations)
|
|
8
|
+
- HPC clusters (uses local caching, offline WandB)
|
|
9
|
+
|
|
10
|
+
The environment is auto-detected based on scheduler variables (SLURM, PBS, etc.)
|
|
11
|
+
and home directory writability.
|
|
7
12
|
|
|
8
13
|
Usage:
|
|
9
|
-
|
|
14
|
+
# Local machine or HPC - same command!
|
|
15
|
+
wavedl-train --model cnn --data_path train.npz --output_dir results
|
|
16
|
+
|
|
17
|
+
# Multi-GPU is automatic (uses all available GPUs)
|
|
18
|
+
wavedl-train --model resnet18 --data_path train.npz --num_gpus 4
|
|
10
19
|
|
|
11
20
|
Example SLURM script:
|
|
12
21
|
#!/bin/bash
|
|
@@ -14,7 +23,7 @@ Example SLURM script:
|
|
|
14
23
|
#SBATCH --gpus-per-node=4
|
|
15
24
|
#SBATCH --time=12:00:00
|
|
16
25
|
|
|
17
|
-
wavedl-
|
|
26
|
+
wavedl-train --model cnn --data_path /scratch/data.npz --compile
|
|
18
27
|
|
|
19
28
|
Author: Ductho Le (ductho.le@outlook.com)
|
|
20
29
|
"""
|
|
@@ -53,78 +62,138 @@ def detect_gpus() -> int:
|
|
|
53
62
|
return 1
|
|
54
63
|
|
|
55
64
|
|
|
56
|
-
def
|
|
57
|
-
"""
|
|
65
|
+
def is_hpc_environment() -> bool:
|
|
66
|
+
"""Detect if running on an HPC cluster.
|
|
67
|
+
|
|
68
|
+
Checks for:
|
|
69
|
+
1. Common HPC scheduler environment variables (SLURM, PBS, LSF, SGE, Cobalt)
|
|
70
|
+
2. Non-writable home directory (common on HPC systems)
|
|
58
71
|
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
since compute nodes typically lack internet access.
|
|
72
|
+
Returns:
|
|
73
|
+
True if HPC environment detected, False otherwise.
|
|
62
74
|
"""
|
|
63
|
-
#
|
|
64
|
-
|
|
75
|
+
# Check for common HPC scheduler environment variables
|
|
76
|
+
hpc_indicators = [
|
|
77
|
+
"SLURM_JOB_ID", # SLURM
|
|
78
|
+
"PBS_JOBID", # PBS/Torque
|
|
79
|
+
"LSB_JOBID", # LSF
|
|
80
|
+
"SGE_TASK_ID", # Sun Grid Engine
|
|
81
|
+
"COBALT_JOBID", # Cobalt
|
|
82
|
+
]
|
|
83
|
+
if any(var in os.environ for var in hpc_indicators):
|
|
84
|
+
return True
|
|
65
85
|
|
|
66
|
-
#
|
|
67
|
-
os.
|
|
68
|
-
|
|
86
|
+
# Check if home directory is not writable (common on HPC)
|
|
87
|
+
home = os.path.expanduser("~")
|
|
88
|
+
return not os.access(home, os.W_OK)
|
|
69
89
|
|
|
70
|
-
# Triton/Inductor caches - prevents permission errors with --compile
|
|
71
|
-
# These MUST be set before any torch.compile calls
|
|
72
|
-
os.environ.setdefault("TRITON_CACHE_DIR", f"{cache_base}/.triton_cache")
|
|
73
|
-
os.environ.setdefault("TORCHINDUCTOR_CACHE_DIR", f"{cache_base}/.inductor_cache")
|
|
74
|
-
Path(os.environ["TRITON_CACHE_DIR"]).mkdir(parents=True, exist_ok=True)
|
|
75
|
-
Path(os.environ["TORCHINDUCTOR_CACHE_DIR"]).mkdir(parents=True, exist_ok=True)
|
|
76
90
|
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
#
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
]
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
91
|
+
def setup_environment() -> None:
|
|
92
|
+
"""Configure environment for HPC or local machine.
|
|
93
|
+
|
|
94
|
+
Automatically detects the environment and configures accordingly:
|
|
95
|
+
- HPC: Uses CWD-based caching, offline WandB (compute nodes lack internet)
|
|
96
|
+
- Local: Uses standard cache locations (~/.cache), doesn't override WandB
|
|
97
|
+
"""
|
|
98
|
+
is_hpc = is_hpc_environment()
|
|
99
|
+
|
|
100
|
+
if is_hpc:
|
|
101
|
+
# HPC: use CWD-based caching (compute nodes lack internet)
|
|
102
|
+
cache_base = os.getcwd()
|
|
103
|
+
|
|
104
|
+
# TORCH_HOME set to CWD - compute nodes need pre-cached weights
|
|
105
|
+
os.environ.setdefault("TORCH_HOME", f"{cache_base}/.torch_cache")
|
|
106
|
+
Path(os.environ["TORCH_HOME"]).mkdir(parents=True, exist_ok=True)
|
|
107
|
+
|
|
108
|
+
# Triton/Inductor caches - prevents permission errors with --compile
|
|
109
|
+
os.environ.setdefault("TRITON_CACHE_DIR", f"{cache_base}/.triton_cache")
|
|
110
|
+
os.environ.setdefault(
|
|
111
|
+
"TORCHINDUCTOR_CACHE_DIR", f"{cache_base}/.inductor_cache"
|
|
112
|
+
)
|
|
113
|
+
Path(os.environ["TRITON_CACHE_DIR"]).mkdir(parents=True, exist_ok=True)
|
|
114
|
+
Path(os.environ["TORCHINDUCTOR_CACHE_DIR"]).mkdir(parents=True, exist_ok=True)
|
|
115
|
+
|
|
116
|
+
# Check if home is writable for other caches
|
|
117
|
+
home = os.path.expanduser("~")
|
|
118
|
+
home_writable = os.access(home, os.W_OK)
|
|
119
|
+
|
|
120
|
+
# Other caches only if home is not writable
|
|
121
|
+
if not home_writable:
|
|
122
|
+
os.environ.setdefault("MPLCONFIGDIR", f"{cache_base}/.matplotlib")
|
|
123
|
+
os.environ.setdefault("FONTCONFIG_CACHE", f"{cache_base}/.fontconfig")
|
|
124
|
+
os.environ.setdefault("XDG_CACHE_HOME", f"{cache_base}/.cache")
|
|
125
|
+
|
|
126
|
+
for env_var in [
|
|
127
|
+
"MPLCONFIGDIR",
|
|
128
|
+
"FONTCONFIG_CACHE",
|
|
129
|
+
"XDG_CACHE_HOME",
|
|
130
|
+
]:
|
|
131
|
+
Path(os.environ[env_var]).mkdir(parents=True, exist_ok=True)
|
|
132
|
+
|
|
133
|
+
# WandB configuration (offline by default for HPC)
|
|
134
|
+
os.environ.setdefault("WANDB_MODE", "offline")
|
|
135
|
+
os.environ.setdefault("WANDB_DIR", f"{cache_base}/.wandb")
|
|
136
|
+
os.environ.setdefault("WANDB_CACHE_DIR", f"{cache_base}/.wandb_cache")
|
|
137
|
+
os.environ.setdefault("WANDB_CONFIG_DIR", f"{cache_base}/.wandb_config")
|
|
138
|
+
|
|
139
|
+
print("🖥️ HPC environment detected - using local caching")
|
|
140
|
+
else:
|
|
141
|
+
# Local machine: use standard locations, don't override user settings
|
|
142
|
+
# TORCH_HOME defaults to ~/.cache/torch (PyTorch default)
|
|
143
|
+
# WANDB_MODE defaults to online (WandB default)
|
|
144
|
+
print("💻 Local environment detected - using standard cache locations")
|
|
145
|
+
|
|
146
|
+
# Suppress non-critical warnings (both environments)
|
|
102
147
|
os.environ.setdefault(
|
|
103
148
|
"PYTHONWARNINGS",
|
|
104
149
|
"ignore::UserWarning,ignore::FutureWarning,ignore::DeprecationWarning",
|
|
105
150
|
)
|
|
106
151
|
|
|
107
152
|
|
|
153
|
+
def handle_fast_path_args() -> int | None:
|
|
154
|
+
"""Handle utility flags that don't need accelerate launch.
|
|
155
|
+
|
|
156
|
+
Returns:
|
|
157
|
+
Exit code if handled (0 for success), None if should continue to full launch.
|
|
158
|
+
"""
|
|
159
|
+
# --list_models: print models and exit immediately
|
|
160
|
+
if "--list_models" in sys.argv:
|
|
161
|
+
from wavedl.models import list_models
|
|
162
|
+
|
|
163
|
+
print("Available models:")
|
|
164
|
+
for name in list_models():
|
|
165
|
+
print(f" {name}")
|
|
166
|
+
return 0
|
|
167
|
+
|
|
168
|
+
return None # Continue to full launch
|
|
169
|
+
|
|
170
|
+
|
|
108
171
|
def parse_args() -> tuple[argparse.Namespace, list[str]]:
|
|
109
|
-
"""Parse
|
|
172
|
+
"""Parse launcher-specific arguments, pass remaining to wavedl.train."""
|
|
110
173
|
parser = argparse.ArgumentParser(
|
|
111
|
-
description="WaveDL
|
|
174
|
+
description="WaveDL Training Launcher (works on local machines and HPC clusters)",
|
|
112
175
|
formatter_class=argparse.RawDescriptionHelpFormatter,
|
|
113
176
|
epilog="""
|
|
114
177
|
Examples:
|
|
115
|
-
# Basic training
|
|
116
|
-
wavedl-
|
|
178
|
+
# Basic training (auto-detects GPUs and environment)
|
|
179
|
+
wavedl-train --model cnn --data_path train.npz --output_dir results
|
|
117
180
|
|
|
118
|
-
# Specify GPU count
|
|
119
|
-
wavedl-
|
|
181
|
+
# Specify GPU count explicitly
|
|
182
|
+
wavedl-train --model cnn --data_path train.npz --num_gpus 4
|
|
120
183
|
|
|
121
184
|
# Full configuration
|
|
122
|
-
wavedl-
|
|
123
|
-
|
|
185
|
+
wavedl-train --model resnet18 --data_path train.npz --batch_size 256 \\
|
|
186
|
+
--lr 1e-3 --epochs 100 --compile --output_dir ./results
|
|
187
|
+
|
|
188
|
+
# List available models
|
|
189
|
+
wavedl-train --list_models
|
|
124
190
|
|
|
125
|
-
Environment
|
|
126
|
-
|
|
127
|
-
|
|
191
|
+
Environment Detection:
|
|
192
|
+
The launcher automatically detects your environment:
|
|
193
|
+
- HPC (SLURM, PBS, etc.): Uses local caching, offline WandB
|
|
194
|
+
- Local machine: Uses standard cache locations (~/.cache)
|
|
195
|
+
|
|
196
|
+
For full training options, see: python -m wavedl.train --help
|
|
128
197
|
""",
|
|
129
198
|
)
|
|
130
199
|
|
|
@@ -204,7 +273,7 @@ def print_summary(
|
|
|
204
273
|
print("Common issues:")
|
|
205
274
|
print(" - Missing data file (check --data_path)")
|
|
206
275
|
print(" - Insufficient GPU memory (reduce --batch_size)")
|
|
207
|
-
print(" - Invalid model name (run:
|
|
276
|
+
print(" - Invalid model name (run: wavedl-train --list_models)")
|
|
208
277
|
print()
|
|
209
278
|
|
|
210
279
|
print("=" * 40)
|
|
@@ -212,12 +281,17 @@ def print_summary(
|
|
|
212
281
|
|
|
213
282
|
|
|
214
283
|
def main() -> int:
|
|
215
|
-
"""Main entry point for wavedl-
|
|
284
|
+
"""Main entry point for wavedl-train command."""
|
|
285
|
+
# Fast path for utility flags (avoid accelerate launch overhead)
|
|
286
|
+
exit_code = handle_fast_path_args()
|
|
287
|
+
if exit_code is not None:
|
|
288
|
+
return exit_code
|
|
289
|
+
|
|
216
290
|
# Parse arguments
|
|
217
291
|
args, train_args = parse_args()
|
|
218
292
|
|
|
219
|
-
# Setup
|
|
220
|
-
|
|
293
|
+
# Setup environment (smart detection)
|
|
294
|
+
setup_environment()
|
|
221
295
|
|
|
222
296
|
# Check if wavedl package is importable
|
|
223
297
|
try:
|
wavedl/train.py
CHANGED
|
@@ -12,25 +12,22 @@ A modular training framework for wave-based inverse problems and regression:
|
|
|
12
12
|
6. Deep Observability: WandB integration with scatter analysis
|
|
13
13
|
|
|
14
14
|
Usage:
|
|
15
|
-
# Recommended:
|
|
16
|
-
wavedl-
|
|
17
|
-
|
|
18
|
-
# Or direct training module (use --precision, not --mixed_precision)
|
|
19
|
-
accelerate launch -m wavedl.train --model cnn --batch_size 128 --precision bf16
|
|
15
|
+
# Recommended: Universal training command (works on local machines and HPC)
|
|
16
|
+
wavedl-train --model cnn --batch_size 128 --compile
|
|
20
17
|
|
|
21
18
|
# Multi-GPU with explicit config
|
|
22
|
-
wavedl-
|
|
19
|
+
wavedl-train --num_gpus 4 --model cnn --output_dir results
|
|
23
20
|
|
|
24
21
|
# Resume from checkpoint
|
|
25
|
-
|
|
22
|
+
wavedl-train --model cnn --output_dir results # auto-resumes if interrupted
|
|
26
23
|
|
|
27
24
|
# List available models
|
|
28
25
|
wavedl-train --list_models
|
|
29
26
|
|
|
30
27
|
Note:
|
|
31
|
-
|
|
32
|
-
-
|
|
33
|
-
|
|
28
|
+
wavedl-train automatically detects your environment:
|
|
29
|
+
- HPC clusters (SLURM, PBS, etc.): Uses local caching, offline WandB
|
|
30
|
+
- Local machines: Uses standard cache locations (~/.cache)
|
|
34
31
|
|
|
35
32
|
Author: Ductho Le (ductho.le@outlook.com)
|
|
36
33
|
"""
|
|
@@ -429,7 +426,7 @@ def parse_args() -> argparse.Namespace:
|
|
|
429
426
|
choices=["bf16", "fp16", "no"],
|
|
430
427
|
help="Mixed precision mode",
|
|
431
428
|
)
|
|
432
|
-
# Alias for consistency with wavedl-
|
|
429
|
+
# Alias for consistency with wavedl-train (--mixed_precision is passed to accelerate)
|
|
433
430
|
parser.add_argument(
|
|
434
431
|
"--mixed_precision",
|
|
435
432
|
dest="precision",
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.2
|
|
2
2
|
Name: wavedl
|
|
3
|
-
Version: 1.6.
|
|
3
|
+
Version: 1.6.2
|
|
4
4
|
Summary: A Scalable Deep Learning Framework for Wave-Based Inverse Problems
|
|
5
5
|
Author: Ductho Le
|
|
6
6
|
License: MIT
|
|
@@ -225,66 +225,74 @@ pip install -e .
|
|
|
225
225
|
> [!TIP]
|
|
226
226
|
> In all examples below, replace `<...>` placeholders with your values. See [Configuration](#️-configuration) for defaults and options.
|
|
227
227
|
|
|
228
|
-
|
|
229
|
-
|
|
230
|
-
The `wavedl-hpc` command automatically configures the environment for HPC systems:
|
|
228
|
+
### Training
|
|
231
229
|
|
|
232
230
|
```bash
|
|
233
|
-
# Basic training (auto-detects
|
|
234
|
-
wavedl-
|
|
231
|
+
# Basic training (auto-detects GPUs and environment)
|
|
232
|
+
wavedl-train --model <model_name> --data_path <train_data> --output_dir <output_folder>
|
|
235
233
|
|
|
236
234
|
# Detailed configuration
|
|
237
|
-
wavedl-
|
|
235
|
+
wavedl-train --model <model_name> --data_path <train_data> --batch_size <number> \
|
|
238
236
|
--lr <number> --epochs <number> --patience <number> --compile --output_dir <output_folder>
|
|
239
237
|
|
|
240
|
-
#
|
|
241
|
-
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
#### Option 2: Direct Accelerate Launch
|
|
245
|
-
|
|
246
|
-
```bash
|
|
247
|
-
# Local - auto-detects GPUs
|
|
248
|
-
accelerate launch -m wavedl.train --model <model_name> --data_path <train_data> --batch_size <number> --output_dir <output_folder>
|
|
238
|
+
# Multi-GPU is automatic (uses all available GPUs)
|
|
239
|
+
# Override with --num_gpus if needed
|
|
240
|
+
wavedl-train --model cnn --data_path train.npz --num_gpus 4 --output_dir results
|
|
249
241
|
|
|
250
242
|
# Resume training (automatic - just re-run with same output_dir)
|
|
251
|
-
|
|
252
|
-
accelerate launch -m wavedl.train --model <model_name> --data_path <train_data> --resume <checkpoint_folder> --output_dir <output_folder>
|
|
243
|
+
wavedl-train --model <model_name> --data_path <train_data> --output_dir <output_folder>
|
|
253
244
|
|
|
254
245
|
# Force fresh start (ignores existing checkpoints)
|
|
255
|
-
|
|
246
|
+
wavedl-train --model <model_name> --data_path <train_data> --output_dir <output_folder> --fresh
|
|
256
247
|
|
|
257
248
|
# List available models
|
|
258
249
|
wavedl-train --list_models
|
|
259
250
|
```
|
|
260
251
|
|
|
261
|
-
> [!
|
|
262
|
-
>
|
|
252
|
+
> [!NOTE]
|
|
253
|
+
> `wavedl-train` automatically detects your environment:
|
|
254
|
+
> - **HPC clusters** (SLURM, PBS, etc.): Uses local caching, offline WandB
|
|
255
|
+
> - **Local machines**: Uses standard cache locations (~/.cache)
|
|
263
256
|
>
|
|
264
|
-
> **
|
|
257
|
+
> **Auto-Resume**: If training crashes or is interrupted, simply re-run with the same `--output_dir`. The framework automatically detects incomplete training and resumes from the last checkpoint.
|
|
258
|
+
|
|
259
|
+
<details>
|
|
260
|
+
<summary><b>Advanced: Direct Accelerate Launch</b></summary>
|
|
261
|
+
|
|
262
|
+
For fine-grained control over distributed training, you can use `accelerate launch` directly:
|
|
263
|
+
|
|
264
|
+
```bash
|
|
265
|
+
# Custom accelerate configuration
|
|
266
|
+
accelerate launch -m wavedl.train --model <model_name> --data_path <train_data> --output_dir <output_folder>
|
|
267
|
+
|
|
268
|
+
# Multi-node training
|
|
269
|
+
accelerate launch --num_machines 2 --main_process_ip <ip> -m wavedl.train --model cnn --data_path train.npz
|
|
270
|
+
```
|
|
271
|
+
|
|
272
|
+
</details>
|
|
265
273
|
|
|
266
274
|
### Testing & Inference
|
|
267
275
|
|
|
268
|
-
After training, use `wavedl
|
|
276
|
+
After training, use `wavedl-test` to evaluate your model on test data:
|
|
269
277
|
|
|
270
278
|
```bash
|
|
271
279
|
# Basic inference
|
|
272
|
-
|
|
280
|
+
wavedl-test --checkpoint <checkpoint_folder> --data_path <test_data>
|
|
273
281
|
|
|
274
282
|
# With visualization, CSV export, and multiple file formats
|
|
275
|
-
|
|
283
|
+
wavedl-test --checkpoint <checkpoint_folder> --data_path <test_data> \
|
|
276
284
|
--plot --plot_format png pdf --save_predictions --output_dir <output_folder>
|
|
277
285
|
|
|
278
286
|
# With custom parameter names
|
|
279
|
-
|
|
287
|
+
wavedl-test --checkpoint <checkpoint_folder> --data_path <test_data> \
|
|
280
288
|
--param_names '$p_1$' '$p_2$' '$p_3$' --plot
|
|
281
289
|
|
|
282
290
|
# Export model to ONNX for deployment (LabVIEW, MATLAB, C++, etc.)
|
|
283
|
-
|
|
291
|
+
wavedl-test --checkpoint <checkpoint_folder> --data_path <test_data> \
|
|
284
292
|
--export onnx --export_path <output_file.onnx>
|
|
285
293
|
|
|
286
294
|
# For 3D volumes with small depth (e.g., 8×128×128), override auto-detection
|
|
287
|
-
|
|
295
|
+
wavedl-test --checkpoint <checkpoint_folder> --data_path <test_data> \
|
|
288
296
|
--input_channels 1
|
|
289
297
|
```
|
|
290
298
|
|
|
@@ -295,7 +303,7 @@ python -m wavedl.test --checkpoint <checkpoint_folder> --data_path <test_data> \
|
|
|
295
303
|
- **Format** (with `--plot_format`): Supported formats: `png` (default), `pdf` (vector), `svg` (vector), `eps` (LaTeX), `tiff`, `jpg`, `ps`
|
|
296
304
|
|
|
297
305
|
> [!NOTE]
|
|
298
|
-
> `wavedl
|
|
306
|
+
> `wavedl-test` auto-detects the model architecture from checkpoint metadata. If unavailable, it falls back to folder name parsing. Use `--model` to override if needed.
|
|
299
307
|
|
|
300
308
|
### Adding Custom Models
|
|
301
309
|
|
|
@@ -339,7 +347,7 @@ class MyModel(BaseModel):
|
|
|
339
347
|
**Step 2: Train**
|
|
340
348
|
|
|
341
349
|
```bash
|
|
342
|
-
wavedl-
|
|
350
|
+
wavedl-train --import my_model.py --model my_model --data_path train.npz
|
|
343
351
|
```
|
|
344
352
|
|
|
345
353
|
WaveDL handles everything else: training loop, logging, checkpoints, multi-GPU, early stopping, etc.
|
|
@@ -355,10 +363,10 @@ WaveDL/
|
|
|
355
363
|
├── src/
|
|
356
364
|
│ └── wavedl/ # Main package (namespaced)
|
|
357
365
|
│ ├── __init__.py # Package init with __version__
|
|
358
|
-
│ ├── train.py # Training
|
|
366
|
+
│ ├── train.py # Training script
|
|
359
367
|
│ ├── test.py # Testing & inference script
|
|
360
368
|
│ ├── hpo.py # Hyperparameter optimization
|
|
361
|
-
│ ├──
|
|
369
|
+
│ ├── launcher.py # Training launcher (wavedl-train)
|
|
362
370
|
│ │
|
|
363
371
|
│ ├── models/ # Model Zoo (69 architectures)
|
|
364
372
|
│ │ ├── registry.py # Model factory (@register_model)
|
|
@@ -389,16 +397,7 @@ WaveDL/
|
|
|
389
397
|
## ⚙️ Configuration
|
|
390
398
|
|
|
391
399
|
> [!NOTE]
|
|
392
|
-
> All configuration options below work with
|
|
393
|
-
>
|
|
394
|
-
> **Examples:**
|
|
395
|
-
> ```bash
|
|
396
|
-
> # Using wavedl-hpc
|
|
397
|
-
> wavedl-hpc --model cnn --batch_size 256 --lr 5e-4 --compile
|
|
398
|
-
>
|
|
399
|
-
> # Using accelerate launch directly
|
|
400
|
-
> accelerate launch -m wavedl.train --model cnn --batch_size 256 --lr 5e-4 --compile
|
|
401
|
-
> ```
|
|
400
|
+
> All configuration options below work with `wavedl-train`. The wrapper script passes all arguments directly to `train.py`.
|
|
402
401
|
|
|
403
402
|
<details>
|
|
404
403
|
<summary><b>Available Models</b> — 69 architectures</summary>
|
|
@@ -642,7 +641,7 @@ WaveDL automatically enables performance optimizations for modern GPUs:
|
|
|
642
641
|
</details>
|
|
643
642
|
|
|
644
643
|
<details>
|
|
645
|
-
<summary><b>
|
|
644
|
+
<summary><b>Distributed Training Arguments</b></summary>
|
|
646
645
|
|
|
647
646
|
| Argument | Default | Description |
|
|
648
647
|
|----------|---------|-------------|
|
|
@@ -674,10 +673,10 @@ WaveDL automatically enables performance optimizations for modern GPUs:
|
|
|
674
673
|
**Example:**
|
|
675
674
|
```bash
|
|
676
675
|
# Use Huber loss for noisy NDE data
|
|
677
|
-
|
|
676
|
+
wavedl-train --model cnn --loss huber --huber_delta 0.5
|
|
678
677
|
|
|
679
678
|
# Weighted MSE: prioritize thickness (first target)
|
|
680
|
-
|
|
679
|
+
wavedl-train --model cnn --loss weighted_mse --loss_weights "2.0,1.0,1.0"
|
|
681
680
|
```
|
|
682
681
|
|
|
683
682
|
</details>
|
|
@@ -697,10 +696,10 @@ accelerate launch -m wavedl.train --model cnn --loss weighted_mse --loss_weights
|
|
|
697
696
|
**Example:**
|
|
698
697
|
```bash
|
|
699
698
|
# SGD with Nesterov momentum (often better generalization)
|
|
700
|
-
|
|
699
|
+
wavedl-train --model cnn --optimizer sgd --lr 0.01 --momentum 0.9 --nesterov
|
|
701
700
|
|
|
702
701
|
# RAdam for more stable training
|
|
703
|
-
|
|
702
|
+
wavedl-train --model cnn --optimizer radam --lr 1e-3
|
|
704
703
|
```
|
|
705
704
|
|
|
706
705
|
</details>
|
|
@@ -722,13 +721,13 @@ accelerate launch -m wavedl.train --model cnn --optimizer radam --lr 1e-3
|
|
|
722
721
|
**Example:**
|
|
723
722
|
```bash
|
|
724
723
|
# Cosine annealing for 1000 epochs
|
|
725
|
-
|
|
724
|
+
wavedl-train --model cnn --scheduler cosine --epochs 1000 --min_lr 1e-7
|
|
726
725
|
|
|
727
726
|
# OneCycleLR for super-convergence
|
|
728
|
-
|
|
727
|
+
wavedl-train --model cnn --scheduler onecycle --lr 1e-2 --epochs 50
|
|
729
728
|
|
|
730
729
|
# MultiStep with custom milestones
|
|
731
|
-
|
|
730
|
+
wavedl-train --model cnn --scheduler multistep --milestones "100,200,300"
|
|
732
731
|
```
|
|
733
732
|
|
|
734
733
|
</details>
|
|
@@ -739,16 +738,14 @@ accelerate launch -m wavedl.train --model cnn --scheduler multistep --milestones
|
|
|
739
738
|
For robust model evaluation, simply add the `--cv` flag:
|
|
740
739
|
|
|
741
740
|
```bash
|
|
742
|
-
# 5-fold cross-validation
|
|
743
|
-
wavedl-
|
|
744
|
-
# OR
|
|
745
|
-
accelerate launch -m wavedl.train --model cnn --cv 5 --data_path train_data.npz
|
|
741
|
+
# 5-fold cross-validation
|
|
742
|
+
wavedl-train --model cnn --cv 5 --data_path train_data.npz
|
|
746
743
|
|
|
747
744
|
# Stratified CV (recommended for unbalanced data)
|
|
748
|
-
wavedl-
|
|
745
|
+
wavedl-train --model cnn --cv 5 --cv_stratify --loss huber --epochs 100
|
|
749
746
|
|
|
750
747
|
# Full configuration
|
|
751
|
-
wavedl-
|
|
748
|
+
wavedl-train --model cnn --cv 5 --cv_stratify \
|
|
752
749
|
--loss huber --optimizer adamw --scheduler cosine \
|
|
753
750
|
--output_dir ./cv_results
|
|
754
751
|
```
|
|
@@ -773,10 +770,10 @@ Use YAML files for reproducible experiments. CLI arguments can override any conf
|
|
|
773
770
|
|
|
774
771
|
```bash
|
|
775
772
|
# Use a config file
|
|
776
|
-
|
|
773
|
+
wavedl-train --config configs/config.yaml --data_path train.npz
|
|
777
774
|
|
|
778
775
|
# Override specific values from config
|
|
779
|
-
|
|
776
|
+
wavedl-train --config configs/config.yaml --lr 5e-4 --epochs 500
|
|
780
777
|
```
|
|
781
778
|
|
|
782
779
|
**Example config (`configs/config.yaml`):**
|
|
@@ -929,7 +926,7 @@ wavedl-hpo --data_path train.npz --models cnn --n_trials 50 --quick
|
|
|
929
926
|
|
|
930
927
|
After HPO completes, it prints the optimal command:
|
|
931
928
|
```bash
|
|
932
|
-
|
|
929
|
+
wavedl-train --data_path train.npz --model cnn --lr 3.2e-4 --batch_size 128 ...
|
|
933
930
|
```
|
|
934
931
|
|
|
935
932
|
---
|
|
@@ -1122,12 +1119,12 @@ The `examples/` folder contains a **complete, ready-to-run example** for **mater
|
|
|
1122
1119
|
|
|
1123
1120
|
```bash
|
|
1124
1121
|
# Run inference on the example data
|
|
1125
|
-
|
|
1122
|
+
wavedl-test --checkpoint ./examples/elasticity_prediction/best_checkpoint \
|
|
1126
1123
|
--data_path ./examples/elasticity_prediction/Test_data_100.mat \
|
|
1127
1124
|
--plot --save_predictions --output_dir ./examples/elasticity_prediction/test_results
|
|
1128
1125
|
|
|
1129
1126
|
# Export to ONNX (already included as model.onnx)
|
|
1130
|
-
|
|
1127
|
+
wavedl-test --checkpoint ./examples/elasticity_prediction/best_checkpoint \
|
|
1131
1128
|
--data_path ./examples/elasticity_prediction/Test_data_100.mat \
|
|
1132
1129
|
--export onnx --export_path ./examples/elasticity_prediction/model.onnx
|
|
1133
1130
|
```
|
|
@@ -1,8 +1,8 @@
|
|
|
1
|
-
wavedl/__init__.py,sha256=
|
|
2
|
-
wavedl/
|
|
3
|
-
wavedl/
|
|
1
|
+
wavedl/__init__.py,sha256=hFGU_j86Beexkcrn_V3fotGQ4ncwLGvz2lCOejEJ-f0,1177
|
|
2
|
+
wavedl/hpo.py,sha256=nEiy-2O_5EhxF5hU8X5TviSAiXfVrTQx0-VE6baW7JQ,14633
|
|
3
|
+
wavedl/launcher.py,sha256=_CFlgpKgHrtZebl1yQbJZJEcob06Y9-fqnRYzwW7UJQ,11776
|
|
4
4
|
wavedl/test.py,sha256=1UUy9phCqrr3h_lN6mGJ7Sj73skDg4KyLk2Yuq9DiKU,38797
|
|
5
|
-
wavedl/train.py,sha256=
|
|
5
|
+
wavedl/train.py,sha256=vBufy6gHShawgj8O6dvVER9TPhORa1s7L6pQtTe-N5M,57824
|
|
6
6
|
wavedl/models/__init__.py,sha256=8OiT2seq1qBiUzKaSkmh_VOLJlLTT9Cn-mjhMHKGFpI,5203
|
|
7
7
|
wavedl/models/_pretrained_utils.py,sha256=VPdU1DwJB93ZBf_GFIgb8-6BbAt18Phs4yorwlhLw70,12404
|
|
8
8
|
wavedl/models/_template.py,sha256=J_D8taSPmV8lBaucN_vU-WiG98iFr7CJrZVNNX_Tdts,4600
|
|
@@ -38,9 +38,9 @@ wavedl/utils/losses.py,sha256=KWpU5S5noFzp3bLbcH9RNpkFPajy6fyTIh5cNjI-BYA,7038
|
|
|
38
38
|
wavedl/utils/metrics.py,sha256=YoqiXWOsUB9Y4_alj8CmHcTgnV4MFcH5PH4XlIC13HY,40304
|
|
39
39
|
wavedl/utils/optimizers.py,sha256=ZoETDSOK1fWUT2dx69PyYebeM8Vcqf9zOIKUERWk5HY,6107
|
|
40
40
|
wavedl/utils/schedulers.py,sha256=K6YCiyiMM9rb0cCRXTp89noXeXcAyUEiePr27O5Cozs,7408
|
|
41
|
-
wavedl-1.6.
|
|
42
|
-
wavedl-1.6.
|
|
43
|
-
wavedl-1.6.
|
|
44
|
-
wavedl-1.6.
|
|
45
|
-
wavedl-1.6.
|
|
46
|
-
wavedl-1.6.
|
|
41
|
+
wavedl-1.6.2.dist-info/LICENSE,sha256=cEUCvcvH-9BT9Y-CNGY__PwWONCKu9zsoIqWA-NeHJ4,1066
|
|
42
|
+
wavedl-1.6.2.dist-info/METADATA,sha256=2mTyuip32AneUURV3K8oAjZQ2rA_13AB16R-VyRN5s8,47659
|
|
43
|
+
wavedl-1.6.2.dist-info/WHEEL,sha256=beeZ86-EfXScwlR_HKu4SllMC9wUEj_8Z_4FJ3egI2w,91
|
|
44
|
+
wavedl-1.6.2.dist-info/entry_points.txt,sha256=NuAvdiG93EYYpqv-_1wf6PN0WqBfABanDKalNKe2GOs,148
|
|
45
|
+
wavedl-1.6.2.dist-info/top_level.txt,sha256=ccneUt3D5Qzbh3bsBSSrq9bqrhGiogcWKY24ZC4Q6Xw,7
|
|
46
|
+
wavedl-1.6.2.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|