wavedl 1.5.3__tar.gz → 1.5.5__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (43) hide show
  1. {wavedl-1.5.3/src/wavedl.egg-info → wavedl-1.5.5}/PKG-INFO +11 -10
  2. {wavedl-1.5.3 → wavedl-1.5.5}/README.md +5 -5
  3. {wavedl-1.5.3 → wavedl-1.5.5}/pyproject.toml +6 -3
  4. {wavedl-1.5.3 → wavedl-1.5.5}/src/wavedl/__init__.py +1 -1
  5. {wavedl-1.5.3 → wavedl-1.5.5}/src/wavedl/hpo.py +2 -1
  6. {wavedl-1.5.3 → wavedl-1.5.5}/src/wavedl/models/swin.py +31 -10
  7. {wavedl-1.5.3 → wavedl-1.5.5}/src/wavedl/train.py +32 -4
  8. {wavedl-1.5.3 → wavedl-1.5.5}/src/wavedl/utils/data.py +104 -13
  9. {wavedl-1.5.3 → wavedl-1.5.5}/src/wavedl/utils/metrics.py +22 -1
  10. {wavedl-1.5.3 → wavedl-1.5.5/src/wavedl.egg-info}/PKG-INFO +11 -10
  11. {wavedl-1.5.3 → wavedl-1.5.5}/src/wavedl.egg-info/requires.txt +5 -3
  12. {wavedl-1.5.3 → wavedl-1.5.5}/LICENSE +0 -0
  13. {wavedl-1.5.3 → wavedl-1.5.5}/setup.cfg +0 -0
  14. {wavedl-1.5.3 → wavedl-1.5.5}/src/wavedl/hpc.py +0 -0
  15. {wavedl-1.5.3 → wavedl-1.5.5}/src/wavedl/models/__init__.py +0 -0
  16. {wavedl-1.5.3 → wavedl-1.5.5}/src/wavedl/models/_template.py +0 -0
  17. {wavedl-1.5.3 → wavedl-1.5.5}/src/wavedl/models/base.py +0 -0
  18. {wavedl-1.5.3 → wavedl-1.5.5}/src/wavedl/models/cnn.py +0 -0
  19. {wavedl-1.5.3 → wavedl-1.5.5}/src/wavedl/models/convnext.py +0 -0
  20. {wavedl-1.5.3 → wavedl-1.5.5}/src/wavedl/models/densenet.py +0 -0
  21. {wavedl-1.5.3 → wavedl-1.5.5}/src/wavedl/models/efficientnet.py +0 -0
  22. {wavedl-1.5.3 → wavedl-1.5.5}/src/wavedl/models/efficientnetv2.py +0 -0
  23. {wavedl-1.5.3 → wavedl-1.5.5}/src/wavedl/models/mobilenetv3.py +0 -0
  24. {wavedl-1.5.3 → wavedl-1.5.5}/src/wavedl/models/registry.py +0 -0
  25. {wavedl-1.5.3 → wavedl-1.5.5}/src/wavedl/models/regnet.py +0 -0
  26. {wavedl-1.5.3 → wavedl-1.5.5}/src/wavedl/models/resnet.py +0 -0
  27. {wavedl-1.5.3 → wavedl-1.5.5}/src/wavedl/models/resnet3d.py +0 -0
  28. {wavedl-1.5.3 → wavedl-1.5.5}/src/wavedl/models/tcn.py +0 -0
  29. {wavedl-1.5.3 → wavedl-1.5.5}/src/wavedl/models/unet.py +0 -0
  30. {wavedl-1.5.3 → wavedl-1.5.5}/src/wavedl/models/vit.py +0 -0
  31. {wavedl-1.5.3 → wavedl-1.5.5}/src/wavedl/test.py +0 -0
  32. {wavedl-1.5.3 → wavedl-1.5.5}/src/wavedl/utils/__init__.py +0 -0
  33. {wavedl-1.5.3 → wavedl-1.5.5}/src/wavedl/utils/config.py +0 -0
  34. {wavedl-1.5.3 → wavedl-1.5.5}/src/wavedl/utils/constraints.py +0 -0
  35. {wavedl-1.5.3 → wavedl-1.5.5}/src/wavedl/utils/cross_validation.py +0 -0
  36. {wavedl-1.5.3 → wavedl-1.5.5}/src/wavedl/utils/distributed.py +0 -0
  37. {wavedl-1.5.3 → wavedl-1.5.5}/src/wavedl/utils/losses.py +0 -0
  38. {wavedl-1.5.3 → wavedl-1.5.5}/src/wavedl/utils/optimizers.py +0 -0
  39. {wavedl-1.5.3 → wavedl-1.5.5}/src/wavedl/utils/schedulers.py +0 -0
  40. {wavedl-1.5.3 → wavedl-1.5.5}/src/wavedl.egg-info/SOURCES.txt +0 -0
  41. {wavedl-1.5.3 → wavedl-1.5.5}/src/wavedl.egg-info/dependency_links.txt +0 -0
  42. {wavedl-1.5.3 → wavedl-1.5.5}/src/wavedl.egg-info/entry_points.txt +0 -0
  43. {wavedl-1.5.3 → wavedl-1.5.5}/src/wavedl.egg-info/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.2
