wavedl 1.4.5__tar.gz → 1.4.6__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (42) hide show
  1. {wavedl-1.4.5/src/wavedl.egg-info → wavedl-1.4.6}/PKG-INFO +13 -11
  2. {wavedl-1.4.5 → wavedl-1.4.6}/README.md +12 -10
  3. {wavedl-1.4.5 → wavedl-1.4.6}/src/wavedl/__init__.py +1 -1
  4. {wavedl-1.4.5 → wavedl-1.4.6}/src/wavedl/hpc.py +11 -2
  5. {wavedl-1.4.5 → wavedl-1.4.6}/src/wavedl/hpo.py +51 -2
  6. {wavedl-1.4.5 → wavedl-1.4.6}/src/wavedl/test.py +13 -7
  7. {wavedl-1.4.5 → wavedl-1.4.6}/src/wavedl/train.py +27 -3
  8. {wavedl-1.4.5 → wavedl-1.4.6/src/wavedl.egg-info}/PKG-INFO +13 -11
  9. {wavedl-1.4.5 → wavedl-1.4.6}/LICENSE +0 -0
  10. {wavedl-1.4.5 → wavedl-1.4.6}/pyproject.toml +0 -0
  11. {wavedl-1.4.5 → wavedl-1.4.6}/setup.cfg +0 -0
  12. {wavedl-1.4.5 → wavedl-1.4.6}/src/wavedl/models/__init__.py +0 -0
  13. {wavedl-1.4.5 → wavedl-1.4.6}/src/wavedl/models/_template.py +0 -0
  14. {wavedl-1.4.5 → wavedl-1.4.6}/src/wavedl/models/base.py +0 -0
  15. {wavedl-1.4.5 → wavedl-1.4.6}/src/wavedl/models/cnn.py +0 -0
  16. {wavedl-1.4.5 → wavedl-1.4.6}/src/wavedl/models/convnext.py +0 -0
  17. {wavedl-1.4.5 → wavedl-1.4.6}/src/wavedl/models/densenet.py +0 -0
  18. {wavedl-1.4.5 → wavedl-1.4.6}/src/wavedl/models/efficientnet.py +0 -0
  19. {wavedl-1.4.5 → wavedl-1.4.6}/src/wavedl/models/efficientnetv2.py +0 -0
  20. {wavedl-1.4.5 → wavedl-1.4.6}/src/wavedl/models/mobilenetv3.py +0 -0
  21. {wavedl-1.4.5 → wavedl-1.4.6}/src/wavedl/models/registry.py +0 -0
  22. {wavedl-1.4.5 → wavedl-1.4.6}/src/wavedl/models/regnet.py +0 -0
  23. {wavedl-1.4.5 → wavedl-1.4.6}/src/wavedl/models/resnet.py +0 -0
  24. {wavedl-1.4.5 → wavedl-1.4.6}/src/wavedl/models/resnet3d.py +0 -0
  25. {wavedl-1.4.5 → wavedl-1.4.6}/src/wavedl/models/swin.py +0 -0
  26. {wavedl-1.4.5 → wavedl-1.4.6}/src/wavedl/models/tcn.py +0 -0
  27. {wavedl-1.4.5 → wavedl-1.4.6}/src/wavedl/models/unet.py +0 -0
  28. {wavedl-1.4.5 → wavedl-1.4.6}/src/wavedl/models/vit.py +0 -0
  29. {wavedl-1.4.5 → wavedl-1.4.6}/src/wavedl/utils/__init__.py +0 -0
  30. {wavedl-1.4.5 → wavedl-1.4.6}/src/wavedl/utils/config.py +0 -0
  31. {wavedl-1.4.5 → wavedl-1.4.6}/src/wavedl/utils/cross_validation.py +0 -0
  32. {wavedl-1.4.5 → wavedl-1.4.6}/src/wavedl/utils/data.py +0 -0
  33. {wavedl-1.4.5 → wavedl-1.4.6}/src/wavedl/utils/distributed.py +0 -0
  34. {wavedl-1.4.5 → wavedl-1.4.6}/src/wavedl/utils/losses.py +0 -0
  35. {wavedl-1.4.5 → wavedl-1.4.6}/src/wavedl/utils/metrics.py +0 -0
  36. {wavedl-1.4.5 → wavedl-1.4.6}/src/wavedl/utils/optimizers.py +0 -0
  37. {wavedl-1.4.5 → wavedl-1.4.6}/src/wavedl/utils/schedulers.py +0 -0
  38. {wavedl-1.4.5 → wavedl-1.4.6}/src/wavedl.egg-info/SOURCES.txt +0 -0
  39. {wavedl-1.4.5 → wavedl-1.4.6}/src/wavedl.egg-info/dependency_links.txt +0 -0
  40. {wavedl-1.4.5 → wavedl-1.4.6}/src/wavedl.egg-info/entry_points.txt +0 -0
  41. {wavedl-1.4.5 → wavedl-1.4.6}/src/wavedl.egg-info/requires.txt +0 -0
  42. {wavedl-1.4.5 → wavedl-1.4.6}/src/wavedl.egg-info/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.2
2
2
  Name: wavedl
3
- Version: 1.4.5
3
+ Version: 1.4.6
4
4
  Summary: A Scalable Deep Learning Framework for Wave-Based Inverse Problems
5
5
  Author: Ductho Le
6
6
  License: MIT
@@ -49,7 +49,7 @@ Requires-Dist: triton>=2.0.0; sys_platform == "linux"
49
49
 
50
50
  ### A Scalable Deep Learning Framework for Wave-Based Inverse Problems
51
51
 
52
- [![Python 3.11+](https://img.shields.io/badge/python-3.11+-blue.svg?style=plastic&logo=python&logoColor=white)](https://www.python.org/downloads/)
52
+ [![Python 3.11+](https://img.shields.io/badge/Python-3.11+-blue.svg?style=plastic&logo=python&logoColor=white)](https://www.python.org/downloads/)
53
53
  [![PyTorch 2.x](https://img.shields.io/badge/PyTorch-2.x-ee4c2c.svg?style=plastic&logo=pytorch&logoColor=white)](https://pytorch.org/)
54
54
  [![Accelerate](https://img.shields.io/badge/Accelerate-Enabled-yellow.svg?style=plastic&logo=huggingface&logoColor=white)](https://huggingface.co/docs/accelerate/)
55
55
  <br>
@@ -57,7 +57,7 @@ Requires-Dist: triton>=2.0.0; sys_platform == "linux"
57
57
  [![Lint](https://img.shields.io/github/actions/workflow/status/ductho-le/WaveDL/lint.yml?branch=main&style=plastic&logo=ruff&logoColor=white&label=Lint)](https://github.com/ductho-le/WaveDL/actions/workflows/lint.yml)
58
58
  [![Try it on Colab](https://img.shields.io/badge/Try_it_on_Colab-8E44AD?style=plastic&logo=googlecolab&logoColor=white)](https://colab.research.google.com/github/ductho-le/WaveDL/blob/main/notebooks/demo.ipynb)
59
59
  <br>
60
- [![Downloads](https://img.shields.io/pepy/dt/wavedl?style=plastic&logo=pypi&logoColor=white&color=9ACD32)](https://pepy.tech/project/wavedl)
60
+ [![Downloads](https://img.shields.io/badge/dynamic/json?url=https://pypistats.org/api/packages/wavedl/recent?period=month%26mirrors=false&query=data.last_month&style=plastic&logo=pypi&logoColor=white&color=9ACD32&label=Downloads&suffix=/month)](https://pypistats.org/packages/wavedl)
61
61
  [![License: MIT](https://img.shields.io/badge/License-MIT-orange.svg?style=plastic)](LICENSE)
62
62
  [![DOI](https://img.shields.io/badge/DOI-10.5281/zenodo.18012338-008080.svg?style=plastic)](https://doi.org/10.5281/zenodo.18012338)
63
63
 
@@ -734,18 +734,20 @@ Automatically find the best training configuration using [Optuna](https://optuna
734
734
 
735
735
  **Run HPO:**
736
736
 
737
- You specify which models to search and how many trials to run:
738
737
  ```bash
739
- # Search 3 models with 100 trials
740
- python -m wavedl.hpo --data_path train.npz --models cnn resnet18 efficientnet_b0 --n_trials 100
738
+ # Basic HPO (auto-detects GPUs for parallel trials)
739
+ wavedl-hpo --data_path train.npz --models cnn --n_trials 100
741
740
 
742
- # Search 1 model (faster)
743
- python -m wavedl.hpo --data_path train.npz --models cnn --n_trials 50
741
+ # Search multiple models
742
+ wavedl-hpo --data_path train.npz --models cnn resnet18 efficientnet_b0 --n_trials 200
744
743
 
745
- # Search all your candidate models
746
- python -m wavedl.hpo --data_path train.npz --models cnn resnet18 resnet50 vit_small densenet121 --n_trials 200
744
+ # Quick mode (fewer parameters, faster)
745
+ wavedl-hpo --data_path train.npz --models cnn --n_trials 50 --quick
747
746
  ```
748
747
 
748
+ > [!TIP]
749
+ > **Auto GPU Detection**: HPO automatically detects available GPUs and runs one trial per GPU in parallel. On a 4-GPU system, 4 trials run simultaneously. Use `--n_jobs 1` to force serial execution.
750
+
749
751
  **Train with best parameters**
750
752
 
751
753
  After HPO completes, it prints the optimal command:
@@ -784,7 +786,7 @@ accelerate launch -m wavedl.train --data_path train.npz --model cnn --lr 3.2e-4
784
786
  | `--optimizers` | all 6 | Optimizers to search |
785
787
  | `--schedulers` | all 8 | Schedulers to search |
786
788
  | `--losses` | all 6 | Losses to search |
787
- | `--n_jobs` | `1` | Parallel trials (multi-GPU) |
789
+ | `--n_jobs` | `-1` | Parallel trials (-1 = auto-detect GPUs) |
788
790
  | `--max_epochs` | `50` | Max epochs per trial |
789
791
  | `--output` | `hpo_results.json` | Output file |
790
792
 
@@ -4,7 +4,7 @@
4
4
 
5
5
  ### A Scalable Deep Learning Framework for Wave-Based Inverse Problems
6
6
 
7
- [![Python 3.11+](https://img.shields.io/badge/python-3.11+-blue.svg?style=plastic&logo=python&logoColor=white)](https://www.python.org/downloads/)
7
+ [![Python 3.11+](https://img.shields.io/badge/Python-3.11+-blue.svg?style=plastic&logo=python&logoColor=white)](https://www.python.org/downloads/)
8
8
  [![PyTorch 2.x](https://img.shields.io/badge/PyTorch-2.x-ee4c2c.svg?style=plastic&logo=pytorch&logoColor=white)](https://pytorch.org/)
9
9
  [![Accelerate](https://img.shields.io/badge/Accelerate-Enabled-yellow.svg?style=plastic&logo=huggingface&logoColor=white)](https://huggingface.co/docs/accelerate/)
10
10
  <br>
@@ -12,7 +12,7 @@
12
12
  [![Lint](https://img.shields.io/github/actions/workflow/status/ductho-le/WaveDL/lint.yml?branch=main&style=plastic&logo=ruff&logoColor=white&label=Lint)](https://github.com/ductho-le/WaveDL/actions/workflows/lint.yml)
13
13
  [![Try it on Colab](https://img.shields.io/badge/Try_it_on_Colab-8E44AD?style=plastic&logo=googlecolab&logoColor=white)](https://colab.research.google.com/github/ductho-le/WaveDL/blob/main/notebooks/demo.ipynb)
14
14
  <br>
15
- [![Downloads](https://img.shields.io/pepy/dt/wavedl?style=plastic&logo=pypi&logoColor=white&color=9ACD32)](https://pepy.tech/project/wavedl)
15
+ [![Downloads](https://img.shields.io/badge/dynamic/json?url=https://pypistats.org/api/packages/wavedl/recent?period=month%26mirrors=false&query=data.last_month&style=plastic&logo=pypi&logoColor=white&color=9ACD32&label=Downloads&suffix=/month)](https://pypistats.org/packages/wavedl)
16
16
  [![License: MIT](https://img.shields.io/badge/License-MIT-orange.svg?style=plastic)](LICENSE)
17
17
  [![DOI](https://img.shields.io/badge/DOI-10.5281/zenodo.18012338-008080.svg?style=plastic)](https://doi.org/10.5281/zenodo.18012338)
18
18
 
@@ -689,18 +689,20 @@ Automatically find the best training configuration using [Optuna](https://optuna
689
689
 
690
690
  **Run HPO:**
691
691
 
692
- You specify which models to search and how many trials to run:
693
692
  ```bash
694
- # Search 3 models with 100 trials
695
- python -m wavedl.hpo --data_path train.npz --models cnn resnet18 efficientnet_b0 --n_trials 100
693
+ # Basic HPO (auto-detects GPUs for parallel trials)
694
+ wavedl-hpo --data_path train.npz --models cnn --n_trials 100
696
695
 
697
- # Search 1 model (faster)
698
- python -m wavedl.hpo --data_path train.npz --models cnn --n_trials 50
696
+ # Search multiple models
697
+ wavedl-hpo --data_path train.npz --models cnn resnet18 efficientnet_b0 --n_trials 200
699
698
 
700
- # Search all your candidate models
701
- python -m wavedl.hpo --data_path train.npz --models cnn resnet18 resnet50 vit_small densenet121 --n_trials 200
699
+ # Quick mode (fewer parameters, faster)
700
+ wavedl-hpo --data_path train.npz --models cnn --n_trials 50 --quick
702
701
  ```
703
702
 
703
+ > [!TIP]
704
+ > **Auto GPU Detection**: HPO automatically detects available GPUs and runs one trial per GPU in parallel. On a 4-GPU system, 4 trials run simultaneously. Use `--n_jobs 1` to force serial execution.
705
+
704
706
  **Train with best parameters**
705
707
 
706
708
  After HPO completes, it prints the optimal command:
@@ -739,7 +741,7 @@ accelerate launch -m wavedl.train --data_path train.npz --model cnn --lr 3.2e-4
739
741
  | `--optimizers` | all 6 | Optimizers to search |
740
742
  | `--schedulers` | all 8 | Schedulers to search |
741
743
  | `--losses` | all 6 | Losses to search |
742
- | `--n_jobs` | `1` | Parallel trials (multi-GPU) |
744
+ | `--n_jobs` | `-1` | Parallel trials (-1 = auto-detect GPUs) |
743
745
  | `--max_epochs` | `50` | Max epochs per trial |
744
746
  | `--output` | `hpo_results.json` | Output file |
745
747
 
@@ -18,7 +18,7 @@ For inference:
18
18
  # or: python -m wavedl.test --checkpoint best_checkpoint --data_path test.npz
19
19
  """
20
20
 
21
- __version__ = "1.4.5"
21
+ __version__ = "1.4.6"
22
22
  __author__ = "Ductho Le"
23
23
  __email__ = "ductho.le@outlook.com"
24
24
 
@@ -174,7 +174,9 @@ Environment Variables:
174
174
  return args, remaining
175
175
 
176
176
 
177
- def print_summary(exit_code: int, wandb_mode: str, wandb_dir: str) -> None:
177
+ def print_summary(
178
+ exit_code: int, wandb_enabled: bool, wandb_mode: str, wandb_dir: str
179
+ ) -> None:
178
180
  """Print post-training summary and instructions."""
179
181
  print()
180
182
  print("=" * 40)
@@ -183,7 +185,8 @@ def print_summary(exit_code: int, wandb_mode: str, wandb_dir: str) -> None:
183
185
  print("✅ Training completed successfully!")
184
186
  print("=" * 40)
185
187
 
186
- if wandb_mode == "offline":
188
+ # Only show WandB sync instructions if user enabled wandb
189
+ if wandb_enabled and wandb_mode == "offline":
187
190
  print()
188
191
  print("📊 WandB Sync Instructions:")
189
192
  print(" From the login node, run:")
@@ -237,6 +240,10 @@ def main() -> int:
237
240
  f"--dynamo_backend={args.dynamo_backend}",
238
241
  ]
239
242
 
243
+ # Explicitly set multi_gpu to suppress accelerate auto-detection warning
244
+ if num_gpus > 1:
245
+ cmd.append("--multi_gpu")
246
+
240
247
  # Add multi-node networking args if specified (required for some clusters)
241
248
  if args.main_process_ip:
242
249
  cmd.append(f"--main_process_ip={args.main_process_ip}")
@@ -263,8 +270,10 @@ def main() -> int:
263
270
  exit_code = 130
264
271
 
265
272
  # Print summary
273
+ wandb_enabled = "--wandb" in train_args
266
274
  print_summary(
267
275
  exit_code,
276
+ wandb_enabled,
268
277
  os.environ.get("WANDB_MODE", "offline"),
269
278
  os.environ.get("WANDB_DIR", "/tmp/wandb"),
270
279
  )
@@ -31,7 +31,7 @@ try:
31
31
  import optuna
32
32
  from optuna.trial import TrialState
33
33
  except ImportError:
34
- print("Error: Optuna not installed. Run: pip install -e '.[hpo]'")
34
+ print("Error: Optuna not installed. Run: pip install wavedl")
35
35
  sys.exit(1)
36
36
 
37
37
 
@@ -147,6 +147,32 @@ def create_objective(args):
147
147
  cmd.extend(["--output_dir", tmpdir])
148
148
  history_file = Path(tmpdir) / "training_history.csv"
149
149
 
150
+ # GPU isolation for parallel trials: assign each trial to a specific GPU
151
+ # This prevents multiple trials from competing for all GPUs
152
+ env = None
153
+ if args.n_jobs > 1:
154
+ import os
155
+
156
+ # Detect available GPUs
157
+ n_gpus = 1
158
+ try:
159
+ import subprocess as sp
160
+
161
+ result_gpu = sp.run(
162
+ ["nvidia-smi", "--list-gpus"],
163
+ capture_output=True,
164
+ text=True,
165
+ )
166
+ if result_gpu.returncode == 0:
167
+ n_gpus = len(result_gpu.stdout.strip().split("\n"))
168
+ except Exception:
169
+ pass
170
+
171
+ # Assign trial to a specific GPU (round-robin)
172
+ gpu_id = trial.number % n_gpus
173
+ env = os.environ.copy()
174
+ env["CUDA_VISIBLE_DEVICES"] = str(gpu_id)
175
+
150
176
  # Run training
151
177
  try:
152
178
  result = subprocess.run(
@@ -155,6 +181,7 @@ def create_objective(args):
155
181
  text=True,
156
182
  timeout=args.timeout,
157
183
  cwd=Path(__file__).parent,
184
+ env=env,
158
185
  )
159
186
 
160
187
  # Read best val_loss from training_history.csv (reliable machine-readable)
@@ -248,7 +275,10 @@ Examples:
248
275
  "--n_trials", type=int, default=50, help="Number of HPO trials (default: 50)"
249
276
  )
250
277
  parser.add_argument(
251
- "--n_jobs", type=int, default=1, help="Parallel trials (default: 1)"
278
+ "--n_jobs",
279
+ type=int,
280
+ default=-1,
281
+ help="Parallel trials (-1 = auto-detect GPUs, default: -1)",
252
282
  )
253
283
  parser.add_argument(
254
284
  "--quick",
@@ -315,11 +345,30 @@ Examples:
315
345
 
316
346
  args = parser.parse_args()
317
347
 
348
+ # Convert to absolute path (child processes may run in different cwd)
349
+ args.data_path = str(Path(args.data_path).resolve())
350
+
318
351
  # Validate data path
319
352
  if not Path(args.data_path).exists():
320
353
  print(f"Error: Data file not found: {args.data_path}")
321
354
  sys.exit(1)
322
355
 
356
+ # Auto-detect GPUs for n_jobs if not specified
357
+ if args.n_jobs == -1:
358
+ try:
359
+ result_gpu = subprocess.run(
360
+ ["nvidia-smi", "--list-gpus"],
361
+ capture_output=True,
362
+ text=True,
363
+ )
364
+ if result_gpu.returncode == 0:
365
+ args.n_jobs = max(1, len(result_gpu.stdout.strip().split("\n")))
366
+ else:
367
+ args.n_jobs = 1
368
+ except Exception:
369
+ args.n_jobs = 1
370
+ print(f"Auto-detected {args.n_jobs} GPU(s) for parallel trials")
371
+
323
372
  # Create study
324
373
  print("=" * 60)
325
374
  print("WaveDL Hyperparameter Optimization")
@@ -366,13 +366,19 @@ def load_checkpoint(
366
366
  logging.info(f" Building model: {model_name}")
367
367
  model = build_model(model_name, in_shape=in_shape, out_size=out_size)
368
368
 
369
- # Load weights (prefer safetensors)
370
- weight_path = checkpoint_dir / "model.safetensors"
371
- if not weight_path.exists():
372
- weight_path = checkpoint_dir / "pytorch_model.bin"
373
-
374
- if not weight_path.exists():
375
- raise FileNotFoundError(f"No model weights found in {checkpoint_dir}")
369
+ # Load weights (check multiple formats in order of preference)
370
+ weight_path = None
371
+ for fname in ["model.safetensors", "model.bin", "pytorch_model.bin"]:
372
+ candidate = checkpoint_dir / fname
373
+ if candidate.exists():
374
+ weight_path = candidate
375
+ break
376
+
377
+ if weight_path is None:
378
+ raise FileNotFoundError(
379
+ f"No model weights found in {checkpoint_dir}. "
380
+ f"Expected one of: model.safetensors, model.bin, pytorch_model.bin"
381
+ )
376
382
 
377
383
  if HAS_SAFETENSORS and weight_path.suffix == ".safetensors":
378
384
  state_dict = load_safetensors(str(weight_path))
@@ -148,6 +148,24 @@ torch.set_float32_matmul_precision("high") # Use TF32 for float32 ops
148
148
  torch.backends.cudnn.benchmark = True
149
149
 
150
150
 
151
+ # ==============================================================================
152
+ # LOGGING UTILITIES
153
+ # ==============================================================================
154
+ from contextlib import contextmanager
155
+
156
+
157
+ @contextmanager
158
+ def suppress_accelerate_logging():
159
+ """Temporarily suppress accelerate's verbose checkpoint save messages."""
160
+ accelerate_logger = logging.getLogger("accelerate.checkpointing")
161
+ original_level = accelerate_logger.level
162
+ accelerate_logger.setLevel(logging.WARNING)
163
+ try:
164
+ yield
165
+ finally:
166
+ accelerate_logger.setLevel(original_level)
167
+
168
+
151
169
  # ==============================================================================
152
170
  # ARGUMENT PARSING
153
171
  # ==============================================================================
@@ -1033,7 +1051,8 @@ def main():
1033
1051
  # Step 3: Save checkpoint with all ranks participating
1034
1052
  if is_best_epoch:
1035
1053
  ckpt_dir = os.path.join(args.output_dir, "best_checkpoint")
1036
- accelerator.save_state(ckpt_dir) # All ranks must call this
1054
+ with suppress_accelerate_logging():
1055
+ accelerator.save_state(ckpt_dir, safe_serialization=False)
1037
1056
 
1038
1057
  # Step 4: Rank 0 handles metadata and updates tracking variables
1039
1058
  if accelerator.is_main_process:
@@ -1096,7 +1115,8 @@ def main():
1096
1115
  if periodic_checkpoint_needed:
1097
1116
  ckpt_name = f"epoch_{epoch + 1}_checkpoint"
1098
1117
  ckpt_dir = os.path.join(args.output_dir, ckpt_name)
1099
- accelerator.save_state(ckpt_dir) # All ranks participate
1118
+ with suppress_accelerate_logging():
1119
+ accelerator.save_state(ckpt_dir, safe_serialization=False)
1100
1120
 
1101
1121
  if accelerator.is_main_process:
1102
1122
  with open(os.path.join(ckpt_dir, "training_meta.pkl"), "wb") as f:
@@ -1147,7 +1167,11 @@ def main():
1147
1167
 
1148
1168
  except KeyboardInterrupt:
1149
1169
  logger.warning("Training interrupted. Saving emergency checkpoint...")
1150
- accelerator.save_state(os.path.join(args.output_dir, "interrupted_checkpoint"))
1170
+ with suppress_accelerate_logging():
1171
+ accelerator.save_state(
1172
+ os.path.join(args.output_dir, "interrupted_checkpoint"),
1173
+ safe_serialization=False,
1174
+ )
1151
1175
 
1152
1176
  except Exception as e:
1153
1177
  logger.error(f"Critical error: {e}", exc_info=True)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.2
2
2
  Name: wavedl
3
- Version: 1.4.5
3
+ Version: 1.4.6
4
4
  Summary: A Scalable Deep Learning Framework for Wave-Based Inverse Problems
5
5
  Author: Ductho Le
6
6
  License: MIT
@@ -49,7 +49,7 @@ Requires-Dist: triton>=2.0.0; sys_platform == "linux"
49
49
 
50
50
  ### A Scalable Deep Learning Framework for Wave-Based Inverse Problems
51
51
 
52
- [![Python 3.11+](https://img.shields.io/badge/python-3.11+-blue.svg?style=plastic&logo=python&logoColor=white)](https://www.python.org/downloads/)
52
+ [![Python 3.11+](https://img.shields.io/badge/Python-3.11+-blue.svg?style=plastic&logo=python&logoColor=white)](https://www.python.org/downloads/)
53
53
  [![PyTorch 2.x](https://img.shields.io/badge/PyTorch-2.x-ee4c2c.svg?style=plastic&logo=pytorch&logoColor=white)](https://pytorch.org/)
54
54
  [![Accelerate](https://img.shields.io/badge/Accelerate-Enabled-yellow.svg?style=plastic&logo=huggingface&logoColor=white)](https://huggingface.co/docs/accelerate/)
55
55
  <br>
@@ -57,7 +57,7 @@ Requires-Dist: triton>=2.0.0; sys_platform == "linux"
57
57
  [![Lint](https://img.shields.io/github/actions/workflow/status/ductho-le/WaveDL/lint.yml?branch=main&style=plastic&logo=ruff&logoColor=white&label=Lint)](https://github.com/ductho-le/WaveDL/actions/workflows/lint.yml)
58
58
  [![Try it on Colab](https://img.shields.io/badge/Try_it_on_Colab-8E44AD?style=plastic&logo=googlecolab&logoColor=white)](https://colab.research.google.com/github/ductho-le/WaveDL/blob/main/notebooks/demo.ipynb)
59
59
  <br>
60
- [![Downloads](https://img.shields.io/pepy/dt/wavedl?style=plastic&logo=pypi&logoColor=white&color=9ACD32)](https://pepy.tech/project/wavedl)
60
+ [![Downloads](https://img.shields.io/badge/dynamic/json?url=https://pypistats.org/api/packages/wavedl/recent?period=month%26mirrors=false&query=data.last_month&style=plastic&logo=pypi&logoColor=white&color=9ACD32&label=Downloads&suffix=/month)](https://pypistats.org/packages/wavedl)
61
61
  [![License: MIT](https://img.shields.io/badge/License-MIT-orange.svg?style=plastic)](LICENSE)
62
62
  [![DOI](https://img.shields.io/badge/DOI-10.5281/zenodo.18012338-008080.svg?style=plastic)](https://doi.org/10.5281/zenodo.18012338)
63
63
 
@@ -734,18 +734,20 @@ Automatically find the best training configuration using [Optuna](https://optuna
734
734
 
735
735
  **Run HPO:**
736
736
 
737
- You specify which models to search and how many trials to run:
738
737
  ```bash
739
- # Search 3 models with 100 trials
740
- python -m wavedl.hpo --data_path train.npz --models cnn resnet18 efficientnet_b0 --n_trials 100
738
+ # Basic HPO (auto-detects GPUs for parallel trials)
739
+ wavedl-hpo --data_path train.npz --models cnn --n_trials 100
741
740
 
742
- # Search 1 model (faster)
743
- python -m wavedl.hpo --data_path train.npz --models cnn --n_trials 50
741
+ # Search multiple models
742
+ wavedl-hpo --data_path train.npz --models cnn resnet18 efficientnet_b0 --n_trials 200
744
743
 
745
- # Search all your candidate models
746
- python -m wavedl.hpo --data_path train.npz --models cnn resnet18 resnet50 vit_small densenet121 --n_trials 200
744
+ # Quick mode (fewer parameters, faster)
745
+ wavedl-hpo --data_path train.npz --models cnn --n_trials 50 --quick
747
746
  ```
748
747
 
748
+ > [!TIP]
749
+ > **Auto GPU Detection**: HPO automatically detects available GPUs and runs one trial per GPU in parallel. On a 4-GPU system, 4 trials run simultaneously. Use `--n_jobs 1` to force serial execution.
750
+
749
751
  **Train with best parameters**
750
752
 
751
753
  After HPO completes, it prints the optimal command:
@@ -784,7 +786,7 @@ accelerate launch -m wavedl.train --data_path train.npz --model cnn --lr 3.2e-4
784
786
  | `--optimizers` | all 6 | Optimizers to search |
785
787
  | `--schedulers` | all 8 | Schedulers to search |
786
788
  | `--losses` | all 6 | Losses to search |
787
- | `--n_jobs` | `1` | Parallel trials (multi-GPU) |
789
+ | `--n_jobs` | `-1` | Parallel trials (-1 = auto-detect GPUs) |
788
790
  | `--max_epochs` | `50` | Max epochs per trial |
789
791
  | `--output` | `hpo_results.json` | Output file |
790
792
 
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes