wavedl 1.5.2__tar.gz → 1.5.4__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (43) hide show
  1. {wavedl-1.5.2/src/wavedl.egg-info → wavedl-1.5.4}/PKG-INFO +32 -27
  2. {wavedl-1.5.2 → wavedl-1.5.4}/README.md +26 -22
  3. {wavedl-1.5.2 → wavedl-1.5.4}/pyproject.toml +6 -3
  4. {wavedl-1.5.2 → wavedl-1.5.4}/src/wavedl/__init__.py +1 -1
  5. {wavedl-1.5.2 → wavedl-1.5.4}/src/wavedl/hpc.py +22 -18
  6. {wavedl-1.5.2 → wavedl-1.5.4}/src/wavedl/models/resnet.py +38 -9
  7. {wavedl-1.5.2 → wavedl-1.5.4}/src/wavedl/models/swin.py +31 -10
  8. {wavedl-1.5.2 → wavedl-1.5.4}/src/wavedl/train.py +89 -59
  9. {wavedl-1.5.2 → wavedl-1.5.4}/src/wavedl/utils/config.py +3 -1
  10. {wavedl-1.5.2 → wavedl-1.5.4}/src/wavedl/utils/cross_validation.py +11 -0
  11. {wavedl-1.5.2 → wavedl-1.5.4}/src/wavedl/utils/data.py +78 -2
  12. {wavedl-1.5.2 → wavedl-1.5.4/src/wavedl.egg-info}/PKG-INFO +32 -27
  13. {wavedl-1.5.2 → wavedl-1.5.4}/src/wavedl.egg-info/requires.txt +5 -3
  14. {wavedl-1.5.2 → wavedl-1.5.4}/LICENSE +0 -0
  15. {wavedl-1.5.2 → wavedl-1.5.4}/setup.cfg +0 -0
  16. {wavedl-1.5.2 → wavedl-1.5.4}/src/wavedl/hpo.py +0 -0
  17. {wavedl-1.5.2 → wavedl-1.5.4}/src/wavedl/models/__init__.py +0 -0
  18. {wavedl-1.5.2 → wavedl-1.5.4}/src/wavedl/models/_template.py +0 -0
  19. {wavedl-1.5.2 → wavedl-1.5.4}/src/wavedl/models/base.py +0 -0
  20. {wavedl-1.5.2 → wavedl-1.5.4}/src/wavedl/models/cnn.py +0 -0
  21. {wavedl-1.5.2 → wavedl-1.5.4}/src/wavedl/models/convnext.py +0 -0
  22. {wavedl-1.5.2 → wavedl-1.5.4}/src/wavedl/models/densenet.py +0 -0
  23. {wavedl-1.5.2 → wavedl-1.5.4}/src/wavedl/models/efficientnet.py +0 -0
  24. {wavedl-1.5.2 → wavedl-1.5.4}/src/wavedl/models/efficientnetv2.py +0 -0
  25. {wavedl-1.5.2 → wavedl-1.5.4}/src/wavedl/models/mobilenetv3.py +0 -0
  26. {wavedl-1.5.2 → wavedl-1.5.4}/src/wavedl/models/registry.py +0 -0
  27. {wavedl-1.5.2 → wavedl-1.5.4}/src/wavedl/models/regnet.py +0 -0
  28. {wavedl-1.5.2 → wavedl-1.5.4}/src/wavedl/models/resnet3d.py +0 -0
  29. {wavedl-1.5.2 → wavedl-1.5.4}/src/wavedl/models/tcn.py +0 -0
  30. {wavedl-1.5.2 → wavedl-1.5.4}/src/wavedl/models/unet.py +0 -0
  31. {wavedl-1.5.2 → wavedl-1.5.4}/src/wavedl/models/vit.py +0 -0
  32. {wavedl-1.5.2 → wavedl-1.5.4}/src/wavedl/test.py +0 -0
  33. {wavedl-1.5.2 → wavedl-1.5.4}/src/wavedl/utils/__init__.py +0 -0
  34. {wavedl-1.5.2 → wavedl-1.5.4}/src/wavedl/utils/constraints.py +0 -0
  35. {wavedl-1.5.2 → wavedl-1.5.4}/src/wavedl/utils/distributed.py +0 -0
  36. {wavedl-1.5.2 → wavedl-1.5.4}/src/wavedl/utils/losses.py +0 -0
  37. {wavedl-1.5.2 → wavedl-1.5.4}/src/wavedl/utils/metrics.py +0 -0
  38. {wavedl-1.5.2 → wavedl-1.5.4}/src/wavedl/utils/optimizers.py +0 -0
  39. {wavedl-1.5.2 → wavedl-1.5.4}/src/wavedl/utils/schedulers.py +0 -0
  40. {wavedl-1.5.2 → wavedl-1.5.4}/src/wavedl.egg-info/SOURCES.txt +0 -0
  41. {wavedl-1.5.2 → wavedl-1.5.4}/src/wavedl.egg-info/dependency_links.txt +0 -0
  42. {wavedl-1.5.2 → wavedl-1.5.4}/src/wavedl.egg-info/entry_points.txt +0 -0
  43. {wavedl-1.5.2 → wavedl-1.5.4}/src/wavedl.egg-info/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.2
2
2
  Name: wavedl
3
- Version: 1.5.2
3
+ Version: 1.5.4
4
4
  Summary: A Scalable Deep Learning Framework for Wave-Based Inverse Problems
5
5
  Author: Ductho Le
6
6
  License: MIT
@@ -37,11 +37,12 @@ Requires-Dist: wandb>=0.15.0
37
37
  Requires-Dist: optuna>=3.0.0
38
38
  Requires-Dist: onnx>=1.14.0
39
39
  Requires-Dist: onnxruntime>=1.15.0
40
- Requires-Dist: pytest>=7.0.0
41
- Requires-Dist: pytest-xdist>=3.5.0
42
- Requires-Dist: ruff>=0.8.0
43
- Requires-Dist: pre-commit>=3.5.0
44
40
  Requires-Dist: triton>=2.0.0; sys_platform == "linux"
41
+ Provides-Extra: dev
42
+ Requires-Dist: pytest>=7.0.0; extra == "dev"
43
+ Requires-Dist: pytest-xdist>=3.5.0; extra == "dev"
44
+ Requires-Dist: ruff>=0.8.0; extra == "dev"
45
+ Requires-Dist: pre-commit>=3.5.0; extra == "dev"
45
46
 
46
47
  <div align="center">
47
48
 
@@ -204,7 +205,7 @@ Deploy models anywhere:
204
205
  pip install wavedl
205
206
  ```
206
207
 
207
- This installs everything you need: training, inference, HPO, ONNX export, and dev tools.
208
+ This installs everything you need: training, inference, HPO, ONNX export.
208
209
 
209
210
  #### From Source (for development)
210
211
 
@@ -301,8 +302,8 @@ python -m wavedl.test --checkpoint <checkpoint_folder> --data_path <test_data> \
301
302
 
302
303
  **Requirements** (your model must):
303
304
  1. Inherit from `BaseModel`
304
- 2. Accept `in_channels`, `num_outputs`, `input_shape` in `__init__`
305
- 3. Return a tensor of shape `(batch, num_outputs)` from `forward()`
305
+ 2. Accept `in_shape`, `out_size` in `__init__`
306
+ 3. Return a tensor of shape `(batch, out_size)` from `forward()`
306
307
 
307
308
  ---
308
309
 
@@ -315,29 +316,28 @@ from wavedl.models import BaseModel, register_model
315
316
 
316
317
  @register_model("my_model") # This name is used with --model flag
317
318
  class MyModel(BaseModel):
318
- def __init__(self, in_channels, num_outputs, input_shape):
319
- # in_channels: number of input channels (auto-detected from data)
320
- # num_outputs: number of parameters to predict (auto-detected from data)
321
- # input_shape: spatial dimensions, e.g., (128,) or (64, 64) or (32, 32, 32)
322
- super().__init__(in_channels, num_outputs, input_shape)
323
-
324
- # Define your layers (this is just an example)
325
- self.conv1 = nn.Conv2d(in_channels, 64, 3, padding=1)
319
+ def __init__(self, in_shape, out_size, **kwargs):
320
+ # in_shape: spatial dimensions, e.g., (128,) or (64, 64) or (32, 32, 32)
321
+ # out_size: number of parameters to predict (auto-detected from data)
322
+ super().__init__(in_shape, out_size)
323
+
324
+ # Define your layers (this is just an example for 2D)
325
+ self.conv1 = nn.Conv2d(1, 64, 3, padding=1) # Input always has 1 channel
326
326
  self.conv2 = nn.Conv2d(64, 128, 3, padding=1)
327
- self.fc = nn.Linear(128, num_outputs)
327
+ self.fc = nn.Linear(128, out_size)
328
328
 
329
329
  def forward(self, x):
330
- # Input x has shape: (batch, in_channels, *input_shape)
330
+ # Input x has shape: (batch, 1, *in_shape)
331
331
  x = F.relu(self.conv1(x))
332
332
  x = F.relu(self.conv2(x))
333
333
  x = x.mean(dim=[-2, -1]) # Global average pooling
334
- return self.fc(x) # Output shape: (batch, num_outputs)
334
+ return self.fc(x) # Output shape: (batch, out_size)
335
335
  ```
336
336
 
337
337
  **Step 2: Train**
338
338
 
339
339
  ```bash
340
- wavedl-hpc --import my_model --model my_model --data_path train.npz
340
+ wavedl-hpc --import my_model.py --model my_model --data_path train.npz
341
341
  ```
342
342
 
343
343
  WaveDL handles everything else: training loop, logging, checkpoints, multi-GPU, early stopping, etc.
@@ -513,7 +513,7 @@ print('\\n✓ All pretrained weights cached!')
513
513
  | Argument | Default | Description |
514
514
  |----------|---------|-------------|
515
515
  | `--model` | `cnn` | Model architecture |
516
- | `--import` | - | Python modules to import (for custom models) |
516
+ | `--import` | - | Python file(s) to import for custom models (supports multiple) |
517
517
  | `--batch_size` | `128` | Per-GPU batch size |
518
518
  | `--lr` | `1e-3` | Learning rate |
519
519
  | `--epochs` | `1000` | Maximum epochs |
@@ -573,14 +573,19 @@ WaveDL automatically enables performance optimizations for modern GPUs:
573
573
  </details>
574
574
 
575
575
  <details>
576
- <summary><b>Environment Variables (wavedl-hpc)</b></summary>
576
+ <summary><b>HPC CLI Arguments (wavedl-hpc)</b></summary>
577
+
578
+ | Argument | Default | Description |
579
+ |----------|---------|-------------|
580
+ | `--num_gpus` | **Auto-detected** | Number of GPUs to use. By default, automatically detected via `nvidia-smi`. Set explicitly to override |
581
+ | `--num_machines` | `1` | Number of machines in distributed setup |
582
+ | `--mixed_precision` | `bf16` | Precision mode: `bf16`, `fp16`, or `no` |
583
+ | `--dynamo_backend` | `no` | PyTorch Dynamo backend |
584
+
585
+ **Environment Variables (for logging):**
577
586
 
578
587
  | Variable | Default | Description |
579
588
  |----------|---------|-------------|
580
- | `NUM_GPUS` | **Auto-detected** | Number of GPUs to use. By default, automatically detected via `nvidia-smi`. Set explicitly to override (e.g., `NUM_GPUS=2`) |
581
- | `NUM_MACHINES` | `1` | Number of machines in distributed setup |
582
- | `MIXED_PRECISION` | `bf16` | Precision mode: `bf16`, `fp16`, or `no` |
583
- | `DYNAMO_BACKEND` | `no` | PyTorch Dynamo backend |
584
589
  | `WANDB_MODE` | `offline` | WandB mode: `offline` or `online` |
585
590
 
586
591
  </details>
@@ -1219,6 +1224,6 @@ This research was enabled in part by support provided by [Compute Ontario](https
1219
1224
  [![Google Scholar](https://img.shields.io/badge/Google_Scholar-4285F4?style=plastic&logo=google-scholar&logoColor=white)](https://scholar.google.ca/citations?user=OlwMr9AAAAAJ)
1220
1225
  [![ResearchGate](https://img.shields.io/badge/ResearchGate-00CCBB?style=plastic&logo=researchgate&logoColor=white)](https://www.researchgate.net/profile/Ductho-Le)
1221
1226
 
1222
- <sub>Released under the MIT License</sub>
1227
+ <sub>May your signals be strong and your attenuation low 👋</sub>
1223
1228
 
1224
1229
  </div>
@@ -159,7 +159,7 @@ Deploy models anywhere:
159
159
  pip install wavedl
160
160
  ```
161
161
 
162
- This installs everything you need: training, inference, HPO, ONNX export, and dev tools.
162
+ This installs everything you need: training, inference, HPO, ONNX export.
163
163
 
164
164
  #### From Source (for development)
165
165
 
@@ -256,8 +256,8 @@ python -m wavedl.test --checkpoint <checkpoint_folder> --data_path <test_data> \
256
256
 
257
257
  **Requirements** (your model must):
258
258
  1. Inherit from `BaseModel`
259
- 2. Accept `in_channels`, `num_outputs`, `input_shape` in `__init__`
260
- 3. Return a tensor of shape `(batch, num_outputs)` from `forward()`
259
+ 2. Accept `in_shape`, `out_size` in `__init__`
260
+ 3. Return a tensor of shape `(batch, out_size)` from `forward()`
261
261
 
262
262
  ---
263
263
 
@@ -270,29 +270,28 @@ from wavedl.models import BaseModel, register_model
270
270
 
271
271
  @register_model("my_model") # This name is used with --model flag
272
272
  class MyModel(BaseModel):
273
- def __init__(self, in_channels, num_outputs, input_shape):
274
- # in_channels: number of input channels (auto-detected from data)
275
- # num_outputs: number of parameters to predict (auto-detected from data)
276
- # input_shape: spatial dimensions, e.g., (128,) or (64, 64) or (32, 32, 32)
277
- super().__init__(in_channels, num_outputs, input_shape)
278
-
279
- # Define your layers (this is just an example)
280
- self.conv1 = nn.Conv2d(in_channels, 64, 3, padding=1)
273
+ def __init__(self, in_shape, out_size, **kwargs):
274
+ # in_shape: spatial dimensions, e.g., (128,) or (64, 64) or (32, 32, 32)
275
+ # out_size: number of parameters to predict (auto-detected from data)
276
+ super().__init__(in_shape, out_size)
277
+
278
+ # Define your layers (this is just an example for 2D)
279
+ self.conv1 = nn.Conv2d(1, 64, 3, padding=1) # Input always has 1 channel
281
280
  self.conv2 = nn.Conv2d(64, 128, 3, padding=1)
282
- self.fc = nn.Linear(128, num_outputs)
281
+ self.fc = nn.Linear(128, out_size)
283
282
 
284
283
  def forward(self, x):
285
- # Input x has shape: (batch, in_channels, *input_shape)
284
+ # Input x has shape: (batch, 1, *in_shape)
286
285
  x = F.relu(self.conv1(x))
287
286
  x = F.relu(self.conv2(x))
288
287
  x = x.mean(dim=[-2, -1]) # Global average pooling
289
- return self.fc(x) # Output shape: (batch, num_outputs)
288
+ return self.fc(x) # Output shape: (batch, out_size)
290
289
  ```
291
290
 
292
291
  **Step 2: Train**
293
292
 
294
293
  ```bash
295
- wavedl-hpc --import my_model --model my_model --data_path train.npz
294
+ wavedl-hpc --import my_model.py --model my_model --data_path train.npz
296
295
  ```
297
296
 
298
297
  WaveDL handles everything else: training loop, logging, checkpoints, multi-GPU, early stopping, etc.
@@ -468,7 +467,7 @@ print('\\n✓ All pretrained weights cached!')
468
467
  | Argument | Default | Description |
469
468
  |----------|---------|-------------|
470
469
  | `--model` | `cnn` | Model architecture |
471
- | `--import` | - | Python modules to import (for custom models) |
470
+ | `--import` | - | Python file(s) to import for custom models (supports multiple) |
472
471
  | `--batch_size` | `128` | Per-GPU batch size |
473
472
  | `--lr` | `1e-3` | Learning rate |
474
473
  | `--epochs` | `1000` | Maximum epochs |
@@ -528,14 +527,19 @@ WaveDL automatically enables performance optimizations for modern GPUs:
528
527
  </details>
529
528
 
530
529
  <details>
531
- <summary><b>Environment Variables (wavedl-hpc)</b></summary>
530
+ <summary><b>HPC CLI Arguments (wavedl-hpc)</b></summary>
531
+
532
+ | Argument | Default | Description |
533
+ |----------|---------|-------------|
534
+ | `--num_gpus` | **Auto-detected** | Number of GPUs to use. By default, automatically detected via `nvidia-smi`. Set explicitly to override |
535
+ | `--num_machines` | `1` | Number of machines in distributed setup |
536
+ | `--mixed_precision` | `bf16` | Precision mode: `bf16`, `fp16`, or `no` |
537
+ | `--dynamo_backend` | `no` | PyTorch Dynamo backend |
538
+
539
+ **Environment Variables (for logging):**
532
540
 
533
541
  | Variable | Default | Description |
534
542
  |----------|---------|-------------|
535
- | `NUM_GPUS` | **Auto-detected** | Number of GPUs to use. By default, automatically detected via `nvidia-smi`. Set explicitly to override (e.g., `NUM_GPUS=2`) |
536
- | `NUM_MACHINES` | `1` | Number of machines in distributed setup |
537
- | `MIXED_PRECISION` | `bf16` | Precision mode: `bf16`, `fp16`, or `no` |
538
- | `DYNAMO_BACKEND` | `no` | PyTorch Dynamo backend |
539
543
  | `WANDB_MODE` | `offline` | WandB mode: `offline` or `online` |
540
544
 
541
545
  </details>
@@ -1174,6 +1178,6 @@ This research was enabled in part by support provided by [Compute Ontario](https
1174
1178
  [![Google Scholar](https://img.shields.io/badge/Google_Scholar-4285F4?style=plastic&logo=google-scholar&logoColor=white)](https://scholar.google.ca/citations?user=OlwMr9AAAAAJ)
1175
1179
  [![ResearchGate](https://img.shields.io/badge/ResearchGate-00CCBB?style=plastic&logo=researchgate&logoColor=white)](https://www.researchgate.net/profile/Ductho-Le)
1176
1180
 
1177
- <sub>Released under the MIT License</sub>
1181
+ <sub>May your signals be strong and your attenuation low 👋</sub>
1178
1182
 
1179
1183
  </div>
@@ -70,13 +70,16 @@ dependencies = [
70
70
  # ONNX export
71
71
  "onnx>=1.14.0",
72
72
  "onnxruntime>=1.15.0",
73
- # Development tools
73
+ # torch.compile backend (Linux only)
74
+ "triton>=2.0.0; sys_platform == 'linux'",
75
+ ]
76
+
77
+ [project.optional-dependencies]
78
+ dev = [
74
79
  "pytest>=7.0.0",
75
80
  "pytest-xdist>=3.5.0",
76
81
  "ruff>=0.8.0",
77
82
  "pre-commit>=3.5.0",
78
- # torch.compile backend (Linux only)
79
- "triton>=2.0.0; sys_platform == 'linux'",
80
83
  ]
81
84
 
82
85
  [project.scripts]
@@ -18,7 +18,7 @@ For inference:
18
18
  # or: python -m wavedl.test --checkpoint best_checkpoint --data_path test.npz
19
19
  """
20
20
 
21
- __version__ = "1.5.2"
21
+ __version__ = "1.5.4"
22
22
  __author__ = "Ductho Le"
23
23
  __email__ = "ductho.le@outlook.com"
24
24
 
@@ -57,30 +57,35 @@ def setup_hpc_environment() -> None:
57
57
  """Configure environment variables for HPC systems.
58
58
 
59
59
  Handles restricted home directories (e.g., Compute Canada) and
60
- offline logging configurations.
60
+ offline logging configurations. Always uses CWD-based TORCH_HOME
61
+ since compute nodes typically lack internet access.
61
62
  """
62
- # Check if home is writable
63
+ # Use CWD for cache base since HPC compute nodes typically lack internet
64
+ cache_base = os.getcwd()
65
+
66
+ # TORCH_HOME always set to CWD - compute nodes need pre-cached weights
67
+ os.environ.setdefault("TORCH_HOME", f"{cache_base}/.torch_cache")
68
+ Path(os.environ["TORCH_HOME"]).mkdir(parents=True, exist_ok=True)
69
+
70
+ # Triton/Inductor caches - prevents permission errors with --compile
71
+ # These MUST be set before any torch.compile calls
72
+ os.environ.setdefault("TRITON_CACHE_DIR", f"{cache_base}/.triton_cache")
73
+ os.environ.setdefault("TORCHINDUCTOR_CACHE_DIR", f"{cache_base}/.inductor_cache")
74
+ Path(os.environ["TRITON_CACHE_DIR"]).mkdir(parents=True, exist_ok=True)
75
+ Path(os.environ["TORCHINDUCTOR_CACHE_DIR"]).mkdir(parents=True, exist_ok=True)
76
+
77
+ # Check if home is writable for other caches
63
78
  home = os.path.expanduser("~")
64
79
  home_writable = os.access(home, os.W_OK)
65
80
 
66
- # Use SLURM_TMPDIR if available, otherwise CWD for HPC, or system temp
67
- if home_writable:
68
- # Local machine - let libraries use defaults
69
- cache_base = None
70
- else:
71
- # HPC with restricted home - use CWD for persistent caches
72
- cache_base = os.getcwd()
73
-
74
- # Only set environment variables if home is not writable
75
- if cache_base:
76
- os.environ.setdefault("TORCH_HOME", f"{cache_base}/.torch_cache")
81
+ # Other caches only if home is not writable
82
+ if not home_writable:
77
83
  os.environ.setdefault("MPLCONFIGDIR", f"{cache_base}/.matplotlib")
78
84
  os.environ.setdefault("FONTCONFIG_CACHE", f"{cache_base}/.fontconfig")
79
85
  os.environ.setdefault("XDG_CACHE_HOME", f"{cache_base}/.cache")
80
86
 
81
87
  # Ensure directories exist
82
88
  for env_var in [
83
- "TORCH_HOME",
84
89
  "MPLCONFIGDIR",
85
90
  "FONTCONFIG_CACHE",
86
91
  "XDG_CACHE_HOME",
@@ -89,10 +94,9 @@ def setup_hpc_environment() -> None:
89
94
 
90
95
  # WandB configuration (offline by default for HPC)
91
96
  os.environ.setdefault("WANDB_MODE", "offline")
92
- if cache_base:
93
- os.environ.setdefault("WANDB_DIR", f"{cache_base}/.wandb")
94
- os.environ.setdefault("WANDB_CACHE_DIR", f"{cache_base}/.wandb_cache")
95
- os.environ.setdefault("WANDB_CONFIG_DIR", f"{cache_base}/.wandb_config")
97
+ os.environ.setdefault("WANDB_DIR", f"{cache_base}/.wandb")
98
+ os.environ.setdefault("WANDB_CACHE_DIR", f"{cache_base}/.wandb_cache")
99
+ os.environ.setdefault("WANDB_CONFIG_DIR", f"{cache_base}/.wandb_config")
96
100
 
97
101
  # Suppress non-critical warnings
98
102
  os.environ.setdefault(
@@ -49,6 +49,36 @@ def _get_conv_layers(
49
49
  raise ValueError(f"Unsupported dimensionality: {dim}D. Supported: 1D, 2D, 3D.")
50
50
 
51
51
 
52
+ def _get_num_groups(num_channels: int, preferred_groups: int = 32) -> int:
53
+ """
54
+ Get valid num_groups for GroupNorm that divides num_channels evenly.
55
+
56
+ Args:
57
+ num_channels: Number of channels to normalize
58
+ preferred_groups: Preferred number of groups (default: 32)
59
+
60
+ Returns:
61
+ Valid num_groups that divides num_channels
62
+
63
+ Raises:
64
+ ValueError: If no valid divisor found (shouldn't happen with power-of-2 channels)
65
+ """
66
+ # Try preferred groups first, then decrease
67
+ for groups in [preferred_groups, 16, 8, 4, 2, 1]:
68
+ if groups <= num_channels and num_channels % groups == 0:
69
+ return groups
70
+
71
+ # Fallback: find any valid divisor
72
+ for groups in range(min(32, num_channels), 0, -1):
73
+ if num_channels % groups == 0:
74
+ return groups
75
+
76
+ raise ValueError(
77
+ f"Cannot find valid num_groups for {num_channels} channels. "
78
+ f"Consider using base_width that is a power of 2 (e.g., 32, 64, 128)."
79
+ )
80
+
81
+
52
82
  class BasicBlock(nn.Module):
53
83
  """
54
84
  Basic residual block for ResNet-18/34.
@@ -77,12 +107,12 @@ class BasicBlock(nn.Module):
77
107
  padding=1,
78
108
  bias=False,
79
109
  )
80
- self.gn1 = nn.GroupNorm(min(32, out_channels), out_channels)
110
+ self.gn1 = nn.GroupNorm(_get_num_groups(out_channels), out_channels)
81
111
  self.relu = nn.ReLU(inplace=True)
82
112
  self.conv2 = Conv(
83
113
  out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False
84
114
  )
85
- self.gn2 = nn.GroupNorm(min(32, out_channels), out_channels)
115
+ self.gn2 = nn.GroupNorm(_get_num_groups(out_channels), out_channels)
86
116
  self.downsample = downsample
87
117
 
88
118
  def forward(self, x: torch.Tensor) -> torch.Tensor:
@@ -125,7 +155,7 @@ class Bottleneck(nn.Module):
125
155
 
126
156
  # 1x1 reduce
127
157
  self.conv1 = Conv(in_channels, out_channels, kernel_size=1, bias=False)
128
- self.gn1 = nn.GroupNorm(min(32, out_channels), out_channels)
158
+ self.gn1 = nn.GroupNorm(_get_num_groups(out_channels), out_channels)
129
159
 
130
160
  # 3x3 conv
131
161
  self.conv2 = Conv(
@@ -136,15 +166,14 @@ class Bottleneck(nn.Module):
136
166
  padding=1,
137
167
  bias=False,
138
168
  )
139
- self.gn2 = nn.GroupNorm(min(32, out_channels), out_channels)
169
+ self.gn2 = nn.GroupNorm(_get_num_groups(out_channels), out_channels)
140
170
 
141
171
  # 1x1 expand
142
172
  self.conv3 = Conv(
143
173
  out_channels, out_channels * self.expansion, kernel_size=1, bias=False
144
174
  )
145
- self.gn3 = nn.GroupNorm(
146
- min(32, out_channels * self.expansion), out_channels * self.expansion
147
- )
175
+ expanded_channels = out_channels * self.expansion
176
+ self.gn3 = nn.GroupNorm(_get_num_groups(expanded_channels), expanded_channels)
148
177
 
149
178
  self.relu = nn.ReLU(inplace=True)
150
179
  self.downsample = downsample
@@ -200,7 +229,7 @@ class ResNetBase(BaseModel):
200
229
 
201
230
  # Stem: 7x7 conv (or equivalent for 1D/3D)
202
231
  self.conv1 = Conv(1, base_width, kernel_size=7, stride=2, padding=3, bias=False)
203
- self.gn1 = nn.GroupNorm(min(32, base_width), base_width)
232
+ self.gn1 = nn.GroupNorm(_get_num_groups(base_width), base_width)
204
233
  self.relu = nn.ReLU(inplace=True)
205
234
  self.maxpool = MaxPool(kernel_size=3, stride=2, padding=1)
206
235
 
@@ -246,7 +275,7 @@ class ResNetBase(BaseModel):
246
275
  bias=False,
247
276
  ),
248
277
  nn.GroupNorm(
249
- min(32, out_channels * block.expansion),
278
+ _get_num_groups(out_channels * block.expansion),
250
279
  out_channels * block.expansion,
251
280
  ),
252
281
  )
@@ -191,22 +191,33 @@ class SwinTransformerBase(BaseModel):
191
191
  Returns:
192
192
  List of parameter group dictionaries
193
193
  """
194
- # Separate parameters: head (full LR) vs backbone (decayed LR)
194
+ # Separate parameters into 4 groups for proper LR decay:
195
+ # 1. Head params with decay (full LR)
196
+ # 2. Backbone params with decay (0.1× LR)
197
+ # 3. Head bias/norm without decay (full LR)
198
+ # 4. Backbone bias/norm without decay (0.1× LR)
195
199
  head_params = []
196
200
  backbone_params = []
197
- no_decay_params = []
201
+ head_no_decay = []
202
+ backbone_no_decay = []
198
203
 
199
204
  for name, param in self.backbone.named_parameters():
200
205
  if not param.requires_grad:
201
206
  continue
202
207
 
203
- # No weight decay for bias and normalization
204
- if "bias" in name or "norm" in name:
205
- no_decay_params.append(param)
206
- elif "head" in name:
207
- head_params.append(param)
208
+ is_head = "head" in name
209
+ is_no_decay = "bias" in name or "norm" in name
210
+
211
+ if is_head:
212
+ if is_no_decay:
213
+ head_no_decay.append(param)
214
+ else:
215
+ head_params.append(param)
208
216
  else:
209
- backbone_params.append(param)
217
+ if is_no_decay:
218
+ backbone_no_decay.append(param)
219
+ else:
220
+ backbone_params.append(param)
210
221
 
211
222
  groups = []
212
223
 
@@ -229,15 +240,25 @@ class SwinTransformerBase(BaseModel):
229
240
  }
230
241
  )
231
242
 
232
- if no_decay_params:
243
+ if head_no_decay:
233
244
  groups.append(
234
245
  {
235
- "params": no_decay_params,
246
+ "params": head_no_decay,
236
247
  "lr": base_lr,
237
248
  "weight_decay": 0.0,
238
249
  }
239
250
  )
240
251
 
252
+ if backbone_no_decay:
253
+ # Backbone bias/norm also gets 0.1× LR to match intended decay
254
+ groups.append(
255
+ {
256
+ "params": backbone_no_decay,
257
+ "lr": base_lr * 0.1,
258
+ "weight_decay": 0.0,
259
+ }
260
+ )
261
+
241
262
  return groups if groups else [{"params": self.parameters(), "lr": base_lr}]
242
263
 
243
264
 
@@ -69,6 +69,39 @@ _setup_cache_dir("XDG_DATA_HOME", "local/share")
69
69
  _setup_cache_dir("XDG_STATE_HOME", "local/state")
70
70
  _setup_cache_dir("XDG_CACHE_HOME", "cache")
71
71
 
72
+
73
+ def _setup_per_rank_compile_cache() -> None:
74
+ """Set per-GPU Triton/Inductor cache to prevent multi-process race warnings.
75
+
76
+ When using torch.compile with multiple GPUs, all processes try to write to
77
+ the same cache directory, causing 'Directory is not empty - skipping!' warnings.
78
+ This gives each GPU rank its own isolated cache subdirectory.
79
+ """
80
+ # Get local rank from environment (set by accelerate/torchrun)
81
+ local_rank = os.environ.get("LOCAL_RANK", "0")
82
+
83
+ # Get cache base from environment or use CWD
84
+ cache_base = os.environ.get(
85
+ "TRITON_CACHE_DIR", os.path.join(os.getcwd(), ".triton_cache")
86
+ )
87
+
88
+ # Set per-rank cache directories
89
+ os.environ["TRITON_CACHE_DIR"] = os.path.join(cache_base, f"rank_{local_rank}")
90
+ os.environ["TORCHINDUCTOR_CACHE_DIR"] = os.path.join(
91
+ os.environ.get(
92
+ "TORCHINDUCTOR_CACHE_DIR", os.path.join(os.getcwd(), ".inductor_cache")
93
+ ),
94
+ f"rank_{local_rank}",
95
+ )
96
+
97
+ # Create directories
98
+ os.makedirs(os.environ["TRITON_CACHE_DIR"], exist_ok=True)
99
+ os.makedirs(os.environ["TORCHINDUCTOR_CACHE_DIR"], exist_ok=True)
100
+
101
+
102
+ # Setup per-rank compile caches (before torch imports)
103
+ _setup_per_rank_compile_cache()
104
+
72
105
  # =============================================================================
73
106
  # Standard imports (after environment setup)
74
107
  # =============================================================================
@@ -89,6 +122,7 @@ import matplotlib.pyplot as plt
89
122
  import numpy as np
90
123
  import pandas as pd
91
124
  import torch
125
+ import torch.distributed as dist
92
126
  from accelerate import Accelerator
93
127
  from accelerate.utils import set_seed
94
128
  from sklearn.metrics import r2_score
@@ -437,15 +471,19 @@ def main():
437
471
  try:
438
472
  # Handle both module names (my_model) and file paths (./my_model.py)
439
473
  if module_name.endswith(".py"):
440
- # Import from file path
474
+ # Import from file path with unique module name
441
475
  import importlib.util
442
476
 
477
+ # Derive unique module name from filename to avoid collisions
478
+ base_name = os.path.splitext(os.path.basename(module_name))[0]
479
+ unique_name = f"wavedl_custom_{base_name}"
480
+
443
481
  spec = importlib.util.spec_from_file_location(
444
- "custom_module", module_name
482
+ unique_name, module_name
445
483
  )
446
484
  if spec and spec.loader:
447
485
  module = importlib.util.module_from_spec(spec)
448
- sys.modules["custom_module"] = module
486
+ sys.modules[unique_name] = module
449
487
  spec.loader.exec_module(module)
450
488
  print(f"✓ Imported custom module from: {module_name}")
451
489
  else:
@@ -908,7 +946,6 @@ def main():
908
946
  logger.info("=" * len(header))
909
947
 
910
948
  try:
911
- time.time()
912
949
  total_training_time = 0.0
913
950
 
914
951
  for epoch in range(start_epoch, args.epochs):
@@ -1002,49 +1039,29 @@ def main():
1002
1039
  local_preds.append(pred.detach().cpu())
1003
1040
  local_targets.append(y.detach().cpu())
1004
1041
 
1005
- # Concatenate locally on CPU (no GPU memory spike)
1006
- cpu_preds = torch.cat(local_preds)
1007
- cpu_targets = torch.cat(local_targets)
1008
-
1009
- # Gather predictions and targets to rank 0 only (memory-efficient)
1010
- # Avoids duplicating full validation set on every GPU
1011
- if torch.distributed.is_initialized():
1012
- # DDP mode: gather only to rank 0
1013
- # NCCL backend requires CUDA tensors for collective ops
1014
- gpu_preds = cpu_preds.to(accelerator.device)
1015
- gpu_targets = cpu_targets.to(accelerator.device)
1016
-
1017
- if accelerator.is_main_process:
1018
- # Rank 0: allocate gather buffers on GPU
1019
- all_preds_list = [
1020
- torch.zeros_like(gpu_preds)
1021
- for _ in range(accelerator.num_processes)
1022
- ]
1023
- all_targets_list = [
1024
- torch.zeros_like(gpu_targets)
1025
- for _ in range(accelerator.num_processes)
1026
- ]
1027
- torch.distributed.gather(
1028
- gpu_preds, gather_list=all_preds_list, dst=0
1029
- )
1030
- torch.distributed.gather(
1031
- gpu_targets, gather_list=all_targets_list, dst=0
1032
- )
1033
- # Move back to CPU for metric computation
1034
- gathered = [
1035
- (
1036
- torch.cat(all_preds_list).cpu(),
1037
- torch.cat(all_targets_list).cpu(),
1038
- )
1039
- ]
1040
- else:
1041
- # Other ranks: send to rank 0, don't allocate gather buffers
1042
- torch.distributed.gather(gpu_preds, gather_list=None, dst=0)
1043
- torch.distributed.gather(gpu_targets, gather_list=None, dst=0)
1044
- gathered = [(cpu_preds, cpu_targets)] # Placeholder, not used
1042
+ # Concatenate locally (keep on GPU for gather_for_metrics compatibility)
1043
+ local_preds_cat = torch.cat(local_preds)
1044
+ local_targets_cat = torch.cat(local_targets)
1045
+
1046
+ # Gather predictions and targets using Accelerate's CPU-efficient utility
1047
+ # gather_for_metrics handles:
1048
+ # - DDP padding removal (no need to trim manually)
1049
+ # - Efficient cross-rank gathering without GPU memory spike
1050
+ # - Returns concatenated tensors on CPU for metric computation
1051
+ if accelerator.num_processes > 1:
1052
+ # Move to GPU for gather (required by NCCL), then back to CPU
1053
+ # gather_for_metrics is more memory-efficient than manual gather
1054
+ # as it processes in chunks internally
1055
+ gathered_preds = accelerator.gather_for_metrics(
1056
+ local_preds_cat.to(accelerator.device)
1057
+ ).cpu()
1058
+ gathered_targets = accelerator.gather_for_metrics(
1059
+ local_targets_cat.to(accelerator.device)
1060
+ ).cpu()
1045
1061
  else:
1046
1062
  # Single-GPU mode: no gathering needed
1047
- gathered = [(cpu_preds, cpu_targets)]
1063
+ gathered_preds = local_preds_cat
1064
+ gathered_targets = local_targets_cat
1048
1065
 
1049
1066
  # Synchronize validation metrics (scalars only - efficient)
1050
1067
  val_loss_scalar = val_loss_sum.item()
@@ -1069,20 +1086,10 @@ def main():
1069
1086
 
1070
1087
  # ==================== LOGGING & CHECKPOINTING ====================
1071
1088
  if accelerator.is_main_process:
1072
- # Concatenate gathered tensors from all ranks (only on rank 0)
1073
- # gathered is list of tuples: [(preds_rank0, targs_rank0), (preds_rank1, targs_rank1), ...]
1074
- all_preds = torch.cat([item[0] for item in gathered])
1075
- all_targets = torch.cat([item[1] for item in gathered])
1076
-
1077
1089
  # Scientific metrics - cast to float32 before numpy
1078
- y_pred = all_preds.float().numpy()
1079
- y_true = all_targets.float().numpy()
1080
-
1081
- # Trim DDP padding
1082
- real_len = len(val_dl.dataset)
1083
- if len(y_pred) > real_len:
1084
- y_pred = y_pred[:real_len]
1085
- y_true = y_true[:real_len]
1090
+ # gather_for_metrics already handles DDP padding removal
1091
+ y_pred = gathered_preds.float().numpy()
1092
+ y_true = gathered_targets.float().numpy()
1086
1093
 
1087
1094
  # Guard against tiny validation sets (R² undefined for <2 samples)
1088
1095
  if len(y_true) >= 2:
@@ -1248,9 +1255,32 @@ def main():
1248
1255
  )
1249
1256
 
1250
1257
  # Learning rate scheduling (epoch-based schedulers only)
1258
+ # NOTE: For ReduceLROnPlateau with DDP, we must step only on main process
1259
+ # to avoid patience counter being incremented by all GPU processes.
1260
+ # Then we sync the new LR to all processes to keep them consistent.
1251
1261
  if not scheduler_step_per_batch:
1252
1262
  if args.scheduler == "plateau":
1253
- scheduler.step(avg_val_loss)
1263
+ # Step only on main process to avoid multi-GPU patience bug
1264
+ if accelerator.is_main_process:
1265
+ scheduler.step(avg_val_loss)
1266
+
1267
+ # Sync LR across all processes after main process updates it
1268
+ accelerator.wait_for_everyone()
1269
+
1270
+ # Broadcast new LR from rank 0 to all processes
1271
+ if dist.is_initialized():
1272
+ if accelerator.is_main_process:
1273
+ new_lr = optimizer.param_groups[0]["lr"]
1274
+ else:
1275
+ new_lr = 0.0
1276
+ new_lr_tensor = torch.tensor(
1277
+ new_lr, device=accelerator.device, dtype=torch.float32
1278
+ )
1279
+ dist.broadcast(new_lr_tensor, src=0)
1280
+ # Update LR on non-main processes
1281
+ if not accelerator.is_main_process:
1282
+ for param_group in optimizer.param_groups:
1283
+ param_group["lr"] = new_lr_tensor.item()
1254
1284
  else:
1255
1285
  scheduler.step()
1256
1286
 
@@ -183,9 +183,11 @@ def save_config(
183
183
  config[key] = value
184
184
 
185
185
  # Add metadata
186
+ from wavedl import __version__
187
+
186
188
  config["_metadata"] = {
187
189
  "saved_at": datetime.now().isoformat(),
188
- "wavedl_version": "1.0.0",
190
+ "wavedl_version": __version__,
189
191
  }
190
192
 
191
193
  output_path = Path(output_path)
@@ -337,6 +337,17 @@ def run_cross_validation(
337
337
  torch.cuda.manual_seed_all(seed)
338
338
 
339
339
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
340
+
341
+ # Auto-detect optimal DataLoader workers if not specified (matches train.py behavior)
342
+ if workers < 0:
343
+ cpu_count = os.cpu_count() or 4
344
+ num_gpus = torch.cuda.device_count() if torch.cuda.is_available() else 1
345
+ # Heuristic: 4-16 workers per GPU, bounded by available CPU cores
346
+ workers = min(16, max(2, (cpu_count - 2) // max(1, num_gpus)))
347
+ logger.info(
348
+ f"⚙️ Auto-detected workers: {workers} (CPUs: {cpu_count}, GPUs: {num_gpus})"
349
+ )
350
+
340
351
  logger.info(f"🚀 K-Fold Cross-Validation ({folds} folds)")
341
352
  logger.info(f" Model: {model_name} | Device: {device}")
342
353
  logger.info(
@@ -763,9 +763,74 @@ def load_test_data(
763
763
  k for k in OUTPUT_KEYS if k != "output_test"
764
764
  ]
765
765
 
766
- # Load data using appropriate source
766
+ # Load data using appropriate source with test-key priority
767
+ # We detect keys first to ensure input_test/output_test are used when present
767
768
  try:
768
- inp, outp = source.load(path)
769
+ if format == "npz":
770
+ with np.load(path, allow_pickle=False) as probe:
771
+ keys = list(probe.keys())
772
+ inp_key = DataSource._find_key(keys, custom_input_keys)
773
+ out_key = DataSource._find_key(keys, custom_output_keys)
774
+ if inp_key is None:
775
+ raise KeyError(
776
+ f"Input key not found. Tried: {custom_input_keys}. Found: {keys}"
777
+ )
778
+ data = NPZSource._safe_load(
779
+ path, [inp_key] + ([out_key] if out_key else [])
780
+ )
781
+ inp = data[inp_key]
782
+ if inp.dtype == object:
783
+ inp = np.array(
784
+ [x.toarray() if hasattr(x, "toarray") else x for x in inp]
785
+ )
786
+ outp = data[out_key] if out_key else None
787
+ elif format == "hdf5":
788
+ with h5py.File(path, "r") as f:
789
+ keys = list(f.keys())
790
+ inp_key = DataSource._find_key(keys, custom_input_keys)
791
+ out_key = DataSource._find_key(keys, custom_output_keys)
792
+ if inp_key is None:
793
+ raise KeyError(
794
+ f"Input key not found. Tried: {custom_input_keys}. Found: {keys}"
795
+ )
796
+ # OOM guard: warn if dataset is very large
797
+ n_samples = f[inp_key].shape[0]
798
+ if n_samples > 100000:
799
+ raise ValueError(
800
+ f"Dataset has {n_samples:,} samples. load_test_data() loads "
801
+ f"everything into RAM which may cause OOM. For large inference "
802
+ f"sets, use a DataLoader with HDF5Source.load_mmap() instead."
803
+ )
804
+ inp = f[inp_key][:]
805
+ outp = f[out_key][:] if out_key else None
806
+ elif format == "mat":
807
+ mat_source = MATSource()
808
+ with h5py.File(path, "r") as f:
809
+ keys = list(f.keys())
810
+ inp_key = DataSource._find_key(keys, custom_input_keys)
811
+ out_key = DataSource._find_key(keys, custom_output_keys)
812
+ if inp_key is None:
813
+ raise KeyError(
814
+ f"Input key not found. Tried: {custom_input_keys}. Found: {keys}"
815
+ )
816
+ # OOM guard: warn if dataset is very large (MAT is transposed)
817
+ n_samples = f[inp_key].shape[-1]
818
+ if n_samples > 100000:
819
+ raise ValueError(
820
+ f"Dataset has {n_samples:,} samples. load_test_data() loads "
821
+ f"everything into RAM which may cause OOM. For large inference "
822
+ f"sets, use a DataLoader with MATSource.load_mmap() instead."
823
+ )
824
+ inp = mat_source._load_dataset(f, inp_key)
825
+ if out_key:
826
+ outp = mat_source._load_dataset(f, out_key)
827
+ if outp.ndim == 2 and outp.shape[0] == 1:
828
+ outp = outp.T
829
+ else:
830
+ outp = None
831
+ else:
832
+ # Fallback to default source.load() for unknown formats
833
+ inp, outp = source.load(path)
769
834
  except KeyError:
770
835
  # Try with just inputs if outputs not found (inference-only mode)
771
836
  if format == "npz":
@@ -1077,6 +1142,17 @@ def prepare_data(
1077
1142
 
1078
1143
  if not cache_exists:
1079
1144
  if accelerator.is_main_process:
1145
+ # Delete stale cache files to force regeneration
1146
+ # This prevents silent reuse of old data when metadata invalidates cache
1147
+ for stale_file in [CACHE_FILE, SCALER_FILE]:
1148
+ if os.path.exists(stale_file):
1149
+ try:
1150
+ os.remove(stale_file)
1151
+ logger.debug(f" Removed stale cache: {stale_file}")
1152
+ except OSError as e:
1153
+ logger.warning(
1154
+ f" Failed to remove stale cache {stale_file}: {e}"
1155
+ )
1080
1156
  # RANK 0: Create cache (can take a long time for large datasets)
1081
1157
  # Other ranks will wait at the barrier below
1082
1158
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.2
2
2
  Name: wavedl
3
- Version: 1.5.2
3
+ Version: 1.5.4
4
4
  Summary: A Scalable Deep Learning Framework for Wave-Based Inverse Problems
5
5
  Author: Ductho Le
6
6
  License: MIT
@@ -37,11 +37,12 @@ Requires-Dist: wandb>=0.15.0
37
37
  Requires-Dist: optuna>=3.0.0
38
38
  Requires-Dist: onnx>=1.14.0
39
39
  Requires-Dist: onnxruntime>=1.15.0
40
- Requires-Dist: pytest>=7.0.0
41
- Requires-Dist: pytest-xdist>=3.5.0
42
- Requires-Dist: ruff>=0.8.0
43
- Requires-Dist: pre-commit>=3.5.0
44
40
  Requires-Dist: triton>=2.0.0; sys_platform == "linux"
41
+ Provides-Extra: dev
42
+ Requires-Dist: pytest>=7.0.0; extra == "dev"
43
+ Requires-Dist: pytest-xdist>=3.5.0; extra == "dev"
44
+ Requires-Dist: ruff>=0.8.0; extra == "dev"
45
+ Requires-Dist: pre-commit>=3.5.0; extra == "dev"
45
46
 
46
47
  <div align="center">
47
48
 
@@ -204,7 +205,7 @@ Deploy models anywhere:
204
205
  pip install wavedl
205
206
  ```
206
207
 
207
- This installs everything you need: training, inference, HPO, ONNX export, and dev tools.
208
+ This installs everything you need: training, inference, HPO, ONNX export.
208
209
 
209
210
  #### From Source (for development)
210
211
 
@@ -301,8 +302,8 @@ python -m wavedl.test --checkpoint <checkpoint_folder> --data_path <test_data> \
301
302
 
302
303
  **Requirements** (your model must):
303
304
  1. Inherit from `BaseModel`
304
- 2. Accept `in_channels`, `num_outputs`, `input_shape` in `__init__`
305
- 3. Return a tensor of shape `(batch, num_outputs)` from `forward()`
305
+ 2. Accept `in_shape`, `out_size` in `__init__`
306
+ 3. Return a tensor of shape `(batch, out_size)` from `forward()`
306
307
 
307
308
  ---
308
309
 
@@ -315,29 +316,28 @@ from wavedl.models import BaseModel, register_model
315
316
 
316
317
  @register_model("my_model") # This name is used with --model flag
317
318
  class MyModel(BaseModel):
318
- def __init__(self, in_channels, num_outputs, input_shape):
319
- # in_channels: number of input channels (auto-detected from data)
320
- # num_outputs: number of parameters to predict (auto-detected from data)
321
- # input_shape: spatial dimensions, e.g., (128,) or (64, 64) or (32, 32, 32)
322
- super().__init__(in_channels, num_outputs, input_shape)
323
-
324
- # Define your layers (this is just an example)
325
- self.conv1 = nn.Conv2d(in_channels, 64, 3, padding=1)
319
+ def __init__(self, in_shape, out_size, **kwargs):
320
+ # in_shape: spatial dimensions, e.g., (128,) or (64, 64) or (32, 32, 32)
321
+ # out_size: number of parameters to predict (auto-detected from data)
322
+ super().__init__(in_shape, out_size)
323
+
324
+ # Define your layers (this is just an example for 2D)
325
+ self.conv1 = nn.Conv2d(1, 64, 3, padding=1) # Input always has 1 channel
326
326
  self.conv2 = nn.Conv2d(64, 128, 3, padding=1)
327
- self.fc = nn.Linear(128, num_outputs)
327
+ self.fc = nn.Linear(128, out_size)
328
328
 
329
329
  def forward(self, x):
330
- # Input x has shape: (batch, in_channels, *input_shape)
330
+ # Input x has shape: (batch, 1, *in_shape)
331
331
  x = F.relu(self.conv1(x))
332
332
  x = F.relu(self.conv2(x))
333
333
  x = x.mean(dim=[-2, -1]) # Global average pooling
334
- return self.fc(x) # Output shape: (batch, num_outputs)
334
+ return self.fc(x) # Output shape: (batch, out_size)
335
335
  ```
336
336
 
337
337
  **Step 2: Train**
338
338
 
339
339
  ```bash
340
- wavedl-hpc --import my_model --model my_model --data_path train.npz
340
+ wavedl-hpc --import my_model.py --model my_model --data_path train.npz
341
341
  ```
342
342
 
343
343
  WaveDL handles everything else: training loop, logging, checkpoints, multi-GPU, early stopping, etc.
@@ -513,7 +513,7 @@ print('\\n✓ All pretrained weights cached!')
513
513
  | Argument | Default | Description |
514
514
  |----------|---------|-------------|
515
515
  | `--model` | `cnn` | Model architecture |
516
- | `--import` | - | Python modules to import (for custom models) |
516
+ | `--import` | - | Python file(s) to import for custom models (supports multiple) |
517
517
  | `--batch_size` | `128` | Per-GPU batch size |
518
518
  | `--lr` | `1e-3` | Learning rate |
519
519
  | `--epochs` | `1000` | Maximum epochs |
@@ -573,14 +573,19 @@ WaveDL automatically enables performance optimizations for modern GPUs:
573
573
  </details>
574
574
 
575
575
  <details>
576
- <summary><b>Environment Variables (wavedl-hpc)</b></summary>
576
+ <summary><b>HPC CLI Arguments (wavedl-hpc)</b></summary>
577
+
578
+ | Argument | Default | Description |
579
+ |----------|---------|-------------|
580
+ | `--num_gpus` | **Auto-detected** | Number of GPUs to use. By default, automatically detected via `nvidia-smi`. Set explicitly to override |
581
+ | `--num_machines` | `1` | Number of machines in distributed setup |
582
+ | `--mixed_precision` | `bf16` | Precision mode: `bf16`, `fp16`, or `no` |
583
+ | `--dynamo_backend` | `no` | PyTorch Dynamo backend |
584
+
585
+ **Environment Variables (for logging):**
577
586
 
578
587
  | Variable | Default | Description |
579
588
  |----------|---------|-------------|
580
- | `NUM_GPUS` | **Auto-detected** | Number of GPUs to use. By default, automatically detected via `nvidia-smi`. Set explicitly to override (e.g., `NUM_GPUS=2`) |
581
- | `NUM_MACHINES` | `1` | Number of machines in distributed setup |
582
- | `MIXED_PRECISION` | `bf16` | Precision mode: `bf16`, `fp16`, or `no` |
583
- | `DYNAMO_BACKEND` | `no` | PyTorch Dynamo backend |
584
589
  | `WANDB_MODE` | `offline` | WandB mode: `offline` or `online` |
585
590
 
586
591
  </details>
@@ -1219,6 +1224,6 @@ This research was enabled in part by support provided by [Compute Ontario](https
1219
1224
  [![Google Scholar](https://img.shields.io/badge/Google_Scholar-4285F4?style=plastic&logo=google-scholar&logoColor=white)](https://scholar.google.ca/citations?user=OlwMr9AAAAAJ)
1220
1225
  [![ResearchGate](https://img.shields.io/badge/ResearchGate-00CCBB?style=plastic&logo=researchgate&logoColor=white)](https://www.researchgate.net/profile/Ductho-Le)
1221
1226
 
1222
- <sub>Released under the MIT License</sub>
1227
+ <sub>May your signals be strong and your attenuation low 👋</sub>
1223
1228
 
1224
1229
  </div>
@@ -14,10 +14,12 @@ wandb>=0.15.0
14
14
  optuna>=3.0.0
15
15
  onnx>=1.14.0
16
16
  onnxruntime>=1.15.0
17
+
18
+ [:sys_platform == "linux"]
19
+ triton>=2.0.0
20
+
21
+ [dev]
17
22
  pytest>=7.0.0
18
23
  pytest-xdist>=3.5.0
19
24
  ruff>=0.8.0
20
25
  pre-commit>=3.5.0
21
-
22
- [:sys_platform == "linux"]
23
- triton>=2.0.0
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes