wavedl 1.4.5__py3-none-any.whl → 1.4.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- wavedl/__init__.py +1 -1
- wavedl/hpc.py +11 -2
- wavedl/hpo.py +51 -2
- wavedl/test.py +13 -7
- wavedl/train.py +27 -3
- {wavedl-1.4.5.dist-info → wavedl-1.4.6.dist-info}/METADATA +13 -11
- {wavedl-1.4.5.dist-info → wavedl-1.4.6.dist-info}/RECORD +11 -11
- {wavedl-1.4.5.dist-info → wavedl-1.4.6.dist-info}/LICENSE +0 -0
- {wavedl-1.4.5.dist-info → wavedl-1.4.6.dist-info}/WHEEL +0 -0
- {wavedl-1.4.5.dist-info → wavedl-1.4.6.dist-info}/entry_points.txt +0 -0
- {wavedl-1.4.5.dist-info → wavedl-1.4.6.dist-info}/top_level.txt +0 -0
wavedl/__init__.py
CHANGED
wavedl/hpc.py
CHANGED
|
@@ -174,7 +174,9 @@ Environment Variables:
|
|
|
174
174
|
return args, remaining
|
|
175
175
|
|
|
176
176
|
|
|
177
|
-
def print_summary(
|
|
177
|
+
def print_summary(
|
|
178
|
+
exit_code: int, wandb_enabled: bool, wandb_mode: str, wandb_dir: str
|
|
179
|
+
) -> None:
|
|
178
180
|
"""Print post-training summary and instructions."""
|
|
179
181
|
print()
|
|
180
182
|
print("=" * 40)
|
|
@@ -183,7 +185,8 @@ def print_summary(exit_code: int, wandb_mode: str, wandb_dir: str) -> None:
|
|
|
183
185
|
print("✅ Training completed successfully!")
|
|
184
186
|
print("=" * 40)
|
|
185
187
|
|
|
186
|
-
if
|
|
188
|
+
# Only show WandB sync instructions if user enabled wandb
|
|
189
|
+
if wandb_enabled and wandb_mode == "offline":
|
|
187
190
|
print()
|
|
188
191
|
print("📊 WandB Sync Instructions:")
|
|
189
192
|
print(" From the login node, run:")
|
|
@@ -237,6 +240,10 @@ def main() -> int:
|
|
|
237
240
|
f"--dynamo_backend={args.dynamo_backend}",
|
|
238
241
|
]
|
|
239
242
|
|
|
243
|
+
# Explicitly set multi_gpu to suppress accelerate auto-detection warning
|
|
244
|
+
if num_gpus > 1:
|
|
245
|
+
cmd.append("--multi_gpu")
|
|
246
|
+
|
|
240
247
|
# Add multi-node networking args if specified (required for some clusters)
|
|
241
248
|
if args.main_process_ip:
|
|
242
249
|
cmd.append(f"--main_process_ip={args.main_process_ip}")
|
|
@@ -263,8 +270,10 @@ def main() -> int:
|
|
|
263
270
|
exit_code = 130
|
|
264
271
|
|
|
265
272
|
# Print summary
|
|
273
|
+
wandb_enabled = "--wandb" in train_args
|
|
266
274
|
print_summary(
|
|
267
275
|
exit_code,
|
|
276
|
+
wandb_enabled,
|
|
268
277
|
os.environ.get("WANDB_MODE", "offline"),
|
|
269
278
|
os.environ.get("WANDB_DIR", "/tmp/wandb"),
|
|
270
279
|
)
|
wavedl/hpo.py
CHANGED
|
@@ -31,7 +31,7 @@ try:
|
|
|
31
31
|
import optuna
|
|
32
32
|
from optuna.trial import TrialState
|
|
33
33
|
except ImportError:
|
|
34
|
-
print("Error: Optuna not installed. Run: pip install
|
|
34
|
+
print("Error: Optuna not installed. Run: pip install wavedl")
|
|
35
35
|
sys.exit(1)
|
|
36
36
|
|
|
37
37
|
|
|
@@ -147,6 +147,32 @@ def create_objective(args):
|
|
|
147
147
|
cmd.extend(["--output_dir", tmpdir])
|
|
148
148
|
history_file = Path(tmpdir) / "training_history.csv"
|
|
149
149
|
|
|
150
|
+
# GPU isolation for parallel trials: assign each trial to a specific GPU
|
|
151
|
+
# This prevents multiple trials from competing for all GPUs
|
|
152
|
+
env = None
|
|
153
|
+
if args.n_jobs > 1:
|
|
154
|
+
import os
|
|
155
|
+
|
|
156
|
+
# Detect available GPUs
|
|
157
|
+
n_gpus = 1
|
|
158
|
+
try:
|
|
159
|
+
import subprocess as sp
|
|
160
|
+
|
|
161
|
+
result_gpu = sp.run(
|
|
162
|
+
["nvidia-smi", "--list-gpus"],
|
|
163
|
+
capture_output=True,
|
|
164
|
+
text=True,
|
|
165
|
+
)
|
|
166
|
+
if result_gpu.returncode == 0:
|
|
167
|
+
n_gpus = len(result_gpu.stdout.strip().split("\n"))
|
|
168
|
+
except Exception:
|
|
169
|
+
pass
|
|
170
|
+
|
|
171
|
+
# Assign trial to a specific GPU (round-robin)
|
|
172
|
+
gpu_id = trial.number % n_gpus
|
|
173
|
+
env = os.environ.copy()
|
|
174
|
+
env["CUDA_VISIBLE_DEVICES"] = str(gpu_id)
|
|
175
|
+
|
|
150
176
|
# Run training
|
|
151
177
|
try:
|
|
152
178
|
result = subprocess.run(
|
|
@@ -155,6 +181,7 @@ def create_objective(args):
|
|
|
155
181
|
text=True,
|
|
156
182
|
timeout=args.timeout,
|
|
157
183
|
cwd=Path(__file__).parent,
|
|
184
|
+
env=env,
|
|
158
185
|
)
|
|
159
186
|
|
|
160
187
|
# Read best val_loss from training_history.csv (reliable machine-readable)
|
|
@@ -248,7 +275,10 @@ Examples:
|
|
|
248
275
|
"--n_trials", type=int, default=50, help="Number of HPO trials (default: 50)"
|
|
249
276
|
)
|
|
250
277
|
parser.add_argument(
|
|
251
|
-
"--n_jobs",
|
|
278
|
+
"--n_jobs",
|
|
279
|
+
type=int,
|
|
280
|
+
default=-1,
|
|
281
|
+
help="Parallel trials (-1 = auto-detect GPUs, default: -1)",
|
|
252
282
|
)
|
|
253
283
|
parser.add_argument(
|
|
254
284
|
"--quick",
|
|
@@ -315,11 +345,30 @@ Examples:
|
|
|
315
345
|
|
|
316
346
|
args = parser.parse_args()
|
|
317
347
|
|
|
348
|
+
# Convert to absolute path (child processes may run in different cwd)
|
|
349
|
+
args.data_path = str(Path(args.data_path).resolve())
|
|
350
|
+
|
|
318
351
|
# Validate data path
|
|
319
352
|
if not Path(args.data_path).exists():
|
|
320
353
|
print(f"Error: Data file not found: {args.data_path}")
|
|
321
354
|
sys.exit(1)
|
|
322
355
|
|
|
356
|
+
# Auto-detect GPUs for n_jobs if not specified
|
|
357
|
+
if args.n_jobs == -1:
|
|
358
|
+
try:
|
|
359
|
+
result_gpu = subprocess.run(
|
|
360
|
+
["nvidia-smi", "--list-gpus"],
|
|
361
|
+
capture_output=True,
|
|
362
|
+
text=True,
|
|
363
|
+
)
|
|
364
|
+
if result_gpu.returncode == 0:
|
|
365
|
+
args.n_jobs = max(1, len(result_gpu.stdout.strip().split("\n")))
|
|
366
|
+
else:
|
|
367
|
+
args.n_jobs = 1
|
|
368
|
+
except Exception:
|
|
369
|
+
args.n_jobs = 1
|
|
370
|
+
print(f"Auto-detected {args.n_jobs} GPU(s) for parallel trials")
|
|
371
|
+
|
|
323
372
|
# Create study
|
|
324
373
|
print("=" * 60)
|
|
325
374
|
print("WaveDL Hyperparameter Optimization")
|
wavedl/test.py
CHANGED
|
@@ -366,13 +366,19 @@ def load_checkpoint(
|
|
|
366
366
|
logging.info(f" Building model: {model_name}")
|
|
367
367
|
model = build_model(model_name, in_shape=in_shape, out_size=out_size)
|
|
368
368
|
|
|
369
|
-
# Load weights (
|
|
370
|
-
weight_path =
|
|
371
|
-
|
|
372
|
-
|
|
373
|
-
|
|
374
|
-
|
|
375
|
-
|
|
369
|
+
# Load weights (check multiple formats in order of preference)
|
|
370
|
+
weight_path = None
|
|
371
|
+
for fname in ["model.safetensors", "model.bin", "pytorch_model.bin"]:
|
|
372
|
+
candidate = checkpoint_dir / fname
|
|
373
|
+
if candidate.exists():
|
|
374
|
+
weight_path = candidate
|
|
375
|
+
break
|
|
376
|
+
|
|
377
|
+
if weight_path is None:
|
|
378
|
+
raise FileNotFoundError(
|
|
379
|
+
f"No model weights found in {checkpoint_dir}. "
|
|
380
|
+
f"Expected one of: model.safetensors, model.bin, pytorch_model.bin"
|
|
381
|
+
)
|
|
376
382
|
|
|
377
383
|
if HAS_SAFETENSORS and weight_path.suffix == ".safetensors":
|
|
378
384
|
state_dict = load_safetensors(str(weight_path))
|
wavedl/train.py
CHANGED
|
@@ -148,6 +148,24 @@ torch.set_float32_matmul_precision("high") # Use TF32 for float32 ops
|
|
|
148
148
|
torch.backends.cudnn.benchmark = True
|
|
149
149
|
|
|
150
150
|
|
|
151
|
+
# ==============================================================================
|
|
152
|
+
# LOGGING UTILITIES
|
|
153
|
+
# ==============================================================================
|
|
154
|
+
from contextlib import contextmanager
|
|
155
|
+
|
|
156
|
+
|
|
157
|
+
@contextmanager
|
|
158
|
+
def suppress_accelerate_logging():
|
|
159
|
+
"""Temporarily suppress accelerate's verbose checkpoint save messages."""
|
|
160
|
+
accelerate_logger = logging.getLogger("accelerate.checkpointing")
|
|
161
|
+
original_level = accelerate_logger.level
|
|
162
|
+
accelerate_logger.setLevel(logging.WARNING)
|
|
163
|
+
try:
|
|
164
|
+
yield
|
|
165
|
+
finally:
|
|
166
|
+
accelerate_logger.setLevel(original_level)
|
|
167
|
+
|
|
168
|
+
|
|
151
169
|
# ==============================================================================
|
|
152
170
|
# ARGUMENT PARSING
|
|
153
171
|
# ==============================================================================
|
|
@@ -1033,7 +1051,8 @@ def main():
|
|
|
1033
1051
|
# Step 3: Save checkpoint with all ranks participating
|
|
1034
1052
|
if is_best_epoch:
|
|
1035
1053
|
ckpt_dir = os.path.join(args.output_dir, "best_checkpoint")
|
|
1036
|
-
|
|
1054
|
+
with suppress_accelerate_logging():
|
|
1055
|
+
accelerator.save_state(ckpt_dir, safe_serialization=False)
|
|
1037
1056
|
|
|
1038
1057
|
# Step 4: Rank 0 handles metadata and updates tracking variables
|
|
1039
1058
|
if accelerator.is_main_process:
|
|
@@ -1096,7 +1115,8 @@ def main():
|
|
|
1096
1115
|
if periodic_checkpoint_needed:
|
|
1097
1116
|
ckpt_name = f"epoch_{epoch + 1}_checkpoint"
|
|
1098
1117
|
ckpt_dir = os.path.join(args.output_dir, ckpt_name)
|
|
1099
|
-
|
|
1118
|
+
with suppress_accelerate_logging():
|
|
1119
|
+
accelerator.save_state(ckpt_dir, safe_serialization=False)
|
|
1100
1120
|
|
|
1101
1121
|
if accelerator.is_main_process:
|
|
1102
1122
|
with open(os.path.join(ckpt_dir, "training_meta.pkl"), "wb") as f:
|
|
@@ -1147,7 +1167,11 @@ def main():
|
|
|
1147
1167
|
|
|
1148
1168
|
except KeyboardInterrupt:
|
|
1149
1169
|
logger.warning("Training interrupted. Saving emergency checkpoint...")
|
|
1150
|
-
|
|
1170
|
+
with suppress_accelerate_logging():
|
|
1171
|
+
accelerator.save_state(
|
|
1172
|
+
os.path.join(args.output_dir, "interrupted_checkpoint"),
|
|
1173
|
+
safe_serialization=False,
|
|
1174
|
+
)
|
|
1151
1175
|
|
|
1152
1176
|
except Exception as e:
|
|
1153
1177
|
logger.error(f"Critical error: {e}", exc_info=True)
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.2
|
|
2
2
|
Name: wavedl
|
|
3
|
-
Version: 1.4.
|
|
3
|
+
Version: 1.4.6
|
|
4
4
|
Summary: A Scalable Deep Learning Framework for Wave-Based Inverse Problems
|
|
5
5
|
Author: Ductho Le
|
|
6
6
|
License: MIT
|
|
@@ -49,7 +49,7 @@ Requires-Dist: triton>=2.0.0; sys_platform == "linux"
|
|
|
49
49
|
|
|
50
50
|
### A Scalable Deep Learning Framework for Wave-Based Inverse Problems
|
|
51
51
|
|
|
52
|
-
[](https://www.python.org/downloads/)
|
|
53
53
|
[](https://pytorch.org/)
|
|
54
54
|
[](https://huggingface.co/docs/accelerate/)
|
|
55
55
|
<br>
|
|
@@ -57,7 +57,7 @@ Requires-Dist: triton>=2.0.0; sys_platform == "linux"
|
|
|
57
57
|
[](https://github.com/ductho-le/WaveDL/actions/workflows/lint.yml)
|
|
58
58
|
[](https://colab.research.google.com/github/ductho-le/WaveDL/blob/main/notebooks/demo.ipynb)
|
|
59
59
|
<br>
|
|
60
|
-
[](https://pypistats.org/packages/wavedl)
|
|
61
61
|
[](LICENSE)
|
|
62
62
|
[](https://doi.org/10.5281/zenodo.18012338)
|
|
63
63
|
|
|
@@ -734,18 +734,20 @@ Automatically find the best training configuration using [Optuna](https://optuna
|
|
|
734
734
|
|
|
735
735
|
**Run HPO:**
|
|
736
736
|
|
|
737
|
-
You specify which models to search and how many trials to run:
|
|
738
737
|
```bash
|
|
739
|
-
#
|
|
740
|
-
|
|
738
|
+
# Basic HPO (auto-detects GPUs for parallel trials)
|
|
739
|
+
wavedl-hpo --data_path train.npz --models cnn --n_trials 100
|
|
741
740
|
|
|
742
|
-
# Search
|
|
743
|
-
|
|
741
|
+
# Search multiple models
|
|
742
|
+
wavedl-hpo --data_path train.npz --models cnn resnet18 efficientnet_b0 --n_trials 200
|
|
744
743
|
|
|
745
|
-
#
|
|
746
|
-
|
|
744
|
+
# Quick mode (fewer parameters, faster)
|
|
745
|
+
wavedl-hpo --data_path train.npz --models cnn --n_trials 50 --quick
|
|
747
746
|
```
|
|
748
747
|
|
|
748
|
+
> [!TIP]
|
|
749
|
+
> **Auto GPU Detection**: HPO automatically detects available GPUs and runs one trial per GPU in parallel. On a 4-GPU system, 4 trials run simultaneously. Use `--n_jobs 1` to force serial execution.
|
|
750
|
+
|
|
749
751
|
**Train with best parameters**
|
|
750
752
|
|
|
751
753
|
After HPO completes, it prints the optimal command:
|
|
@@ -784,7 +786,7 @@ accelerate launch -m wavedl.train --data_path train.npz --model cnn --lr 3.2e-4
|
|
|
784
786
|
| `--optimizers` | all 6 | Optimizers to search |
|
|
785
787
|
| `--schedulers` | all 8 | Schedulers to search |
|
|
786
788
|
| `--losses` | all 6 | Losses to search |
|
|
787
|
-
| `--n_jobs` |
|
|
789
|
+
| `--n_jobs` | `-1` | Parallel trials (-1 = auto-detect GPUs) |
|
|
788
790
|
| `--max_epochs` | `50` | Max epochs per trial |
|
|
789
791
|
| `--output` | `hpo_results.json` | Output file |
|
|
790
792
|
|
|
@@ -1,8 +1,8 @@
|
|
|
1
|
-
wavedl/__init__.py,sha256=
|
|
2
|
-
wavedl/hpc.py,sha256
|
|
3
|
-
wavedl/hpo.py,sha256=
|
|
4
|
-
wavedl/test.py,sha256=
|
|
5
|
-
wavedl/train.py,sha256=
|
|
1
|
+
wavedl/__init__.py,sha256=ItdZLt3f7sbtAMgiwUtGwwG5Cko4tPLugC_OVhfHMno,1177
|
|
2
|
+
wavedl/hpc.py,sha256=-iOjjKkXPcV_quj4vAsMBJN_zWKtD1lMRfIZZBhyGms,8756
|
|
3
|
+
wavedl/hpo.py,sha256=JQvwPgiVHj3sB9Wombn1QO4ammpuo0QAMpRee0LjkuI,14731
|
|
4
|
+
wavedl/test.py,sha256=oWGSSC7178loqOxwti-oDXUVogOqbwHL__GfoXSE5Ss,37846
|
|
5
|
+
wavedl/train.py,sha256=9l4aVW1Jd1Sq6yBr8BOoVIKUYmxASDO8XK6BqEkLLWs,50151
|
|
6
6
|
wavedl/models/__init__.py,sha256=lfSohEnAUztO14nuwayMJhPjpgySzRN3jGiyAUuBmAU,3206
|
|
7
7
|
wavedl/models/_template.py,sha256=J_D8taSPmV8lBaucN_vU-WiG98iFr7CJrZVNNX_Tdts,4600
|
|
8
8
|
wavedl/models/base.py,sha256=T9iDF9IQM2MYucG_ggQd31rieUkB2fob-nkHyNIl2ak,7337
|
|
@@ -29,9 +29,9 @@ wavedl/utils/losses.py,sha256=5762M-TBC_hz6uyj1NPbU1vZeFOJQq7fR3-j7OygJRo,7254
|
|
|
29
29
|
wavedl/utils/metrics.py,sha256=mkCpqZwl_XUpNvA5Ekjf7y-HqApafR7eR6EuA8cBdM8,37287
|
|
30
30
|
wavedl/utils/optimizers.py,sha256=PyIkJ_hRhFi_Fio81Gy5YQNhcME0JUUEl8OTSyu-0RA,6323
|
|
31
31
|
wavedl/utils/schedulers.py,sha256=e6Sf0yj8VOqkdwkUHLMyUfGfHKTX4NMr-zfgxWqCTYI,7659
|
|
32
|
-
wavedl-1.4.
|
|
33
|
-
wavedl-1.4.
|
|
34
|
-
wavedl-1.4.
|
|
35
|
-
wavedl-1.4.
|
|
36
|
-
wavedl-1.4.
|
|
37
|
-
wavedl-1.4.
|
|
32
|
+
wavedl-1.4.6.dist-info/LICENSE,sha256=cEUCvcvH-9BT9Y-CNGY__PwWONCKu9zsoIqWA-NeHJ4,1066
|
|
33
|
+
wavedl-1.4.6.dist-info/METADATA,sha256=Hnot8ui2oksCz2UXhj3FHd_Z9MtoP8MJyiMzC6eWq5s,42453
|
|
34
|
+
wavedl-1.4.6.dist-info/WHEEL,sha256=beeZ86-EfXScwlR_HKu4SllMC9wUEj_8Z_4FJ3egI2w,91
|
|
35
|
+
wavedl-1.4.6.dist-info/entry_points.txt,sha256=f1RNDkXFZwBzrBzTMFocJ6xhfTvTmaEDTi5YyDEUaF8,140
|
|
36
|
+
wavedl-1.4.6.dist-info/top_level.txt,sha256=ccneUt3D5Qzbh3bsBSSrq9bqrhGiogcWKY24ZC4Q6Xw,7
|
|
37
|
+
wavedl-1.4.6.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|