wavedl 1.5.6__py3-none-any.whl → 1.6.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- wavedl/__init__.py +1 -1
- wavedl/models/__init__.py +52 -4
- wavedl/models/_timm_utils.py +238 -0
- wavedl/models/caformer.py +270 -0
- wavedl/models/convnext.py +108 -33
- wavedl/models/convnext_v2.py +504 -0
- wavedl/models/densenet.py +5 -5
- wavedl/models/efficientnet.py +30 -13
- wavedl/models/efficientnetv2.py +32 -9
- wavedl/models/fastvit.py +285 -0
- wavedl/models/mamba.py +535 -0
- wavedl/models/maxvit.py +251 -0
- wavedl/models/mobilenetv3.py +35 -12
- wavedl/models/regnet.py +39 -16
- wavedl/models/resnet.py +5 -5
- wavedl/models/resnet3d.py +2 -2
- wavedl/models/swin.py +41 -9
- wavedl/models/tcn.py +25 -5
- wavedl/models/unet.py +1 -1
- wavedl/models/vit.py +6 -6
- wavedl/test.py +7 -3
- wavedl/train.py +57 -23
- wavedl/utils/constraints.py +11 -5
- wavedl/utils/data.py +120 -18
- wavedl/utils/metrics.py +287 -326
- {wavedl-1.5.6.dist-info → wavedl-1.6.0.dist-info}/METADATA +104 -67
- wavedl-1.6.0.dist-info/RECORD +44 -0
- wavedl-1.5.6.dist-info/RECORD +0 -38
- {wavedl-1.5.6.dist-info → wavedl-1.6.0.dist-info}/LICENSE +0 -0
- {wavedl-1.5.6.dist-info → wavedl-1.6.0.dist-info}/WHEEL +0 -0
- {wavedl-1.5.6.dist-info → wavedl-1.6.0.dist-info}/entry_points.txt +0 -0
- {wavedl-1.5.6.dist-info → wavedl-1.6.0.dist-info}/top_level.txt +0 -0
wavedl/models/vit.py
CHANGED
|
@@ -10,9 +10,9 @@ Supports both 1D (signals) and 2D (images) inputs via configurable patch embeddi
|
|
|
10
10
|
- 2D: Images/spectrograms → patches are grid squares
|
|
11
11
|
|
|
12
12
|
**Variants**:
|
|
13
|
-
- vit_tiny: Smallest (~5.
|
|
14
|
-
- vit_small: Light (~
|
|
15
|
-
- vit_base: Standard (~
|
|
13
|
+
- vit_tiny: Smallest (~5.4M backbone params, embed_dim=192, depth=12, heads=3)
|
|
14
|
+
- vit_small: Light (~21.4M backbone params, embed_dim=384, depth=12, heads=6)
|
|
15
|
+
- vit_base: Standard (~85.3M backbone params, embed_dim=768, depth=12, heads=12)
|
|
16
16
|
|
|
17
17
|
References:
|
|
18
18
|
Dosovitskiy, A., et al. (2021). An Image is Worth 16x16 Words:
|
|
@@ -365,7 +365,7 @@ class ViTTiny(ViTBase):
|
|
|
365
365
|
"""
|
|
366
366
|
ViT-Tiny: Smallest Vision Transformer variant.
|
|
367
367
|
|
|
368
|
-
~5.
|
|
368
|
+
~5.4M backbone parameters. Good for: Quick experiments, smaller datasets.
|
|
369
369
|
|
|
370
370
|
Args:
|
|
371
371
|
in_shape: (L,) for 1D or (H, W) for 2D
|
|
@@ -398,7 +398,7 @@ class ViTSmall(ViTBase):
|
|
|
398
398
|
"""
|
|
399
399
|
ViT-Small: Light Vision Transformer variant.
|
|
400
400
|
|
|
401
|
-
~
|
|
401
|
+
~21.4M backbone parameters. Good for: Balanced performance.
|
|
402
402
|
|
|
403
403
|
Args:
|
|
404
404
|
in_shape: (L,) for 1D or (H, W) for 2D
|
|
@@ -429,7 +429,7 @@ class ViTBase_(ViTBase):
|
|
|
429
429
|
"""
|
|
430
430
|
ViT-Base: Standard Vision Transformer variant.
|
|
431
431
|
|
|
432
|
-
~
|
|
432
|
+
~85.3M backbone parameters. Good for: High accuracy, larger datasets.
|
|
433
433
|
|
|
434
434
|
Args:
|
|
435
435
|
in_shape: (L,) for 1D or (H, W) for 2D
|
wavedl/test.py
CHANGED
|
@@ -311,7 +311,7 @@ def load_data_for_inference(
|
|
|
311
311
|
# ==============================================================================
|
|
312
312
|
def load_checkpoint(
|
|
313
313
|
checkpoint_dir: str,
|
|
314
|
-
in_shape: tuple[int,
|
|
314
|
+
in_shape: tuple[int, ...],
|
|
315
315
|
out_size: int,
|
|
316
316
|
model_name: str | None = None,
|
|
317
317
|
) -> tuple[nn.Module, any]:
|
|
@@ -320,7 +320,7 @@ def load_checkpoint(
|
|
|
320
320
|
|
|
321
321
|
Args:
|
|
322
322
|
checkpoint_dir: Path to checkpoint directory
|
|
323
|
-
in_shape: Input
|
|
323
|
+
in_shape: Input spatial shape - (L,) for 1D, (H, W) for 2D, or (D, H, W) for 3D
|
|
324
324
|
out_size: Number of output parameters
|
|
325
325
|
model_name: Model architecture name (auto-detect if None)
|
|
326
326
|
|
|
@@ -376,7 +376,11 @@ def load_checkpoint(
|
|
|
376
376
|
)
|
|
377
377
|
|
|
378
378
|
logging.info(f" Building model: {model_name}")
|
|
379
|
-
|
|
379
|
+
# Use pretrained=False: checkpoint weights will overwrite any pretrained weights,
|
|
380
|
+
# so downloading ImageNet weights is wasteful and breaks offline/HPC inference.
|
|
381
|
+
model = build_model(
|
|
382
|
+
model_name, in_shape=in_shape, out_size=out_size, pretrained=False
|
|
383
|
+
)
|
|
380
384
|
|
|
381
385
|
# Load weights (check multiple formats in order of preference)
|
|
382
386
|
weight_path = None
|
wavedl/train.py
CHANGED
|
@@ -238,18 +238,13 @@ def parse_args() -> argparse.Namespace:
|
|
|
238
238
|
default=[],
|
|
239
239
|
help="Python modules to import before training (for custom models)",
|
|
240
240
|
)
|
|
241
|
-
parser.add_argument(
|
|
242
|
-
"--pretrained",
|
|
243
|
-
action="store_true",
|
|
244
|
-
default=True,
|
|
245
|
-
help="Use pretrained weights (default: True)",
|
|
246
|
-
)
|
|
247
241
|
parser.add_argument(
|
|
248
242
|
"--no_pretrained",
|
|
249
243
|
dest="pretrained",
|
|
250
244
|
action="store_false",
|
|
251
|
-
help="Train from scratch without pretrained weights",
|
|
245
|
+
help="Train from scratch without pretrained weights (default: use pretrained)",
|
|
252
246
|
)
|
|
247
|
+
parser.set_defaults(pretrained=True)
|
|
253
248
|
|
|
254
249
|
# Configuration File
|
|
255
250
|
parser.add_argument(
|
|
@@ -367,6 +362,18 @@ def parse_args() -> argparse.Namespace:
|
|
|
367
362
|
help="DataLoader workers per GPU (-1=auto-detect based on CPU cores)",
|
|
368
363
|
)
|
|
369
364
|
parser.add_argument("--seed", type=int, default=2025, help="Random seed")
|
|
365
|
+
parser.add_argument(
|
|
366
|
+
"--deterministic",
|
|
367
|
+
action="store_true",
|
|
368
|
+
help="Enable deterministic mode for reproducibility (slower, disables TF32/cuDNN benchmark)",
|
|
369
|
+
)
|
|
370
|
+
parser.add_argument(
|
|
371
|
+
"--cache_validate",
|
|
372
|
+
type=str,
|
|
373
|
+
default="sha256",
|
|
374
|
+
choices=["sha256", "fast", "size"],
|
|
375
|
+
help="Cache validation mode: sha256 (full hash), fast (partial), size (quick)",
|
|
376
|
+
)
|
|
370
377
|
parser.add_argument(
|
|
371
378
|
"--single_channel",
|
|
372
379
|
action="store_true",
|
|
@@ -512,11 +519,23 @@ def main():
|
|
|
512
519
|
# Import as regular module
|
|
513
520
|
importlib.import_module(module_name)
|
|
514
521
|
print(f"✓ Imported module: {module_name}")
|
|
515
|
-
except ImportError as e:
|
|
522
|
+
except (ImportError, FileNotFoundError, SyntaxError, PermissionError) as e:
|
|
516
523
|
print(f"✗ Failed to import '{module_name}': {e}", file=sys.stderr)
|
|
517
|
-
|
|
518
|
-
"
|
|
519
|
-
)
|
|
524
|
+
if isinstance(e, FileNotFoundError):
|
|
525
|
+
print(" File does not exist. Check the path.", file=sys.stderr)
|
|
526
|
+
elif isinstance(e, SyntaxError):
|
|
527
|
+
print(
|
|
528
|
+
f" Syntax error at line {e.lineno}: {e.msg}", file=sys.stderr
|
|
529
|
+
)
|
|
530
|
+
elif isinstance(e, PermissionError):
|
|
531
|
+
print(
|
|
532
|
+
" Permission denied. Check file permissions.", file=sys.stderr
|
|
533
|
+
)
|
|
534
|
+
else:
|
|
535
|
+
print(
|
|
536
|
+
" Make sure the module is in your Python path or current directory.",
|
|
537
|
+
file=sys.stderr,
|
|
538
|
+
)
|
|
520
539
|
sys.exit(1)
|
|
521
540
|
|
|
522
541
|
# Handle --list_models flag
|
|
@@ -648,6 +667,17 @@ def main():
|
|
|
648
667
|
)
|
|
649
668
|
set_seed(args.seed)
|
|
650
669
|
|
|
670
|
+
# Deterministic mode for scientific reproducibility
|
|
671
|
+
# Disables TF32 and cuDNN benchmark for exact reproducibility (slower)
|
|
672
|
+
if args.deterministic:
|
|
673
|
+
torch.backends.cudnn.benchmark = False
|
|
674
|
+
torch.backends.cudnn.deterministic = True
|
|
675
|
+
torch.backends.cuda.matmul.allow_tf32 = False
|
|
676
|
+
torch.backends.cudnn.allow_tf32 = False
|
|
677
|
+
torch.use_deterministic_algorithms(True, warn_only=True)
|
|
678
|
+
if accelerator.is_main_process:
|
|
679
|
+
print("🔒 Deterministic mode enabled (slower but reproducible)")
|
|
680
|
+
|
|
651
681
|
# Configure logging (rank 0 only prints to console)
|
|
652
682
|
logging.basicConfig(
|
|
653
683
|
level=logging.INFO if accelerator.is_main_process else logging.ERROR,
|
|
@@ -999,12 +1029,14 @@ def main():
|
|
|
999
1029
|
|
|
1000
1030
|
for x, y in pbar:
|
|
1001
1031
|
with accelerator.accumulate(model):
|
|
1002
|
-
|
|
1003
|
-
|
|
1004
|
-
|
|
1005
|
-
|
|
1006
|
-
|
|
1007
|
-
|
|
1032
|
+
# Use mixed precision for forward pass (respects --precision flag)
|
|
1033
|
+
with accelerator.autocast():
|
|
1034
|
+
pred = model(x)
|
|
1035
|
+
# Pass inputs for input-dependent constraints (x_mean, x[...], etc.)
|
|
1036
|
+
if isinstance(criterion, PhysicsConstrainedLoss):
|
|
1037
|
+
loss = criterion(pred, y, x)
|
|
1038
|
+
else:
|
|
1039
|
+
loss = criterion(pred, y)
|
|
1008
1040
|
|
|
1009
1041
|
accelerator.backward(loss)
|
|
1010
1042
|
|
|
@@ -1053,12 +1085,14 @@ def main():
|
|
|
1053
1085
|
|
|
1054
1086
|
with torch.inference_mode():
|
|
1055
1087
|
for x, y in val_dl:
|
|
1056
|
-
|
|
1057
|
-
|
|
1058
|
-
|
|
1059
|
-
|
|
1060
|
-
|
|
1061
|
-
|
|
1088
|
+
# Use mixed precision for validation (consistent with training)
|
|
1089
|
+
with accelerator.autocast():
|
|
1090
|
+
pred = model(x)
|
|
1091
|
+
# Pass inputs for input-dependent constraints
|
|
1092
|
+
if isinstance(criterion, PhysicsConstrainedLoss):
|
|
1093
|
+
loss = criterion(pred, y, x)
|
|
1094
|
+
else:
|
|
1095
|
+
loss = criterion(pred, y)
|
|
1062
1096
|
|
|
1063
1097
|
val_loss_sum += loss.detach() * x.size(0)
|
|
1064
1098
|
val_samples += x.size(0)
|
wavedl/utils/constraints.py
CHANGED
|
@@ -207,22 +207,28 @@ class ExpressionConstraint(nn.Module):
|
|
|
207
207
|
# Parse indices from the slice
|
|
208
208
|
indices = self._parse_subscript_indices(node.slice)
|
|
209
209
|
|
|
210
|
+
# Auto-squeeze channel dimension for single-channel inputs
|
|
211
|
+
# This allows x[i,j] syntax for (B, 1, H, W) inputs instead of x[c,i,j]
|
|
212
|
+
inputs_for_indexing = inputs
|
|
213
|
+
if inputs.ndim >= 3 and inputs.shape[1] == 1:
|
|
214
|
+
inputs_for_indexing = inputs.squeeze(1) # (B, 1, H, W) → (B, H, W)
|
|
215
|
+
|
|
210
216
|
# Validate dimensions match
|
|
211
217
|
# inputs shape: (batch, dim1) or (batch, dim1, dim2) or (batch, dim1, dim2, dim3)
|
|
212
|
-
input_ndim =
|
|
218
|
+
input_ndim = inputs_for_indexing.ndim - 1 # Exclude batch dimension
|
|
213
219
|
if len(indices) != input_ndim:
|
|
214
220
|
raise ValueError(
|
|
215
|
-
f"Input has {input_ndim}D shape, but got {len(indices)} indices. "
|
|
221
|
+
f"Input has {input_ndim}D shape (after channel squeeze), but got {len(indices)} indices. "
|
|
216
222
|
f"Use x[i] for 1D, x[i,j] for 2D, x[i,j,k] for 3D inputs."
|
|
217
223
|
)
|
|
218
224
|
|
|
219
225
|
# Extract the value at the specified indices (for entire batch)
|
|
220
226
|
if len(indices) == 1:
|
|
221
|
-
return
|
|
227
|
+
return inputs_for_indexing[:, indices[0]]
|
|
222
228
|
elif len(indices) == 2:
|
|
223
|
-
return
|
|
229
|
+
return inputs_for_indexing[:, indices[0], indices[1]]
|
|
224
230
|
elif len(indices) == 3:
|
|
225
|
-
return
|
|
231
|
+
return inputs_for_indexing[:, indices[0], indices[1], indices[2]]
|
|
226
232
|
else:
|
|
227
233
|
raise ValueError("Only 1D, 2D, or 3D input indexing supported.")
|
|
228
234
|
elif isinstance(node, ast.Expression):
|
wavedl/utils/data.py
CHANGED
|
@@ -50,9 +50,11 @@ INPUT_KEYS = ["input_train", "input_test", "X", "data", "inputs", "features", "x
|
|
|
50
50
|
OUTPUT_KEYS = ["output_train", "output_test", "Y", "labels", "outputs", "targets", "y"]
|
|
51
51
|
|
|
52
52
|
|
|
53
|
-
def _compute_file_hash(
|
|
53
|
+
def _compute_file_hash(
|
|
54
|
+
path: str, mode: str = "sha256", chunk_size: int = 8 * 1024 * 1024
|
|
55
|
+
) -> str:
|
|
54
56
|
"""
|
|
55
|
-
Compute
|
|
57
|
+
Compute hash of a file for cache validation.
|
|
56
58
|
|
|
57
59
|
Uses chunked reading to handle large files efficiently without loading
|
|
58
60
|
the entire file into memory. This is more reliable than mtime for detecting
|
|
@@ -61,16 +63,34 @@ def _compute_file_hash(path: str, chunk_size: int = 8 * 1024 * 1024) -> str:
|
|
|
61
63
|
|
|
62
64
|
Args:
|
|
63
65
|
path: Path to file to hash
|
|
66
|
+
mode: Validation mode:
|
|
67
|
+
- 'sha256': Full content hash (default, most reliable)
|
|
68
|
+
- 'fast': Partial hash (first+last 1MB + size, faster for large files)
|
|
69
|
+
- 'size': File size only (fastest, least reliable)
|
|
64
70
|
chunk_size: Read buffer size (default 8MB for fast I/O)
|
|
65
71
|
|
|
66
72
|
Returns:
|
|
67
|
-
|
|
73
|
+
Hash string for cache comparison
|
|
68
74
|
"""
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
75
|
+
if mode == "size":
|
|
76
|
+
return str(os.path.getsize(path))
|
|
77
|
+
elif mode == "fast":
|
|
78
|
+
# Hash first 1MB + last 1MB + file size for quick validation
|
|
79
|
+
file_size = os.path.getsize(path)
|
|
80
|
+
hasher = hashlib.sha256()
|
|
81
|
+
hasher.update(str(file_size).encode())
|
|
82
|
+
with open(path, "rb") as f:
|
|
83
|
+
hasher.update(f.read(1024 * 1024)) # First 1MB
|
|
84
|
+
if file_size > 2 * 1024 * 1024:
|
|
85
|
+
f.seek(-1024 * 1024, 2)
|
|
86
|
+
hasher.update(f.read()) # Last 1MB
|
|
87
|
+
return hasher.hexdigest()
|
|
88
|
+
else: # sha256 (full)
|
|
89
|
+
hasher = hashlib.sha256()
|
|
90
|
+
with open(path, "rb") as f:
|
|
91
|
+
while chunk := f.read(chunk_size):
|
|
92
|
+
hasher.update(chunk)
|
|
93
|
+
return hasher.hexdigest()
|
|
74
94
|
|
|
75
95
|
|
|
76
96
|
class LazyDataHandle:
|
|
@@ -454,9 +474,18 @@ class _TransposedH5Dataset:
|
|
|
454
474
|
self.shape = tuple(reversed(h5_dataset.shape))
|
|
455
475
|
self.dtype = h5_dataset.dtype
|
|
456
476
|
|
|
457
|
-
|
|
458
|
-
|
|
459
|
-
|
|
477
|
+
@property
|
|
478
|
+
def ndim(self) -> int:
|
|
479
|
+
"""Number of dimensions (derived from shape for numpy compatibility)."""
|
|
480
|
+
return len(self.shape)
|
|
481
|
+
|
|
482
|
+
@property
|
|
483
|
+
def _transpose_axes(self) -> tuple[int, ...]:
|
|
484
|
+
"""Transpose axis order for reversing dimensions.
|
|
485
|
+
|
|
486
|
+
For shape (A, B, C) -> reversed (C, B, A), transpose axes are (2, 1, 0).
|
|
487
|
+
"""
|
|
488
|
+
return tuple(range(len(self._dataset.shape) - 1, -1, -1))
|
|
460
489
|
|
|
461
490
|
def __len__(self) -> int:
|
|
462
491
|
return self.shape[0]
|
|
@@ -840,10 +869,22 @@ def load_test_data(
|
|
|
840
869
|
keys = list(probe.keys())
|
|
841
870
|
inp_key = DataSource._find_key(keys, custom_input_keys)
|
|
842
871
|
out_key = DataSource._find_key(keys, custom_output_keys)
|
|
872
|
+
# Strict validation: if user explicitly specified input_key, it must exist exactly
|
|
873
|
+
if input_key is not None and input_key not in keys:
|
|
874
|
+
raise KeyError(
|
|
875
|
+
f"Explicit --input_key '{input_key}' not found. "
|
|
876
|
+
f"Available keys: {keys}"
|
|
877
|
+
)
|
|
843
878
|
if inp_key is None:
|
|
844
879
|
raise KeyError(
|
|
845
880
|
f"Input key not found. Tried: {custom_input_keys}. Found: {keys}"
|
|
846
881
|
)
|
|
882
|
+
# Strict validation: if user explicitly specified output_key, it must exist exactly
|
|
883
|
+
if output_key is not None and output_key not in keys:
|
|
884
|
+
raise KeyError(
|
|
885
|
+
f"Explicit --output_key '{output_key}' not found. "
|
|
886
|
+
f"Available keys: {keys}"
|
|
887
|
+
)
|
|
847
888
|
data = NPZSource._load_and_copy(
|
|
848
889
|
path, [inp_key] + ([out_key] if out_key else [])
|
|
849
890
|
)
|
|
@@ -858,10 +899,22 @@ def load_test_data(
|
|
|
858
899
|
keys = list(f.keys())
|
|
859
900
|
inp_key = DataSource._find_key(keys, custom_input_keys)
|
|
860
901
|
out_key = DataSource._find_key(keys, custom_output_keys)
|
|
902
|
+
# Strict validation: if user explicitly specified input_key, it must exist exactly
|
|
903
|
+
if input_key is not None and input_key not in keys:
|
|
904
|
+
raise KeyError(
|
|
905
|
+
f"Explicit --input_key '{input_key}' not found. "
|
|
906
|
+
f"Available keys: {keys}"
|
|
907
|
+
)
|
|
861
908
|
if inp_key is None:
|
|
862
909
|
raise KeyError(
|
|
863
910
|
f"Input key not found. Tried: {custom_input_keys}. Found: {keys}"
|
|
864
911
|
)
|
|
912
|
+
# Strict validation: if user explicitly specified output_key, it must exist exactly
|
|
913
|
+
if output_key is not None and output_key not in keys:
|
|
914
|
+
raise KeyError(
|
|
915
|
+
f"Explicit --output_key '{output_key}' not found. "
|
|
916
|
+
f"Available keys: {keys}"
|
|
917
|
+
)
|
|
865
918
|
# OOM guard: warn if dataset is very large
|
|
866
919
|
n_samples = f[inp_key].shape[0]
|
|
867
920
|
if n_samples > 100000:
|
|
@@ -878,10 +931,22 @@ def load_test_data(
|
|
|
878
931
|
keys = list(f.keys())
|
|
879
932
|
inp_key = DataSource._find_key(keys, custom_input_keys)
|
|
880
933
|
out_key = DataSource._find_key(keys, custom_output_keys)
|
|
934
|
+
# Strict validation: if user explicitly specified input_key, it must exist exactly
|
|
935
|
+
if input_key is not None and input_key not in keys:
|
|
936
|
+
raise KeyError(
|
|
937
|
+
f"Explicit --input_key '{input_key}' not found. "
|
|
938
|
+
f"Available keys: {keys}"
|
|
939
|
+
)
|
|
881
940
|
if inp_key is None:
|
|
882
941
|
raise KeyError(
|
|
883
942
|
f"Input key not found. Tried: {custom_input_keys}. Found: {keys}"
|
|
884
943
|
)
|
|
944
|
+
# Strict validation: if user explicitly specified output_key, it must exist exactly
|
|
945
|
+
if output_key is not None and output_key not in keys:
|
|
946
|
+
raise KeyError(
|
|
947
|
+
f"Explicit --output_key '{output_key}' not found. "
|
|
948
|
+
f"Available keys: {keys}"
|
|
949
|
+
)
|
|
885
950
|
# OOM guard: warn if dataset is very large (MAT is transposed)
|
|
886
951
|
n_samples = f[inp_key].shape[-1]
|
|
887
952
|
if n_samples > 100000:
|
|
@@ -909,8 +974,17 @@ def load_test_data(
|
|
|
909
974
|
else:
|
|
910
975
|
# Fallback to default source.load() for unknown formats
|
|
911
976
|
inp, outp = source.load(path)
|
|
912
|
-
except KeyError:
|
|
913
|
-
#
|
|
977
|
+
except KeyError as e:
|
|
978
|
+
# IMPORTANT: Only fall back to inference-only mode if outputs are
|
|
979
|
+
# genuinely missing (auto-detection failed). If user explicitly
|
|
980
|
+
# provided --output_key, they expect it to exist - don't silently drop.
|
|
981
|
+
if output_key is not None:
|
|
982
|
+
raise KeyError(
|
|
983
|
+
f"Explicit --output_key '{output_key}' not found in file. "
|
|
984
|
+
f"Available keys depend on file format. Original error: {e}"
|
|
985
|
+
) from e
|
|
986
|
+
|
|
987
|
+
# Legitimate fallback: no explicit output_key, outputs just not present
|
|
914
988
|
if format == "npz":
|
|
915
989
|
# First pass to find keys
|
|
916
990
|
with np.load(path, allow_pickle=False) as probe:
|
|
@@ -1020,9 +1094,33 @@ def load_test_data(
|
|
|
1020
1094
|
if input_channels == 1:
|
|
1021
1095
|
X = X.unsqueeze(1) # Add channel: (N, D, H, W) → (N, 1, D, H, W)
|
|
1022
1096
|
# else: already has channels, leave as-is
|
|
1023
|
-
|
|
1024
|
-
#
|
|
1025
|
-
|
|
1097
|
+
else:
|
|
1098
|
+
# Detect channels-last format: (N, H, W, C) where C is small (1-4)
|
|
1099
|
+
# and spatial dims are large (>16). This catches common mistakes.
|
|
1100
|
+
if X.shape[-1] <= 4 and X.shape[1] > 16 and X.shape[2] > 16:
|
|
1101
|
+
raise ValueError(
|
|
1102
|
+
f"Input appears to be channels-last format: {tuple(X.shape)}. "
|
|
1103
|
+
"WaveDL expects channels-first (N, C, H, W). "
|
|
1104
|
+
"Convert your data using: X = X.permute(0, 3, 1, 2). "
|
|
1105
|
+
"If this is actually a 3D volume with small depth, "
|
|
1106
|
+
"use --input_channels 1 to add a channel dimension."
|
|
1107
|
+
)
|
|
1108
|
+
elif X.shape[1] > 16:
|
|
1109
|
+
# Heuristic fallback: large dim 1 suggests 3D volume needing channel
|
|
1110
|
+
X = X.unsqueeze(1) # 3D volume: (N, D, H, W) → (N, 1, D, H, W)
|
|
1111
|
+
else:
|
|
1112
|
+
# Ambiguous case: shallow 3D volume (D <= 16) or multi-channel 2D
|
|
1113
|
+
# Default to treating as multi-channel 2D (no modification needed)
|
|
1114
|
+
# Log a warning so users know about the --input_channels option
|
|
1115
|
+
import warnings
|
|
1116
|
+
|
|
1117
|
+
warnings.warn(
|
|
1118
|
+
f"Ambiguous 4D input shape: {tuple(X.shape)}. "
|
|
1119
|
+
f"Assuming {X.shape[1]} channels (multi-channel 2D). "
|
|
1120
|
+
f"For 3D volumes with depth={X.shape[1]}, use --input_channels 1.",
|
|
1121
|
+
UserWarning,
|
|
1122
|
+
stacklevel=2,
|
|
1123
|
+
)
|
|
1026
1124
|
# X.ndim >= 5: assume channel dimension already exists
|
|
1027
1125
|
|
|
1028
1126
|
return X, y
|
|
@@ -1207,7 +1305,9 @@ def prepare_data(
|
|
|
1207
1305
|
cache_exists = False
|
|
1208
1306
|
# Content hash check (robust against cloud sync mtime changes)
|
|
1209
1307
|
elif cached_content_hash is not None:
|
|
1210
|
-
current_hash = _compute_file_hash(
|
|
1308
|
+
current_hash = _compute_file_hash(
|
|
1309
|
+
args.data_path, mode=getattr(args, "cache_validate", "sha256")
|
|
1310
|
+
)
|
|
1211
1311
|
if cached_content_hash != current_hash:
|
|
1212
1312
|
if accelerator.is_main_process:
|
|
1213
1313
|
logger.warning(
|
|
@@ -1330,7 +1430,9 @@ def prepare_data(
|
|
|
1330
1430
|
|
|
1331
1431
|
# Save metadata (including data path, size, content hash for cache validation)
|
|
1332
1432
|
file_stats = os.stat(args.data_path)
|
|
1333
|
-
content_hash = _compute_file_hash(
|
|
1433
|
+
content_hash = _compute_file_hash(
|
|
1434
|
+
args.data_path, mode=getattr(args, "cache_validate", "sha256")
|
|
1435
|
+
)
|
|
1334
1436
|
with open(META_FILE, "wb") as f:
|
|
1335
1437
|
pickle.dump(
|
|
1336
1438
|
{
|