wavedl 1.5.3__py3-none-any.whl → 1.5.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- wavedl/__init__.py +1 -1
- wavedl/models/swin.py +31 -10
- wavedl/train.py +32 -4
- wavedl/utils/data.py +27 -0
- {wavedl-1.5.3.dist-info → wavedl-1.5.4.dist-info}/METADATA +10 -9
- {wavedl-1.5.3.dist-info → wavedl-1.5.4.dist-info}/RECORD +10 -10
- {wavedl-1.5.3.dist-info → wavedl-1.5.4.dist-info}/LICENSE +0 -0
- {wavedl-1.5.3.dist-info → wavedl-1.5.4.dist-info}/WHEEL +0 -0
- {wavedl-1.5.3.dist-info → wavedl-1.5.4.dist-info}/entry_points.txt +0 -0
- {wavedl-1.5.3.dist-info → wavedl-1.5.4.dist-info}/top_level.txt +0 -0
wavedl/__init__.py
CHANGED
wavedl/models/swin.py
CHANGED
|
@@ -191,22 +191,33 @@ class SwinTransformerBase(BaseModel):
|
|
|
191
191
|
Returns:
|
|
192
192
|
List of parameter group dictionaries
|
|
193
193
|
"""
|
|
194
|
-
# Separate parameters
|
|
194
|
+
# Separate parameters into 4 groups for proper LR decay:
|
|
195
|
+
# 1. Head params with decay (full LR)
|
|
196
|
+
# 2. Backbone params with decay (0.1× LR)
|
|
197
|
+
# 3. Head bias/norm without decay (full LR)
|
|
198
|
+
# 4. Backbone bias/norm without decay (0.1× LR)
|
|
195
199
|
head_params = []
|
|
196
200
|
backbone_params = []
|
|
197
|
-
|
|
201
|
+
head_no_decay = []
|
|
202
|
+
backbone_no_decay = []
|
|
198
203
|
|
|
199
204
|
for name, param in self.backbone.named_parameters():
|
|
200
205
|
if not param.requires_grad:
|
|
201
206
|
continue
|
|
202
207
|
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
208
|
+
is_head = "head" in name
|
|
209
|
+
is_no_decay = "bias" in name or "norm" in name
|
|
210
|
+
|
|
211
|
+
if is_head:
|
|
212
|
+
if is_no_decay:
|
|
213
|
+
head_no_decay.append(param)
|
|
214
|
+
else:
|
|
215
|
+
head_params.append(param)
|
|
208
216
|
else:
|
|
209
|
-
|
|
217
|
+
if is_no_decay:
|
|
218
|
+
backbone_no_decay.append(param)
|
|
219
|
+
else:
|
|
220
|
+
backbone_params.append(param)
|
|
210
221
|
|
|
211
222
|
groups = []
|
|
212
223
|
|
|
@@ -229,15 +240,25 @@ class SwinTransformerBase(BaseModel):
|
|
|
229
240
|
}
|
|
230
241
|
)
|
|
231
242
|
|
|
232
|
-
if
|
|
243
|
+
if head_no_decay:
|
|
233
244
|
groups.append(
|
|
234
245
|
{
|
|
235
|
-
"params":
|
|
246
|
+
"params": head_no_decay,
|
|
236
247
|
"lr": base_lr,
|
|
237
248
|
"weight_decay": 0.0,
|
|
238
249
|
}
|
|
239
250
|
)
|
|
240
251
|
|
|
252
|
+
if backbone_no_decay:
|
|
253
|
+
# Backbone bias/norm also gets 0.1× LR to match intended decay
|
|
254
|
+
groups.append(
|
|
255
|
+
{
|
|
256
|
+
"params": backbone_no_decay,
|
|
257
|
+
"lr": base_lr * 0.1,
|
|
258
|
+
"weight_decay": 0.0,
|
|
259
|
+
}
|
|
260
|
+
)
|
|
261
|
+
|
|
241
262
|
return groups if groups else [{"params": self.parameters(), "lr": base_lr}]
|
|
242
263
|
|
|
243
264
|
|
wavedl/train.py
CHANGED
|
@@ -122,6 +122,7 @@ import matplotlib.pyplot as plt
|
|
|
122
122
|
import numpy as np
|
|
123
123
|
import pandas as pd
|
|
124
124
|
import torch
|
|
125
|
+
import torch.distributed as dist
|
|
125
126
|
from accelerate import Accelerator
|
|
126
127
|
from accelerate.utils import set_seed
|
|
127
128
|
from sklearn.metrics import r2_score
|
|
@@ -470,15 +471,19 @@ def main():
|
|
|
470
471
|
try:
|
|
471
472
|
# Handle both module names (my_model) and file paths (./my_model.py)
|
|
472
473
|
if module_name.endswith(".py"):
|
|
473
|
-
# Import from file path
|
|
474
|
+
# Import from file path with unique module name
|
|
474
475
|
import importlib.util
|
|
475
476
|
|
|
477
|
+
# Derive unique module name from filename to avoid collisions
|
|
478
|
+
base_name = os.path.splitext(os.path.basename(module_name))[0]
|
|
479
|
+
unique_name = f"wavedl_custom_{base_name}"
|
|
480
|
+
|
|
476
481
|
spec = importlib.util.spec_from_file_location(
|
|
477
|
-
|
|
482
|
+
unique_name, module_name
|
|
478
483
|
)
|
|
479
484
|
if spec and spec.loader:
|
|
480
485
|
module = importlib.util.module_from_spec(spec)
|
|
481
|
-
sys.modules[
|
|
486
|
+
sys.modules[unique_name] = module
|
|
482
487
|
spec.loader.exec_module(module)
|
|
483
488
|
print(f"✓ Imported custom module from: {module_name}")
|
|
484
489
|
else:
|
|
@@ -1250,9 +1255,32 @@ def main():
|
|
|
1250
1255
|
)
|
|
1251
1256
|
|
|
1252
1257
|
# Learning rate scheduling (epoch-based schedulers only)
|
|
1258
|
+
# NOTE: For ReduceLROnPlateau with DDP, we must step only on main process
|
|
1259
|
+
# to avoid patience counter being incremented by all GPU processes.
|
|
1260
|
+
# Then we sync the new LR to all processes to keep them consistent.
|
|
1253
1261
|
if not scheduler_step_per_batch:
|
|
1254
1262
|
if args.scheduler == "plateau":
|
|
1255
|
-
|
|
1263
|
+
# Step only on main process to avoid multi-GPU patience bug
|
|
1264
|
+
if accelerator.is_main_process:
|
|
1265
|
+
scheduler.step(avg_val_loss)
|
|
1266
|
+
|
|
1267
|
+
# Sync LR across all processes after main process updates it
|
|
1268
|
+
accelerator.wait_for_everyone()
|
|
1269
|
+
|
|
1270
|
+
# Broadcast new LR from rank 0 to all processes
|
|
1271
|
+
if dist.is_initialized():
|
|
1272
|
+
if accelerator.is_main_process:
|
|
1273
|
+
new_lr = optimizer.param_groups[0]["lr"]
|
|
1274
|
+
else:
|
|
1275
|
+
new_lr = 0.0
|
|
1276
|
+
new_lr_tensor = torch.tensor(
|
|
1277
|
+
new_lr, device=accelerator.device, dtype=torch.float32
|
|
1278
|
+
)
|
|
1279
|
+
dist.broadcast(new_lr_tensor, src=0)
|
|
1280
|
+
# Update LR on non-main processes
|
|
1281
|
+
if not accelerator.is_main_process:
|
|
1282
|
+
for param_group in optimizer.param_groups:
|
|
1283
|
+
param_group["lr"] = new_lr_tensor.item()
|
|
1256
1284
|
else:
|
|
1257
1285
|
scheduler.step()
|
|
1258
1286
|
|
wavedl/utils/data.py
CHANGED
|
@@ -793,6 +793,14 @@ def load_test_data(
|
|
|
793
793
|
raise KeyError(
|
|
794
794
|
f"Input key not found. Tried: {custom_input_keys}. Found: {keys}"
|
|
795
795
|
)
|
|
796
|
+
# OOM guard: warn if dataset is very large
|
|
797
|
+
n_samples = f[inp_key].shape[0]
|
|
798
|
+
if n_samples > 100000:
|
|
799
|
+
raise ValueError(
|
|
800
|
+
f"Dataset has {n_samples:,} samples. load_test_data() loads "
|
|
801
|
+
f"everything into RAM which may cause OOM. For large inference "
|
|
802
|
+
f"sets, use a DataLoader with HDF5Source.load_mmap() instead."
|
|
803
|
+
)
|
|
796
804
|
inp = f[inp_key][:]
|
|
797
805
|
outp = f[out_key][:] if out_key else None
|
|
798
806
|
elif format == "mat":
|
|
@@ -805,6 +813,14 @@ def load_test_data(
|
|
|
805
813
|
raise KeyError(
|
|
806
814
|
f"Input key not found. Tried: {custom_input_keys}. Found: {keys}"
|
|
807
815
|
)
|
|
816
|
+
# OOM guard: warn if dataset is very large (MAT is transposed)
|
|
817
|
+
n_samples = f[inp_key].shape[-1]
|
|
818
|
+
if n_samples > 100000:
|
|
819
|
+
raise ValueError(
|
|
820
|
+
f"Dataset has {n_samples:,} samples. load_test_data() loads "
|
|
821
|
+
f"everything into RAM which may cause OOM. For large inference "
|
|
822
|
+
f"sets, use a DataLoader with MATSource.load_mmap() instead."
|
|
823
|
+
)
|
|
808
824
|
inp = mat_source._load_dataset(f, inp_key)
|
|
809
825
|
if out_key:
|
|
810
826
|
outp = mat_source._load_dataset(f, out_key)
|
|
@@ -1126,6 +1142,17 @@ def prepare_data(
|
|
|
1126
1142
|
|
|
1127
1143
|
if not cache_exists:
|
|
1128
1144
|
if accelerator.is_main_process:
|
|
1145
|
+
# Delete stale cache files to force regeneration
|
|
1146
|
+
# This prevents silent reuse of old data when metadata invalidates cache
|
|
1147
|
+
for stale_file in [CACHE_FILE, SCALER_FILE]:
|
|
1148
|
+
if os.path.exists(stale_file):
|
|
1149
|
+
try:
|
|
1150
|
+
os.remove(stale_file)
|
|
1151
|
+
logger.debug(f" Removed stale cache: {stale_file}")
|
|
1152
|
+
except OSError as e:
|
|
1153
|
+
logger.warning(
|
|
1154
|
+
f" Failed to remove stale cache {stale_file}: {e}"
|
|
1155
|
+
)
|
|
1129
1156
|
# RANK 0: Create cache (can take a long time for large datasets)
|
|
1130
1157
|
# Other ranks will wait at the barrier below
|
|
1131
1158
|
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.2
|
|
2
2
|
Name: wavedl
|
|
3
|
-
Version: 1.5.
|
|
3
|
+
Version: 1.5.4
|
|
4
4
|
Summary: A Scalable Deep Learning Framework for Wave-Based Inverse Problems
|
|
5
5
|
Author: Ductho Le
|
|
6
6
|
License: MIT
|
|
@@ -37,11 +37,12 @@ Requires-Dist: wandb>=0.15.0
|
|
|
37
37
|
Requires-Dist: optuna>=3.0.0
|
|
38
38
|
Requires-Dist: onnx>=1.14.0
|
|
39
39
|
Requires-Dist: onnxruntime>=1.15.0
|
|
40
|
-
Requires-Dist: pytest>=7.0.0
|
|
41
|
-
Requires-Dist: pytest-xdist>=3.5.0
|
|
42
|
-
Requires-Dist: ruff>=0.8.0
|
|
43
|
-
Requires-Dist: pre-commit>=3.5.0
|
|
44
40
|
Requires-Dist: triton>=2.0.0; sys_platform == "linux"
|
|
41
|
+
Provides-Extra: dev
|
|
42
|
+
Requires-Dist: pytest>=7.0.0; extra == "dev"
|
|
43
|
+
Requires-Dist: pytest-xdist>=3.5.0; extra == "dev"
|
|
44
|
+
Requires-Dist: ruff>=0.8.0; extra == "dev"
|
|
45
|
+
Requires-Dist: pre-commit>=3.5.0; extra == "dev"
|
|
45
46
|
|
|
46
47
|
<div align="center">
|
|
47
48
|
|
|
@@ -204,7 +205,7 @@ Deploy models anywhere:
|
|
|
204
205
|
pip install wavedl
|
|
205
206
|
```
|
|
206
207
|
|
|
207
|
-
This installs everything you need: training, inference, HPO, ONNX export
|
|
208
|
+
This installs everything you need: training, inference, HPO, ONNX export.
|
|
208
209
|
|
|
209
210
|
#### From Source (for development)
|
|
210
211
|
|
|
@@ -336,7 +337,7 @@ class MyModel(BaseModel):
|
|
|
336
337
|
**Step 2: Train**
|
|
337
338
|
|
|
338
339
|
```bash
|
|
339
|
-
wavedl-hpc --import my_model --model my_model --data_path train.npz
|
|
340
|
+
wavedl-hpc --import my_model.py --model my_model --data_path train.npz
|
|
340
341
|
```
|
|
341
342
|
|
|
342
343
|
WaveDL handles everything else: training loop, logging, checkpoints, multi-GPU, early stopping, etc.
|
|
@@ -512,7 +513,7 @@ print('\\n✓ All pretrained weights cached!')
|
|
|
512
513
|
| Argument | Default | Description |
|
|
513
514
|
|----------|---------|-------------|
|
|
514
515
|
| `--model` | `cnn` | Model architecture |
|
|
515
|
-
| `--import` | - | Python
|
|
516
|
+
| `--import` | - | Python file(s) to import for custom models (supports multiple) |
|
|
516
517
|
| `--batch_size` | `128` | Per-GPU batch size |
|
|
517
518
|
| `--lr` | `1e-3` | Learning rate |
|
|
518
519
|
| `--epochs` | `1000` | Maximum epochs |
|
|
@@ -1223,6 +1224,6 @@ This research was enabled in part by support provided by [Compute Ontario](https
|
|
|
1223
1224
|
[](https://scholar.google.ca/citations?user=OlwMr9AAAAAJ)
|
|
1224
1225
|
[](https://www.researchgate.net/profile/Ductho-Le)
|
|
1225
1226
|
|
|
1226
|
-
<sub>
|
|
1227
|
+
<sub>May your signals be strong and your attenuation low 👋</sub>
|
|
1227
1228
|
|
|
1228
1229
|
</div>
|
|
@@ -1,8 +1,8 @@
|
|
|
1
|
-
wavedl/__init__.py,sha256=
|
|
1
|
+
wavedl/__init__.py,sha256=L3ckuWk3BDr6h9oiADkGP_JKcGSF669qDkuzofh86IU,1177
|
|
2
2
|
wavedl/hpc.py,sha256=6rV38nozzMt0-jKZbVJNwvQZXK0wUsIZmr9lgWN_XUw,9212
|
|
3
3
|
wavedl/hpo.py,sha256=DGCGyt2yhr3WAifAuljhE26gg07CHdaQW4wpDaTKbyo,14968
|
|
4
4
|
wavedl/test.py,sha256=WIHG3HWT-uF399FQApPpxjggBVFn59cC54HAL4990QU,38550
|
|
5
|
-
wavedl/train.py,sha256=
|
|
5
|
+
wavedl/train.py,sha256=7AVaCORFUv2_IgdYSPKdHLxbi11GzMOyu4RcNc4Uf_I,55963
|
|
6
6
|
wavedl/models/__init__.py,sha256=lfSohEnAUztO14nuwayMJhPjpgySzRN3jGiyAUuBmAU,3206
|
|
7
7
|
wavedl/models/_template.py,sha256=J_D8taSPmV8lBaucN_vU-WiG98iFr7CJrZVNNX_Tdts,4600
|
|
8
8
|
wavedl/models/base.py,sha256=T9iDF9IQM2MYucG_ggQd31rieUkB2fob-nkHyNIl2ak,7337
|
|
@@ -16,7 +16,7 @@ wavedl/models/registry.py,sha256=InYAXX2xbRvsFDFnYUPCptJh0F9lHlFPN77A9kqHRT0,298
|
|
|
16
16
|
wavedl/models/regnet.py,sha256=Yf9gAoDLv0j4uEuoKC822gizHNh59LCbvFCMP11Q1C0,13116
|
|
17
17
|
wavedl/models/resnet.py,sha256=laePTbIgINijh-Xkcp4iui8-1F17NJAjyAuA4T11eG4,18027
|
|
18
18
|
wavedl/models/resnet3d.py,sha256=C7CL4XeSnRlIBuwf5Ei-z183uzIBObrXfkM9Iwuc5e0,8746
|
|
19
|
-
wavedl/models/swin.py,sha256=
|
|
19
|
+
wavedl/models/swin.py,sha256=cbV_iqIS4no-EAUR8j_93gqd59AkAkfM5DYo6VryLEg,13937
|
|
20
20
|
wavedl/models/tcn.py,sha256=RtY13QpFHqz72b4ultv2lStCIDxfvjySVe5JaTx_GaM,12601
|
|
21
21
|
wavedl/models/unet.py,sha256=LqIXhasdBygwP7SZNNmiW1bHMPaJTVBpaeHtPgEHkdU,7790
|
|
22
22
|
wavedl/models/vit.py,sha256=68o9nNjkftvHFArAPupU2ew5e5yCsI2AYaT9TQinVMk,12075
|
|
@@ -24,15 +24,15 @@ wavedl/utils/__init__.py,sha256=s5R9bRmJ8GNcJrD3OSAOXzwZJIXZbdYrAkZnus11sVQ,3300
|
|
|
24
24
|
wavedl/utils/config.py,sha256=AsGwb3XtxmbTLb59BLl5AA4wzMNgVTpl7urOJ6IGqfM,10901
|
|
25
25
|
wavedl/utils/constraints.py,sha256=Pof5hzeTSGsPY_E6Sc8iMQDaXc_zfEasQI2tCszk_gw,17614
|
|
26
26
|
wavedl/utils/cross_validation.py,sha256=gwXSFTx5oxWndPjWLJAJzB6nnq2f1t9f86SbjbF-jNI,18475
|
|
27
|
-
wavedl/utils/data.py,sha256=
|
|
27
|
+
wavedl/utils/data.py,sha256=JusSrIZd98t9oiN0xTy2V2mfVyuBCIu0MLAQGcaC0vQ,54194
|
|
28
28
|
wavedl/utils/distributed.py,sha256=7wQ3mRjkp_xjPSxDWMnBf5dSkAGUaTzntxbz0BhC5v0,4145
|
|
29
29
|
wavedl/utils/losses.py,sha256=5762M-TBC_hz6uyj1NPbU1vZeFOJQq7fR3-j7OygJRo,7254
|
|
30
30
|
wavedl/utils/metrics.py,sha256=EJmJvF7gACQsUoKYldlladN_SbnRiuE-Smj0eSnbraQ,39394
|
|
31
31
|
wavedl/utils/optimizers.py,sha256=PyIkJ_hRhFi_Fio81Gy5YQNhcME0JUUEl8OTSyu-0RA,6323
|
|
32
32
|
wavedl/utils/schedulers.py,sha256=e6Sf0yj8VOqkdwkUHLMyUfGfHKTX4NMr-zfgxWqCTYI,7659
|
|
33
|
-
wavedl-1.5.
|
|
34
|
-
wavedl-1.5.
|
|
35
|
-
wavedl-1.5.
|
|
36
|
-
wavedl-1.5.
|
|
37
|
-
wavedl-1.5.
|
|
38
|
-
wavedl-1.5.
|
|
33
|
+
wavedl-1.5.4.dist-info/LICENSE,sha256=cEUCvcvH-9BT9Y-CNGY__PwWONCKu9zsoIqWA-NeHJ4,1066
|
|
34
|
+
wavedl-1.5.4.dist-info/METADATA,sha256=D7_MbjGWyVEIEH2m23GrJInZO4pcfHAINlY1FIUgD-A,45604
|
|
35
|
+
wavedl-1.5.4.dist-info/WHEEL,sha256=beeZ86-EfXScwlR_HKu4SllMC9wUEj_8Z_4FJ3egI2w,91
|
|
36
|
+
wavedl-1.5.4.dist-info/entry_points.txt,sha256=f1RNDkXFZwBzrBzTMFocJ6xhfTvTmaEDTi5YyDEUaF8,140
|
|
37
|
+
wavedl-1.5.4.dist-info/top_level.txt,sha256=ccneUt3D5Qzbh3bsBSSrq9bqrhGiogcWKY24ZC4Q6Xw,7
|
|
38
|
+
wavedl-1.5.4.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|