2
2
  Name: wavedl
3
- Version: 1.5.3
3
+ Version: 1.5.5
4
4
  Summary: A Scalable Deep Learning Framework for Wave-Based Inverse Problems
5
5
  Author: Ductho Le
6
6
  License: MIT
@@ -37,11 +37,12 @@ Requires-Dist: wandb>=0.15.0
37
37
  Requires-Dist: optuna>=3.0.0
38
38
  Requires-Dist: onnx>=1.14.0
39
39
  Requires-Dist: onnxruntime>=1.15.0
40
- Requires-Dist: pytest>=7.0.0
41
- Requires-Dist: pytest-xdist>=3.5.0
42
- Requires-Dist: ruff>=0.8.0
43
- Requires-Dist: pre-commit>=3.5.0
44
40
  Requires-Dist: triton>=2.0.0; sys_platform == "linux"
41
+ Provides-Extra: dev
42
+ Requires-Dist: pytest>=7.0.0; extra == "dev"
43
+ Requires-Dist: pytest-xdist>=3.5.0; extra == "dev"
44
+ Requires-Dist: ruff>=0.8.0; extra == "dev"
45
+ Requires-Dist: pre-commit>=3.5.0; extra == "dev"
45
46
 
46
47
  <div align="center">
47
48
 
@@ -204,7 +205,7 @@ Deploy models anywhere:
204
205
  pip install wavedl
205
206
  ```
206
207
 
207
- This installs everything you need: training, inference, HPO, ONNX export, and dev tools.
208
+ This installs everything you need: training, inference, HPO, ONNX export.
208
209
 
209
210
  #### From Source (for development)
210
211
 
@@ -336,7 +337,7 @@ class MyModel(BaseModel):
336
337
  **Step 2: Train**
337
338
 
338
339
  ```bash
339
- wavedl-hpc --import my_model --model my_model --data_path train.npz
340
+ wavedl-hpc --import my_model.py --model my_model --data_path train.npz
340
341
  ```
341
342
 
342
343
  WaveDL handles everything else: training loop, logging, checkpoints, multi-GPU, early stopping, etc.
@@ -387,7 +388,7 @@ WaveDL/
387
388
  ├── configs/ # YAML config templates
388
389
  ├── examples/ # Ready-to-run examples
389
390
  ├── notebooks/ # Jupyter notebooks
390
- ├── unit_tests/ # Pytest test suite (725 tests)
391
+ ├── unit_tests/ # Pytest test suite (731 tests)
391
392
 
392
393
  ├── pyproject.toml # Package config, dependencies
393
394
  ├── CHANGELOG.md # Version history
@@ -512,7 +513,7 @@ print('\\n✓ All pretrained weights cached!')
512
513
  | Argument | Default | Description |
513
514
  |----------|---------|-------------|
514
515
  | `--model` | `cnn` | Model architecture |
515
- | `--import` | - | Python modules to import (for custom models) |
516
+ | `--import` | - | Python file(s) to import for custom models (supports multiple) |
516
517
  | `--batch_size` | `128` | Per-GPU batch size |
517
518
  | `--lr` | `1e-3` | Learning rate |
518
519
  | `--epochs` | `1000` | Maximum epochs |
@@ -1223,6 +1224,6 @@ This research was enabled in part by support provided by [Compute Ontario](https
1223
1224
  [![Google Scholar](https://img.shields.io/badge/Google_Scholar-4285F4?style=plastic&logo=google-scholar&logoColor=white)](https://scholar.google.ca/citations?user=OlwMr9AAAAAJ)
1224
1225
  [![ResearchGate](https://img.shields.io/badge/ResearchGate-00CCBB?style=plastic&logo=researchgate&logoColor=white)](https://www.researchgate.net/profile/Ductho-Le)
1225
1226
 
1226
- <sub>Released under the MIT License</sub>
1227
+ <sub>May your signals be strong and your attenuation low 👋</sub>
1227
1228
 
1228
1229
  </div>
@@ -159,7 +159,7 @@ Deploy models anywhere:
159
159
  pip install wavedl
160
160
  ```
161
161
 
162
- This installs everything you need: training, inference, HPO, ONNX export, and dev tools.
162
+ This installs everything you need: training, inference, HPO, ONNX export.
163
163
 
164
164
  #### From Source (for development)
165
165
 
@@ -291,7 +291,7 @@ class MyModel(BaseModel):
291
291
  **Step 2: Train**
292
292
 
293
293
  ```bash
294
- wavedl-hpc --import my_model --model my_model --data_path train.npz
294
+ wavedl-hpc --import my_model.py --model my_model --data_path train.npz
295
295
  ```
296
296
 
297
297
  WaveDL handles everything else: training loop, logging, checkpoints, multi-GPU, early stopping, etc.
@@ -342,7 +342,7 @@ WaveDL/
342
342
  ├── configs/ # YAML config templates
343
343
  ├── examples/ # Ready-to-run examples
344
344
  ├── notebooks/ # Jupyter notebooks
345
- ├── unit_tests/ # Pytest test suite (725 tests)
345
+ ├── unit_tests/ # Pytest test suite (731 tests)
346
346
 
347
347
  ├── pyproject.toml # Package config, dependencies
348
348
  ├── CHANGELOG.md # Version history
@@ -467,7 +467,7 @@ print('\\n✓ All pretrained weights cached!')
467
467
  | Argument | Default | Description |
468
468
  |----------|---------|-------------|
469
469
  | `--model` | `cnn` | Model architecture |
470
- | `--import` | - | Python modules to import (for custom models) |
470
+ | `--import` | - | Python file(s) to import for custom models (supports multiple) |
471
471
  | `--batch_size` | `128` | Per-GPU batch size |
472
472
  | `--lr` | `1e-3` | Learning rate |
473
473
  | `--epochs` | `1000` | Maximum epochs |
@@ -1178,6 +1178,6 @@ This research was enabled in part by support provided by [Compute Ontario](https
1178
1178
  [![Google Scholar](https://img.shields.io/badge/Google_Scholar-4285F4?style=plastic&logo=google-scholar&logoColor=white)](https://scholar.google.ca/citations?user=OlwMr9AAAAAJ)
1179
1179
  [![ResearchGate](https://img.shields.io/badge/ResearchGate-00CCBB?style=plastic&logo=researchgate&logoColor=white)](https://www.researchgate.net/profile/Ductho-Le)
1180
1180
 
1181
- <sub>Released under the MIT License</sub>
1181
+ <sub>May your signals be strong and your attenuation low 👋</sub>
1182
1182
 
1183
1183
  </div>
@@ -70,13 +70,16 @@ dependencies = [
70
70
  # ONNX export
71
71
  "onnx>=1.14.0",
72
72
  "onnxruntime>=1.15.0",
73
- # Development tools
73
+ # torch.compile backend (Linux only)
74
+ "triton>=2.0.0; sys_platform == 'linux'",
75
+ ]
76
+
77
+ [project.optional-dependencies]
78
+ dev = [
74
79
  "pytest>=7.0.0",
75
80
  "pytest-xdist>=3.5.0",
76
81
  "ruff>=0.8.0",
77
82
  "pre-commit>=3.5.0",
78
- # torch.compile backend (Linux only)
79
- "triton>=2.0.0; sys_platform == 'linux'",
80
83
  ]
81
84
 
82
85
  [project.scripts]
@@ -18,7 +18,7 @@ For inference:
18
18
  # or: python -m wavedl.test --checkpoint best_checkpoint --data_path test.npz
19
19
  """
20
20
 
21
- __version__ = "1.5.3"
21
+ __version__ = "1.5.5"
22
22
  __author__ = "Ductho Le"
23
23
  __email__ = "ductho.le@outlook.com"
24
24
 
@@ -175,13 +175,14 @@ def create_objective(args):
175
175
  env["CUDA_VISIBLE_DEVICES"] = str(gpu_id)
176
176
 
177
177
  # Run training
178
+ # Note: We inherit the user's cwd instead of setting cwd=Path(__file__).parent
179
+ # because site-packages may be read-only and train.py creates cache directories
178
180
  try:
179
181
  result = subprocess.run(
180
182
  cmd,
181
183
  capture_output=True,
182
184
  text=True,
183
185
  timeout=args.timeout,
184
- cwd=Path(__file__).parent,
185
186
  env=env,
186
187
  )
187
188
 
@@ -191,22 +191,33 @@ class SwinTransformerBase(BaseModel):
191
191
  Returns:
192
192
  List of parameter group dictionaries
193
193
  """
194
- # Separate parameters: head (full LR) vs backbone (decayed LR)
194
+ # Separate parameters into 4 groups for proper LR decay:
195
+ # 1. Head params with decay (full LR)
196
+ # 2. Backbone params with decay (0.1× LR)
197
+ # 3. Head bias/norm without decay (full LR)
198
+ # 4. Backbone bias/norm without decay (0.1× LR)
195
199
  head_params = []
196
200
  backbone_params = []
197
- no_decay_params = []
201
+ head_no_decay = []
202
+ backbone_no_decay = []
198
203
 
199
204
  for name, param in self.backbone.named_parameters():
200
205
  if not param.requires_grad:
201
206
  continue
202
207
 
203
- # No weight decay for bias and normalization
204
- if "bias" in name or "norm" in name:
205
- no_decay_params.append(param)
206
- elif "head" in name:
207
- head_params.append(param)
208
+ is_head = "head" in name
209
+ is_no_decay = "bias" in name or "norm" in name
210
+
211
+ if is_head:
212
+ if is_no_decay:
213
+ head_no_decay.append(param)
214
+ else:
215
+ head_params.append(param)
208
216
  else:
209
- backbone_params.append(param)
217
+ if is_no_decay:
218
+ backbone_no_decay.append(param)
219
+ else:
220
+ backbone_params.append(param)
210
221
 
211
222
  groups = []
212
223
 
@@ -229,15 +240,25 @@ class SwinTransformerBase(BaseModel):
229
240
  }
230
241
  )
231
242
 
232
- if no_decay_params:
243
+ if head_no_decay:
233
244
  groups.append(
234
245
  {
235
- "params": no_decay_params,
246
+ "params": head_no_decay,
236
247
  "lr": base_lr,
237
248
  "weight_decay": 0.0,
238
249
  }
239
250
  )
240
251
 
252
+ if backbone_no_decay:
253
+ # Backbone bias/norm also gets 0.1× LR to match intended decay
254
+ groups.append(
255
+ {
256
+ "params": backbone_no_decay,
257
+ "lr": base_lr * 0.1,
258
+ "weight_decay": 0.0,
259
+ }
260
+ )
261
+
241
262
  return groups if groups else [{"params": self.parameters(), "lr": base_lr}]
242
263
 
243
264
 
@@ -122,6 +122,7 @@ import matplotlib.pyplot as plt
122
122
  import numpy as np
123
123
  import pandas as pd
124
124
  import torch
125
+ import torch.distributed as dist
125
126
  from accelerate import Accelerator
126
127
  from accelerate.utils import set_seed
127
128
  from sklearn.metrics import r2_score
@@ -470,15 +471,19 @@ def main():
470
471
  try:
471
472
  # Handle both module names (my_model) and file paths (./my_model.py)
472
473
  if module_name.endswith(".py"):
473
- # Import from file path
474
+ # Import from file path with unique module name
474
475
  import importlib.util
475
476
 
477
+ # Derive unique module name from filename to avoid collisions
478
+ base_name = os.path.splitext(os.path.basename(module_name))[0]
479
+ unique_name = f"wavedl_custom_{base_name}"
480
+
476
481
  spec = importlib.util.spec_from_file_location(
477
- "custom_module", module_name
482
+ unique_name, module_name
478
483
  )
479
484
  if spec and spec.loader:
480
485
  module = importlib.util.module_from_spec(spec)
481
- sys.modules["custom_module"] = module
486
+ sys.modules[unique_name] = module
482
487
  spec.loader.exec_module(module)
483
488
  print(f"✓ Imported custom module from: {module_name}")
484
489
  else:
@@ -1250,9 +1255,32 @@ def main():
1250
1255
  )
1251
1256
 
1252
1257
  # Learning rate scheduling (epoch-based schedulers only)
1258
+ # NOTE: For ReduceLROnPlateau with DDP, we must step only on main process
1259
+ # to avoid patience counter being incremented by all GPU processes.
1260
+ # Then we sync the new LR to all processes to keep them consistent.
1253
1261
  if not scheduler_step_per_batch:
1254
1262
  if args.scheduler == "plateau":
1255
- scheduler.step(avg_val_loss)
1263
+ # Step only on main process to avoid multi-GPU patience bug
1264
+ if accelerator.is_main_process:
1265
+ scheduler.step(avg_val_loss)
1266
+
1267
+ # Sync LR across all processes after main process updates it
1268
+ accelerator.wait_for_everyone()
1269
+
1270
+ # Broadcast new LR from rank 0 to all processes
1271
+ if dist.is_initialized():
1272
+ if accelerator.is_main_process:
1273
+ new_lr = optimizer.param_groups[0]["lr"]
1274
+ else:
1275
+ new_lr = 0.0
1276
+ new_lr_tensor = torch.tensor(
1277
+ new_lr, device=accelerator.device, dtype=torch.float32
1278
+ )
1279
+ dist.broadcast(new_lr_tensor, src=0)
1280
+ # Update LR on non-main processes
1281
+ if not accelerator.is_main_process:
1282
+ for param_group in optimizer.param_groups:
1283
+ param_group["lr"] = new_lr_tensor.item()
1256
1284
  else:
1257
1285
  scheduler.step()
1258
1286
 
@@ -207,6 +207,10 @@ class NPZSource(DataSource):
207
207
 
208
208
  The error for object arrays happens at ACCESS time, not load time.
209
209
  So we need to probe the keys to detect if pickle is required.
210
+
211
+ WARNING: When mmap_mode is not None, the returned NpzFile must be kept
212
+ open for arrays to remain valid. Caller is responsible for closing.
213
+ For non-mmap loading, use _load_and_copy() instead to avoid leaks.
210
214
  """
211
215
  data = np.load(path, allow_pickle=False, mmap_mode=mmap_mode)
212
216
  try:
@@ -222,6 +226,26 @@ class NPZSource(DataSource):
222
226
  return np.load(path, allow_pickle=True, mmap_mode=mmap_mode)
223
227
  raise
224
228
 
229
+ @staticmethod
230
+ def _load_and_copy(path: str, keys: list[str]) -> dict[str, np.ndarray]:
231
+ """Load NPZ and copy arrays, ensuring file is properly closed.
232
+
233
+ This prevents file descriptor leaks by copying arrays before closing.
234
+ Use this for eager loading; use _safe_load for memory-mapped access.
235
+ """
236
+ data = NPZSource._safe_load(path, keys, mmap_mode=None)
237
+ try:
238
+ result = {}
239
+ for key in keys:
240
+ if key in data:
241
+ arr = data[key]
242
+ # Copy ensures we don't hold reference to mmap
243
+ result[key] = arr.copy() if hasattr(arr, "copy") else arr
244
+ return result
245
+ finally:
246
+ if hasattr(data, "close"):
247
+ data.close()
248
+
225
249
  def load(self, path: str) -> tuple[np.ndarray, np.ndarray]:
226
250
  """Load NPZ file (pickle enabled only for sparse matrices)."""
227
251
  # First pass to find keys without loading data
@@ -238,7 +262,7 @@ class NPZSource(DataSource):
238
262
  f"Found: {keys}"
239
263
  )
240
264
 
241
- data = self._safe_load(path, [input_key, output_key])
265
+ data = self._load_and_copy(path, [input_key, output_key])
242
266
  inp = data[input_key]
243
267
  outp = data[output_key]
244
268
 
@@ -290,7 +314,7 @@ class NPZSource(DataSource):
290
314
  f"Supported keys: {OUTPUT_KEYS}. Found: {keys}"
291
315
  )
292
316
 
293
- data = self._safe_load(path, [output_key])
317
+ data = self._load_and_copy(path, [output_key])
294
318
  return data[output_key]
295
319
 
296
320
 
@@ -527,9 +551,17 @@ class MATSource(DataSource):
527
551
  inp = self._load_dataset(f, input_key)
528
552
  outp = self._load_dataset(f, output_key)
529
553
 
530
- # Handle 1D outputs that become (1, N) after transpose
531
- if outp.ndim == 2 and outp.shape[0] == 1:
532
- outp = outp.T
554
+ # Handle transposed outputs from MATLAB.
555
+ # Case 1: (1, N) - N samples with 1 target → transpose to (N, 1)
556
+ # Case 2: (T, 1) - 1 sample with T targets → transpose to (1, T)
557
+ num_samples = inp.shape[0] # inp is already transposed
558
+ if outp.ndim == 2:
559
+ if outp.shape[0] == 1 and outp.shape[1] == num_samples:
560
+ # 1D vector: (1, N) → (N, 1)
561
+ outp = outp.T
562
+ elif outp.shape[1] == 1 and outp.shape[0] != num_samples:
563
+ # Single sample with multiple targets: (T, 1) → (1, T)
564
+ outp = outp.T
533
565
 
534
566
  except OSError as e:
535
567
  raise ValueError(
@@ -614,7 +646,10 @@ class MATSource(DataSource):
614
646
  # Load with sparse matrix support
615
647
  outp = self._load_dataset(f, output_key)
616
648
 
617
- # Handle 1D outputs
649
+ # Handle 1D outputs that become (1, N) after transpose.
650
+ # Note: This method has no input to compare against, so we can't
651
+ # distinguish single-sample outputs. This is acceptable for training
652
+ # data where single-sample is unlikely. For inference, use load_test_data.
618
653
  if outp.ndim == 2 and outp.shape[0] == 1:
619
654
  outp = outp.T
620
655
 
@@ -775,7 +810,7 @@ def load_test_data(
775
810
  raise KeyError(
776
811
  f"Input key not found. Tried: {custom_input_keys}. Found: {keys}"
777
812
  )
778
- data = NPZSource._safe_load(
813
+ data = NPZSource._load_and_copy(
779
814
  path, [inp_key] + ([out_key] if out_key else [])
780
815
  )
781
816
  inp = data[inp_key]
@@ -793,6 +828,14 @@ def load_test_data(
793
828
  raise KeyError(
794
829
  f"Input key not found. Tried: {custom_input_keys}. Found: {keys}"
795
830
  )
831
+ # OOM guard: warn if dataset is very large
832
+ n_samples = f[inp_key].shape[0]
833
+ if n_samples > 100000:
834
+ raise ValueError(
835
+ f"Dataset has {n_samples:,} samples. load_test_data() loads "
836
+ f"everything into RAM which may cause OOM. For large inference "
837
+ f"sets, use a DataLoader with HDF5Source.load_mmap() instead."
838
+ )
796
839
  inp = f[inp_key][:]
797
840
  outp = f[out_key][:] if out_key else None
798
841
  elif format == "mat":
@@ -805,11 +848,28 @@ def load_test_data(
805
848
  raise KeyError(
806
849
  f"Input key not found. Tried: {custom_input_keys}. Found: {keys}"
807
850
  )
851
+ # OOM guard: warn if dataset is very large (MAT is transposed)
852
+ n_samples = f[inp_key].shape[-1]
853
+ if n_samples > 100000:
854
+ raise ValueError(
855
+ f"Dataset has {n_samples:,} samples. load_test_data() loads "
856
+ f"everything into RAM which may cause OOM. For large inference "
857
+ f"sets, use a DataLoader with MATSource.load_mmap() instead."
858
+ )
808
859
  inp = mat_source._load_dataset(f, inp_key)
809
860
  if out_key:
810
861
  outp = mat_source._load_dataset(f, out_key)
811
- if outp.ndim == 2 and outp.shape[0] == 1:
812
- outp = outp.T
862
+ # Handle transposed outputs from MATLAB
863
+ # Case 1: (1, N) - N samples with 1 target → transpose to (N, 1)
864
+ # Case 2: (T, 1) - 1 sample with T targets → transpose to (1, T)
865
+ num_samples = inp.shape[0]
866
+ if outp.ndim == 2:
867
+ if outp.shape[0] == 1 and outp.shape[1] == num_samples:
868
+ # 1D vector: (1, N) → (N, 1)
869
+ outp = outp.T
870
+ elif outp.shape[1] == 1 and outp.shape[0] != num_samples:
871
+ # Single sample with multiple targets: (T, 1) → (1, T)
872
+ outp = outp.T
813
873
  else:
814
874
  outp = None
815
875
  else:
@@ -828,7 +888,7 @@ def load_test_data(
828
888
  )
829
889
  out_key = DataSource._find_key(keys, custom_output_keys)
830
890
  keys_to_probe = [inp_key] + ([out_key] if out_key else [])
831
- data = NPZSource._safe_load(path, keys_to_probe)
891
+ data = NPZSource._load_and_copy(path, keys_to_probe)
832
892
  inp = data[inp_key]
833
893
  if inp.dtype == object:
834
894
  inp = np.array(
@@ -878,9 +938,17 @@ def load_test_data(
878
938
  out_key = DataSource._find_key(keys, custom_output_keys)
879
939
  if out_key:
880
940
  outp = mat_source._load_dataset(f, out_key)
881
- # Handle 1D outputs that become (1, N) after transpose
882
- if outp.ndim == 2 and outp.shape[0] == 1:
883
- outp = outp.T
941
+ # Handle transposed outputs from MATLAB
942
+ # Case 1: (1, N) - N samples with 1 target → transpose to (N, 1)
943
+ # Case 2: (T, 1) - 1 sample with T targets → transpose to (1, T)
944
+ num_samples = inp.shape[0]
945
+ if outp.ndim == 2:
946
+ if outp.shape[0] == 1 and outp.shape[1] == num_samples:
947
+ # 1D vector: (1, N) → (N, 1)
948
+ outp = outp.T
949
+ elif outp.shape[1] == 1 and outp.shape[0] != num_samples:
950
+ # Single sample with multiple targets: (T, 1) → (1, T)
951
+ outp = outp.T
884
952
  else:
885
953
  outp = None
886
954
  else:
@@ -1126,6 +1194,29 @@ def prepare_data(
1126
1194
 
1127
1195
  if not cache_exists:
1128
1196
  if accelerator.is_main_process:
1197
+ # Delete stale cache files to force regeneration
1198
+ # This prevents silent reuse of old data when metadata invalidates cache
1199
+ for stale_file in [CACHE_FILE, SCALER_FILE]:
1200
+ if os.path.exists(stale_file):
1201
+ try:
1202
+ os.remove(stale_file)
1203
+ logger.debug(f" Removed stale cache: {stale_file}")
1204
+ except OSError as e:
1205
+ logger.warning(
1206
+ f" Failed to remove stale cache {stale_file}: {e}"
1207
+ )
1208
+
1209
+ # Fail explicitly if stale cache files couldn't be removed
1210
+ # This prevents silent reuse of outdated data
1211
+ remaining_stale = [
1212
+ f for f in [CACHE_FILE, SCALER_FILE] if os.path.exists(f)
1213
+ ]
1214
+ if remaining_stale:
1215
+ raise RuntimeError(
1216
+ f"Cannot regenerate cache: stale files could not be removed. "
1217
+ f"Please manually delete: {remaining_stale}"
1218
+ )
1219
+
1129
1220
  # RANK 0: Create cache (can take a long time for large datasets)
1130
1221
  # Other ranks will wait at the barrier below
1131
1222
 
@@ -815,7 +815,28 @@ def plot_qq(
815
815
 
816
816
  # Standardize errors for QQ plot
817
817
  err = errors[:, i]
818
- standardized = (err - np.mean(err)) / np.std(err)
818
+ std_err = np.std(err)
819
+
820
+ # Guard against zero variance (constant errors)
821
+ if std_err < 1e-10:
822
+ title = (
823
+ param_names[i] if param_names and i < len(param_names) else f"Param {i}"
824
+ )
825
+ ax.text(
826
+ 0.5,
827
+ 0.5,
828
+ "Zero variance\n(constant errors)",
829
+ ha="center",
830
+ va="center",
831
+ fontsize=10,
832
+ transform=ax.transAxes,
833
+ )
834
+ ax.set_title(f"{title}\n(zero variance)")
835
+ ax.set_xlabel("Theoretical Quantiles")
836
+ ax.set_ylabel("Sample Quantiles")
837
+ continue
838
+
839
+ standardized = (err - np.mean(err)) / std_err
819
840
 
820
841
  # Calculate theoretical quantiles and sample quantiles
821
842
  (osm, osr), (slope, intercept, r) = stats.probplot(standardized, dist="norm")
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.2
2
2
  Name: wavedl
3
- Version: 1.5.3
3
+ Version: 1.5.5
4
4
  Summary: A Scalable Deep Learning Framework for Wave-Based Inverse Problems
5
5
  Author: Ductho Le
6
6
  License: MIT
@@ -37,11 +37,12 @@ Requires-Dist: wandb>=0.15.0
37
37
  Requires-Dist: optuna>=3.0.0
38
38
  Requires-Dist: onnx>=1.14.0
39
39
  Requires-Dist: onnxruntime>=1.15.0
40
- Requires-Dist: pytest>=7.0.0
41
- Requires-Dist: pytest-xdist>=3.5.0
42
- Requires-Dist: ruff>=0.8.0
43
- Requires-Dist: pre-commit>=3.5.0
44
40
  Requires-Dist: triton>=2.0.0; sys_platform == "linux"
41
+ Provides-Extra: dev
42
+ Requires-Dist: pytest>=7.0.0; extra == "dev"
43
+ Requires-Dist: pytest-xdist>=3.5.0; extra == "dev"
44
+ Requires-Dist: ruff>=0.8.0; extra == "dev"
45
+ Requires-Dist: pre-commit>=3.5.0; extra == "dev"
45
46
 
46
47
  <div align="center">
47
48
 
@@ -204,7 +205,7 @@ Deploy models anywhere:
204
205
  pip install wavedl
205
206
  ```
206
207
 
207
- This installs everything you need: training, inference, HPO, ONNX export, and dev tools.
208
+ This installs everything you need: training, inference, HPO, ONNX export.
208
209
 
209
210
  #### From Source (for development)
210
211
 
@@ -336,7 +337,7 @@ class MyModel(BaseModel):
336
337
  **Step 2: Train**
337
338
 
338
339
  ```bash
339
- wavedl-hpc --import my_model --model my_model --data_path train.npz
340
+ wavedl-hpc --import my_model.py --model my_model --data_path train.npz
340
341
  ```
341
342
 
342
343
  WaveDL handles everything else: training loop, logging, checkpoints, multi-GPU, early stopping, etc.
@@ -387,7 +388,7 @@ WaveDL/
387
388
  ├── configs/ # YAML config templates
388
389
  ├── examples/ # Ready-to-run examples
389
390
  ├── notebooks/ # Jupyter notebooks
390
- ├── unit_tests/ # Pytest test suite (725 tests)
391
+ ├── unit_tests/ # Pytest test suite (731 tests)
391
392
 
392
393
  ├── pyproject.toml # Package config, dependencies
393
394
  ├── CHANGELOG.md # Version history
@@ -512,7 +513,7 @@ print('\\n✓ All pretrained weights cached!')
512
513
  | Argument | Default | Description |
513
514
  |----------|---------|-------------|
514
515
  | `--model` | `cnn` | Model architecture |
515
- | `--import` | - | Python modules to import (for custom models) |
516
+ | `--import` | - | Python file(s) to import for custom models (supports multiple) |
516
517
  | `--batch_size` | `128` | Per-GPU batch size |
517
518
  | `--lr` | `1e-3` | Learning rate |
518
519
  | `--epochs` | `1000` | Maximum epochs |
@@ -1223,6 +1224,6 @@ This research was enabled in part by support provided by [Compute Ontario](https
1223
1224
  [![Google Scholar](https://img.shields.io/badge/Google_Scholar-4285F4?style=plastic&logo=google-scholar&logoColor=white)](https://scholar.google.ca/citations?user=OlwMr9AAAAAJ)
1224
1225
  [![ResearchGate](https://img.shields.io/badge/ResearchGate-00CCBB?style=plastic&logo=researchgate&logoColor=white)](https://www.researchgate.net/profile/Ductho-Le)
1225
1226
 
1226
- <sub>Released under the MIT License</sub>
1227
+ <sub>May your signals be strong and your attenuation low 👋</sub>
1227
1228
 
1228
1229
  </div>
@@ -14,10 +14,12 @@ wandb>=0.15.0
14
14
  optuna>=3.0.0
15
15
  onnx>=1.14.0
16
16
  onnxruntime>=1.15.0
17
+
18
+ [:sys_platform == "linux"]
19
+ triton>=2.0.0
20
+
21
+ [dev]
17
22
  pytest>=7.0.0
18
23
  pytest-xdist>=3.5.0
19
24
  ruff>=0.8.0
20
25
  pre-commit>=3.5.0
21
-
22
- [:sys_platform == "linux"]
23
- triton>=2.0.0
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes