wavedl 1.5.1__tar.gz → 1.5.3__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (43) hide show
  1. {wavedl-1.5.1/src/wavedl.egg-info → wavedl-1.5.3}/PKG-INFO +23 -19
  2. {wavedl-1.5.1 → wavedl-1.5.3}/README.md +22 -18
  3. {wavedl-1.5.1 → wavedl-1.5.3}/src/wavedl/__init__.py +1 -1
  4. {wavedl-1.5.1 → wavedl-1.5.3}/src/wavedl/hpc.py +22 -18
  5. {wavedl-1.5.1 → wavedl-1.5.3}/src/wavedl/models/resnet.py +38 -9
  6. {wavedl-1.5.1 → wavedl-1.5.3}/src/wavedl/train.py +57 -55
  7. {wavedl-1.5.1 → wavedl-1.5.3}/src/wavedl/utils/config.py +3 -1
  8. {wavedl-1.5.1 → wavedl-1.5.3}/src/wavedl/utils/cross_validation.py +11 -0
  9. {wavedl-1.5.1 → wavedl-1.5.3}/src/wavedl/utils/data.py +87 -17
  10. {wavedl-1.5.1 → wavedl-1.5.3/src/wavedl.egg-info}/PKG-INFO +23 -19
  11. {wavedl-1.5.1 → wavedl-1.5.3}/LICENSE +0 -0
  12. {wavedl-1.5.1 → wavedl-1.5.3}/pyproject.toml +0 -0
  13. {wavedl-1.5.1 → wavedl-1.5.3}/setup.cfg +0 -0
  14. {wavedl-1.5.1 → wavedl-1.5.3}/src/wavedl/hpo.py +0 -0
  15. {wavedl-1.5.1 → wavedl-1.5.3}/src/wavedl/models/__init__.py +0 -0
  16. {wavedl-1.5.1 → wavedl-1.5.3}/src/wavedl/models/_template.py +0 -0
  17. {wavedl-1.5.1 → wavedl-1.5.3}/src/wavedl/models/base.py +0 -0
  18. {wavedl-1.5.1 → wavedl-1.5.3}/src/wavedl/models/cnn.py +0 -0
  19. {wavedl-1.5.1 → wavedl-1.5.3}/src/wavedl/models/convnext.py +0 -0
  20. {wavedl-1.5.1 → wavedl-1.5.3}/src/wavedl/models/densenet.py +0 -0
  21. {wavedl-1.5.1 → wavedl-1.5.3}/src/wavedl/models/efficientnet.py +0 -0
  22. {wavedl-1.5.1 → wavedl-1.5.3}/src/wavedl/models/efficientnetv2.py +0 -0
  23. {wavedl-1.5.1 → wavedl-1.5.3}/src/wavedl/models/mobilenetv3.py +0 -0
  24. {wavedl-1.5.1 → wavedl-1.5.3}/src/wavedl/models/registry.py +0 -0
  25. {wavedl-1.5.1 → wavedl-1.5.3}/src/wavedl/models/regnet.py +0 -0
  26. {wavedl-1.5.1 → wavedl-1.5.3}/src/wavedl/models/resnet3d.py +0 -0
  27. {wavedl-1.5.1 → wavedl-1.5.3}/src/wavedl/models/swin.py +0 -0
  28. {wavedl-1.5.1 → wavedl-1.5.3}/src/wavedl/models/tcn.py +0 -0
  29. {wavedl-1.5.1 → wavedl-1.5.3}/src/wavedl/models/unet.py +0 -0
  30. {wavedl-1.5.1 → wavedl-1.5.3}/src/wavedl/models/vit.py +0 -0
  31. {wavedl-1.5.1 → wavedl-1.5.3}/src/wavedl/test.py +0 -0
  32. {wavedl-1.5.1 → wavedl-1.5.3}/src/wavedl/utils/__init__.py +0 -0
  33. {wavedl-1.5.1 → wavedl-1.5.3}/src/wavedl/utils/constraints.py +0 -0
  34. {wavedl-1.5.1 → wavedl-1.5.3}/src/wavedl/utils/distributed.py +0 -0
  35. {wavedl-1.5.1 → wavedl-1.5.3}/src/wavedl/utils/losses.py +0 -0
  36. {wavedl-1.5.1 → wavedl-1.5.3}/src/wavedl/utils/metrics.py +0 -0
  37. {wavedl-1.5.1 → wavedl-1.5.3}/src/wavedl/utils/optimizers.py +0 -0
  38. {wavedl-1.5.1 → wavedl-1.5.3}/src/wavedl/utils/schedulers.py +0 -0
  39. {wavedl-1.5.1 → wavedl-1.5.3}/src/wavedl.egg-info/SOURCES.txt +0 -0
  40. {wavedl-1.5.1 → wavedl-1.5.3}/src/wavedl.egg-info/dependency_links.txt +0 -0
  41. {wavedl-1.5.1 → wavedl-1.5.3}/src/wavedl.egg-info/entry_points.txt +0 -0
  42. {wavedl-1.5.1 → wavedl-1.5.3}/src/wavedl.egg-info/requires.txt +0 -0
  43. {wavedl-1.5.1 → wavedl-1.5.3}/src/wavedl.egg-info/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.2
2
2
  Name: wavedl
3
- Version: 1.5.1
3
+ Version: 1.5.3
4
4
  Summary: A Scalable Deep Learning Framework for Wave-Based Inverse Problems
5
5
  Author: Ductho Le
6
6
  License: MIT
@@ -301,8 +301,8 @@ python -m wavedl.test --checkpoint <checkpoint_folder> --data_path <test_data> \
301
301
 
302
302
  **Requirements** (your model must):
303
303
  1. Inherit from `BaseModel`
304
- 2. Accept `in_channels`, `num_outputs`, `input_shape` in `__init__`
305
- 3. Return a tensor of shape `(batch, num_outputs)` from `forward()`
304
+ 2. Accept `in_shape`, `out_size` in `__init__`
305
+ 3. Return a tensor of shape `(batch, out_size)` from `forward()`
306
306
 
307
307
  ---
308
308
 
@@ -315,23 +315,22 @@ from wavedl.models import BaseModel, register_model
315
315
 
316
316
  @register_model("my_model") # This name is used with --model flag
317
317
  class MyModel(BaseModel):
318
- def __init__(self, in_channels, num_outputs, input_shape):
319
- # in_channels: number of input channels (auto-detected from data)
320
- # num_outputs: number of parameters to predict (auto-detected from data)
321
- # input_shape: spatial dimensions, e.g., (128,) or (64, 64) or (32, 32, 32)
322
- super().__init__(in_channels, num_outputs, input_shape)
323
-
324
- # Define your layers (this is just an example)
325
- self.conv1 = nn.Conv2d(in_channels, 64, 3, padding=1)
318
+ def __init__(self, in_shape, out_size, **kwargs):
319
+ # in_shape: spatial dimensions, e.g., (128,) or (64, 64) or (32, 32, 32)
320
+ # out_size: number of parameters to predict (auto-detected from data)
321
+ super().__init__(in_shape, out_size)
322
+
323
+ # Define your layers (this is just an example for 2D)
324
+ self.conv1 = nn.Conv2d(1, 64, 3, padding=1) # Input always has 1 channel
326
325
  self.conv2 = nn.Conv2d(64, 128, 3, padding=1)
327
- self.fc = nn.Linear(128, num_outputs)
326
+ self.fc = nn.Linear(128, out_size)
328
327
 
329
328
  def forward(self, x):
330
- # Input x has shape: (batch, in_channels, *input_shape)
329
+ # Input x has shape: (batch, 1, *in_shape)
331
330
  x = F.relu(self.conv1(x))
332
331
  x = F.relu(self.conv2(x))
333
332
  x = x.mean(dim=[-2, -1]) # Global average pooling
334
- return self.fc(x) # Output shape: (batch, num_outputs)
333
+ return self.fc(x) # Output shape: (batch, out_size)
335
334
  ```
336
335
 
337
336
  **Step 2: Train**
@@ -573,14 +572,19 @@ WaveDL automatically enables performance optimizations for modern GPUs:
573
572
  </details>
574
573
 
575
574
  <details>
576
- <summary><b>Environment Variables (wavedl-hpc)</b></summary>
575
+ <summary><b>HPC CLI Arguments (wavedl-hpc)</b></summary>
576
+
577
+ | Argument | Default | Description |
578
+ |----------|---------|-------------|
579
+ | `--num_gpus` | **Auto-detected** | Number of GPUs to use. By default, automatically detected via `nvidia-smi`. Set explicitly to override |
580
+ | `--num_machines` | `1` | Number of machines in distributed setup |
581
+ | `--mixed_precision` | `bf16` | Precision mode: `bf16`, `fp16`, or `no` |
582
+ | `--dynamo_backend` | `no` | PyTorch Dynamo backend |
583
+
584
+ **Environment Variables (for logging):**
577
585
 
578
586
  | Variable | Default | Description |
579
587
  |----------|---------|-------------|
580
- | `NUM_GPUS` | **Auto-detected** | Number of GPUs to use. By default, automatically detected via `nvidia-smi`. Set explicitly to override (e.g., `NUM_GPUS=2`) |
581
- | `NUM_MACHINES` | `1` | Number of machines in distributed setup |
582
- | `MIXED_PRECISION` | `bf16` | Precision mode: `bf16`, `fp16`, or `no` |
583
- | `DYNAMO_BACKEND` | `no` | PyTorch Dynamo backend |
584
588
  | `WANDB_MODE` | `offline` | WandB mode: `offline` or `online` |
585
589
 
586
590
  </details>
@@ -256,8 +256,8 @@ python -m wavedl.test --checkpoint <checkpoint_folder> --data_path <test_data> \
256
256
 
257
257
  **Requirements** (your model must):
258
258
  1. Inherit from `BaseModel`
259
- 2. Accept `in_channels`, `num_outputs`, `input_shape` in `__init__`
260
- 3. Return a tensor of shape `(batch, num_outputs)` from `forward()`
259
+ 2. Accept `in_shape`, `out_size` in `__init__`
260
+ 3. Return a tensor of shape `(batch, out_size)` from `forward()`
261
261
 
262
262
  ---
263
263
 
@@ -270,23 +270,22 @@ from wavedl.models import BaseModel, register_model
270
270
 
271
271
  @register_model("my_model") # This name is used with --model flag
272
272
  class MyModel(BaseModel):
273
- def __init__(self, in_channels, num_outputs, input_shape):
274
- # in_channels: number of input channels (auto-detected from data)
275
- # num_outputs: number of parameters to predict (auto-detected from data)
276
- # input_shape: spatial dimensions, e.g., (128,) or (64, 64) or (32, 32, 32)
277
- super().__init__(in_channels, num_outputs, input_shape)
278
-
279
- # Define your layers (this is just an example)
280
- self.conv1 = nn.Conv2d(in_channels, 64, 3, padding=1)
273
+ def __init__(self, in_shape, out_size, **kwargs):
274
+ # in_shape: spatial dimensions, e.g., (128,) or (64, 64) or (32, 32, 32)
275
+ # out_size: number of parameters to predict (auto-detected from data)
276
+ super().__init__(in_shape, out_size)
277
+
278
+ # Define your layers (this is just an example for 2D)
279
+ self.conv1 = nn.Conv2d(1, 64, 3, padding=1) # Input always has 1 channel
281
280
  self.conv2 = nn.Conv2d(64, 128, 3, padding=1)
282
- self.fc = nn.Linear(128, num_outputs)
281
+ self.fc = nn.Linear(128, out_size)
283
282
 
284
283
  def forward(self, x):
285
- # Input x has shape: (batch, in_channels, *input_shape)
284
+ # Input x has shape: (batch, 1, *in_shape)
286
285
  x = F.relu(self.conv1(x))
287
286
  x = F.relu(self.conv2(x))
288
287
  x = x.mean(dim=[-2, -1]) # Global average pooling
289
- return self.fc(x) # Output shape: (batch, num_outputs)
288
+ return self.fc(x) # Output shape: (batch, out_size)
290
289
  ```
291
290
 
292
291
  **Step 2: Train**
@@ -528,14 +527,19 @@ WaveDL automatically enables performance optimizations for modern GPUs:
528
527
  </details>
529
528
 
530
529
  <details>
531
- <summary><b>Environment Variables (wavedl-hpc)</b></summary>
530
+ <summary><b>HPC CLI Arguments (wavedl-hpc)</b></summary>
531
+
532
+ | Argument | Default | Description |
533
+ |----------|---------|-------------|
534
+ | `--num_gpus` | **Auto-detected** | Number of GPUs to use. By default, automatically detected via `nvidia-smi`. Set explicitly to override |
535
+ | `--num_machines` | `1` | Number of machines in distributed setup |
536
+ | `--mixed_precision` | `bf16` | Precision mode: `bf16`, `fp16`, or `no` |
537
+ | `--dynamo_backend` | `no` | PyTorch Dynamo backend |
538
+
539
+ **Environment Variables (for logging):**
532
540
 
533
541
  | Variable | Default | Description |
534
542
  |----------|---------|-------------|
535
- | `NUM_GPUS` | **Auto-detected** | Number of GPUs to use. By default, automatically detected via `nvidia-smi`. Set explicitly to override (e.g., `NUM_GPUS=2`) |
536
- | `NUM_MACHINES` | `1` | Number of machines in distributed setup |
537
- | `MIXED_PRECISION` | `bf16` | Precision mode: `bf16`, `fp16`, or `no` |
538
- | `DYNAMO_BACKEND` | `no` | PyTorch Dynamo backend |
539
543
  | `WANDB_MODE` | `offline` | WandB mode: `offline` or `online` |
540
544
 
541
545
  </details>
@@ -18,7 +18,7 @@ For inference:
18
18
  # or: python -m wavedl.test --checkpoint best_checkpoint --data_path test.npz
19
19
  """
20
20
 
21
- __version__ = "1.5.1"
21
+ __version__ = "1.5.3"
22
22
  __author__ = "Ductho Le"
23
23
  __email__ = "ductho.le@outlook.com"
24
24
 
@@ -57,30 +57,35 @@ def setup_hpc_environment() -> None:
57
57
  """Configure environment variables for HPC systems.
58
58
 
59
59
  Handles restricted home directories (e.g., Compute Canada) and
60
- offline logging configurations.
60
+ offline logging configurations. Always uses CWD-based TORCH_HOME
61
+ since compute nodes typically lack internet access.
61
62
  """
62
- # Check if home is writable
63
+ # Use CWD for cache base since HPC compute nodes typically lack internet
64
+ cache_base = os.getcwd()
65
+
66
+ # TORCH_HOME always set to CWD - compute nodes need pre-cached weights
67
+ os.environ.setdefault("TORCH_HOME", f"{cache_base}/.torch_cache")
68
+ Path(os.environ["TORCH_HOME"]).mkdir(parents=True, exist_ok=True)
69
+
70
+ # Triton/Inductor caches - prevents permission errors with --compile
71
+ # These MUST be set before any torch.compile calls
72
+ os.environ.setdefault("TRITON_CACHE_DIR", f"{cache_base}/.triton_cache")
73
+ os.environ.setdefault("TORCHINDUCTOR_CACHE_DIR", f"{cache_base}/.inductor_cache")
74
+ Path(os.environ["TRITON_CACHE_DIR"]).mkdir(parents=True, exist_ok=True)
75
+ Path(os.environ["TORCHINDUCTOR_CACHE_DIR"]).mkdir(parents=True, exist_ok=True)
76
+
77
+ # Check if home is writable for other caches
63
78
  home = os.path.expanduser("~")
64
79
  home_writable = os.access(home, os.W_OK)
65
80
 
66
- # Use SLURM_TMPDIR if available, otherwise CWD for HPC, or system temp
67
- if home_writable:
68
- # Local machine - let libraries use defaults
69
- cache_base = None
70
- else:
71
- # HPC with restricted home - use CWD for persistent caches
72
- cache_base = os.getcwd()
73
-
74
- # Only set environment variables if home is not writable
75
- if cache_base:
76
- os.environ.setdefault("TORCH_HOME", f"{cache_base}/.torch_cache")
81
+ # Other caches only if home is not writable
82
+ if not home_writable:
77
83
  os.environ.setdefault("MPLCONFIGDIR", f"{cache_base}/.matplotlib")
78
84
  os.environ.setdefault("FONTCONFIG_CACHE", f"{cache_base}/.fontconfig")
79
85
  os.environ.setdefault("XDG_CACHE_HOME", f"{cache_base}/.cache")
80
86
 
81
87
  # Ensure directories exist
82
88
  for env_var in [
83
- "TORCH_HOME",
84
89
  "MPLCONFIGDIR",
85
90
  "FONTCONFIG_CACHE",
86
91
  "XDG_CACHE_HOME",
@@ -89,10 +94,9 @@ def setup_hpc_environment() -> None:
89
94
 
90
95
  # WandB configuration (offline by default for HPC)
91
96
  os.environ.setdefault("WANDB_MODE", "offline")
92
- if cache_base:
93
- os.environ.setdefault("WANDB_DIR", f"{cache_base}/.wandb")
94
- os.environ.setdefault("WANDB_CACHE_DIR", f"{cache_base}/.wandb_cache")
95
- os.environ.setdefault("WANDB_CONFIG_DIR", f"{cache_base}/.wandb_config")
97
+ os.environ.setdefault("WANDB_DIR", f"{cache_base}/.wandb")
98
+ os.environ.setdefault("WANDB_CACHE_DIR", f"{cache_base}/.wandb_cache")
99
+ os.environ.setdefault("WANDB_CONFIG_DIR", f"{cache_base}/.wandb_config")
96
100
 
97
101
  # Suppress non-critical warnings
98
102
  os.environ.setdefault(
@@ -49,6 +49,36 @@ def _get_conv_layers(
49
49
  raise ValueError(f"Unsupported dimensionality: {dim}D. Supported: 1D, 2D, 3D.")
50
50
 
51
51
 
52
+ def _get_num_groups(num_channels: int, preferred_groups: int = 32) -> int:
53
+ """
54
+ Get valid num_groups for GroupNorm that divides num_channels evenly.
55
+
56
+ Args:
57
+ num_channels: Number of channels to normalize
58
+ preferred_groups: Preferred number of groups (default: 32)
59
+
60
+ Returns:
61
+ Valid num_groups that divides num_channels
62
+
63
+ Raises:
64
+ ValueError: If no valid divisor found (shouldn't happen with power-of-2 channels)
65
+ """
66
+ # Try preferred groups first, then decrease
67
+ for groups in [preferred_groups, 16, 8, 4, 2, 1]:
68
+ if groups <= num_channels and num_channels % groups == 0:
69
+ return groups
70
+
71
+ # Fallback: find any valid divisor
72
+ for groups in range(min(32, num_channels), 0, -1):
73
+ if num_channels % groups == 0:
74
+ return groups
75
+
76
+ raise ValueError(
77
+ f"Cannot find valid num_groups for {num_channels} channels. "
78
+ f"Consider using base_width that is a power of 2 (e.g., 32, 64, 128)."
79
+ )
80
+
81
+
52
82
  class BasicBlock(nn.Module):
53
83
  """
54
84
  Basic residual block for ResNet-18/34.
@@ -77,12 +107,12 @@ class BasicBlock(nn.Module):
77
107
  padding=1,
78
108
  bias=False,
79
109
  )
80
- self.gn1 = nn.GroupNorm(min(32, out_channels), out_channels)
110
+ self.gn1 = nn.GroupNorm(_get_num_groups(out_channels), out_channels)
81
111
  self.relu = nn.ReLU(inplace=True)
82
112
  self.conv2 = Conv(
83
113
  out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False
84
114
  )
85
- self.gn2 = nn.GroupNorm(min(32, out_channels), out_channels)
115
+ self.gn2 = nn.GroupNorm(_get_num_groups(out_channels), out_channels)
86
116
  self.downsample = downsample
87
117
 
88
118
  def forward(self, x: torch.Tensor) -> torch.Tensor:
@@ -125,7 +155,7 @@ class Bottleneck(nn.Module):
125
155
 
126
156
  # 1x1 reduce
127
157
  self.conv1 = Conv(in_channels, out_channels, kernel_size=1, bias=False)
128
- self.gn1 = nn.GroupNorm(min(32, out_channels), out_channels)
158
+ self.gn1 = nn.GroupNorm(_get_num_groups(out_channels), out_channels)
129
159
 
130
160
  # 3x3 conv
131
161
  self.conv2 = Conv(
@@ -136,15 +166,14 @@ class Bottleneck(nn.Module):
136
166
  padding=1,
137
167
  bias=False,
138
168
  )
139
- self.gn2 = nn.GroupNorm(min(32, out_channels), out_channels)
169
+ self.gn2 = nn.GroupNorm(_get_num_groups(out_channels), out_channels)
140
170
 
141
171
  # 1x1 expand
142
172
  self.conv3 = Conv(
143
173
  out_channels, out_channels * self.expansion, kernel_size=1, bias=False
144
174
  )
145
- self.gn3 = nn.GroupNorm(
146
- min(32, out_channels * self.expansion), out_channels * self.expansion
147
- )
175
+ expanded_channels = out_channels * self.expansion
176
+ self.gn3 = nn.GroupNorm(_get_num_groups(expanded_channels), expanded_channels)
148
177
 
149
178
  self.relu = nn.ReLU(inplace=True)
150
179
  self.downsample = downsample
@@ -200,7 +229,7 @@ class ResNetBase(BaseModel):
200
229
 
201
230
  # Stem: 7x7 conv (or equivalent for 1D/3D)
202
231
  self.conv1 = Conv(1, base_width, kernel_size=7, stride=2, padding=3, bias=False)
203
- self.gn1 = nn.GroupNorm(min(32, base_width), base_width)
232
+ self.gn1 = nn.GroupNorm(_get_num_groups(base_width), base_width)
204
233
  self.relu = nn.ReLU(inplace=True)
205
234
  self.maxpool = MaxPool(kernel_size=3, stride=2, padding=1)
206
235
 
@@ -246,7 +275,7 @@ class ResNetBase(BaseModel):
246
275
  bias=False,
247
276
  ),
248
277
  nn.GroupNorm(
249
- min(32, out_channels * block.expansion),
278
+ _get_num_groups(out_channels * block.expansion),
250
279
  out_channels * block.expansion,
251
280
  ),
252
281
  )
@@ -69,6 +69,39 @@ _setup_cache_dir("XDG_DATA_HOME", "local/share")
69
69
  _setup_cache_dir("XDG_STATE_HOME", "local/state")
70
70
  _setup_cache_dir("XDG_CACHE_HOME", "cache")
71
71
 
72
+
73
+ def _setup_per_rank_compile_cache() -> None:
74
+ """Set per-GPU Triton/Inductor cache to prevent multi-process race warnings.
75
+
76
+ When using torch.compile with multiple GPUs, all processes try to write to
77
+ the same cache directory, causing 'Directory is not empty - skipping!' warnings.
78
+ This gives each GPU rank its own isolated cache subdirectory.
79
+ """
80
+ # Get local rank from environment (set by accelerate/torchrun)
81
+ local_rank = os.environ.get("LOCAL_RANK", "0")
82
+
83
+ # Get cache base from environment or use CWD
84
+ cache_base = os.environ.get(
85
+ "TRITON_CACHE_DIR", os.path.join(os.getcwd(), ".triton_cache")
86
+ )
87
+
88
+ # Set per-rank cache directories
89
+ os.environ["TRITON_CACHE_DIR"] = os.path.join(cache_base, f"rank_{local_rank}")
90
+ os.environ["TORCHINDUCTOR_CACHE_DIR"] = os.path.join(
91
+ os.environ.get(
92
+ "TORCHINDUCTOR_CACHE_DIR", os.path.join(os.getcwd(), ".inductor_cache")
93
+ ),
94
+ f"rank_{local_rank}",
95
+ )
96
+
97
+ # Create directories
98
+ os.makedirs(os.environ["TRITON_CACHE_DIR"], exist_ok=True)
99
+ os.makedirs(os.environ["TORCHINDUCTOR_CACHE_DIR"], exist_ok=True)
100
+
101
+
102
+ # Setup per-rank compile caches (before torch imports)
103
+ _setup_per_rank_compile_cache()
104
+
72
105
  # =============================================================================
73
106
  # Standard imports (after environment setup)
74
107
  # =============================================================================
@@ -908,7 +941,6 @@ def main():
908
941
  logger.info("=" * len(header))
909
942
 
910
943
  try:
911
- time.time()
912
944
  total_training_time = 0.0
913
945
 
914
946
  for epoch in range(start_epoch, args.epochs):
@@ -1002,49 +1034,29 @@ def main():
1002
1034
  local_preds.append(pred.detach().cpu())
1003
1035
  local_targets.append(y.detach().cpu())
1004
1036
 
1005
- # Concatenate locally on CPU (no GPU memory spike)
1006
- cpu_preds = torch.cat(local_preds)
1007
- cpu_targets = torch.cat(local_targets)
1008
-
1009
- # Gather predictions and targets to rank 0 only (memory-efficient)
1010
- # Avoids duplicating full validation set on every GPU
1011
- if torch.distributed.is_initialized():
1012
- # DDP mode: gather only to rank 0
1013
- # NCCL backend requires CUDA tensors for collective ops
1014
- gpu_preds = cpu_preds.to(accelerator.device)
1015
- gpu_targets = cpu_targets.to(accelerator.device)
1016
-
1017
- if accelerator.is_main_process:
1018
- # Rank 0: allocate gather buffers on GPU
1019
- all_preds_list = [
1020
- torch.zeros_like(gpu_preds)
1021
- for _ in range(accelerator.num_processes)
1022
- ]
1023
- all_targets_list = [
1024
- torch.zeros_like(gpu_targets)
1025
- for _ in range(accelerator.num_processes)
1026
- ]
1027
- torch.distributed.gather(
1028
- gpu_preds, gather_list=all_preds_list, dst=0
1029
- )
1030
- torch.distributed.gather(
1031
- gpu_targets, gather_list=all_targets_list, dst=0
1032
- )
1033
- # Move back to CPU for metric computation
1034
- gathered = [
1035
- (
1036
- torch.cat(all_preds_list).cpu(),
1037
- torch.cat(all_targets_list).cpu(),
1038
- )
1039
- ]
1040
- else:
1041
- # Other ranks: send to rank 0, don't allocate gather buffers
1042
- torch.distributed.gather(gpu_preds, gather_list=None, dst=0)
1043
- torch.distributed.gather(gpu_targets, gather_list=None, dst=0)
1044
- gathered = [(cpu_preds, cpu_targets)] # Placeholder, not used
1037
+ # Concatenate locally (keep on GPU for gather_for_metrics compatibility)
1038
+ local_preds_cat = torch.cat(local_preds)
1039
+ local_targets_cat = torch.cat(local_targets)
1040
+
1041
+ # Gather predictions and targets using Accelerate's CPU-efficient utility
1042
+ # gather_for_metrics handles:
1043
+ # - DDP padding removal (no need to trim manually)
1044
+ # - Efficient cross-rank gathering without GPU memory spike
1045
+ # - Returns concatenated tensors on CPU for metric computation
1046
+ if accelerator.num_processes > 1:
1047
+ # Move to GPU for gather (required by NCCL), then back to CPU
1048
+ # gather_for_metrics is more memory-efficient than manual gather
1049
+ # as it processes in chunks internally
1050
+ gathered_preds = accelerator.gather_for_metrics(
1051
+ local_preds_cat.to(accelerator.device)
1052
+ ).cpu()
1053
+ gathered_targets = accelerator.gather_for_metrics(
1054
+ local_targets_cat.to(accelerator.device)
1055
+ ).cpu()
1045
1056
  else:
1046
1057
  # Single-GPU mode: no gathering needed
1047
- gathered = [(cpu_preds, cpu_targets)]
1058
+ gathered_preds = local_preds_cat
1059
+ gathered_targets = local_targets_cat
1048
1060
 
1049
1061
  # Synchronize validation metrics (scalars only - efficient)
1050
1062
  val_loss_scalar = val_loss_sum.item()
@@ -1069,20 +1081,10 @@ def main():
1069
1081
 
1070
1082
  # ==================== LOGGING & CHECKPOINTING ====================
1071
1083
  if accelerator.is_main_process:
1072
- # Concatenate gathered tensors from all ranks (only on rank 0)
1073
- # gathered is list of tuples: [(preds_rank0, targs_rank0), (preds_rank1, targs_rank1), ...]
1074
- all_preds = torch.cat([item[0] for item in gathered])
1075
- all_targets = torch.cat([item[1] for item in gathered])
1076
-
1077
1084
  # Scientific metrics - cast to float32 before numpy
1078
- y_pred = all_preds.float().numpy()
1079
- y_true = all_targets.float().numpy()
1080
-
1081
- # Trim DDP padding
1082
- real_len = len(val_dl.dataset)
1083
- if len(y_pred) > real_len:
1084
- y_pred = y_pred[:real_len]
1085
- y_true = y_true[:real_len]
1085
+ # gather_for_metrics already handles DDP padding removal
1086
+ y_pred = gathered_preds.float().numpy()
1087
+ y_true = gathered_targets.float().numpy()
1086
1088
 
1087
1089
  # Guard against tiny validation sets (R² undefined for <2 samples)
1088
1090
  if len(y_true) >= 2:
