wavedl 1.2.0__py3-none-any.whl → 1.3.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
wavedl/__init__.py CHANGED
@@ -18,7 +18,7 @@ For inference:
18
18
  # or: python -m wavedl.test --checkpoint best_checkpoint --data_path test.npz
19
19
  """
20
20
 
21
- __version__ = "1.2.0"
21
+ __version__ = "1.3.1"
22
22
  __author__ = "Ductho Le"
23
23
  __email__ = "ductho.le@outlook.com"
24
24
 
wavedl/hpc.py ADDED
@@ -0,0 +1,243 @@
1
+ #!/usr/bin/env python
2
+ """
3
+ WaveDL HPC Training Launcher.
4
+
5
+ This module provides a Python-based HPC training launcher that wraps accelerate
6
+ for distributed training on High-Performance Computing clusters.
7
+
8
+ Usage:
9
+ wavedl-hpc --model cnn --data_path train.npz --num_gpus 4
10
+
11
+ Example SLURM script:
12
+ #!/bin/bash
13
+ #SBATCH --nodes=1
14
+ #SBATCH --gpus-per-node=4
15
+ #SBATCH --time=12:00:00
16
+
17
+ wavedl-hpc --model cnn --data_path /scratch/data.npz --compile
18
+
19
+ Author: Ductho Le (ductho.le@outlook.com)
20
+ """
21
+
22
+ from __future__ import annotations
23
+
24
+ import argparse
25
+ import os
26
+ import shutil
27
+ import subprocess
28
+ import sys
29
+ import tempfile
30
+ from pathlib import Path
31
+
32
+
33
+ def detect_gpus() -> int:
34
+ """Auto-detect available GPUs using nvidia-smi."""
35
+ if shutil.which("nvidia-smi") is None:
36
+ print("Warning: nvidia-smi not found, defaulting to 1 GPU")
37
+ return 1
38
+
39
+ try:
40
+ result = subprocess.run(
41
+ ["nvidia-smi", "--list-gpus"],
42
+ capture_output=True,
43
+ text=True,
44
+ check=True,
45
+ )
46
+ gpu_count = len(result.stdout.strip().split("\n"))
47
+ if gpu_count > 0:
48
+ print(f"Auto-detected {gpu_count} GPU(s)")
49
+ return gpu_count
50
+ except (subprocess.CalledProcessError, FileNotFoundError):
51
+ pass
52
+
53
+ print("Warning: Could not detect GPUs, defaulting to 1")
54
+ return 1
55
+
56
+
57
+ def setup_hpc_environment() -> None:
58
+ """Configure environment variables for HPC systems.
59
+
60
+ Handles restricted home directories (e.g., Compute Canada) and
61
+ offline logging configurations.
62
+ """
63
+ # Use SLURM_TMPDIR if available, otherwise system temp
64
+ tmpdir = os.environ.get("SLURM_TMPDIR", tempfile.gettempdir())
65
+
66
+ # Configure directories for systems with restricted home directories
67
+ os.environ.setdefault("MPLCONFIGDIR", f"{tmpdir}/matplotlib")
68
+ os.environ.setdefault("XDG_CACHE_HOME", f"{tmpdir}/.cache")
69
+
70
+ # Ensure matplotlib config dir exists
71
+ Path(os.environ["MPLCONFIGDIR"]).mkdir(parents=True, exist_ok=True)
72
+
73
+ # WandB configuration (offline by default for HPC)
74
+ os.environ.setdefault("WANDB_MODE", "offline")
75
+ os.environ.setdefault("WANDB_DIR", f"{tmpdir}/wandb")
76
+ os.environ.setdefault("WANDB_CACHE_DIR", f"{tmpdir}/wandb_cache")
77
+ os.environ.setdefault("WANDB_CONFIG_DIR", f"{tmpdir}/wandb_config")
78
+
79
+ # Suppress non-critical warnings
80
+ os.environ.setdefault(
81
+ "PYTHONWARNINGS",
82
+ "ignore::UserWarning,ignore::FutureWarning,ignore::DeprecationWarning",
83
+ )
84
+
85
+
86
+ def parse_args() -> tuple[argparse.Namespace, list[str]]:
87
+ """Parse HPC-specific arguments, pass remaining to wavedl.train."""
88
+ parser = argparse.ArgumentParser(
89
+ description="WaveDL HPC Training Launcher",
90
+ formatter_class=argparse.RawDescriptionHelpFormatter,
91
+ epilog="""
92
+ Examples:
93
+ # Basic training with auto-detected GPUs
94
+ wavedl-hpc --model cnn --data_path train.npz --epochs 100
95
+
96
+ # Specify GPU count and mixed precision
97
+ wavedl-hpc --model cnn --data_path train.npz --num_gpus 4 --mixed_precision bf16
98
+
99
+ # Full configuration
100
+ wavedl-hpc --model resnet18 --data_path train.npz --num_gpus 8 \\
101
+ --batch_size 256 --lr 1e-3 --compile --output_dir ./results
102
+
103
+ Environment Variables:
104
+ WANDB_MODE WandB mode: offline|online (default: offline)
105
+ SLURM_TMPDIR Temp directory for HPC systems
106
+ """,
107
+ )
108
+
109
+ # HPC-specific arguments
110
+ parser.add_argument(
111
+ "--num_gpus",
112
+ type=int,
113
+ default=None,
114
+ help="Number of GPUs to use (default: auto-detect)",
115
+ )
116
+ parser.add_argument(
117
+ "--num_machines",
118
+ type=int,
119
+ default=1,
120
+ help="Number of machines for multi-node training (default: 1)",
121
+ )
122
+ parser.add_argument(
123
+ "--machine_rank",
124
+ type=int,
125
+ default=0,
126
+ help="Rank of this machine in multi-node setup (default: 0)",
127
+ )
128
+ parser.add_argument(
129
+ "--mixed_precision",
130
+ type=str,
131
+ choices=["bf16", "fp16", "no"],
132
+ default="bf16",
133
+ help="Mixed precision mode (default: bf16)",
134
+ )
135
+ parser.add_argument(
136
+ "--dynamo_backend",
137
+ type=str,
138
+ default="no",
139
+ help="PyTorch dynamo backend (default: no)",
140
+ )
141
+
142
+ # Parse known args, pass rest to wavedl.train
143
+ args, remaining = parser.parse_known_args()
144
+ return args, remaining
145
+
146
+
147
+ def print_summary(exit_code: int, wandb_mode: str, wandb_dir: str) -> None:
148
+ """Print post-training summary and instructions."""
149
+ print()
150
+ print("=" * 50)
151
+
152
+ if exit_code == 0:
153
+ print("✅ Training completed successfully!")
154
+ print("=" * 50)
155
+
156
+ if wandb_mode == "offline":
157
+ print()
158
+ print("📊 WandB Sync Instructions:")
159
+ print(" From the login node, run:")
160
+ print(f" wandb sync {wandb_dir}/wandb/offline-run-*")
161
+ print()
162
+ print(" This will upload your training logs to wandb.ai")
163
+ else:
164
+ print(f"❌ Training failed with exit code: {exit_code}")
165
+ print("=" * 50)
166
+ print()
167
+ print("Common issues:")
168
+ print(" - Missing data file (check --data_path)")
169
+ print(" - Insufficient GPU memory (reduce --batch_size)")
170
+ print(" - Invalid model name (run: wavedl-train --list_models)")
171
+ print()
172
+
173
+ print("=" * 50)
174
+ print()
175
+
176
+
177
+ def main() -> int:
178
+ """Main entry point for wavedl-hpc command."""
179
+ # Parse arguments
180
+ args, train_args = parse_args()
181
+
182
+ # Setup HPC environment
183
+ setup_hpc_environment()
184
+
185
+ # Auto-detect GPUs if not specified
186
+ num_gpus = args.num_gpus if args.num_gpus is not None else detect_gpus()
187
+
188
+ # Build accelerate launch command
189
+ cmd = [
190
+ sys.executable,
191
+ "-m",
192
+ "accelerate.commands.launch",
193
+ f"--num_processes={num_gpus}",
194
+ f"--num_machines={args.num_machines}",
195
+ f"--machine_rank={args.machine_rank}",
196
+ f"--mixed_precision={args.mixed_precision}",
197
+ f"--dynamo_backend={args.dynamo_backend}",
198
+ "-m",
199
+ "wavedl.train",
200
+ ] + train_args
201
+
202
+ # Create output directory if specified
203
+ for i, arg in enumerate(train_args):
204
+ if arg == "--output_dir" and i + 1 < len(train_args):
205
+ Path(train_args[i + 1]).mkdir(parents=True, exist_ok=True)
206
+ break
207
+ if arg.startswith("--output_dir="):
208
+ Path(arg.split("=", 1)[1]).mkdir(parents=True, exist_ok=True)
209
+ break
210
+
211
+ # Print launch configuration
212
+ print()
213
+ print("=" * 50)
214
+ print("🚀 WaveDL HPC Training Launcher")
215
+ print("=" * 50)
216
+ print(f" GPUs: {num_gpus}")
217
+ print(f" Machines: {args.num_machines}")
218
+ print(f" Mixed Precision: {args.mixed_precision}")
219
+ print(f" Dynamo Backend: {args.dynamo_backend}")
220
+ print(f" WandB Mode: {os.environ.get('WANDB_MODE', 'offline')}")
221
+ print("=" * 50)
222
+ print()
223
+
224
+ # Launch training
225
+ try:
226
+ result = subprocess.run(cmd, check=False)
227
+ exit_code = result.returncode
228
+ except KeyboardInterrupt:
229
+ print("\n\n⚠️ Training interrupted by user")
230
+ exit_code = 130
231
+
232
+ # Print summary
233
+ print_summary(
234
+ exit_code,
235
+ os.environ.get("WANDB_MODE", "offline"),
236
+ os.environ.get("WANDB_DIR", "/tmp/wandb"),
237
+ )
238
+
239
+ return exit_code
240
+
241
+
242
+ if __name__ == "__main__":
243
+ sys.exit(main())
wavedl/hpo.py CHANGED
@@ -5,16 +5,16 @@ Automated hyperparameter search for finding optimal training configurations.
5
5
 
6
6
  Usage:
7
7
  # Basic HPO (50 trials)
8
- python hpo.py --data_path train.npz --n_trials 50
8
+ wavedl-hpo --data_path train.npz --n_trials 50
9
9
 
10
10
  # Quick search (fewer parameters)
11
- python hpo.py --data_path train.npz --n_trials 30 --quick
11
+ wavedl-hpo --data_path train.npz --n_trials 30 --quick
12
12
 
13
13
  # Full search with specific models
14
- python hpo.py --data_path train.npz --n_trials 100 --models cnn resnet18 efficientnet_b0
14
+ wavedl-hpo --data_path train.npz --n_trials 100 --models cnn resnet18 efficientnet_b0
15
15
 
16
16
  # Parallel trials on multiple GPUs
17
- python hpo.py --data_path train.npz --n_trials 100 --n_jobs 4
17
+ wavedl-hpo --data_path train.npz --n_trials 100 --n_jobs 4
18
18
 
19
19
  Author: Ductho Le (ductho.le@outlook.com)
20
20
  """
@@ -205,9 +205,9 @@ def main():
205
205
  formatter_class=argparse.RawDescriptionHelpFormatter,
206
206
  epilog="""
207
207
  Examples:
208
- python hpo.py --data_path train.npz --n_trials 50
209
- python hpo.py --data_path train.npz --n_trials 30 --quick
210
- python hpo.py --data_path train.npz --n_trials 100 --models cnn resnet18
208
+ wavedl-hpo --data_path train.npz --n_trials 50
209
+ wavedl-hpo --data_path train.npz --n_trials 30 --quick
210
+ wavedl-hpo --data_path train.npz --n_trials 100 --models cnn resnet18
211
211
  """,
212
212
  )
213
213
 
@@ -355,7 +355,7 @@ Examples:
355
355
  print("\n" + "=" * 60)
356
356
  print("TO TRAIN WITH BEST PARAMETERS:")
357
357
  print("=" * 60)
358
- cmd_parts = ["accelerate launch train.py"]
358
+ cmd_parts = ["accelerate launch -m wavedl.train"]
359
359
  cmd_parts.append(f"--data_path {args.data_path}")
360
360
  for key, value in study.best_params.items():
361
361
  cmd_parts.append(f"--{key} {value}")
@@ -11,7 +11,7 @@ Steps to Add a New Model:
11
11
  3. Implement the __init__ and forward methods
12
12
  4. Import your model in models/__init__.py:
13
13
  from wavedl.models.your_model import YourModel
14
- 5. Run: accelerate launch train.py --model your_model --wandb
14
+ 5. Run: accelerate launch -m wavedl.train --model your_model --wandb
15
15
 
16
16
  Author: Ductho Le (ductho.le@outlook.com)
17
17
  Version: 1.0.0
wavedl/test.py CHANGED
@@ -13,14 +13,14 @@ Production-grade inference script for evaluating trained WaveDL models:
13
13
 
14
14
  Usage:
15
15
  # Basic inference
16
- python test.py --checkpoint ./best_checkpoint --data_path test_data.npz
16
+ wavedl-test --checkpoint ./best_checkpoint --data_path test_data.npz
17
17
 
18
18
  # With visualization and detailed output
19
- python test.py --checkpoint ./best_checkpoint --data_path test_data.npz \\
19
+ wavedl-test --checkpoint ./best_checkpoint --data_path test_data.npz \\
20
20
  --plot --plot_format png pdf --output_dir ./test_results --save_predictions
21
21
 
22
22
  # Export model to ONNX for deployment
23
- python test.py --checkpoint ./best_checkpoint --data_path test_data.npz \\
23
+ wavedl-test --checkpoint ./best_checkpoint --data_path test_data.npz \\
24
24
  --export onnx --export_path model.onnx
25
25
 
26
26
  Author: Ductho Le (ductho.le@outlook.com)
wavedl/train.py CHANGED
@@ -12,26 +12,25 @@ A modular training framework for wave-based inverse problems and regression:
12
12
  6. Deep Observability: WandB integration with scatter analysis
13
13
 
14
14
  Usage:
15
- # Recommended: Using the HPC helper script
16
- ./run_training.sh --model cnn --batch_size 128 --wandb
15
+ # Recommended: Using the HPC launcher
16
+ wavedl-hpc --model cnn --batch_size 128 --wandb
17
17
 
18
18
  # Or with direct accelerate launch
19
- accelerate launch train.py --model cnn --batch_size 128 --wandb
19
+ accelerate launch -m wavedl.train --model cnn --batch_size 128 --wandb
20
20
 
21
21
  # Multi-GPU with explicit config
22
- accelerate launch --num_processes=4 --mixed_precision=bf16 \
23
- train.py --model cnn --wandb --project_name "MyProject"
22
+ wavedl-hpc --num_gpus 4 --mixed_precision bf16 --model cnn --wandb
24
23
 
25
24
  # Resume from checkpoint
26
- accelerate launch train.py --model cnn --resume best_checkpoint --wandb
25
+ accelerate launch -m wavedl.train --model cnn --resume best_checkpoint --wandb
27
26
 
28
27
  # List available models
29
- python train.py --list_models
28
+ wavedl-train --list_models
30
29
 
31
30
  Note:
32
- For HPC clusters (Compute Canada, etc.), use run_training.sh which handles
31
+ For HPC clusters (Compute Canada, etc.), use wavedl-hpc which handles
33
32
  environment configuration automatically. Mixed precision is controlled via
34
- --precision flag (default: bf16).
33
+ --mixed_precision flag (default: bf16).
35
34
 
36
35
  Author: Ductho Le (ductho.le@outlook.com)
37
36
  """
@@ -122,6 +121,14 @@ def parse_args() -> argparse.Namespace:
122
121
  parser.add_argument(
123
122
  "--list_models", action="store_true", help="List all available models and exit"
124
123
  )
124
+ parser.add_argument(
125
+ "--import",
126
+ dest="import_modules",
127
+ type=str,
128
+ nargs="+",
129
+ default=[],
130
+ help="Python modules to import before training (for custom models)",
131
+ )
125
132
 
126
133
  # Configuration File
127
134
  parser.add_argument(
@@ -314,6 +321,36 @@ def parse_args() -> argparse.Namespace:
314
321
  def main():
315
322
  args, parser = parse_args()
316
323
 
324
+ # Import custom model modules if specified
325
+ if args.import_modules:
326
+ import importlib
327
+
328
+ for module_name in args.import_modules:
329
+ try:
330
+ # Handle both module names (my_model) and file paths (./my_model.py)
331
+ if module_name.endswith(".py"):
332
+ # Import from file path
333
+ import importlib.util
334
+
335
+ spec = importlib.util.spec_from_file_location(
336
+ "custom_module", module_name
337
+ )
338
+ if spec and spec.loader:
339
+ module = importlib.util.module_from_spec(spec)
340
+ sys.modules["custom_module"] = module
341
+ spec.loader.exec_module(module)
342
+ print(f"✓ Imported custom module from: {module_name}")
343
+ else:
344
+ # Import as regular module
345
+ importlib.import_module(module_name)
346
+ print(f"✓ Imported module: {module_name}")
347
+ except ImportError as e:
348
+ print(f"✗ Failed to import '{module_name}': {e}", file=sys.stderr)
349
+ print(
350
+ " Make sure the module is in your Python path or current directory."
351
+ )
352
+ sys.exit(1)
353
+
317
354
  # Handle --list_models flag
318
355
  if args.list_models:
319
356
  print("Available models:")
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.2
2
2
  Name: wavedl
3
- Version: 1.2.0
3
+ Version: 1.3.1
4
4
  Summary: A Scalable Deep Learning Framework for Wave-Based Inverse Problems
5
5
  Author: Ductho Le
6
6
  License: MIT
@@ -43,7 +43,7 @@ Provides-Extra: onnx
43
43
  Requires-Dist: onnx>=1.14.0; extra == "onnx"
44
44
  Requires-Dist: onnxruntime>=1.15.0; extra == "onnx"
45
45
  Provides-Extra: compile
46
- Requires-Dist: triton; extra == "compile"
46
+ Requires-Dist: triton; sys_platform == "linux" and extra == "compile"
47
47
  Provides-Extra: hpo
48
48
  Requires-Dist: optuna>=3.0.0; extra == "hpo"
49
49
  Provides-Extra: all
@@ -53,7 +53,7 @@ Requires-Dist: ruff>=0.8.0; extra == "all"
53
53
  Requires-Dist: pre-commit>=3.5.0; extra == "all"
54
54
  Requires-Dist: onnx>=1.14.0; extra == "all"
55
55
  Requires-Dist: onnxruntime>=1.15.0; extra == "all"
56
- Requires-Dist: triton; extra == "all"
56
+ Requires-Dist: triton; sys_platform == "linux" and extra == "all"
57
57
  Requires-Dist: optuna>=3.0.0; extra == "all"
58
58
 
59
59
  <div align="center">
@@ -211,40 +211,43 @@ Deploy models anywhere:
211
211
  ### Installation
212
212
 
213
213
  ```bash
214
- git clone https://github.com/ductho-le/WaveDL.git
215
- cd WaveDL
214
+ # Install from PyPI (recommended)
215
+ pip install wavedl
216
+
217
+ # Or install with all extras (ONNX export, HPO, dev tools)
218
+ pip install wavedl[all]
219
+ ```
216
220
 
217
- # Basic install (training + inference)
218
- pip install -e .
221
+ #### From Source (for development)
219
222
 
220
- # Full install (adds ONNX export, torch.compile, HPO, dev tools)
221
- pip install -e ".[all]"
223
+ ```bash
224
+ git clone https://github.com/ductho-le/WaveDL.git
225
+ cd WaveDL
226
+ pip install -e ".[dev]"
222
227
  ```
223
228
 
224
229
  > [!NOTE]
225
- > Dependencies are managed in `pyproject.toml`. Python 3.11+ required.
226
- >
227
- > For development setup (running tests, contributing), see [CONTRIBUTING.md](.github/CONTRIBUTING.md).
230
+ > Python 3.11+ required. For development setup, see [CONTRIBUTING.md](.github/CONTRIBUTING.md).
228
231
 
229
232
  ### Quick Start
230
233
 
231
234
  > [!TIP]
232
235
  > In all examples below, replace `<...>` placeholders with your values. See [Configuration](#️-configuration) for defaults and options.
233
236
 
234
- #### Option 1: Using the Helper Script (Recommended for HPC)
237
+ #### Option 1: Using wavedl-hpc (Recommended for HPC)
235
238
 
236
- The `run_training.sh` wrapper automatically configures the environment for HPC systems:
239
+ The `wavedl-hpc` command automatically configures the environment for HPC systems:
237
240
 
238
241
  ```bash
239
- # Make executable (first time only)
240
- chmod +x run_training.sh
241
-
242
242
  # Basic training (auto-detects available GPUs)
243
- ./run_training.sh --model <model_name> --data_path <train_data> --batch_size <number> --output_dir <output_folder>
243
+ wavedl-hpc --model <model_name> --data_path <train_data> --batch_size <number> --output_dir <output_folder>
244
244
 
245
245
  # Detailed configuration
246
- ./run_training.sh --model <model_name> --data_path <train_data> --batch_size <number> \
246
+ wavedl-hpc --model <model_name> --data_path <train_data> --batch_size <number> \
247
247
  --lr <number> --epochs <number> --patience <number> --compile --output_dir <output_folder>
248
+
249
+ # Specify GPU count explicitly
250
+ wavedl-hpc --num_gpus 4 --model cnn --data_path train.npz --output_dir results
248
251
  ```
249
252
 
250
253
  #### Option 2: Direct Accelerate Launch
@@ -261,13 +264,13 @@ accelerate launch -m wavedl.train --model <model_name> --data_path <train_data>
261
264
  accelerate launch -m wavedl.train --model <model_name> --data_path <train_data> --output_dir <output_folder> --fresh
262
265
 
263
266
  # List available models
264
- python -m wavedl.train --list_models
267
+ wavedl-train --list_models
265
268
  ```
266
269
 
267
270
  > [!TIP]
268
271
  > **Auto-Resume**: If training crashes or is interrupted, simply re-run with the same `--output_dir`. The framework automatically detects incomplete training and resumes from the last checkpoint. Use `--fresh` to force a fresh start.
269
272
  >
270
- > **GPU Auto-Detection**: By default, `run_training.sh` automatically detects available GPUs using `nvidia-smi`. Set `NUM_GPUS` to override this behavior.
273
+ > **GPU Auto-Detection**: `wavedl-hpc` automatically detects available GPUs using `nvidia-smi`. Use `--num_gpus` to override.
271
274
 
272
275
  ### Testing & Inference
273
276
 
@@ -299,6 +302,56 @@ python -m wavedl.test --checkpoint <checkpoint_folder> --data_path <test_data> \
299
302
  > [!NOTE]
300
303
  > `wavedl.test` auto-detects the model architecture from checkpoint metadata. If unavailable, it falls back to folder name parsing. Use `--model` to override if needed.
301
304
 
305
+ ### Adding Custom Models
306
+
307
+ <details>
308
+ <summary><b>Creating Your Own Architecture</b></summary>
309
+
310
+ **Requirements** (your model must):
311
+ 1. Inherit from `BaseModel`
312
+ 2. Accept `in_channels`, `num_outputs`, `input_shape` in `__init__`
313
+ 3. Return a tensor of shape `(batch, num_outputs)` from `forward()`
314
+
315
+ ---
316
+
317
+ **Step 1: Create `my_model.py`**
318
+
319
+ ```python
320
+ import torch.nn as nn
321
+ import torch.nn.functional as F
322
+ from wavedl.models import BaseModel, register_model
323
+
324
+ @register_model("my_model") # This name is used with --model flag
325
+ class MyModel(BaseModel):
326
+ def __init__(self, in_channels, num_outputs, input_shape):
327
+ # in_channels: number of input channels (auto-detected from data)
328
+ # num_outputs: number of parameters to predict (auto-detected from data)
329
+ # input_shape: spatial dimensions, e.g., (128,) or (64, 64) or (32, 32, 32)
330
+ super().__init__(in_channels, num_outputs, input_shape)
331
+
332
+ # Define your layers (this is just an example)
333
+ self.conv1 = nn.Conv2d(in_channels, 64, 3, padding=1)
334
+ self.conv2 = nn.Conv2d(64, 128, 3, padding=1)
335
+ self.fc = nn.Linear(128, num_outputs)
336
+
337
+ def forward(self, x):
338
+ # Input x has shape: (batch, in_channels, *input_shape)
339
+ x = F.relu(self.conv1(x))
340
+ x = F.relu(self.conv2(x))
341
+ x = x.mean(dim=[-2, -1]) # Global average pooling
342
+ return self.fc(x) # Output shape: (batch, num_outputs)
343
+ ```
344
+
345
+ **Step 2: Train**
346
+
347
+ ```bash
348
+ wavedl-hpc --import my_model --model my_model --data_path train.npz
349
+ ```
350
+
351
+ WaveDL handles everything else: training loop, logging, checkpoints, multi-GPU, early stopping, etc.
352
+
353
+ </details>
354
+
302
355
  ---
303
356
 
304
357
  ## 📁 Project Structure
@@ -311,6 +364,7 @@ WaveDL/
311
364
  │ ├── train.py # Training entry point
312
365
  │ ├── test.py # Testing & inference script
313
366
  │ ├── hpo.py # Hyperparameter optimization
367
+ │ ├── hpc.py # HPC distributed training launcher
314
368
  │ │
315
369
  │ ├── models/ # Model architectures
316
370
  │ │ ├── registry.py # Model factory (@register_model)
@@ -332,7 +386,6 @@ WaveDL/
332
386
  │ ├── schedulers.py # LR scheduler factory
333
387
  │ └── config.py # YAML configuration support
334
388
 
335
- ├── run_training.sh # HPC helper script
336
389
  ├── configs/ # YAML config templates
337
390
  ├── examples/ # Ready-to-run examples
338
391
  ├── notebooks/ # Jupyter notebooks
@@ -347,12 +400,12 @@ WaveDL/
347
400
  ## ⚙️ Configuration
348
401
 
349
402
  > [!NOTE]
350
- > All configuration options below work with **both** `run_training.sh` and direct `accelerate launch`. The wrapper script passes all arguments directly to `train.py`.
403
+ > All configuration options below work with **both** `wavedl-hpc` and direct `accelerate launch`. The wrapper script passes all arguments directly to `train.py`.
351
404
  >
352
405
  > **Examples:**
353
406
  > ```bash
354
- > # Using run_training.sh
355
- > ./run_training.sh --model cnn --batch_size 256 --lr 5e-4 --compile
407
+ > # Using wavedl-hpc
408
+ > wavedl-hpc --model cnn --batch_size 256 --lr 5e-4 --compile
356
409
  >
357
410
  > # Using accelerate launch directly
358
411
  > accelerate launch -m wavedl.train --model cnn --batch_size 256 --lr 5e-4 --compile
@@ -395,6 +448,7 @@ WaveDL/
395
448
  | Argument | Default | Description |
396
449
  |----------|---------|-------------|
397
450
  | `--model` | `cnn` | Model architecture |
451
+ | `--import` | - | Python modules to import (for custom models) |
398
452
  | `--batch_size` | `128` | Per-GPU batch size |
399
453
  | `--lr` | `1e-3` | Learning rate |
400
454
  | `--epochs` | `1000` | Maximum epochs |
@@ -434,7 +488,7 @@ WaveDL/
434
488
  </details>
435
489
 
436
490
  <details>
437
- <summary><b>Environment Variables (run_training.sh)</b></summary>
491
+ <summary><b>Environment Variables (wavedl-hpc)</b></summary>
438
492
 
439
493
  | Variable | Default | Description |
440
494
  |----------|---------|-------------|
@@ -527,15 +581,15 @@ For robust model evaluation, simply add the `--cv` flag:
527
581
 
528
582
  ```bash
529
583
  # 5-fold cross-validation (works with both methods!)
530
- ./run_training.sh --model cnn --cv 5 --data_path train_data.npz
584
+ wavedl-hpc --model cnn --cv 5 --data_path train_data.npz
531
585
  # OR
532
586
  accelerate launch -m wavedl.train --model cnn --cv 5 --data_path train_data.npz
533
587
 
534
588
  # Stratified CV (recommended for unbalanced data)
535
- ./run_training.sh --model cnn --cv 5 --cv_stratify --loss huber --epochs 100
589
+ wavedl-hpc --model cnn --cv 5 --cv_stratify --loss huber --epochs 100
536
590
 
537
591
  # Full configuration
538
- ./run_training.sh --model cnn --cv 5 --cv_stratify \
592
+ wavedl-hpc --model cnn --cv 5 --cv_stratify \
539
593
  --loss huber --optimizer adamw --scheduler cosine \
540
594
  --output_dir ./cv_results
541
595
  ```
@@ -1,9 +1,10 @@
1
- wavedl/__init__.py,sha256=YUkbm14YDpotLU8Gv2kmgzSgXaLzLphzawZ98YBAZ7w,1177
2
- wavedl/hpo.py,sha256=ncArshHv-DDtD-1ufqCHneQoUeplAP0dWxXD0Dka4qc,11556
3
- wavedl/test.py,sha256=14LHUwgQyrDcM13CdC7KRiHg-FYI9e26z-8B9gMnU5U,36255
4
- wavedl/train.py,sha256=X6ibjYRiSeKZBKRevJcNLvYeXNykH2hXcWoHDEwb7EU,42394
1
+ wavedl/__init__.py,sha256=5EO4WDuyQksw2UQnnojmuA6asc7_Ew9qtLCF-dxo_qo,1177
2
+ wavedl/hpc.py,sha256=OaiGo0Q_ylu6tCEZSnMZ9ohk3nWcqbnwNMXrbZgikF0,7325
3
+ wavedl/hpo.py,sha256=aZoa_Oto_anZpIhz-YM6kN8KxQXTolUvDEyg3NXwBrY,11542
4
+ wavedl/test.py,sha256=jZmRJaivYYTMMTaccCi0yQjHOfp0a9YWR1wAPeKFH-k,36246
5
+ wavedl/train.py,sha256=dO64C2ktW6on1wbYVdZSPk6w-ZzZbzOGym-2xi-gk_g,43868
5
6
  wavedl/models/__init__.py,sha256=AbsFkRNlsiWv4sJ-kLPdwjA2FS_cSp_TB3CV8884uUE,2219
6
- wavedl/models/_template.py,sha256=fBAzCPjceDZ-jTslSqXLgVvfcuBATOwnn7qXkgHUlrI,4830
7
+ wavedl/models/_template.py,sha256=O7SfL3Ef7eDXGmcOXPD0c82o_t3K4ybgJwpSEDsZNEg,4837
7
8
  wavedl/models/base.py,sha256=cql0wv8i1sMaVttXOSdBBTPfa2s2sLH5LyAsfKJdXX8,5304
8
9
  wavedl/models/cnn.py,sha256=2FFQetQaCJqeeku6glXbOQ3KJw5VvSTu9-u9cpygVk8,8356
9
10
  wavedl/models/convnext.py,sha256=zh-x5NFcZrcRv3bi55p-VKWHLYe-v1nvPcMp9xPizLk,12747
@@ -22,9 +23,9 @@ wavedl/utils/losses.py,sha256=5762M-TBC_hz6uyj1NPbU1vZeFOJQq7fR3-j7OygJRo,7254
22
23
  wavedl/utils/metrics.py,sha256=mkCpqZwl_XUpNvA5Ekjf7y-HqApafR7eR6EuA8cBdM8,37287
23
24
  wavedl/utils/optimizers.py,sha256=PyIkJ_hRhFi_Fio81Gy5YQNhcME0JUUEl8OTSyu-0RA,6323
24
25
  wavedl/utils/schedulers.py,sha256=e6Sf0yj8VOqkdwkUHLMyUfGfHKTX4NMr-zfgxWqCTYI,7659
25
- wavedl-1.2.0.dist-info/LICENSE,sha256=cEUCvcvH-9BT9Y-CNGY__PwWONCKu9zsoIqWA-NeHJ4,1066
26
- wavedl-1.2.0.dist-info/METADATA,sha256=PMVsB66TjbsNadJu-AWixvSuPuHA6SL3x_h1MXHZDsI,37229
27
- wavedl-1.2.0.dist-info/WHEEL,sha256=beeZ86-EfXScwlR_HKu4SllMC9wUEj_8Z_4FJ3egI2w,91
28
- wavedl-1.2.0.dist-info/entry_points.txt,sha256=DeJ7crL01nqr_YUs9AekYjJ3P6hEN90GK_dASLvhPbM,111
29
- wavedl-1.2.0.dist-info/top_level.txt,sha256=ccneUt3D5Qzbh3bsBSSrq9bqrhGiogcWKY24ZC4Q6Xw,7
30
- wavedl-1.2.0.dist-info/RECORD,,
26
+ wavedl-1.3.1.dist-info/LICENSE,sha256=cEUCvcvH-9BT9Y-CNGY__PwWONCKu9zsoIqWA-NeHJ4,1066
27
+ wavedl-1.3.1.dist-info/METADATA,sha256=JrPtQBD_sXt_8lUlqIYzSqe2KBYgQLCeGAaUXy_hmhA,38922
28
+ wavedl-1.3.1.dist-info/WHEEL,sha256=beeZ86-EfXScwlR_HKu4SllMC9wUEj_8Z_4FJ3egI2w,91
29
+ wavedl-1.3.1.dist-info/entry_points.txt,sha256=f1RNDkXFZwBzrBzTMFocJ6xhfTvTmaEDTi5YyDEUaF8,140
30
+ wavedl-1.3.1.dist-info/top_level.txt,sha256=ccneUt3D5Qzbh3bsBSSrq9bqrhGiogcWKY24ZC4Q6Xw,7
31
+ wavedl-1.3.1.dist-info/RECORD,,
@@ -1,4 +1,5 @@
1
1
  [console_scripts]
2
+ wavedl-hpc = wavedl.hpc:main
2
3
  wavedl-hpo = wavedl.hpo:main
3
4
  wavedl-test = wavedl.test:main
4
5
  wavedl-train = wavedl.train:main
File without changes