wavedl 1.5.6__py3-none-any.whl → 1.6.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
wavedl/models/vit.py CHANGED
@@ -10,9 +10,9 @@ Supports both 1D (signals) and 2D (images) inputs via configurable patch embeddi
10
10
  - 2D: Images/spectrograms → patches are grid squares
11
11
 
12
12
  **Variants**:
13
- - vit_tiny: Smallest (~5.7M params, embed_dim=192, depth=12, heads=3)
14
- - vit_small: Light (~22M params, embed_dim=384, depth=12, heads=6)
15
- - vit_base: Standard (~86M params, embed_dim=768, depth=12, heads=12)
13
+ - vit_tiny: Smallest (~5.4M backbone params, embed_dim=192, depth=12, heads=3)
14
+ - vit_small: Light (~21.4M backbone params, embed_dim=384, depth=12, heads=6)
15
+ - vit_base: Standard (~85.3M backbone params, embed_dim=768, depth=12, heads=12)
16
16
 
17
17
  References:
18
18
  Dosovitskiy, A., et al. (2021). An Image is Worth 16x16 Words:
@@ -365,7 +365,7 @@ class ViTTiny(ViTBase):
365
365
  """
366
366
  ViT-Tiny: Smallest Vision Transformer variant.
367
367
 
368
- ~5.7M parameters. Good for: Quick experiments, smaller datasets.
368
+ ~5.4M backbone parameters. Good for: Quick experiments, smaller datasets.
369
369
 
370
370
  Args:
371
371
  in_shape: (L,) for 1D or (H, W) for 2D
@@ -398,7 +398,7 @@ class ViTSmall(ViTBase):
398
398
  """
399
399
  ViT-Small: Light Vision Transformer variant.
400
400
 
401
- ~22M parameters. Good for: Balanced performance.
401
+ ~21.4M backbone parameters. Good for: Balanced performance.
402
402
 
403
403
  Args:
404
404
  in_shape: (L,) for 1D or (H, W) for 2D
@@ -429,7 +429,7 @@ class ViTBase_(ViTBase):
429
429
  """
430
430
  ViT-Base: Standard Vision Transformer variant.
431
431
 
432
- ~86M parameters. Good for: High accuracy, larger datasets.
432
+ ~85.3M backbone parameters. Good for: High accuracy, larger datasets.
433
433
 
434
434
  Args:
435
435
  in_shape: (L,) for 1D or (H, W) for 2D
wavedl/test.py CHANGED
@@ -311,7 +311,7 @@ def load_data_for_inference(
311
311
  # ==============================================================================
312
312
  def load_checkpoint(
313
313
  checkpoint_dir: str,
314
- in_shape: tuple[int, int],
314
+ in_shape: tuple[int, ...],
315
315
  out_size: int,
316
316
  model_name: str | None = None,
317
317
  ) -> tuple[nn.Module, any]:
@@ -320,7 +320,7 @@ def load_checkpoint(
320
320
 
321
321
  Args:
322
322
  checkpoint_dir: Path to checkpoint directory
323
- in_shape: Input image shape (H, W)
323
+ in_shape: Input spatial shape - (L,) for 1D, (H, W) for 2D, or (D, H, W) for 3D
324
324
  out_size: Number of output parameters
325
325
  model_name: Model architecture name (auto-detect if None)
326
326
 
@@ -376,7 +376,11 @@ def load_checkpoint(
376
376
  )
377
377
 
378
378
  logging.info(f" Building model: {model_name}")
379
- model = build_model(model_name, in_shape=in_shape, out_size=out_size)
379
+ # Use pretrained=False: checkpoint weights will overwrite any pretrained weights,
380
+ # so downloading ImageNet weights is wasteful and breaks offline/HPC inference.
381
+ model = build_model(
382
+ model_name, in_shape=in_shape, out_size=out_size, pretrained=False
383
+ )
380
384
 
381
385
  # Load weights (check multiple formats in order of preference)
382
386
  weight_path = None
wavedl/train.py CHANGED
@@ -238,18 +238,13 @@ def parse_args() -> argparse.Namespace:
238
238
  default=[],
239
239
  help="Python modules to import before training (for custom models)",
240
240
  )
241
- parser.add_argument(
242
- "--pretrained",
243
- action="store_true",
244
- default=True,
245
- help="Use pretrained weights (default: True)",
246
- )
247
241
  parser.add_argument(
248
242
  "--no_pretrained",
249
243
  dest="pretrained",
250
244
  action="store_false",
251
- help="Train from scratch without pretrained weights",
245
+ help="Train from scratch without pretrained weights (default: use pretrained)",
252
246
  )
247
+ parser.set_defaults(pretrained=True)
253
248
 
254
249
  # Configuration File
255
250
  parser.add_argument(
@@ -367,6 +362,18 @@ def parse_args() -> argparse.Namespace:
367
362
  help="DataLoader workers per GPU (-1=auto-detect based on CPU cores)",
368
363
  )
369
364
  parser.add_argument("--seed", type=int, default=2025, help="Random seed")
365
+ parser.add_argument(
366
+ "--deterministic",
367
+ action="store_true",
368
+ help="Enable deterministic mode for reproducibility (slower, disables TF32/cuDNN benchmark)",
369
+ )
370
+ parser.add_argument(
371
+ "--cache_validate",
372
+ type=str,
373
+ default="sha256",
374
+ choices=["sha256", "fast", "size"],
375
+ help="Cache validation mode: sha256 (full hash), fast (partial), size (quick)",
376
+ )
370
377
  parser.add_argument(
371
378
  "--single_channel",
372
379
  action="store_true",
@@ -512,11 +519,23 @@ def main():
512
519
  # Import as regular module
513
520
  importlib.import_module(module_name)
514
521
  print(f"✓ Imported module: {module_name}")
515
- except ImportError as e:
522
+ except (ImportError, FileNotFoundError, SyntaxError, PermissionError) as e:
516
523
  print(f"✗ Failed to import '{module_name}': {e}", file=sys.stderr)
517
- print(
518
- " Make sure the module is in your Python path or current directory."
519
- )
524
+ if isinstance(e, FileNotFoundError):
525
+ print(" File does not exist. Check the path.", file=sys.stderr)
526
+ elif isinstance(e, SyntaxError):
527
+ print(
528
+ f" Syntax error at line {e.lineno}: {e.msg}", file=sys.stderr
529
+ )
530
+ elif isinstance(e, PermissionError):
531
+ print(
532
+ " Permission denied. Check file permissions.", file=sys.stderr
533
+ )
534
+ else:
535
+ print(
536
+ " Make sure the module is in your Python path or current directory.",
537
+ file=sys.stderr,
538
+ )
520
539
  sys.exit(1)
521
540
 
522
541
  # Handle --list_models flag
@@ -648,6 +667,17 @@ def main():
648
667
  )
649
668
  set_seed(args.seed)
650
669
 
670
+ # Deterministic mode for scientific reproducibility
671
+ # Disables TF32 and cuDNN benchmark for exact reproducibility (slower)
672
+ if args.deterministic:
673
+ torch.backends.cudnn.benchmark = False
674
+ torch.backends.cudnn.deterministic = True
675
+ torch.backends.cuda.matmul.allow_tf32 = False
676
+ torch.backends.cudnn.allow_tf32 = False
677
+ torch.use_deterministic_algorithms(True, warn_only=True)
678
+ if accelerator.is_main_process:
679
+ print("🔒 Deterministic mode enabled (slower but reproducible)")
680
+
651
681
  # Configure logging (rank 0 only prints to console)
652
682
  logging.basicConfig(
653
683
  level=logging.INFO if accelerator.is_main_process else logging.ERROR,
@@ -999,12 +1029,14 @@ def main():
999
1029
 
1000
1030
  for x, y in pbar:
1001
1031
  with accelerator.accumulate(model):
1002
- pred = model(x)
1003
- # Pass inputs for input-dependent constraints (x_mean, x[...], etc.)
1004
- if isinstance(criterion, PhysicsConstrainedLoss):
1005
- loss = criterion(pred, y, x)
1006
- else:
1007
- loss = criterion(pred, y)
1032
+ # Use mixed precision for forward pass (respects --precision flag)
1033
+ with accelerator.autocast():
1034
+ pred = model(x)
1035
+ # Pass inputs for input-dependent constraints (x_mean, x[...], etc.)
1036
+ if isinstance(criterion, PhysicsConstrainedLoss):
1037
+ loss = criterion(pred, y, x)
1038
+ else:
1039
+ loss = criterion(pred, y)
1008
1040
 
1009
1041
  accelerator.backward(loss)
1010
1042
 
@@ -1053,12 +1085,14 @@ def main():
1053
1085
 
1054
1086
  with torch.inference_mode():
1055
1087
  for x, y in val_dl:
1056
- pred = model(x)
1057
- # Pass inputs for input-dependent constraints
1058
- if isinstance(criterion, PhysicsConstrainedLoss):
1059
- loss = criterion(pred, y, x)
1060
- else:
1061
- loss = criterion(pred, y)
1088
+ # Use mixed precision for validation (consistent with training)
1089
+ with accelerator.autocast():
1090
+ pred = model(x)
1091
+ # Pass inputs for input-dependent constraints
1092
+ if isinstance(criterion, PhysicsConstrainedLoss):
1093
+ loss = criterion(pred, y, x)
1094
+ else:
1095
+ loss = criterion(pred, y)
1062
1096
 
1063
1097
  val_loss_sum += loss.detach() * x.size(0)
1064
1098
  val_samples += x.size(0)
@@ -207,22 +207,28 @@ class ExpressionConstraint(nn.Module):
207
207
  # Parse indices from the slice
208
208
  indices = self._parse_subscript_indices(node.slice)
209
209
 
210
+ # Auto-squeeze channel dimension for single-channel inputs
211
+ # This allows x[i,j] syntax for (B, 1, H, W) inputs instead of x[c,i,j]
212
+ inputs_for_indexing = inputs
213
+ if inputs.ndim >= 3 and inputs.shape[1] == 1:
214
+ inputs_for_indexing = inputs.squeeze(1) # (B, 1, H, W) → (B, H, W)
215
+
210
216
  # Validate dimensions match
211
217
  # inputs shape: (batch, dim1) or (batch, dim1, dim2) or (batch, dim1, dim2, dim3)
212
- input_ndim = inputs.ndim - 1 # Exclude batch dimension
218
+ input_ndim = inputs_for_indexing.ndim - 1 # Exclude batch dimension
213
219
  if len(indices) != input_ndim:
214
220
  raise ValueError(
215
- f"Input has {input_ndim}D shape, but got {len(indices)} indices. "
221
+ f"Input has {input_ndim}D shape (after channel squeeze), but got {len(indices)} indices. "
216
222
  f"Use x[i] for 1D, x[i,j] for 2D, x[i,j,k] for 3D inputs."
217
223
  )
218
224
 
219
225
  # Extract the value at the specified indices (for entire batch)
220
226
  if len(indices) == 1:
221
- return inputs[:, indices[0]]
227
+ return inputs_for_indexing[:, indices[0]]
222
228
  elif len(indices) == 2:
223
- return inputs[:, indices[0], indices[1]]
229
+ return inputs_for_indexing[:, indices[0], indices[1]]
224
230
  elif len(indices) == 3:
225
- return inputs[:, indices[0], indices[1], indices[2]]
231
+ return inputs_for_indexing[:, indices[0], indices[1], indices[2]]
226
232
  else:
227
233
  raise ValueError("Only 1D, 2D, or 3D input indexing supported.")
228
234
  elif isinstance(node, ast.Expression):
wavedl/utils/data.py CHANGED
@@ -50,9 +50,11 @@ INPUT_KEYS = ["input_train", "input_test", "X", "data", "inputs", "features", "x
50
50
  OUTPUT_KEYS = ["output_train", "output_test", "Y", "labels", "outputs", "targets", "y"]
51
51
 
52
52
 
53
- def _compute_file_hash(path: str, chunk_size: int = 8 * 1024 * 1024) -> str:
53
+ def _compute_file_hash(
54
+ path: str, mode: str = "sha256", chunk_size: int = 8 * 1024 * 1024
55
+ ) -> str:
54
56
  """
55
- Compute SHA256 hash of a file for cache validation.
57
+ Compute hash of a file for cache validation.
56
58
 
57
59
  Uses chunked reading to handle large files efficiently without loading
58
60
  the entire file into memory. This is more reliable than mtime for detecting
@@ -61,16 +63,34 @@ def _compute_file_hash(path: str, chunk_size: int = 8 * 1024 * 1024) -> str:
61
63
 
62
64
  Args:
63
65
  path: Path to file to hash
66
+ mode: Validation mode:
67
+ - 'sha256': Full content hash (default, most reliable)
68
+ - 'fast': Partial hash (first+last 1MB + size, faster for large files)
69
+ - 'size': File size only (fastest, least reliable)
64
70
  chunk_size: Read buffer size (default 8MB for fast I/O)
65
71
 
66
72
  Returns:
67
- Hex string of SHA256 hash
73
+ Hash string for cache comparison
68
74
  """
69
- hasher = hashlib.sha256()
70
- with open(path, "rb") as f:
71
- while chunk := f.read(chunk_size):
72
- hasher.update(chunk)
73
- return hasher.hexdigest()
75
+ if mode == "size":
76
+ return str(os.path.getsize(path))
77
+ elif mode == "fast":
78
+ # Hash first 1MB + last 1MB + file size for quick validation
79
+ file_size = os.path.getsize(path)
80
+ hasher = hashlib.sha256()
81
+ hasher.update(str(file_size).encode())
82
+ with open(path, "rb") as f:
83
+ hasher.update(f.read(1024 * 1024)) # First 1MB
84
+ if file_size > 2 * 1024 * 1024:
85
+ f.seek(-1024 * 1024, 2)
86
+ hasher.update(f.read()) # Last 1MB
87
+ return hasher.hexdigest()
88
+ else: # sha256 (full)
89
+ hasher = hashlib.sha256()
90
+ with open(path, "rb") as f:
91
+ while chunk := f.read(chunk_size):
92
+ hasher.update(chunk)
93
+ return hasher.hexdigest()
74
94
 
75
95
 
76
96
  class LazyDataHandle:
@@ -454,9 +474,18 @@ class _TransposedH5Dataset:
454
474
  self.shape = tuple(reversed(h5_dataset.shape))
455
475
  self.dtype = h5_dataset.dtype
456
476
 
457
- # Precompute transpose axis order for efficiency
458
- # For shape (A, B, C) -> reversed (C, B, A), transpose axes are (2, 1, 0)
459
- self._transpose_axes = tuple(range(len(h5_dataset.shape) - 1, -1, -1))
477
+ @property
478
+ def ndim(self) -> int:
479
+ """Number of dimensions (derived from shape for numpy compatibility)."""
480
+ return len(self.shape)
481
+
482
+ @property
483
+ def _transpose_axes(self) -> tuple[int, ...]:
484
+ """Transpose axis order for reversing dimensions.
485
+
486
+ For shape (A, B, C) -> reversed (C, B, A), transpose axes are (2, 1, 0).
487
+ """
488
+ return tuple(range(len(self._dataset.shape) - 1, -1, -1))
460
489
 
461
490
  def __len__(self) -> int:
462
491
  return self.shape[0]
@@ -840,10 +869,22 @@ def load_test_data(
840
869
  keys = list(probe.keys())
841
870
  inp_key = DataSource._find_key(keys, custom_input_keys)
842
871
  out_key = DataSource._find_key(keys, custom_output_keys)
872
+ # Strict validation: if user explicitly specified input_key, it must exist exactly
873
+ if input_key is not None and input_key not in keys:
874
+ raise KeyError(
875
+ f"Explicit --input_key '{input_key}' not found. "
876
+ f"Available keys: {keys}"
877
+ )
843
878
  if inp_key is None:
844
879
  raise KeyError(
845
880
  f"Input key not found. Tried: {custom_input_keys}. Found: {keys}"
846
881
  )
882
+ # Strict validation: if user explicitly specified output_key, it must exist exactly
883
+ if output_key is not None and output_key not in keys:
884
+ raise KeyError(
885
+ f"Explicit --output_key '{output_key}' not found. "
886
+ f"Available keys: {keys}"
887
+ )
847
888
  data = NPZSource._load_and_copy(
848
889
  path, [inp_key] + ([out_key] if out_key else [])
849
890
  )
@@ -858,10 +899,22 @@ def load_test_data(
858
899
  keys = list(f.keys())
859
900
  inp_key = DataSource._find_key(keys, custom_input_keys)
860
901
  out_key = DataSource._find_key(keys, custom_output_keys)
902
+ # Strict validation: if user explicitly specified input_key, it must exist exactly
903
+ if input_key is not None and input_key not in keys:
904
+ raise KeyError(
905
+ f"Explicit --input_key '{input_key}' not found. "
906
+ f"Available keys: {keys}"
907
+ )
861
908
  if inp_key is None:
862
909
  raise KeyError(
863
910
  f"Input key not found. Tried: {custom_input_keys}. Found: {keys}"
864
911
  )
912
+ # Strict validation: if user explicitly specified output_key, it must exist exactly
913
+ if output_key is not None and output_key not in keys:
914
+ raise KeyError(
915
+ f"Explicit --output_key '{output_key}' not found. "
916
+ f"Available keys: {keys}"
917
+ )
865
918
  # OOM guard: warn if dataset is very large
866
919
  n_samples = f[inp_key].shape[0]
867
920
  if n_samples > 100000:
@@ -878,10 +931,22 @@ def load_test_data(
878
931
  keys = list(f.keys())
879
932
  inp_key = DataSource._find_key(keys, custom_input_keys)
880
933
  out_key = DataSource._find_key(keys, custom_output_keys)
934
+ # Strict validation: if user explicitly specified input_key, it must exist exactly
935
+ if input_key is not None and input_key not in keys:
936
+ raise KeyError(
937
+ f"Explicit --input_key '{input_key}' not found. "
938
+ f"Available keys: {keys}"
939
+ )
881
940
  if inp_key is None:
882
941
  raise KeyError(
883
942
  f"Input key not found. Tried: {custom_input_keys}. Found: {keys}"
884
943
  )
944
+ # Strict validation: if user explicitly specified output_key, it must exist exactly
945
+ if output_key is not None and output_key not in keys:
946
+ raise KeyError(
947
+ f"Explicit --output_key '{output_key}' not found. "
948
+ f"Available keys: {keys}"
949
+ )
885
950
  # OOM guard: warn if dataset is very large (MAT is transposed)
886
951
  n_samples = f[inp_key].shape[-1]
887
952
  if n_samples > 100000:
@@ -909,8 +974,17 @@ def load_test_data(
909
974
  else:
910
975
  # Fallback to default source.load() for unknown formats
911
976
  inp, outp = source.load(path)
912
- except KeyError:
913
- # Try with just inputs if outputs not found (inference-only mode)
977
+ except KeyError as e:
978
+ # IMPORTANT: Only fall back to inference-only mode if outputs are
979
+ # genuinely missing (auto-detection failed). If user explicitly
980
+ # provided --output_key, they expect it to exist - don't silently drop.
981
+ if output_key is not None:
982
+ raise KeyError(
983
+ f"Explicit --output_key '{output_key}' not found in file. "
984
+ f"Available keys depend on file format. Original error: {e}"
985
+ ) from e
986
+
987
+ # Legitimate fallback: no explicit output_key, outputs just not present
914
988
  if format == "npz":
915
989
  # First pass to find keys
916
990
  with np.load(path, allow_pickle=False) as probe:
@@ -1020,9 +1094,33 @@ def load_test_data(
1020
1094
  if input_channels == 1:
1021
1095
  X = X.unsqueeze(1) # Add channel: (N, D, H, W) → (N, 1, D, H, W)
1022
1096
  # else: already has channels, leave as-is
1023
- elif X.shape[1] > 16:
1024
- # Heuristic fallback: large dim 1 suggests 3D volume needing channel
1025
- X = X.unsqueeze(1) # 3D volume: (N, D, H, W) (N, 1, D, H, W)
1097
+ else:
1098
+ # Detect channels-last format: (N, H, W, C) where C is small (1-4)
1099
+ # and spatial dims are large (>16). This catches common mistakes.
1100
+ if X.shape[-1] <= 4 and X.shape[1] > 16 and X.shape[2] > 16:
1101
+ raise ValueError(
1102
+ f"Input appears to be channels-last format: {tuple(X.shape)}. "
1103
+ "WaveDL expects channels-first (N, C, H, W). "
1104
+ "Convert your data using: X = X.permute(0, 3, 1, 2). "
1105
+ "If this is actually a 3D volume with small depth, "
1106
+ "use --input_channels 1 to add a channel dimension."
1107
+ )
1108
+ elif X.shape[1] > 16:
1109
+ # Heuristic fallback: large dim 1 suggests 3D volume needing channel
1110
+ X = X.unsqueeze(1) # 3D volume: (N, D, H, W) → (N, 1, D, H, W)
1111
+ else:
1112
+ # Ambiguous case: shallow 3D volume (D <= 16) or multi-channel 2D
1113
+ # Default to treating as multi-channel 2D (no modification needed)
1114
+ # Log a warning so users know about the --input_channels option
1115
+ import warnings
1116
+
1117
+ warnings.warn(
1118
+ f"Ambiguous 4D input shape: {tuple(X.shape)}. "
1119
+ f"Assuming {X.shape[1]} channels (multi-channel 2D). "
1120
+ f"For 3D volumes with depth={X.shape[1]}, use --input_channels 1.",
1121
+ UserWarning,
1122
+ stacklevel=2,
1123
+ )
1026
1124
  # X.ndim >= 5: assume channel dimension already exists
1027
1125
 
1028
1126
  return X, y
@@ -1207,7 +1305,9 @@ def prepare_data(
1207
1305
  cache_exists = False
1208
1306
  # Content hash check (robust against cloud sync mtime changes)
1209
1307
  elif cached_content_hash is not None:
1210
- current_hash = _compute_file_hash(args.data_path)
1308
+ current_hash = _compute_file_hash(
1309
+ args.data_path, mode=getattr(args, "cache_validate", "sha256")
1310
+ )
1211
1311
  if cached_content_hash != current_hash:
1212
1312
  if accelerator.is_main_process:
1213
1313
  logger.warning(
@@ -1330,7 +1430,9 @@ def prepare_data(
1330
1430
 
1331
1431
  # Save metadata (including data path, size, content hash for cache validation)
1332
1432
  file_stats = os.stat(args.data_path)
1333
- content_hash = _compute_file_hash(args.data_path)
1433
+ content_hash = _compute_file_hash(
1434
+ args.data_path, mode=getattr(args, "cache_validate", "sha256")
1435
+ )
1334
1436
  with open(META_FILE, "wb") as f:
1335
1437
  pickle.dump(
1336
1438
  {