@@ -183,9 +183,11 @@ def save_config(
183
183
  config[key] = value
184
184
 
185
185
  # Add metadata
186
+ from wavedl import __version__
187
+
186
188
  config["_metadata"] = {
187
189
  "saved_at": datetime.now().isoformat(),
188
- "wavedl_version": "1.0.0",
190
+ "wavedl_version": __version__,
189
191
  }
190
192
 
191
193
  output_path = Path(output_path)
@@ -337,6 +337,17 @@ def run_cross_validation(
337
337
  torch.cuda.manual_seed_all(seed)
338
338
 
339
339
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
340
+
341
+ # Auto-detect optimal DataLoader workers if not specified (matches train.py behavior)
342
+ if workers < 0:
343
+ cpu_count = os.cpu_count() or 4
344
+ num_gpus = torch.cuda.device_count() if torch.cuda.is_available() else 1
345
+ # Heuristic: 4-16 workers per GPU, bounded by available CPU cores
346
+ workers = min(16, max(2, (cpu_count - 2) // max(1, num_gpus)))
347
+ logger.info(
348
+ f"⚙️ Auto-detected workers: {workers} (CPUs: {cpu_count}, GPUs: {num_gpus})"
349
+ )
350
+
340
351
  logger.info(f"🚀 K-Fold Cross-Validation ({folds} folds)")
341
352
  logger.info(f" Model: {model_name} | Device: {device}")
342
353
  logger.info(
@@ -202,18 +202,31 @@ class NPZSource(DataSource):
202
202
  """Load data from NumPy .npz archives."""
203
203
 
204
204
  @staticmethod
205
- def _safe_load(path: str, mmap_mode: str | None = None):
206
- """Load NPZ with pickle only if needed (sparse matrix support)."""
205
+ def _safe_load(path: str, keys_to_probe: list[str], mmap_mode: str | None = None):
206
+ """Load NPZ with pickle only if needed (sparse matrix support).
207
+
208
+ The error for object arrays happens at ACCESS time, not load time.
209
+ So we need to probe the keys to detect if pickle is required.
210
+ """
211
+ data = np.load(path, allow_pickle=False, mmap_mode=mmap_mode)
207
212
  try:
208
- return np.load(path, allow_pickle=False, mmap_mode=mmap_mode)
209
- except ValueError:
210
- # Fallback for sparse matrices stored as object arrays
211
- return np.load(path, allow_pickle=True, mmap_mode=mmap_mode)
213
+ # Probe keys to trigger error if object arrays exist
214
+ for key in keys_to_probe:
215
+ if key in data:
216
+ _ = data[key] # This raises ValueError for object arrays
217
+ return data
218
+ except ValueError as e:
219
+ if "allow_pickle=False" in str(e):
220
+ # Fallback for sparse matrices stored as object arrays
221
+ data.close() if hasattr(data, "close") else None
222
+ return np.load(path, allow_pickle=True, mmap_mode=mmap_mode)
223
+ raise
212
224
 
213
225
  def load(self, path: str) -> tuple[np.ndarray, np.ndarray]:
214
226
  """Load NPZ file (pickle enabled only for sparse matrices)."""
215
- data = self._safe_load(path)
216
- keys = list(data.keys())
227
+ # First pass to find keys without loading data
228
+ with np.load(path, allow_pickle=False) as probe:
229
+ keys = list(probe.keys())
217
230
 
218
231
  input_key = self._find_key(keys, INPUT_KEYS)
219
232
  output_key = self._find_key(keys, OUTPUT_KEYS)
@@ -225,6 +238,7 @@ class NPZSource(DataSource):
225
238
  f"Found: {keys}"
226
239
  )
227
240
 
241
+ data = self._safe_load(path, [input_key, output_key])
228
242
  inp = data[input_key]
229
243
  outp = data[output_key]
230
244
 
@@ -243,8 +257,9 @@ class NPZSource(DataSource):
243
257
 
244
258
  Note: Returns memory-mapped arrays - do NOT modify them.
245
259
  """
246
- data = self._safe_load(path, mmap_mode="r")
247
- keys = list(data.keys())
260
+ # First pass to find keys without loading data
261
+ with np.load(path, allow_pickle=False) as probe:
262
+ keys = list(probe.keys())
248
263
 
249
264
  input_key = self._find_key(keys, INPUT_KEYS)
250
265
  output_key = self._find_key(keys, OUTPUT_KEYS)
@@ -256,6 +271,7 @@ class NPZSource(DataSource):
256
271
  f"Found: {keys}"
257
272
  )
258
273
 
274
+ data = self._safe_load(path, [input_key, output_key], mmap_mode="r")
259
275
  inp = data[input_key]
260
276
  outp = data[output_key]
261
277
 
@@ -263,8 +279,9 @@ class NPZSource(DataSource):
263
279
 
264
280
  def load_outputs_only(self, path: str) -> np.ndarray:
265
281
  """Load only targets from NPZ (avoids loading large input arrays)."""
266
- data = self._safe_load(path)
267
- keys = list(data.keys())
282
+ # First pass to find keys without loading data
283
+ with np.load(path, allow_pickle=False) as probe:
284
+ keys = list(probe.keys())
268
285
 
269
286
  output_key = self._find_key(keys, OUTPUT_KEYS)
270
287
  if output_key is None:
@@ -273,6 +290,7 @@ class NPZSource(DataSource):
273
290
  f"Supported keys: {OUTPUT_KEYS}. Found: {keys}"
274
291
  )
275
292
 
293
+ data = self._safe_load(path, [output_key])
276
294
  return data[output_key]
277
295
 
278
296
 
@@ -745,25 +763,77 @@ def load_test_data(
745
763
  k for k in OUTPUT_KEYS if k != "output_test"
746
764
  ]
747
765
 
748
- # Load data using appropriate source
766
+ # Load data using appropriate source with test-key priority
767
+ # We detect keys first to ensure input_test/output_test are used when present
749
768
  try:
750
- inp, outp = source.load(path)
769
+ if format == "npz":
770
+ with np.load(path, allow_pickle=False) as probe:
771
+ keys = list(probe.keys())
772
+ inp_key = DataSource._find_key(keys, custom_input_keys)
773
+ out_key = DataSource._find_key(keys, custom_output_keys)
774
+ if inp_key is None:
775
+ raise KeyError(
776
+ f"Input key not found. Tried: {custom_input_keys}. Found: {keys}"
777
+ )
778
+ data = NPZSource._safe_load(
779
+ path, [inp_key] + ([out_key] if out_key else [])
780
+ )
781
+ inp = data[inp_key]
782
+ if inp.dtype == object:
783
+ inp = np.array(
784
+ [x.toarray() if hasattr(x, "toarray") else x for x in inp]
785
+ )
786
+ outp = data[out_key] if out_key else None
787
+ elif format == "hdf5":
788
+ with h5py.File(path, "r") as f:
789
+ keys = list(f.keys())
790
+ inp_key = DataSource._find_key(keys, custom_input_keys)
791
+ out_key = DataSource._find_key(keys, custom_output_keys)
792
+ if inp_key is None:
793
+ raise KeyError(
794
+ f"Input key not found. Tried: {custom_input_keys}. Found: {keys}"
795
+ )
796
+ inp = f[inp_key][:]
797
+ outp = f[out_key][:] if out_key else None
798
+ elif format == "mat":
799
+ mat_source = MATSource()
800
+ with h5py.File(path, "r") as f:
801
+ keys = list(f.keys())
802
+ inp_key = DataSource._find_key(keys, custom_input_keys)
803
+ out_key = DataSource._find_key(keys, custom_output_keys)
804
+ if inp_key is None:
805
+ raise KeyError(
806
+ f"Input key not found. Tried: {custom_input_keys}. Found: {keys}"
807
+ )
808
+ inp = mat_source._load_dataset(f, inp_key)
809
+ if out_key:
810
+ outp = mat_source._load_dataset(f, out_key)
811
+ if outp.ndim == 2 and outp.shape[0] == 1:
812
+ outp = outp.T
813
+ else:
814
+ outp = None
815
+ else:
816
+ # Fallback to default source.load() for unknown formats
817
+ inp, outp = source.load(path)
751
818
  except KeyError:
752
819
  # Try with just inputs if outputs not found (inference-only mode)
753
820
  if format == "npz":
754
- data = NPZSource._safe_load(path)
755
- keys = list(data.keys())
821
+ # First pass to find keys
822
+ with np.load(path, allow_pickle=False) as probe:
823
+ keys = list(probe.keys())
756
824
  inp_key = DataSource._find_key(keys, custom_input_keys)
757
825
  if inp_key is None:
758
826
  raise KeyError(
759
827
  f"Input key not found. Tried: {custom_input_keys}. Found: {keys}"
760
828
  )
829
+ out_key = DataSource._find_key(keys, custom_output_keys)
830
+ keys_to_probe = [inp_key] + ([out_key] if out_key else [])
831
+ data = NPZSource._safe_load(path, keys_to_probe)
761
832
  inp = data[inp_key]
762
833
  if inp.dtype == object:
763
834
  inp = np.array(
764
835
  [x.toarray() if hasattr(x, "toarray") else x for x in inp]
765
836
  )
766
- out_key = DataSource._find_key(keys, custom_output_keys)
767
837
  outp = data[out_key] if out_key else None
768
838
  elif format == "hdf5":
769
839
  # HDF5: input-only loading for inference
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.2
2
2
  Name: wavedl
3
- Version: 1.5.1
3
+ Version: 1.5.3
4
4
  Summary: A Scalable Deep Learning Framework for Wave-Based Inverse Problems
5
5
  Author: Ductho Le
6
6
  License: MIT
@@ -301,8 +301,8 @@ python -m wavedl.test --checkpoint <checkpoint_folder> --data_path <test_data> \
301
301
 
302
302
  **Requirements** (your model must):
303
303
  1. Inherit from `BaseModel`
304
- 2. Accept `in_channels`, `num_outputs`, `input_shape` in `__init__`
305
- 3. Return a tensor of shape `(batch, num_outputs)` from `forward()`
304
+ 2. Accept `in_shape`, `out_size` in `__init__`
305
+ 3. Return a tensor of shape `(batch, out_size)` from `forward()`
306
306
 
307
307
  ---
308
308
 
@@ -315,23 +315,22 @@ from wavedl.models import BaseModel, register_model
315
315
 
316
316
  @register_model("my_model") # This name is used with --model flag
317
317
  class MyModel(BaseModel):
318
- def __init__(self, in_channels, num_outputs, input_shape):
319
- # in_channels: number of input channels (auto-detected from data)
320
- # num_outputs: number of parameters to predict (auto-detected from data)
321
- # input_shape: spatial dimensions, e.g., (128,) or (64, 64) or (32, 32, 32)
322
- super().__init__(in_channels, num_outputs, input_shape)
323
-
324
- # Define your layers (this is just an example)
325
- self.conv1 = nn.Conv2d(in_channels, 64, 3, padding=1)
318
+ def __init__(self, in_shape, out_size, **kwargs):
319
+ # in_shape: spatial dimensions, e.g., (128,) or (64, 64) or (32, 32, 32)
320
+ # out_size: number of parameters to predict (auto-detected from data)
321
+ super().__init__(in_shape, out_size)
322
+
323
+ # Define your layers (this is just an example for 2D)
324
+ self.conv1 = nn.Conv2d(1, 64, 3, padding=1) # Input always has 1 channel
326
325
  self.conv2 = nn.Conv2d(64, 128, 3, padding=1)
327
- self.fc = nn.Linear(128, num_outputs)
326
+ self.fc = nn.Linear(128, out_size)
328
327
 
329
328
  def forward(self, x):
330
- # Input x has shape: (batch, in_channels, *input_shape)
329
+ # Input x has shape: (batch, 1, *in_shape)
331
330
  x = F.relu(self.conv1(x))
332
331
  x = F.relu(self.conv2(x))
333
332
  x = x.mean(dim=[-2, -1]) # Global average pooling
334
- return self.fc(x) # Output shape: (batch, num_outputs)
333
+ return self.fc(x) # Output shape: (batch, out_size)
335
334
  ```
336
335
 
337
336
  **Step 2: Train**
@@ -573,14 +572,19 @@ WaveDL automatically enables performance optimizations for modern GPUs:
573
572
  </details>
574
573
 
575
574
  <details>
576
- <summary><b>Environment Variables (wavedl-hpc)</b></summary>
575
+ <summary><b>HPC CLI Arguments (wavedl-hpc)</b></summary>
576
+
577
+ | Argument | Default | Description |
578
+ |----------|---------|-------------|
579
+ | `--num_gpus` | **Auto-detected** | Number of GPUs to use. By default, automatically detected via `nvidia-smi`. Set explicitly to override |
580
+ | `--num_machines` | `1` | Number of machines in distributed setup |
581
+ | `--mixed_precision` | `bf16` | Precision mode: `bf16`, `fp16`, or `no` |
582
+ | `--dynamo_backend` | `no` | PyTorch Dynamo backend |
583
+
584
+ **Environment Variables (for logging):**
577
585
 
578
586
  | Variable | Default | Description |
579
587
  |----------|---------|-------------|
580
- | `NUM_GPUS` | **Auto-detected** | Number of GPUs to use. By default, automatically detected via `nvidia-smi`. Set explicitly to override (e.g., `NUM_GPUS=2`) |
581
- | `NUM_MACHINES` | `1` | Number of machines in distributed setup |
582
- | `MIXED_PRECISION` | `bf16` | Precision mode: `bf16`, `fp16`, or `no` |
583
- | `DYNAMO_BACKEND` | `no` | PyTorch Dynamo backend |
584
588
  | `WANDB_MODE` | `offline` | WandB mode: `offline` or `online` |
585
589
 
586
590
  </details>
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes