wavedl 1.5.0__tar.gz → 1.5.2__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {wavedl-1.5.0/src/wavedl.egg-info → wavedl-1.5.2}/PKG-INFO +8 -1
- {wavedl-1.5.0 → wavedl-1.5.2}/README.md +7 -0
- {wavedl-1.5.0 → wavedl-1.5.2}/src/wavedl/__init__.py +1 -1
- {wavedl-1.5.0 → wavedl-1.5.2}/src/wavedl/models/vit.py +21 -0
- {wavedl-1.5.0 → wavedl-1.5.2}/src/wavedl/test.py +28 -5
- {wavedl-1.5.0 → wavedl-1.5.2}/src/wavedl/train.py +49 -9
- {wavedl-1.5.0 → wavedl-1.5.2}/src/wavedl/utils/cross_validation.py +12 -2
- {wavedl-1.5.0 → wavedl-1.5.2}/src/wavedl/utils/data.py +52 -12
- {wavedl-1.5.0 → wavedl-1.5.2/src/wavedl.egg-info}/PKG-INFO +8 -1
- {wavedl-1.5.0 → wavedl-1.5.2}/LICENSE +0 -0
- {wavedl-1.5.0 → wavedl-1.5.2}/pyproject.toml +0 -0
- {wavedl-1.5.0 → wavedl-1.5.2}/setup.cfg +0 -0
- {wavedl-1.5.0 → wavedl-1.5.2}/src/wavedl/hpc.py +0 -0
- {wavedl-1.5.0 → wavedl-1.5.2}/src/wavedl/hpo.py +0 -0
- {wavedl-1.5.0 → wavedl-1.5.2}/src/wavedl/models/__init__.py +0 -0
- {wavedl-1.5.0 → wavedl-1.5.2}/src/wavedl/models/_template.py +0 -0
- {wavedl-1.5.0 → wavedl-1.5.2}/src/wavedl/models/base.py +0 -0
- {wavedl-1.5.0 → wavedl-1.5.2}/src/wavedl/models/cnn.py +0 -0
- {wavedl-1.5.0 → wavedl-1.5.2}/src/wavedl/models/convnext.py +0 -0
- {wavedl-1.5.0 → wavedl-1.5.2}/src/wavedl/models/densenet.py +0 -0
- {wavedl-1.5.0 → wavedl-1.5.2}/src/wavedl/models/efficientnet.py +0 -0
- {wavedl-1.5.0 → wavedl-1.5.2}/src/wavedl/models/efficientnetv2.py +0 -0
- {wavedl-1.5.0 → wavedl-1.5.2}/src/wavedl/models/mobilenetv3.py +0 -0
- {wavedl-1.5.0 → wavedl-1.5.2}/src/wavedl/models/registry.py +0 -0
- {wavedl-1.5.0 → wavedl-1.5.2}/src/wavedl/models/regnet.py +0 -0
- {wavedl-1.5.0 → wavedl-1.5.2}/src/wavedl/models/resnet.py +0 -0
- {wavedl-1.5.0 → wavedl-1.5.2}/src/wavedl/models/resnet3d.py +0 -0
- {wavedl-1.5.0 → wavedl-1.5.2}/src/wavedl/models/swin.py +0 -0
- {wavedl-1.5.0 → wavedl-1.5.2}/src/wavedl/models/tcn.py +0 -0
- {wavedl-1.5.0 → wavedl-1.5.2}/src/wavedl/models/unet.py +0 -0
- {wavedl-1.5.0 → wavedl-1.5.2}/src/wavedl/utils/__init__.py +0 -0
- {wavedl-1.5.0 → wavedl-1.5.2}/src/wavedl/utils/config.py +0 -0
- {wavedl-1.5.0 → wavedl-1.5.2}/src/wavedl/utils/constraints.py +0 -0
- {wavedl-1.5.0 → wavedl-1.5.2}/src/wavedl/utils/distributed.py +0 -0
- {wavedl-1.5.0 → wavedl-1.5.2}/src/wavedl/utils/losses.py +0 -0
- {wavedl-1.5.0 → wavedl-1.5.2}/src/wavedl/utils/metrics.py +0 -0
- {wavedl-1.5.0 → wavedl-1.5.2}/src/wavedl/utils/optimizers.py +0 -0
- {wavedl-1.5.0 → wavedl-1.5.2}/src/wavedl/utils/schedulers.py +0 -0
- {wavedl-1.5.0 → wavedl-1.5.2}/src/wavedl.egg-info/SOURCES.txt +0 -0
- {wavedl-1.5.0 → wavedl-1.5.2}/src/wavedl.egg-info/dependency_links.txt +0 -0
- {wavedl-1.5.0 → wavedl-1.5.2}/src/wavedl.egg-info/entry_points.txt +0 -0
- {wavedl-1.5.0 → wavedl-1.5.2}/src/wavedl.egg-info/requires.txt +0 -0
- {wavedl-1.5.0 → wavedl-1.5.2}/src/wavedl.egg-info/top_level.txt +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.2
|
|
2
2
|
Name: wavedl
|
|
3
|
-
Version: 1.5.
|
|
3
|
+
Version: 1.5.2
|
|
4
4
|
Summary: A Scalable Deep Learning Framework for Wave-Based Inverse Problems
|
|
5
5
|
Author: Ductho Le
|
|
6
6
|
License: MIT
|
|
@@ -99,6 +99,7 @@ The framework handles the engineering challenges of large-scale deep learning
|
|
|
99
99
|
|
|
100
100
|
## ✨ Features
|
|
101
101
|
|
|
102
|
+
<div align="center">
|
|
102
103
|
<table width="100%">
|
|
103
104
|
<tr>
|
|
104
105
|
<td width="50%" valign="top">
|
|
@@ -189,6 +190,7 @@ Deploy models anywhere:
|
|
|
189
190
|
</td>
|
|
190
191
|
</tr>
|
|
191
192
|
</table>
|
|
193
|
+
</div>
|
|
192
194
|
|
|
193
195
|
---
|
|
194
196
|
|
|
@@ -277,6 +279,10 @@ python -m wavedl.test --checkpoint <checkpoint_folder> --data_path <test_data> \
|
|
|
277
279
|
# Export model to ONNX for deployment (LabVIEW, MATLAB, C++, etc.)
|
|
278
280
|
python -m wavedl.test --checkpoint <checkpoint_folder> --data_path <test_data> \
|
|
279
281
|
--export onnx --export_path <output_file.onnx>
|
|
282
|
+
|
|
283
|
+
# For 3D volumes with small depth (e.g., 8×128×128), override auto-detection
|
|
284
|
+
python -m wavedl.test --checkpoint <checkpoint_folder> --data_path <test_data> \
|
|
285
|
+
--input_channels 1
|
|
280
286
|
```
|
|
281
287
|
|
|
282
288
|
**Output:**
|
|
@@ -372,6 +378,7 @@ WaveDL/
|
|
|
372
378
|
│ └── utils/ # Utilities
|
|
373
379
|
│ ├── data.py # Memory-mapped data pipeline
|
|
374
380
|
│ ├── metrics.py # R², Pearson, visualization
|
|
381
|
+
│ ├── constraints.py # Physical constraints for training
|
|
375
382
|
│ ├── distributed.py # DDP synchronization
|
|
376
383
|
│ ├── losses.py # Loss function factory
|
|
377
384
|
│ ├── optimizers.py # Optimizer factory
|
|
@@ -54,6 +54,7 @@ The framework handles the engineering challenges of large-scale deep learning
|
|
|
54
54
|
|
|
55
55
|
## ✨ Features
|
|
56
56
|
|
|
57
|
+
<div align="center">
|
|
57
58
|
<table width="100%">
|
|
58
59
|
<tr>
|
|
59
60
|
<td width="50%" valign="top">
|
|
@@ -144,6 +145,7 @@ Deploy models anywhere:
|
|
|
144
145
|
</td>
|
|
145
146
|
</tr>
|
|
146
147
|
</table>
|
|
148
|
+
</div>
|
|
147
149
|
|
|
148
150
|
---
|
|
149
151
|
|
|
@@ -232,6 +234,10 @@ python -m wavedl.test --checkpoint <checkpoint_folder> --data_path <test_data> \
|
|
|
232
234
|
# Export model to ONNX for deployment (LabVIEW, MATLAB, C++, etc.)
|
|
233
235
|
python -m wavedl.test --checkpoint <checkpoint_folder> --data_path <test_data> \
|
|
234
236
|
--export onnx --export_path <output_file.onnx>
|
|
237
|
+
|
|
238
|
+
# For 3D volumes with small depth (e.g., 8×128×128), override auto-detection
|
|
239
|
+
python -m wavedl.test --checkpoint <checkpoint_folder> --data_path <test_data> \
|
|
240
|
+
--input_channels 1
|
|
235
241
|
```
|
|
236
242
|
|
|
237
243
|
**Output:**
|
|
@@ -327,6 +333,7 @@ WaveDL/
|
|
|
327
333
|
│ └── utils/ # Utilities
|
|
328
334
|
│ ├── data.py # Memory-mapped data pipeline
|
|
329
335
|
│ ├── metrics.py # R², Pearson, visualization
|
|
336
|
+
│ ├── constraints.py # Physical constraints for training
|
|
330
337
|
│ ├── distributed.py # DDP synchronization
|
|
331
338
|
│ ├── losses.py # Loss function factory
|
|
332
339
|
│ ├── optimizers.py # Optimizer factory
|
|
@@ -54,6 +54,16 @@ class PatchEmbed(nn.Module):
|
|
|
54
54
|
if self.dim == 1:
|
|
55
55
|
# 1D: segment patches
|
|
56
56
|
L = in_shape[0]
|
|
57
|
+
if L % patch_size != 0:
|
|
58
|
+
import warnings
|
|
59
|
+
|
|
60
|
+
warnings.warn(
|
|
61
|
+
f"Input length {L} not divisible by patch_size {patch_size}. "
|
|
62
|
+
f"Last {L % patch_size} elements will be dropped. "
|
|
63
|
+
f"Consider padding input to {((L // patch_size) + 1) * patch_size}.",
|
|
64
|
+
UserWarning,
|
|
65
|
+
stacklevel=2,
|
|
66
|
+
)
|
|
57
67
|
self.num_patches = L // patch_size
|
|
58
68
|
self.proj = nn.Conv1d(
|
|
59
69
|
1, embed_dim, kernel_size=patch_size, stride=patch_size
|
|
@@ -61,6 +71,17 @@ class PatchEmbed(nn.Module):
|
|
|
61
71
|
elif self.dim == 2:
|
|
62
72
|
# 2D: grid patches
|
|
63
73
|
H, W = in_shape
|
|
74
|
+
if H % patch_size != 0 or W % patch_size != 0:
|
|
75
|
+
import warnings
|
|
76
|
+
|
|
77
|
+
warnings.warn(
|
|
78
|
+
f"Input shape ({H}, {W}) not divisible by patch_size {patch_size}. "
|
|
79
|
+
f"Border pixels will be dropped (H: {H % patch_size}, W: {W % patch_size}). "
|
|
80
|
+
f"Consider padding to ({((H // patch_size) + 1) * patch_size}, "
|
|
81
|
+
f"{((W // patch_size) + 1) * patch_size}).",
|
|
82
|
+
UserWarning,
|
|
83
|
+
stacklevel=2,
|
|
84
|
+
)
|
|
64
85
|
self.num_patches = (H // patch_size) * (W // patch_size)
|
|
65
86
|
self.proj = nn.Conv2d(
|
|
66
87
|
1, embed_dim, kernel_size=patch_size, stride=patch_size
|
|
@@ -166,6 +166,13 @@ def parse_args() -> argparse.Namespace:
|
|
|
166
166
|
default=None,
|
|
167
167
|
help="Parameter names for output (e.g., 'h' 'v11' 'v12')",
|
|
168
168
|
)
|
|
169
|
+
parser.add_argument(
|
|
170
|
+
"--input_channels",
|
|
171
|
+
type=int,
|
|
172
|
+
default=None,
|
|
173
|
+
help="Explicit number of input channels. Bypasses auto-detection heuristics "
|
|
174
|
+
"for ambiguous 4D shapes (e.g., 3D volumes with small depth).",
|
|
175
|
+
)
|
|
169
176
|
|
|
170
177
|
# Inference options
|
|
171
178
|
parser.add_argument(
|
|
@@ -235,6 +242,7 @@ def load_data_for_inference(
|
|
|
235
242
|
format: str = "auto",
|
|
236
243
|
input_key: str | None = None,
|
|
237
244
|
output_key: str | None = None,
|
|
245
|
+
input_channels: int | None = None,
|
|
238
246
|
) -> tuple[torch.Tensor, torch.Tensor | None]:
|
|
239
247
|
"""
|
|
240
248
|
Load test data for inference using the unified data loading pipeline.
|
|
@@ -278,7 +286,11 @@ def load_data_for_inference(
|
|
|
278
286
|
|
|
279
287
|
# Use the unified loader from utils.data
|
|
280
288
|
X, y = load_test_data(
|
|
281
|
-
file_path,
|
|
289
|
+
file_path,
|
|
290
|
+
format=format,
|
|
291
|
+
input_key=input_key,
|
|
292
|
+
output_key=output_key,
|
|
293
|
+
input_channels=input_channels,
|
|
282
294
|
)
|
|
283
295
|
|
|
284
296
|
# Log results
|
|
@@ -452,7 +464,12 @@ def run_inference(
|
|
|
452
464
|
predictions: Numpy array (N, out_size) - still in normalized space
|
|
453
465
|
"""
|
|
454
466
|
if device is None:
|
|
455
|
-
|
|
467
|
+
if torch.cuda.is_available():
|
|
468
|
+
device = torch.device("cuda")
|
|
469
|
+
elif torch.backends.mps.is_available():
|
|
470
|
+
device = torch.device("mps")
|
|
471
|
+
else:
|
|
472
|
+
device = torch.device("cpu")
|
|
456
473
|
|
|
457
474
|
model = model.to(device)
|
|
458
475
|
model.eval()
|
|
@@ -463,7 +480,7 @@ def run_inference(
|
|
|
463
480
|
batch_size=batch_size,
|
|
464
481
|
shuffle=False,
|
|
465
482
|
num_workers=num_workers,
|
|
466
|
-
pin_memory=device.type
|
|
483
|
+
pin_memory=device.type in ("cuda", "mps"),
|
|
467
484
|
)
|
|
468
485
|
|
|
469
486
|
predictions = []
|
|
@@ -919,8 +936,13 @@ def main():
|
|
|
919
936
|
)
|
|
920
937
|
logger = logging.getLogger("Tester")
|
|
921
938
|
|
|
922
|
-
# Device
|
|
923
|
-
|
|
939
|
+
# Device (CUDA > MPS > CPU)
|
|
940
|
+
if torch.cuda.is_available():
|
|
941
|
+
device = torch.device("cuda")
|
|
942
|
+
elif torch.backends.mps.is_available():
|
|
943
|
+
device = torch.device("mps")
|
|
944
|
+
else:
|
|
945
|
+
device = torch.device("cpu")
|
|
924
946
|
logger.info(f"Using device: {device}")
|
|
925
947
|
|
|
926
948
|
# Load test data
|
|
@@ -929,6 +951,7 @@ def main():
|
|
|
929
951
|
format=args.format,
|
|
930
952
|
input_key=args.input_key,
|
|
931
953
|
output_key=args.output_key,
|
|
954
|
+
input_channels=args.input_channels,
|
|
932
955
|
)
|
|
933
956
|
in_shape = tuple(X_test.shape[2:])
|
|
934
957
|
|
|
@@ -931,7 +931,11 @@ def main():
|
|
|
931
931
|
for x, y in pbar:
|
|
932
932
|
with accelerator.accumulate(model):
|
|
933
933
|
pred = model(x)
|
|
934
|
-
|
|
934
|
+
# Pass inputs for input-dependent constraints (x_mean, x[...], etc.)
|
|
935
|
+
if isinstance(criterion, PhysicsConstrainedLoss):
|
|
936
|
+
loss = criterion(pred, y, x)
|
|
937
|
+
else:
|
|
938
|
+
loss = criterion(pred, y)
|
|
935
939
|
|
|
936
940
|
accelerator.backward(loss)
|
|
937
941
|
|
|
@@ -981,7 +985,11 @@ def main():
|
|
|
981
985
|
with torch.inference_mode():
|
|
982
986
|
for x, y in val_dl:
|
|
983
987
|
pred = model(x)
|
|
984
|
-
|
|
988
|
+
# Pass inputs for input-dependent constraints
|
|
989
|
+
if isinstance(criterion, PhysicsConstrainedLoss):
|
|
990
|
+
loss = criterion(pred, y, x)
|
|
991
|
+
else:
|
|
992
|
+
loss = criterion(pred, y)
|
|
985
993
|
|
|
986
994
|
val_loss_sum += loss.detach() * x.size(0)
|
|
987
995
|
val_samples += x.size(0)
|
|
@@ -998,13 +1006,45 @@ def main():
|
|
|
998
1006
|
cpu_preds = torch.cat(local_preds)
|
|
999
1007
|
cpu_targets = torch.cat(local_targets)
|
|
1000
1008
|
|
|
1001
|
-
# Gather predictions and targets
|
|
1002
|
-
#
|
|
1003
|
-
|
|
1004
|
-
|
|
1005
|
-
|
|
1006
|
-
|
|
1007
|
-
|
|
1009
|
+
# Gather predictions and targets to rank 0 only (memory-efficient)
|
|
1010
|
+
# Avoids duplicating full validation set on every GPU
|
|
1011
|
+
if torch.distributed.is_initialized():
|
|
1012
|
+
# DDP mode: gather only to rank 0
|
|
1013
|
+
# NCCL backend requires CUDA tensors for collective ops
|
|
1014
|
+
gpu_preds = cpu_preds.to(accelerator.device)
|
|
1015
|
+
gpu_targets = cpu_targets.to(accelerator.device)
|
|
1016
|
+
|
|
1017
|
+
if accelerator.is_main_process:
|
|
1018
|
+
# Rank 0: allocate gather buffers on GPU
|
|
1019
|
+
all_preds_list = [
|
|
1020
|
+
torch.zeros_like(gpu_preds)
|
|
1021
|
+
for _ in range(accelerator.num_processes)
|
|
1022
|
+
]
|
|
1023
|
+
all_targets_list = [
|
|
1024
|
+
torch.zeros_like(gpu_targets)
|
|
1025
|
+
for _ in range(accelerator.num_processes)
|
|
1026
|
+
]
|
|
1027
|
+
torch.distributed.gather(
|
|
1028
|
+
gpu_preds, gather_list=all_preds_list, dst=0
|
|
1029
|
+
)
|
|
1030
|
+
torch.distributed.gather(
|
|
1031
|
+
gpu_targets, gather_list=all_targets_list, dst=0
|
|
1032
|
+
)
|
|
1033
|
+
# Move back to CPU for metric computation
|
|
1034
|
+
gathered = [
|
|
1035
|
+
(
|
|
1036
|
+
torch.cat(all_preds_list).cpu(),
|
|
1037
|
+
torch.cat(all_targets_list).cpu(),
|
|
1038
|
+
)
|
|
1039
|
+
]
|
|
1040
|
+
else:
|
|
1041
|
+
# Other ranks: send to rank 0, don't allocate gather buffers
|
|
1042
|
+
torch.distributed.gather(gpu_preds, gather_list=None, dst=0)
|
|
1043
|
+
torch.distributed.gather(gpu_targets, gather_list=None, dst=0)
|
|
1044
|
+
gathered = [(cpu_preds, cpu_targets)] # Placeholder, not used
|
|
1045
|
+
else:
|
|
1046
|
+
# Single-GPU mode: no gathering needed
|
|
1047
|
+
gathered = [(cpu_preds, cpu_targets)]
|
|
1008
1048
|
|
|
1009
1049
|
# Synchronize validation metrics (scalars only - efficient)
|
|
1010
1050
|
val_loss_scalar = val_loss_sum.item()
|
|
@@ -128,6 +128,12 @@ def train_fold(
|
|
|
128
128
|
best_state = None
|
|
129
129
|
history = []
|
|
130
130
|
|
|
131
|
+
# Determine if scheduler steps per batch (OneCycleLR) or per epoch
|
|
132
|
+
# Use isinstance check since class name 'OneCycleLR' != 'onecycle' string in is_epoch_based
|
|
133
|
+
from torch.optim.lr_scheduler import OneCycleLR
|
|
134
|
+
|
|
135
|
+
step_per_batch = isinstance(scheduler, OneCycleLR)
|
|
136
|
+
|
|
131
137
|
for epoch in range(epochs):
|
|
132
138
|
# Training
|
|
133
139
|
model.train()
|
|
@@ -144,6 +150,10 @@ def train_fold(
|
|
|
144
150
|
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
|
|
145
151
|
optimizer.step()
|
|
146
152
|
|
|
153
|
+
# Per-batch LR scheduling (OneCycleLR)
|
|
154
|
+
if step_per_batch:
|
|
155
|
+
scheduler.step()
|
|
156
|
+
|
|
147
157
|
train_loss += loss.item() * x.size(0)
|
|
148
158
|
train_samples += x.size(0)
|
|
149
159
|
|
|
@@ -186,8 +196,8 @@ def train_fold(
|
|
|
186
196
|
}
|
|
187
197
|
)
|
|
188
198
|
|
|
189
|
-
# LR scheduling
|
|
190
|
-
if hasattr(scheduler, "step"):
|
|
199
|
+
# LR scheduling (epoch-based only, not for per-batch schedulers)
|
|
200
|
+
if not step_per_batch and hasattr(scheduler, "step"):
|
|
191
201
|
if isinstance(scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
|
|
192
202
|
scheduler.step(avg_val_loss)
|
|
193
203
|
else:
|
|
@@ -201,9 +201,32 @@ class DataSource(ABC):
|
|
|
201
201
|
class NPZSource(DataSource):
|
|
202
202
|
"""Load data from NumPy .npz archives."""
|
|
203
203
|
|
|
204
|
+
@staticmethod
|
|
205
|
+
def _safe_load(path: str, keys_to_probe: list[str], mmap_mode: str | None = None):
|
|
206
|
+
"""Load NPZ with pickle only if needed (sparse matrix support).
|
|
207
|
+
|
|
208
|
+
The error for object arrays happens at ACCESS time, not load time.
|
|
209
|
+
So we need to probe the keys to detect if pickle is required.
|
|
210
|
+
"""
|
|
211
|
+
data = np.load(path, allow_pickle=False, mmap_mode=mmap_mode)
|
|
212
|
+
try:
|
|
213
|
+
# Probe keys to trigger error if object arrays exist
|
|
214
|
+
for key in keys_to_probe:
|
|
215
|
+
if key in data:
|
|
216
|
+
_ = data[key] # This raises ValueError for object arrays
|
|
217
|
+
return data
|
|
218
|
+
except ValueError as e:
|
|
219
|
+
if "allow_pickle=False" in str(e):
|
|
220
|
+
# Fallback for sparse matrices stored as object arrays
|
|
221
|
+
data.close() if hasattr(data, "close") else None
|
|
222
|
+
return np.load(path, allow_pickle=True, mmap_mode=mmap_mode)
|
|
223
|
+
raise
|
|
224
|
+
|
|
204
225
|
def load(self, path: str) -> tuple[np.ndarray, np.ndarray]:
|
|
205
|
-
|
|
206
|
-
keys
|
|
226
|
+
"""Load NPZ file (pickle enabled only for sparse matrices)."""
|
|
227
|
+
# First pass to find keys without loading data
|
|
228
|
+
with np.load(path, allow_pickle=False) as probe:
|
|
229
|
+
keys = list(probe.keys())
|
|
207
230
|
|
|
208
231
|
input_key = self._find_key(keys, INPUT_KEYS)
|
|
209
232
|
output_key = self._find_key(keys, OUTPUT_KEYS)
|
|
@@ -215,6 +238,7 @@ class NPZSource(DataSource):
|
|
|
215
238
|
f"Found: {keys}"
|
|
216
239
|
)
|
|
217
240
|
|
|
241
|
+
data = self._safe_load(path, [input_key, output_key])
|
|
218
242
|
inp = data[input_key]
|
|
219
243
|
outp = data[output_key]
|
|
220
244
|
|
|
@@ -233,8 +257,9 @@ class NPZSource(DataSource):
|
|
|
233
257
|
|
|
234
258
|
Note: Returns memory-mapped arrays - do NOT modify them.
|
|
235
259
|
"""
|
|
236
|
-
|
|
237
|
-
|
|
260
|
+
# First pass to find keys without loading data
|
|
261
|
+
with np.load(path, allow_pickle=False) as probe:
|
|
262
|
+
keys = list(probe.keys())
|
|
238
263
|
|
|
239
264
|
input_key = self._find_key(keys, INPUT_KEYS)
|
|
240
265
|
output_key = self._find_key(keys, OUTPUT_KEYS)
|
|
@@ -246,6 +271,7 @@ class NPZSource(DataSource):
|
|
|
246
271
|
f"Found: {keys}"
|
|
247
272
|
)
|
|
248
273
|
|
|
274
|
+
data = self._safe_load(path, [input_key, output_key], mmap_mode="r")
|
|
249
275
|
inp = data[input_key]
|
|
250
276
|
outp = data[output_key]
|
|
251
277
|
|
|
@@ -253,8 +279,9 @@ class NPZSource(DataSource):
|
|
|
253
279
|
|
|
254
280
|
def load_outputs_only(self, path: str) -> np.ndarray:
|
|
255
281
|
"""Load only targets from NPZ (avoids loading large input arrays)."""
|
|
256
|
-
|
|
257
|
-
|
|
282
|
+
# First pass to find keys without loading data
|
|
283
|
+
with np.load(path, allow_pickle=False) as probe:
|
|
284
|
+
keys = list(probe.keys())
|
|
258
285
|
|
|
259
286
|
output_key = self._find_key(keys, OUTPUT_KEYS)
|
|
260
287
|
if output_key is None:
|
|
@@ -263,6 +290,7 @@ class NPZSource(DataSource):
|
|
|
263
290
|
f"Supported keys: {OUTPUT_KEYS}. Found: {keys}"
|
|
264
291
|
)
|
|
265
292
|
|
|
293
|
+
data = self._safe_load(path, [output_key])
|
|
266
294
|
return data[output_key]
|
|
267
295
|
|
|
268
296
|
|
|
@@ -677,6 +705,7 @@ def load_test_data(
|
|
|
677
705
|
format: str = "auto",
|
|
678
706
|
input_key: str | None = None,
|
|
679
707
|
output_key: str | None = None,
|
|
708
|
+
input_channels: int | None = None,
|
|
680
709
|
) -> tuple[torch.Tensor, torch.Tensor | None]:
|
|
681
710
|
"""
|
|
682
711
|
Load test/inference data and return PyTorch tensors ready for model input.
|
|
@@ -698,6 +727,9 @@ def load_test_data(
|
|
|
698
727
|
format: Format hint ('npz', 'hdf5', 'mat', or 'auto' for detection)
|
|
699
728
|
input_key: Custom key for input data (overrides auto-detection)
|
|
700
729
|
output_key: Custom key for output data (overrides auto-detection)
|
|
730
|
+
input_channels: Explicit number of input channels. If provided, bypasses
|
|
731
|
+
the heuristic for 4D data. Use input_channels=1 for 3D volumes that
|
|
732
|
+
look like multi-channel 2D (e.g., depth ≤16).
|
|
701
733
|
|
|
702
734
|
Returns:
|
|
703
735
|
Tuple of:
|
|
@@ -737,19 +769,22 @@ def load_test_data(
|
|
|
737
769
|
except KeyError:
|
|
738
770
|
# Try with just inputs if outputs not found (inference-only mode)
|
|
739
771
|
if format == "npz":
|
|
740
|
-
|
|
741
|
-
|
|
772
|
+
# First pass to find keys
|
|
773
|
+
with np.load(path, allow_pickle=False) as probe:
|
|
774
|
+
keys = list(probe.keys())
|
|
742
775
|
inp_key = DataSource._find_key(keys, custom_input_keys)
|
|
743
776
|
if inp_key is None:
|
|
744
777
|
raise KeyError(
|
|
745
778
|
f"Input key not found. Tried: {custom_input_keys}. Found: {keys}"
|
|
746
779
|
)
|
|
780
|
+
out_key = DataSource._find_key(keys, custom_output_keys)
|
|
781
|
+
keys_to_probe = [inp_key] + ([out_key] if out_key else [])
|
|
782
|
+
data = NPZSource._safe_load(path, keys_to_probe)
|
|
747
783
|
inp = data[inp_key]
|
|
748
784
|
if inp.dtype == object:
|
|
749
785
|
inp = np.array(
|
|
750
786
|
[x.toarray() if hasattr(x, "toarray") else x for x in inp]
|
|
751
787
|
)
|
|
752
|
-
out_key = DataSource._find_key(keys, custom_output_keys)
|
|
753
788
|
outp = data[out_key] if out_key else None
|
|
754
789
|
elif format == "hdf5":
|
|
755
790
|
# HDF5: input-only loading for inference
|
|
@@ -822,15 +857,20 @@ def load_test_data(
|
|
|
822
857
|
# Add channel dimension if needed (dimension-agnostic)
|
|
823
858
|
# X.ndim == 2: 1D data (N, L) → (N, 1, L)
|
|
824
859
|
# X.ndim == 3: 2D data (N, H, W) → (N, 1, H, W)
|
|
825
|
-
# X.ndim == 4: Check if already has channel dim
|
|
860
|
+
# X.ndim == 4: Check if already has channel dim
|
|
826
861
|
if X.ndim == 2:
|
|
827
862
|
X = X.unsqueeze(1) # 1D signal: (N, L) → (N, 1, L)
|
|
828
863
|
elif X.ndim == 3:
|
|
829
864
|
X = X.unsqueeze(1) # 2D image: (N, H, W) → (N, 1, H, W)
|
|
830
865
|
elif X.ndim == 4:
|
|
831
866
|
# Could be 3D volume (N, D, H, W) or 2D with channel (N, C, H, W)
|
|
832
|
-
|
|
833
|
-
|
|
867
|
+
if input_channels is not None:
|
|
868
|
+
# Explicit override: user specifies channel count
|
|
869
|
+
if input_channels == 1:
|
|
870
|
+
X = X.unsqueeze(1) # Add channel: (N, D, H, W) → (N, 1, D, H, W)
|
|
871
|
+
# else: already has channels, leave as-is
|
|
872
|
+
elif X.shape[1] > 16:
|
|
873
|
+
# Heuristic fallback: large dim 1 suggests 3D volume needing channel
|
|
834
874
|
X = X.unsqueeze(1) # 3D volume: (N, D, H, W) → (N, 1, D, H, W)
|
|
835
875
|
# X.ndim >= 5: assume channel dimension already exists
|
|
836
876
|
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.2
|
|
2
2
|
Name: wavedl
|
|
3
|
-
Version: 1.5.
|
|
3
|
+
Version: 1.5.2
|
|
4
4
|
Summary: A Scalable Deep Learning Framework for Wave-Based Inverse Problems
|
|
5
5
|
Author: Ductho Le
|
|
6
6
|
License: MIT
|
|
@@ -99,6 +99,7 @@ The framework handles the engineering challenges of large-scale deep learning
|
|
|
99
99
|
|
|
100
100
|
## ✨ Features
|
|
101
101
|
|
|
102
|
+
<div align="center">
|
|
102
103
|
<table width="100%">
|
|
103
104
|
<tr>
|
|
104
105
|
<td width="50%" valign="top">
|
|
@@ -189,6 +190,7 @@ Deploy models anywhere:
|
|
|
189
190
|
</td>
|
|
190
191
|
</tr>
|
|
191
192
|
</table>
|
|
193
|
+
</div>
|
|
192
194
|
|
|
193
195
|
---
|
|
194
196
|
|
|
@@ -277,6 +279,10 @@ python -m wavedl.test --checkpoint <checkpoint_folder> --data_path <test_data> \
|
|
|
277
279
|
# Export model to ONNX for deployment (LabVIEW, MATLAB, C++, etc.)
|
|
278
280
|
python -m wavedl.test --checkpoint <checkpoint_folder> --data_path <test_data> \
|
|
279
281
|
--export onnx --export_path <output_file.onnx>
|
|
282
|
+
|
|
283
|
+
# For 3D volumes with small depth (e.g., 8×128×128), override auto-detection
|
|
284
|
+
python -m wavedl.test --checkpoint <checkpoint_folder> --data_path <test_data> \
|
|
285
|
+
--input_channels 1
|
|
280
286
|
```
|
|
281
287
|
|
|
282
288
|
**Output:**
|
|
@@ -372,6 +378,7 @@ WaveDL/
|
|
|
372
378
|
│ └── utils/ # Utilities
|
|
373
379
|
│ ├── data.py # Memory-mapped data pipeline
|
|
374
380
|
│ ├── metrics.py # R², Pearson, visualization
|
|
381
|
+
│ ├── constraints.py # Physical constraints for training
|
|
375
382
|
│ ├── distributed.py # DDP synchronization
|
|
376
383
|
│ ├── losses.py # Loss function factory
|
|
377
384
|
│ ├── optimizers.py # Optimizer factory
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|