wavedl 1.5.0__tar.gz → 1.5.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (43) hide show
  1. {wavedl-1.5.0/src/wavedl.egg-info → wavedl-1.5.1}/PKG-INFO +8 -1
  2. {wavedl-1.5.0 → wavedl-1.5.1}/README.md +7 -0
  3. {wavedl-1.5.0 → wavedl-1.5.1}/src/wavedl/__init__.py +1 -1
  4. {wavedl-1.5.0 → wavedl-1.5.1}/src/wavedl/models/vit.py +21 -0
  5. {wavedl-1.5.0 → wavedl-1.5.1}/src/wavedl/test.py +28 -5
  6. {wavedl-1.5.0 → wavedl-1.5.1}/src/wavedl/train.py +49 -9
  7. {wavedl-1.5.0 → wavedl-1.5.1}/src/wavedl/utils/cross_validation.py +12 -2
  8. {wavedl-1.5.0 → wavedl-1.5.1}/src/wavedl/utils/data.py +26 -7
  9. {wavedl-1.5.0 → wavedl-1.5.1/src/wavedl.egg-info}/PKG-INFO +8 -1
  10. {wavedl-1.5.0 → wavedl-1.5.1}/LICENSE +0 -0
  11. {wavedl-1.5.0 → wavedl-1.5.1}/pyproject.toml +0 -0
  12. {wavedl-1.5.0 → wavedl-1.5.1}/setup.cfg +0 -0
  13. {wavedl-1.5.0 → wavedl-1.5.1}/src/wavedl/hpc.py +0 -0
  14. {wavedl-1.5.0 → wavedl-1.5.1}/src/wavedl/hpo.py +0 -0
  15. {wavedl-1.5.0 → wavedl-1.5.1}/src/wavedl/models/__init__.py +0 -0
  16. {wavedl-1.5.0 → wavedl-1.5.1}/src/wavedl/models/_template.py +0 -0
  17. {wavedl-1.5.0 → wavedl-1.5.1}/src/wavedl/models/base.py +0 -0
  18. {wavedl-1.5.0 → wavedl-1.5.1}/src/wavedl/models/cnn.py +0 -0
  19. {wavedl-1.5.0 → wavedl-1.5.1}/src/wavedl/models/convnext.py +0 -0
  20. {wavedl-1.5.0 → wavedl-1.5.1}/src/wavedl/models/densenet.py +0 -0
  21. {wavedl-1.5.0 → wavedl-1.5.1}/src/wavedl/models/efficientnet.py +0 -0
  22. {wavedl-1.5.0 → wavedl-1.5.1}/src/wavedl/models/efficientnetv2.py +0 -0
  23. {wavedl-1.5.0 → wavedl-1.5.1}/src/wavedl/models/mobilenetv3.py +0 -0
  24. {wavedl-1.5.0 → wavedl-1.5.1}/src/wavedl/models/registry.py +0 -0
  25. {wavedl-1.5.0 → wavedl-1.5.1}/src/wavedl/models/regnet.py +0 -0
  26. {wavedl-1.5.0 → wavedl-1.5.1}/src/wavedl/models/resnet.py +0 -0
  27. {wavedl-1.5.0 → wavedl-1.5.1}/src/wavedl/models/resnet3d.py +0 -0
  28. {wavedl-1.5.0 → wavedl-1.5.1}/src/wavedl/models/swin.py +0 -0
  29. {wavedl-1.5.0 → wavedl-1.5.1}/src/wavedl/models/tcn.py +0 -0
  30. {wavedl-1.5.0 → wavedl-1.5.1}/src/wavedl/models/unet.py +0 -0
  31. {wavedl-1.5.0 → wavedl-1.5.1}/src/wavedl/utils/__init__.py +0 -0
  32. {wavedl-1.5.0 → wavedl-1.5.1}/src/wavedl/utils/config.py +0 -0
  33. {wavedl-1.5.0 → wavedl-1.5.1}/src/wavedl/utils/constraints.py +0 -0
  34. {wavedl-1.5.0 → wavedl-1.5.1}/src/wavedl/utils/distributed.py +0 -0
  35. {wavedl-1.5.0 → wavedl-1.5.1}/src/wavedl/utils/losses.py +0 -0
  36. {wavedl-1.5.0 → wavedl-1.5.1}/src/wavedl/utils/metrics.py +0 -0
  37. {wavedl-1.5.0 → wavedl-1.5.1}/src/wavedl/utils/optimizers.py +0 -0
  38. {wavedl-1.5.0 → wavedl-1.5.1}/src/wavedl/utils/schedulers.py +0 -0
  39. {wavedl-1.5.0 → wavedl-1.5.1}/src/wavedl.egg-info/SOURCES.txt +0 -0
  40. {wavedl-1.5.0 → wavedl-1.5.1}/src/wavedl.egg-info/dependency_links.txt +0 -0
  41. {wavedl-1.5.0 → wavedl-1.5.1}/src/wavedl.egg-info/entry_points.txt +0 -0
  42. {wavedl-1.5.0 → wavedl-1.5.1}/src/wavedl.egg-info/requires.txt +0 -0
  43. {wavedl-1.5.0 → wavedl-1.5.1}/src/wavedl.egg-info/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.2
2
2
  Name: wavedl
3
- Version: 1.5.0
3
+ Version: 1.5.1
4
4
  Summary: A Scalable Deep Learning Framework for Wave-Based Inverse Problems
5
5
  Author: Ductho Le
6
6
  License: MIT
@@ -99,6 +99,7 @@ The framework handles the engineering challenges of large-scale deep learning
99
99
 
100
100
  ## ✨ Features
101
101
 
102
+ <div align="center">
102
103
  <table width="100%">
103
104
  <tr>
104
105
  <td width="50%" valign="top">
@@ -189,6 +190,7 @@ Deploy models anywhere:
189
190
  </td>
190
191
  </tr>
191
192
  </table>
193
+ </div>
192
194
 
193
195
  ---
194
196
 
@@ -277,6 +279,10 @@ python -m wavedl.test --checkpoint <checkpoint_folder> --data_path <test_data> \
277
279
  # Export model to ONNX for deployment (LabVIEW, MATLAB, C++, etc.)
278
280
  python -m wavedl.test --checkpoint <checkpoint_folder> --data_path <test_data> \
279
281
  --export onnx --export_path <output_file.onnx>
282
+
283
+ # For 3D volumes with small depth (e.g., 8×128×128), override auto-detection
284
+ python -m wavedl.test --checkpoint <checkpoint_folder> --data_path <test_data> \
285
+ --input_channels 1
280
286
  ```
281
287
 
282
288
  **Output:**
@@ -372,6 +378,7 @@ WaveDL/
372
378
  │ └── utils/ # Utilities
373
379
  │ ├── data.py # Memory-mapped data pipeline
374
380
  │ ├── metrics.py # R², Pearson, visualization
381
+ │ ├── constraints.py # Physical constraints for training
375
382
  │ ├── distributed.py # DDP synchronization
376
383
  │ ├── losses.py # Loss function factory
377
384
  │ ├── optimizers.py # Optimizer factory
@@ -54,6 +54,7 @@ The framework handles the engineering challenges of large-scale deep learning
54
54
 
55
55
  ## ✨ Features
56
56
 
57
+ <div align="center">
57
58
  <table width="100%">
58
59
  <tr>
59
60
  <td width="50%" valign="top">
@@ -144,6 +145,7 @@ Deploy models anywhere:
144
145
  </td>
145
146
  </tr>
146
147
  </table>
148
+ </div>
147
149
 
148
150
  ---
149
151
 
@@ -232,6 +234,10 @@ python -m wavedl.test --checkpoint <checkpoint_folder> --data_path <test_data> \
232
234
  # Export model to ONNX for deployment (LabVIEW, MATLAB, C++, etc.)
233
235
  python -m wavedl.test --checkpoint <checkpoint_folder> --data_path <test_data> \
234
236
  --export onnx --export_path <output_file.onnx>
237
+
238
+ # For 3D volumes with small depth (e.g., 8×128×128), override auto-detection
239
+ python -m wavedl.test --checkpoint <checkpoint_folder> --data_path <test_data> \
240
+ --input_channels 1
235
241
  ```
236
242
 
237
243
  **Output:**
@@ -327,6 +333,7 @@ WaveDL/
327
333
  │ └── utils/ # Utilities
328
334
  │ ├── data.py # Memory-mapped data pipeline
329
335
  │ ├── metrics.py # R², Pearson, visualization
336
+ │ ├── constraints.py # Physical constraints for training
330
337
  │ ├── distributed.py # DDP synchronization
331
338
  │ ├── losses.py # Loss function factory
332
339
  │ ├── optimizers.py # Optimizer factory
@@ -18,7 +18,7 @@ For inference:
18
18
  # or: python -m wavedl.test --checkpoint best_checkpoint --data_path test.npz
19
19
  """
20
20
 
21
- __version__ = "1.5.0"
21
+ __version__ = "1.5.1"
22
22
  __author__ = "Ductho Le"
23
23
  __email__ = "ductho.le@outlook.com"
24
24
 
@@ -54,6 +54,16 @@ class PatchEmbed(nn.Module):
54
54
  if self.dim == 1:
55
55
  # 1D: segment patches
56
56
  L = in_shape[0]
57
+ if L % patch_size != 0:
58
+ import warnings
59
+
60
+ warnings.warn(
61
+ f"Input length {L} not divisible by patch_size {patch_size}. "
62
+ f"Last {L % patch_size} elements will be dropped. "
63
+ f"Consider padding input to {((L // patch_size) + 1) * patch_size}.",
64
+ UserWarning,
65
+ stacklevel=2,
66
+ )
57
67
  self.num_patches = L // patch_size
58
68
  self.proj = nn.Conv1d(
59
69
  1, embed_dim, kernel_size=patch_size, stride=patch_size
@@ -61,6 +71,17 @@ class PatchEmbed(nn.Module):
61
71
  elif self.dim == 2:
62
72
  # 2D: grid patches
63
73
  H, W = in_shape
74
+ if H % patch_size != 0 or W % patch_size != 0:
75
+ import warnings
76
+
77
+ warnings.warn(
78
+ f"Input shape ({H}, {W}) not divisible by patch_size {patch_size}. "
79
+ f"Border pixels will be dropped (H: {H % patch_size}, W: {W % patch_size}). "
80
+ f"Consider padding to ({((H // patch_size) + 1) * patch_size}, "
81
+ f"{((W // patch_size) + 1) * patch_size}).",
82
+ UserWarning,
83
+ stacklevel=2,
84
+ )
64
85
  self.num_patches = (H // patch_size) * (W // patch_size)
65
86
  self.proj = nn.Conv2d(
66
87
  1, embed_dim, kernel_size=patch_size, stride=patch_size
@@ -166,6 +166,13 @@ def parse_args() -> argparse.Namespace:
166
166
  default=None,
167
167
  help="Parameter names for output (e.g., 'h' 'v11' 'v12')",
168
168
  )
169
+ parser.add_argument(
170
+ "--input_channels",
171
+ type=int,
172
+ default=None,
173
+ help="Explicit number of input channels. Bypasses auto-detection heuristics "
174
+ "for ambiguous 4D shapes (e.g., 3D volumes with small depth).",
175
+ )
169
176
 
170
177
  # Inference options
171
178
  parser.add_argument(
@@ -235,6 +242,7 @@ def load_data_for_inference(
235
242
  format: str = "auto",
236
243
  input_key: str | None = None,
237
244
  output_key: str | None = None,
245
+ input_channels: int | None = None,
238
246
  ) -> tuple[torch.Tensor, torch.Tensor | None]:
239
247
  """
240
248
  Load test data for inference using the unified data loading pipeline.
@@ -278,7 +286,11 @@ def load_data_for_inference(
278
286
 
279
287
  # Use the unified loader from utils.data
280
288
  X, y = load_test_data(
281
- file_path, format=format, input_key=input_key, output_key=output_key
289
+ file_path,
290
+ format=format,
291
+ input_key=input_key,
292
+ output_key=output_key,
293
+ input_channels=input_channels,
282
294
  )
283
295
 
284
296
  # Log results
@@ -452,7 +464,12 @@ def run_inference(
452
464
  predictions: Numpy array (N, out_size) - still in normalized space
453
465
  """
454
466
  if device is None:
455
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
467
+ if torch.cuda.is_available():
468
+ device = torch.device("cuda")
469
+ elif torch.backends.mps.is_available():
470
+ device = torch.device("mps")
471
+ else:
472
+ device = torch.device("cpu")
456
473
 
457
474
  model = model.to(device)
458
475
  model.eval()
@@ -463,7 +480,7 @@ def run_inference(
463
480
  batch_size=batch_size,
464
481
  shuffle=False,
465
482
  num_workers=num_workers,
466
- pin_memory=device.type == "cuda",
483
+ pin_memory=device.type in ("cuda", "mps"),
467
484
  )
468
485
 
469
486
  predictions = []
@@ -919,8 +936,13 @@ def main():
919
936
  )
920
937
  logger = logging.getLogger("Tester")
921
938
 
922
- # Device
923
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
939
+ # Device (CUDA > MPS > CPU)
940
+ if torch.cuda.is_available():
941
+ device = torch.device("cuda")
942
+ elif torch.backends.mps.is_available():
943
+ device = torch.device("mps")
944
+ else:
945
+ device = torch.device("cpu")
924
946
  logger.info(f"Using device: {device}")
925
947
 
926
948
  # Load test data
@@ -929,6 +951,7 @@ def main():
929
951
  format=args.format,
930
952
  input_key=args.input_key,
931
953
  output_key=args.output_key,
954
+ input_channels=args.input_channels,
932
955
  )
933
956
  in_shape = tuple(X_test.shape[2:])
934
957
 
@@ -931,7 +931,11 @@ def main():
931
931
  for x, y in pbar:
932
932
  with accelerator.accumulate(model):
933
933
  pred = model(x)
934
- loss = criterion(pred, y)
934
+ # Pass inputs for input-dependent constraints (x_mean, x[...], etc.)
935
+ if isinstance(criterion, PhysicsConstrainedLoss):
936
+ loss = criterion(pred, y, x)
937
+ else:
938
+ loss = criterion(pred, y)
935
939
 
936
940
  accelerator.backward(loss)
937
941
 
@@ -981,7 +985,11 @@ def main():
981
985
  with torch.inference_mode():
982
986
  for x, y in val_dl:
983
987
  pred = model(x)
984
- loss = criterion(pred, y)
988
+ # Pass inputs for input-dependent constraints
989
+ if isinstance(criterion, PhysicsConstrainedLoss):
990
+ loss = criterion(pred, y, x)
991
+ else:
992
+ loss = criterion(pred, y)
985
993
 
986
994
  val_loss_sum += loss.detach() * x.size(0)
987
995
  val_samples += x.size(0)
@@ -998,13 +1006,45 @@ def main():
998
1006
  cpu_preds = torch.cat(local_preds)
999
1007
  cpu_targets = torch.cat(local_targets)
1000
1008
 
1001
- # Gather predictions and targets across all ranks
1002
- # Use accelerator.gather (works with all accelerate versions)
1003
- gpu_preds = cpu_preds.to(accelerator.device)
1004
- gpu_targets = cpu_targets.to(accelerator.device)
1005
- all_preds_gathered = accelerator.gather(gpu_preds).cpu()
1006
- all_targets_gathered = accelerator.gather(gpu_targets).cpu()
1007
- gathered = [(all_preds_gathered, all_targets_gathered)]
1009
+ # Gather predictions and targets to rank 0 only (memory-efficient)
1010
+ # Avoids duplicating full validation set on every GPU
1011
+ if torch.distributed.is_initialized():
1012
+ # DDP mode: gather only to rank 0
1013
+ # NCCL backend requires CUDA tensors for collective ops
1014
+ gpu_preds = cpu_preds.to(accelerator.device)
1015
+ gpu_targets = cpu_targets.to(accelerator.device)
1016
+
1017
+ if accelerator.is_main_process:
1018
+ # Rank 0: allocate gather buffers on GPU
1019
+ all_preds_list = [
1020
+ torch.zeros_like(gpu_preds)
1021
+ for _ in range(accelerator.num_processes)
1022
+ ]
1023
+ all_targets_list = [
1024
+ torch.zeros_like(gpu_targets)
1025
+ for _ in range(accelerator.num_processes)
1026
+ ]
1027
+ torch.distributed.gather(
1028
+ gpu_preds, gather_list=all_preds_list, dst=0
1029
+ )
1030
+ torch.distributed.gather(
1031
+ gpu_targets, gather_list=all_targets_list, dst=0
1032
+ )
1033
+ # Move back to CPU for metric computation
1034
+ gathered = [
1035
+ (
1036
+ torch.cat(all_preds_list).cpu(),
1037
+ torch.cat(all_targets_list).cpu(),
1038
+ )
1039
+ ]
1040
+ else:
1041
+ # Other ranks: send to rank 0, don't allocate gather buffers
1042
+ torch.distributed.gather(gpu_preds, gather_list=None, dst=0)
1043
+ torch.distributed.gather(gpu_targets, gather_list=None, dst=0)
1044
+ gathered = [(cpu_preds, cpu_targets)] # Placeholder, not used
1045
+ else:
1046
+ # Single-GPU mode: no gathering needed
1047
+ gathered = [(cpu_preds, cpu_targets)]
1008
1048
 
1009
1049
  # Synchronize validation metrics (scalars only - efficient)
1010
1050
  val_loss_scalar = val_loss_sum.item()
@@ -128,6 +128,12 @@ def train_fold(
128
128
  best_state = None
129
129
  history = []
130
130
 
131
+ # Determine if scheduler steps per batch (OneCycleLR) or per epoch
132
+ # Use isinstance check since class name 'OneCycleLR' != 'onecycle' string in is_epoch_based
133
+ from torch.optim.lr_scheduler import OneCycleLR
134
+
135
+ step_per_batch = isinstance(scheduler, OneCycleLR)
136
+
131
137
  for epoch in range(epochs):
132
138
  # Training
133
139
  model.train()
@@ -144,6 +150,10 @@ def train_fold(
144
150
  torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
145
151
  optimizer.step()
146
152
 
153
+ # Per-batch LR scheduling (OneCycleLR)
154
+ if step_per_batch:
155
+ scheduler.step()
156
+
147
157
  train_loss += loss.item() * x.size(0)
148
158
  train_samples += x.size(0)
149
159
 
@@ -186,8 +196,8 @@ def train_fold(
186
196
  }
187
197
  )
188
198
 
189
- # LR scheduling
190
- if hasattr(scheduler, "step"):
199
+ # LR scheduling (epoch-based only, not for per-batch schedulers)
200
+ if not step_per_batch and hasattr(scheduler, "step"):
191
201
  if isinstance(scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
192
202
  scheduler.step(avg_val_loss)
193
203
  else:
@@ -201,8 +201,18 @@ class DataSource(ABC):
201
201
  class NPZSource(DataSource):
202
202
  """Load data from NumPy .npz archives."""
203
203
 
204
+ @staticmethod
205
+ def _safe_load(path: str, mmap_mode: str | None = None):
206
+ """Load NPZ with pickle only if needed (sparse matrix support)."""
207
+ try:
208
+ return np.load(path, allow_pickle=False, mmap_mode=mmap_mode)
209
+ except ValueError:
210
+ # Fallback for sparse matrices stored as object arrays
211
+ return np.load(path, allow_pickle=True, mmap_mode=mmap_mode)
212
+
204
213
  def load(self, path: str) -> tuple[np.ndarray, np.ndarray]:
205
- data = np.load(path, allow_pickle=True)
214
+ """Load NPZ file (pickle enabled only for sparse matrices)."""
215
+ data = self._safe_load(path)
206
216
  keys = list(data.keys())
207
217
 
208
218
  input_key = self._find_key(keys, INPUT_KEYS)
@@ -233,7 +243,7 @@ class NPZSource(DataSource):
233
243
 
234
244
  Note: Returns memory-mapped arrays - do NOT modify them.
235
245
  """
236
- data = np.load(path, allow_pickle=True, mmap_mode="r")
246
+ data = self._safe_load(path, mmap_mode="r")
237
247
  keys = list(data.keys())
238
248
 
239
249
  input_key = self._find_key(keys, INPUT_KEYS)
@@ -253,7 +263,7 @@ class NPZSource(DataSource):
253
263
 
254
264
  def load_outputs_only(self, path: str) -> np.ndarray:
255
265
  """Load only targets from NPZ (avoids loading large input arrays)."""
256
- data = np.load(path, allow_pickle=True)
266
+ data = self._safe_load(path)
257
267
  keys = list(data.keys())
258
268
 
259
269
  output_key = self._find_key(keys, OUTPUT_KEYS)
@@ -677,6 +687,7 @@ def load_test_data(
677
687
  format: str = "auto",
678
688
  input_key: str | None = None,
679
689
  output_key: str | None = None,
690
+ input_channels: int | None = None,
680
691
  ) -> tuple[torch.Tensor, torch.Tensor | None]:
681
692
  """
682
693
  Load test/inference data and return PyTorch tensors ready for model input.
@@ -698,6 +709,9 @@ def load_test_data(
698
709
  format: Format hint ('npz', 'hdf5', 'mat', or 'auto' for detection)
699
710
  input_key: Custom key for input data (overrides auto-detection)
700
711
  output_key: Custom key for output data (overrides auto-detection)
712
+ input_channels: Explicit number of input channels. If provided, bypasses
713
+ the heuristic for 4D data. Use input_channels=1 for 3D volumes that
714
+ look like multi-channel 2D (e.g., depth ≤16).
701
715
 
702
716
  Returns:
703
717
  Tuple of:
@@ -737,7 +751,7 @@ def load_test_data(
737
751
  except KeyError:
738
752
  # Try with just inputs if outputs not found (inference-only mode)
739
753
  if format == "npz":
740
- data = np.load(path, allow_pickle=True)
754
+ data = NPZSource._safe_load(path)
741
755
  keys = list(data.keys())
742
756
  inp_key = DataSource._find_key(keys, custom_input_keys)
743
757
  if inp_key is None:
@@ -822,15 +836,20 @@ def load_test_data(
822
836
  # Add channel dimension if needed (dimension-agnostic)
823
837
  # X.ndim == 2: 1D data (N, L) → (N, 1, L)
824
838
  # X.ndim == 3: 2D data (N, H, W) → (N, 1, H, W)
825
- # X.ndim == 4: Check if already has channel dim (C <= 16 heuristic)
839
+ # X.ndim == 4: Check if already has channel dim
826
840
  if X.ndim == 2:
827
841
  X = X.unsqueeze(1) # 1D signal: (N, L) → (N, 1, L)
828
842
  elif X.ndim == 3:
829
843
  X = X.unsqueeze(1) # 2D image: (N, H, W) → (N, 1, H, W)
830
844
  elif X.ndim == 4:
831
845
  # Could be 3D volume (N, D, H, W) or 2D with channel (N, C, H, W)
832
- # Heuristic: if dim 1 is small (<=16), assume it's already a channel dim
833
- if X.shape[1] > 16:
846
+ if input_channels is not None:
847
+ # Explicit override: user specifies channel count
848
+ if input_channels == 1:
849
+ X = X.unsqueeze(1) # Add channel: (N, D, H, W) → (N, 1, D, H, W)
850
+ # else: already has channels, leave as-is
851
+ elif X.shape[1] > 16:
852
+ # Heuristic fallback: large dim 1 suggests 3D volume needing channel
834
853
  X = X.unsqueeze(1) # 3D volume: (N, D, H, W) → (N, 1, D, H, W)
835
854
  # X.ndim >= 5: assume channel dimension already exists
836
855
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.2
2
2
  Name: wavedl
3
- Version: 1.5.0
3
+ Version: 1.5.1
4
4
  Summary: A Scalable Deep Learning Framework for Wave-Based Inverse Problems
5
5
  Author: Ductho Le
6
6
  License: MIT
@@ -99,6 +99,7 @@ The framework handles the engineering challenges of large-scale deep learning
99
99
 
100
100
  ## ✨ Features
101
101
 
102
+ <div align="center">
102
103
  <table width="100%">
103
104
  <tr>
104
105
  <td width="50%" valign="top">
@@ -189,6 +190,7 @@ Deploy models anywhere:
189
190
  </td>
190
191
  </tr>
191
192
  </table>
193
+ </div>
192
194
 
193
195
  ---
194
196
 
@@ -277,6 +279,10 @@ python -m wavedl.test --checkpoint <checkpoint_folder> --data_path <test_data> \
277
279
  # Export model to ONNX for deployment (LabVIEW, MATLAB, C++, etc.)
278
280
  python -m wavedl.test --checkpoint <checkpoint_folder> --data_path <test_data> \
279
281
  --export onnx --export_path <output_file.onnx>
282
+
283
+ # For 3D volumes with small depth (e.g., 8×128×128), override auto-detection
284
+ python -m wavedl.test --checkpoint <checkpoint_folder> --data_path <test_data> \
285
+ --input_channels 1
280
286
  ```
281
287
 
282
288
  **Output:**
@@ -372,6 +378,7 @@ WaveDL/
372
378
  │ └── utils/ # Utilities
373
379
  │ ├── data.py # Memory-mapped data pipeline
374
380
  │ ├── metrics.py # R², Pearson, visualization
381
+ │ ├── constraints.py # Physical constraints for training
375
382
  │ ├── distributed.py # DDP synchronization
376
383
  │ ├── losses.py # Loss function factory
377
384
  │ ├── optimizers.py # Optimizer factory
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes