wavedl 1.6.1__py3-none-any.whl → 1.6.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
wavedl/__init__.py CHANGED
@@ -18,7 +18,7 @@ For inference:
18
18
  # or: python -m wavedl.test --checkpoint best_checkpoint --data_path test.npz
19
19
  """
20
20
 
21
- __version__ = "1.6.1"
21
+ __version__ = "1.6.3"
22
22
  __author__ = "Ductho Le"
23
23
  __email__ = "ductho.le@outlook.com"
24
24
 
wavedl/hpo.py CHANGED
@@ -440,7 +440,7 @@ Examples:
440
440
  print("\n" + "=" * 60)
441
441
  print("TO TRAIN WITH BEST PARAMETERS:")
442
442
  print("=" * 60)
443
- cmd_parts = ["accelerate launch -m wavedl.train"]
443
+ cmd_parts = ["wavedl-train"]
444
444
  cmd_parts.append(f"--data_path {args.data_path}")
445
445
  for key, value in study.best_params.items():
446
446
  cmd_parts.append(f"--{key} {value}")
@@ -1,12 +1,21 @@
1
1
  #!/usr/bin/env python
2
2
  """
3
- WaveDL HPC Training Launcher.
3
+ WaveDL Training Launcher.
4
4
 
5
- This module provides a Python-based HPC training launcher that wraps accelerate
6
- for distributed training on High-Performance Computing clusters.
5
+ This module provides a universal training launcher that wraps accelerate
6
+ for distributed training. It works seamlessly on both:
7
+ - Local machines (uses standard cache locations)
8
+ - HPC clusters (uses local caching, offline WandB)
9
+
10
+ The environment is auto-detected based on scheduler variables (SLURM, PBS, etc.)
11
+ and home directory writability.
7
12
 
8
13
  Usage:
9
- wavedl-hpc --model cnn --data_path train.npz --num_gpus 4
14
+ # Local machine or HPC - same command!
15
+ wavedl-train --model cnn --data_path train.npz --output_dir results
16
+
17
+ # Multi-GPU is automatic (uses all available GPUs)
18
+ wavedl-train --model resnet18 --data_path train.npz --num_gpus 4
10
19
 
11
20
  Example SLURM script:
12
21
  #!/bin/bash
@@ -14,7 +23,7 @@ Example SLURM script:
14
23
  #SBATCH --gpus-per-node=4
15
24
  #SBATCH --time=12:00:00
16
25
 
17
- wavedl-hpc --model cnn --data_path /scratch/data.npz --compile
26
+ wavedl-train --model cnn --data_path /scratch/data.npz --compile
18
27
 
19
28
  Author: Ductho Le (ductho.le@outlook.com)
20
29
  """
@@ -53,78 +62,138 @@ def detect_gpus() -> int:
53
62
  return 1
54
63
 
55
64
 
56
- def setup_hpc_environment() -> None:
57
- """Configure environment variables for HPC systems.
65
+ def is_hpc_environment() -> bool:
66
+ """Detect if running on an HPC cluster.
67
+
68
+ Checks for:
69
+ 1. Common HPC scheduler environment variables (SLURM, PBS, LSF, SGE, Cobalt)
70
+ 2. Non-writable home directory (common on HPC systems)
58
71
 
59
- Handles restricted home directories (e.g., Compute Canada) and
60
- offline logging configurations. Always uses CWD-based TORCH_HOME
61
- since compute nodes typically lack internet access.
72
+ Returns:
73
+ True if HPC environment detected, False otherwise.
62
74
  """
63
- # Use CWD for cache base since HPC compute nodes typically lack internet
64
- cache_base = os.getcwd()
75
+ # Check for common HPC scheduler environment variables
76
+ hpc_indicators = [
77
+ "SLURM_JOB_ID", # SLURM
78
+ "PBS_JOBID", # PBS/Torque
79
+ "LSB_JOBID", # LSF
80
+ "SGE_TASK_ID", # Sun Grid Engine
81
+ "COBALT_JOBID", # Cobalt
82
+ ]
83
+ if any(var in os.environ for var in hpc_indicators):
84
+ return True
65
85
 
66
- # TORCH_HOME always set to CWD - compute nodes need pre-cached weights
67
- os.environ.setdefault("TORCH_HOME", f"{cache_base}/.torch_cache")
68
- Path(os.environ["TORCH_HOME"]).mkdir(parents=True, exist_ok=True)
86
+ # Check if home directory is not writable (common on HPC)
87
+ home = os.path.expanduser("~")
88
+ return not os.access(home, os.W_OK)
69
89
 
70
- # Triton/Inductor caches - prevents permission errors with --compile
71
- # These MUST be set before any torch.compile calls
72
- os.environ.setdefault("TRITON_CACHE_DIR", f"{cache_base}/.triton_cache")
73
- os.environ.setdefault("TORCHINDUCTOR_CACHE_DIR", f"{cache_base}/.inductor_cache")
74
- Path(os.environ["TRITON_CACHE_DIR"]).mkdir(parents=True, exist_ok=True)
75
- Path(os.environ["TORCHINDUCTOR_CACHE_DIR"]).mkdir(parents=True, exist_ok=True)
76
90
 
77
- # Check if home is writable for other caches
78
- home = os.path.expanduser("~")
79
- home_writable = os.access(home, os.W_OK)
80
-
81
- # Other caches only if home is not writable
82
- if not home_writable:
83
- os.environ.setdefault("MPLCONFIGDIR", f"{cache_base}/.matplotlib")
84
- os.environ.setdefault("FONTCONFIG_CACHE", f"{cache_base}/.fontconfig")
85
- os.environ.setdefault("XDG_CACHE_HOME", f"{cache_base}/.cache")
86
-
87
- # Ensure directories exist
88
- for env_var in [
89
- "MPLCONFIGDIR",
90
- "FONTCONFIG_CACHE",
91
- "XDG_CACHE_HOME",
92
- ]:
93
- Path(os.environ[env_var]).mkdir(parents=True, exist_ok=True)
94
-
95
- # WandB configuration (offline by default for HPC)
96
- os.environ.setdefault("WANDB_MODE", "offline")
97
- os.environ.setdefault("WANDB_DIR", f"{cache_base}/.wandb")
98
- os.environ.setdefault("WANDB_CACHE_DIR", f"{cache_base}/.wandb_cache")
99
- os.environ.setdefault("WANDB_CONFIG_DIR", f"{cache_base}/.wandb_config")
100
-
101
- # Suppress non-critical warnings
91
+ def setup_environment() -> None:
92
+ """Configure environment for HPC or local machine.
93
+
94
+ Automatically detects the environment and configures accordingly:
95
+ - HPC: Uses CWD-based caching, offline WandB (compute nodes lack internet)
96
+ - Local: Uses standard cache locations (~/.cache), doesn't override WandB
97
+ """
98
+ is_hpc = is_hpc_environment()
99
+
100
+ if is_hpc:
101
+ # HPC: use CWD-based caching (compute nodes lack internet)
102
+ cache_base = os.getcwd()
103
+
104
+ # TORCH_HOME set to CWD - compute nodes need pre-cached weights
105
+ os.environ.setdefault("TORCH_HOME", f"{cache_base}/.torch_cache")
106
+ Path(os.environ["TORCH_HOME"]).mkdir(parents=True, exist_ok=True)
107
+
108
+ # Triton/Inductor caches - prevents permission errors with --compile
109
+ os.environ.setdefault("TRITON_CACHE_DIR", f"{cache_base}/.triton_cache")
110
+ os.environ.setdefault(
111
+ "TORCHINDUCTOR_CACHE_DIR", f"{cache_base}/.inductor_cache"
112
+ )
113
+ Path(os.environ["TRITON_CACHE_DIR"]).mkdir(parents=True, exist_ok=True)
114
+ Path(os.environ["TORCHINDUCTOR_CACHE_DIR"]).mkdir(parents=True, exist_ok=True)
115
+
116
+ # Check if home is writable for other caches
117
+ home = os.path.expanduser("~")
118
+ home_writable = os.access(home, os.W_OK)
119
+
120
+ # Other caches only if home is not writable
121
+ if not home_writable:
122
+ os.environ.setdefault("MPLCONFIGDIR", f"{cache_base}/.matplotlib")
123
+ os.environ.setdefault("FONTCONFIG_CACHE", f"{cache_base}/.fontconfig")
124
+ os.environ.setdefault("XDG_CACHE_HOME", f"{cache_base}/.cache")
125
+
126
+ for env_var in [
127
+ "MPLCONFIGDIR",
128
+ "FONTCONFIG_CACHE",
129
+ "XDG_CACHE_HOME",
130
+ ]:
131
+ Path(os.environ[env_var]).mkdir(parents=True, exist_ok=True)
132
+
133
+ # WandB configuration (offline by default for HPC)
134
+ os.environ.setdefault("WANDB_MODE", "offline")
135
+ os.environ.setdefault("WANDB_DIR", f"{cache_base}/.wandb")
136
+ os.environ.setdefault("WANDB_CACHE_DIR", f"{cache_base}/.wandb_cache")
137
+ os.environ.setdefault("WANDB_CONFIG_DIR", f"{cache_base}/.wandb_config")
138
+
139
+ print("🖥️ HPC environment detected - using local caching")
140
+ else:
141
+ # Local machine: use standard locations, don't override user settings
142
+ # TORCH_HOME defaults to ~/.cache/torch (PyTorch default)
143
+ # WANDB_MODE defaults to online (WandB default)
144
+ print("💻 Local environment detected - using standard cache locations")
145
+
146
+ # Suppress non-critical warnings (both environments)
102
147
  os.environ.setdefault(
103
148
  "PYTHONWARNINGS",
104
149
  "ignore::UserWarning,ignore::FutureWarning,ignore::DeprecationWarning",
105
150
  )
106
151
 
107
152
 
153
+ def handle_fast_path_args() -> int | None:
154
+ """Handle utility flags that don't need accelerate launch.
155
+
156
+ Returns:
157
+ Exit code if handled (0 for success), None if should continue to full launch.
158
+ """
159
+ # --list_models: print models and exit immediately
160
+ if "--list_models" in sys.argv:
161
+ from wavedl.models import list_models
162
+
163
+ print("Available models:")
164
+ for name in list_models():
165
+ print(f" {name}")
166
+ return 0
167
+
168
+ return None # Continue to full launch
169
+
170
+
108
171
  def parse_args() -> tuple[argparse.Namespace, list[str]]:
109
- """Parse HPC-specific arguments, pass remaining to wavedl.train."""
172
+ """Parse launcher-specific arguments, pass remaining to wavedl.train."""
110
173
  parser = argparse.ArgumentParser(
111
- description="WaveDL HPC Training Launcher",
174
+ description="WaveDL Training Launcher (works on local machines and HPC clusters)",
112
175
  formatter_class=argparse.RawDescriptionHelpFormatter,
113
176
  epilog="""
114
177
  Examples:
115
- # Basic training with auto-detected GPUs
116
- wavedl-hpc --model cnn --data_path train.npz --epochs 100
178
+ # Basic training (auto-detects GPUs and environment)
179
+ wavedl-train --model cnn --data_path train.npz --output_dir results
117
180
 
118
- # Specify GPU count and mixed precision
119
- wavedl-hpc --model cnn --data_path train.npz --num_gpus 4 --mixed_precision bf16
181
+ # Specify GPU count explicitly
182
+ wavedl-train --model cnn --data_path train.npz --num_gpus 4
120
183
 
121
184
  # Full configuration
122
- wavedl-hpc --model resnet18 --data_path train.npz --num_gpus 8 \\
123
- --batch_size 256 --lr 1e-3 --compile --output_dir ./results
185
+ wavedl-train --model resnet18 --data_path train.npz --batch_size 256 \\
186
+ --lr 1e-3 --epochs 100 --compile --output_dir ./results
187
+
188
+ # List available models
189
+ wavedl-train --list_models
124
190
 
125
- Environment Variables:
126
- WANDB_MODE WandB mode: offline|online (default: offline)
127
- SLURM_TMPDIR Temp directory for HPC systems
191
+ Environment Detection:
192
+ The launcher automatically detects your environment:
193
+ - HPC (SLURM, PBS, etc.): Uses local caching, offline WandB
194
+ - Local machine: Uses standard cache locations (~/.cache)
195
+
196
+ For full training options, see: python -m wavedl.train --help
128
197
  """,
129
198
  )
130
199
 
@@ -204,7 +273,7 @@ def print_summary(
204
273
  print("Common issues:")
205
274
  print(" - Missing data file (check --data_path)")
206
275
  print(" - Insufficient GPU memory (reduce --batch_size)")
207
- print(" - Invalid model name (run: python train.py --list_models)")
276
+ print(" - Invalid model name (run: wavedl-train --list_models)")
208
277
  print()
209
278
 
210
279
  print("=" * 40)
@@ -212,12 +281,17 @@ def print_summary(
212
281
 
213
282
 
214
283
  def main() -> int:
215
- """Main entry point for wavedl-hpc command."""
284
+ """Main entry point for wavedl-train command."""
285
+ # Fast path for utility flags (avoid accelerate launch overhead)
286
+ exit_code = handle_fast_path_args()
287
+ if exit_code is not None:
288
+ return exit_code
289
+
216
290
  # Parse arguments
217
291
  args, train_args = parse_args()
218
292
 
219
- # Setup HPC environment
220
- setup_hpc_environment()
293
+ # Setup environment (smart detection)
294
+ setup_environment()
221
295
 
222
296
  # Check if wavedl package is importable
223
297
  try:
wavedl/models/__init__.py CHANGED
@@ -77,6 +77,15 @@ from .unet import UNetRegression
77
77
  from .vit import ViTBase_, ViTSmall, ViTTiny
78
78
 
79
79
 
80
+ # Optional RATENet (unpublished, may be gitignored)
81
+ try:
82
+ from .ratenet import RATENet, RATENetLite, RATENetTiny, RATENetV2
83
+
84
+ _HAS_RATENET = True
85
+ except ImportError:
86
+ _HAS_RATENET = False
87
+
88
+
80
89
  # Optional timm-based models (imported conditionally)
81
90
  try:
82
91
  from .caformer import CaFormerS18, CaFormerS36, PoolFormerS12
@@ -111,6 +120,7 @@ __all__ = [
111
120
  "MC3_18",
112
121
  "MODEL_REGISTRY",
113
122
  "TCN",
123
+ # Classes (uppercase first, alphabetically)
114
124
  "BaseModel",
115
125
  "ConvNeXtBase_",
116
126
  "ConvNeXtSmall",
@@ -152,6 +162,7 @@ __all__ = [
152
162
  "VimBase",
153
163
  "VimSmall",
154
164
  "VimTiny",
165
+ # Functions (lowercase, alphabetically)
155
166
  "build_model",
156
167
  "get_model",
157
168
  "list_models",
@@ -186,3 +197,14 @@ if _HAS_TIMM_MODELS:
186
197
  "UniRepLKNetTiny",
187
198
  ]
188
199
  )
200
+
201
+ # Add RATENet models to __all__ if available (unpublished)
202
+ if _HAS_RATENET:
203
+ __all__.extend(
204
+ [
205
+ "RATENet",
206
+ "RATENetLite",
207
+ "RATENetTiny",
208
+ "RATENetV2",
209
+ ]
210
+ )
wavedl/test.py CHANGED
@@ -398,6 +398,14 @@ def load_checkpoint(
398
398
 
399
399
  if HAS_SAFETENSORS and weight_path.suffix == ".safetensors":
400
400
  state_dict = load_safetensors(str(weight_path))
401
+ elif weight_path.suffix == ".safetensors":
402
+ # Safetensors file exists but library not installed
403
+ raise ImportError(
404
+ f"Checkpoint uses safetensors format ({weight_path.name}) but "
405
+ f"'safetensors' package is not installed. Install it with:\n"
406
+ f" pip install safetensors\n"
407
+ f"Or convert the checkpoint to PyTorch format (model.bin)."
408
+ )
401
409
  else:
402
410
  state_dict = torch.load(weight_path, map_location="cpu", weights_only=True)
403
411
 
wavedl/train.py CHANGED
@@ -12,25 +12,22 @@ A modular training framework for wave-based inverse problems and regression:
12
12
  6. Deep Observability: WandB integration with scatter analysis
13
13
 
14
14
  Usage:
15
- # Recommended: Using the HPC launcher (handles accelerate configuration)
16
- wavedl-hpc --model cnn --batch_size 128 --mixed_precision bf16 --wandb
17
-
18
- # Or direct training module (use --precision, not --mixed_precision)
19
- accelerate launch -m wavedl.train --model cnn --batch_size 128 --precision bf16
15
+ # Recommended: Universal training command (works on local machines and HPC)
16
+ wavedl-train --model cnn --batch_size 128 --compile
20
17
 
21
18
  # Multi-GPU with explicit config
22
- wavedl-hpc --num_gpus 4 --mixed_precision bf16 --model cnn --wandb
19
+ wavedl-train --num_gpus 4 --model cnn --output_dir results
23
20
 
24
21
  # Resume from checkpoint
25
- accelerate launch -m wavedl.train --model cnn --resume best_checkpoint --wandb
22
+ wavedl-train --model cnn --output_dir results # auto-resumes if interrupted
26
23
 
27
24
  # List available models
28
25
  wavedl-train --list_models
29
26
 
30
27
  Note:
31
- - wavedl-hpc: Uses --mixed_precision (passed to accelerate launch)
32
- - wavedl.train: Uses --precision (internal module flag)
33
- Both control the same behavior; use the appropriate flag for your entry point.
28
+ wavedl-train automatically detects your environment:
29
+ - HPC clusters (SLURM, PBS, etc.): Uses local caching, offline WandB
30
+ - Local machines: Uses standard cache locations (~/.cache)
34
31
 
35
32
  Author: Ductho Le (ductho.le@outlook.com)
36
33
  """
@@ -429,7 +426,7 @@ def parse_args() -> argparse.Namespace:
429
426
  choices=["bf16", "fp16", "no"],
430
427
  help="Mixed precision mode",
431
428
  )
432
- # Alias for consistency with wavedl-hpc (--mixed_precision)
429
+ # Alias for consistency with wavedl-train (--mixed_precision is passed to accelerate)
433
430
  parser.add_argument(
434
431
  "--mixed_precision",
435
432
  dest="precision",
@@ -1269,10 +1266,10 @@ def main():
1269
1266
  os.path.join(args.output_dir, "best_model_weights.pth"),
1270
1267
  )
1271
1268
 
1272
- # Copy scaler to checkpoint for portability
1269
+ # Copy scaler to checkpoint for portability (always overwrite to stay current)
1273
1270
  scaler_src = os.path.join(args.output_dir, "scaler.pkl")
1274
1271
  scaler_dst = os.path.join(ckpt_dir, "scaler.pkl")
1275
- if os.path.exists(scaler_src) and not os.path.exists(scaler_dst):
1272
+ if os.path.exists(scaler_src):
1276
1273
  shutil.copy2(scaler_src, scaler_dst)
1277
1274
 
1278
1275
  logger.info(
wavedl/utils/data.py CHANGED
@@ -984,7 +984,15 @@ def load_test_data(
984
984
  f"Available keys depend on file format. Original error: {e}"
985
985
  ) from e
986
986
 
987
- # Legitimate fallback: no explicit output_key, outputs just not present
987
+ # Also fail-fast if explicit input_key was provided but not found
988
+ # This prevents silently loading a different tensor when user mistyped key
989
+ if input_key is not None:
990
+ raise KeyError(
991
+ f"Explicit --input_key '{input_key}' not found in file. "
992
+ f"Original error: {e}"
993
+ ) from e
994
+
995
+ # Legitimate fallback: no explicit keys, outputs just not present
988
996
  if format == "npz":
989
997
  # First pass to find keys
990
998
  with np.load(path, allow_pickle=False) as probe:
@@ -1524,21 +1532,43 @@ def prepare_data(
1524
1532
 
1525
1533
  logger.info(" ✔ Cache creation complete, synchronizing ranks...")
1526
1534
  else:
1527
- # NON-MAIN RANKS: Wait for cache creation
1528
- # Log that we're waiting (helps with debugging)
1535
+ # NON-MAIN RANKS: Wait for cache creation with timeout
1536
+ # Use monotonic clock (immune to system clock changes)
1529
1537
  import time
1530
1538
 
1531
- wait_start = time.time()
1539
+ wait_start = time.monotonic()
1540
+
1541
+ # Robust env parsing with guards for invalid/non-positive values
1542
+ DEFAULT_CACHE_TIMEOUT = 3600 # 1 hour default
1543
+ try:
1544
+ env_timeout = os.environ.get("WAVEDL_CACHE_TIMEOUT", "")
1545
+ CACHE_TIMEOUT = (
1546
+ int(env_timeout) if env_timeout else DEFAULT_CACHE_TIMEOUT
1547
+ )
1548
+ if CACHE_TIMEOUT <= 0:
1549
+ CACHE_TIMEOUT = DEFAULT_CACHE_TIMEOUT
1550
+ except ValueError:
1551
+ CACHE_TIMEOUT = DEFAULT_CACHE_TIMEOUT
1552
+
1532
1553
  while not (
1533
1554
  os.path.exists(CACHE_FILE)
1534
1555
  and os.path.exists(SCALER_FILE)
1535
1556
  and os.path.exists(META_FILE)
1536
1557
  ):
1537
1558
  time.sleep(5) # Check every 5 seconds
1538
- elapsed = time.time() - wait_start
1559
+ elapsed = time.monotonic() - wait_start
1560
+
1561
+ if elapsed > CACHE_TIMEOUT:
1562
+ raise RuntimeError(
1563
+ f"[Rank {accelerator.process_index}] Timeout waiting for cache "
1564
+ f"files after {CACHE_TIMEOUT}s. Rank 0 may have failed during "
1565
+ f"cache generation. Check rank 0 logs for errors."
1566
+ )
1567
+
1539
1568
  if elapsed > 60 and int(elapsed) % 60 < 5: # Log every ~minute
1540
1569
  logger.info(
1541
- f" [Rank {accelerator.process_index}] Waiting for cache creation... ({int(elapsed)}s)"
1570
+ f" [Rank {accelerator.process_index}] Waiting for cache "
1571
+ f"creation... ({int(elapsed)}s / {CACHE_TIMEOUT}s max)"
1542
1572
  )
1543
1573
  # Small delay to ensure files are fully written
1544
1574
  time.sleep(2)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.2
2
2
  Name: wavedl
3
- Version: 1.6.1
3
+ Version: 1.6.3
4
4
  Summary: A Scalable Deep Learning Framework for Wave-Based Inverse Problems
5
5
  Author: Ductho Le
6
6
  License: MIT
@@ -214,77 +214,83 @@ This installs everything you need: training, inference, HPO, ONNX export.
214
214
  ```bash
215
215
  git clone https://github.com/ductho-le/WaveDL.git
216
216
  cd WaveDL
217
- pip install -e .
217
+ pip install -e ".[dev]"
218
218
  ```
219
219
 
220
220
  > [!NOTE]
221
- > Python 3.11+ required. For development setup, see [CONTRIBUTING.md](.github/CONTRIBUTING.md).
221
+ > Python 3.11+ required. For contributor setup (pre-commit hooks), see [CONTRIBUTING.md](.github/CONTRIBUTING.md).
222
222
 
223
223
  ### Quick Start
224
224
 
225
225
  > [!TIP]
226
226
  > In all examples below, replace `<...>` placeholders with your values. See [Configuration](#️-configuration) for defaults and options.
227
227
 
228
- #### Option 1: Using wavedl-hpc (Recommended for HPC)
229
-
230
- The `wavedl-hpc` command automatically configures the environment for HPC systems:
228
+ ### Training
231
229
 
232
230
  ```bash
233
- # Basic training (auto-detects available GPUs)
234
- wavedl-hpc --model <model_name> --data_path <train_data> --batch_size <number> --output_dir <output_folder>
231
+ # Basic training (auto-detects GPUs and environment)
232
+ wavedl-train --model <model_name> --data_path <train_data> --output_dir <output_folder>
235
233
 
236
234
  # Detailed configuration
237
- wavedl-hpc --model <model_name> --data_path <train_data> --batch_size <number> \
235
+ wavedl-train --model <model_name> --data_path <train_data> --batch_size <number> \
238
236
  --lr <number> --epochs <number> --patience <number> --compile --output_dir <output_folder>
239
237
 
240
- # Specify GPU count explicitly
241
- wavedl-hpc --num_gpus 4 --model cnn --data_path train.npz --output_dir results
242
- ```
243
-
244
- #### Option 2: Direct Accelerate Launch
245
-
246
- ```bash
247
- # Local - auto-detects GPUs
248
- accelerate launch -m wavedl.train --model <model_name> --data_path <train_data> --batch_size <number> --output_dir <output_folder>
238
+ # Multi-GPU is automatic (uses all available GPUs)
239
+ # Override with --num_gpus if needed
240
+ wavedl-train --model cnn --data_path train.npz --num_gpus 4 --output_dir results
249
241
 
250
242
  # Resume training (automatic - just re-run with same output_dir)
251
- # Manual resume from specific checkpoint:
252
- accelerate launch -m wavedl.train --model <model_name> --data_path <train_data> --resume <checkpoint_folder> --output_dir <output_folder>
243
+ wavedl-train --model <model_name> --data_path <train_data> --output_dir <output_folder>
253
244
 
254
245
  # Force fresh start (ignores existing checkpoints)
255
- accelerate launch -m wavedl.train --model <model_name> --data_path <train_data> --output_dir <output_folder> --fresh
246
+ wavedl-train --model <model_name> --data_path <train_data> --output_dir <output_folder> --fresh
256
247
 
257
248
  # List available models
258
249
  wavedl-train --list_models
259
250
  ```
260
251
 
261
- > [!TIP]
262
- > **Auto-Resume**: If training crashes or is interrupted, simply re-run with the same `--output_dir`. The framework automatically detects incomplete training and resumes from the last checkpoint. Use `--fresh` to force a fresh start.
252
+ > [!NOTE]
253
+ > `wavedl-train` automatically detects your environment:
254
+ > - **HPC clusters** (SLURM, PBS, etc.): Uses local caching, offline WandB
255
+ > - **Local machines**: Uses standard cache locations (~/.cache)
263
256
  >
264
- > **GPU Auto-Detection**: `wavedl-hpc` automatically detects available GPUs using `nvidia-smi`. Use `--num_gpus` to override.
257
+ > **Auto-Resume**: If training crashes or is interrupted, simply re-run with the same `--output_dir`. The framework automatically detects incomplete training and resumes from the last checkpoint.
265
258
 
266
- ### Testing & Inference
259
+ <details>
260
+ <summary><b>Advanced: Direct Accelerate Launch</b></summary>
261
+
262
+ For fine-grained control over distributed training, you can use `accelerate launch` directly:
267
263
 
268
- After training, use `wavedl.test` to evaluate your model on test data:
264
+ ```bash
265
+ # Custom accelerate configuration
266
+ accelerate launch -m wavedl.train --model <model_name> --data_path <train_data> --output_dir <output_folder>
267
+
268
+ # Multi-node training
269
+ accelerate launch --num_machines 2 --main_process_ip <ip> -m wavedl.train --model cnn --data_path train.npz
270
+ ```
271
+
272
+ </details>
273
+
274
+ ### Testing & Inference
269
275
 
270
276
  ```bash
271
277
  # Basic inference
272
- python -m wavedl.test --checkpoint <checkpoint_folder> --data_path <test_data>
278
+ wavedl-test --checkpoint <checkpoint_folder> --data_path <test_data>
273
279
 
274
280
  # With visualization, CSV export, and multiple file formats
275
- python -m wavedl.test --checkpoint <checkpoint_folder> --data_path <test_data> \
281
+ wavedl-test --checkpoint <checkpoint_folder> --data_path <test_data> \
276
282
  --plot --plot_format png pdf --save_predictions --output_dir <output_folder>
277
283
 
278
284
  # With custom parameter names
279
- python -m wavedl.test --checkpoint <checkpoint_folder> --data_path <test_data> \
285
+ wavedl-test --checkpoint <checkpoint_folder> --data_path <test_data> \
280
286
  --param_names '$p_1$' '$p_2$' '$p_3$' --plot
281
287
 
282
288
  # Export model to ONNX for deployment (LabVIEW, MATLAB, C++, etc.)
283
- python -m wavedl.test --checkpoint <checkpoint_folder> --data_path <test_data> \
289
+ wavedl-test --checkpoint <checkpoint_folder> --data_path <test_data> \
284
290
  --export onnx --export_path <output_file.onnx>
285
291
 
286
292
  # For 3D volumes with small depth (e.g., 8×128×128), override auto-detection
287
- python -m wavedl.test --checkpoint <checkpoint_folder> --data_path <test_data> \
293
+ wavedl-test --checkpoint <checkpoint_folder> --data_path <test_data> \
288
294
  --input_channels 1
289
295
  ```
290
296
 
@@ -295,7 +301,7 @@ python -m wavedl.test --checkpoint <checkpoint_folder> --data_path <test_data> \
295
301
  - **Format** (with `--plot_format`): Supported formats: `png` (default), `pdf` (vector), `svg` (vector), `eps` (LaTeX), `tiff`, `jpg`, `ps`
296
302
 
297
303
  > [!NOTE]
298
- > `wavedl.test` auto-detects the model architecture from checkpoint metadata. If unavailable, it falls back to folder name parsing. Use `--model` to override if needed.
304
+ > `wavedl-test` auto-detects the model architecture from checkpoint metadata. If unavailable, it falls back to folder name parsing. Use `--model` to override if needed.
299
305
 
300
306
  ### Adding Custom Models
301
307
 
@@ -339,7 +345,7 @@ class MyModel(BaseModel):
339
345
  **Step 2: Train**
340
346
 
341
347
  ```bash
342
- wavedl-hpc --import my_model.py --model my_model --data_path train.npz
348
+ wavedl-train --import my_model.py --model my_model --data_path train.npz
343
349
  ```
344
350
 
345
351
  WaveDL handles everything else: training loop, logging, checkpoints, multi-GPU, early stopping, etc.
@@ -355,10 +361,10 @@ WaveDL/
355
361
  ├── src/
356
362
  │ └── wavedl/ # Main package (namespaced)
357
363
  │ ├── __init__.py # Package init with __version__
358
- │ ├── train.py # Training entry point
364
+ │ ├── train.py # Training script
359
365
  │ ├── test.py # Testing & inference script
360
366
  │ ├── hpo.py # Hyperparameter optimization
361
- │ ├── hpc.py # HPC distributed training launcher
367
+ │ ├── launcher.py # Training launcher (wavedl-train)
362
368
  │ │
363
369
  │ ├── models/ # Model Zoo (69 architectures)
364
370
  │ │ ├── registry.py # Model factory (@register_model)
@@ -389,16 +395,7 @@ WaveDL/
389
395
  ## ⚙️ Configuration
390
396
 
391
397
  > [!NOTE]
392
- > All configuration options below work with **both** `wavedl-hpc` and direct `accelerate launch`. The wrapper script passes all arguments directly to `train.py`.
393
- >
394
- > **Examples:**
395
- > ```bash
396
- > # Using wavedl-hpc
397
- > wavedl-hpc --model cnn --batch_size 256 --lr 5e-4 --compile
398
- >
399
- > # Using accelerate launch directly
400
- > accelerate launch -m wavedl.train --model cnn --batch_size 256 --lr 5e-4 --compile
401
- > ```
398
+ > All configuration options below work with `wavedl-train`. The wrapper script passes all arguments directly to `train.py`.
402
399
 
403
400
  <details>
404
401
  <summary><b>Available Models</b> — 69 architectures</summary>
@@ -642,7 +639,7 @@ WaveDL automatically enables performance optimizations for modern GPUs:
642
639
  </details>
643
640
 
644
641
  <details>
645
- <summary><b>HPC CLI Arguments (wavedl-hpc)</b></summary>
642
+ <summary><b>Distributed Training Arguments</b></summary>
646
643
 
647
644
  | Argument | Default | Description |
648
645
  |----------|---------|-------------|
@@ -674,10 +671,10 @@ WaveDL automatically enables performance optimizations for modern GPUs:
674
671
  **Example:**
675
672
  ```bash
676
673
  # Use Huber loss for noisy NDE data
677
- accelerate launch -m wavedl.train --model cnn --loss huber --huber_delta 0.5
674
+ wavedl-train --model cnn --loss huber --huber_delta 0.5
678
675
 
679
676
  # Weighted MSE: prioritize thickness (first target)
680
- accelerate launch -m wavedl.train --model cnn --loss weighted_mse --loss_weights "2.0,1.0,1.0"
677
+ wavedl-train --model cnn --loss weighted_mse --loss_weights "2.0,1.0,1.0"
681
678
  ```
682
679
 
683
680
  </details>
@@ -697,10 +694,10 @@ accelerate launch -m wavedl.train --model cnn --loss weighted_mse --loss_weights
697
694
  **Example:**
698
695
  ```bash
699
696
  # SGD with Nesterov momentum (often better generalization)
700
- accelerate launch -m wavedl.train --model cnn --optimizer sgd --lr 0.01 --momentum 0.9 --nesterov
697
+ wavedl-train --model cnn --optimizer sgd --lr 0.01 --momentum 0.9 --nesterov
701
698
 
702
699
  # RAdam for more stable training
703
- accelerate launch -m wavedl.train --model cnn --optimizer radam --lr 1e-3
700
+ wavedl-train --model cnn --optimizer radam --lr 1e-3
704
701
  ```
705
702
 
706
703
  </details>
@@ -722,13 +719,13 @@ accelerate launch -m wavedl.train --model cnn --optimizer radam --lr 1e-3
722
719
  **Example:**
723
720
  ```bash
724
721
  # Cosine annealing for 1000 epochs
725
- accelerate launch -m wavedl.train --model cnn --scheduler cosine --epochs 1000 --min_lr 1e-7
722
+ wavedl-train --model cnn --scheduler cosine --epochs 1000 --min_lr 1e-7
726
723
 
727
724
  # OneCycleLR for super-convergence
728
- accelerate launch -m wavedl.train --model cnn --scheduler onecycle --lr 1e-2 --epochs 50
725
+ wavedl-train --model cnn --scheduler onecycle --lr 1e-2 --epochs 50
729
726
 
730
727
  # MultiStep with custom milestones
731
- accelerate launch -m wavedl.train --model cnn --scheduler multistep --milestones "100,200,300"
728
+ wavedl-train --model cnn --scheduler multistep --milestones "100,200,300"
732
729
  ```
733
730
 
734
731
  </details>
@@ -739,16 +736,14 @@ accelerate launch -m wavedl.train --model cnn --scheduler multistep --milestones
739
736
  For robust model evaluation, simply add the `--cv` flag:
740
737
 
741
738
  ```bash
742
- # 5-fold cross-validation (works with both methods!)
743
- wavedl-hpc --model cnn --cv 5 --data_path train_data.npz
744
- # OR
745
- accelerate launch -m wavedl.train --model cnn --cv 5 --data_path train_data.npz
739
+ # 5-fold cross-validation
740
+ wavedl-train --model cnn --cv 5 --data_path train_data.npz
746
741
 
747
742
  # Stratified CV (recommended for unbalanced data)
748
- wavedl-hpc --model cnn --cv 5 --cv_stratify --loss huber --epochs 100
743
+ wavedl-train --model cnn --cv 5 --cv_stratify --loss huber --epochs 100
749
744
 
750
745
  # Full configuration
751
- wavedl-hpc --model cnn --cv 5 --cv_stratify \
746
+ wavedl-train --model cnn --cv 5 --cv_stratify \
752
747
  --loss huber --optimizer adamw --scheduler cosine \
753
748
  --output_dir ./cv_results
754
749
  ```
@@ -773,10 +768,10 @@ Use YAML files for reproducible experiments. CLI arguments can override any conf
773
768
 
774
769
  ```bash
775
770
  # Use a config file
776
- accelerate launch -m wavedl.train --config configs/config.yaml --data_path train.npz
771
+ wavedl-train --config configs/config.yaml --data_path train.npz
777
772
 
778
773
  # Override specific values from config
779
- accelerate launch -m wavedl.train --config configs/config.yaml --lr 5e-4 --epochs 500
774
+ wavedl-train --config configs/config.yaml --lr 5e-4 --epochs 500
780
775
  ```
781
776
 
782
777
  **Example config (`configs/config.yaml`):**
@@ -929,7 +924,7 @@ wavedl-hpo --data_path train.npz --models cnn --n_trials 50 --quick
929
924
 
930
925
  After HPO completes, it prints the optimal command:
931
926
  ```bash
932
- accelerate launch -m wavedl.train --data_path train.npz --model cnn --lr 3.2e-4 --batch_size 128 ...
927
+ wavedl-train --data_path train.npz --model cnn --lr 3.2e-4 --batch_size 128 ...
933
928
  ```
934
929
 
935
930
  ---
@@ -1122,12 +1117,12 @@ The `examples/` folder contains a **complete, ready-to-run example** for **mater
1122
1117
 
1123
1118
  ```bash
1124
1119
  # Run inference on the example data
1125
- python -m wavedl.test --checkpoint ./examples/elasticity_prediction/best_checkpoint \
1120
+ wavedl-test --checkpoint ./examples/elasticity_prediction/best_checkpoint \
1126
1121
  --data_path ./examples/elasticity_prediction/Test_data_100.mat \
1127
1122
  --plot --save_predictions --output_dir ./examples/elasticity_prediction/test_results
1128
1123
 
1129
1124
  # Export to ONNX (already included as model.onnx)
1130
- python -m wavedl.test --checkpoint ./examples/elasticity_prediction/best_checkpoint \
1125
+ wavedl-test --checkpoint ./examples/elasticity_prediction/best_checkpoint \
1131
1126
  --data_path ./examples/elasticity_prediction/Test_data_100.mat \
1132
1127
  --export onnx --export_path ./examples/elasticity_prediction/model.onnx
1133
1128
  ```
@@ -1,9 +1,9 @@
1
- wavedl/__init__.py,sha256=tz3qpBFZ4wTNv32Nz7aKGmHCXrTFTxAEJQeJO8La38Q,1177
2
- wavedl/hpc.py,sha256=6rV38nozzMt0-jKZbVJNwvQZXK0wUsIZmr9lgWN_XUw,9212
3
- wavedl/hpo.py,sha256=6eHYV9Nzbp2YbTY52NRnW7pwzlI_DNWskN-zBR-wj24,14654
4
- wavedl/test.py,sha256=1UUy9phCqrr3h_lN6mGJ7Sj73skDg4KyLk2Yuq9DiKU,38797
5
- wavedl/train.py,sha256=PzJGARHounr6R8WUOrUwwd2hRcLsGkxes08jYKkBRIo,58003
6
- wavedl/models/__init__.py,sha256=8OiT2seq1qBiUzKaSkmh_VOLJlLTT9Cn-mjhMHKGFpI,5203
1
+ wavedl/__init__.py,sha256=CdjeIFEZh4ccwPiNsUxpac5l3pxwb9e1SmCNiLiYG1I,1177
2
+ wavedl/hpo.py,sha256=nEiy-2O_5EhxF5hU8X5TviSAiXfVrTQx0-VE6baW7JQ,14633
3
+ wavedl/launcher.py,sha256=_CFlgpKgHrtZebl1yQbJZJEcob06Y9-fqnRYzwW7UJQ,11776
4
+ wavedl/test.py,sha256=5MzBtEH2lWWYG23Fz-VpMFAWR5SfZbFomBbu8ptsZRU,39208
5
+ wavedl/train.py,sha256=DizXhi9BFL8heLmO8ENiNm2QubAMm9mdpDiaBlULeKM,57824
6
+ wavedl/models/__init__.py,sha256=hyR__h_D8PsUQCBSM5tj94yYK00uG8ABjEmj_RR8SGE,5719
7
7
  wavedl/models/_pretrained_utils.py,sha256=VPdU1DwJB93ZBf_GFIgb8-6BbAt18Phs4yorwlhLw70,12404
8
8
  wavedl/models/_template.py,sha256=J_D8taSPmV8lBaucN_vU-WiG98iFr7CJrZVNNX_Tdts,4600
9
9
  wavedl/models/base.py,sha256=bDoHYFli-aR8amcFYXbF98QYaKSCEwZWpvOhN21ODro,9075
@@ -32,15 +32,15 @@ wavedl/utils/__init__.py,sha256=s5R9bRmJ8GNcJrD3OSAOXzwZJIXZbdYrAkZnus11sVQ,3300
32
32
  wavedl/utils/config.py,sha256=MXkaVc1_zo8sDro8mjtK1MV65t2z8b1Z6fviwSorNiY,10534
33
33
  wavedl/utils/constraints.py,sha256=V9Gyi8-uIMbLUWb2cOaHZD0SliWLxVrHZHFyo4HWK7g,18031
34
34
  wavedl/utils/cross_validation.py,sha256=HfInyZ8gUROc_AyihYKzzUE0vnoPt_mFvAI2OPK4P54,17945
35
- wavedl/utils/data.py,sha256=5ph2Pi8PKvuaSoJaXbFIL9WsX8pTN0A6P8FdmxvXdv4,63469
35
+ wavedl/utils/data.py,sha256=HXod6i6g76oFAjLz7xepBPQEFHRgQ7E1M-YSKwUya-I,64799
36
36
  wavedl/utils/distributed.py,sha256=7wQ3mRjkp_xjPSxDWMnBf5dSkAGUaTzntxbz0BhC5v0,4145
37
37
  wavedl/utils/losses.py,sha256=KWpU5S5noFzp3bLbcH9RNpkFPajy6fyTIh5cNjI-BYA,7038
38
38
  wavedl/utils/metrics.py,sha256=YoqiXWOsUB9Y4_alj8CmHcTgnV4MFcH5PH4XlIC13HY,40304
39
39
  wavedl/utils/optimizers.py,sha256=ZoETDSOK1fWUT2dx69PyYebeM8Vcqf9zOIKUERWk5HY,6107
40
40
  wavedl/utils/schedulers.py,sha256=K6YCiyiMM9rb0cCRXTp89noXeXcAyUEiePr27O5Cozs,7408
41
- wavedl-1.6.1.dist-info/LICENSE,sha256=cEUCvcvH-9BT9Y-CNGY__PwWONCKu9zsoIqWA-NeHJ4,1066
42
- wavedl-1.6.1.dist-info/METADATA,sha256=eS4uG6dzEVs25zYmiZZnGeHz8lUHIVKL9TpyCJt7kh8,48232
43
- wavedl-1.6.1.dist-info/WHEEL,sha256=beeZ86-EfXScwlR_HKu4SllMC9wUEj_8Z_4FJ3egI2w,91
44
- wavedl-1.6.1.dist-info/entry_points.txt,sha256=f1RNDkXFZwBzrBzTMFocJ6xhfTvTmaEDTi5YyDEUaF8,140
45
- wavedl-1.6.1.dist-info/top_level.txt,sha256=ccneUt3D5Qzbh3bsBSSrq9bqrhGiogcWKY24ZC4Q6Xw,7
46
- wavedl-1.6.1.dist-info/RECORD,,
41
+ wavedl-1.6.3.dist-info/LICENSE,sha256=cEUCvcvH-9BT9Y-CNGY__PwWONCKu9zsoIqWA-NeHJ4,1066
42
+ wavedl-1.6.3.dist-info/METADATA,sha256=BtPoAiMwvE58b-dmR0TfPPLLBqYzaAiJ2GBUd0FBSY0,47613
43
+ wavedl-1.6.3.dist-info/WHEEL,sha256=beeZ86-EfXScwlR_HKu4SllMC9wUEj_8Z_4FJ3egI2w,91
44
+ wavedl-1.6.3.dist-info/entry_points.txt,sha256=NuAvdiG93EYYpqv-_1wf6PN0WqBfABanDKalNKe2GOs,148
45
+ wavedl-1.6.3.dist-info/top_level.txt,sha256=ccneUt3D5Qzbh3bsBSSrq9bqrhGiogcWKY24ZC4Q6Xw,7
46
+ wavedl-1.6.3.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  [console_scripts]
2
- wavedl-hpc = wavedl.hpc:main
2
+ wavedl-hpc = wavedl.launcher:main
3
3
  wavedl-hpo = wavedl.hpo:main
4
4
  wavedl-test = wavedl.test:main
5
- wavedl-train = wavedl.train:main
5
+ wavedl-train = wavedl.launcher:main
File without changes