wavedl 1.5.3__py3-none-any.whl → 1.5.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
wavedl/__init__.py CHANGED
@@ -18,7 +18,7 @@ For inference:
18
18
  # or: python -m wavedl.test --checkpoint best_checkpoint --data_path test.npz
19
19
  """
20
20
 
21
- __version__ = "1.5.3"
21
+ __version__ = "1.5.4"
22
22
  __author__ = "Ductho Le"
23
23
  __email__ = "ductho.le@outlook.com"
24
24
 
wavedl/models/swin.py CHANGED
@@ -191,22 +191,33 @@ class SwinTransformerBase(BaseModel):
191
191
  Returns:
192
192
  List of parameter group dictionaries
193
193
  """
194
- # Separate parameters: head (full LR) vs backbone (decayed LR)
194
+ # Separate parameters into 4 groups for proper LR decay:
195
+ # 1. Head params with decay (full LR)
196
+ # 2. Backbone params with decay (0.1× LR)
197
+ # 3. Head bias/norm without decay (full LR)
198
+ # 4. Backbone bias/norm without decay (0.1× LR)
195
199
  head_params = []
196
200
  backbone_params = []
197
- no_decay_params = []
201
+ head_no_decay = []
202
+ backbone_no_decay = []
198
203
 
199
204
  for name, param in self.backbone.named_parameters():
200
205
  if not param.requires_grad:
201
206
  continue
202
207
 
203
- # No weight decay for bias and normalization
204
- if "bias" in name or "norm" in name:
205
- no_decay_params.append(param)
206
- elif "head" in name:
207
- head_params.append(param)
208
+ is_head = "head" in name
209
+ is_no_decay = "bias" in name or "norm" in name
210
+
211
+ if is_head:
212
+ if is_no_decay:
213
+ head_no_decay.append(param)
214
+ else:
215
+ head_params.append(param)
208
216
  else:
209
- backbone_params.append(param)
217
+ if is_no_decay:
218
+ backbone_no_decay.append(param)
219
+ else:
220
+ backbone_params.append(param)
210
221
 
211
222
  groups = []
212
223
 
@@ -229,15 +240,25 @@ class SwinTransformerBase(BaseModel):
229
240
  }
230
241
  )
231
242
 
232
- if no_decay_params:
243
+ if head_no_decay:
233
244
  groups.append(
234
245
  {
235
- "params": no_decay_params,
246
+ "params": head_no_decay,
236
247
  "lr": base_lr,
237
248
  "weight_decay": 0.0,
238
249
  }
239
250
  )
240
251
 
252
+ if backbone_no_decay:
253
+ # Backbone bias/norm also gets 0.1× LR to match intended decay
254
+ groups.append(
255
+ {
256
+ "params": backbone_no_decay,
257
+ "lr": base_lr * 0.1,
258
+ "weight_decay": 0.0,
259
+ }
260
+ )
261
+
241
262
  return groups if groups else [{"params": self.parameters(), "lr": base_lr}]
242
263
 
243
264
 
wavedl/train.py CHANGED
@@ -122,6 +122,7 @@ import matplotlib.pyplot as plt
122
122
  import numpy as np
123
123
  import pandas as pd
124
124
  import torch
125
+ import torch.distributed as dist
125
126
  from accelerate import Accelerator
126
127
  from accelerate.utils import set_seed
127
128
  from sklearn.metrics import r2_score
@@ -470,15 +471,19 @@ def main():
470
471
  try:
471
472
  # Handle both module names (my_model) and file paths (./my_model.py)
472
473
  if module_name.endswith(".py"):
473
- # Import from file path
474
+ # Import from file path with unique module name
474
475
  import importlib.util
475
476
 
477
+ # Derive unique module name from filename to avoid collisions
478
+ base_name = os.path.splitext(os.path.basename(module_name))[0]
479
+ unique_name = f"wavedl_custom_{base_name}"
480
+
476
481
  spec = importlib.util.spec_from_file_location(
477
- "custom_module", module_name
482
+ unique_name, module_name
478
483
  )
479
484
  if spec and spec.loader:
480
485
  module = importlib.util.module_from_spec(spec)
481
- sys.modules["custom_module"] = module
486
+ sys.modules[unique_name] = module
482
487
  spec.loader.exec_module(module)
483
488
  print(f"✓ Imported custom module from: {module_name}")
484
489
  else:
@@ -1250,9 +1255,32 @@ def main():
1250
1255
  )
1251
1256
 
1252
1257
  # Learning rate scheduling (epoch-based schedulers only)
1258
+ # NOTE: For ReduceLROnPlateau with DDP, we must step only on main process
1259
+ # to avoid patience counter being incremented by all GPU processes.
1260
+ # Then we sync the new LR to all processes to keep them consistent.
1253
1261
  if not scheduler_step_per_batch:
1254
1262
  if args.scheduler == "plateau":
1255
- scheduler.step(avg_val_loss)
1263
+ # Step only on main process to avoid multi-GPU patience bug
1264
+ if accelerator.is_main_process:
1265
+ scheduler.step(avg_val_loss)
1266
+
1267
+ # Sync LR across all processes after main process updates it
1268
+ accelerator.wait_for_everyone()
1269
+
1270
+ # Broadcast new LR from rank 0 to all processes
1271
+ if dist.is_initialized():
1272
+ if accelerator.is_main_process:
1273
+ new_lr = optimizer.param_groups[0]["lr"]
1274
+ else:
1275
+ new_lr = 0.0
1276
+ new_lr_tensor = torch.tensor(
1277
+ new_lr, device=accelerator.device, dtype=torch.float32
1278
+ )
1279
+ dist.broadcast(new_lr_tensor, src=0)
1280
+ # Update LR on non-main processes
1281
+ if not accelerator.is_main_process:
1282
+ for param_group in optimizer.param_groups:
1283
+ param_group["lr"] = new_lr_tensor.item()
1256
1284
  else:
1257
1285
  scheduler.step()
1258
1286
 
wavedl/utils/data.py CHANGED
@@ -793,6 +793,14 @@ def load_test_data(
793
793
  raise KeyError(
794
794
  f"Input key not found. Tried: {custom_input_keys}. Found: {keys}"
795
795
  )
796
+ # OOM guard: warn if dataset is very large
797
+ n_samples = f[inp_key].shape[0]
798
+ if n_samples > 100000:
799
+ raise ValueError(
800
+ f"Dataset has {n_samples:,} samples. load_test_data() loads "
801
+ f"everything into RAM which may cause OOM. For large inference "
802
+ f"sets, use a DataLoader with HDF5Source.load_mmap() instead."
803
+ )
796
804
  inp = f[inp_key][:]
797
805
  outp = f[out_key][:] if out_key else None
798
806
  elif format == "mat":
@@ -805,6 +813,14 @@ def load_test_data(
805
813
  raise KeyError(
806
814
  f"Input key not found. Tried: {custom_input_keys}. Found: {keys}"
807
815
  )
816
+ # OOM guard: warn if dataset is very large (MAT is transposed)
817
+ n_samples = f[inp_key].shape[-1]
818
+ if n_samples > 100000:
819
+ raise ValueError(
820
+ f"Dataset has {n_samples:,} samples. load_test_data() loads "
821
+ f"everything into RAM which may cause OOM. For large inference "
822
+ f"sets, use a DataLoader with MATSource.load_mmap() instead."
823
+ )
808
824
  inp = mat_source._load_dataset(f, inp_key)
809
825
  if out_key:
810
826
  outp = mat_source._load_dataset(f, out_key)
@@ -1126,6 +1142,17 @@ def prepare_data(
1126
1142
 
1127
1143
  if not cache_exists:
1128
1144
  if accelerator.is_main_process:
1145
+ # Delete stale cache files to force regeneration
1146
+ # This prevents silent reuse of old data when metadata invalidates cache
1147
+ for stale_file in [CACHE_FILE, SCALER_FILE]:
1148
+ if os.path.exists(stale_file):
1149
+ try:
1150
+ os.remove(stale_file)
1151
+ logger.debug(f" Removed stale cache: {stale_file}")
1152
+ except OSError as e:
1153
+ logger.warning(
1154
+ f" Failed to remove stale cache {stale_file}: {e}"
1155
+ )
1129
1156
  # RANK 0: Create cache (can take a long time for large datasets)
1130
1157
  # Other ranks will wait at the barrier below
1131
1158
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.2
2
2
  Name: wavedl
3
- Version: 1.5.3
3
+ Version: 1.5.4
4
4
  Summary: A Scalable Deep Learning Framework for Wave-Based Inverse Problems
5
5
  Author: Ductho Le
6
6
  License: MIT
@@ -37,11 +37,12 @@ Requires-Dist: wandb>=0.15.0
37
37
  Requires-Dist: optuna>=3.0.0
38
38
  Requires-Dist: onnx>=1.14.0
39
39
  Requires-Dist: onnxruntime>=1.15.0
40
- Requires-Dist: pytest>=7.0.0
41
- Requires-Dist: pytest-xdist>=3.5.0
42
- Requires-Dist: ruff>=0.8.0
43
- Requires-Dist: pre-commit>=3.5.0
44
40
  Requires-Dist: triton>=2.0.0; sys_platform == "linux"
41
+ Provides-Extra: dev
42
+ Requires-Dist: pytest>=7.0.0; extra == "dev"
43
+ Requires-Dist: pytest-xdist>=3.5.0; extra == "dev"
44
+ Requires-Dist: ruff>=0.8.0; extra == "dev"
45
+ Requires-Dist: pre-commit>=3.5.0; extra == "dev"
45
46
 
46
47
  <div align="center">
47
48
 
@@ -204,7 +205,7 @@ Deploy models anywhere:
204
205
  pip install wavedl
205
206
  ```
206
207
 
207
- This installs everything you need: training, inference, HPO, ONNX export, and dev tools.
208
+ This installs everything you need: training, inference, HPO, ONNX export.
208
209
 
209
210
  #### From Source (for development)
210
211
 
@@ -336,7 +337,7 @@ class MyModel(BaseModel):
336
337
  **Step 2: Train**
337
338
 
338
339
  ```bash
339
- wavedl-hpc --import my_model --model my_model --data_path train.npz
340
+ wavedl-hpc --import my_model.py --model my_model --data_path train.npz
340
341
  ```
341
342
 
342
343
  WaveDL handles everything else: training loop, logging, checkpoints, multi-GPU, early stopping, etc.
@@ -512,7 +513,7 @@ print('\\n✓ All pretrained weights cached!')
512
513
  | Argument | Default | Description |
513
514
  |----------|---------|-------------|
514
515
  | `--model` | `cnn` | Model architecture |
515
- | `--import` | - | Python modules to import (for custom models) |
516
+ | `--import` | - | Python file(s) to import for custom models (supports multiple) |
516
517
  | `--batch_size` | `128` | Per-GPU batch size |
517
518
  | `--lr` | `1e-3` | Learning rate |
518
519
  | `--epochs` | `1000` | Maximum epochs |
@@ -1223,6 +1224,6 @@ This research was enabled in part by support provided by [Compute Ontario](https
1223
1224
  [![Google Scholar](https://img.shields.io/badge/Google_Scholar-4285F4?style=plastic&logo=google-scholar&logoColor=white)](https://scholar.google.ca/citations?user=OlwMr9AAAAAJ)
1224
1225
  [![ResearchGate](https://img.shields.io/badge/ResearchGate-00CCBB?style=plastic&logo=researchgate&logoColor=white)](https://www.researchgate.net/profile/Ductho-Le)
1225
1226
 
1226
- <sub>Released under the MIT License</sub>
1227
+ <sub>May your signals be strong and your attenuation low 👋</sub>
1227
1228
 
1228
1229
  </div>
@@ -1,8 +1,8 @@
1
- wavedl/__init__.py,sha256=1h6l9c3ms45mYhJZskUm28my7Lrq9tXMUs4BtMTiK_s,1177
1
+ wavedl/__init__.py,sha256=L3ckuWk3BDr6h9oiADkGP_JKcGSF669qDkuzofh86IU,1177
2
2
  wavedl/hpc.py,sha256=6rV38nozzMt0-jKZbVJNwvQZXK0wUsIZmr9lgWN_XUw,9212
3
3
  wavedl/hpo.py,sha256=DGCGyt2yhr3WAifAuljhE26gg07CHdaQW4wpDaTKbyo,14968
4
4
  wavedl/test.py,sha256=WIHG3HWT-uF399FQApPpxjggBVFn59cC54HAL4990QU,38550
5
- wavedl/train.py,sha256=Aao8ofyYALqPrMTQarRn4rPWzDLZD-PeuKNVJ76IrVQ,54344
5
+ wavedl/train.py,sha256=7AVaCORFUv2_IgdYSPKdHLxbi11GzMOyu4RcNc4Uf_I,55963
6
6
  wavedl/models/__init__.py,sha256=lfSohEnAUztO14nuwayMJhPjpgySzRN3jGiyAUuBmAU,3206
7
7
  wavedl/models/_template.py,sha256=J_D8taSPmV8lBaucN_vU-WiG98iFr7CJrZVNNX_Tdts,4600
8
8
  wavedl/models/base.py,sha256=T9iDF9IQM2MYucG_ggQd31rieUkB2fob-nkHyNIl2ak,7337
@@ -16,7 +16,7 @@ wavedl/models/registry.py,sha256=InYAXX2xbRvsFDFnYUPCptJh0F9lHlFPN77A9kqHRT0,298
16
16
  wavedl/models/regnet.py,sha256=Yf9gAoDLv0j4uEuoKC822gizHNh59LCbvFCMP11Q1C0,13116
17
17
  wavedl/models/resnet.py,sha256=laePTbIgINijh-Xkcp4iui8-1F17NJAjyAuA4T11eG4,18027
18
18
  wavedl/models/resnet3d.py,sha256=C7CL4XeSnRlIBuwf5Ei-z183uzIBObrXfkM9Iwuc5e0,8746
19
- wavedl/models/swin.py,sha256=p-okfq3Qm4_neJTxCcMzoHoVzC0BHW3BMnbpr_Ri2U0,13224
19
+ wavedl/models/swin.py,sha256=cbV_iqIS4no-EAUR8j_93gqd59AkAkfM5DYo6VryLEg,13937
20
20
  wavedl/models/tcn.py,sha256=RtY13QpFHqz72b4ultv2lStCIDxfvjySVe5JaTx_GaM,12601
21
21
  wavedl/models/unet.py,sha256=LqIXhasdBygwP7SZNNmiW1bHMPaJTVBpaeHtPgEHkdU,7790
22
22
  wavedl/models/vit.py,sha256=68o9nNjkftvHFArAPupU2ew5e5yCsI2AYaT9TQinVMk,12075
@@ -24,15 +24,15 @@ wavedl/utils/__init__.py,sha256=s5R9bRmJ8GNcJrD3OSAOXzwZJIXZbdYrAkZnus11sVQ,3300
24
24
  wavedl/utils/config.py,sha256=AsGwb3XtxmbTLb59BLl5AA4wzMNgVTpl7urOJ6IGqfM,10901
25
25
  wavedl/utils/constraints.py,sha256=Pof5hzeTSGsPY_E6Sc8iMQDaXc_zfEasQI2tCszk_gw,17614
26
26
  wavedl/utils/cross_validation.py,sha256=gwXSFTx5oxWndPjWLJAJzB6nnq2f1t9f86SbjbF-jNI,18475
27
- wavedl/utils/data.py,sha256=H5crttnSfJZBMWQOvM7Cq7nkefnhVlgO0O6J71zJdgI,52651
27
+ wavedl/utils/data.py,sha256=JusSrIZd98t9oiN0xTy2V2mfVyuBCIu0MLAQGcaC0vQ,54194
28
28
  wavedl/utils/distributed.py,sha256=7wQ3mRjkp_xjPSxDWMnBf5dSkAGUaTzntxbz0BhC5v0,4145
29
29
  wavedl/utils/losses.py,sha256=5762M-TBC_hz6uyj1NPbU1vZeFOJQq7fR3-j7OygJRo,7254
30
30
  wavedl/utils/metrics.py,sha256=EJmJvF7gACQsUoKYldlladN_SbnRiuE-Smj0eSnbraQ,39394
31
31
  wavedl/utils/optimizers.py,sha256=PyIkJ_hRhFi_Fio81Gy5YQNhcME0JUUEl8OTSyu-0RA,6323
32
32
  wavedl/utils/schedulers.py,sha256=e6Sf0yj8VOqkdwkUHLMyUfGfHKTX4NMr-zfgxWqCTYI,7659
33
- wavedl-1.5.3.dist-info/LICENSE,sha256=cEUCvcvH-9BT9Y-CNGY__PwWONCKu9zsoIqWA-NeHJ4,1066
34
- wavedl-1.5.3.dist-info/METADATA,sha256=bPNcR8sYE9U7a001lvMFn9oHfmcmkpHUDdGRowLjJEs,45488
35
- wavedl-1.5.3.dist-info/WHEEL,sha256=beeZ86-EfXScwlR_HKu4SllMC9wUEj_8Z_4FJ3egI2w,91
36
- wavedl-1.5.3.dist-info/entry_points.txt,sha256=f1RNDkXFZwBzrBzTMFocJ6xhfTvTmaEDTi5YyDEUaF8,140
37
- wavedl-1.5.3.dist-info/top_level.txt,sha256=ccneUt3D5Qzbh3bsBSSrq9bqrhGiogcWKY24ZC4Q6Xw,7
38
- wavedl-1.5.3.dist-info/RECORD,,
33
+ wavedl-1.5.4.dist-info/LICENSE,sha256=cEUCvcvH-9BT9Y-CNGY__PwWONCKu9zsoIqWA-NeHJ4,1066
34
+ wavedl-1.5.4.dist-info/METADATA,sha256=D7_MbjGWyVEIEH2m23GrJInZO4pcfHAINlY1FIUgD-A,45604
35
+ wavedl-1.5.4.dist-info/WHEEL,sha256=beeZ86-EfXScwlR_HKu4SllMC9wUEj_8Z_4FJ3egI2w,91
36
+ wavedl-1.5.4.dist-info/entry_points.txt,sha256=f1RNDkXFZwBzrBzTMFocJ6xhfTvTmaEDTi5YyDEUaF8,140
37
+ wavedl-1.5.4.dist-info/top_level.txt,sha256=ccneUt3D5Qzbh3bsBSSrq9bqrhGiogcWKY24ZC4Q6Xw,7
38
+ wavedl-1.5.4.dist-info/RECORD,,
File without changes