wavedl 1.5.3__tar.gz → 1.5.4__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (43) hide show
  1. {wavedl-1.5.3/src/wavedl.egg-info → wavedl-1.5.4}/PKG-INFO +10 -9
  2. {wavedl-1.5.3 → wavedl-1.5.4}/README.md +4 -4
  3. {wavedl-1.5.3 → wavedl-1.5.4}/pyproject.toml +6 -3
  4. {wavedl-1.5.3 → wavedl-1.5.4}/src/wavedl/__init__.py +1 -1
  5. {wavedl-1.5.3 → wavedl-1.5.4}/src/wavedl/models/swin.py +31 -10
  6. {wavedl-1.5.3 → wavedl-1.5.4}/src/wavedl/train.py +32 -4
  7. {wavedl-1.5.3 → wavedl-1.5.4}/src/wavedl/utils/data.py +27 -0
  8. {wavedl-1.5.3 → wavedl-1.5.4/src/wavedl.egg-info}/PKG-INFO +10 -9
  9. {wavedl-1.5.3 → wavedl-1.5.4}/src/wavedl.egg-info/requires.txt +5 -3
  10. {wavedl-1.5.3 → wavedl-1.5.4}/LICENSE +0 -0
  11. {wavedl-1.5.3 → wavedl-1.5.4}/setup.cfg +0 -0
  12. {wavedl-1.5.3 → wavedl-1.5.4}/src/wavedl/hpc.py +0 -0
  13. {wavedl-1.5.3 → wavedl-1.5.4}/src/wavedl/hpo.py +0 -0
  14. {wavedl-1.5.3 → wavedl-1.5.4}/src/wavedl/models/__init__.py +0 -0
  15. {wavedl-1.5.3 → wavedl-1.5.4}/src/wavedl/models/_template.py +0 -0
  16. {wavedl-1.5.3 → wavedl-1.5.4}/src/wavedl/models/base.py +0 -0
  17. {wavedl-1.5.3 → wavedl-1.5.4}/src/wavedl/models/cnn.py +0 -0
  18. {wavedl-1.5.3 → wavedl-1.5.4}/src/wavedl/models/convnext.py +0 -0
  19. {wavedl-1.5.3 → wavedl-1.5.4}/src/wavedl/models/densenet.py +0 -0
  20. {wavedl-1.5.3 → wavedl-1.5.4}/src/wavedl/models/efficientnet.py +0 -0
  21. {wavedl-1.5.3 → wavedl-1.5.4}/src/wavedl/models/efficientnetv2.py +0 -0
  22. {wavedl-1.5.3 → wavedl-1.5.4}/src/wavedl/models/mobilenetv3.py +0 -0
  23. {wavedl-1.5.3 → wavedl-1.5.4}/src/wavedl/models/registry.py +0 -0
  24. {wavedl-1.5.3 → wavedl-1.5.4}/src/wavedl/models/regnet.py +0 -0
  25. {wavedl-1.5.3 → wavedl-1.5.4}/src/wavedl/models/resnet.py +0 -0
  26. {wavedl-1.5.3 → wavedl-1.5.4}/src/wavedl/models/resnet3d.py +0 -0
  27. {wavedl-1.5.3 → wavedl-1.5.4}/src/wavedl/models/tcn.py +0 -0
  28. {wavedl-1.5.3 → wavedl-1.5.4}/src/wavedl/models/unet.py +0 -0
  29. {wavedl-1.5.3 → wavedl-1.5.4}/src/wavedl/models/vit.py +0 -0
  30. {wavedl-1.5.3 → wavedl-1.5.4}/src/wavedl/test.py +0 -0
  31. {wavedl-1.5.3 → wavedl-1.5.4}/src/wavedl/utils/__init__.py +0 -0
  32. {wavedl-1.5.3 → wavedl-1.5.4}/src/wavedl/utils/config.py +0 -0
  33. {wavedl-1.5.3 → wavedl-1.5.4}/src/wavedl/utils/constraints.py +0 -0
  34. {wavedl-1.5.3 → wavedl-1.5.4}/src/wavedl/utils/cross_validation.py +0 -0
  35. {wavedl-1.5.3 → wavedl-1.5.4}/src/wavedl/utils/distributed.py +0 -0
  36. {wavedl-1.5.3 → wavedl-1.5.4}/src/wavedl/utils/losses.py +0 -0
  37. {wavedl-1.5.3 → wavedl-1.5.4}/src/wavedl/utils/metrics.py +0 -0
  38. {wavedl-1.5.3 → wavedl-1.5.4}/src/wavedl/utils/optimizers.py +0 -0
  39. {wavedl-1.5.3 → wavedl-1.5.4}/src/wavedl/utils/schedulers.py +0 -0
  40. {wavedl-1.5.3 → wavedl-1.5.4}/src/wavedl.egg-info/SOURCES.txt +0 -0
  41. {wavedl-1.5.3 → wavedl-1.5.4}/src/wavedl.egg-info/dependency_links.txt +0 -0
  42. {wavedl-1.5.3 → wavedl-1.5.4}/src/wavedl.egg-info/entry_points.txt +0 -0
  43. {wavedl-1.5.3 → wavedl-1.5.4}/src/wavedl.egg-info/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.2
2
2
  Name: wavedl
3
- Version: 1.5.3
3
+ Version: 1.5.4
4
4
  Summary: A Scalable Deep Learning Framework for Wave-Based Inverse Problems
5
5
  Author: Ductho Le
6
6
  License: MIT
@@ -37,11 +37,12 @@ Requires-Dist: wandb>=0.15.0
37
37
  Requires-Dist: optuna>=3.0.0
38
38
  Requires-Dist: onnx>=1.14.0
39
39
  Requires-Dist: onnxruntime>=1.15.0
40
- Requires-Dist: pytest>=7.0.0
41
- Requires-Dist: pytest-xdist>=3.5.0
42
- Requires-Dist: ruff>=0.8.0
43
- Requires-Dist: pre-commit>=3.5.0
44
40
  Requires-Dist: triton>=2.0.0; sys_platform == "linux"
41
+ Provides-Extra: dev
42
+ Requires-Dist: pytest>=7.0.0; extra == "dev"
43
+ Requires-Dist: pytest-xdist>=3.5.0; extra == "dev"
44
+ Requires-Dist: ruff>=0.8.0; extra == "dev"
45
+ Requires-Dist: pre-commit>=3.5.0; extra == "dev"
45
46
 
46
47
  <div align="center">
47
48
 
@@ -204,7 +205,7 @@ Deploy models anywhere:
204
205
  pip install wavedl
205
206
  ```
206
207
 
207
- This installs everything you need: training, inference, HPO, ONNX export, and dev tools.
208
+ This installs everything you need: training, inference, HPO, ONNX export.
208
209
 
209
210
  #### From Source (for development)
210
211
 
@@ -336,7 +337,7 @@ class MyModel(BaseModel):
336
337
  **Step 2: Train**
337
338
 
338
339
  ```bash
339
- wavedl-hpc --import my_model --model my_model --data_path train.npz
340
+ wavedl-hpc --import my_model.py --model my_model --data_path train.npz
340
341
  ```
341
342
 
342
343
  WaveDL handles everything else: training loop, logging, checkpoints, multi-GPU, early stopping, etc.
@@ -512,7 +513,7 @@ print('\\n✓ All pretrained weights cached!')
512
513
  | Argument | Default | Description |
513
514
  |----------|---------|-------------|
514
515
  | `--model` | `cnn` | Model architecture |
515
- | `--import` | - | Python modules to import (for custom models) |
516
+ | `--import` | - | Python file(s) to import for custom models (supports multiple) |
516
517
  | `--batch_size` | `128` | Per-GPU batch size |
517
518
  | `--lr` | `1e-3` | Learning rate |
518
519
  | `--epochs` | `1000` | Maximum epochs |
@@ -1223,6 +1224,6 @@ This research was enabled in part by support provided by [Compute Ontario](https
1223
1224
  [![Google Scholar](https://img.shields.io/badge/Google_Scholar-4285F4?style=plastic&logo=google-scholar&logoColor=white)](https://scholar.google.ca/citations?user=OlwMr9AAAAAJ)
1224
1225
  [![ResearchGate](https://img.shields.io/badge/ResearchGate-00CCBB?style=plastic&logo=researchgate&logoColor=white)](https://www.researchgate.net/profile/Ductho-Le)
1225
1226
 
1226
- <sub>Released under the MIT License</sub>
1227
+ <sub>May your signals be strong and your attenuation low 👋</sub>
1227
1228
 
1228
1229
  </div>
@@ -159,7 +159,7 @@ Deploy models anywhere:
159
159
  pip install wavedl
160
160
  ```
161
161
 
162
- This installs everything you need: training, inference, HPO, ONNX export, and dev tools.
162
+ This installs everything you need: training, inference, HPO, ONNX export.
163
163
 
164
164
  #### From Source (for development)
165
165
 
@@ -291,7 +291,7 @@ class MyModel(BaseModel):
291
291
  **Step 2: Train**
292
292
 
293
293
  ```bash
294
- wavedl-hpc --import my_model --model my_model --data_path train.npz
294
+ wavedl-hpc --import my_model.py --model my_model --data_path train.npz
295
295
  ```
296
296
 
297
297
  WaveDL handles everything else: training loop, logging, checkpoints, multi-GPU, early stopping, etc.
@@ -467,7 +467,7 @@ print('\\n✓ All pretrained weights cached!')
467
467
  | Argument | Default | Description |
468
468
  |----------|---------|-------------|
469
469
  | `--model` | `cnn` | Model architecture |
470
- | `--import` | - | Python modules to import (for custom models) |
470
+ | `--import` | - | Python file(s) to import for custom models (supports multiple) |
471
471
  | `--batch_size` | `128` | Per-GPU batch size |
472
472
  | `--lr` | `1e-3` | Learning rate |
473
473
  | `--epochs` | `1000` | Maximum epochs |
@@ -1178,6 +1178,6 @@ This research was enabled in part by support provided by [Compute Ontario](https
1178
1178
  [![Google Scholar](https://img.shields.io/badge/Google_Scholar-4285F4?style=plastic&logo=google-scholar&logoColor=white)](https://scholar.google.ca/citations?user=OlwMr9AAAAAJ)
1179
1179
  [![ResearchGate](https://img.shields.io/badge/ResearchGate-00CCBB?style=plastic&logo=researchgate&logoColor=white)](https://www.researchgate.net/profile/Ductho-Le)
1180
1180
 
1181
- <sub>Released under the MIT License</sub>
1181
+ <sub>May your signals be strong and your attenuation low 👋</sub>
1182
1182
 
1183
1183
  </div>
@@ -70,13 +70,16 @@ dependencies = [
70
70
  # ONNX export
71
71
  "onnx>=1.14.0",
72
72
  "onnxruntime>=1.15.0",
73
- # Development tools
73
+ # torch.compile backend (Linux only)
74
+ "triton>=2.0.0; sys_platform == 'linux'",
75
+ ]
76
+
77
+ [project.optional-dependencies]
78
+ dev = [
74
79
  "pytest>=7.0.0",
75
80
  "pytest-xdist>=3.5.0",
76
81
  "ruff>=0.8.0",
77
82
  "pre-commit>=3.5.0",
78
- # torch.compile backend (Linux only)
79
- "triton>=2.0.0; sys_platform == 'linux'",
80
83
  ]
81
84
 
82
85
  [project.scripts]
@@ -18,7 +18,7 @@ For inference:
18
18
  # or: python -m wavedl.test --checkpoint best_checkpoint --data_path test.npz
19
19
  """
20
20
 
21
- __version__ = "1.5.3"
21
+ __version__ = "1.5.4"
22
22
  __author__ = "Ductho Le"
23
23
  __email__ = "ductho.le@outlook.com"
24
24
 
@@ -191,22 +191,33 @@ class SwinTransformerBase(BaseModel):
191
191
  Returns:
192
192
  List of parameter group dictionaries
193
193
  """
194
- # Separate parameters: head (full LR) vs backbone (decayed LR)
194
+ # Separate parameters into 4 groups for proper LR decay:
195
+ # 1. Head params with decay (full LR)
196
+ # 2. Backbone params with decay (0.1× LR)
197
+ # 3. Head bias/norm without decay (full LR)
198
+ # 4. Backbone bias/norm without decay (0.1× LR)
195
199
  head_params = []
196
200
  backbone_params = []
197
- no_decay_params = []
201
+ head_no_decay = []
202
+ backbone_no_decay = []
198
203
 
199
204
  for name, param in self.backbone.named_parameters():
200
205
  if not param.requires_grad:
201
206
  continue
202
207
 
203
- # No weight decay for bias and normalization
204
- if "bias" in name or "norm" in name:
205
- no_decay_params.append(param)
206
- elif "head" in name:
207
- head_params.append(param)
208
+ is_head = "head" in name
209
+ is_no_decay = "bias" in name or "norm" in name
210
+
211
+ if is_head:
212
+ if is_no_decay:
213
+ head_no_decay.append(param)
214
+ else:
215
+ head_params.append(param)
208
216
  else:
209
- backbone_params.append(param)
217
+ if is_no_decay:
218
+ backbone_no_decay.append(param)
219
+ else:
220
+ backbone_params.append(param)
210
221
 
211
222
  groups = []
212
223
 
@@ -229,15 +240,25 @@ class SwinTransformerBase(BaseModel):
229
240
  }
230
241
  )
231
242
 
232
- if no_decay_params:
243
+ if head_no_decay:
233
244
  groups.append(
234
245
  {
235
- "params": no_decay_params,
246
+ "params": head_no_decay,
236
247
  "lr": base_lr,
237
248
  "weight_decay": 0.0,
238
249
  }
239
250
  )
240
251
 
252
+ if backbone_no_decay:
253
+ # Backbone bias/norm also gets 0.1× LR to match intended decay
254
+ groups.append(
255
+ {
256
+ "params": backbone_no_decay,
257
+ "lr": base_lr * 0.1,
258
+ "weight_decay": 0.0,
259
+ }
260
+ )
261
+
241
262
  return groups if groups else [{"params": self.parameters(), "lr": base_lr}]
242
263
 
243
264
 
@@ -122,6 +122,7 @@ import matplotlib.pyplot as plt
122
122
  import numpy as np
123
123
  import pandas as pd
124
124
  import torch
125
+ import torch.distributed as dist
125
126
  from accelerate import Accelerator
126
127
  from accelerate.utils import set_seed
127
128
  from sklearn.metrics import r2_score
@@ -470,15 +471,19 @@ def main():
470
471
  try:
471
472
  # Handle both module names (my_model) and file paths (./my_model.py)
472
473
  if module_name.endswith(".py"):
473
- # Import from file path
474
+ # Import from file path with unique module name
474
475
  import importlib.util
475
476
 
477
+ # Derive unique module name from filename to avoid collisions
478
+ base_name = os.path.splitext(os.path.basename(module_name))[0]
479
+ unique_name = f"wavedl_custom_{base_name}"
480
+
476
481
  spec = importlib.util.spec_from_file_location(
477
- "custom_module", module_name
482
+ unique_name, module_name
478
483
  )
479
484
  if spec and spec.loader:
480
485
  module = importlib.util.module_from_spec(spec)
481
- sys.modules["custom_module"] = module
486
+ sys.modules[unique_name] = module
482
487
  spec.loader.exec_module(module)
483
488
  print(f"✓ Imported custom module from: {module_name}")
484
489
  else:
@@ -1250,9 +1255,32 @@ def main():
1250
1255
  )
1251
1256
 
1252
1257
  # Learning rate scheduling (epoch-based schedulers only)
1258
+ # NOTE: For ReduceLROnPlateau with DDP, we must step only on main process
1259
+ # to avoid patience counter being incremented by all GPU processes.
1260
+ # Then we sync the new LR to all processes to keep them consistent.
1253
1261
  if not scheduler_step_per_batch:
1254
1262
  if args.scheduler == "plateau":
1255
- scheduler.step(avg_val_loss)
1263
+ # Step only on main process to avoid multi-GPU patience bug
1264
+ if accelerator.is_main_process:
1265
+ scheduler.step(avg_val_loss)
1266
+
1267
+ # Sync LR across all processes after main process updates it
1268
+ accelerator.wait_for_everyone()
1269
+
1270
+ # Broadcast new LR from rank 0 to all processes
1271
+ if dist.is_initialized():
1272
+ if accelerator.is_main_process:
1273
+ new_lr = optimizer.param_groups[0]["lr"]
1274
+ else:
1275
+ new_lr = 0.0
1276
+ new_lr_tensor = torch.tensor(
1277
+ new_lr, device=accelerator.device, dtype=torch.float32
1278
+ )
1279
+ dist.broadcast(new_lr_tensor, src=0)
1280
+ # Update LR on non-main processes
1281
+ if not accelerator.is_main_process:
1282
+ for param_group in optimizer.param_groups:
1283
+ param_group["lr"] = new_lr_tensor.item()
1256
1284
  else:
1257
1285
  scheduler.step()
1258
1286
 
@@ -793,6 +793,14 @@ def load_test_data(
793
793
  raise KeyError(
794
794
  f"Input key not found. Tried: {custom_input_keys}. Found: {keys}"
795
795
  )
796
+ # OOM guard: warn if dataset is very large
797
+ n_samples = f[inp_key].shape[0]
798
+ if n_samples > 100000:
799
+ raise ValueError(
800
+ f"Dataset has {n_samples:,} samples. load_test_data() loads "
801
+ f"everything into RAM which may cause OOM. For large inference "
802
+ f"sets, use a DataLoader with HDF5Source.load_mmap() instead."
803
+ )
796
804
  inp = f[inp_key][:]
797
805
  outp = f[out_key][:] if out_key else None
798
806
  elif format == "mat":
@@ -805,6 +813,14 @@ def load_test_data(
805
813
  raise KeyError(
806
814
  f"Input key not found. Tried: {custom_input_keys}. Found: {keys}"
807
815
  )
816
+ # OOM guard: warn if dataset is very large (MAT is transposed)
817
+ n_samples = f[inp_key].shape[-1]
818
+ if n_samples > 100000:
819
+ raise ValueError(
820
+ f"Dataset has {n_samples:,} samples. load_test_data() loads "
821
+ f"everything into RAM which may cause OOM. For large inference "
822
+ f"sets, use a DataLoader with MATSource.load_mmap() instead."
823
+ )
808
824
  inp = mat_source._load_dataset(f, inp_key)
809
825
  if out_key:
810
826
  outp = mat_source._load_dataset(f, out_key)
@@ -1126,6 +1142,17 @@ def prepare_data(
1126
1142
 
1127
1143
  if not cache_exists:
1128
1144
  if accelerator.is_main_process:
1145
+ # Delete stale cache files to force regeneration
1146
+ # This prevents silent reuse of old data when metadata invalidates cache
1147
+ for stale_file in [CACHE_FILE, SCALER_FILE]:
1148
+ if os.path.exists(stale_file):
1149
+ try:
1150
+ os.remove(stale_file)
1151
+ logger.debug(f" Removed stale cache: {stale_file}")
1152
+ except OSError as e:
1153
+ logger.warning(
1154
+ f" Failed to remove stale cache {stale_file}: {e}"
1155
+ )
1129
1156
  # RANK 0: Create cache (can take a long time for large datasets)
1130
1157
  # Other ranks will wait at the barrier below
1131
1158
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.2
2
2
  Name: wavedl
3
- Version: 1.5.3
3
+ Version: 1.5.4
4
4
  Summary: A Scalable Deep Learning Framework for Wave-Based Inverse Problems
5
5
  Author: Ductho Le
6
6
  License: MIT
@@ -37,11 +37,12 @@ Requires-Dist: wandb>=0.15.0
37
37
  Requires-Dist: optuna>=3.0.0
38
38
  Requires-Dist: onnx>=1.14.0
39
39
  Requires-Dist: onnxruntime>=1.15.0
40
- Requires-Dist: pytest>=7.0.0
41
- Requires-Dist: pytest-xdist>=3.5.0
42
- Requires-Dist: ruff>=0.8.0
43
- Requires-Dist: pre-commit>=3.5.0
44
40
  Requires-Dist: triton>=2.0.0; sys_platform == "linux"
41
+ Provides-Extra: dev
42
+ Requires-Dist: pytest>=7.0.0; extra == "dev"
43
+ Requires-Dist: pytest-xdist>=3.5.0; extra == "dev"
44
+ Requires-Dist: ruff>=0.8.0; extra == "dev"
45
+ Requires-Dist: pre-commit>=3.5.0; extra == "dev"
45
46
 
46
47
  <div align="center">
47
48
 
@@ -204,7 +205,7 @@ Deploy models anywhere:
204
205
  pip install wavedl
205
206
  ```
206
207
 
207
- This installs everything you need: training, inference, HPO, ONNX export, and dev tools.
208
+ This installs everything you need: training, inference, HPO, ONNX export.
208
209
 
209
210
  #### From Source (for development)
210
211
 
@@ -336,7 +337,7 @@ class MyModel(BaseModel):
336
337
  **Step 2: Train**
337
338
 
338
339
  ```bash
339
- wavedl-hpc --import my_model --model my_model --data_path train.npz
340
+ wavedl-hpc --import my_model.py --model my_model --data_path train.npz
340
341
  ```
341
342
 
342
343
  WaveDL handles everything else: training loop, logging, checkpoints, multi-GPU, early stopping, etc.
@@ -512,7 +513,7 @@ print('\\n✓ All pretrained weights cached!')
512
513
  | Argument | Default | Description |
513
514
  |----------|---------|-------------|
514
515
  | `--model` | `cnn` | Model architecture |
515
- | `--import` | - | Python modules to import (for custom models) |
516
+ | `--import` | - | Python file(s) to import for custom models (supports multiple) |
516
517
  | `--batch_size` | `128` | Per-GPU batch size |
517
518
  | `--lr` | `1e-3` | Learning rate |
518
519
  | `--epochs` | `1000` | Maximum epochs |
@@ -1223,6 +1224,6 @@ This research was enabled in part by support provided by [Compute Ontario](https
1223
1224
  [![Google Scholar](https://img.shields.io/badge/Google_Scholar-4285F4?style=plastic&logo=google-scholar&logoColor=white)](https://scholar.google.ca/citations?user=OlwMr9AAAAAJ)
1224
1225
  [![ResearchGate](https://img.shields.io/badge/ResearchGate-00CCBB?style=plastic&logo=researchgate&logoColor=white)](https://www.researchgate.net/profile/Ductho-Le)
1225
1226
 
1226
- <sub>Released under the MIT License</sub>
1227
+ <sub>May your signals be strong and your attenuation low 👋</sub>
1227
1228
 
1228
1229
  </div>
@@ -14,10 +14,12 @@ wandb>=0.15.0
14
14
  optuna>=3.0.0
15
15
  onnx>=1.14.0
16
16
  onnxruntime>=1.15.0
17
+
18
+ [:sys_platform == "linux"]
19
+ triton>=2.0.0
20
+
21
+ [dev]
17
22
  pytest>=7.0.0
18
23
  pytest-xdist>=3.5.0
19
24
  ruff>=0.8.0
20
25
  pre-commit>=3.5.0
21
-
22
- [:sys_platform == "linux"]
23
- triton>=2.0.0
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes