wavedl 1.5.2__tar.gz → 1.5.4__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {wavedl-1.5.2/src/wavedl.egg-info → wavedl-1.5.4}/PKG-INFO +32 -27
- {wavedl-1.5.2 → wavedl-1.5.4}/README.md +26 -22
- {wavedl-1.5.2 → wavedl-1.5.4}/pyproject.toml +6 -3
- {wavedl-1.5.2 → wavedl-1.5.4}/src/wavedl/__init__.py +1 -1
- {wavedl-1.5.2 → wavedl-1.5.4}/src/wavedl/hpc.py +22 -18
- {wavedl-1.5.2 → wavedl-1.5.4}/src/wavedl/models/resnet.py +38 -9
- {wavedl-1.5.2 → wavedl-1.5.4}/src/wavedl/models/swin.py +31 -10
- {wavedl-1.5.2 → wavedl-1.5.4}/src/wavedl/train.py +89 -59
- {wavedl-1.5.2 → wavedl-1.5.4}/src/wavedl/utils/config.py +3 -1
- {wavedl-1.5.2 → wavedl-1.5.4}/src/wavedl/utils/cross_validation.py +11 -0
- {wavedl-1.5.2 → wavedl-1.5.4}/src/wavedl/utils/data.py +78 -2
- {wavedl-1.5.2 → wavedl-1.5.4/src/wavedl.egg-info}/PKG-INFO +32 -27
- {wavedl-1.5.2 → wavedl-1.5.4}/src/wavedl.egg-info/requires.txt +5 -3
- {wavedl-1.5.2 → wavedl-1.5.4}/LICENSE +0 -0
- {wavedl-1.5.2 → wavedl-1.5.4}/setup.cfg +0 -0
- {wavedl-1.5.2 → wavedl-1.5.4}/src/wavedl/hpo.py +0 -0
- {wavedl-1.5.2 → wavedl-1.5.4}/src/wavedl/models/__init__.py +0 -0
- {wavedl-1.5.2 → wavedl-1.5.4}/src/wavedl/models/_template.py +0 -0
- {wavedl-1.5.2 → wavedl-1.5.4}/src/wavedl/models/base.py +0 -0
- {wavedl-1.5.2 → wavedl-1.5.4}/src/wavedl/models/cnn.py +0 -0
- {wavedl-1.5.2 → wavedl-1.5.4}/src/wavedl/models/convnext.py +0 -0
- {wavedl-1.5.2 → wavedl-1.5.4}/src/wavedl/models/densenet.py +0 -0
- {wavedl-1.5.2 → wavedl-1.5.4}/src/wavedl/models/efficientnet.py +0 -0
- {wavedl-1.5.2 → wavedl-1.5.4}/src/wavedl/models/efficientnetv2.py +0 -0
- {wavedl-1.5.2 → wavedl-1.5.4}/src/wavedl/models/mobilenetv3.py +0 -0
- {wavedl-1.5.2 → wavedl-1.5.4}/src/wavedl/models/registry.py +0 -0
- {wavedl-1.5.2 → wavedl-1.5.4}/src/wavedl/models/regnet.py +0 -0
- {wavedl-1.5.2 → wavedl-1.5.4}/src/wavedl/models/resnet3d.py +0 -0
- {wavedl-1.5.2 → wavedl-1.5.4}/src/wavedl/models/tcn.py +0 -0
- {wavedl-1.5.2 → wavedl-1.5.4}/src/wavedl/models/unet.py +0 -0
- {wavedl-1.5.2 → wavedl-1.5.4}/src/wavedl/models/vit.py +0 -0
- {wavedl-1.5.2 → wavedl-1.5.4}/src/wavedl/test.py +0 -0
- {wavedl-1.5.2 → wavedl-1.5.4}/src/wavedl/utils/__init__.py +0 -0
- {wavedl-1.5.2 → wavedl-1.5.4}/src/wavedl/utils/constraints.py +0 -0
- {wavedl-1.5.2 → wavedl-1.5.4}/src/wavedl/utils/distributed.py +0 -0
- {wavedl-1.5.2 → wavedl-1.5.4}/src/wavedl/utils/losses.py +0 -0
- {wavedl-1.5.2 → wavedl-1.5.4}/src/wavedl/utils/metrics.py +0 -0
- {wavedl-1.5.2 → wavedl-1.5.4}/src/wavedl/utils/optimizers.py +0 -0
- {wavedl-1.5.2 → wavedl-1.5.4}/src/wavedl/utils/schedulers.py +0 -0
- {wavedl-1.5.2 → wavedl-1.5.4}/src/wavedl.egg-info/SOURCES.txt +0 -0
- {wavedl-1.5.2 → wavedl-1.5.4}/src/wavedl.egg-info/dependency_links.txt +0 -0
- {wavedl-1.5.2 → wavedl-1.5.4}/src/wavedl.egg-info/entry_points.txt +0 -0
- {wavedl-1.5.2 → wavedl-1.5.4}/src/wavedl.egg-info/top_level.txt +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.2
|
|
2
2
|
Name: wavedl
|
|
3
|
-
Version: 1.5.
|
|
3
|
+
Version: 1.5.4
|
|
4
4
|
Summary: A Scalable Deep Learning Framework for Wave-Based Inverse Problems
|
|
5
5
|
Author: Ductho Le
|
|
6
6
|
License: MIT
|
|
@@ -37,11 +37,12 @@ Requires-Dist: wandb>=0.15.0
|
|
|
37
37
|
Requires-Dist: optuna>=3.0.0
|
|
38
38
|
Requires-Dist: onnx>=1.14.0
|
|
39
39
|
Requires-Dist: onnxruntime>=1.15.0
|
|
40
|
-
Requires-Dist: pytest>=7.0.0
|
|
41
|
-
Requires-Dist: pytest-xdist>=3.5.0
|
|
42
|
-
Requires-Dist: ruff>=0.8.0
|
|
43
|
-
Requires-Dist: pre-commit>=3.5.0
|
|
44
40
|
Requires-Dist: triton>=2.0.0; sys_platform == "linux"
|
|
41
|
+
Provides-Extra: dev
|
|
42
|
+
Requires-Dist: pytest>=7.0.0; extra == "dev"
|
|
43
|
+
Requires-Dist: pytest-xdist>=3.5.0; extra == "dev"
|
|
44
|
+
Requires-Dist: ruff>=0.8.0; extra == "dev"
|
|
45
|
+
Requires-Dist: pre-commit>=3.5.0; extra == "dev"
|
|
45
46
|
|
|
46
47
|
<div align="center">
|
|
47
48
|
|
|
@@ -204,7 +205,7 @@ Deploy models anywhere:
|
|
|
204
205
|
pip install wavedl
|
|
205
206
|
```
|
|
206
207
|
|
|
207
|
-
This installs everything you need: training, inference, HPO, ONNX export
|
|
208
|
+
This installs everything you need: training, inference, HPO, ONNX export.
|
|
208
209
|
|
|
209
210
|
#### From Source (for development)
|
|
210
211
|
|
|
@@ -301,8 +302,8 @@ python -m wavedl.test --checkpoint <checkpoint_folder> --data_path <test_data> \
|
|
|
301
302
|
|
|
302
303
|
**Requirements** (your model must):
|
|
303
304
|
1. Inherit from `BaseModel`
|
|
304
|
-
2. Accept `
|
|
305
|
-
3. Return a tensor of shape `(batch,
|
|
305
|
+
2. Accept `in_shape`, `out_size` in `__init__`
|
|
306
|
+
3. Return a tensor of shape `(batch, out_size)` from `forward()`
|
|
306
307
|
|
|
307
308
|
---
|
|
308
309
|
|
|
@@ -315,29 +316,28 @@ from wavedl.models import BaseModel, register_model
|
|
|
315
316
|
|
|
316
317
|
@register_model("my_model") # This name is used with --model flag
|
|
317
318
|
class MyModel(BaseModel):
|
|
318
|
-
def __init__(self,
|
|
319
|
-
#
|
|
320
|
-
#
|
|
321
|
-
|
|
322
|
-
|
|
323
|
-
|
|
324
|
-
|
|
325
|
-
self.conv1 = nn.Conv2d(in_channels, 64, 3, padding=1)
|
|
319
|
+
def __init__(self, in_shape, out_size, **kwargs):
|
|
320
|
+
# in_shape: spatial dimensions, e.g., (128,) or (64, 64) or (32, 32, 32)
|
|
321
|
+
# out_size: number of parameters to predict (auto-detected from data)
|
|
322
|
+
super().__init__(in_shape, out_size)
|
|
323
|
+
|
|
324
|
+
# Define your layers (this is just an example for 2D)
|
|
325
|
+
self.conv1 = nn.Conv2d(1, 64, 3, padding=1) # Input always has 1 channel
|
|
326
326
|
self.conv2 = nn.Conv2d(64, 128, 3, padding=1)
|
|
327
|
-
self.fc = nn.Linear(128,
|
|
327
|
+
self.fc = nn.Linear(128, out_size)
|
|
328
328
|
|
|
329
329
|
def forward(self, x):
|
|
330
|
-
# Input x has shape: (batch,
|
|
330
|
+
# Input x has shape: (batch, 1, *in_shape)
|
|
331
331
|
x = F.relu(self.conv1(x))
|
|
332
332
|
x = F.relu(self.conv2(x))
|
|
333
333
|
x = x.mean(dim=[-2, -1]) # Global average pooling
|
|
334
|
-
return self.fc(x) # Output shape: (batch,
|
|
334
|
+
return self.fc(x) # Output shape: (batch, out_size)
|
|
335
335
|
```
|
|
336
336
|
|
|
337
337
|
**Step 2: Train**
|
|
338
338
|
|
|
339
339
|
```bash
|
|
340
|
-
wavedl-hpc --import my_model --model my_model --data_path train.npz
|
|
340
|
+
wavedl-hpc --import my_model.py --model my_model --data_path train.npz
|
|
341
341
|
```
|
|
342
342
|
|
|
343
343
|
WaveDL handles everything else: training loop, logging, checkpoints, multi-GPU, early stopping, etc.
|
|
@@ -513,7 +513,7 @@ print('\\n✓ All pretrained weights cached!')
|
|
|
513
513
|
| Argument | Default | Description |
|
|
514
514
|
|----------|---------|-------------|
|
|
515
515
|
| `--model` | `cnn` | Model architecture |
|
|
516
|
-
| `--import` | - | Python
|
|
516
|
+
| `--import` | - | Python file(s) to import for custom models (supports multiple) |
|
|
517
517
|
| `--batch_size` | `128` | Per-GPU batch size |
|
|
518
518
|
| `--lr` | `1e-3` | Learning rate |
|
|
519
519
|
| `--epochs` | `1000` | Maximum epochs |
|
|
@@ -573,14 +573,19 @@ WaveDL automatically enables performance optimizations for modern GPUs:
|
|
|
573
573
|
</details>
|
|
574
574
|
|
|
575
575
|
<details>
|
|
576
|
-
<summary><b>
|
|
576
|
+
<summary><b>HPC CLI Arguments (wavedl-hpc)</b></summary>
|
|
577
|
+
|
|
578
|
+
| Argument | Default | Description |
|
|
579
|
+
|----------|---------|-------------|
|
|
580
|
+
| `--num_gpus` | **Auto-detected** | Number of GPUs to use. By default, automatically detected via `nvidia-smi`. Set explicitly to override |
|
|
581
|
+
| `--num_machines` | `1` | Number of machines in distributed setup |
|
|
582
|
+
| `--mixed_precision` | `bf16` | Precision mode: `bf16`, `fp16`, or `no` |
|
|
583
|
+
| `--dynamo_backend` | `no` | PyTorch Dynamo backend |
|
|
584
|
+
|
|
585
|
+
**Environment Variables (for logging):**
|
|
577
586
|
|
|
578
587
|
| Variable | Default | Description |
|
|
579
588
|
|----------|---------|-------------|
|
|
580
|
-
| `NUM_GPUS` | **Auto-detected** | Number of GPUs to use. By default, automatically detected via `nvidia-smi`. Set explicitly to override (e.g., `NUM_GPUS=2`) |
|
|
581
|
-
| `NUM_MACHINES` | `1` | Number of machines in distributed setup |
|
|
582
|
-
| `MIXED_PRECISION` | `bf16` | Precision mode: `bf16`, `fp16`, or `no` |
|
|
583
|
-
| `DYNAMO_BACKEND` | `no` | PyTorch Dynamo backend |
|
|
584
589
|
| `WANDB_MODE` | `offline` | WandB mode: `offline` or `online` |
|
|
585
590
|
|
|
586
591
|
</details>
|
|
@@ -1219,6 +1224,6 @@ This research was enabled in part by support provided by [Compute Ontario](https
|
|
|
1219
1224
|
[](https://scholar.google.ca/citations?user=OlwMr9AAAAAJ)
|
|
1220
1225
|
[](https://www.researchgate.net/profile/Ductho-Le)
|
|
1221
1226
|
|
|
1222
|
-
<sub>
|
|
1227
|
+
<sub>May your signals be strong and your attenuation low 👋</sub>
|
|
1223
1228
|
|
|
1224
1229
|
</div>
|
|
@@ -159,7 +159,7 @@ Deploy models anywhere:
|
|
|
159
159
|
pip install wavedl
|
|
160
160
|
```
|
|
161
161
|
|
|
162
|
-
This installs everything you need: training, inference, HPO, ONNX export
|
|
162
|
+
This installs everything you need: training, inference, HPO, ONNX export.
|
|
163
163
|
|
|
164
164
|
#### From Source (for development)
|
|
165
165
|
|
|
@@ -256,8 +256,8 @@ python -m wavedl.test --checkpoint <checkpoint_folder> --data_path <test_data> \
|
|
|
256
256
|
|
|
257
257
|
**Requirements** (your model must):
|
|
258
258
|
1. Inherit from `BaseModel`
|
|
259
|
-
2. Accept `
|
|
260
|
-
3. Return a tensor of shape `(batch,
|
|
259
|
+
2. Accept `in_shape`, `out_size` in `__init__`
|
|
260
|
+
3. Return a tensor of shape `(batch, out_size)` from `forward()`
|
|
261
261
|
|
|
262
262
|
---
|
|
263
263
|
|
|
@@ -270,29 +270,28 @@ from wavedl.models import BaseModel, register_model
|
|
|
270
270
|
|
|
271
271
|
@register_model("my_model") # This name is used with --model flag
|
|
272
272
|
class MyModel(BaseModel):
|
|
273
|
-
def __init__(self,
|
|
274
|
-
#
|
|
275
|
-
#
|
|
276
|
-
|
|
277
|
-
|
|
278
|
-
|
|
279
|
-
|
|
280
|
-
self.conv1 = nn.Conv2d(in_channels, 64, 3, padding=1)
|
|
273
|
+
def __init__(self, in_shape, out_size, **kwargs):
|
|
274
|
+
# in_shape: spatial dimensions, e.g., (128,) or (64, 64) or (32, 32, 32)
|
|
275
|
+
# out_size: number of parameters to predict (auto-detected from data)
|
|
276
|
+
super().__init__(in_shape, out_size)
|
|
277
|
+
|
|
278
|
+
# Define your layers (this is just an example for 2D)
|
|
279
|
+
self.conv1 = nn.Conv2d(1, 64, 3, padding=1) # Input always has 1 channel
|
|
281
280
|
self.conv2 = nn.Conv2d(64, 128, 3, padding=1)
|
|
282
|
-
self.fc = nn.Linear(128,
|
|
281
|
+
self.fc = nn.Linear(128, out_size)
|
|
283
282
|
|
|
284
283
|
def forward(self, x):
|
|
285
|
-
# Input x has shape: (batch,
|
|
284
|
+
# Input x has shape: (batch, 1, *in_shape)
|
|
286
285
|
x = F.relu(self.conv1(x))
|
|
287
286
|
x = F.relu(self.conv2(x))
|
|
288
287
|
x = x.mean(dim=[-2, -1]) # Global average pooling
|
|
289
|
-
return self.fc(x) # Output shape: (batch,
|
|
288
|
+
return self.fc(x) # Output shape: (batch, out_size)
|
|
290
289
|
```
|
|
291
290
|
|
|
292
291
|
**Step 2: Train**
|
|
293
292
|
|
|
294
293
|
```bash
|
|
295
|
-
wavedl-hpc --import my_model --model my_model --data_path train.npz
|
|
294
|
+
wavedl-hpc --import my_model.py --model my_model --data_path train.npz
|
|
296
295
|
```
|
|
297
296
|
|
|
298
297
|
WaveDL handles everything else: training loop, logging, checkpoints, multi-GPU, early stopping, etc.
|
|
@@ -468,7 +467,7 @@ print('\\n✓ All pretrained weights cached!')
|
|
|
468
467
|
| Argument | Default | Description |
|
|
469
468
|
|----------|---------|-------------|
|
|
470
469
|
| `--model` | `cnn` | Model architecture |
|
|
471
|
-
| `--import` | - | Python
|
|
470
|
+
| `--import` | - | Python file(s) to import for custom models (supports multiple) |
|
|
472
471
|
| `--batch_size` | `128` | Per-GPU batch size |
|
|
473
472
|
| `--lr` | `1e-3` | Learning rate |
|
|
474
473
|
| `--epochs` | `1000` | Maximum epochs |
|
|
@@ -528,14 +527,19 @@ WaveDL automatically enables performance optimizations for modern GPUs:
|
|
|
528
527
|
</details>
|
|
529
528
|
|
|
530
529
|
<details>
|
|
531
|
-
<summary><b>
|
|
530
|
+
<summary><b>HPC CLI Arguments (wavedl-hpc)</b></summary>
|
|
531
|
+
|
|
532
|
+
| Argument | Default | Description |
|
|
533
|
+
|----------|---------|-------------|
|
|
534
|
+
| `--num_gpus` | **Auto-detected** | Number of GPUs to use. By default, automatically detected via `nvidia-smi`. Set explicitly to override |
|
|
535
|
+
| `--num_machines` | `1` | Number of machines in distributed setup |
|
|
536
|
+
| `--mixed_precision` | `bf16` | Precision mode: `bf16`, `fp16`, or `no` |
|
|
537
|
+
| `--dynamo_backend` | `no` | PyTorch Dynamo backend |
|
|
538
|
+
|
|
539
|
+
**Environment Variables (for logging):**
|
|
532
540
|
|
|
533
541
|
| Variable | Default | Description |
|
|
534
542
|
|----------|---------|-------------|
|
|
535
|
-
| `NUM_GPUS` | **Auto-detected** | Number of GPUs to use. By default, automatically detected via `nvidia-smi`. Set explicitly to override (e.g., `NUM_GPUS=2`) |
|
|
536
|
-
| `NUM_MACHINES` | `1` | Number of machines in distributed setup |
|
|
537
|
-
| `MIXED_PRECISION` | `bf16` | Precision mode: `bf16`, `fp16`, or `no` |
|
|
538
|
-
| `DYNAMO_BACKEND` | `no` | PyTorch Dynamo backend |
|
|
539
543
|
| `WANDB_MODE` | `offline` | WandB mode: `offline` or `online` |
|
|
540
544
|
|
|
541
545
|
</details>
|
|
@@ -1174,6 +1178,6 @@ This research was enabled in part by support provided by [Compute Ontario](https
|
|
|
1174
1178
|
[](https://scholar.google.ca/citations?user=OlwMr9AAAAAJ)
|
|
1175
1179
|
[](https://www.researchgate.net/profile/Ductho-Le)
|
|
1176
1180
|
|
|
1177
|
-
<sub>
|
|
1181
|
+
<sub>May your signals be strong and your attenuation low 👋</sub>
|
|
1178
1182
|
|
|
1179
1183
|
</div>
|
|
@@ -70,13 +70,16 @@ dependencies = [
|
|
|
70
70
|
# ONNX export
|
|
71
71
|
"onnx>=1.14.0",
|
|
72
72
|
"onnxruntime>=1.15.0",
|
|
73
|
-
#
|
|
73
|
+
# torch.compile backend (Linux only)
|
|
74
|
+
"triton>=2.0.0; sys_platform == 'linux'",
|
|
75
|
+
]
|
|
76
|
+
|
|
77
|
+
[project.optional-dependencies]
|
|
78
|
+
dev = [
|
|
74
79
|
"pytest>=7.0.0",
|
|
75
80
|
"pytest-xdist>=3.5.0",
|
|
76
81
|
"ruff>=0.8.0",
|
|
77
82
|
"pre-commit>=3.5.0",
|
|
78
|
-
# torch.compile backend (Linux only)
|
|
79
|
-
"triton>=2.0.0; sys_platform == 'linux'",
|
|
80
83
|
]
|
|
81
84
|
|
|
82
85
|
[project.scripts]
|
|
@@ -57,30 +57,35 @@ def setup_hpc_environment() -> None:
|
|
|
57
57
|
"""Configure environment variables for HPC systems.
|
|
58
58
|
|
|
59
59
|
Handles restricted home directories (e.g., Compute Canada) and
|
|
60
|
-
offline logging configurations.
|
|
60
|
+
offline logging configurations. Always uses CWD-based TORCH_HOME
|
|
61
|
+
since compute nodes typically lack internet access.
|
|
61
62
|
"""
|
|
62
|
-
#
|
|
63
|
+
# Use CWD for cache base since HPC compute nodes typically lack internet
|
|
64
|
+
cache_base = os.getcwd()
|
|
65
|
+
|
|
66
|
+
# TORCH_HOME always set to CWD - compute nodes need pre-cached weights
|
|
67
|
+
os.environ.setdefault("TORCH_HOME", f"{cache_base}/.torch_cache")
|
|
68
|
+
Path(os.environ["TORCH_HOME"]).mkdir(parents=True, exist_ok=True)
|
|
69
|
+
|
|
70
|
+
# Triton/Inductor caches - prevents permission errors with --compile
|
|
71
|
+
# These MUST be set before any torch.compile calls
|
|
72
|
+
os.environ.setdefault("TRITON_CACHE_DIR", f"{cache_base}/.triton_cache")
|
|
73
|
+
os.environ.setdefault("TORCHINDUCTOR_CACHE_DIR", f"{cache_base}/.inductor_cache")
|
|
74
|
+
Path(os.environ["TRITON_CACHE_DIR"]).mkdir(parents=True, exist_ok=True)
|
|
75
|
+
Path(os.environ["TORCHINDUCTOR_CACHE_DIR"]).mkdir(parents=True, exist_ok=True)
|
|
76
|
+
|
|
77
|
+
# Check if home is writable for other caches
|
|
63
78
|
home = os.path.expanduser("~")
|
|
64
79
|
home_writable = os.access(home, os.W_OK)
|
|
65
80
|
|
|
66
|
-
#
|
|
67
|
-
if home_writable:
|
|
68
|
-
# Local machine - let libraries use defaults
|
|
69
|
-
cache_base = None
|
|
70
|
-
else:
|
|
71
|
-
# HPC with restricted home - use CWD for persistent caches
|
|
72
|
-
cache_base = os.getcwd()
|
|
73
|
-
|
|
74
|
-
# Only set environment variables if home is not writable
|
|
75
|
-
if cache_base:
|
|
76
|
-
os.environ.setdefault("TORCH_HOME", f"{cache_base}/.torch_cache")
|
|
81
|
+
# Other caches only if home is not writable
|
|
82
|
+
if not home_writable:
|
|
77
83
|
os.environ.setdefault("MPLCONFIGDIR", f"{cache_base}/.matplotlib")
|
|
78
84
|
os.environ.setdefault("FONTCONFIG_CACHE", f"{cache_base}/.fontconfig")
|
|
79
85
|
os.environ.setdefault("XDG_CACHE_HOME", f"{cache_base}/.cache")
|
|
80
86
|
|
|
81
87
|
# Ensure directories exist
|
|
82
88
|
for env_var in [
|
|
83
|
-
"TORCH_HOME",
|
|
84
89
|
"MPLCONFIGDIR",
|
|
85
90
|
"FONTCONFIG_CACHE",
|
|
86
91
|
"XDG_CACHE_HOME",
|
|
@@ -89,10 +94,9 @@ def setup_hpc_environment() -> None:
|
|
|
89
94
|
|
|
90
95
|
# WandB configuration (offline by default for HPC)
|
|
91
96
|
os.environ.setdefault("WANDB_MODE", "offline")
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
os.environ.setdefault("WANDB_CONFIG_DIR", f"{cache_base}/.wandb_config")
|
|
97
|
+
os.environ.setdefault("WANDB_DIR", f"{cache_base}/.wandb")
|
|
98
|
+
os.environ.setdefault("WANDB_CACHE_DIR", f"{cache_base}/.wandb_cache")
|
|
99
|
+
os.environ.setdefault("WANDB_CONFIG_DIR", f"{cache_base}/.wandb_config")
|
|
96
100
|
|
|
97
101
|
# Suppress non-critical warnings
|
|
98
102
|
os.environ.setdefault(
|
|
@@ -49,6 +49,36 @@ def _get_conv_layers(
|
|
|
49
49
|
raise ValueError(f"Unsupported dimensionality: {dim}D. Supported: 1D, 2D, 3D.")
|
|
50
50
|
|
|
51
51
|
|
|
52
|
+
def _get_num_groups(num_channels: int, preferred_groups: int = 32) -> int:
|
|
53
|
+
"""
|
|
54
|
+
Get valid num_groups for GroupNorm that divides num_channels evenly.
|
|
55
|
+
|
|
56
|
+
Args:
|
|
57
|
+
num_channels: Number of channels to normalize
|
|
58
|
+
preferred_groups: Preferred number of groups (default: 32)
|
|
59
|
+
|
|
60
|
+
Returns:
|
|
61
|
+
Valid num_groups that divides num_channels
|
|
62
|
+
|
|
63
|
+
Raises:
|
|
64
|
+
ValueError: If no valid divisor found (shouldn't happen with power-of-2 channels)
|
|
65
|
+
"""
|
|
66
|
+
# Try preferred groups first, then decrease
|
|
67
|
+
for groups in [preferred_groups, 16, 8, 4, 2, 1]:
|
|
68
|
+
if groups <= num_channels and num_channels % groups == 0:
|
|
69
|
+
return groups
|
|
70
|
+
|
|
71
|
+
# Fallback: find any valid divisor
|
|
72
|
+
for groups in range(min(32, num_channels), 0, -1):
|
|
73
|
+
if num_channels % groups == 0:
|
|
74
|
+
return groups
|
|
75
|
+
|
|
76
|
+
raise ValueError(
|
|
77
|
+
f"Cannot find valid num_groups for {num_channels} channels. "
|
|
78
|
+
f"Consider using base_width that is a power of 2 (e.g., 32, 64, 128)."
|
|
79
|
+
)
|
|
80
|
+
|
|
81
|
+
|
|
52
82
|
class BasicBlock(nn.Module):
|
|
53
83
|
"""
|
|
54
84
|
Basic residual block for ResNet-18/34.
|
|
@@ -77,12 +107,12 @@ class BasicBlock(nn.Module):
|
|
|
77
107
|
padding=1,
|
|
78
108
|
bias=False,
|
|
79
109
|
)
|
|
80
|
-
self.gn1 = nn.GroupNorm(
|
|
110
|
+
self.gn1 = nn.GroupNorm(_get_num_groups(out_channels), out_channels)
|
|
81
111
|
self.relu = nn.ReLU(inplace=True)
|
|
82
112
|
self.conv2 = Conv(
|
|
83
113
|
out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False
|
|
84
114
|
)
|
|
85
|
-
self.gn2 = nn.GroupNorm(
|
|
115
|
+
self.gn2 = nn.GroupNorm(_get_num_groups(out_channels), out_channels)
|
|
86
116
|
self.downsample = downsample
|
|
87
117
|
|
|
88
118
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
@@ -125,7 +155,7 @@ class Bottleneck(nn.Module):
|
|
|
125
155
|
|
|
126
156
|
# 1x1 reduce
|
|
127
157
|
self.conv1 = Conv(in_channels, out_channels, kernel_size=1, bias=False)
|
|
128
|
-
self.gn1 = nn.GroupNorm(
|
|
158
|
+
self.gn1 = nn.GroupNorm(_get_num_groups(out_channels), out_channels)
|
|
129
159
|
|
|
130
160
|
# 3x3 conv
|
|
131
161
|
self.conv2 = Conv(
|
|
@@ -136,15 +166,14 @@ class Bottleneck(nn.Module):
|
|
|
136
166
|
padding=1,
|
|
137
167
|
bias=False,
|
|
138
168
|
)
|
|
139
|
-
self.gn2 = nn.GroupNorm(
|
|
169
|
+
self.gn2 = nn.GroupNorm(_get_num_groups(out_channels), out_channels)
|
|
140
170
|
|
|
141
171
|
# 1x1 expand
|
|
142
172
|
self.conv3 = Conv(
|
|
143
173
|
out_channels, out_channels * self.expansion, kernel_size=1, bias=False
|
|
144
174
|
)
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
)
|
|
175
|
+
expanded_channels = out_channels * self.expansion
|
|
176
|
+
self.gn3 = nn.GroupNorm(_get_num_groups(expanded_channels), expanded_channels)
|
|
148
177
|
|
|
149
178
|
self.relu = nn.ReLU(inplace=True)
|
|
150
179
|
self.downsample = downsample
|
|
@@ -200,7 +229,7 @@ class ResNetBase(BaseModel):
|
|
|
200
229
|
|
|
201
230
|
# Stem: 7x7 conv (or equivalent for 1D/3D)
|
|
202
231
|
self.conv1 = Conv(1, base_width, kernel_size=7, stride=2, padding=3, bias=False)
|
|
203
|
-
self.gn1 = nn.GroupNorm(
|
|
232
|
+
self.gn1 = nn.GroupNorm(_get_num_groups(base_width), base_width)
|
|
204
233
|
self.relu = nn.ReLU(inplace=True)
|
|
205
234
|
self.maxpool = MaxPool(kernel_size=3, stride=2, padding=1)
|
|
206
235
|
|
|
@@ -246,7 +275,7 @@ class ResNetBase(BaseModel):
|
|
|
246
275
|
bias=False,
|
|
247
276
|
),
|
|
248
277
|
nn.GroupNorm(
|
|
249
|
-
|
|
278
|
+
_get_num_groups(out_channels * block.expansion),
|
|
250
279
|
out_channels * block.expansion,
|
|
251
280
|
),
|
|
252
281
|
)
|
|
@@ -191,22 +191,33 @@ class SwinTransformerBase(BaseModel):
|
|
|
191
191
|
Returns:
|
|
192
192
|
List of parameter group dictionaries
|
|
193
193
|
"""
|
|
194
|
-
# Separate parameters
|
|
194
|
+
# Separate parameters into 4 groups for proper LR decay:
|
|
195
|
+
# 1. Head params with decay (full LR)
|
|
196
|
+
# 2. Backbone params with decay (0.1× LR)
|
|
197
|
+
# 3. Head bias/norm without decay (full LR)
|
|
198
|
+
# 4. Backbone bias/norm without decay (0.1× LR)
|
|
195
199
|
head_params = []
|
|
196
200
|
backbone_params = []
|
|
197
|
-
|
|
201
|
+
head_no_decay = []
|
|
202
|
+
backbone_no_decay = []
|
|
198
203
|
|
|
199
204
|
for name, param in self.backbone.named_parameters():
|
|
200
205
|
if not param.requires_grad:
|
|
201
206
|
continue
|
|
202
207
|
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
208
|
+
is_head = "head" in name
|
|
209
|
+
is_no_decay = "bias" in name or "norm" in name
|
|
210
|
+
|
|
211
|
+
if is_head:
|
|
212
|
+
if is_no_decay:
|
|
213
|
+
head_no_decay.append(param)
|
|
214
|
+
else:
|
|
215
|
+
head_params.append(param)
|
|
208
216
|
else:
|
|
209
|
-
|
|
217
|
+
if is_no_decay:
|
|
218
|
+
backbone_no_decay.append(param)
|
|
219
|
+
else:
|
|
220
|
+
backbone_params.append(param)
|
|
210
221
|
|
|
211
222
|
groups = []
|
|
212
223
|
|
|
@@ -229,15 +240,25 @@ class SwinTransformerBase(BaseModel):
|
|
|
229
240
|
}
|
|
230
241
|
)
|
|
231
242
|
|
|
232
|
-
if
|
|
243
|
+
if head_no_decay:
|
|
233
244
|
groups.append(
|
|
234
245
|
{
|
|
235
|
-
"params":
|
|
246
|
+
"params": head_no_decay,
|
|
236
247
|
"lr": base_lr,
|
|
237
248
|
"weight_decay": 0.0,
|
|
238
249
|
}
|
|
239
250
|
)
|
|
240
251
|
|
|
252
|
+
if backbone_no_decay:
|
|
253
|
+
# Backbone bias/norm also gets 0.1× LR to match intended decay
|
|
254
|
+
groups.append(
|
|
255
|
+
{
|
|
256
|
+
"params": backbone_no_decay,
|
|
257
|
+
"lr": base_lr * 0.1,
|
|
258
|
+
"weight_decay": 0.0,
|
|
259
|
+
}
|
|
260
|
+
)
|
|
261
|
+
|
|
241
262
|
return groups if groups else [{"params": self.parameters(), "lr": base_lr}]
|
|
242
263
|
|
|
243
264
|
|
|
@@ -69,6 +69,39 @@ _setup_cache_dir("XDG_DATA_HOME", "local/share")
|
|
|
69
69
|
_setup_cache_dir("XDG_STATE_HOME", "local/state")
|
|
70
70
|
_setup_cache_dir("XDG_CACHE_HOME", "cache")
|
|
71
71
|
|
|
72
|
+
|
|
73
|
+
def _setup_per_rank_compile_cache() -> None:
|
|
74
|
+
"""Set per-GPU Triton/Inductor cache to prevent multi-process race warnings.
|
|
75
|
+
|
|
76
|
+
When using torch.compile with multiple GPUs, all processes try to write to
|
|
77
|
+
the same cache directory, causing 'Directory is not empty - skipping!' warnings.
|
|
78
|
+
This gives each GPU rank its own isolated cache subdirectory.
|
|
79
|
+
"""
|
|
80
|
+
# Get local rank from environment (set by accelerate/torchrun)
|
|
81
|
+
local_rank = os.environ.get("LOCAL_RANK", "0")
|
|
82
|
+
|
|
83
|
+
# Get cache base from environment or use CWD
|
|
84
|
+
cache_base = os.environ.get(
|
|
85
|
+
"TRITON_CACHE_DIR", os.path.join(os.getcwd(), ".triton_cache")
|
|
86
|
+
)
|
|
87
|
+
|
|
88
|
+
# Set per-rank cache directories
|
|
89
|
+
os.environ["TRITON_CACHE_DIR"] = os.path.join(cache_base, f"rank_{local_rank}")
|
|
90
|
+
os.environ["TORCHINDUCTOR_CACHE_DIR"] = os.path.join(
|
|
91
|
+
os.environ.get(
|
|
92
|
+
"TORCHINDUCTOR_CACHE_DIR", os.path.join(os.getcwd(), ".inductor_cache")
|
|
93
|
+
),
|
|
94
|
+
f"rank_{local_rank}",
|
|
95
|
+
)
|
|
96
|
+
|
|
97
|
+
# Create directories
|
|
98
|
+
os.makedirs(os.environ["TRITON_CACHE_DIR"], exist_ok=True)
|
|
99
|
+
os.makedirs(os.environ["TORCHINDUCTOR_CACHE_DIR"], exist_ok=True)
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
# Setup per-rank compile caches (before torch imports)
|
|
103
|
+
_setup_per_rank_compile_cache()
|
|
104
|
+
|
|
72
105
|
# =============================================================================
|
|
73
106
|
# Standard imports (after environment setup)
|
|
74
107
|
# =============================================================================
|
|
@@ -89,6 +122,7 @@ import matplotlib.pyplot as plt
|
|
|
89
122
|
import numpy as np
|
|
90
123
|
import pandas as pd
|
|
91
124
|
import torch
|
|
125
|
+
import torch.distributed as dist
|
|
92
126
|
from accelerate import Accelerator
|
|
93
127
|
from accelerate.utils import set_seed
|
|
94
128
|
from sklearn.metrics import r2_score
|
|
@@ -437,15 +471,19 @@ def main():
|
|
|
437
471
|
try:
|
|
438
472
|
# Handle both module names (my_model) and file paths (./my_model.py)
|
|
439
473
|
if module_name.endswith(".py"):
|
|
440
|
-
# Import from file path
|
|
474
|
+
# Import from file path with unique module name
|
|
441
475
|
import importlib.util
|
|
442
476
|
|
|
477
|
+
# Derive unique module name from filename to avoid collisions
|
|
478
|
+
base_name = os.path.splitext(os.path.basename(module_name))[0]
|
|
479
|
+
unique_name = f"wavedl_custom_{base_name}"
|
|
480
|
+
|
|
443
481
|
spec = importlib.util.spec_from_file_location(
|
|
444
|
-
|
|
482
|
+
unique_name, module_name
|
|
445
483
|
)
|
|
446
484
|
if spec and spec.loader:
|
|
447
485
|
module = importlib.util.module_from_spec(spec)
|
|
448
|
-
sys.modules[
|
|
486
|
+
sys.modules[unique_name] = module
|
|
449
487
|
spec.loader.exec_module(module)
|
|
450
488
|
print(f"✓ Imported custom module from: {module_name}")
|
|
451
489
|
else:
|
|
@@ -908,7 +946,6 @@ def main():
|
|
|
908
946
|
logger.info("=" * len(header))
|
|
909
947
|
|
|
910
948
|
try:
|
|
911
|
-
time.time()
|
|
912
949
|
total_training_time = 0.0
|
|
913
950
|
|
|
914
951
|
for epoch in range(start_epoch, args.epochs):
|
|
@@ -1002,49 +1039,29 @@ def main():
|
|
|
1002
1039
|
local_preds.append(pred.detach().cpu())
|
|
1003
1040
|
local_targets.append(y.detach().cpu())
|
|
1004
1041
|
|
|
1005
|
-
# Concatenate locally on
|
|
1006
|
-
|
|
1007
|
-
|
|
1008
|
-
|
|
1009
|
-
# Gather predictions and targets
|
|
1010
|
-
#
|
|
1011
|
-
|
|
1012
|
-
|
|
1013
|
-
|
|
1014
|
-
|
|
1015
|
-
|
|
1016
|
-
|
|
1017
|
-
|
|
1018
|
-
|
|
1019
|
-
|
|
1020
|
-
|
|
1021
|
-
|
|
1022
|
-
|
|
1023
|
-
|
|
1024
|
-
torch.zeros_like(gpu_targets)
|
|
1025
|
-
for _ in range(accelerator.num_processes)
|
|
1026
|
-
]
|
|
1027
|
-
torch.distributed.gather(
|
|
1028
|
-
gpu_preds, gather_list=all_preds_list, dst=0
|
|
1029
|
-
)
|
|
1030
|
-
torch.distributed.gather(
|
|
1031
|
-
gpu_targets, gather_list=all_targets_list, dst=0
|
|
1032
|
-
)
|
|
1033
|
-
# Move back to CPU for metric computation
|
|
1034
|
-
gathered = [
|
|
1035
|
-
(
|
|
1036
|
-
torch.cat(all_preds_list).cpu(),
|
|
1037
|
-
torch.cat(all_targets_list).cpu(),
|
|
1038
|
-
)
|
|
1039
|
-
]
|
|
1040
|
-
else:
|
|
1041
|
-
# Other ranks: send to rank 0, don't allocate gather buffers
|
|
1042
|
-
torch.distributed.gather(gpu_preds, gather_list=None, dst=0)
|
|
1043
|
-
torch.distributed.gather(gpu_targets, gather_list=None, dst=0)
|
|
1044
|
-
gathered = [(cpu_preds, cpu_targets)] # Placeholder, not used
|
|
1042
|
+
# Concatenate locally (keep on GPU for gather_for_metrics compatibility)
|
|
1043
|
+
local_preds_cat = torch.cat(local_preds)
|
|
1044
|
+
local_targets_cat = torch.cat(local_targets)
|
|
1045
|
+
|
|
1046
|
+
# Gather predictions and targets using Accelerate's CPU-efficient utility
|
|
1047
|
+
# gather_for_metrics handles:
|
|
1048
|
+
# - DDP padding removal (no need to trim manually)
|
|
1049
|
+
# - Efficient cross-rank gathering without GPU memory spike
|
|
1050
|
+
# - Returns concatenated tensors on CPU for metric computation
|
|
1051
|
+
if accelerator.num_processes > 1:
|
|
1052
|
+
# Move to GPU for gather (required by NCCL), then back to CPU
|
|
1053
|
+
# gather_for_metrics is more memory-efficient than manual gather
|
|
1054
|
+
# as it processes in chunks internally
|
|
1055
|
+
gathered_preds = accelerator.gather_for_metrics(
|
|
1056
|
+
local_preds_cat.to(accelerator.device)
|
|
1057
|
+
).cpu()
|
|
1058
|
+
gathered_targets = accelerator.gather_for_metrics(
|
|
1059
|
+
local_targets_cat.to(accelerator.device)
|
|
1060
|
+
).cpu()
|
|
1045
1061
|
else:
|
|
1046
1062
|
# Single-GPU mode: no gathering needed
|
|
1047
|
-
|
|
1063
|
+
gathered_preds = local_preds_cat
|
|
1064
|
+
gathered_targets = local_targets_cat
|
|
1048
1065
|
|
|
1049
1066
|
# Synchronize validation metrics (scalars only - efficient)
|
|
1050
1067
|
val_loss_scalar = val_loss_sum.item()
|
|
@@ -1069,20 +1086,10 @@ def main():
|
|
|
1069
1086
|
|
|
1070
1087
|
# ==================== LOGGING & CHECKPOINTING ====================
|
|
1071
1088
|
if accelerator.is_main_process:
|
|
1072
|
-
# Concatenate gathered tensors from all ranks (only on rank 0)
|
|
1073
|
-
# gathered is list of tuples: [(preds_rank0, targs_rank0), (preds_rank1, targs_rank1), ...]
|
|
1074
|
-
all_preds = torch.cat([item[0] for item in gathered])
|
|
1075
|
-
all_targets = torch.cat([item[1] for item in gathered])
|
|
1076
|
-
|
|
1077
1089
|
# Scientific metrics - cast to float32 before numpy
|
|
1078
|
-
|
|
1079
|
-
|
|
1080
|
-
|
|
1081
|
-
# Trim DDP padding
|
|
1082
|
-
real_len = len(val_dl.dataset)
|
|
1083
|
-
if len(y_pred) > real_len:
|
|
1084
|
-
y_pred = y_pred[:real_len]
|
|
1085
|
-
y_true = y_true[:real_len]
|
|
1090
|
+
# gather_for_metrics already handles DDP padding removal
|
|
1091
|
+
y_pred = gathered_preds.float().numpy()
|
|
1092
|
+
y_true = gathered_targets.float().numpy()
|
|
1086
1093
|
|
|
1087
1094
|
# Guard against tiny validation sets (R² undefined for <2 samples)
|
|
1088
1095
|
if len(y_true) >= 2:
|
|
@@ -1248,9 +1255,32 @@ def main():
|
|
|
1248
1255
|
)
|
|
1249
1256
|
|
|
1250
1257
|
# Learning rate scheduling (epoch-based schedulers only)
|
|
1258
|
+
# NOTE: For ReduceLROnPlateau with DDP, we must step only on main process
|
|
1259
|
+
# to avoid patience counter being incremented by all GPU processes.
|
|
1260
|
+
# Then we sync the new LR to all processes to keep them consistent.
|
|
1251
1261
|
if not scheduler_step_per_batch:
|
|
1252
1262
|
if args.scheduler == "plateau":
|
|
1253
|
-
|
|
1263
|
+
# Step only on main process to avoid multi-GPU patience bug
|
|
1264
|
+
if accelerator.is_main_process:
|
|
1265
|
+
scheduler.step(avg_val_loss)
|
|
1266
|
+
|
|
1267
|
+
# Sync LR across all processes after main process updates it
|
|
1268
|
+
accelerator.wait_for_everyone()
|
|
1269
|
+
|
|
1270
|
+
# Broadcast new LR from rank 0 to all processes
|
|
1271
|
+
if dist.is_initialized():
|
|
1272
|
+
if accelerator.is_main_process:
|
|
1273
|
+
new_lr = optimizer.param_groups[0]["lr"]
|
|
1274
|
+
else:
|
|
1275
|
+
new_lr = 0.0
|
|
1276
|
+
new_lr_tensor = torch.tensor(
|
|
1277
|
+
new_lr, device=accelerator.device, dtype=torch.float32
|
|
1278
|
+
)
|
|
1279
|
+
dist.broadcast(new_lr_tensor, src=0)
|
|
1280
|
+
# Update LR on non-main processes
|
|
1281
|
+
if not accelerator.is_main_process:
|
|
1282
|
+
for param_group in optimizer.param_groups:
|
|
1283
|
+
param_group["lr"] = new_lr_tensor.item()
|
|
1254
1284
|
else:
|
|
1255
1285
|
scheduler.step()
|
|
1256
1286
|
|
|
@@ -183,9 +183,11 @@ def save_config(
|
|
|
183
183
|
config[key] = value
|
|
184
184
|
|
|
185
185
|
# Add metadata
|
|
186
|
+
from wavedl import __version__
|
|
187
|
+
|
|
186
188
|
config["_metadata"] = {
|
|
187
189
|
"saved_at": datetime.now().isoformat(),
|
|
188
|
-
"wavedl_version":
|
|
190
|
+
"wavedl_version": __version__,
|
|
189
191
|
}
|
|
190
192
|
|
|
191
193
|
output_path = Path(output_path)
|
|
@@ -337,6 +337,17 @@ def run_cross_validation(
|
|
|
337
337
|
torch.cuda.manual_seed_all(seed)
|
|
338
338
|
|
|
339
339
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
340
|
+
|
|
341
|
+
# Auto-detect optimal DataLoader workers if not specified (matches train.py behavior)
|
|
342
|
+
if workers < 0:
|
|
343
|
+
cpu_count = os.cpu_count() or 4
|
|
344
|
+
num_gpus = torch.cuda.device_count() if torch.cuda.is_available() else 1
|
|
345
|
+
# Heuristic: 4-16 workers per GPU, bounded by available CPU cores
|
|
346
|
+
workers = min(16, max(2, (cpu_count - 2) // max(1, num_gpus)))
|
|
347
|
+
logger.info(
|
|
348
|
+
f"⚙️ Auto-detected workers: {workers} (CPUs: {cpu_count}, GPUs: {num_gpus})"
|
|
349
|
+
)
|
|
350
|
+
|
|
340
351
|
logger.info(f"🚀 K-Fold Cross-Validation ({folds} folds)")
|
|
341
352
|
logger.info(f" Model: {model_name} | Device: {device}")
|
|
342
353
|
logger.info(
|
|
@@ -763,9 +763,74 @@ def load_test_data(
|
|
|
763
763
|
k for k in OUTPUT_KEYS if k != "output_test"
|
|
764
764
|
]
|
|
765
765
|
|
|
766
|
-
# Load data using appropriate source
|
|
766
|
+
# Load data using appropriate source with test-key priority
|
|
767
|
+
# We detect keys first to ensure input_test/output_test are used when present
|
|
767
768
|
try:
|
|
768
|
-
|
|
769
|
+
if format == "npz":
|
|
770
|
+
with np.load(path, allow_pickle=False) as probe:
|
|
771
|
+
keys = list(probe.keys())
|
|
772
|
+
inp_key = DataSource._find_key(keys, custom_input_keys)
|
|
773
|
+
out_key = DataSource._find_key(keys, custom_output_keys)
|
|
774
|
+
if inp_key is None:
|
|
775
|
+
raise KeyError(
|
|
776
|
+
f"Input key not found. Tried: {custom_input_keys}. Found: {keys}"
|
|
777
|
+
)
|
|
778
|
+
data = NPZSource._safe_load(
|
|
779
|
+
path, [inp_key] + ([out_key] if out_key else [])
|
|
780
|
+
)
|
|
781
|
+
inp = data[inp_key]
|
|
782
|
+
if inp.dtype == object:
|
|
783
|
+
inp = np.array(
|
|
784
|
+
[x.toarray() if hasattr(x, "toarray") else x for x in inp]
|
|
785
|
+
)
|
|
786
|
+
outp = data[out_key] if out_key else None
|
|
787
|
+
elif format == "hdf5":
|
|
788
|
+
with h5py.File(path, "r") as f:
|
|
789
|
+
keys = list(f.keys())
|
|
790
|
+
inp_key = DataSource._find_key(keys, custom_input_keys)
|
|
791
|
+
out_key = DataSource._find_key(keys, custom_output_keys)
|
|
792
|
+
if inp_key is None:
|
|
793
|
+
raise KeyError(
|
|
794
|
+
f"Input key not found. Tried: {custom_input_keys}. Found: {keys}"
|
|
795
|
+
)
|
|
796
|
+
# OOM guard: warn if dataset is very large
|
|
797
|
+
n_samples = f[inp_key].shape[0]
|
|
798
|
+
if n_samples > 100000:
|
|
799
|
+
raise ValueError(
|
|
800
|
+
f"Dataset has {n_samples:,} samples. load_test_data() loads "
|
|
801
|
+
f"everything into RAM which may cause OOM. For large inference "
|
|
802
|
+
f"sets, use a DataLoader with HDF5Source.load_mmap() instead."
|
|
803
|
+
)
|
|
804
|
+
inp = f[inp_key][:]
|
|
805
|
+
outp = f[out_key][:] if out_key else None
|
|
806
|
+
elif format == "mat":
|
|
807
|
+
mat_source = MATSource()
|
|
808
|
+
with h5py.File(path, "r") as f:
|
|
809
|
+
keys = list(f.keys())
|
|
810
|
+
inp_key = DataSource._find_key(keys, custom_input_keys)
|
|
811
|
+
out_key = DataSource._find_key(keys, custom_output_keys)
|
|
812
|
+
if inp_key is None:
|
|
813
|
+
raise KeyError(
|
|
814
|
+
f"Input key not found. Tried: {custom_input_keys}. Found: {keys}"
|
|
815
|
+
)
|
|
816
|
+
# OOM guard: warn if dataset is very large (MAT is transposed)
|
|
817
|
+
n_samples = f[inp_key].shape[-1]
|
|
818
|
+
if n_samples > 100000:
|
|
819
|
+
raise ValueError(
|
|
820
|
+
f"Dataset has {n_samples:,} samples. load_test_data() loads "
|
|
821
|
+
f"everything into RAM which may cause OOM. For large inference "
|
|
822
|
+
f"sets, use a DataLoader with MATSource.load_mmap() instead."
|
|
823
|
+
)
|
|
824
|
+
inp = mat_source._load_dataset(f, inp_key)
|
|
825
|
+
if out_key:
|
|
826
|
+
outp = mat_source._load_dataset(f, out_key)
|
|
827
|
+
if outp.ndim == 2 and outp.shape[0] == 1:
|
|
828
|
+
outp = outp.T
|
|
829
|
+
else:
|
|
830
|
+
outp = None
|
|
831
|
+
else:
|
|
832
|
+
# Fallback to default source.load() for unknown formats
|
|
833
|
+
inp, outp = source.load(path)
|
|
769
834
|
except KeyError:
|
|
770
835
|
# Try with just inputs if outputs not found (inference-only mode)
|
|
771
836
|
if format == "npz":
|
|
@@ -1077,6 +1142,17 @@ def prepare_data(
|
|
|
1077
1142
|
|
|
1078
1143
|
if not cache_exists:
|
|
1079
1144
|
if accelerator.is_main_process:
|
|
1145
|
+
# Delete stale cache files to force regeneration
|
|
1146
|
+
# This prevents silent reuse of old data when metadata invalidates cache
|
|
1147
|
+
for stale_file in [CACHE_FILE, SCALER_FILE]:
|
|
1148
|
+
if os.path.exists(stale_file):
|
|
1149
|
+
try:
|
|
1150
|
+
os.remove(stale_file)
|
|
1151
|
+
logger.debug(f" Removed stale cache: {stale_file}")
|
|
1152
|
+
except OSError as e:
|
|
1153
|
+
logger.warning(
|
|
1154
|
+
f" Failed to remove stale cache {stale_file}: {e}"
|
|
1155
|
+
)
|
|
1080
1156
|
# RANK 0: Create cache (can take a long time for large datasets)
|
|
1081
1157
|
# Other ranks will wait at the barrier below
|
|
1082
1158
|
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.2
|
|
2
2
|
Name: wavedl
|
|
3
|
-
Version: 1.5.
|
|
3
|
+
Version: 1.5.4
|
|
4
4
|
Summary: A Scalable Deep Learning Framework for Wave-Based Inverse Problems
|
|
5
5
|
Author: Ductho Le
|
|
6
6
|
License: MIT
|
|
@@ -37,11 +37,12 @@ Requires-Dist: wandb>=0.15.0
|
|
|
37
37
|
Requires-Dist: optuna>=3.0.0
|
|
38
38
|
Requires-Dist: onnx>=1.14.0
|
|
39
39
|
Requires-Dist: onnxruntime>=1.15.0
|
|
40
|
-
Requires-Dist: pytest>=7.0.0
|
|
41
|
-
Requires-Dist: pytest-xdist>=3.5.0
|
|
42
|
-
Requires-Dist: ruff>=0.8.0
|
|
43
|
-
Requires-Dist: pre-commit>=3.5.0
|
|
44
40
|
Requires-Dist: triton>=2.0.0; sys_platform == "linux"
|
|
41
|
+
Provides-Extra: dev
|
|
42
|
+
Requires-Dist: pytest>=7.0.0; extra == "dev"
|
|
43
|
+
Requires-Dist: pytest-xdist>=3.5.0; extra == "dev"
|
|
44
|
+
Requires-Dist: ruff>=0.8.0; extra == "dev"
|
|
45
|
+
Requires-Dist: pre-commit>=3.5.0; extra == "dev"
|
|
45
46
|
|
|
46
47
|
<div align="center">
|
|
47
48
|
|
|
@@ -204,7 +205,7 @@ Deploy models anywhere:
|
|
|
204
205
|
pip install wavedl
|
|
205
206
|
```
|
|
206
207
|
|
|
207
|
-
This installs everything you need: training, inference, HPO, ONNX export
|
|
208
|
+
This installs everything you need: training, inference, HPO, ONNX export.
|
|
208
209
|
|
|
209
210
|
#### From Source (for development)
|
|
210
211
|
|
|
@@ -301,8 +302,8 @@ python -m wavedl.test --checkpoint <checkpoint_folder> --data_path <test_data> \
|
|
|
301
302
|
|
|
302
303
|
**Requirements** (your model must):
|
|
303
304
|
1. Inherit from `BaseModel`
|
|
304
|
-
2. Accept `
|
|
305
|
-
3. Return a tensor of shape `(batch,
|
|
305
|
+
2. Accept `in_shape`, `out_size` in `__init__`
|
|
306
|
+
3. Return a tensor of shape `(batch, out_size)` from `forward()`
|
|
306
307
|
|
|
307
308
|
---
|
|
308
309
|
|
|
@@ -315,29 +316,28 @@ from wavedl.models import BaseModel, register_model
|
|
|
315
316
|
|
|
316
317
|
@register_model("my_model") # This name is used with --model flag
|
|
317
318
|
class MyModel(BaseModel):
|
|
318
|
-
def __init__(self,
|
|
319
|
-
#
|
|
320
|
-
#
|
|
321
|
-
|
|
322
|
-
|
|
323
|
-
|
|
324
|
-
|
|
325
|
-
self.conv1 = nn.Conv2d(in_channels, 64, 3, padding=1)
|
|
319
|
+
def __init__(self, in_shape, out_size, **kwargs):
|
|
320
|
+
# in_shape: spatial dimensions, e.g., (128,) or (64, 64) or (32, 32, 32)
|
|
321
|
+
# out_size: number of parameters to predict (auto-detected from data)
|
|
322
|
+
super().__init__(in_shape, out_size)
|
|
323
|
+
|
|
324
|
+
# Define your layers (this is just an example for 2D)
|
|
325
|
+
self.conv1 = nn.Conv2d(1, 64, 3, padding=1) # Input always has 1 channel
|
|
326
326
|
self.conv2 = nn.Conv2d(64, 128, 3, padding=1)
|
|
327
|
-
self.fc = nn.Linear(128,
|
|
327
|
+
self.fc = nn.Linear(128, out_size)
|
|
328
328
|
|
|
329
329
|
def forward(self, x):
|
|
330
|
-
# Input x has shape: (batch,
|
|
330
|
+
# Input x has shape: (batch, 1, *in_shape)
|
|
331
331
|
x = F.relu(self.conv1(x))
|
|
332
332
|
x = F.relu(self.conv2(x))
|
|
333
333
|
x = x.mean(dim=[-2, -1]) # Global average pooling
|
|
334
|
-
return self.fc(x) # Output shape: (batch,
|
|
334
|
+
return self.fc(x) # Output shape: (batch, out_size)
|
|
335
335
|
```
|
|
336
336
|
|
|
337
337
|
**Step 2: Train**
|
|
338
338
|
|
|
339
339
|
```bash
|
|
340
|
-
wavedl-hpc --import my_model --model my_model --data_path train.npz
|
|
340
|
+
wavedl-hpc --import my_model.py --model my_model --data_path train.npz
|
|
341
341
|
```
|
|
342
342
|
|
|
343
343
|
WaveDL handles everything else: training loop, logging, checkpoints, multi-GPU, early stopping, etc.
|
|
@@ -513,7 +513,7 @@ print('\\n✓ All pretrained weights cached!')
|
|
|
513
513
|
| Argument | Default | Description |
|
|
514
514
|
|----------|---------|-------------|
|
|
515
515
|
| `--model` | `cnn` | Model architecture |
|
|
516
|
-
| `--import` | - | Python
|
|
516
|
+
| `--import` | - | Python file(s) to import for custom models (supports multiple) |
|
|
517
517
|
| `--batch_size` | `128` | Per-GPU batch size |
|
|
518
518
|
| `--lr` | `1e-3` | Learning rate |
|
|
519
519
|
| `--epochs` | `1000` | Maximum epochs |
|
|
@@ -573,14 +573,19 @@ WaveDL automatically enables performance optimizations for modern GPUs:
|
|
|
573
573
|
</details>
|
|
574
574
|
|
|
575
575
|
<details>
|
|
576
|
-
<summary><b>
|
|
576
|
+
<summary><b>HPC CLI Arguments (wavedl-hpc)</b></summary>
|
|
577
|
+
|
|
578
|
+
| Argument | Default | Description |
|
|
579
|
+
|----------|---------|-------------|
|
|
580
|
+
| `--num_gpus` | **Auto-detected** | Number of GPUs to use. By default, automatically detected via `nvidia-smi`. Set explicitly to override |
|
|
581
|
+
| `--num_machines` | `1` | Number of machines in distributed setup |
|
|
582
|
+
| `--mixed_precision` | `bf16` | Precision mode: `bf16`, `fp16`, or `no` |
|
|
583
|
+
| `--dynamo_backend` | `no` | PyTorch Dynamo backend |
|
|
584
|
+
|
|
585
|
+
**Environment Variables (for logging):**
|
|
577
586
|
|
|
578
587
|
| Variable | Default | Description |
|
|
579
588
|
|----------|---------|-------------|
|
|
580
|
-
| `NUM_GPUS` | **Auto-detected** | Number of GPUs to use. By default, automatically detected via `nvidia-smi`. Set explicitly to override (e.g., `NUM_GPUS=2`) |
|
|
581
|
-
| `NUM_MACHINES` | `1` | Number of machines in distributed setup |
|
|
582
|
-
| `MIXED_PRECISION` | `bf16` | Precision mode: `bf16`, `fp16`, or `no` |
|
|
583
|
-
| `DYNAMO_BACKEND` | `no` | PyTorch Dynamo backend |
|
|
584
589
|
| `WANDB_MODE` | `offline` | WandB mode: `offline` or `online` |
|
|
585
590
|
|
|
586
591
|
</details>
|
|
@@ -1219,6 +1224,6 @@ This research was enabled in part by support provided by [Compute Ontario](https
|
|
|
1219
1224
|
[](https://scholar.google.ca/citations?user=OlwMr9AAAAAJ)
|
|
1220
1225
|
[](https://www.researchgate.net/profile/Ductho-Le)
|
|
1221
1226
|
|
|
1222
|
-
<sub>
|
|
1227
|
+
<sub>May your signals be strong and your attenuation low 👋</sub>
|
|
1223
1228
|
|
|
1224
1229
|
</div>
|
|
@@ -14,10 +14,12 @@ wandb>=0.15.0
|
|
|
14
14
|
optuna>=3.0.0
|
|
15
15
|
onnx>=1.14.0
|
|
16
16
|
onnxruntime>=1.15.0
|
|
17
|
+
|
|
18
|
+
[:sys_platform == "linux"]
|
|
19
|
+
triton>=2.0.0
|
|
20
|
+
|
|
21
|
+
[dev]
|
|
17
22
|
pytest>=7.0.0
|
|
18
23
|
pytest-xdist>=3.5.0
|
|
19
24
|
ruff>=0.8.0
|
|
20
25
|
pre-commit>=3.5.0
|
|
21
|
-
|
|
22
|
-
[:sys_platform == "linux"]
|
|
23
|
-
triton>=2.0.0
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|