wavedl 1.2.0__tar.gz → 1.3.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (36) hide show
  1. {wavedl-1.2.0/src/wavedl.egg-info → wavedl-1.3.0}/PKG-INFO +83 -29
  2. {wavedl-1.2.0 → wavedl-1.3.0}/README.md +80 -26
  3. {wavedl-1.2.0 → wavedl-1.3.0}/pyproject.toml +3 -2
  4. {wavedl-1.2.0 → wavedl-1.3.0}/src/wavedl/__init__.py +1 -1
  5. wavedl-1.3.0/src/wavedl/hpc.py +243 -0
  6. {wavedl-1.2.0 → wavedl-1.3.0}/src/wavedl/hpo.py +8 -8
  7. {wavedl-1.2.0 → wavedl-1.3.0}/src/wavedl/models/_template.py +1 -1
  8. {wavedl-1.2.0 → wavedl-1.3.0}/src/wavedl/test.py +3 -3
  9. {wavedl-1.2.0 → wavedl-1.3.0}/src/wavedl/train.py +47 -9
  10. {wavedl-1.2.0 → wavedl-1.3.0/src/wavedl.egg-info}/PKG-INFO +83 -29
  11. {wavedl-1.2.0 → wavedl-1.3.0}/src/wavedl.egg-info/SOURCES.txt +1 -0
  12. {wavedl-1.2.0 → wavedl-1.3.0}/src/wavedl.egg-info/entry_points.txt +1 -0
  13. {wavedl-1.2.0 → wavedl-1.3.0}/src/wavedl.egg-info/requires.txt +5 -1
  14. {wavedl-1.2.0 → wavedl-1.3.0}/LICENSE +0 -0
  15. {wavedl-1.2.0 → wavedl-1.3.0}/setup.cfg +0 -0
  16. {wavedl-1.2.0 → wavedl-1.3.0}/src/wavedl/models/__init__.py +0 -0
  17. {wavedl-1.2.0 → wavedl-1.3.0}/src/wavedl/models/base.py +0 -0
  18. {wavedl-1.2.0 → wavedl-1.3.0}/src/wavedl/models/cnn.py +0 -0
  19. {wavedl-1.2.0 → wavedl-1.3.0}/src/wavedl/models/convnext.py +0 -0
  20. {wavedl-1.2.0 → wavedl-1.3.0}/src/wavedl/models/densenet.py +0 -0
  21. {wavedl-1.2.0 → wavedl-1.3.0}/src/wavedl/models/efficientnet.py +0 -0
  22. {wavedl-1.2.0 → wavedl-1.3.0}/src/wavedl/models/registry.py +0 -0
  23. {wavedl-1.2.0 → wavedl-1.3.0}/src/wavedl/models/resnet.py +0 -0
  24. {wavedl-1.2.0 → wavedl-1.3.0}/src/wavedl/models/unet.py +0 -0
  25. {wavedl-1.2.0 → wavedl-1.3.0}/src/wavedl/models/vit.py +0 -0
  26. {wavedl-1.2.0 → wavedl-1.3.0}/src/wavedl/utils/__init__.py +0 -0
  27. {wavedl-1.2.0 → wavedl-1.3.0}/src/wavedl/utils/config.py +0 -0
  28. {wavedl-1.2.0 → wavedl-1.3.0}/src/wavedl/utils/cross_validation.py +0 -0
  29. {wavedl-1.2.0 → wavedl-1.3.0}/src/wavedl/utils/data.py +0 -0
  30. {wavedl-1.2.0 → wavedl-1.3.0}/src/wavedl/utils/distributed.py +0 -0
  31. {wavedl-1.2.0 → wavedl-1.3.0}/src/wavedl/utils/losses.py +0 -0
  32. {wavedl-1.2.0 → wavedl-1.3.0}/src/wavedl/utils/metrics.py +0 -0
  33. {wavedl-1.2.0 → wavedl-1.3.0}/src/wavedl/utils/optimizers.py +0 -0
  34. {wavedl-1.2.0 → wavedl-1.3.0}/src/wavedl/utils/schedulers.py +0 -0
  35. {wavedl-1.2.0 → wavedl-1.3.0}/src/wavedl.egg-info/dependency_links.txt +0 -0
  36. {wavedl-1.2.0 → wavedl-1.3.0}/src/wavedl.egg-info/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.2
2
2
  Name: wavedl
3
- Version: 1.2.0
3
+ Version: 1.3.0
4
4
  Summary: A Scalable Deep Learning Framework for Wave-Based Inverse Problems
5
5
  Author: Ductho Le
6
6
  License: MIT
@@ -43,7 +43,7 @@ Provides-Extra: onnx
43
43
  Requires-Dist: onnx>=1.14.0; extra == "onnx"
44
44
  Requires-Dist: onnxruntime>=1.15.0; extra == "onnx"
45
45
  Provides-Extra: compile
46
- Requires-Dist: triton; extra == "compile"
46
+ Requires-Dist: triton; sys_platform == "linux" and extra == "compile"
47
47
  Provides-Extra: hpo
48
48
  Requires-Dist: optuna>=3.0.0; extra == "hpo"
49
49
  Provides-Extra: all
@@ -53,7 +53,7 @@ Requires-Dist: ruff>=0.8.0; extra == "all"
53
53
  Requires-Dist: pre-commit>=3.5.0; extra == "all"
54
54
  Requires-Dist: onnx>=1.14.0; extra == "all"
55
55
  Requires-Dist: onnxruntime>=1.15.0; extra == "all"
56
- Requires-Dist: triton; extra == "all"
56
+ Requires-Dist: triton; sys_platform == "linux" and extra == "all"
57
57
  Requires-Dist: optuna>=3.0.0; extra == "all"
58
58
 
59
59
  <div align="center">
@@ -211,40 +211,43 @@ Deploy models anywhere:
211
211
  ### Installation
212
212
 
213
213
  ```bash
214
- git clone https://github.com/ductho-le/WaveDL.git
215
- cd WaveDL
214
+ # Install from PyPI (recommended)
215
+ pip install wavedl
216
+
217
+ # Or install with all extras (ONNX export, HPO, dev tools)
218
+ pip install wavedl[all]
219
+ ```
216
220
 
217
- # Basic install (training + inference)
218
- pip install -e .
221
+ #### From Source (for development)
219
222
 
220
- # Full install (adds ONNX export, torch.compile, HPO, dev tools)
221
- pip install -e ".[all]"
223
+ ```bash
224
+ git clone https://github.com/ductho-le/WaveDL.git
225
+ cd WaveDL
226
+ pip install -e ".[dev]"
222
227
  ```
223
228
 
224
229
  > [!NOTE]
225
- > Dependencies are managed in `pyproject.toml`. Python 3.11+ required.
226
- >
227
- > For development setup (running tests, contributing), see [CONTRIBUTING.md](.github/CONTRIBUTING.md).
230
+ > Python 3.11+ required. For development setup, see [CONTRIBUTING.md](.github/CONTRIBUTING.md).
228
231
 
229
232
  ### Quick Start
230
233
 
231
234
  > [!TIP]
232
235
  > In all examples below, replace `<...>` placeholders with your values. See [Configuration](#️-configuration) for defaults and options.
233
236
 
234
- #### Option 1: Using the Helper Script (Recommended for HPC)
237
+ #### Option 1: Using wavedl-hpc (Recommended for HPC)
235
238
 
236
- The `run_training.sh` wrapper automatically configures the environment for HPC systems:
239
+ The `wavedl-hpc` command automatically configures the environment for HPC systems:
237
240
 
238
241
  ```bash
239
- # Make executable (first time only)
240
- chmod +x run_training.sh
241
-
242
242
  # Basic training (auto-detects available GPUs)
243
- ./run_training.sh --model <model_name> --data_path <train_data> --batch_size <number> --output_dir <output_folder>
243
+ wavedl-hpc --model <model_name> --data_path <train_data> --batch_size <number> --output_dir <output_folder>
244
244
 
245
245
  # Detailed configuration
246
- ./run_training.sh --model <model_name> --data_path <train_data> --batch_size <number> \
246
+ wavedl-hpc --model <model_name> --data_path <train_data> --batch_size <number> \
247
247
  --lr <number> --epochs <number> --patience <number> --compile --output_dir <output_folder>
248
+
249
+ # Specify GPU count explicitly
250
+ wavedl-hpc --num_gpus 4 --model cnn --data_path train.npz --output_dir results
248
251
  ```
249
252
 
250
253
  #### Option 2: Direct Accelerate Launch
@@ -261,13 +264,13 @@ accelerate launch -m wavedl.train --model <model_name> --data_path <train_data>
261
264
  accelerate launch -m wavedl.train --model <model_name> --data_path <train_data> --output_dir <output_folder> --fresh
262
265
 
263
266
  # List available models
264
- python -m wavedl.train --list_models
267
+ wavedl-train --list_models
265
268
  ```
266
269
 
267
270
  > [!TIP]
268
271
  > **Auto-Resume**: If training crashes or is interrupted, simply re-run with the same `--output_dir`. The framework automatically detects incomplete training and resumes from the last checkpoint. Use `--fresh` to force a fresh start.
269
272
  >
270
- > **GPU Auto-Detection**: By default, `run_training.sh` automatically detects available GPUs using `nvidia-smi`. Set `NUM_GPUS` to override this behavior.
273
+ > **GPU Auto-Detection**: `wavedl-hpc` automatically detects available GPUs using `nvidia-smi`. Use `--num_gpus` to override.
271
274
 
272
275
  ### Testing & Inference
273
276
 
@@ -299,6 +302,56 @@ python -m wavedl.test --checkpoint <checkpoint_folder> --data_path <test_data> \
299
302
  > [!NOTE]
300
303
  > `wavedl.test` auto-detects the model architecture from checkpoint metadata. If unavailable, it falls back to folder name parsing. Use `--model` to override if needed.
301
304
 
305
+ ### Adding Custom Models
306
+
307
+ <details>
308
+ <summary><b>Creating Your Own Architecture</b></summary>
309
+
310
+ **Requirements** (your model must):
311
+ 1. Inherit from `BaseModel`
312
+ 2. Accept `in_channels`, `num_outputs`, `input_shape` in `__init__`
313
+ 3. Return a tensor of shape `(batch, num_outputs)` from `forward()`
314
+
315
+ ---
316
+
317
+ **Step 1: Create `my_model.py`**
318
+
319
+ ```python
320
+ import torch.nn as nn
321
+ import torch.nn.functional as F
322
+ from wavedl.models import BaseModel, register_model
323
+
324
+ @register_model("my_model") # This name is used with --model flag
325
+ class MyModel(BaseModel):
326
+ def __init__(self, in_channels, num_outputs, input_shape):
327
+ # in_channels: number of input channels (auto-detected from data)
328
+ # num_outputs: number of parameters to predict (auto-detected from data)
329
+ # input_shape: spatial dimensions, e.g., (128,) or (64, 64) or (32, 32, 32)
330
+ super().__init__(in_channels, num_outputs, input_shape)
331
+
332
+ # Define your layers (this is just an example)
333
+ self.conv1 = nn.Conv2d(in_channels, 64, 3, padding=1)
334
+ self.conv2 = nn.Conv2d(64, 128, 3, padding=1)
335
+ self.fc = nn.Linear(128, num_outputs)
336
+
337
+ def forward(self, x):
338
+ # Input x has shape: (batch, in_channels, *input_shape)
339
+ x = F.relu(self.conv1(x))
340
+ x = F.relu(self.conv2(x))
341
+ x = x.mean(dim=[-2, -1]) # Global average pooling
342
+ return self.fc(x) # Output shape: (batch, num_outputs)
343
+ ```
344
+
345
+ **Step 2: Train**
346
+
347
+ ```bash
348
+ wavedl-hpc --import my_model --model my_model --data_path train.npz
349
+ ```
350
+
351
+ WaveDL handles everything else: training loop, logging, checkpoints, multi-GPU, early stopping, etc.
352
+
353
+ </details>
354
+
302
355
  ---
303
356
 
304
357
  ## 📁 Project Structure
@@ -311,6 +364,7 @@ WaveDL/
311
364
  │ ├── train.py # Training entry point
312
365
  │ ├── test.py # Testing & inference script
313
366
  │ ├── hpo.py # Hyperparameter optimization
367
+ │ ├── hpc.py # HPC distributed training launcher
314
368
  │ │
315
369
  │ ├── models/ # Model architectures
316
370
  │ │ ├── registry.py # Model factory (@register_model)
@@ -332,7 +386,6 @@ WaveDL/
332
386
  │ ├── schedulers.py # LR scheduler factory
333
387
  │ └── config.py # YAML configuration support
334
388
 
335
- ├── run_training.sh # HPC helper script
336
389
  ├── configs/ # YAML config templates
337
390
  ├── examples/ # Ready-to-run examples
338
391
  ├── notebooks/ # Jupyter notebooks
@@ -347,12 +400,12 @@ WaveDL/
347
400
  ## ⚙️ Configuration
348
401
 
349
402
  > [!NOTE]
350
- > All configuration options below work with **both** `run_training.sh` and direct `accelerate launch`. The wrapper script passes all arguments directly to `train.py`.
403
+ > All configuration options below work with **both** `wavedl-hpc` and direct `accelerate launch`. The wrapper script passes all arguments directly to `train.py`.
351
404
  >
352
405
  > **Examples:**
353
406
  > ```bash
354
- > # Using run_training.sh
355
- > ./run_training.sh --model cnn --batch_size 256 --lr 5e-4 --compile
407
+ > # Using wavedl-hpc
408
+ > wavedl-hpc --model cnn --batch_size 256 --lr 5e-4 --compile
356
409
  >
357
410
  > # Using accelerate launch directly
358
411
  > accelerate launch -m wavedl.train --model cnn --batch_size 256 --lr 5e-4 --compile
@@ -395,6 +448,7 @@ WaveDL/
395
448
  | Argument | Default | Description |
396
449
  |----------|---------|-------------|
397
450
  | `--model` | `cnn` | Model architecture |
451
+ | `--import` | - | Python modules to import (for custom models) |
398
452
  | `--batch_size` | `128` | Per-GPU batch size |
399
453
  | `--lr` | `1e-3` | Learning rate |
400
454
  | `--epochs` | `1000` | Maximum epochs |
@@ -434,7 +488,7 @@ WaveDL/
434
488
  </details>
435
489
 
436
490
  <details>
437
- <summary><b>Environment Variables (run_training.sh)</b></summary>
491
+ <summary><b>Environment Variables (wavedl-hpc)</b></summary>
438
492
 
439
493
  | Variable | Default | Description |
440
494
  |----------|---------|-------------|
@@ -527,15 +581,15 @@ For robust model evaluation, simply add the `--cv` flag:
527
581
 
528
582
  ```bash
529
583
  # 5-fold cross-validation (works with both methods!)
530
- ./run_training.sh --model cnn --cv 5 --data_path train_data.npz
584
+ wavedl-hpc --model cnn --cv 5 --data_path train_data.npz
531
585
  # OR
532
586
  accelerate launch -m wavedl.train --model cnn --cv 5 --data_path train_data.npz
533
587
 
534
588
  # Stratified CV (recommended for unbalanced data)
535
- ./run_training.sh --model cnn --cv 5 --cv_stratify --loss huber --epochs 100
589
+ wavedl-hpc --model cnn --cv 5 --cv_stratify --loss huber --epochs 100
536
590
 
537
591
  # Full configuration
538
- ./run_training.sh --model cnn --cv 5 --cv_stratify \
592
+ wavedl-hpc --model cnn --cv 5 --cv_stratify \
539
593
  --loss huber --optimizer adamw --scheduler cosine \
540
594
  --output_dir ./cv_results
541
595
  ```
@@ -153,40 +153,43 @@ Deploy models anywhere:
153
153
  ### Installation
154
154
 
155
155
  ```bash
156
- git clone https://github.com/ductho-le/WaveDL.git
157
- cd WaveDL
156
+ # Install from PyPI (recommended)
157
+ pip install wavedl
158
+
159
+ # Or install with all extras (ONNX export, HPO, dev tools)
160
+ pip install wavedl[all]
161
+ ```
158
162
 
159
- # Basic install (training + inference)
160
- pip install -e .
163
+ #### From Source (for development)
161
164
 
162
- # Full install (adds ONNX export, torch.compile, HPO, dev tools)
163
- pip install -e ".[all]"
165
+ ```bash
166
+ git clone https://github.com/ductho-le/WaveDL.git
167
+ cd WaveDL
168
+ pip install -e ".[dev]"
164
169
  ```
165
170
 
166
171
  > [!NOTE]
167
- > Dependencies are managed in `pyproject.toml`. Python 3.11+ required.
168
- >
169
- > For development setup (running tests, contributing), see [CONTRIBUTING.md](.github/CONTRIBUTING.md).
172
+ > Python 3.11+ required. For development setup, see [CONTRIBUTING.md](.github/CONTRIBUTING.md).
170
173
 
171
174
  ### Quick Start
172
175
 
173
176
  > [!TIP]
174
177
  > In all examples below, replace `<...>` placeholders with your values. See [Configuration](#️-configuration) for defaults and options.
175
178
 
176
- #### Option 1: Using the Helper Script (Recommended for HPC)
179
+ #### Option 1: Using wavedl-hpc (Recommended for HPC)
177
180
 
178
- The `run_training.sh` wrapper automatically configures the environment for HPC systems:
181
+ The `wavedl-hpc` command automatically configures the environment for HPC systems:
179
182
 
180
183
  ```bash
181
- # Make executable (first time only)
182
- chmod +x run_training.sh
183
-
184
184
  # Basic training (auto-detects available GPUs)
185
- ./run_training.sh --model <model_name> --data_path <train_data> --batch_size <number> --output_dir <output_folder>
185
+ wavedl-hpc --model <model_name> --data_path <train_data> --batch_size <number> --output_dir <output_folder>
186
186
 
187
187
  # Detailed configuration
188
- ./run_training.sh --model <model_name> --data_path <train_data> --batch_size <number> \
188
+ wavedl-hpc --model <model_name> --data_path <train_data> --batch_size <number> \
189
189
  --lr <number> --epochs <number> --patience <number> --compile --output_dir <output_folder>
190
+
191
+ # Specify GPU count explicitly
192
+ wavedl-hpc --num_gpus 4 --model cnn --data_path train.npz --output_dir results
190
193
  ```
191
194
 
192
195
  #### Option 2: Direct Accelerate Launch
@@ -203,13 +206,13 @@ accelerate launch -m wavedl.train --model <model_name> --data_path <train_data>
203
206
  accelerate launch -m wavedl.train --model <model_name> --data_path <train_data> --output_dir <output_folder> --fresh
204
207
 
205
208
  # List available models
206
- python -m wavedl.train --list_models
209
+ wavedl-train --list_models
207
210
  ```
208
211
 
209
212
  > [!TIP]
210
213
  > **Auto-Resume**: If training crashes or is interrupted, simply re-run with the same `--output_dir`. The framework automatically detects incomplete training and resumes from the last checkpoint. Use `--fresh` to force a fresh start.
211
214
  >
212
- > **GPU Auto-Detection**: By default, `run_training.sh` automatically detects available GPUs using `nvidia-smi`. Set `NUM_GPUS` to override this behavior.
215
+ > **GPU Auto-Detection**: `wavedl-hpc` automatically detects available GPUs using `nvidia-smi`. Use `--num_gpus` to override.
213
216
 
214
217
  ### Testing & Inference
215
218
 
@@ -241,6 +244,56 @@ python -m wavedl.test --checkpoint <checkpoint_folder> --data_path <test_data> \
241
244
  > [!NOTE]
242
245
  > `wavedl.test` auto-detects the model architecture from checkpoint metadata. If unavailable, it falls back to folder name parsing. Use `--model` to override if needed.
243
246
 
247
+ ### Adding Custom Models
248
+
249
+ <details>
250
+ <summary><b>Creating Your Own Architecture</b></summary>
251
+
252
+ **Requirements** (your model must):
253
+ 1. Inherit from `BaseModel`
254
+ 2. Accept `in_channels`, `num_outputs`, `input_shape` in `__init__`
255
+ 3. Return a tensor of shape `(batch, num_outputs)` from `forward()`
256
+
257
+ ---
258
+
259
+ **Step 1: Create `my_model.py`**
260
+
261
+ ```python
262
+ import torch.nn as nn
263
+ import torch.nn.functional as F
264
+ from wavedl.models import BaseModel, register_model
265
+
266
+ @register_model("my_model") # This name is used with --model flag
267
+ class MyModel(BaseModel):
268
+ def __init__(self, in_channels, num_outputs, input_shape):
269
+ # in_channels: number of input channels (auto-detected from data)
270
+ # num_outputs: number of parameters to predict (auto-detected from data)
271
+ # input_shape: spatial dimensions, e.g., (128,) or (64, 64) or (32, 32, 32)
272
+ super().__init__(in_channels, num_outputs, input_shape)
273
+
274
+ # Define your layers (this is just an example)
275
+ self.conv1 = nn.Conv2d(in_channels, 64, 3, padding=1)
276
+ self.conv2 = nn.Conv2d(64, 128, 3, padding=1)
277
+ self.fc = nn.Linear(128, num_outputs)
278
+
279
+ def forward(self, x):
280
+ # Input x has shape: (batch, in_channels, *input_shape)
281
+ x = F.relu(self.conv1(x))
282
+ x = F.relu(self.conv2(x))
283
+ x = x.mean(dim=[-2, -1]) # Global average pooling
284
+ return self.fc(x) # Output shape: (batch, num_outputs)
285
+ ```
286
+
287
+ **Step 2: Train**
288
+
289
+ ```bash
290
+ wavedl-hpc --import my_model --model my_model --data_path train.npz
291
+ ```
292
+
293
+ WaveDL handles everything else: training loop, logging, checkpoints, multi-GPU, early stopping, etc.
294
+
295
+ </details>
296
+
244
297
  ---
245
298
 
246
299
  ## 📁 Project Structure
@@ -253,6 +306,7 @@ WaveDL/
253
306
  │ ├── train.py # Training entry point
254
307
  │ ├── test.py # Testing & inference script
255
308
  │ ├── hpo.py # Hyperparameter optimization
309
+ │ ├── hpc.py # HPC distributed training launcher
256
310
  │ │
257
311
  │ ├── models/ # Model architectures
258
312
  │ │ ├── registry.py # Model factory (@register_model)
@@ -274,7 +328,6 @@ WaveDL/
274
328
  │ ├── schedulers.py # LR scheduler factory
275
329
  │ └── config.py # YAML configuration support
276
330
 
277
- ├── run_training.sh # HPC helper script
278
331
  ├── configs/ # YAML config templates
279
332
  ├── examples/ # Ready-to-run examples
280
333
  ├── notebooks/ # Jupyter notebooks
@@ -289,12 +342,12 @@ WaveDL/
289
342
  ## ⚙️ Configuration
290
343
 
291
344
  > [!NOTE]
292
- > All configuration options below work with **both** `run_training.sh` and direct `accelerate launch`. The wrapper script passes all arguments directly to `train.py`.
345
+ > All configuration options below work with **both** `wavedl-hpc` and direct `accelerate launch`. The wrapper script passes all arguments directly to `train.py`.
293
346
  >
294
347
  > **Examples:**
295
348
  > ```bash
296
- > # Using run_training.sh
297
- > ./run_training.sh --model cnn --batch_size 256 --lr 5e-4 --compile
349
+ > # Using wavedl-hpc
350
+ > wavedl-hpc --model cnn --batch_size 256 --lr 5e-4 --compile
298
351
  >
299
352
  > # Using accelerate launch directly
300
353
  > accelerate launch -m wavedl.train --model cnn --batch_size 256 --lr 5e-4 --compile
@@ -337,6 +390,7 @@ WaveDL/
337
390
  | Argument | Default | Description |
338
391
  |----------|---------|-------------|
339
392
  | `--model` | `cnn` | Model architecture |
393
+ | `--import` | - | Python modules to import (for custom models) |
340
394
  | `--batch_size` | `128` | Per-GPU batch size |
341
395
  | `--lr` | `1e-3` | Learning rate |
342
396
  | `--epochs` | `1000` | Maximum epochs |
@@ -376,7 +430,7 @@ WaveDL/
376
430
  </details>
377
431
 
378
432
  <details>
379
- <summary><b>Environment Variables (run_training.sh)</b></summary>
433
+ <summary><b>Environment Variables (wavedl-hpc)</b></summary>
380
434
 
381
435
  | Variable | Default | Description |
382
436
  |----------|---------|-------------|
@@ -469,15 +523,15 @@ For robust model evaluation, simply add the `--cv` flag:
469
523
 
470
524
  ```bash
471
525
  # 5-fold cross-validation (works with both methods!)
472
- ./run_training.sh --model cnn --cv 5 --data_path train_data.npz
526
+ wavedl-hpc --model cnn --cv 5 --data_path train_data.npz
473
527
  # OR
474
528
  accelerate launch -m wavedl.train --model cnn --cv 5 --data_path train_data.npz
475
529
 
476
530
  # Stratified CV (recommended for unbalanced data)
477
- ./run_training.sh --model cnn --cv 5 --cv_stratify --loss huber --epochs 100
531
+ wavedl-hpc --model cnn --cv 5 --cv_stratify --loss huber --epochs 100
478
532
 
479
533
  # Full configuration
480
- ./run_training.sh --model cnn --cv 5 --cv_stratify \
534
+ wavedl-hpc --model cnn --cv 5 --cv_stratify \
481
535
  --loss huber --optimizer adamw --scheduler cosine \
482
536
  --output_dir ./cv_results
483
537
  ```
@@ -67,12 +67,12 @@ dependencies = [
67
67
  [project.optional-dependencies]
68
68
  dev = ["pytest>=7.0.0", "pytest-xdist>=3.5.0", "ruff>=0.8.0", "pre-commit>=3.5.0"]
69
69
  onnx = ["onnx>=1.14.0", "onnxruntime>=1.15.0"]
70
- compile = ["triton"] # Version resolved by PyTorch compatibility
70
+ compile = ["triton; sys_platform == 'linux'"] # Linux-only, enables torch.compile
71
71
  hpo = ["optuna>=3.0.0"] # Hyperparameter optimization
72
72
  all = [
73
73
  "pytest>=7.0.0", "pytest-xdist>=3.5.0", "ruff>=0.8.0", "pre-commit>=3.5.0",
74
74
  "onnx>=1.14.0", "onnxruntime>=1.15.0",
75
- "triton",
75
+ "triton; sys_platform == 'linux'",
76
76
  "optuna>=3.0.0",
77
77
  ]
78
78
 
@@ -80,6 +80,7 @@ all = [
80
80
  wavedl-train = "wavedl.train:main"
81
81
  wavedl-test = "wavedl.test:main"
82
82
  wavedl-hpo = "wavedl.hpo:main"
83
+ wavedl-hpc = "wavedl.hpc:main"
83
84
 
84
85
  [project.urls]
85
86
  Homepage = "https://github.com/ductho-le/WaveDL"
@@ -18,7 +18,7 @@ For inference:
18
18
  # or: python -m wavedl.test --checkpoint best_checkpoint --data_path test.npz
19
19
  """
20
20
 
21
- __version__ = "1.2.0"
21
+ __version__ = "1.3.0"
22
22
  __author__ = "Ductho Le"
23
23
  __email__ = "ductho.le@outlook.com"
24
24
 
@@ -0,0 +1,243 @@
1
+ #!/usr/bin/env python
2
+ """
3
+ WaveDL HPC Training Launcher.
4
+
5
+ This module provides a Python-based HPC training launcher that wraps accelerate
6
+ for distributed training on High-Performance Computing clusters.
7
+
8
+ Usage:
9
+ wavedl-hpc --model cnn --data_path train.npz --num_gpus 4
10
+
11
+ Example SLURM script:
12
+ #!/bin/bash
13
+ #SBATCH --nodes=1
14
+ #SBATCH --gpus-per-node=4
15
+ #SBATCH --time=12:00:00
16
+
17
+ wavedl-hpc --model cnn --data_path /scratch/data.npz --compile
18
+
19
+ Author: Ductho Le (ductho.le@outlook.com)
20
+ """
21
+
22
+ from __future__ import annotations
23
+
24
+ import argparse
25
+ import os
26
+ import shutil
27
+ import subprocess
28
+ import sys
29
+ import tempfile
30
+ from pathlib import Path
31
+
32
+
33
+ def detect_gpus() -> int:
34
+ """Auto-detect available GPUs using nvidia-smi."""
35
+ if shutil.which("nvidia-smi") is None:
36
+ print("Warning: nvidia-smi not found, defaulting to 1 GPU")
37
+ return 1
38
+
39
+ try:
40
+ result = subprocess.run(
41
+ ["nvidia-smi", "--list-gpus"],
42
+ capture_output=True,
43
+ text=True,
44
+ check=True,
45
+ )
46
+ gpu_count = len(result.stdout.strip().split("\n"))
47
+ if gpu_count > 0:
48
+ print(f"Auto-detected {gpu_count} GPU(s)")
49
+ return gpu_count
50
+ except (subprocess.CalledProcessError, FileNotFoundError):
51
+ pass
52
+
53
+ print("Warning: Could not detect GPUs, defaulting to 1")
54
+ return 1
55
+
56
+
57
+ def setup_hpc_environment() -> None:
58
+ """Configure environment variables for HPC systems.
59
+
60
+ Handles restricted home directories (e.g., Compute Canada) and
61
+ offline logging configurations.
62
+ """
63
+ # Use SLURM_TMPDIR if available, otherwise system temp
64
+ tmpdir = os.environ.get("SLURM_TMPDIR", tempfile.gettempdir())
65
+
66
+ # Configure directories for systems with restricted home directories
67
+ os.environ.setdefault("MPLCONFIGDIR", f"{tmpdir}/matplotlib")
68
+ os.environ.setdefault("XDG_CACHE_HOME", f"{tmpdir}/.cache")
69
+
70
+ # Ensure matplotlib config dir exists
71
+ Path(os.environ["MPLCONFIGDIR"]).mkdir(parents=True, exist_ok=True)
72
+
73
+ # WandB configuration (offline by default for HPC)
74
+ os.environ.setdefault("WANDB_MODE", "offline")
75
+ os.environ.setdefault("WANDB_DIR", f"{tmpdir}/wandb")
76
+ os.environ.setdefault("WANDB_CACHE_DIR", f"{tmpdir}/wandb_cache")
77
+ os.environ.setdefault("WANDB_CONFIG_DIR", f"{tmpdir}/wandb_config")
78
+
79
+ # Suppress non-critical warnings
80
+ os.environ.setdefault(
81
+ "PYTHONWARNINGS",
82
+ "ignore::UserWarning,ignore::FutureWarning,ignore::DeprecationWarning",
83
+ )
84
+
85
+
86
+ def parse_args() -> tuple[argparse.Namespace, list[str]]:
87
+ """Parse HPC-specific arguments, pass remaining to wavedl.train."""
88
+ parser = argparse.ArgumentParser(
89
+ description="WaveDL HPC Training Launcher",
90
+ formatter_class=argparse.RawDescriptionHelpFormatter,
91
+ epilog="""
92
+ Examples:
93
+ # Basic training with auto-detected GPUs
94
+ wavedl-hpc --model cnn --data_path train.npz --epochs 100
95
+
96
+ # Specify GPU count and mixed precision
97
+ wavedl-hpc --model cnn --data_path train.npz --num_gpus 4 --mixed_precision bf16
98
+
99
+ # Full configuration
100
+ wavedl-hpc --model resnet18 --data_path train.npz --num_gpus 8 \\
101
+ --batch_size 256 --lr 1e-3 --compile --output_dir ./results
102
+
103
+ Environment Variables:
104
+ WANDB_MODE WandB mode: offline|online (default: offline)
105
+ SLURM_TMPDIR Temp directory for HPC systems
106
+ """,
107
+ )
108
+
109
+ # HPC-specific arguments
110
+ parser.add_argument(
111
+ "--num_gpus",
112
+ type=int,
113
+ default=None,
114
+ help="Number of GPUs to use (default: auto-detect)",
115
+ )
116
+ parser.add_argument(
117
+ "--num_machines",
118
+ type=int,
119
+ default=1,
120
+ help="Number of machines for multi-node training (default: 1)",
121
+ )
122
+ parser.add_argument(
123
+ "--machine_rank",
124
+ type=int,
125
+ default=0,
126
+ help="Rank of this machine in multi-node setup (default: 0)",
127
+ )
128
+ parser.add_argument(
129
+ "--mixed_precision",
130
+ type=str,
131
+ choices=["bf16", "fp16", "no"],
132
+ default="bf16",
133
+ help="Mixed precision mode (default: bf16)",
134
+ )
135
+ parser.add_argument(
136
+ "--dynamo_backend",
137
+ type=str,
138
+ default="no",
139
+ help="PyTorch dynamo backend (default: no)",
140
+ )
141
+
142
+ # Parse known args, pass rest to wavedl.train
143
+ args, remaining = parser.parse_known_args()
144
+ return args, remaining
145
+
146
+
147
+ def print_summary(exit_code: int, wandb_mode: str, wandb_dir: str) -> None:
148
+ """Print post-training summary and instructions."""
149
+ print()
150
+ print("=" * 50)
151
+
152
+ if exit_code == 0:
153
+ print("✅ Training completed successfully!")
154
+ print("=" * 50)
155
+
156
+ if wandb_mode == "offline":
157
+ print()
158
+ print("📊 WandB Sync Instructions:")
159
+ print(" From the login node, run:")
160
+ print(f" wandb sync {wandb_dir}/wandb/offline-run-*")
161
+ print()
162
+ print(" This will upload your training logs to wandb.ai")
163
+ else:
164
+ print(f"❌ Training failed with exit code: {exit_code}")
165
+ print("=" * 50)
166
+ print()
167
+ print("Common issues:")
168
+ print(" - Missing data file (check --data_path)")
169
+ print(" - Insufficient GPU memory (reduce --batch_size)")
170
+ print(" - Invalid model name (run: wavedl-train --list_models)")
171
+ print()
172
+
173
+ print("=" * 50)
174
+ print()
175
+
176
+
177
+ def main() -> int:
178
+ """Main entry point for wavedl-hpc command."""
179
+ # Parse arguments
180
+ args, train_args = parse_args()
181
+
182
+ # Setup HPC environment
183
+ setup_hpc_environment()
184
+
185
+ # Auto-detect GPUs if not specified
186
+ num_gpus = args.num_gpus if args.num_gpus is not None else detect_gpus()
187
+
188
+ # Build accelerate launch command
189
+ cmd = [
190
+ sys.executable,
191
+ "-m",
192
+ "accelerate.commands.launch",
193
+ f"--num_processes={num_gpus}",
194
+ f"--num_machines={args.num_machines}",
195
+ f"--machine_rank={args.machine_rank}",
196
+ f"--mixed_precision={args.mixed_precision}",
197
+ f"--dynamo_backend={args.dynamo_backend}",
198
+ "-m",
199
+ "wavedl.train",
200
+ ] + train_args
201
+
202
+ # Create output directory if specified
203
+ for i, arg in enumerate(train_args):
204
+ if arg == "--output_dir" and i + 1 < len(train_args):
205
+ Path(train_args[i + 1]).mkdir(parents=True, exist_ok=True)
206
+ break
207
+ if arg.startswith("--output_dir="):
208
+ Path(arg.split("=", 1)[1]).mkdir(parents=True, exist_ok=True)
209
+ break
210
+
211
+ # Print launch configuration
212
+ print()
213
+ print("=" * 50)
214
+ print("🚀 WaveDL HPC Training Launcher")
215
+ print("=" * 50)
216
+ print(f" GPUs: {num_gpus}")
217
+ print(f" Machines: {args.num_machines}")
218
+ print(f" Mixed Precision: {args.mixed_precision}")
219
+ print(f" Dynamo Backend: {args.dynamo_backend}")
220
+ print(f" WandB Mode: {os.environ.get('WANDB_MODE', 'offline')}")
221
+ print("=" * 50)
222
+ print()
223
+
224
+ # Launch training
225
+ try:
226
+ result = subprocess.run(cmd, check=False)
227
+ exit_code = result.returncode
228
+ except KeyboardInterrupt:
229
+ print("\n\n⚠️ Training interrupted by user")
230
+ exit_code = 130
231
+
232
+ # Print summary
233
+ print_summary(
234
+ exit_code,
235
+ os.environ.get("WANDB_MODE", "offline"),
236
+ os.environ.get("WANDB_DIR", "/tmp/wandb"),
237
+ )
238
+
239
+ return exit_code
240
+
241
+
242
+ if __name__ == "__main__":
243
+ sys.exit(main())
@@ -5,16 +5,16 @@ Automated hyperparameter search for finding optimal training configurations.
5
5
 
6
6
  Usage:
7
7
  # Basic HPO (50 trials)
8
- python hpo.py --data_path train.npz --n_trials 50
8
+ wavedl-hpo --data_path train.npz --n_trials 50
9
9
 
10
10
  # Quick search (fewer parameters)
11
- python hpo.py --data_path train.npz --n_trials 30 --quick
11
+ wavedl-hpo --data_path train.npz --n_trials 30 --quick
12
12
 
13
13
  # Full search with specific models
14
- python hpo.py --data_path train.npz --n_trials 100 --models cnn resnet18 efficientnet_b0
14
+ wavedl-hpo --data_path train.npz --n_trials 100 --models cnn resnet18 efficientnet_b0
15
15
 
16
16
  # Parallel trials on multiple GPUs
17
- python hpo.py --data_path train.npz --n_trials 100 --n_jobs 4
17
+ wavedl-hpo --data_path train.npz --n_trials 100 --n_jobs 4
18
18
 
19
19
  Author: Ductho Le (ductho.le@outlook.com)
20
20
  """
@@ -205,9 +205,9 @@ def main():
205
205
  formatter_class=argparse.RawDescriptionHelpFormatter,
206
206
  epilog="""
207
207
  Examples:
208
- python hpo.py --data_path train.npz --n_trials 50
209
- python hpo.py --data_path train.npz --n_trials 30 --quick
210
- python hpo.py --data_path train.npz --n_trials 100 --models cnn resnet18
208
+ wavedl-hpo --data_path train.npz --n_trials 50
209
+ wavedl-hpo --data_path train.npz --n_trials 30 --quick
210
+ wavedl-hpo --data_path train.npz --n_trials 100 --models cnn resnet18
211
211
  """,
212
212
  )
213
213
 
@@ -355,7 +355,7 @@ Examples:
355
355
  print("\n" + "=" * 60)
356
356
  print("TO TRAIN WITH BEST PARAMETERS:")
357
357
  print("=" * 60)
358
- cmd_parts = ["accelerate launch train.py"]
358
+ cmd_parts = ["accelerate launch -m wavedl.train"]
359
359
  cmd_parts.append(f"--data_path {args.data_path}")
360
360
  for key, value in study.best_params.items():
361
361
  cmd_parts.append(f"--{key} {value}")
@@ -11,7 +11,7 @@ Steps to Add a New Model:
11
11
  3. Implement the __init__ and forward methods
12
12
  4. Import your model in models/__init__.py:
13
13
  from wavedl.models.your_model import YourModel
14
- 5. Run: accelerate launch train.py --model your_model --wandb
14
+ 5. Run: accelerate launch -m wavedl.train --model your_model --wandb
15
15
 
16
16
  Author: Ductho Le (ductho.le@outlook.com)
17
17
  Version: 1.0.0
@@ -13,14 +13,14 @@ Production-grade inference script for evaluating trained WaveDL models:
13
13
 
14
14
  Usage:
15
15
  # Basic inference
16
- python test.py --checkpoint ./best_checkpoint --data_path test_data.npz
16
+ wavedl-test --checkpoint ./best_checkpoint --data_path test_data.npz
17
17
 
18
18
  # With visualization and detailed output
19
- python test.py --checkpoint ./best_checkpoint --data_path test_data.npz \\
19
+ wavedl-test --checkpoint ./best_checkpoint --data_path test_data.npz \\
20
20
  --plot --plot_format png pdf --output_dir ./test_results --save_predictions
21
21
 
22
22
  # Export model to ONNX for deployment
23
- python test.py --checkpoint ./best_checkpoint --data_path test_data.npz \\
23
+ wavedl-test --checkpoint ./best_checkpoint --data_path test_data.npz \\
24
24
  --export onnx --export_path model.onnx
25
25
 
26
26
  Author: Ductho Le (ductho.le@outlook.com)
@@ -12,26 +12,25 @@ A modular training framework for wave-based inverse problems and regression:
12
12
  6. Deep Observability: WandB integration with scatter analysis
13
13
 
14
14
  Usage:
15
- # Recommended: Using the HPC helper script
16
- ./run_training.sh --model cnn --batch_size 128 --wandb
15
+ # Recommended: Using the HPC launcher
16
+ wavedl-hpc --model cnn --batch_size 128 --wandb
17
17
 
18
18
  # Or with direct accelerate launch
19
- accelerate launch train.py --model cnn --batch_size 128 --wandb
19
+ accelerate launch -m wavedl.train --model cnn --batch_size 128 --wandb
20
20
 
21
21
  # Multi-GPU with explicit config
22
- accelerate launch --num_processes=4 --mixed_precision=bf16 \
23
- train.py --model cnn --wandb --project_name "MyProject"
22
+ wavedl-hpc --num_gpus 4 --mixed_precision bf16 --model cnn --wandb
24
23
 
25
24
  # Resume from checkpoint
26
- accelerate launch train.py --model cnn --resume best_checkpoint --wandb
25
+ accelerate launch -m wavedl.train --model cnn --resume best_checkpoint --wandb
27
26
 
28
27
  # List available models
29
- python train.py --list_models
28
+ wavedl-train --list_models
30
29
 
31
30
  Note:
32
- For HPC clusters (Compute Canada, etc.), use run_training.sh which handles
31
+ For HPC clusters (Compute Canada, etc.), use wavedl-hpc which handles
33
32
  environment configuration automatically. Mixed precision is controlled via
34
- --precision flag (default: bf16).
33
+ --mixed_precision flag (default: bf16).
35
34
 
36
35
  Author: Ductho Le (ductho.le@outlook.com)
37
36
  """
@@ -122,6 +121,14 @@ def parse_args() -> argparse.Namespace:
122
121
  parser.add_argument(
123
122
  "--list_models", action="store_true", help="List all available models and exit"
124
123
  )
124
+ parser.add_argument(
125
+ "--import",
126
+ dest="import_modules",
127
+ type=str,
128
+ nargs="+",
129
+ default=[],
130
+ help="Python modules to import before training (for custom models)",
131
+ )
125
132
 
126
133
  # Configuration File
127
134
  parser.add_argument(
@@ -314,6 +321,37 @@ def parse_args() -> argparse.Namespace:
314
321
  def main():
315
322
  args, parser = parse_args()
316
323
 
324
+ # Import custom model modules if specified
325
+ if args.import_modules:
326
+ import importlib
327
+ import sys
328
+
329
+ for module_name in args.import_modules:
330
+ try:
331
+ # Handle both module names (my_model) and file paths (./my_model.py)
332
+ if module_name.endswith(".py"):
333
+ # Import from file path
334
+ import importlib.util
335
+
336
+ spec = importlib.util.spec_from_file_location(
337
+ "custom_module", module_name
338
+ )
339
+ if spec and spec.loader:
340
+ module = importlib.util.module_from_spec(spec)
341
+ sys.modules["custom_module"] = module
342
+ spec.loader.exec_module(module)
343
+ print(f"✓ Imported custom module from: {module_name}")
344
+ else:
345
+ # Import as regular module
346
+ importlib.import_module(module_name)
347
+ print(f"✓ Imported module: {module_name}")
348
+ except ImportError as e:
349
+ print(f"✗ Failed to import '{module_name}': {e}", file=sys.stderr)
350
+ print(
351
+ " Make sure the module is in your Python path or current directory."
352
+ )
353
+ sys.exit(1)
354
+
317
355
  # Handle --list_models flag
318
356
  if args.list_models:
319
357
  print("Available models:")
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.2
2
2
  Name: wavedl
3
- Version: 1.2.0
3
+ Version: 1.3.0
4
4
  Summary: A Scalable Deep Learning Framework for Wave-Based Inverse Problems
5
5
  Author: Ductho Le
6
6
  License: MIT
@@ -43,7 +43,7 @@ Provides-Extra: onnx
43
43
  Requires-Dist: onnx>=1.14.0; extra == "onnx"
44
44
  Requires-Dist: onnxruntime>=1.15.0; extra == "onnx"
45
45
  Provides-Extra: compile
46
- Requires-Dist: triton; extra == "compile"
46
+ Requires-Dist: triton; sys_platform == "linux" and extra == "compile"
47
47
  Provides-Extra: hpo
48
48
  Requires-Dist: optuna>=3.0.0; extra == "hpo"
49
49
  Provides-Extra: all
@@ -53,7 +53,7 @@ Requires-Dist: ruff>=0.8.0; extra == "all"
53
53
  Requires-Dist: pre-commit>=3.5.0; extra == "all"
54
54
  Requires-Dist: onnx>=1.14.0; extra == "all"
55
55
  Requires-Dist: onnxruntime>=1.15.0; extra == "all"
56
- Requires-Dist: triton; extra == "all"
56
+ Requires-Dist: triton; sys_platform == "linux" and extra == "all"
57
57
  Requires-Dist: optuna>=3.0.0; extra == "all"
58
58
 
59
59
  <div align="center">
@@ -211,40 +211,43 @@ Deploy models anywhere:
211
211
  ### Installation
212
212
 
213
213
  ```bash
214
- git clone https://github.com/ductho-le/WaveDL.git
215
- cd WaveDL
214
+ # Install from PyPI (recommended)
215
+ pip install wavedl
216
+
217
+ # Or install with all extras (ONNX export, HPO, dev tools)
218
+ pip install wavedl[all]
219
+ ```
216
220
 
217
- # Basic install (training + inference)
218
- pip install -e .
221
+ #### From Source (for development)
219
222
 
220
- # Full install (adds ONNX export, torch.compile, HPO, dev tools)
221
- pip install -e ".[all]"
223
+ ```bash
224
+ git clone https://github.com/ductho-le/WaveDL.git
225
+ cd WaveDL
226
+ pip install -e ".[dev]"
222
227
  ```
223
228
 
224
229
  > [!NOTE]
225
- > Dependencies are managed in `pyproject.toml`. Python 3.11+ required.
226
- >
227
- > For development setup (running tests, contributing), see [CONTRIBUTING.md](.github/CONTRIBUTING.md).
230
+ > Python 3.11+ required. For development setup, see [CONTRIBUTING.md](.github/CONTRIBUTING.md).
228
231
 
229
232
  ### Quick Start
230
233
 
231
234
  > [!TIP]
232
235
  > In all examples below, replace `<...>` placeholders with your values. See [Configuration](#️-configuration) for defaults and options.
233
236
 
234
- #### Option 1: Using the Helper Script (Recommended for HPC)
237
+ #### Option 1: Using wavedl-hpc (Recommended for HPC)
235
238
 
236
- The `run_training.sh` wrapper automatically configures the environment for HPC systems:
239
+ The `wavedl-hpc` command automatically configures the environment for HPC systems:
237
240
 
238
241
  ```bash
239
- # Make executable (first time only)
240
- chmod +x run_training.sh
241
-
242
242
  # Basic training (auto-detects available GPUs)
243
- ./run_training.sh --model <model_name> --data_path <train_data> --batch_size <number> --output_dir <output_folder>
243
+ wavedl-hpc --model <model_name> --data_path <train_data> --batch_size <number> --output_dir <output_folder>
244
244
 
245
245
  # Detailed configuration
246
- ./run_training.sh --model <model_name> --data_path <train_data> --batch_size <number> \
246
+ wavedl-hpc --model <model_name> --data_path <train_data> --batch_size <number> \
247
247
  --lr <number> --epochs <number> --patience <number> --compile --output_dir <output_folder>
248
+
249
+ # Specify GPU count explicitly
250
+ wavedl-hpc --num_gpus 4 --model cnn --data_path train.npz --output_dir results
248
251
  ```
249
252
 
250
253
  #### Option 2: Direct Accelerate Launch
@@ -261,13 +264,13 @@ accelerate launch -m wavedl.train --model <model_name> --data_path <train_data>
261
264
  accelerate launch -m wavedl.train --model <model_name> --data_path <train_data> --output_dir <output_folder> --fresh
262
265
 
263
266
  # List available models
264
- python -m wavedl.train --list_models
267
+ wavedl-train --list_models
265
268
  ```
266
269
 
267
270
  > [!TIP]
268
271
  > **Auto-Resume**: If training crashes or is interrupted, simply re-run with the same `--output_dir`. The framework automatically detects incomplete training and resumes from the last checkpoint. Use `--fresh` to force a fresh start.
269
272
  >
270
- > **GPU Auto-Detection**: By default, `run_training.sh` automatically detects available GPUs using `nvidia-smi`. Set `NUM_GPUS` to override this behavior.
273
+ > **GPU Auto-Detection**: `wavedl-hpc` automatically detects available GPUs using `nvidia-smi`. Use `--num_gpus` to override.
271
274
 
272
275
  ### Testing & Inference
273
276
 
@@ -299,6 +302,56 @@ python -m wavedl.test --checkpoint <checkpoint_folder> --data_path <test_data> \
299
302
  > [!NOTE]
300
303
  > `wavedl.test` auto-detects the model architecture from checkpoint metadata. If unavailable, it falls back to folder name parsing. Use `--model` to override if needed.
301
304
 
305
+ ### Adding Custom Models
306
+
307
+ <details>
308
+ <summary><b>Creating Your Own Architecture</b></summary>
309
+
310
+ **Requirements** (your model must):
311
+ 1. Inherit from `BaseModel`
312
+ 2. Accept `in_channels`, `num_outputs`, `input_shape` in `__init__`
313
+ 3. Return a tensor of shape `(batch, num_outputs)` from `forward()`
314
+
315
+ ---
316
+
317
+ **Step 1: Create `my_model.py`**
318
+
319
+ ```python
320
+ import torch.nn as nn
321
+ import torch.nn.functional as F
322
+ from wavedl.models import BaseModel, register_model
323
+
324
+ @register_model("my_model") # This name is used with --model flag
325
+ class MyModel(BaseModel):
326
+ def __init__(self, in_channels, num_outputs, input_shape):
327
+ # in_channels: number of input channels (auto-detected from data)
328
+ # num_outputs: number of parameters to predict (auto-detected from data)
329
+ # input_shape: spatial dimensions, e.g., (128,) or (64, 64) or (32, 32, 32)
330
+ super().__init__(in_channels, num_outputs, input_shape)
331
+
332
+ # Define your layers (this is just an example)
333
+ self.conv1 = nn.Conv2d(in_channels, 64, 3, padding=1)
334
+ self.conv2 = nn.Conv2d(64, 128, 3, padding=1)
335
+ self.fc = nn.Linear(128, num_outputs)
336
+
337
+ def forward(self, x):
338
+ # Input x has shape: (batch, in_channels, *input_shape)
339
+ x = F.relu(self.conv1(x))
340
+ x = F.relu(self.conv2(x))
341
+ x = x.mean(dim=[-2, -1]) # Global average pooling
342
+ return self.fc(x) # Output shape: (batch, num_outputs)
343
+ ```
344
+
345
+ **Step 2: Train**
346
+
347
+ ```bash
348
+ wavedl-hpc --import my_model --model my_model --data_path train.npz
349
+ ```
350
+
351
+ WaveDL handles everything else: training loop, logging, checkpoints, multi-GPU, early stopping, etc.
352
+
353
+ </details>
354
+
302
355
  ---
303
356
 
304
357
  ## 📁 Project Structure
@@ -311,6 +364,7 @@ WaveDL/
311
364
  │ ├── train.py # Training entry point
312
365
  │ ├── test.py # Testing & inference script
313
366
  │ ├── hpo.py # Hyperparameter optimization
367
+ │ ├── hpc.py # HPC distributed training launcher
314
368
  │ │
315
369
  │ ├── models/ # Model architectures
316
370
  │ │ ├── registry.py # Model factory (@register_model)
@@ -332,7 +386,6 @@ WaveDL/
332
386
  │ ├── schedulers.py # LR scheduler factory
333
387
  │ └── config.py # YAML configuration support
334
388
 
335
- ├── run_training.sh # HPC helper script
336
389
  ├── configs/ # YAML config templates
337
390
  ├── examples/ # Ready-to-run examples
338
391
  ├── notebooks/ # Jupyter notebooks
@@ -347,12 +400,12 @@ WaveDL/
347
400
  ## ⚙️ Configuration
348
401
 
349
402
  > [!NOTE]
350
- > All configuration options below work with **both** `run_training.sh` and direct `accelerate launch`. The wrapper script passes all arguments directly to `train.py`.
403
+ > All configuration options below work with **both** `wavedl-hpc` and direct `accelerate launch`. The wrapper script passes all arguments directly to `train.py`.
351
404
  >
352
405
  > **Examples:**
353
406
  > ```bash
354
- > # Using run_training.sh
355
- > ./run_training.sh --model cnn --batch_size 256 --lr 5e-4 --compile
407
+ > # Using wavedl-hpc
408
+ > wavedl-hpc --model cnn --batch_size 256 --lr 5e-4 --compile
356
409
  >
357
410
  > # Using accelerate launch directly
358
411
  > accelerate launch -m wavedl.train --model cnn --batch_size 256 --lr 5e-4 --compile
@@ -395,6 +448,7 @@ WaveDL/
395
448
  | Argument | Default | Description |
396
449
  |----------|---------|-------------|
397
450
  | `--model` | `cnn` | Model architecture |
451
+ | `--import` | - | Python modules to import (for custom models) |
398
452
  | `--batch_size` | `128` | Per-GPU batch size |
399
453
  | `--lr` | `1e-3` | Learning rate |
400
454
  | `--epochs` | `1000` | Maximum epochs |
@@ -434,7 +488,7 @@ WaveDL/
434
488
  </details>
435
489
 
436
490
  <details>
437
- <summary><b>Environment Variables (run_training.sh)</b></summary>
491
+ <summary><b>Environment Variables (wavedl-hpc)</b></summary>
438
492
 
439
493
  | Variable | Default | Description |
440
494
  |----------|---------|-------------|
@@ -527,15 +581,15 @@ For robust model evaluation, simply add the `--cv` flag:
527
581
 
528
582
  ```bash
529
583
  # 5-fold cross-validation (works with both methods!)
530
- ./run_training.sh --model cnn --cv 5 --data_path train_data.npz
584
+ wavedl-hpc --model cnn --cv 5 --data_path train_data.npz
531
585
  # OR
532
586
  accelerate launch -m wavedl.train --model cnn --cv 5 --data_path train_data.npz
533
587
 
534
588
  # Stratified CV (recommended for unbalanced data)
535
- ./run_training.sh --model cnn --cv 5 --cv_stratify --loss huber --epochs 100
589
+ wavedl-hpc --model cnn --cv 5 --cv_stratify --loss huber --epochs 100
536
590
 
537
591
  # Full configuration
538
- ./run_training.sh --model cnn --cv 5 --cv_stratify \
592
+ wavedl-hpc --model cnn --cv 5 --cv_stratify \
539
593
  --loss huber --optimizer adamw --scheduler cosine \
540
594
  --output_dir ./cv_results
541
595
  ```
@@ -2,6 +2,7 @@ LICENSE
2
2
  README.md
3
3
  pyproject.toml
4
4
  src/wavedl/__init__.py
5
+ src/wavedl/hpc.py
5
6
  src/wavedl/hpo.py
6
7
  src/wavedl/test.py
7
8
  src/wavedl/train.py
@@ -1,4 +1,5 @@
1
1
  [console_scripts]
2
+ wavedl-hpc = wavedl.hpc:main
2
3
  wavedl-hpo = wavedl.hpo:main
3
4
  wavedl-test = wavedl.test:main
4
5
  wavedl-train = wavedl.train:main
@@ -19,10 +19,14 @@ ruff>=0.8.0
19
19
  pre-commit>=3.5.0
20
20
  onnx>=1.14.0
21
21
  onnxruntime>=1.15.0
22
- triton
23
22
  optuna>=3.0.0
24
23
 
24
+ [all:sys_platform == "linux"]
25
+ triton
26
+
25
27
  [compile]
28
+
29
+ [compile:sys_platform == "linux"]
26
30
  triton
27
31
 
28
32
  [dev]
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes