wavedl 1.4.0__py3-none-any.whl → 1.4.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
wavedl/__init__.py CHANGED
@@ -18,7 +18,7 @@ For inference:
18
18
  # or: python -m wavedl.test --checkpoint best_checkpoint --data_path test.npz
19
19
  """
20
20
 
21
- __version__ = "1.4.0"
21
+ __version__ = "1.4.2"
22
22
  __author__ = "Ductho Le"
23
23
  __email__ = "ductho.le@outlook.com"
24
24
 
wavedl/hpc.py CHANGED
@@ -130,6 +130,18 @@ Environment Variables:
130
130
  default=0,
131
131
  help="Rank of this machine in multi-node setup (default: 0)",
132
132
  )
133
+ parser.add_argument(
134
+ "--main_process_ip",
135
+ type=str,
136
+ default=None,
137
+ help="IP address of the main process for multi-node training",
138
+ )
139
+ parser.add_argument(
140
+ "--main_process_port",
141
+ type=int,
142
+ default=None,
143
+ help="Port for multi-node communication (default: accelerate auto-selects)",
144
+ )
133
145
  parser.add_argument(
134
146
  "--mixed_precision",
135
147
  type=str,
@@ -207,12 +219,18 @@ def main() -> int:
207
219
  "launch",
208
220
  f"--num_processes={num_gpus}",
209
221
  f"--num_machines={args.num_machines}",
210
- "--machine_rank=0",
222
+ f"--machine_rank={args.machine_rank}",
211
223
  f"--mixed_precision={args.mixed_precision}",
212
224
  f"--dynamo_backend={args.dynamo_backend}",
213
- "-m",
214
- "wavedl.train",
215
- ] + train_args
225
+ ]
226
+
227
+ # Add multi-node networking args if specified (required for some clusters)
228
+ if args.main_process_ip:
229
+ cmd.append(f"--main_process_ip={args.main_process_ip}")
230
+ if args.main_process_port:
231
+ cmd.append(f"--main_process_port={args.main_process_port}")
232
+
233
+ cmd += ["-m", "wavedl.train"] + train_args
216
234
 
217
235
  # Create output directory if specified
218
236
  for i, arg in enumerate(train_args):
wavedl/hpo.py CHANGED
@@ -145,6 +145,7 @@ def create_objective(args):
145
145
  # Use temporary directory for trial output
146
146
  with tempfile.TemporaryDirectory() as tmpdir:
147
147
  cmd.extend(["--output_dir", tmpdir])
148
+ history_file = Path(tmpdir) / "training_history.csv"
148
149
 
149
150
  # Run training
150
151
  try:
@@ -156,29 +157,55 @@ def create_objective(args):
156
157
  cwd=Path(__file__).parent,
157
158
  )
158
159
 
159
- # Parse validation loss from output
160
- # Look for "Best val_loss: X.XXXX" in stdout
160
+ # Read best val_loss from training_history.csv (reliable machine-readable)
161
161
  val_loss = None
162
- for line in result.stdout.split("\n"):
163
- if "Best val_loss:" in line:
164
- try:
165
- val_loss = float(line.split(":")[-1].strip())
166
- except ValueError:
167
- pass
168
- # Also check for final validation loss
169
- if "val_loss=" in line.lower():
170
- try:
171
- # Extract number after val_loss=
172
- parts = line.lower().split("val_loss=")
173
- if len(parts) > 1:
174
- val_str = parts[1].split()[0].strip(",")
175
- val_loss = float(val_str)
176
- except (ValueError, IndexError):
177
- pass
162
+ if history_file.exists():
163
+ try:
164
+ import csv
165
+
166
+ with open(history_file) as f:
167
+ reader = csv.DictReader(f)
168
+ val_losses = []
169
+ for row in reader:
170
+ if "val_loss" in row:
171
+ try:
172
+ val_losses.append(float(row["val_loss"]))
173
+ except (ValueError, TypeError):
174
+ pass
175
+ if val_losses:
176
+ val_loss = min(val_losses) # Best (minimum) val_loss
177
+ except Exception as e:
178
+ print(f"Trial {trial.number}: Error reading history: {e}")
179
+
180
+ if val_loss is None:
181
+ # Fallback: parse stdout for training log format
182
+ # Pattern: "epoch | train_loss | val_loss | ..."
183
+ # Use regex to avoid false positives from unrelated lines
184
+ import re
185
+
186
+ # Match lines like: " 42 | 0.0123 | 0.0156 | ..."
187
+ log_pattern = re.compile(
188
+ r"^\s*\d+\s*\|\s*[\d.]+\s*\|\s*([\d.]+)\s*\|"
189
+ )
190
+ val_losses_stdout = []
191
+ for line in result.stdout.split("\n"):
192
+ match = log_pattern.match(line)
193
+ if match:
194
+ try:
195
+ val_losses_stdout.append(float(match.group(1)))
196
+ except ValueError:
197
+ continue
198
+ if val_losses_stdout:
199
+ val_loss = min(val_losses_stdout)
178
200
 
179
201
  if val_loss is None:
180
202
  # Training failed or no loss found
181
- print(f"Trial {trial.number}: Training failed")
203
+ print(f"Trial {trial.number}: Training failed (no val_loss found)")
204
+ if result.returncode != 0:
205
+ # Show last few lines of stderr for debugging
206
+ stderr_lines = result.stderr.strip().split("\n")[-3:]
207
+ for line in stderr_lines:
208
+ print(f" stderr: {line}")
182
209
  return float("inf")
183
210
 
184
211
  print(f"Trial {trial.number}: val_loss={val_loss:.6f}")
@@ -1,23 +1,26 @@
1
1
  """
2
- Model Template for New Architectures
3
- =====================================
2
+ Model Template for Custom Architectures
3
+ ========================================
4
4
 
5
- Copy this file and modify to add new model architectures to the framework.
5
+ Copy this file and modify to add custom model architectures to WaveDL.
6
6
  The model will be automatically registered and available via --model flag.
7
7
 
8
- Steps to Add a New Model:
9
- 1. Copy this file to models/your_model.py
10
- 2. Rename the class and update @register_model("your_model")
11
- 3. Implement the __init__ and forward methods
12
- 4. Import your model in models/__init__.py:
13
- from wavedl.models.your_model import YourModel
14
- 5. Run: accelerate launch -m wavedl.train --model your_model --wandb
8
+ Quick Start:
9
+ 1. Copy this file to your project: cp _template.py my_model.py
10
+ 2. Rename the class and update @register_model("my_model")
11
+ 3. Implement your architecture in __init__ and forward
12
+ 4. Train: wavedl-train --import my_model --model my_model --data_path data.npz
13
+
14
+ Requirements (your model MUST):
15
+ 1. Inherit from BaseModel
16
+ 2. Accept (in_shape, out_size, **kwargs) in __init__
17
+ 3. Return tensor of shape (batch, out_size) from forward()
18
+
19
+ See README.md "Adding Custom Models" section for more details.
15
20
 
16
21
  Author: Ductho Le (ductho.le@outlook.com)
17
22
  """
18
23
 
19
- from typing import Any
20
-
21
24
  import torch
22
25
  import torch.nn as nn
23
26
 
@@ -25,7 +28,7 @@ from wavedl.models.base import BaseModel
25
28
 
26
29
 
27
30
  # Uncomment the decorator to register this model
28
- # @register_model("template")
31
+ # @register_model("my_model")
29
32
  class TemplateModel(BaseModel):
30
33
  """
31
34
  Template Model Architecture.
@@ -34,14 +37,16 @@ class TemplateModel(BaseModel):
34
37
  The first line will appear in --list_models output.
35
38
 
36
39
  Args:
37
- in_shape: Input spatial dimensions (H, W)
38
- out_size: Number of regression output targets
40
+ in_shape: Input spatial dimensions (auto-detected from data)
41
+ - 1D: (L,) for signals
42
+ - 2D: (H, W) for images
43
+ - 3D: (D, H, W) for volumes
44
+ out_size: Number of regression targets (auto-detected from data)
39
45
  hidden_dim: Size of hidden layers (default: 256)
40
- num_layers: Number of convolutional layers (default: 4)
41
46
  dropout: Dropout rate (default: 0.1)
42
47
 
43
48
  Input Shape:
44
- (B, 1, H, W) - Single-channel images
49
+ (B, 1, *in_shape) - e.g., (B, 1, 64, 64) for 2D
45
50
 
46
51
  Output Shape:
47
52
  (B, out_size) - Regression predictions
@@ -49,10 +54,9 @@ class TemplateModel(BaseModel):
49
54
 
50
55
  def __init__(
51
56
  self,
52
- in_shape: tuple[int, int],
57
+ in_shape: tuple,
53
58
  out_size: int,
54
59
  hidden_dim: int = 256,
55
- num_layers: int = 4,
56
60
  dropout: float = 0.1,
57
61
  **kwargs, # Accept extra kwargs for flexibility
58
62
  ):
@@ -61,14 +65,13 @@ class TemplateModel(BaseModel):
61
65
 
62
66
  # Store hyperparameters as attributes (optional but recommended)
63
67
  self.hidden_dim = hidden_dim
64
- self.num_layers = num_layers
65
68
  self.dropout_rate = dropout
66
69
 
67
70
  # =================================================================
68
71
  # BUILD YOUR ARCHITECTURE HERE
69
72
  # =================================================================
70
73
 
71
- # Example: Simple CNN encoder
74
+ # Example: Simple CNN encoder (assumes 2D input with 1 channel)
72
75
  self.encoder = nn.Sequential(
73
76
  # Layer 1
74
77
  nn.Conv2d(1, 32, kernel_size=3, padding=1),
@@ -106,10 +109,10 @@ class TemplateModel(BaseModel):
106
109
  """
107
110
  Forward pass of the model.
108
111
 
109
- REQUIRED: Must accept (B, C, H, W) and return (B, out_size)
112
+ REQUIRED: Must accept (B, C, *spatial) and return (B, out_size)
110
113
 
111
114
  Args:
112
- x: Input tensor of shape (B, 1, H, W)
115
+ x: Input tensor of shape (B, 1, *in_shape)
113
116
 
114
117
  Returns:
115
118
  Output tensor of shape (B, out_size)
@@ -122,35 +125,20 @@ class TemplateModel(BaseModel):
122
125
 
123
126
  return output
124
127
 
125
- @classmethod
126
- def get_default_config(cls) -> dict[str, Any]:
127
- """
128
- Return default hyperparameters for this model.
129
-
130
- OPTIONAL: Override to provide model-specific defaults.
131
- These can be used for documentation or config files.
132
- """
133
- return {
134
- "hidden_dim": 256,
135
- "num_layers": 4,
136
- "dropout": 0.1,
137
- }
138
-
139
128
 
140
129
  # =============================================================================
141
130
  # USAGE EXAMPLE
142
131
  # =============================================================================
143
132
  if __name__ == "__main__":
144
133
  # Quick test of the model
145
- model = TemplateModel(in_shape=(500, 500), out_size=5)
134
+ model = TemplateModel(in_shape=(64, 64), out_size=5)
146
135
 
147
136
  # Print model summary
148
137
  print(f"Model: {model.__class__.__name__}")
149
138
  print(f"Parameters: {model.count_parameters():,}")
150
- print(f"Default config: {model.get_default_config()}")
151
139
 
152
140
  # Test forward pass
153
- dummy_input = torch.randn(2, 1, 500, 500)
141
+ dummy_input = torch.randn(2, 1, 64, 64)
154
142
  output = model(dummy_input)
155
143
  print(f"Input shape: {dummy_input.shape}")
156
144
  print(f"Output shape: {output.shape}")
wavedl/models/base.py CHANGED
@@ -75,13 +75,61 @@ class BaseModel(nn.Module, ABC):
75
75
  Forward pass of the model.
76
76
 
77
77
  Args:
78
- x: Input tensor of shape (B, C, H, W)
78
+ x: Input tensor of shape (B, C, *spatial_dims)
79
+ - 1D: (B, C, L)
80
+ - 2D: (B, C, H, W)
81
+ - 3D: (B, C, D, H, W)
79
82
 
80
83
  Returns:
81
84
  Output tensor of shape (B, out_size)
82
85
  """
83
86
  pass
84
87
 
88
+ def validate_input_shape(self, x: torch.Tensor) -> None:
89
+ """
90
+ Validate input tensor shape against model's expected shape.
91
+
92
+ Call this at the start of forward() for explicit shape contract enforcement.
93
+ Provides clear, actionable error messages instead of cryptic Conv layer errors.
94
+
95
+ Args:
96
+ x: Input tensor to validate
97
+
98
+ Raises:
99
+ ValueError: If shape doesn't match expected dimensions
100
+
101
+ Example:
102
+ def forward(self, x):
103
+ self.validate_input_shape(x) # Optional but recommended
104
+ return self.model(x)
105
+ """
106
+ expected_ndim = len(self.in_shape) + 2 # +2 for (batch, channel)
107
+
108
+ if x.ndim != expected_ndim:
109
+ dim_names = {
110
+ 3: "1D (B, C, L)",
111
+ 4: "2D (B, C, H, W)",
112
+ 5: "3D (B, C, D, H, W)",
113
+ }
114
+ expected_name = dim_names.get(expected_ndim, f"{expected_ndim}D")
115
+ actual_name = dim_names.get(x.ndim, f"{x.ndim}D")
116
+ raise ValueError(
117
+ f"Input shape mismatch: model expects {expected_name} input, "
118
+ f"got {actual_name} with shape {tuple(x.shape)}.\n"
119
+ f"Expected in_shape: {self.in_shape} -> input should be (B, C, {', '.join(map(str, self.in_shape))})\n"
120
+ f"Hint: Check your data preprocessing - you may need to add/remove dimensions."
121
+ )
122
+
123
+ # Validate spatial dimensions match
124
+ spatial_dims = tuple(x.shape[2:]) # Skip batch and channel
125
+ if spatial_dims != tuple(self.in_shape):
126
+ raise ValueError(
127
+ f"Spatial dimension mismatch: model expects {self.in_shape}, "
128
+ f"got {spatial_dims}.\n"
129
+ f"Full input shape: {tuple(x.shape)} (B={x.shape[0]}, C={x.shape[1]})\n"
130
+ f"Hint: Ensure your data dimensions match the model's in_shape."
131
+ )
132
+
85
133
  def count_parameters(self, trainable_only: bool = True) -> int:
86
134
  """
87
135
  Count the number of parameters in the model.
wavedl/train.py CHANGED
@@ -12,11 +12,11 @@ A modular training framework for wave-based inverse problems and regression:
12
12
  6. Deep Observability: WandB integration with scatter analysis
13
13
 
14
14
  Usage:
15
- # Recommended: Using the HPC launcher
16
- wavedl-hpc --model cnn --batch_size 128 --wandb
15
+ # Recommended: Using the HPC launcher (handles accelerate configuration)
16
+ wavedl-hpc --model cnn --batch_size 128 --mixed_precision bf16 --wandb
17
17
 
18
- # Or with direct accelerate launch
19
- accelerate launch -m wavedl.train --model cnn --batch_size 128 --wandb
18
+ # Or direct training module (use --precision, not --mixed_precision)
19
+ accelerate launch -m wavedl.train --model cnn --batch_size 128 --precision bf16
20
20
 
21
21
  # Multi-GPU with explicit config
22
22
  wavedl-hpc --num_gpus 4 --mixed_precision bf16 --model cnn --wandb
@@ -28,9 +28,9 @@ Usage:
28
28
  wavedl-train --list_models
29
29
 
30
30
  Note:
31
- For HPC clusters (Compute Canada, etc.), use wavedl-hpc which handles
32
- environment configuration automatically. Mixed precision is controlled via
33
- --mixed_precision flag (default: bf16).
31
+ - wavedl-hpc: Uses --mixed_precision (passed to accelerate launch)
32
+ - wavedl.train: Uses --precision (internal module flag)
33
+ Both control the same behavior; use the appropriate flag for your entry point.
34
34
 
35
35
  Author: Ductho Le (ductho.le@outlook.com)
36
36
  """
@@ -97,6 +97,18 @@ warnings.filterwarnings("ignore", category=DeprecationWarning)
97
97
  warnings.filterwarnings("ignore", module="pydantic")
98
98
  warnings.filterwarnings("ignore", message=".*UnsupportedFieldAttributeWarning.*")
99
99
 
100
+ # ==============================================================================
101
+ # GPU PERFORMANCE OPTIMIZATIONS (Ampere/Hopper: A100, H100)
102
+ # ==============================================================================
103
+ # Enable TF32 for faster matmul (safe precision for training, ~2x speedup)
104
+ torch.backends.cuda.matmul.allow_tf32 = True
105
+ torch.backends.cudnn.allow_tf32 = True
106
+ torch.set_float32_matmul_precision("high") # Use TF32 for float32 ops
107
+
108
+ # Enable cuDNN autotuning for fixed-size inputs (CNN-like models benefit most)
109
+ # Note: First few batches may be slower due to benchmarking
110
+ torch.backends.cudnn.benchmark = True
111
+
100
112
 
101
113
  # ==============================================================================
102
114
  # ARGUMENT PARSING
@@ -298,11 +310,24 @@ def parse_args() -> argparse.Namespace:
298
310
  choices=["bf16", "fp16", "no"],
299
311
  help="Mixed precision mode",
300
312
  )
313
+ # Alias for consistency with wavedl-hpc (--mixed_precision)
314
+ parser.add_argument(
315
+ "--mixed_precision",
316
+ dest="precision",
317
+ type=str,
318
+ choices=["bf16", "fp16", "no"],
319
+ help=argparse.SUPPRESS, # Hidden: use --precision instead
320
+ )
301
321
 
302
322
  # Logging
303
323
  parser.add_argument(
304
324
  "--wandb", action="store_true", help="Enable Weights & Biases logging"
305
325
  )
326
+ parser.add_argument(
327
+ "--wandb_watch",
328
+ action="store_true",
329
+ help="Enable WandB gradient watching (adds overhead, useful for debugging)",
330
+ )
306
331
  parser.add_argument(
307
332
  "--project_name", type=str, default="DL-Training", help="WandB project name"
308
333
  )
@@ -467,8 +492,8 @@ def main():
467
492
  if _cv_handle is not None and hasattr(_cv_handle, "close"):
468
493
  try:
469
494
  _cv_handle.close()
470
- except Exception:
471
- pass
495
+ except Exception as e:
496
+ logging.debug(f"Failed to close CV data handle: {e}")
472
497
  return
473
498
 
474
499
  # ==========================================================================
@@ -496,9 +521,9 @@ def main():
496
521
  if args.workers < 0:
497
522
  cpu_count = os.cpu_count() or 4
498
523
  num_gpus = accelerator.num_processes
499
- # Heuristic: 4-8 workers per GPU, bounded by available CPU cores
500
- # Leave some cores for main process and system overhead
501
- args.workers = min(8, max(2, (cpu_count - 2) // num_gpus))
524
+ # Heuristic: 4-16 workers per GPU, bounded by available CPU cores
525
+ # Increased cap from 8 to 16 for high-throughput GPUs (H100, A100)
526
+ args.workers = min(16, max(2, (cpu_count - 2) // num_gpus))
502
527
  if accelerator.is_main_process:
503
528
  logger.info(
504
529
  f"⚙️ Auto-detected workers: {args.workers} per GPU "
@@ -544,9 +569,15 @@ def main():
544
569
  )
545
570
  logger.info(f" Model Size: {param_info['total_mb']:.2f} MB")
546
571
 
547
- # Optional WandB model watching
548
- if args.wandb and WANDB_AVAILABLE and accelerator.is_main_process:
572
+ # Optional WandB model watching (opt-in due to overhead on large models)
573
+ if (
574
+ args.wandb
575
+ and args.wandb_watch
576
+ and WANDB_AVAILABLE
577
+ and accelerator.is_main_process
578
+ ):
549
579
  wandb.watch(model, log="gradients", log_freq=100)
580
+ logger.info(" 📊 WandB gradient watching enabled")
550
581
 
551
582
  # Torch 2.0 compilation (requires compatible Triton on GPU)
552
583
  if args.compile:
@@ -820,7 +851,7 @@ def main():
820
851
  val_mae_sum = torch.zeros(out_dim, device=accelerator.device)
821
852
  val_samples = 0
822
853
 
823
- # Accumulate predictions locally, gather ONCE at end (reduces sync overhead)
854
+ # Accumulate predictions locally ON CPU to prevent GPU OOM
824
855
  local_preds = []
825
856
  local_targets = []
826
857
 
@@ -836,17 +867,19 @@ def main():
836
867
  mae_batch = torch.abs((pred - y) * phys_scale).sum(dim=0)
837
868
  val_mae_sum += mae_batch
838
869
 
839
- # Store locally (no GPU sync per batch)
840
- local_preds.append(pred)
841
- local_targets.append(y)
870
+ # Store on CPU (critical for large val sets)
871
+ local_preds.append(pred.detach().cpu())
872
+ local_targets.append(y.detach().cpu())
873
+
874
+ # Concatenate locally on CPU (no GPU memory spike)
875
+ cpu_preds = torch.cat(local_preds)
876
+ cpu_targets = torch.cat(local_targets)
842
877
 
843
- # Single gather at end of validation (2 syncs instead of 2×num_batches)
844
- all_local_preds = torch.cat(local_preds)
845
- all_local_targets = torch.cat(local_targets)
846
- all_preds = accelerator.gather_for_metrics(all_local_preds)
847
- all_targets = accelerator.gather_for_metrics(all_local_targets)
878
+ # Gather to rank 0 only via gather_object (avoids all-gather to every rank)
879
+ # gather_object returns list of objects from each rank: [(preds0, targs0), (preds1, targs1), ...]
880
+ gathered = accelerator.gather_object((cpu_preds, cpu_targets))
848
881
 
849
- # Synchronize validation metrics
882
+ # Synchronize validation metrics (scalars only - efficient)
850
883
  val_loss_scalar = val_loss_sum.item()
851
884
  val_metrics = torch.cat(
852
885
  [
@@ -869,9 +902,14 @@ def main():
869
902
 
870
903
  # ==================== LOGGING & CHECKPOINTING ====================
871
904
  if accelerator.is_main_process:
872
- # Scientific metrics - cast to float32 before numpy (bf16 can't convert)
873
- y_pred = all_preds.float().cpu().numpy()
874
- y_true = all_targets.float().cpu().numpy()
905
+ # Concatenate gathered tensors from all ranks (only on rank 0)
906
+ # gathered is list of tuples: [(preds_rank0, targs_rank0), (preds_rank1, targs_rank1), ...]
907
+ all_preds = torch.cat([item[0] for item in gathered])
908
+ all_targets = torch.cat([item[1] for item in gathered])
909
+
910
+ # Scientific metrics - cast to float32 before numpy
911
+ y_pred = all_preds.float().numpy()
912
+ y_true = all_targets.float().numpy()
875
913
 
876
914
  # Trim DDP padding
877
915
  real_len = len(val_dl.dataset)
wavedl/utils/config.py CHANGED
@@ -116,7 +116,13 @@ def merge_config_with_args(
116
116
  """
117
117
  # Get parser defaults to detect which args were explicitly set by user
118
118
  if parser is not None:
119
- defaults = vars(parser.parse_args([]))
119
+ # Safe extraction: iterate actions instead of parse_args([])
120
+ # This avoids failures if required arguments are added later
121
+ defaults = {
122
+ action.dest: action.default
123
+ for action in parser._actions
124
+ if action.dest != "help"
125
+ }
120
126
  else:
121
127
  # Fallback: reconstruct defaults from known patterns
122
128
  # This works because argparse stores actual values, and we compare
@@ -141,6 +147,9 @@ def merge_config_with_args(
141
147
  setattr(args, key, value)
142
148
  elif not ignore_unknown:
143
149
  logging.warning(f"Unknown config key: {key}")
150
+ else:
151
+ # Even in ignore_unknown mode, log for discoverability
152
+ logging.debug(f"Config key '{key}' ignored: not a valid argument")
144
153
 
145
154
  return args
146
155
 
@@ -188,12 +197,15 @@ def save_config(
188
197
  return str(output_path)
189
198
 
190
199
 
191
- def validate_config(config: dict[str, Any]) -> list[str]:
200
+ def validate_config(
201
+ config: dict[str, Any], known_keys: list[str] | None = None
202
+ ) -> list[str]:
192
203
  """
193
204
  Validate configuration values against known options.
194
205
 
195
206
  Args:
196
207
  config: Configuration dictionary
208
+ known_keys: Optional list of valid keys (if None, uses defaults from parser args)
197
209
 
198
210
  Returns:
199
211
  List of warning messages (empty if valid)
@@ -229,9 +241,83 @@ def validate_config(config: dict[str, Any]) -> list[str]:
229
241
  for key, (min_val, max_val, msg) in numeric_checks.items():
230
242
  if key in config:
231
243
  val = config[key]
244
+ # Type check: ensure value is numeric before comparison
245
+ if not isinstance(val, (int, float)):
246
+ warnings.append(
247
+ f"Invalid type for '{key}': expected number, got {type(val).__name__} ({val!r})"
248
+ )
249
+ continue
232
250
  if not (min_val <= val <= max_val):
233
251
  warnings.append(f"{msg}: got {val}")
234
252
 
253
+ # Check for unknown/unrecognized keys (helps catch typos)
254
+ # Default known keys based on common training arguments
255
+ default_known_keys = {
256
+ # Model
257
+ "model",
258
+ "import_modules",
259
+ # Hyperparameters
260
+ "batch_size",
261
+ "lr",
262
+ "epochs",
263
+ "patience",
264
+ "weight_decay",
265
+ "grad_clip",
266
+ # Loss
267
+ "loss",
268
+ "huber_delta",
269
+ "loss_weights",
270
+ # Optimizer
271
+ "optimizer",
272
+ "momentum",
273
+ "nesterov",
274
+ "betas",
275
+ # Scheduler
276
+ "scheduler",
277
+ "scheduler_patience",
278
+ "min_lr",
279
+ "scheduler_factor",
280
+ "warmup_epochs",
281
+ "step_size",
282
+ "milestones",
283
+ # Data
284
+ "data_path",
285
+ "workers",
286
+ "seed",
287
+ "single_channel",
288
+ # Cross-validation
289
+ "cv",
290
+ "cv_stratify",
291
+ "cv_bins",
292
+ # Checkpointing
293
+ "resume",
294
+ "save_every",
295
+ "output_dir",
296
+ "fresh",
297
+ # Performance
298
+ "compile",
299
+ "precision",
300
+ "mixed_precision",
301
+ # Logging
302
+ "wandb",
303
+ "wandb_watch",
304
+ "project_name",
305
+ "run_name",
306
+ # Config
307
+ "config",
308
+ "list_models",
309
+ # Metadata (internal)
310
+ "_metadata",
311
+ }
312
+
313
+ check_keys = set(known_keys) if known_keys else default_known_keys
314
+
315
+ for key in config:
316
+ if key not in check_keys:
317
+ warnings.append(
318
+ f"Unknown config key: '{key}' - check for typos or see wavedl-train --help"
319
+ )
320
+
235
321
  return warnings
236
322
 
237
323
 
wavedl/utils/data.py CHANGED
@@ -735,7 +735,7 @@ def load_test_data(
735
735
  try:
736
736
  inp, outp = source.load(path)
737
737
  except KeyError:
738
- # Try with just inputs if outputs not found
738
+ # Try with just inputs if outputs not found (inference-only mode)
739
739
  if format == "npz":
740
740
  data = np.load(path, allow_pickle=True)
741
741
  keys = list(data.keys())
@@ -751,6 +751,54 @@ def load_test_data(
751
751
  )
752
752
  out_key = DataSource._find_key(keys, custom_output_keys)
753
753
  outp = data[out_key] if out_key else None
754
+ elif format == "hdf5":
755
+ # HDF5: input-only loading for inference
756
+ with h5py.File(path, "r") as f:
757
+ keys = list(f.keys())
758
+ inp_key = DataSource._find_key(keys, custom_input_keys)
759
+ if inp_key is None:
760
+ raise KeyError(
761
+ f"Input key not found. Tried: {custom_input_keys}. Found: {keys}"
762
+ )
763
+ # Check size - load_test_data is eager, large files should use DataLoader
764
+ n_samples = f[inp_key].shape[0]
765
+ if n_samples > 100000:
766
+ raise ValueError(
767
+ f"Dataset has {n_samples:,} samples. load_test_data() loads "
768
+ f"everything into RAM which may cause OOM. For large inference "
769
+ f"sets, use a DataLoader with HDF5Source.load_mmap() instead."
770
+ )
771
+ inp = f[inp_key][:]
772
+ out_key = DataSource._find_key(keys, custom_output_keys)
773
+ outp = f[out_key][:] if out_key else None
774
+ elif format == "mat":
775
+ # MAT v7.3: input-only loading with proper sparse handling
776
+ mat_source = MATSource()
777
+ with h5py.File(path, "r") as f:
778
+ keys = list(f.keys())
779
+ inp_key = DataSource._find_key(keys, custom_input_keys)
780
+ if inp_key is None:
781
+ raise KeyError(
782
+ f"Input key not found. Tried: {custom_input_keys}. Found: {keys}"
783
+ )
784
+ # Check size - load_test_data is eager, large files should use DataLoader
785
+ n_samples = f[inp_key].shape[-1] # MAT is transposed
786
+ if n_samples > 100000:
787
+ raise ValueError(
788
+ f"Dataset has {n_samples:,} samples. load_test_data() loads "
789
+ f"everything into RAM which may cause OOM. For large inference "
790
+ f"sets, use a DataLoader with MATSource.load_mmap() instead."
791
+ )
792
+ # Use _load_dataset for sparse support and proper transpose
793
+ inp = mat_source._load_dataset(f, inp_key)
794
+ out_key = DataSource._find_key(keys, custom_output_keys)
795
+ if out_key:
796
+ outp = mat_source._load_dataset(f, out_key)
797
+ # Handle 1D outputs that become (1, N) after transpose
798
+ if outp.ndim == 2 and outp.shape[0] == 1:
799
+ outp = outp.T
800
+ else:
801
+ outp = None
754
802
  else:
755
803
  raise
756
804
 
@@ -949,6 +997,15 @@ def prepare_data(
949
997
  with open(META_FILE, "rb") as f:
950
998
  meta = pickle.load(f)
951
999
  cached_data_path = meta.get("data_path", None)
1000
+ cached_file_size = meta.get("file_size", None)
1001
+ cached_file_mtime = meta.get("file_mtime", None)
1002
+
1003
+ # Get current file stats
1004
+ current_stats = os.stat(args.data_path)
1005
+ current_size = current_stats.st_size
1006
+ current_mtime = current_stats.st_mtime
1007
+
1008
+ # Check if data path changed
952
1009
  if cached_data_path != os.path.abspath(args.data_path):
953
1010
  if accelerator.is_main_process:
954
1011
  logger.warning(
@@ -958,6 +1015,23 @@ def prepare_data(
958
1015
  f" Invalidating cache and regenerating..."
959
1016
  )
960
1017
  cache_exists = False
1018
+ # Check if file was modified (size or mtime changed)
1019
+ elif cached_file_size is not None and cached_file_size != current_size:
1020
+ if accelerator.is_main_process:
1021
+ logger.warning(
1022
+ f"⚠️ Data file size changed!\n"
1023
+ f" Cached size: {cached_file_size:,} bytes\n"
1024
+ f" Current size: {current_size:,} bytes\n"
1025
+ f" Invalidating cache and regenerating..."
1026
+ )
1027
+ cache_exists = False
1028
+ elif cached_file_mtime is not None and cached_file_mtime != current_mtime:
1029
+ if accelerator.is_main_process:
1030
+ logger.warning(
1031
+ "⚠️ Data file was modified!\n"
1032
+ " Cache may be stale, regenerating..."
1033
+ )
1034
+ cache_exists = False
961
1035
  except Exception:
962
1036
  cache_exists = False
963
1037
 
@@ -1053,13 +1127,16 @@ def prepare_data(
1053
1127
  f" Shape Detected: {full_shape} [{dim_type}] | Output Dim: {out_dim}"
1054
1128
  )
1055
1129
 
1056
- # Save metadata (including data path for cache validation)
1130
+ # Save metadata (including data path, size, mtime for cache validation)
1131
+ file_stats = os.stat(args.data_path)
1057
1132
  with open(META_FILE, "wb") as f:
1058
1133
  pickle.dump(
1059
1134
  {
1060
1135
  "shape": full_shape,
1061
1136
  "out_dim": out_dim,
1062
1137
  "data_path": os.path.abspath(args.data_path),
1138
+ "file_size": file_stats.st_size,
1139
+ "file_mtime": file_stats.st_mtime,
1063
1140
  },
1064
1141
  f,
1065
1142
  )
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.2
2
2
  Name: wavedl
3
- Version: 1.4.0
3
+ Version: 1.4.2
4
4
  Summary: A Scalable Deep Learning Framework for Wave-Based Inverse Problems
5
5
  Author: Ductho Le
6
6
  License: MIT
@@ -502,12 +502,32 @@ WaveDL/
502
502
 
503
503
  | Argument | Default | Description |
504
504
  |----------|---------|-------------|
505
- | `--compile` | `False` | Enable `torch.compile` |
505
+ | `--compile` | `False` | Enable `torch.compile` (recommended for long runs) |
506
506
  | `--precision` | `bf16` | Mixed precision mode (`bf16`, `fp16`, `no`) |
507
+ | `--workers` | `-1` | DataLoader workers per GPU (-1=auto, up to 16) |
507
508
  | `--wandb` | `False` | Enable W&B logging |
509
+ | `--wandb_watch` | `False` | Enable W&B gradient watching (adds overhead) |
508
510
  | `--project_name` | `DL-Training` | W&B project name |
509
511
  | `--run_name` | `None` | W&B run name (auto-generated if not set) |
510
512
 
513
+ **Automatic GPU Optimizations:**
514
+
515
+ WaveDL automatically enables performance optimizations for modern GPUs:
516
+
517
+ | Optimization | Effect | GPU Support |
518
+ |--------------|--------|-------------|
519
+ | **TF32 precision** | ~2x speedup for float32 matmul | A100, H100 (Ampere+) |
520
+ | **cuDNN benchmark** | Auto-tuned convolutions | All NVIDIA GPUs |
521
+ | **Worker scaling** | Up to 16 workers per GPU | All systems |
522
+
523
+ > [!NOTE]
524
+ > These optimizations are **backward compatible** — they have no effect on older GPUs (V100, T4, GTX) or CPU-only systems. No configuration needed.
525
+
526
+ **HPC Best Practices:**
527
+ - Stage data to `$SLURM_TMPDIR` (local NVMe) for maximum I/O throughput
528
+ - Use `--compile` for training runs > 50 epochs
529
+ - Increase `--workers` manually if auto-detection is suboptimal
530
+
511
531
  </details>
512
532
 
513
533
  <details>
@@ -690,7 +710,7 @@ python -m wavedl.hpo --data_path train.npz --models cnn --n_trials 50
690
710
  python -m wavedl.hpo --data_path train.npz --models cnn resnet18 resnet50 vit_small densenet121 --n_trials 200
691
711
  ```
692
712
 
693
- **Step 3: Train with best parameters**
713
+ **Train with best parameters**
694
714
 
695
715
  After HPO completes, it prints the optimal command:
696
716
  ```bash
@@ -837,7 +857,7 @@ import numpy as np
837
857
  X = np.random.randn(1000, 256, 256).astype(np.float32)
838
858
  y = np.random.randn(1000, 5).astype(np.float32)
839
859
 
840
- np.savez('test_data.npz', input_train=X, output_train=y)
860
+ np.savez('test_data.npz', input_test=X, output_test=y)
841
861
  ```
842
862
 
843
863
  </details>
@@ -849,7 +869,7 @@ np.savez('test_data.npz', input_train=X, output_train=y)
849
869
  import numpy as np
850
870
 
851
871
  data = np.load('train_data.npz')
852
- assert data['input_train'].ndim == 3, "Input must be 3D: (N, H, W)"
872
+ assert data['input_train'].ndim >= 2, "Input must be at least 2D: (N, ...) "
853
873
  assert data['output_train'].ndim == 2, "Output must be 2D: (N, T)"
854
874
  assert len(data['input_train']) == len(data['output_train']), "Sample mismatch"
855
875
 
@@ -1004,7 +1024,7 @@ Beyond the material characterization example above, the WaveDL pipeline can be a
1004
1024
  | Resource | Description |
1005
1025
  |----------|-------------|
1006
1026
  | Technical Paper | In-depth framework description *(coming soon)* |
1007
- | [`_template.py`](models/_template.py) | Template for new architectures |
1027
+ | [`_template.py`](src/wavedl/models/_template.py) | Template for custom architectures |
1008
1028
 
1009
1029
  ---
1010
1030
 
@@ -1,11 +1,11 @@
1
- wavedl/__init__.py,sha256=rGJqm3RDNQt8I9vv57urpUSDeLrmKFFYFFbQj7JC85k,1177
2
- wavedl/hpc.py,sha256=87oGGDj1lRJaM01eInHPJX3L-G5aexC3Y06376BdVhg,7321
3
- wavedl/hpo.py,sha256=aZoa_Oto_anZpIhz-YM6kN8KxQXTolUvDEyg3NXwBrY,11542
1
+ wavedl/__init__.py,sha256=K52yq0nkj2B3W0ZpR3tb7RUcHMIiANAE2d1WTuGVZLI,1177
2
+ wavedl/hpc.py,sha256=de_GKERX8GS10sXRX9yXiGzMnk1jjq8JPzRw7QDs6d4,7967
3
+ wavedl/hpo.py,sha256=YJXsnSGEBSVUqp_2ah7zu3_VClAUqZrdkuzDaSqQUjU,12952
4
4
  wavedl/test.py,sha256=jZmRJaivYYTMMTaccCi0yQjHOfp0a9YWR1wAPeKFH-k,36246
5
- wavedl/train.py,sha256=LtPJqULHZU3tPbJNlSwK4GlLxQB1Df2yQfWiEHMn54o,44945
5
+ wavedl/train.py,sha256=Gh02hlfjcote6w1sgUndJzKU_FhkJckTAhlq1aL8FY8,46842
6
6
  wavedl/models/__init__.py,sha256=lfSohEnAUztO14nuwayMJhPjpgySzRN3jGiyAUuBmAU,3206
7
- wavedl/models/_template.py,sha256=TYQSeH6NIHf9iqJ4n6PK0s7xsjitxFsu-n1EdWL24AA,4822
8
- wavedl/models/base.py,sha256=atM7NJQhtgpA-NauLBVUkXTd8XccjGPBqg3qxX05SRo,5289
7
+ wavedl/models/_template.py,sha256=J_D8taSPmV8lBaucN_vU-WiG98iFr7CJrZVNNX_Tdts,4600
8
+ wavedl/models/base.py,sha256=T9iDF9IQM2MYucG_ggQd31rieUkB2fob-nkHyNIl2ak,7337
9
9
  wavedl/models/cnn.py,sha256=rn2Xmup0w_ll6wuAnYclSeIVazoSUrUGPY-9XnhA1gE,8341
10
10
  wavedl/models/convnext.py,sha256=5zELY0ztMB6FxJB9uBurloT7JBdxLXezmrNRzLQjrI0,12846
11
11
  wavedl/models/densenet.py,sha256=LzNbQOvtcJJ4SVf-XvIlXGNUgVS2SXl-MMPbr8lcYrA,12995
@@ -21,17 +21,17 @@ wavedl/models/tcn.py,sha256=RtY13QpFHqz72b4ultv2lStCIDxfvjySVe5JaTx_GaM,12601
21
21
  wavedl/models/unet.py,sha256=LqIXhasdBygwP7SZNNmiW1bHMPaJTVBpaeHtPgEHkdU,7790
22
22
  wavedl/models/vit.py,sha256=0C3GZk11VsYFTl14d86Wtl1Zk1T5rYJjvkaEfEN4N3k,11100
23
23
  wavedl/utils/__init__.py,sha256=YMgzuwndjr64kt9k0_6_9PMJYTVdiaH5veSMff_ZycA,3051
24
- wavedl/utils/config.py,sha256=E0_m5aQ1OdwEwzZysSwc5v905P4g3SDprObFAeVIj9g,8107
24
+ wavedl/utils/config.py,sha256=fMoucikIQHn85mVhGMa7TnXTuFDcEEPjfXk2EjbkJR0,10591
25
25
  wavedl/utils/cross_validation.py,sha256=117ac9KDzaIaqhtP8ZRs15Xpqmq5fLpX2-vqkNvtMaU,17487
26
- wavedl/utils/data.py,sha256=9LrB9MC6jRZzbRSc9xiGzJWoh8FahwP_68REqBAT3Os,44131
26
+ wavedl/utils/data.py,sha256=_OaWvU5oFVJW0NwM5WyDD0Kb1hy5MgvJIFpzvJGux9w,48214
27
27
  wavedl/utils/distributed.py,sha256=7wQ3mRjkp_xjPSxDWMnBf5dSkAGUaTzntxbz0BhC5v0,4145
28
28
  wavedl/utils/losses.py,sha256=5762M-TBC_hz6uyj1NPbU1vZeFOJQq7fR3-j7OygJRo,7254
29
29
  wavedl/utils/metrics.py,sha256=mkCpqZwl_XUpNvA5Ekjf7y-HqApafR7eR6EuA8cBdM8,37287
30
30
  wavedl/utils/optimizers.py,sha256=PyIkJ_hRhFi_Fio81Gy5YQNhcME0JUUEl8OTSyu-0RA,6323
31
31
  wavedl/utils/schedulers.py,sha256=e6Sf0yj8VOqkdwkUHLMyUfGfHKTX4NMr-zfgxWqCTYI,7659
32
- wavedl-1.4.0.dist-info/LICENSE,sha256=cEUCvcvH-9BT9Y-CNGY__PwWONCKu9zsoIqWA-NeHJ4,1066
33
- wavedl-1.4.0.dist-info/METADATA,sha256=NDjDovaoOPWEno7rQVIId5_kOSyYrTA0nQgkyOfbdXI,39286
34
- wavedl-1.4.0.dist-info/WHEEL,sha256=beeZ86-EfXScwlR_HKu4SllMC9wUEj_8Z_4FJ3egI2w,91
35
- wavedl-1.4.0.dist-info/entry_points.txt,sha256=f1RNDkXFZwBzrBzTMFocJ6xhfTvTmaEDTi5YyDEUaF8,140
36
- wavedl-1.4.0.dist-info/top_level.txt,sha256=ccneUt3D5Qzbh3bsBSSrq9bqrhGiogcWKY24ZC4Q6Xw,7
37
- wavedl-1.4.0.dist-info/RECORD,,
32
+ wavedl-1.4.2.dist-info/LICENSE,sha256=cEUCvcvH-9BT9Y-CNGY__PwWONCKu9zsoIqWA-NeHJ4,1066
33
+ wavedl-1.4.2.dist-info/METADATA,sha256=xHMuRcGdF8Tdju6aBK601mqLYx0uSIdmtjoQrHpPYmI,40245
34
+ wavedl-1.4.2.dist-info/WHEEL,sha256=beeZ86-EfXScwlR_HKu4SllMC9wUEj_8Z_4FJ3egI2w,91
35
+ wavedl-1.4.2.dist-info/entry_points.txt,sha256=f1RNDkXFZwBzrBzTMFocJ6xhfTvTmaEDTi5YyDEUaF8,140
36
+ wavedl-1.4.2.dist-info/top_level.txt,sha256=ccneUt3D5Qzbh3bsBSSrq9bqrhGiogcWKY24ZC4Q6Xw,7
37
+ wavedl-1.4.2.dist-info/RECORD,,
File without changes