wavedl 1.6.0__py3-none-any.whl → 1.6.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
wavedl/train.py CHANGED
@@ -1,1430 +1,1430 @@
1
- """
2
- WaveDL - Deep Learning for Wave-based Inverse Problems
3
- =======================================================
4
- Target Environment: NVIDIA HPC GPUs (Multi-GPU DDP) | PyTorch 2.x | Python 3.11+
5
-
6
- A modular training framework for wave-based inverse problems and regression:
7
- 1. HPC-Grade DDP Training: BF16/FP16 mixed precision with torch.compile support
8
- 2. Dynamic Model Selection: Use --model flag to select any registered architecture
9
- 3. Zero-Copy Data Engine: Memmap-backed datasets for large-scale training
10
- 4. Physics-Aware Metrics: Real-time physical MAE with proper unscaling
11
- 5. Robust Checkpointing: Resume training, periodic saves, and training curves
12
- 6. Deep Observability: WandB integration with scatter analysis
13
-
14
- Usage:
15
- # Recommended: Using the HPC launcher (handles accelerate configuration)
16
- wavedl-hpc --model cnn --batch_size 128 --mixed_precision bf16 --wandb
17
-
18
- # Or direct training module (use --precision, not --mixed_precision)
19
- accelerate launch -m wavedl.train --model cnn --batch_size 128 --precision bf16
20
-
21
- # Multi-GPU with explicit config
22
- wavedl-hpc --num_gpus 4 --mixed_precision bf16 --model cnn --wandb
23
-
24
- # Resume from checkpoint
25
- accelerate launch -m wavedl.train --model cnn --resume best_checkpoint --wandb
26
-
27
- # List available models
28
- wavedl-train --list_models
29
-
30
- Note:
31
- - wavedl-hpc: Uses --mixed_precision (passed to accelerate launch)
32
- - wavedl.train: Uses --precision (internal module flag)
33
- Both control the same behavior; use the appropriate flag for your entry point.
34
-
35
- Author: Ductho Le (ductho.le@outlook.com)
36
- """
37
-
38
- from __future__ import annotations
39
-
40
- # =============================================================================
41
- # HPC Environment Setup (MUST be before any library imports)
42
- # =============================================================================
43
- # Auto-configure writable cache directories when home is not writable.
44
- # Uses current working directory as fallback - works on HPC and local machines.
45
- import os
46
-
47
-
48
- def _setup_cache_dir(env_var: str, subdir: str) -> None:
49
- """Set cache directory to CWD if home is not writable."""
50
- if env_var in os.environ:
51
- return # User already set, respect their choice
52
-
53
- # Check if home is writable
54
- home = os.path.expanduser("~")
55
- if os.access(home, os.W_OK):
56
- return # Home is writable, let library use defaults
57
-
58
- # Home not writable - use current working directory
59
- cache_path = os.path.join(os.getcwd(), f".{subdir}")
60
- os.makedirs(cache_path, exist_ok=True)
61
- os.environ[env_var] = cache_path
62
-
63
-
64
- # Configure cache directories (before any library imports)
65
- _setup_cache_dir("TORCH_HOME", "torch_cache")
66
- _setup_cache_dir("MPLCONFIGDIR", "matplotlib")
67
- _setup_cache_dir("FONTCONFIG_CACHE", "fontconfig")
68
- _setup_cache_dir("XDG_DATA_HOME", "local/share")
69
- _setup_cache_dir("XDG_STATE_HOME", "local/state")
70
- _setup_cache_dir("XDG_CACHE_HOME", "cache")
71
-
72
-
73
- def _setup_per_rank_compile_cache() -> None:
74
- """Set per-GPU Triton/Inductor cache to prevent multi-process race warnings.
75
-
76
- When using torch.compile with multiple GPUs, all processes try to write to
77
- the same cache directory, causing 'Directory is not empty - skipping!' warnings.
78
- This gives each GPU rank its own isolated cache subdirectory.
79
- """
80
- # Get local rank from environment (set by accelerate/torchrun)
81
- local_rank = os.environ.get("LOCAL_RANK", "0")
82
-
83
- # Get cache base from environment or use CWD
84
- cache_base = os.environ.get(
85
- "TRITON_CACHE_DIR", os.path.join(os.getcwd(), ".triton_cache")
86
- )
87
-
88
- # Set per-rank cache directories
89
- os.environ["TRITON_CACHE_DIR"] = os.path.join(cache_base, f"rank_{local_rank}")
90
- os.environ["TORCHINDUCTOR_CACHE_DIR"] = os.path.join(
91
- os.environ.get(
92
- "TORCHINDUCTOR_CACHE_DIR", os.path.join(os.getcwd(), ".inductor_cache")
93
- ),
94
- f"rank_{local_rank}",
95
- )
96
-
97
- # Create directories
98
- os.makedirs(os.environ["TRITON_CACHE_DIR"], exist_ok=True)
99
- os.makedirs(os.environ["TORCHINDUCTOR_CACHE_DIR"], exist_ok=True)
100
-
101
-
102
- # Setup per-rank compile caches (before torch imports)
103
- _setup_per_rank_compile_cache()
104
-
105
- # =============================================================================
106
- # Standard imports (after environment setup)
107
- # =============================================================================
108
- import argparse
109
- import logging
110
- import pickle
111
- import shutil
112
- import sys
113
- import time
114
- import warnings
115
- from typing import Any
116
-
117
-
118
- # Suppress Pydantic warnings from accelerate's internal Field() usage
119
- warnings.filterwarnings("ignore", category=UserWarning, module="pydantic")
120
-
121
- import matplotlib.pyplot as plt
122
- import numpy as np
123
- import pandas as pd
124
- import torch
125
- import torch.distributed as dist
126
- from accelerate import Accelerator
127
- from accelerate.utils import set_seed
128
- from sklearn.metrics import r2_score
129
- from tqdm.auto import tqdm
130
-
131
- from wavedl.models import build_model, get_model, list_models
132
- from wavedl.utils import (
133
- FIGURE_DPI,
134
- MetricTracker,
135
- broadcast_early_stop,
136
- calc_pearson,
137
- create_training_curves,
138
- get_loss,
139
- get_lr,
140
- get_optimizer,
141
- get_scheduler,
142
- is_epoch_based,
143
- list_losses,
144
- list_optimizers,
145
- list_schedulers,
146
- plot_scientific_scatter,
147
- prepare_data,
148
- )
149
-
150
-
151
- try:
152
- import wandb
153
-
154
- WANDB_AVAILABLE = True
155
- except ImportError:
156
- WANDB_AVAILABLE = False
157
-
158
- # ==============================================================================
159
- # RUNTIME CONFIGURATION (post-import)
160
- # ==============================================================================
161
- # Configure matplotlib paths for HPC systems without writable home directories
162
- os.environ.setdefault("MPLCONFIGDIR", os.getenv("TMPDIR", "/tmp") + "/matplotlib")
163
- os.environ.setdefault("FONTCONFIG_PATH", "/etc/fonts")
164
-
165
- # Suppress warnings from known-noisy libraries, but preserve legitimate warnings
166
- # from torch/numpy about NaN, dtype, and numerical issues.
167
- warnings.filterwarnings("ignore", category=FutureWarning)
168
- warnings.filterwarnings("ignore", category=DeprecationWarning)
169
- # Pydantic v1/v2 compatibility warnings
170
- warnings.filterwarnings("ignore", module="pydantic")
171
- warnings.filterwarnings("ignore", message=".*UnsupportedFieldAttributeWarning.*")
172
- # Transformer library warnings (loading configs, etc.)
173
- warnings.filterwarnings("ignore", module="transformers")
174
- # Accelerate verbose messages
175
- warnings.filterwarnings("ignore", module="accelerate")
176
- # torch.compile backend selection warnings
177
- warnings.filterwarnings("ignore", message=".*TorchDynamo.*")
178
- warnings.filterwarnings("ignore", message=".*Dynamo is not supported.*")
179
- # Note: UserWarning from torch/numpy core is NOT suppressed to preserve
180
- # legitimate warnings about NaN values, dtype mismatches, etc.
181
-
182
- # ==============================================================================
183
- # GPU PERFORMANCE OPTIMIZATIONS (Ampere/Hopper: A100, H100)
184
- # ==============================================================================
185
- # Enable TF32 for faster matmul (safe precision for training, ~2x speedup)
186
- torch.backends.cuda.matmul.allow_tf32 = True
187
- torch.backends.cudnn.allow_tf32 = True
188
- torch.set_float32_matmul_precision("high") # Use TF32 for float32 ops
189
-
190
- # Enable cuDNN autotuning for fixed-size inputs (CNN-like models benefit most)
191
- # Note: First few batches may be slower due to benchmarking
192
- torch.backends.cudnn.benchmark = True
193
-
194
-
195
- # ==============================================================================
196
- # LOGGING UTILITIES
197
- # ==============================================================================
198
- from contextlib import contextmanager
199
-
200
-
201
- @contextmanager
202
- def suppress_accelerate_logging():
203
- """Temporarily suppress accelerate's verbose checkpoint save messages."""
204
- accelerate_logger = logging.getLogger("accelerate.checkpointing")
205
- original_level = accelerate_logger.level
206
- accelerate_logger.setLevel(logging.WARNING)
207
- try:
208
- yield
209
- finally:
210
- accelerate_logger.setLevel(original_level)
211
-
212
-
213
- # ==============================================================================
214
- # ARGUMENT PARSING
215
- # ==============================================================================
216
- def parse_args() -> argparse.Namespace:
217
- """Parse command-line arguments with comprehensive options."""
218
- parser = argparse.ArgumentParser(
219
- description="Universal DDP Training Pipeline",
220
- formatter_class=argparse.ArgumentDefaultsHelpFormatter,
221
- )
222
-
223
- # Model Selection
224
- parser.add_argument(
225
- "--model",
226
- type=str,
227
- default="cnn",
228
- help=f"Model architecture to train. Available: {list_models()}",
229
- )
230
- parser.add_argument(
231
- "--list_models", action="store_true", help="List all available models and exit"
232
- )
233
- parser.add_argument(
234
- "--import",
235
- dest="import_modules",
236
- type=str,
237
- nargs="+",
238
- default=[],
239
- help="Python modules to import before training (for custom models)",
240
- )
241
- parser.add_argument(
242
- "--no_pretrained",
243
- dest="pretrained",
244
- action="store_false",
245
- help="Train from scratch without pretrained weights (default: use pretrained)",
246
- )
247
- parser.set_defaults(pretrained=True)
248
-
249
- # Configuration File
250
- parser.add_argument(
251
- "--config",
252
- type=str,
253
- default=None,
254
- help="Path to YAML config file. CLI args override config values.",
255
- )
256
-
257
- # Hyperparameters
258
- parser.add_argument(
259
- "--batch_size", type=int, default=128, help="Batch size per GPU"
260
- )
261
- parser.add_argument("--lr", type=float, default=1e-3, help="Initial learning rate")
262
- parser.add_argument(
263
- "--epochs", type=int, default=1000, help="Maximum training epochs"
264
- )
265
- parser.add_argument(
266
- "--patience", type=int, default=20, help="Early stopping patience"
267
- )
268
- parser.add_argument("--weight_decay", type=float, default=1e-4, help="Weight decay")
269
- parser.add_argument(
270
- "--grad_clip", type=float, default=1.0, help="Gradient clipping norm"
271
- )
272
-
273
- # Loss Function
274
- parser.add_argument(
275
- "--loss",
276
- type=str,
277
- default="mse",
278
- choices=["mse", "mae", "huber", "smooth_l1", "log_cosh", "weighted_mse"],
279
- help=f"Loss function for training. Available: {list_losses()}",
280
- )
281
- parser.add_argument(
282
- "--huber_delta", type=float, default=1.0, help="Delta for Huber loss"
283
- )
284
- parser.add_argument(
285
- "--loss_weights",
286
- type=str,
287
- default=None,
288
- help="Comma-separated weights for weighted_mse (e.g., '1.0,2.0,1.0')",
289
- )
290
-
291
- # Optimizer
292
- parser.add_argument(
293
- "--optimizer",
294
- type=str,
295
- default="adamw",
296
- choices=["adamw", "adam", "sgd", "nadam", "radam", "rmsprop"],
297
- help=f"Optimizer for training. Available: {list_optimizers()}",
298
- )
299
- parser.add_argument(
300
- "--momentum", type=float, default=0.9, help="Momentum for SGD/RMSprop"
301
- )
302
- parser.add_argument(
303
- "--nesterov", action="store_true", help="Use Nesterov momentum (SGD)"
304
- )
305
- parser.add_argument(
306
- "--betas",
307
- type=str,
308
- default="0.9,0.999",
309
- help="Betas for Adam variants (comma-separated)",
310
- )
311
-
312
- # Learning Rate Scheduler
313
- parser.add_argument(
314
- "--scheduler",
315
- type=str,
316
- default="plateau",
317
- choices=[
318
- "plateau",
319
- "cosine",
320
- "cosine_restarts",
321
- "onecycle",
322
- "step",
323
- "multistep",
324
- "exponential",
325
- "linear_warmup",
326
- ],
327
- help=f"LR scheduler. Available: {list_schedulers()}",
328
- )
329
- parser.add_argument(
330
- "--scheduler_patience",
331
- type=int,
332
- default=10,
333
- help="Patience for ReduceLROnPlateau",
334
- )
335
- parser.add_argument(
336
- "--min_lr", type=float, default=1e-6, help="Minimum learning rate"
337
- )
338
- parser.add_argument(
339
- "--scheduler_factor", type=float, default=0.5, help="LR reduction factor"
340
- )
341
- parser.add_argument(
342
- "--warmup_epochs", type=int, default=5, help="Warmup epochs for linear_warmup"
343
- )
344
- parser.add_argument(
345
- "--step_size", type=int, default=30, help="Step size for StepLR"
346
- )
347
- parser.add_argument(
348
- "--milestones",
349
- type=str,
350
- default=None,
351
- help="Comma-separated epochs for MultiStepLR (e.g., '30,60,90')",
352
- )
353
-
354
- # Data
355
- parser.add_argument(
356
- "--data_path", type=str, default="train_data.npz", help="Path to NPZ dataset"
357
- )
358
- parser.add_argument(
359
- "--workers",
360
- type=int,
361
- default=-1,
362
- help="DataLoader workers per GPU (-1=auto-detect based on CPU cores)",
363
- )
364
- parser.add_argument("--seed", type=int, default=2025, help="Random seed")
365
- parser.add_argument(
366
- "--deterministic",
367
- action="store_true",
368
- help="Enable deterministic mode for reproducibility (slower, disables TF32/cuDNN benchmark)",
369
- )
370
- parser.add_argument(
371
- "--cache_validate",
372
- type=str,
373
- default="sha256",
374
- choices=["sha256", "fast", "size"],
375
- help="Cache validation mode: sha256 (full hash), fast (partial), size (quick)",
376
- )
377
- parser.add_argument(
378
- "--single_channel",
379
- action="store_true",
380
- help="Confirm data is single-channel (suppress ambiguous shape warnings for shallow 3D volumes)",
381
- )
382
-
383
- # Cross-Validation
384
- parser.add_argument(
385
- "--cv",
386
- type=int,
387
- default=0,
388
- help="Enable K-fold cross-validation with K folds (0=disabled)",
389
- )
390
- parser.add_argument(
391
- "--cv_stratify",
392
- action="store_true",
393
- help="Use stratified splitting for cross-validation",
394
- )
395
- parser.add_argument(
396
- "--cv_bins",
397
- type=int,
398
- default=10,
399
- help="Number of bins for stratified CV (only with --cv_stratify)",
400
- )
401
-
402
- # Checkpointing & Resume
403
- parser.add_argument(
404
- "--resume", type=str, default=None, help="Checkpoint directory to resume from"
405
- )
406
- parser.add_argument(
407
- "--save_every",
408
- type=int,
409
- default=50,
410
- help="Save checkpoint every N epochs (0=disable)",
411
- )
412
- parser.add_argument(
413
- "--output_dir", type=str, default=".", help="Output directory for checkpoints"
414
- )
415
- parser.add_argument(
416
- "--fresh",
417
- action="store_true",
418
- help="Force fresh training, ignore existing checkpoints",
419
- )
420
-
421
- # Performance
422
- parser.add_argument(
423
- "--compile", action="store_true", help="Enable torch.compile (PyTorch 2.x)"
424
- )
425
- parser.add_argument(
426
- "--precision",
427
- type=str,
428
- default="bf16",
429
- choices=["bf16", "fp16", "no"],
430
- help="Mixed precision mode",
431
- )
432
- # Alias for consistency with wavedl-hpc (--mixed_precision)
433
- parser.add_argument(
434
- "--mixed_precision",
435
- dest="precision",
436
- type=str,
437
- choices=["bf16", "fp16", "no"],
438
- help=argparse.SUPPRESS, # Hidden: use --precision instead
439
- )
440
-
441
- # Physical Constraints
442
- parser.add_argument(
443
- "--constraint",
444
- type=str,
445
- nargs="+",
446
- default=[],
447
- help="Soft constraint expressions: 'y0 - y1*y2' (penalize violations)",
448
- )
449
-
450
- parser.add_argument(
451
- "--constraint_file",
452
- type=str,
453
- default=None,
454
- help="Python file with constraint(pred, inputs) function",
455
- )
456
- parser.add_argument(
457
- "--constraint_weight",
458
- type=float,
459
- nargs="+",
460
- default=[0.1],
461
- help="Weight(s) for soft constraints (one per constraint, or single shared weight)",
462
- )
463
- parser.add_argument(
464
- "--constraint_reduction",
465
- type=str,
466
- default="mse",
467
- choices=["mse", "mae"],
468
- help="Reduction mode for constraint penalties",
469
- )
470
-
471
- # Logging
472
- parser.add_argument(
473
- "--wandb", action="store_true", help="Enable Weights & Biases logging"
474
- )
475
- parser.add_argument(
476
- "--wandb_watch",
477
- action="store_true",
478
- help="Enable WandB gradient watching (adds overhead, useful for debugging)",
479
- )
480
- parser.add_argument(
481
- "--project_name", type=str, default="DL-Training", help="WandB project name"
482
- )
483
- parser.add_argument("--run_name", type=str, default=None, help="WandB run name")
484
-
485
- args = parser.parse_args()
486
- return args, parser # Returns (Namespace, ArgumentParser)
487
-
488
-
489
- # ==============================================================================
490
- # MAIN TRAINING FUNCTION
491
- # ==============================================================================
492
- def main():
493
- args, parser = parse_args()
494
-
495
- # Import custom model modules if specified
496
- if args.import_modules:
497
- import importlib
498
-
499
- for module_name in args.import_modules:
500
- try:
501
- # Handle both module names (my_model) and file paths (./my_model.py)
502
- if module_name.endswith(".py"):
503
- # Import from file path with unique module name
504
- import importlib.util
505
-
506
- # Derive unique module name from filename to avoid collisions
507
- base_name = os.path.splitext(os.path.basename(module_name))[0]
508
- unique_name = f"wavedl_custom_{base_name}"
509
-
510
- spec = importlib.util.spec_from_file_location(
511
- unique_name, module_name
512
- )
513
- if spec and spec.loader:
514
- module = importlib.util.module_from_spec(spec)
515
- sys.modules[unique_name] = module
516
- spec.loader.exec_module(module)
517
- print(f"✓ Imported custom module from: {module_name}")
518
- else:
519
- # Import as regular module
520
- importlib.import_module(module_name)
521
- print(f"✓ Imported module: {module_name}")
522
- except (ImportError, FileNotFoundError, SyntaxError, PermissionError) as e:
523
- print(f"✗ Failed to import '{module_name}': {e}", file=sys.stderr)
524
- if isinstance(e, FileNotFoundError):
525
- print(" File does not exist. Check the path.", file=sys.stderr)
526
- elif isinstance(e, SyntaxError):
527
- print(
528
- f" Syntax error at line {e.lineno}: {e.msg}", file=sys.stderr
529
- )
530
- elif isinstance(e, PermissionError):
531
- print(
532
- " Permission denied. Check file permissions.", file=sys.stderr
533
- )
534
- else:
535
- print(
536
- " Make sure the module is in your Python path or current directory.",
537
- file=sys.stderr,
538
- )
539
- sys.exit(1)
540
-
541
- # Handle --list_models flag
542
- if args.list_models:
543
- print("Available models:")
544
- for name in list_models():
545
- ModelClass = get_model(name)
546
- # Get first non-empty docstring line
547
- if ModelClass.__doc__:
548
- lines = [
549
- l.strip() for l in ModelClass.__doc__.splitlines() if l.strip()
550
- ]
551
- doc_first_line = lines[0] if lines else "No description"
552
- else:
553
- doc_first_line = "No description"
554
- print(f" - {name}: {doc_first_line}")
555
- sys.exit(0)
556
-
557
- # Load and merge config file if provided
558
- if args.config:
559
- from wavedl.utils.config import (
560
- load_config,
561
- merge_config_with_args,
562
- validate_config,
563
- )
564
-
565
- print(f"📄 Loading config from: {args.config}")
566
- config = load_config(args.config)
567
-
568
- # Validate config values
569
- warnings_list = validate_config(config)
570
- for w in warnings_list:
571
- print(f" ⚠ {w}")
572
-
573
- # Merge config with CLI args (CLI takes precedence via parser defaults detection)
574
- args = merge_config_with_args(config, args, parser=parser)
575
-
576
- # Handle --cv flag (cross-validation mode)
577
- if args.cv > 0:
578
- print(f"🔄 Cross-Validation Mode: {args.cv} folds")
579
- from wavedl.utils.cross_validation import run_cross_validation
580
-
581
- # Load data for CV using memory-efficient loader
582
- from wavedl.utils.data import DataSource, get_data_source
583
-
584
- data_format = DataSource.detect_format(args.data_path)
585
- source = get_data_source(data_format)
586
-
587
- # Use memory-mapped loading when available (now returns LazyDataHandle for all formats)
588
- _cv_handle = None
589
- if hasattr(source, "load_mmap"):
590
- _cv_handle = source.load_mmap(args.data_path)
591
- X, y = _cv_handle.inputs, _cv_handle.outputs
592
- else:
593
- X, y = source.load(args.data_path)
594
-
595
- # Handle sparse matrices (must materialize for CV shuffling)
596
- if hasattr(X, "__getitem__") and len(X) > 0 and hasattr(X[0], "toarray"):
597
- X = np.stack([x.toarray() for x in X])
598
-
599
- # Normalize target shape: (N,) -> (N, 1) for consistency
600
- if y.ndim == 1:
601
- y = y.reshape(-1, 1)
602
-
603
- # Validate and determine input shape (consistent with prepare_data)
604
- # Check for ambiguous shapes that could be multi-channel or shallow 3D volume
605
- sample_shape = X.shape[1:] # Per-sample shape
606
-
607
- # Same heuristic as prepare_data: detect ambiguous 3D shapes
608
- is_ambiguous_shape = (
609
- len(sample_shape) == 3 # Exactly 3D: could be (C, H, W) or (D, H, W)
610
- and sample_shape[0] <= 16 # First dim looks like channels
611
- and sample_shape[1] > 16
612
- and sample_shape[2] > 16 # Both spatial dims are large
613
- )
614
-
615
- if is_ambiguous_shape and not args.single_channel:
616
- raise ValueError(
617
- f"Ambiguous input shape detected: sample shape {sample_shape}. "
618
- f"This could be either:\n"
619
- f" - Multi-channel 2D data (C={sample_shape[0]}, H={sample_shape[1]}, W={sample_shape[2]})\n"
620
- f" - Single-channel 3D volume (D={sample_shape[0]}, H={sample_shape[1]}, W={sample_shape[2]})\n\n"
621
- f"If this is single-channel 3D/shallow volume data, use --single_channel flag.\n"
622
- f"If this is multi-channel 2D data, reshape to (N*C, H, W) with adjusted targets."
623
- )
624
-
625
- # in_shape = spatial dimensions for model registry (channel added during training)
626
- in_shape = sample_shape
627
-
628
- # Run cross-validation
629
- try:
630
- run_cross_validation(
631
- X=X,
632
- y=y,
633
- model_name=args.model,
634
- in_shape=in_shape,
635
- out_size=y.shape[1],
636
- folds=args.cv,
637
- stratify=args.cv_stratify,
638
- stratify_bins=args.cv_bins,
639
- batch_size=args.batch_size,
640
- lr=args.lr,
641
- epochs=args.epochs,
642
- patience=args.patience,
643
- weight_decay=args.weight_decay,
644
- loss_name=args.loss,
645
- optimizer_name=args.optimizer,
646
- scheduler_name=args.scheduler,
647
- output_dir=args.output_dir,
648
- workers=args.workers,
649
- seed=args.seed,
650
- )
651
- finally:
652
- # Clean up file handle if HDF5/MAT
653
- if _cv_handle is not None and hasattr(_cv_handle, "close"):
654
- try:
655
- _cv_handle.close()
656
- except Exception as e:
657
- logging.debug(f"Failed to close CV data handle: {e}")
658
- return
659
-
660
- # ==========================================================================
661
- # SYSTEM INITIALIZATION
662
- # ==========================================================================
663
- # Initialize Accelerator for DDP and mixed precision
664
- accelerator = Accelerator(
665
- mixed_precision=args.precision,
666
- log_with="wandb" if args.wandb and WANDB_AVAILABLE else None,
667
- )
668
- set_seed(args.seed)
669
-
670
- # Deterministic mode for scientific reproducibility
671
- # Disables TF32 and cuDNN benchmark for exact reproducibility (slower)
672
- if args.deterministic:
673
- torch.backends.cudnn.benchmark = False
674
- torch.backends.cudnn.deterministic = True
675
- torch.backends.cuda.matmul.allow_tf32 = False
676
- torch.backends.cudnn.allow_tf32 = False
677
- torch.use_deterministic_algorithms(True, warn_only=True)
678
- if accelerator.is_main_process:
679
- print("🔒 Deterministic mode enabled (slower but reproducible)")
680
-
681
- # Configure logging (rank 0 only prints to console)
682
- logging.basicConfig(
683
- level=logging.INFO if accelerator.is_main_process else logging.ERROR,
684
- format="%(asctime)s | %(levelname)s | %(message)s",
685
- datefmt="%H:%M:%S",
686
- )
687
- logger = logging.getLogger("Trainer")
688
-
689
- # Ensure output directory exists (critical for cache files, checkpoints, etc.)
690
- os.makedirs(args.output_dir, exist_ok=True)
691
-
692
- # Auto-detect optimal DataLoader workers if not specified
693
- if args.workers < 0:
694
- cpu_count = os.cpu_count() or 4
695
- num_gpus = accelerator.num_processes
696
- # Heuristic: 4-16 workers per GPU, bounded by available CPU cores
697
- # Increased cap from 8 to 16 for high-throughput GPUs (H100, A100)
698
- args.workers = min(16, max(2, (cpu_count - 2) // num_gpus))
699
- if accelerator.is_main_process:
700
- logger.info(
701
- f"⚙️ Auto-detected workers: {args.workers} per GPU "
702
- f"(CPUs: {cpu_count}, GPUs: {num_gpus})"
703
- )
704
-
705
- if accelerator.is_main_process:
706
- logger.info(f"🚀 Cluster Status: {accelerator.num_processes}x GPUs detected")
707
- logger.info(
708
- f" Model: {args.model} | Precision: {args.precision} | Compile: {args.compile}"
709
- )
710
- logger.info(
711
- f" Loss: {args.loss} | Optimizer: {args.optimizer} | Scheduler: {args.scheduler}"
712
- )
713
- logger.info(f" Early Stopping Patience: {args.patience} epochs")
714
- if args.save_every > 0:
715
- logger.info(f" Periodic Checkpointing: Every {args.save_every} epochs")
716
- if args.resume:
717
- logger.info(f" 📂 Resuming from: {args.resume}")
718
-
719
- # Initialize WandB
720
- if args.wandb and WANDB_AVAILABLE:
721
- accelerator.init_trackers(
722
- project_name=args.project_name,
723
- config=vars(args),
724
- init_kwargs={"wandb": {"name": args.run_name or f"{args.model}_run"}},
725
- )
726
-
727
- # ==========================================================================
728
- # DATA & MODEL LOADING
729
- # ==========================================================================
730
- train_dl, val_dl, scaler, in_shape, out_dim = prepare_data(
731
- args, logger, accelerator, cache_dir=args.output_dir
732
- )
733
-
734
- # Build model using registry
735
- model = build_model(
736
- args.model, in_shape=in_shape, out_size=out_dim, pretrained=args.pretrained
737
- )
738
-
739
- if accelerator.is_main_process:
740
- param_info = model.parameter_summary()
741
- logger.info(
742
- f" Model Parameters: {param_info['trainable_parameters']:,} trainable"
743
- )
744
- logger.info(f" Model Size: {param_info['total_mb']:.2f} MB")
745
-
746
- # Optional WandB model watching (opt-in due to overhead on large models)
747
- if (
748
- args.wandb
749
- and args.wandb_watch
750
- and WANDB_AVAILABLE
751
- and accelerator.is_main_process
752
- ):
753
- wandb.watch(model, log="gradients", log_freq=100)
754
- logger.info(" 📊 WandB gradient watching enabled")
755
-
756
- # Torch 2.0 compilation (requires compatible Triton on GPU)
757
- if args.compile:
758
- try:
759
- # Test if Triton is available - just import the package
760
- # Different Triton versions have different internal APIs, so just check base import
761
- import triton
762
-
763
- model = torch.compile(model)
764
- if accelerator.is_main_process:
765
- logger.info(" ✔ torch.compile enabled (Triton backend)")
766
- except ImportError as e:
767
- if accelerator.is_main_process:
768
- if "triton" in str(e).lower():
769
- logger.warning(
770
- " ⚠ Triton not installed or incompatible version - torch.compile disabled. "
771
- "Training will proceed without compilation."
772
- )
773
- else:
774
- logger.warning(
775
- f" ⚠ torch.compile setup failed: {e}. Continuing without compilation."
776
- )
777
- except Exception as e:
778
- if accelerator.is_main_process:
779
- logger.warning(
780
- f" ⚠ torch.compile failed: {e}. Continuing without compilation."
781
- )
782
-
783
- # ==========================================================================
784
- # OPTIMIZER, SCHEDULER & LOSS CONFIGURATION
785
- # ==========================================================================
786
- # Parse comma-separated arguments with validation
787
- try:
788
- betas_list = [float(x.strip()) for x in args.betas.split(",")]
789
- if len(betas_list) != 2:
790
- raise ValueError(
791
- f"--betas must have exactly 2 values, got {len(betas_list)}"
792
- )
793
- if not all(0.0 <= b < 1.0 for b in betas_list):
794
- raise ValueError(f"--betas values must be in [0, 1), got {betas_list}")
795
- betas = tuple(betas_list)
796
- except ValueError as e:
797
- raise ValueError(
798
- f"Invalid --betas format '{args.betas}': {e}. Expected format: '0.9,0.999'"
799
- )
800
-
801
- loss_weights = None
802
- if args.loss_weights:
803
- loss_weights = [float(x.strip()) for x in args.loss_weights.split(",")]
804
- milestones = None
805
- if args.milestones:
806
- milestones = [int(x.strip()) for x in args.milestones.split(",")]
807
-
808
- # Create optimizer using factory
809
- optimizer = get_optimizer(
810
- name=args.optimizer,
811
- params=model.get_optimizer_groups(args.lr, args.weight_decay),
812
- lr=args.lr,
813
- weight_decay=args.weight_decay,
814
- momentum=args.momentum,
815
- nesterov=args.nesterov,
816
- betas=betas,
817
- )
818
-
819
- # Create loss function using factory
820
- criterion = get_loss(
821
- name=args.loss,
822
- weights=loss_weights,
823
- delta=args.huber_delta,
824
- )
825
- # Move criterion to device (important for WeightedMSELoss buffer)
826
- criterion = criterion.to(accelerator.device)
827
-
828
- # ==========================================================================
829
- # PHYSICAL CONSTRAINTS INTEGRATION
830
- # ==========================================================================
831
- from wavedl.utils.constraints import (
832
- PhysicsConstrainedLoss,
833
- build_constraints,
834
- )
835
-
836
- # Build soft constraints
837
- soft_constraints = build_constraints(
838
- expressions=args.constraint,
839
- file_path=args.constraint_file,
840
- reduction=args.constraint_reduction,
841
- )
842
-
843
- # Wrap criterion with PhysicsConstrainedLoss if we have soft constraints
844
- if soft_constraints:
845
- # Pass output scaler so constraints can be evaluated in physical space
846
- output_mean = scaler.mean_ if hasattr(scaler, "mean_") else None
847
- output_std = scaler.scale_ if hasattr(scaler, "scale_") else None
848
- criterion = PhysicsConstrainedLoss(
849
- criterion,
850
- soft_constraints,
851
- weights=args.constraint_weight,
852
- output_mean=output_mean,
853
- output_std=output_std,
854
- )
855
- if accelerator.is_main_process:
856
- logger.info(
857
- f" 🔬 Physical constraints: {len(soft_constraints)} constraint(s) "
858
- f"with weight(s) {args.constraint_weight}"
859
- )
860
- if output_mean is not None:
861
- logger.info(
862
- " 📐 Constraints evaluated in physical space (denormalized)"
863
- )
864
-
865
- # Track if scheduler should step per batch (OneCycleLR) or per epoch
866
- scheduler_step_per_batch = not is_epoch_based(args.scheduler)
867
-
868
- # ==========================================================================
869
- # DDP Preparation Strategy:
870
- # - For batch-based schedulers (OneCycleLR): prepare DataLoaders first to get
871
- # the correct sharded batch count, then create scheduler
872
- # - For epoch-based schedulers: create scheduler before prepare (no issue)
873
- # ==========================================================================
874
- if scheduler_step_per_batch:
875
- # BATCH-BASED SCHEDULER (e.g., OneCycleLR)
876
- # Prepare model, optimizer, dataloaders FIRST to get sharded loader length
877
- model, optimizer, train_dl, val_dl = accelerator.prepare(
878
- model, optimizer, train_dl, val_dl
879
- )
880
-
881
- # Now create scheduler with the CORRECT sharded steps_per_epoch
882
- steps_per_epoch = len(train_dl) # Post-DDP sharded length
883
- scheduler = get_scheduler(
884
- name=args.scheduler,
885
- optimizer=optimizer,
886
- epochs=args.epochs,
887
- steps_per_epoch=steps_per_epoch,
888
- min_lr=args.min_lr,
889
- patience=args.scheduler_patience,
890
- factor=args.scheduler_factor,
891
- gamma=args.scheduler_factor, # For Step/MultiStep/Exponential schedulers
892
- step_size=args.step_size,
893
- milestones=milestones,
894
- warmup_epochs=args.warmup_epochs,
895
- )
896
- # Prepare scheduler separately (Accelerator wraps it for state saving)
897
- scheduler = accelerator.prepare(scheduler)
898
- else:
899
- # EPOCH-BASED SCHEDULER (plateau, cosine, step, etc.)
900
- # No batch count dependency - create scheduler before prepare
901
- scheduler = get_scheduler(
902
- name=args.scheduler,
903
- optimizer=optimizer,
904
- epochs=args.epochs,
905
- steps_per_epoch=None,
906
- min_lr=args.min_lr,
907
- patience=args.scheduler_patience,
908
- factor=args.scheduler_factor,
909
- gamma=args.scheduler_factor, # For Step/MultiStep/Exponential schedulers
910
- step_size=args.step_size,
911
- milestones=milestones,
912
- warmup_epochs=args.warmup_epochs,
913
- )
914
-
915
- # For ReduceLROnPlateau: DON'T include scheduler in accelerator.prepare()
916
- # because accelerator wraps scheduler.step() to sync across processes,
917
- # which defeats our rank-0-only stepping for correct patience counting.
918
- # Other schedulers are safe to prepare (no internal state affected by multi-call).
919
- if args.scheduler == "plateau":
920
- model, optimizer, train_dl, val_dl = accelerator.prepare(
921
- model, optimizer, train_dl, val_dl
922
- )
923
- # Scheduler stays unwrapped - we handle sync manually in training loop
924
- # But register it for checkpointing so state is saved/loaded on resume
925
- accelerator.register_for_checkpointing(scheduler)
926
- else:
927
- model, optimizer, train_dl, val_dl, scheduler = accelerator.prepare(
928
- model, optimizer, train_dl, val_dl, scheduler
929
- )
930
-
931
- # ==========================================================================
932
- # AUTO-RESUME / RESUME FROM CHECKPOINT
933
- # ==========================================================================
934
- start_epoch = 0
935
- best_val_loss = float("inf")
936
- patience_ctr = 0
937
- history: list[dict[str, Any]] = []
938
-
939
- # Define checkpoint paths
940
- best_ckpt_path = os.path.join(args.output_dir, "best_checkpoint")
941
- complete_flag_path = os.path.join(args.output_dir, "training_complete.flag")
942
-
943
- # Auto-resume logic (if not --fresh and no explicit --resume)
944
- if not args.fresh and args.resume is None:
945
- if os.path.exists(complete_flag_path):
946
- # Training already completed
947
- if accelerator.is_main_process:
948
- logger.info(
949
- "✅ Training already completed (early stopping). Use --fresh to retrain."
950
- )
951
- return # Exit gracefully
952
- elif os.path.exists(best_ckpt_path):
953
- # Incomplete training found - auto-resume
954
- args.resume = best_ckpt_path
955
- if accelerator.is_main_process:
956
- logger.info(f"🔄 Auto-resuming from: {best_ckpt_path}")
957
-
958
- if args.resume:
959
- if os.path.exists(args.resume):
960
- logger.info(f"🔄 Loading checkpoint from: {args.resume}")
961
- accelerator.load_state(args.resume)
962
-
963
- # Restore training metadata
964
- meta_path = os.path.join(args.resume, "training_meta.pkl")
965
- if os.path.exists(meta_path):
966
- with open(meta_path, "rb") as f:
967
- meta = pickle.load(f)
968
- start_epoch = meta.get("epoch", 0)
969
- best_val_loss = meta.get("best_val_loss", float("inf"))
970
- patience_ctr = meta.get("patience_ctr", 0)
971
- logger.info(
972
- f" ✅ Restored: Epoch {start_epoch}, Best Loss: {best_val_loss:.6f}"
973
- )
974
- else:
975
- logger.warning(
976
- " ⚠️ training_meta.pkl not found, starting from epoch 0"
977
- )
978
-
979
- # Restore history
980
- history_path = os.path.join(args.output_dir, "training_history.csv")
981
- if os.path.exists(history_path):
982
- history = pd.read_csv(history_path).to_dict("records")
983
- logger.info(f" ✅ Loaded {len(history)} epochs from history")
984
- else:
985
- raise FileNotFoundError(f"Checkpoint not found: {args.resume}")
986
-
987
- # ==========================================================================
988
- # PHYSICAL METRIC SETUP
989
- # ==========================================================================
990
- # Physical MAE = normalized MAE * scaler.scale_
991
- phys_scale = torch.tensor(
992
- scaler.scale_, device=accelerator.device, dtype=torch.float32
993
- )
994
-
995
- # ==========================================================================
996
- # TRAINING LOOP
997
- # ==========================================================================
998
- # Dynamic console header
999
- if accelerator.is_main_process:
1000
- base_cols = ["Ep", "TrnLoss", "ValLoss", "R2", "PCC", "GradN", "LR", "MAE_Avg"]
1001
- param_cols = [f"MAE_P{i}" for i in range(out_dim)]
1002
- header = "{:<4} | {:<8} | {:<8} | {:<6} | {:<6} | {:<6} | {:<8} | {:<8}".format(
1003
- *base_cols
1004
- )
1005
- header += " | " + " | ".join([f"{c:<8}" for c in param_cols])
1006
- logger.info("=" * len(header))
1007
- logger.info(header)
1008
- logger.info("=" * len(header))
1009
-
1010
- try:
1011
- total_training_time = 0.0
1012
-
1013
- for epoch in range(start_epoch, args.epochs):
1014
- epoch_start_time = time.time()
1015
-
1016
- # ==================== TRAINING PHASE ====================
1017
- model.train()
1018
- # Use GPU tensor for loss accumulation to avoid .item() sync per batch
1019
- train_loss_sum = torch.tensor(0.0, device=accelerator.device)
1020
- train_samples = 0
1021
- grad_norm_tracker = MetricTracker()
1022
-
1023
- pbar = tqdm(
1024
- train_dl,
1025
- disable=not accelerator.is_main_process,
1026
- leave=False,
1027
- desc=f"Epoch {epoch + 1}",
1028
- )
1029
-
1030
- for x, y in pbar:
1031
- with accelerator.accumulate(model):
1032
- # Use mixed precision for forward pass (respects --precision flag)
1033
- with accelerator.autocast():
1034
- pred = model(x)
1035
- # Pass inputs for input-dependent constraints (x_mean, x[...], etc.)
1036
- if isinstance(criterion, PhysicsConstrainedLoss):
1037
- loss = criterion(pred, y, x)
1038
- else:
1039
- loss = criterion(pred, y)
1040
-
1041
- accelerator.backward(loss)
1042
-
1043
- if accelerator.sync_gradients:
1044
- grad_norm = accelerator.clip_grad_norm_(
1045
- model.parameters(), args.grad_clip
1046
- )
1047
- if grad_norm is not None:
1048
- grad_norm_tracker.update(grad_norm.item())
1049
-
1050
- optimizer.step()
1051
- optimizer.zero_grad(set_to_none=True) # Faster than zero_grad()
1052
-
1053
- # Per-batch LR scheduling (e.g., OneCycleLR)
1054
- if scheduler_step_per_batch:
1055
- scheduler.step()
1056
-
1057
- # Accumulate as tensors to avoid .item() sync per batch
1058
- train_loss_sum += loss.detach() * x.size(0)
1059
- train_samples += x.size(0)
1060
-
1061
- # Single .item() call at end of epoch (reduces GPU sync overhead)
1062
- train_loss_scalar = train_loss_sum.item()
1063
-
1064
- # Synchronize training metrics across GPUs
1065
- global_loss = accelerator.reduce(
1066
- torch.tensor([train_loss_scalar], device=accelerator.device),
1067
- reduction="sum",
1068
- ).item()
1069
- global_samples = accelerator.reduce(
1070
- torch.tensor([train_samples], device=accelerator.device),
1071
- reduction="sum",
1072
- ).item()
1073
- avg_train_loss = global_loss / global_samples
1074
-
1075
- # ==================== VALIDATION PHASE ====================
1076
- model.eval()
1077
- # Use GPU tensor for loss accumulation (consistent with training phase)
1078
- val_loss_sum = torch.tensor(0.0, device=accelerator.device)
1079
- val_mae_sum = torch.zeros(out_dim, device=accelerator.device)
1080
- val_samples = 0
1081
-
1082
- # Accumulate predictions locally ON CPU to prevent GPU OOM
1083
- local_preds = []
1084
- local_targets = []
1085
-
1086
- with torch.inference_mode():
1087
- for x, y in val_dl:
1088
- # Use mixed precision for validation (consistent with training)
1089
- with accelerator.autocast():
1090
- pred = model(x)
1091
- # Pass inputs for input-dependent constraints
1092
- if isinstance(criterion, PhysicsConstrainedLoss):
1093
- loss = criterion(pred, y, x)
1094
- else:
1095
- loss = criterion(pred, y)
1096
-
1097
- val_loss_sum += loss.detach() * x.size(0)
1098
- val_samples += x.size(0)
1099
-
1100
- # Physical MAE
1101
- mae_batch = torch.abs((pred - y) * phys_scale).sum(dim=0)
1102
- val_mae_sum += mae_batch
1103
-
1104
- # Store on CPU (critical for large val sets)
1105
- local_preds.append(pred.detach().cpu())
1106
- local_targets.append(y.detach().cpu())
1107
-
1108
- # Concatenate locally (keep on GPU for gather_for_metrics compatibility)
1109
- local_preds_cat = torch.cat(local_preds)
1110
- local_targets_cat = torch.cat(local_targets)
1111
-
1112
- # Gather predictions and targets using Accelerate's CPU-efficient utility
1113
- # gather_for_metrics handles:
1114
- # - DDP padding removal (no need to trim manually)
1115
- # - Efficient cross-rank gathering without GPU memory spike
1116
- # - Returns concatenated tensors on CPU for metric computation
1117
- if accelerator.num_processes > 1:
1118
- # Move to GPU for gather (required by NCCL), then back to CPU
1119
- # gather_for_metrics is more memory-efficient than manual gather
1120
- # as it processes in chunks internally
1121
- gathered_preds = accelerator.gather_for_metrics(
1122
- local_preds_cat.to(accelerator.device)
1123
- ).cpu()
1124
- gathered_targets = accelerator.gather_for_metrics(
1125
- local_targets_cat.to(accelerator.device)
1126
- ).cpu()
1127
- else:
1128
- # Single-GPU mode: no gathering needed
1129
- gathered_preds = local_preds_cat
1130
- gathered_targets = local_targets_cat
1131
-
1132
- # Synchronize validation metrics (scalars only - efficient)
1133
- val_loss_scalar = val_loss_sum.item()
1134
- val_metrics = torch.cat(
1135
- [
1136
- torch.tensor([val_loss_scalar], device=accelerator.device),
1137
- val_mae_sum,
1138
- ]
1139
- )
1140
- val_metrics_sync = accelerator.reduce(val_metrics, reduction="sum")
1141
-
1142
- total_val_samples = accelerator.reduce(
1143
- torch.tensor([val_samples], device=accelerator.device), reduction="sum"
1144
- ).item()
1145
-
1146
- avg_val_loss = val_metrics_sync[0].item() / total_val_samples
1147
- # Cast to float32 before numpy (bf16 tensors can't convert directly)
1148
- avg_mae_per_param = (
1149
- (val_metrics_sync[1:] / total_val_samples).float().cpu().numpy()
1150
- )
1151
- avg_mae = avg_mae_per_param.mean()
1152
-
1153
- # ==================== LOGGING & CHECKPOINTING ====================
1154
- if accelerator.is_main_process:
1155
- # Scientific metrics - cast to float32 before numpy
1156
- # gather_for_metrics already handles DDP padding removal
1157
- y_pred = gathered_preds.float().numpy()
1158
- y_true = gathered_targets.float().numpy()
1159
-
1160
- # Guard against tiny validation sets (R² undefined for <2 samples)
1161
- if len(y_true) >= 2:
1162
- r2 = r2_score(y_true, y_pred)
1163
- else:
1164
- r2 = float("nan")
1165
- pcc = calc_pearson(y_true, y_pred)
1166
- current_lr = get_lr(optimizer)
1167
-
1168
- # Update history
1169
- epoch_end_time = time.time()
1170
- epoch_time = epoch_end_time - epoch_start_time
1171
- total_training_time += epoch_time
1172
-
1173
- epoch_stats = {
1174
- "epoch": epoch + 1,
1175
- "train_loss": avg_train_loss,
1176
- "val_loss": avg_val_loss,
1177
- "val_r2": r2,
1178
- "val_pearson": pcc,
1179
- "val_mae_avg": avg_mae,
1180
- "grad_norm": grad_norm_tracker.avg,
1181
- "lr": current_lr,
1182
- "epoch_time": round(epoch_time, 2),
1183
- "total_time": round(total_training_time, 2),
1184
- }
1185
- for i, mae in enumerate(avg_mae_per_param):
1186
- epoch_stats[f"MAE_Phys_P{i}"] = mae
1187
-
1188
- history.append(epoch_stats)
1189
-
1190
- # Console display
1191
- base_str = f"{epoch + 1:<4} | {avg_train_loss:<8.4f} | {avg_val_loss:<8.4f} | {r2:<6.4f} | {pcc:<6.4f} | {grad_norm_tracker.avg:<6.4f} | {current_lr:<8.2e} | {avg_mae:<8.4f}"
1192
- param_str = " | ".join([f"{m:<8.4f}" for m in avg_mae_per_param])
1193
- logger.info(f"{base_str} | {param_str}")
1194
-
1195
- # WandB logging
1196
- if args.wandb and WANDB_AVAILABLE:
1197
- log_dict = {
1198
- "main/train_loss": avg_train_loss,
1199
- "main/val_loss": avg_val_loss,
1200
- "metrics/r2_score": r2,
1201
- "metrics/pearson_corr": pcc,
1202
- "metrics/mae_avg": avg_mae,
1203
- "system/grad_norm": grad_norm_tracker.avg,
1204
- "hyper/lr": current_lr,
1205
- }
1206
- for i, mae in enumerate(avg_mae_per_param):
1207
- log_dict[f"mae_detailed/P{i}"] = mae
1208
-
1209
- # Periodic scatter plots
1210
- if (epoch % 5 == 0) or (avg_val_loss < best_val_loss):
1211
- real_true = scaler.inverse_transform(y_true)
1212
- real_pred = scaler.inverse_transform(y_pred)
1213
- fig = plot_scientific_scatter(real_true, real_pred)
1214
- log_dict["plots/scatter_analysis"] = wandb.Image(fig)
1215
- plt.close(fig)
1216
-
1217
- accelerator.log(log_dict)
1218
-
1219
- # ==========================================================================
1220
- # DDP-SAFE CHECKPOINT LOGIC
1221
- # ==========================================================================
1222
- # Step 1: Determine if this is the best epoch (BEFORE updating best_val_loss)
1223
- is_best_epoch = False
1224
- if accelerator.is_main_process:
1225
- if avg_val_loss < best_val_loss:
1226
- is_best_epoch = True
1227
-
1228
- # Step 2: Broadcast decision to all ranks (required for save_state)
1229
- is_best_epoch = broadcast_early_stop(is_best_epoch, accelerator)
1230
-
1231
- # Step 3: Save checkpoint with all ranks participating
1232
- if is_best_epoch:
1233
- ckpt_dir = os.path.join(args.output_dir, "best_checkpoint")
1234
- with suppress_accelerate_logging():
1235
- accelerator.save_state(ckpt_dir, safe_serialization=False)
1236
-
1237
- # Step 4: Rank 0 handles metadata and updates tracking variables
1238
- if accelerator.is_main_process:
1239
- best_val_loss = avg_val_loss # Update AFTER checkpoint saved
1240
- patience_ctr = 0
1241
-
1242
- with open(os.path.join(ckpt_dir, "training_meta.pkl"), "wb") as f:
1243
- pickle.dump(
1244
- {
1245
- "epoch": epoch + 1,
1246
- "best_val_loss": best_val_loss,
1247
- "patience_ctr": patience_ctr,
1248
- # Model info for auto-detection during inference
1249
- "model_name": args.model,
1250
- "in_shape": in_shape,
1251
- "out_dim": out_dim,
1252
- },
1253
- f,
1254
- )
1255
-
1256
- # Unwrap model for saving (handle torch.compile compatibility)
1257
- try:
1258
- unwrapped = accelerator.unwrap_model(model)
1259
- except KeyError:
1260
- # torch.compile model may not have _orig_mod in expected location
1261
- # Fall back to getting the module directly
1262
- unwrapped = model.module if hasattr(model, "module") else model
1263
- # If still compiled, try to get the underlying model
1264
- if hasattr(unwrapped, "_orig_mod"):
1265
- unwrapped = unwrapped._orig_mod
1266
-
1267
- torch.save(
1268
- unwrapped.state_dict(),
1269
- os.path.join(args.output_dir, "best_model_weights.pth"),
1270
- )
1271
-
1272
- # Copy scaler to checkpoint for portability
1273
- scaler_src = os.path.join(args.output_dir, "scaler.pkl")
1274
- scaler_dst = os.path.join(ckpt_dir, "scaler.pkl")
1275
- if os.path.exists(scaler_src) and not os.path.exists(scaler_dst):
1276
- shutil.copy2(scaler_src, scaler_dst)
1277
-
1278
- logger.info(
1279
- f" 💾 Best model saved (val_loss: {best_val_loss:.6f})"
1280
- )
1281
-
1282
- # Also save CSV on best model (ensures progress is saved)
1283
- pd.DataFrame(history).to_csv(
1284
- os.path.join(args.output_dir, "training_history.csv"),
1285
- index=False,
1286
- )
1287
- else:
1288
- if accelerator.is_main_process:
1289
- patience_ctr += 1
1290
-
1291
- # Periodic checkpoint (all ranks participate in save_state)
1292
- periodic_checkpoint_needed = (
1293
- args.save_every > 0 and (epoch + 1) % args.save_every == 0
1294
- )
1295
- if periodic_checkpoint_needed:
1296
- ckpt_name = f"epoch_{epoch + 1}_checkpoint"
1297
- ckpt_dir = os.path.join(args.output_dir, ckpt_name)
1298
- with suppress_accelerate_logging():
1299
- accelerator.save_state(ckpt_dir, safe_serialization=False)
1300
-
1301
- if accelerator.is_main_process:
1302
- with open(os.path.join(ckpt_dir, "training_meta.pkl"), "wb") as f:
1303
- pickle.dump(
1304
- {
1305
- "epoch": epoch + 1,
1306
- "best_val_loss": best_val_loss,
1307
- "patience_ctr": patience_ctr,
1308
- # Model info for auto-detection during inference
1309
- "model_name": args.model,
1310
- "in_shape": in_shape,
1311
- "out_dim": out_dim,
1312
- },
1313
- f,
1314
- )
1315
- logger.info(f" 📁 Periodic checkpoint: {ckpt_name}")
1316
-
1317
- # Save CSV with each checkpoint (keeps logs in sync with model state)
1318
- pd.DataFrame(history).to_csv(
1319
- os.path.join(args.output_dir, "training_history.csv"),
1320
- index=False,
1321
- )
1322
-
1323
- # Learning rate scheduling (epoch-based schedulers only)
1324
- # NOTE: For ReduceLROnPlateau with DDP, we must step only on main process
1325
- # to avoid patience counter being incremented by all GPU processes.
1326
- # Then we sync the new LR to all processes to keep them consistent.
1327
- if not scheduler_step_per_batch:
1328
- if args.scheduler == "plateau":
1329
- # Step only on main process to avoid multi-GPU patience bug
1330
- if accelerator.is_main_process:
1331
- scheduler.step(avg_val_loss)
1332
-
1333
- # Sync LR across all processes after main process updates it
1334
- accelerator.wait_for_everyone()
1335
-
1336
- # Broadcast new LR from rank 0 to all processes
1337
- if dist.is_initialized():
1338
- if accelerator.is_main_process:
1339
- new_lr = optimizer.param_groups[0]["lr"]
1340
- else:
1341
- new_lr = 0.0
1342
- new_lr_tensor = torch.tensor(
1343
- new_lr, device=accelerator.device, dtype=torch.float32
1344
- )
1345
- dist.broadcast(new_lr_tensor, src=0)
1346
- # Update LR on non-main processes
1347
- if not accelerator.is_main_process:
1348
- for param_group in optimizer.param_groups:
1349
- param_group["lr"] = new_lr_tensor.item()
1350
- else:
1351
- scheduler.step()
1352
-
1353
- # DDP-safe early stopping
1354
- should_stop = (
1355
- patience_ctr >= args.patience if accelerator.is_main_process else False
1356
- )
1357
- if broadcast_early_stop(should_stop, accelerator):
1358
- if accelerator.is_main_process:
1359
- logger.info(
1360
- f"🛑 Early stopping at epoch {epoch + 1} (patience={args.patience})"
1361
- )
1362
- # Create completion flag to prevent auto-resume
1363
- with open(
1364
- os.path.join(args.output_dir, "training_complete.flag"), "w"
1365
- ) as f:
1366
- f.write(
1367
- f"Training completed via early stopping at epoch {epoch + 1}\n"
1368
- )
1369
- break
1370
-
1371
- except KeyboardInterrupt:
1372
- logger.warning("Training interrupted. Saving emergency checkpoint...")
1373
- with suppress_accelerate_logging():
1374
- accelerator.save_state(
1375
- os.path.join(args.output_dir, "interrupted_checkpoint"),
1376
- safe_serialization=False,
1377
- )
1378
-
1379
- except Exception as e:
1380
- logger.error(f"Critical error: {e}", exc_info=True)
1381
- raise
1382
-
1383
- else:
1384
- # Training completed normally (reached max epochs without early stopping)
1385
- # Create completion flag to prevent auto-resume on re-run
1386
- if accelerator.is_main_process:
1387
- if not os.path.exists(complete_flag_path):
1388
- with open(complete_flag_path, "w") as f:
1389
- f.write(f"Training completed normally after {args.epochs} epochs\n")
1390
- logger.info(f"✅ Training completed after {args.epochs} epochs")
1391
-
1392
- finally:
1393
- # Final CSV write to capture all epochs (handles non-multiple-of-10 endings)
1394
- if accelerator.is_main_process and len(history) > 0:
1395
- pd.DataFrame(history).to_csv(
1396
- os.path.join(args.output_dir, "training_history.csv"),
1397
- index=False,
1398
- )
1399
-
1400
- # Generate training curves plot (PNG + SVG)
1401
- if accelerator.is_main_process and len(history) > 0:
1402
- try:
1403
- fig = create_training_curves(history, show_lr=True)
1404
- for fmt in ["png", "svg"]:
1405
- fig.savefig(
1406
- os.path.join(args.output_dir, f"training_curves.{fmt}"),
1407
- dpi=FIGURE_DPI,
1408
- bbox_inches="tight",
1409
- )
1410
- plt.close(fig)
1411
- logger.info("✔ Saved: training_curves.png, training_curves.svg")
1412
- except Exception as e:
1413
- logger.warning(f"Could not generate training curves: {e}")
1414
-
1415
- if args.wandb and WANDB_AVAILABLE:
1416
- accelerator.end_training()
1417
-
1418
- # Clean up distributed process group to prevent resource leak warning
1419
- if torch.distributed.is_initialized():
1420
- torch.distributed.destroy_process_group()
1421
-
1422
- logger.info("Training completed.")
1423
-
1424
-
1425
- if __name__ == "__main__":
1426
- try:
1427
- torch.multiprocessing.set_start_method("spawn")
1428
- except RuntimeError:
1429
- pass
1430
- main()
1
+ """
2
+ WaveDL - Deep Learning for Wave-based Inverse Problems
3
+ =======================================================
4
+ Target Environment: NVIDIA HPC GPUs (Multi-GPU DDP) | PyTorch 2.x | Python 3.11+
5
+
6
+ A modular training framework for wave-based inverse problems and regression:
7
+ 1. HPC-Grade DDP Training: BF16/FP16 mixed precision with torch.compile support
8
+ 2. Dynamic Model Selection: Use --model flag to select any registered architecture
9
+ 3. Zero-Copy Data Engine: Memmap-backed datasets for large-scale training
10
+ 4. Physics-Aware Metrics: Real-time physical MAE with proper unscaling
11
+ 5. Robust Checkpointing: Resume training, periodic saves, and training curves
12
+ 6. Deep Observability: WandB integration with scatter analysis
13
+
14
+ Usage:
15
+ # Recommended: Using the HPC launcher (handles accelerate configuration)
16
+ wavedl-hpc --model cnn --batch_size 128 --mixed_precision bf16 --wandb
17
+
18
+ # Or direct training module (use --precision, not --mixed_precision)
19
+ accelerate launch -m wavedl.train --model cnn --batch_size 128 --precision bf16
20
+
21
+ # Multi-GPU with explicit config
22
+ wavedl-hpc --num_gpus 4 --mixed_precision bf16 --model cnn --wandb
23
+
24
+ # Resume from checkpoint
25
+ accelerate launch -m wavedl.train --model cnn --resume best_checkpoint --wandb
26
+
27
+ # List available models
28
+ wavedl-train --list_models
29
+
30
+ Note:
31
+ - wavedl-hpc: Uses --mixed_precision (passed to accelerate launch)
32
+ - wavedl.train: Uses --precision (internal module flag)
33
+ Both control the same behavior; use the appropriate flag for your entry point.
34
+
35
+ Author: Ductho Le (ductho.le@outlook.com)
36
+ """
37
+
38
+ from __future__ import annotations
39
+
40
+ # =============================================================================
41
+ # HPC Environment Setup (MUST be before any library imports)
42
+ # =============================================================================
43
+ # Auto-configure writable cache directories when home is not writable.
44
+ # Uses current working directory as fallback - works on HPC and local machines.
45
+ import os
46
+
47
+
48
+ def _setup_cache_dir(env_var: str, subdir: str) -> None:
49
+ """Set cache directory to CWD if home is not writable."""
50
+ if env_var in os.environ:
51
+ return # User already set, respect their choice
52
+
53
+ # Check if home is writable
54
+ home = os.path.expanduser("~")
55
+ if os.access(home, os.W_OK):
56
+ return # Home is writable, let library use defaults
57
+
58
+ # Home not writable - use current working directory
59
+ cache_path = os.path.join(os.getcwd(), f".{subdir}")
60
+ os.makedirs(cache_path, exist_ok=True)
61
+ os.environ[env_var] = cache_path
62
+
63
+
64
+ # Configure cache directories (before any library imports)
65
+ _setup_cache_dir("TORCH_HOME", "torch_cache")
66
+ _setup_cache_dir("MPLCONFIGDIR", "matplotlib")
67
+ _setup_cache_dir("FONTCONFIG_CACHE", "fontconfig")
68
+ _setup_cache_dir("XDG_DATA_HOME", "local/share")
69
+ _setup_cache_dir("XDG_STATE_HOME", "local/state")
70
+ _setup_cache_dir("XDG_CACHE_HOME", "cache")
71
+
72
+
73
+ def _setup_per_rank_compile_cache() -> None:
74
+ """Set per-GPU Triton/Inductor cache to prevent multi-process race warnings.
75
+
76
+ When using torch.compile with multiple GPUs, all processes try to write to
77
+ the same cache directory, causing 'Directory is not empty - skipping!' warnings.
78
+ This gives each GPU rank its own isolated cache subdirectory.
79
+ """
80
+ # Get local rank from environment (set by accelerate/torchrun)
81
+ local_rank = os.environ.get("LOCAL_RANK", "0")
82
+
83
+ # Get cache base from environment or use CWD
84
+ cache_base = os.environ.get(
85
+ "TRITON_CACHE_DIR", os.path.join(os.getcwd(), ".triton_cache")
86
+ )
87
+
88
+ # Set per-rank cache directories
89
+ os.environ["TRITON_CACHE_DIR"] = os.path.join(cache_base, f"rank_{local_rank}")
90
+ os.environ["TORCHINDUCTOR_CACHE_DIR"] = os.path.join(
91
+ os.environ.get(
92
+ "TORCHINDUCTOR_CACHE_DIR", os.path.join(os.getcwd(), ".inductor_cache")
93
+ ),
94
+ f"rank_{local_rank}",
95
+ )
96
+
97
+ # Create directories
98
+ os.makedirs(os.environ["TRITON_CACHE_DIR"], exist_ok=True)
99
+ os.makedirs(os.environ["TORCHINDUCTOR_CACHE_DIR"], exist_ok=True)
100
+
101
+
102
+ # Setup per-rank compile caches (before torch imports)
103
+ _setup_per_rank_compile_cache()
104
+
105
+ # =============================================================================
106
+ # Standard imports (after environment setup)
107
+ # =============================================================================
108
+ import argparse
109
+ import logging
110
+ import pickle
111
+ import shutil
112
+ import sys
113
+ import time
114
+ import warnings
115
+ from typing import Any
116
+
117
+
118
+ # Suppress Pydantic warnings from accelerate's internal Field() usage
119
+ warnings.filterwarnings("ignore", category=UserWarning, module="pydantic")
120
+
121
+ import matplotlib.pyplot as plt
122
+ import numpy as np
123
+ import pandas as pd
124
+ import torch
125
+ import torch.distributed as dist
126
+ from accelerate import Accelerator
127
+ from accelerate.utils import set_seed
128
+ from sklearn.metrics import r2_score
129
+ from tqdm.auto import tqdm
130
+
131
+ from wavedl.models import build_model, get_model, list_models
132
+ from wavedl.utils import (
133
+ FIGURE_DPI,
134
+ MetricTracker,
135
+ broadcast_early_stop,
136
+ calc_pearson,
137
+ create_training_curves,
138
+ get_loss,
139
+ get_lr,
140
+ get_optimizer,
141
+ get_scheduler,
142
+ is_epoch_based,
143
+ list_losses,
144
+ list_optimizers,
145
+ list_schedulers,
146
+ plot_scientific_scatter,
147
+ prepare_data,
148
+ )
149
+
150
+
151
+ try:
152
+ import wandb
153
+
154
+ WANDB_AVAILABLE = True
155
+ except ImportError:
156
+ WANDB_AVAILABLE = False
157
+
158
+ # ==============================================================================
159
+ # RUNTIME CONFIGURATION (post-import)
160
+ # ==============================================================================
161
+ # Configure matplotlib paths for HPC systems without writable home directories
162
+ os.environ.setdefault("MPLCONFIGDIR", os.getenv("TMPDIR", "/tmp") + "/matplotlib")
163
+ os.environ.setdefault("FONTCONFIG_PATH", "/etc/fonts")
164
+
165
+ # Suppress warnings from known-noisy libraries, but preserve legitimate warnings
166
+ # from torch/numpy about NaN, dtype, and numerical issues.
167
+ warnings.filterwarnings("ignore", category=FutureWarning)
168
+ warnings.filterwarnings("ignore", category=DeprecationWarning)
169
+ # Pydantic v1/v2 compatibility warnings
170
+ warnings.filterwarnings("ignore", module="pydantic")
171
+ warnings.filterwarnings("ignore", message=".*UnsupportedFieldAttributeWarning.*")
172
+ # Transformer library warnings (loading configs, etc.)
173
+ warnings.filterwarnings("ignore", module="transformers")
174
+ # Accelerate verbose messages
175
+ warnings.filterwarnings("ignore", module="accelerate")
176
+ # torch.compile backend selection warnings
177
+ warnings.filterwarnings("ignore", message=".*TorchDynamo.*")
178
+ warnings.filterwarnings("ignore", message=".*Dynamo is not supported.*")
179
+ # Note: UserWarning from torch/numpy core is NOT suppressed to preserve
180
+ # legitimate warnings about NaN values, dtype mismatches, etc.
181
+
182
+ # ==============================================================================
183
+ # GPU PERFORMANCE OPTIMIZATIONS (Ampere/Hopper: A100, H100)
184
+ # ==============================================================================
185
+ # Enable TF32 for faster matmul (safe precision for training, ~2x speedup)
186
+ torch.backends.cuda.matmul.allow_tf32 = True
187
+ torch.backends.cudnn.allow_tf32 = True
188
+ torch.set_float32_matmul_precision("high") # Use TF32 for float32 ops
189
+
190
+ # Enable cuDNN autotuning for fixed-size inputs (CNN-like models benefit most)
191
+ # Note: First few batches may be slower due to benchmarking
192
+ torch.backends.cudnn.benchmark = True
193
+
194
+
195
+ # ==============================================================================
196
+ # LOGGING UTILITIES
197
+ # ==============================================================================
198
+ from contextlib import contextmanager
199
+
200
+
201
+ @contextmanager
202
+ def suppress_accelerate_logging():
203
+ """Temporarily suppress accelerate's verbose checkpoint save messages."""
204
+ accelerate_logger = logging.getLogger("accelerate.checkpointing")
205
+ original_level = accelerate_logger.level
206
+ accelerate_logger.setLevel(logging.WARNING)
207
+ try:
208
+ yield
209
+ finally:
210
+ accelerate_logger.setLevel(original_level)
211
+
212
+
213
+ # ==============================================================================
214
+ # ARGUMENT PARSING
215
+ # ==============================================================================
216
+ def parse_args() -> argparse.Namespace:
217
+ """Parse command-line arguments with comprehensive options."""
218
+ parser = argparse.ArgumentParser(
219
+ description="Universal DDP Training Pipeline",
220
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter,
221
+ )
222
+
223
+ # Model Selection
224
+ parser.add_argument(
225
+ "--model",
226
+ type=str,
227
+ default="cnn",
228
+ help=f"Model architecture to train. Available: {list_models()}",
229
+ )
230
+ parser.add_argument(
231
+ "--list_models", action="store_true", help="List all available models and exit"
232
+ )
233
+ parser.add_argument(
234
+ "--import",
235
+ dest="import_modules",
236
+ type=str,
237
+ nargs="+",
238
+ default=[],
239
+ help="Python modules to import before training (for custom models)",
240
+ )
241
+ parser.add_argument(
242
+ "--no_pretrained",
243
+ dest="pretrained",
244
+ action="store_false",
245
+ help="Train from scratch without pretrained weights (default: use pretrained)",
246
+ )
247
+ parser.set_defaults(pretrained=True)
248
+
249
+ # Configuration File
250
+ parser.add_argument(
251
+ "--config",
252
+ type=str,
253
+ default=None,
254
+ help="Path to YAML config file. CLI args override config values.",
255
+ )
256
+
257
+ # Hyperparameters
258
+ parser.add_argument(
259
+ "--batch_size", type=int, default=128, help="Batch size per GPU"
260
+ )
261
+ parser.add_argument("--lr", type=float, default=1e-3, help="Initial learning rate")
262
+ parser.add_argument(
263
+ "--epochs", type=int, default=1000, help="Maximum training epochs"
264
+ )
265
+ parser.add_argument(
266
+ "--patience", type=int, default=20, help="Early stopping patience"
267
+ )
268
+ parser.add_argument("--weight_decay", type=float, default=1e-4, help="Weight decay")
269
+ parser.add_argument(
270
+ "--grad_clip", type=float, default=1.0, help="Gradient clipping norm"
271
+ )
272
+
273
+ # Loss Function
274
+ parser.add_argument(
275
+ "--loss",
276
+ type=str,
277
+ default="mse",
278
+ choices=["mse", "mae", "huber", "smooth_l1", "log_cosh", "weighted_mse"],
279
+ help=f"Loss function for training. Available: {list_losses()}",
280
+ )
281
+ parser.add_argument(
282
+ "--huber_delta", type=float, default=1.0, help="Delta for Huber loss"
283
+ )
284
+ parser.add_argument(
285
+ "--loss_weights",
286
+ type=str,
287
+ default=None,
288
+ help="Comma-separated weights for weighted_mse (e.g., '1.0,2.0,1.0')",
289
+ )
290
+
291
+ # Optimizer
292
+ parser.add_argument(
293
+ "--optimizer",
294
+ type=str,
295
+ default="adamw",
296
+ choices=["adamw", "adam", "sgd", "nadam", "radam", "rmsprop"],
297
+ help=f"Optimizer for training. Available: {list_optimizers()}",
298
+ )
299
+ parser.add_argument(
300
+ "--momentum", type=float, default=0.9, help="Momentum for SGD/RMSprop"
301
+ )
302
+ parser.add_argument(
303
+ "--nesterov", action="store_true", help="Use Nesterov momentum (SGD)"
304
+ )
305
+ parser.add_argument(
306
+ "--betas",
307
+ type=str,
308
+ default="0.9,0.999",
309
+ help="Betas for Adam variants (comma-separated)",
310
+ )
311
+
312
+ # Learning Rate Scheduler
313
+ parser.add_argument(
314
+ "--scheduler",
315
+ type=str,
316
+ default="plateau",
317
+ choices=[
318
+ "plateau",
319
+ "cosine",
320
+ "cosine_restarts",
321
+ "onecycle",
322
+ "step",
323
+ "multistep",
324
+ "exponential",
325
+ "linear_warmup",
326
+ ],
327
+ help=f"LR scheduler. Available: {list_schedulers()}",
328
+ )
329
+ parser.add_argument(
330
+ "--scheduler_patience",
331
+ type=int,
332
+ default=10,
333
+ help="Patience for ReduceLROnPlateau",
334
+ )
335
+ parser.add_argument(
336
+ "--min_lr", type=float, default=1e-6, help="Minimum learning rate"
337
+ )
338
+ parser.add_argument(
339
+ "--scheduler_factor", type=float, default=0.5, help="LR reduction factor"
340
+ )
341
+ parser.add_argument(
342
+ "--warmup_epochs", type=int, default=5, help="Warmup epochs for linear_warmup"
343
+ )
344
+ parser.add_argument(
345
+ "--step_size", type=int, default=30, help="Step size for StepLR"
346
+ )
347
+ parser.add_argument(
348
+ "--milestones",
349
+ type=str,
350
+ default=None,
351
+ help="Comma-separated epochs for MultiStepLR (e.g., '30,60,90')",
352
+ )
353
+
354
+ # Data
355
+ parser.add_argument(
356
+ "--data_path", type=str, default="train_data.npz", help="Path to NPZ dataset"
357
+ )
358
+ parser.add_argument(
359
+ "--workers",
360
+ type=int,
361
+ default=-1,
362
+ help="DataLoader workers per GPU (-1=auto-detect based on CPU cores)",
363
+ )
364
+ parser.add_argument("--seed", type=int, default=2025, help="Random seed")
365
+ parser.add_argument(
366
+ "--deterministic",
367
+ action="store_true",
368
+ help="Enable deterministic mode for reproducibility (slower, disables TF32/cuDNN benchmark)",
369
+ )
370
+ parser.add_argument(
371
+ "--cache_validate",
372
+ type=str,
373
+ default="sha256",
374
+ choices=["sha256", "fast", "size"],
375
+ help="Cache validation mode: sha256 (full hash), fast (partial), size (quick)",
376
+ )
377
+ parser.add_argument(
378
+ "--single_channel",
379
+ action="store_true",
380
+ help="Confirm data is single-channel (suppress ambiguous shape warnings for shallow 3D volumes)",
381
+ )
382
+
383
+ # Cross-Validation
384
+ parser.add_argument(
385
+ "--cv",
386
+ type=int,
387
+ default=0,
388
+ help="Enable K-fold cross-validation with K folds (0=disabled)",
389
+ )
390
+ parser.add_argument(
391
+ "--cv_stratify",
392
+ action="store_true",
393
+ help="Use stratified splitting for cross-validation",
394
+ )
395
+ parser.add_argument(
396
+ "--cv_bins",
397
+ type=int,
398
+ default=10,
399
+ help="Number of bins for stratified CV (only with --cv_stratify)",
400
+ )
401
+
402
+ # Checkpointing & Resume
403
+ parser.add_argument(
404
+ "--resume", type=str, default=None, help="Checkpoint directory to resume from"
405
+ )
406
+ parser.add_argument(
407
+ "--save_every",
408
+ type=int,
409
+ default=50,
410
+ help="Save checkpoint every N epochs (0=disable)",
411
+ )
412
+ parser.add_argument(
413
+ "--output_dir", type=str, default=".", help="Output directory for checkpoints"
414
+ )
415
+ parser.add_argument(
416
+ "--fresh",
417
+ action="store_true",
418
+ help="Force fresh training, ignore existing checkpoints",
419
+ )
420
+
421
+ # Performance
422
+ parser.add_argument(
423
+ "--compile", action="store_true", help="Enable torch.compile (PyTorch 2.x)"
424
+ )
425
+ parser.add_argument(
426
+ "--precision",
427
+ type=str,
428
+ default="bf16",
429
+ choices=["bf16", "fp16", "no"],
430
+ help="Mixed precision mode",
431
+ )
432
+ # Alias for consistency with wavedl-hpc (--mixed_precision)
433
+ parser.add_argument(
434
+ "--mixed_precision",
435
+ dest="precision",
436
+ type=str,
437
+ choices=["bf16", "fp16", "no"],
438
+ help=argparse.SUPPRESS, # Hidden: use --precision instead
439
+ )
440
+
441
+ # Physical Constraints
442
+ parser.add_argument(
443
+ "--constraint",
444
+ type=str,
445
+ nargs="+",
446
+ default=[],
447
+ help="Soft constraint expressions: 'y0 - y1*y2' (penalize violations)",
448
+ )
449
+
450
+ parser.add_argument(
451
+ "--constraint_file",
452
+ type=str,
453
+ default=None,
454
+ help="Python file with constraint(pred, inputs) function",
455
+ )
456
+ parser.add_argument(
457
+ "--constraint_weight",
458
+ type=float,
459
+ nargs="+",
460
+ default=[0.1],
461
+ help="Weight(s) for soft constraints (one per constraint, or single shared weight)",
462
+ )
463
+ parser.add_argument(
464
+ "--constraint_reduction",
465
+ type=str,
466
+ default="mse",
467
+ choices=["mse", "mae"],
468
+ help="Reduction mode for constraint penalties",
469
+ )
470
+
471
+ # Logging
472
+ parser.add_argument(
473
+ "--wandb", action="store_true", help="Enable Weights & Biases logging"
474
+ )
475
+ parser.add_argument(
476
+ "--wandb_watch",
477
+ action="store_true",
478
+ help="Enable WandB gradient watching (adds overhead, useful for debugging)",
479
+ )
480
+ parser.add_argument(
481
+ "--project_name", type=str, default="DL-Training", help="WandB project name"
482
+ )
483
+ parser.add_argument("--run_name", type=str, default=None, help="WandB run name")
484
+
485
+ args = parser.parse_args()
486
+ return args, parser # Returns (Namespace, ArgumentParser)
487
+
488
+
489
+ # ==============================================================================
490
+ # MAIN TRAINING FUNCTION
491
+ # ==============================================================================
492
+ def main():
493
+ args, parser = parse_args()
494
+
495
+ # Import custom model modules if specified
496
+ if args.import_modules:
497
+ import importlib
498
+
499
+ for module_name in args.import_modules:
500
+ try:
501
+ # Handle both module names (my_model) and file paths (./my_model.py)
502
+ if module_name.endswith(".py"):
503
+ # Import from file path with unique module name
504
+ import importlib.util
505
+
506
+ # Derive unique module name from filename to avoid collisions
507
+ base_name = os.path.splitext(os.path.basename(module_name))[0]
508
+ unique_name = f"wavedl_custom_{base_name}"
509
+
510
+ spec = importlib.util.spec_from_file_location(
511
+ unique_name, module_name
512
+ )
513
+ if spec and spec.loader:
514
+ module = importlib.util.module_from_spec(spec)
515
+ sys.modules[unique_name] = module
516
+ spec.loader.exec_module(module)
517
+ print(f"✓ Imported custom module from: {module_name}")
518
+ else:
519
+ # Import as regular module
520
+ importlib.import_module(module_name)
521
+ print(f"✓ Imported module: {module_name}")
522
+ except (ImportError, FileNotFoundError, SyntaxError, PermissionError) as e:
523
+ print(f"✗ Failed to import '{module_name}': {e}", file=sys.stderr)
524
+ if isinstance(e, FileNotFoundError):
525
+ print(" File does not exist. Check the path.", file=sys.stderr)
526
+ elif isinstance(e, SyntaxError):
527
+ print(
528
+ f" Syntax error at line {e.lineno}: {e.msg}", file=sys.stderr
529
+ )
530
+ elif isinstance(e, PermissionError):
531
+ print(
532
+ " Permission denied. Check file permissions.", file=sys.stderr
533
+ )
534
+ else:
535
+ print(
536
+ " Make sure the module is in your Python path or current directory.",
537
+ file=sys.stderr,
538
+ )
539
+ sys.exit(1)
540
+
541
+ # Handle --list_models flag
542
+ if args.list_models:
543
+ print("Available models:")
544
+ for name in list_models():
545
+ ModelClass = get_model(name)
546
+ # Get first non-empty docstring line
547
+ if ModelClass.__doc__:
548
+ lines = [
549
+ l.strip() for l in ModelClass.__doc__.splitlines() if l.strip()
550
+ ]
551
+ doc_first_line = lines[0] if lines else "No description"
552
+ else:
553
+ doc_first_line = "No description"
554
+ print(f" - {name}: {doc_first_line}")
555
+ sys.exit(0)
556
+
557
+ # Load and merge config file if provided
558
+ if args.config:
559
+ from wavedl.utils.config import (
560
+ load_config,
561
+ merge_config_with_args,
562
+ validate_config,
563
+ )
564
+
565
+ print(f"📄 Loading config from: {args.config}")
566
+ config = load_config(args.config)
567
+
568
+ # Validate config values
569
+ warnings_list = validate_config(config)
570
+ for w in warnings_list:
571
+ print(f" ⚠ {w}")
572
+
573
+ # Merge config with CLI args (CLI takes precedence via parser defaults detection)
574
+ args = merge_config_with_args(config, args, parser=parser)
575
+
576
+ # Handle --cv flag (cross-validation mode)
577
+ if args.cv > 0:
578
+ print(f"🔄 Cross-Validation Mode: {args.cv} folds")
579
+ from wavedl.utils.cross_validation import run_cross_validation
580
+
581
+ # Load data for CV using memory-efficient loader
582
+ from wavedl.utils.data import DataSource, get_data_source
583
+
584
+ data_format = DataSource.detect_format(args.data_path)
585
+ source = get_data_source(data_format)
586
+
587
+ # Use memory-mapped loading when available (now returns LazyDataHandle for all formats)
588
+ _cv_handle = None
589
+ if hasattr(source, "load_mmap"):
590
+ _cv_handle = source.load_mmap(args.data_path)
591
+ X, y = _cv_handle.inputs, _cv_handle.outputs
592
+ else:
593
+ X, y = source.load(args.data_path)
594
+
595
+ # Handle sparse matrices (must materialize for CV shuffling)
596
+ if hasattr(X, "__getitem__") and len(X) > 0 and hasattr(X[0], "toarray"):
597
+ X = np.stack([x.toarray() for x in X])
598
+
599
+ # Normalize target shape: (N,) -> (N, 1) for consistency
600
+ if y.ndim == 1:
601
+ y = y.reshape(-1, 1)
602
+
603
+ # Validate and determine input shape (consistent with prepare_data)
604
+ # Check for ambiguous shapes that could be multi-channel or shallow 3D volume
605
+ sample_shape = X.shape[1:] # Per-sample shape
606
+
607
+ # Same heuristic as prepare_data: detect ambiguous 3D shapes
608
+ is_ambiguous_shape = (
609
+ len(sample_shape) == 3 # Exactly 3D: could be (C, H, W) or (D, H, W)
610
+ and sample_shape[0] <= 16 # First dim looks like channels
611
+ and sample_shape[1] > 16
612
+ and sample_shape[2] > 16 # Both spatial dims are large
613
+ )
614
+
615
+ if is_ambiguous_shape and not args.single_channel:
616
+ raise ValueError(
617
+ f"Ambiguous input shape detected: sample shape {sample_shape}. "
618
+ f"This could be either:\n"
619
+ f" - Multi-channel 2D data (C={sample_shape[0]}, H={sample_shape[1]}, W={sample_shape[2]})\n"
620
+ f" - Single-channel 3D volume (D={sample_shape[0]}, H={sample_shape[1]}, W={sample_shape[2]})\n\n"
621
+ f"If this is single-channel 3D/shallow volume data, use --single_channel flag.\n"
622
+ f"If this is multi-channel 2D data, reshape to (N*C, H, W) with adjusted targets."
623
+ )
624
+
625
+ # in_shape = spatial dimensions for model registry (channel added during training)
626
+ in_shape = sample_shape
627
+
628
+ # Run cross-validation
629
+ try:
630
+ run_cross_validation(
631
+ X=X,
632
+ y=y,
633
+ model_name=args.model,
634
+ in_shape=in_shape,
635
+ out_size=y.shape[1],
636
+ folds=args.cv,
637
+ stratify=args.cv_stratify,
638
+ stratify_bins=args.cv_bins,
639
+ batch_size=args.batch_size,
640
+ lr=args.lr,
641
+ epochs=args.epochs,
642
+ patience=args.patience,
643
+ weight_decay=args.weight_decay,
644
+ loss_name=args.loss,
645
+ optimizer_name=args.optimizer,
646
+ scheduler_name=args.scheduler,
647
+ output_dir=args.output_dir,
648
+ workers=args.workers,
649
+ seed=args.seed,
650
+ )
651
+ finally:
652
+ # Clean up file handle if HDF5/MAT
653
+ if _cv_handle is not None and hasattr(_cv_handle, "close"):
654
+ try:
655
+ _cv_handle.close()
656
+ except Exception as e:
657
+ logging.debug(f"Failed to close CV data handle: {e}")
658
+ return
659
+
660
+ # ==========================================================================
661
+ # SYSTEM INITIALIZATION
662
+ # ==========================================================================
663
+ # Initialize Accelerator for DDP and mixed precision
664
+ accelerator = Accelerator(
665
+ mixed_precision=args.precision,
666
+ log_with="wandb" if args.wandb and WANDB_AVAILABLE else None,
667
+ )
668
+ set_seed(args.seed)
669
+
670
+ # Deterministic mode for scientific reproducibility
671
+ # Disables TF32 and cuDNN benchmark for exact reproducibility (slower)
672
+ if args.deterministic:
673
+ torch.backends.cudnn.benchmark = False
674
+ torch.backends.cudnn.deterministic = True
675
+ torch.backends.cuda.matmul.allow_tf32 = False
676
+ torch.backends.cudnn.allow_tf32 = False
677
+ torch.use_deterministic_algorithms(True, warn_only=True)
678
+ if accelerator.is_main_process:
679
+ print("🔒 Deterministic mode enabled (slower but reproducible)")
680
+
681
+ # Configure logging (rank 0 only prints to console)
682
+ logging.basicConfig(
683
+ level=logging.INFO if accelerator.is_main_process else logging.ERROR,
684
+ format="%(asctime)s | %(levelname)s | %(message)s",
685
+ datefmt="%H:%M:%S",
686
+ )
687
+ logger = logging.getLogger("Trainer")
688
+
689
+ # Ensure output directory exists (critical for cache files, checkpoints, etc.)
690
+ os.makedirs(args.output_dir, exist_ok=True)
691
+
692
+ # Auto-detect optimal DataLoader workers if not specified
693
+ if args.workers < 0:
694
+ cpu_count = os.cpu_count() or 4
695
+ num_gpus = accelerator.num_processes
696
+ # Heuristic: 4-16 workers per GPU, bounded by available CPU cores
697
+ # Increased cap from 8 to 16 for high-throughput GPUs (H100, A100)
698
+ args.workers = min(16, max(2, (cpu_count - 2) // num_gpus))
699
+ if accelerator.is_main_process:
700
+ logger.info(
701
+ f"⚙️ Auto-detected workers: {args.workers} per GPU "
702
+ f"(CPUs: {cpu_count}, GPUs: {num_gpus})"
703
+ )
704
+
705
+ if accelerator.is_main_process:
706
+ logger.info(f"🚀 Cluster Status: {accelerator.num_processes}x GPUs detected")
707
+ logger.info(
708
+ f" Model: {args.model} | Precision: {args.precision} | Compile: {args.compile}"
709
+ )
710
+ logger.info(
711
+ f" Loss: {args.loss} | Optimizer: {args.optimizer} | Scheduler: {args.scheduler}"
712
+ )
713
+ logger.info(f" Early Stopping Patience: {args.patience} epochs")
714
+ if args.save_every > 0:
715
+ logger.info(f" Periodic Checkpointing: Every {args.save_every} epochs")
716
+ if args.resume:
717
+ logger.info(f" 📂 Resuming from: {args.resume}")
718
+
719
+ # Initialize WandB
720
+ if args.wandb and WANDB_AVAILABLE:
721
+ accelerator.init_trackers(
722
+ project_name=args.project_name,
723
+ config=vars(args),
724
+ init_kwargs={"wandb": {"name": args.run_name or f"{args.model}_run"}},
725
+ )
726
+
727
+ # ==========================================================================
728
+ # DATA & MODEL LOADING
729
+ # ==========================================================================
730
+ train_dl, val_dl, scaler, in_shape, out_dim = prepare_data(
731
+ args, logger, accelerator, cache_dir=args.output_dir
732
+ )
733
+
734
+ # Build model using registry
735
+ model = build_model(
736
+ args.model, in_shape=in_shape, out_size=out_dim, pretrained=args.pretrained
737
+ )
738
+
739
+ if accelerator.is_main_process:
740
+ param_info = model.parameter_summary()
741
+ logger.info(
742
+ f" Model Parameters: {param_info['trainable_parameters']:,} trainable"
743
+ )
744
+ logger.info(f" Model Size: {param_info['total_mb']:.2f} MB")
745
+
746
+ # Optional WandB model watching (opt-in due to overhead on large models)
747
+ if (
748
+ args.wandb
749
+ and args.wandb_watch
750
+ and WANDB_AVAILABLE
751
+ and accelerator.is_main_process
752
+ ):
753
+ wandb.watch(model, log="gradients", log_freq=100)
754
+ logger.info(" 📊 WandB gradient watching enabled")
755
+
756
+ # Torch 2.0 compilation (requires compatible Triton on GPU)
757
+ if args.compile:
758
+ try:
759
+ # Test if Triton is available - just import the package
760
+ # Different Triton versions have different internal APIs, so just check base import
761
+ import triton
762
+
763
+ model = torch.compile(model)
764
+ if accelerator.is_main_process:
765
+ logger.info(" ✔ torch.compile enabled (Triton backend)")
766
+ except ImportError as e:
767
+ if accelerator.is_main_process:
768
+ if "triton" in str(e).lower():
769
+ logger.warning(
770
+ " ⚠ Triton not installed or incompatible version - torch.compile disabled. "
771
+ "Training will proceed without compilation."
772
+ )
773
+ else:
774
+ logger.warning(
775
+ f" ⚠ torch.compile setup failed: {e}. Continuing without compilation."
776
+ )
777
+ except Exception as e:
778
+ if accelerator.is_main_process:
779
+ logger.warning(
780
+ f" ⚠ torch.compile failed: {e}. Continuing without compilation."
781
+ )
782
+
783
+ # ==========================================================================
784
+ # OPTIMIZER, SCHEDULER & LOSS CONFIGURATION
785
+ # ==========================================================================
786
+ # Parse comma-separated arguments with validation
787
+ try:
788
+ betas_list = [float(x.strip()) for x in args.betas.split(",")]
789
+ if len(betas_list) != 2:
790
+ raise ValueError(
791
+ f"--betas must have exactly 2 values, got {len(betas_list)}"
792
+ )
793
+ if not all(0.0 <= b < 1.0 for b in betas_list):
794
+ raise ValueError(f"--betas values must be in [0, 1), got {betas_list}")
795
+ betas = tuple(betas_list)
796
+ except ValueError as e:
797
+ raise ValueError(
798
+ f"Invalid --betas format '{args.betas}': {e}. Expected format: '0.9,0.999'"
799
+ )
800
+
801
+ loss_weights = None
802
+ if args.loss_weights:
803
+ loss_weights = [float(x.strip()) for x in args.loss_weights.split(",")]
804
+ milestones = None
805
+ if args.milestones:
806
+ milestones = [int(x.strip()) for x in args.milestones.split(",")]
807
+
808
+ # Create optimizer using factory
809
+ optimizer = get_optimizer(
810
+ name=args.optimizer,
811
+ params=model.get_optimizer_groups(args.lr, args.weight_decay),
812
+ lr=args.lr,
813
+ weight_decay=args.weight_decay,
814
+ momentum=args.momentum,
815
+ nesterov=args.nesterov,
816
+ betas=betas,
817
+ )
818
+
819
+ # Create loss function using factory
820
+ criterion = get_loss(
821
+ name=args.loss,
822
+ weights=loss_weights,
823
+ delta=args.huber_delta,
824
+ )
825
+ # Move criterion to device (important for WeightedMSELoss buffer)
826
+ criterion = criterion.to(accelerator.device)
827
+
828
+ # ==========================================================================
829
+ # PHYSICAL CONSTRAINTS INTEGRATION
830
+ # ==========================================================================
831
+ from wavedl.utils.constraints import (
832
+ PhysicsConstrainedLoss,
833
+ build_constraints,
834
+ )
835
+
836
+ # Build soft constraints
837
+ soft_constraints = build_constraints(
838
+ expressions=args.constraint,
839
+ file_path=args.constraint_file,
840
+ reduction=args.constraint_reduction,
841
+ )
842
+
843
+ # Wrap criterion with PhysicsConstrainedLoss if we have soft constraints
844
+ if soft_constraints:
845
+ # Pass output scaler so constraints can be evaluated in physical space
846
+ output_mean = scaler.mean_ if hasattr(scaler, "mean_") else None
847
+ output_std = scaler.scale_ if hasattr(scaler, "scale_") else None
848
+ criterion = PhysicsConstrainedLoss(
849
+ criterion,
850
+ soft_constraints,
851
+ weights=args.constraint_weight,
852
+ output_mean=output_mean,
853
+ output_std=output_std,
854
+ )
855
+ if accelerator.is_main_process:
856
+ logger.info(
857
+ f" 🔬 Physical constraints: {len(soft_constraints)} constraint(s) "
858
+ f"with weight(s) {args.constraint_weight}"
859
+ )
860
+ if output_mean is not None:
861
+ logger.info(
862
+ " 📐 Constraints evaluated in physical space (denormalized)"
863
+ )
864
+
865
+ # Track if scheduler should step per batch (OneCycleLR) or per epoch
866
+ scheduler_step_per_batch = not is_epoch_based(args.scheduler)
867
+
868
+ # ==========================================================================
869
+ # DDP Preparation Strategy:
870
+ # - For batch-based schedulers (OneCycleLR): prepare DataLoaders first to get
871
+ # the correct sharded batch count, then create scheduler
872
+ # - For epoch-based schedulers: create scheduler before prepare (no issue)
873
+ # ==========================================================================
874
+ if scheduler_step_per_batch:
875
+ # BATCH-BASED SCHEDULER (e.g., OneCycleLR)
876
+ # Prepare model, optimizer, dataloaders FIRST to get sharded loader length
877
+ model, optimizer, train_dl, val_dl = accelerator.prepare(
878
+ model, optimizer, train_dl, val_dl
879
+ )
880
+
881
+ # Now create scheduler with the CORRECT sharded steps_per_epoch
882
+ steps_per_epoch = len(train_dl) # Post-DDP sharded length
883
+ scheduler = get_scheduler(
884
+ name=args.scheduler,
885
+ optimizer=optimizer,
886
+ epochs=args.epochs,
887
+ steps_per_epoch=steps_per_epoch,
888
+ min_lr=args.min_lr,
889
+ patience=args.scheduler_patience,
890
+ factor=args.scheduler_factor,
891
+ gamma=args.scheduler_factor, # For Step/MultiStep/Exponential schedulers
892
+ step_size=args.step_size,
893
+ milestones=milestones,
894
+ warmup_epochs=args.warmup_epochs,
895
+ )
896
+ # Prepare scheduler separately (Accelerator wraps it for state saving)
897
+ scheduler = accelerator.prepare(scheduler)
898
+ else:
899
+ # EPOCH-BASED SCHEDULER (plateau, cosine, step, etc.)
900
+ # No batch count dependency - create scheduler before prepare
901
+ scheduler = get_scheduler(
902
+ name=args.scheduler,
903
+ optimizer=optimizer,
904
+ epochs=args.epochs,
905
+ steps_per_epoch=None,
906
+ min_lr=args.min_lr,
907
+ patience=args.scheduler_patience,
908
+ factor=args.scheduler_factor,
909
+ gamma=args.scheduler_factor, # For Step/MultiStep/Exponential schedulers
910
+ step_size=args.step_size,
911
+ milestones=milestones,
912
+ warmup_epochs=args.warmup_epochs,
913
+ )
914
+
915
+ # For ReduceLROnPlateau: DON'T include scheduler in accelerator.prepare()
916
+ # because accelerator wraps scheduler.step() to sync across processes,
917
+ # which defeats our rank-0-only stepping for correct patience counting.
918
+ # Other schedulers are safe to prepare (no internal state affected by multi-call).
919
+ if args.scheduler == "plateau":
920
+ model, optimizer, train_dl, val_dl = accelerator.prepare(
921
+ model, optimizer, train_dl, val_dl
922
+ )
923
+ # Scheduler stays unwrapped - we handle sync manually in training loop
924
+ # But register it for checkpointing so state is saved/loaded on resume
925
+ accelerator.register_for_checkpointing(scheduler)
926
+ else:
927
+ model, optimizer, train_dl, val_dl, scheduler = accelerator.prepare(
928
+ model, optimizer, train_dl, val_dl, scheduler
929
+ )
930
+
931
+ # ==========================================================================
932
+ # AUTO-RESUME / RESUME FROM CHECKPOINT
933
+ # ==========================================================================
934
+ start_epoch = 0
935
+ best_val_loss = float("inf")
936
+ patience_ctr = 0
937
+ history: list[dict[str, Any]] = []
938
+
939
+ # Define checkpoint paths
940
+ best_ckpt_path = os.path.join(args.output_dir, "best_checkpoint")
941
+ complete_flag_path = os.path.join(args.output_dir, "training_complete.flag")
942
+
943
+ # Auto-resume logic (if not --fresh and no explicit --resume)
944
+ if not args.fresh and args.resume is None:
945
+ if os.path.exists(complete_flag_path):
946
+ # Training already completed
947
+ if accelerator.is_main_process:
948
+ logger.info(
949
+ "✅ Training already completed (early stopping). Use --fresh to retrain."
950
+ )
951
+ return # Exit gracefully
952
+ elif os.path.exists(best_ckpt_path):
953
+ # Incomplete training found - auto-resume
954
+ args.resume = best_ckpt_path
955
+ if accelerator.is_main_process:
956
+ logger.info(f"🔄 Auto-resuming from: {best_ckpt_path}")
957
+
958
+ if args.resume:
959
+ if os.path.exists(args.resume):
960
+ logger.info(f"🔄 Loading checkpoint from: {args.resume}")
961
+ accelerator.load_state(args.resume)
962
+
963
+ # Restore training metadata
964
+ meta_path = os.path.join(args.resume, "training_meta.pkl")
965
+ if os.path.exists(meta_path):
966
+ with open(meta_path, "rb") as f:
967
+ meta = pickle.load(f)
968
+ start_epoch = meta.get("epoch", 0)
969
+ best_val_loss = meta.get("best_val_loss", float("inf"))
970
+ patience_ctr = meta.get("patience_ctr", 0)
971
+ logger.info(
972
+ f" ✅ Restored: Epoch {start_epoch}, Best Loss: {best_val_loss:.6f}"
973
+ )
974
+ else:
975
+ logger.warning(
976
+ " ⚠️ training_meta.pkl not found, starting from epoch 0"
977
+ )
978
+
979
+ # Restore history
980
+ history_path = os.path.join(args.output_dir, "training_history.csv")
981
+ if os.path.exists(history_path):
982
+ history = pd.read_csv(history_path).to_dict("records")
983
+ logger.info(f" ✅ Loaded {len(history)} epochs from history")
984
+ else:
985
+ raise FileNotFoundError(f"Checkpoint not found: {args.resume}")
986
+
987
+ # ==========================================================================
988
+ # PHYSICAL METRIC SETUP
989
+ # ==========================================================================
990
+ # Physical MAE = normalized MAE * scaler.scale_
991
+ phys_scale = torch.tensor(
992
+ scaler.scale_, device=accelerator.device, dtype=torch.float32
993
+ )
994
+
995
+ # ==========================================================================
996
+ # TRAINING LOOP
997
+ # ==========================================================================
998
+ # Dynamic console header
999
+ if accelerator.is_main_process:
1000
+ base_cols = ["Ep", "TrnLoss", "ValLoss", "R2", "PCC", "GradN", "LR", "MAE_Avg"]
1001
+ param_cols = [f"MAE_P{i}" for i in range(out_dim)]
1002
+ header = "{:<4} | {:<8} | {:<8} | {:<6} | {:<6} | {:<6} | {:<8} | {:<8}".format(
1003
+ *base_cols
1004
+ )
1005
+ header += " | " + " | ".join([f"{c:<8}" for c in param_cols])
1006
+ logger.info("=" * len(header))
1007
+ logger.info(header)
1008
+ logger.info("=" * len(header))
1009
+
1010
+ try:
1011
+ total_training_time = 0.0
1012
+
1013
+ for epoch in range(start_epoch, args.epochs):
1014
+ epoch_start_time = time.time()
1015
+
1016
+ # ==================== TRAINING PHASE ====================
1017
+ model.train()
1018
+ # Use GPU tensor for loss accumulation to avoid .item() sync per batch
1019
+ train_loss_sum = torch.tensor(0.0, device=accelerator.device)
1020
+ train_samples = 0
1021
+ grad_norm_tracker = MetricTracker()
1022
+
1023
+ pbar = tqdm(
1024
+ train_dl,
1025
+ disable=not accelerator.is_main_process,
1026
+ leave=False,
1027
+ desc=f"Epoch {epoch + 1}",
1028
+ )
1029
+
1030
+ for x, y in pbar:
1031
+ with accelerator.accumulate(model):
1032
+ # Use mixed precision for forward pass (respects --precision flag)
1033
+ with accelerator.autocast():
1034
+ pred = model(x)
1035
+ # Pass inputs for input-dependent constraints (x_mean, x[...], etc.)
1036
+ if isinstance(criterion, PhysicsConstrainedLoss):
1037
+ loss = criterion(pred, y, x)
1038
+ else:
1039
+ loss = criterion(pred, y)
1040
+
1041
+ accelerator.backward(loss)
1042
+
1043
+ if accelerator.sync_gradients:
1044
+ grad_norm = accelerator.clip_grad_norm_(
1045
+ model.parameters(), args.grad_clip
1046
+ )
1047
+ if grad_norm is not None:
1048
+ grad_norm_tracker.update(grad_norm.item())
1049
+
1050
+ optimizer.step()
1051
+ optimizer.zero_grad(set_to_none=True) # Faster than zero_grad()
1052
+
1053
+ # Per-batch LR scheduling (e.g., OneCycleLR)
1054
+ if scheduler_step_per_batch:
1055
+ scheduler.step()
1056
+
1057
+ # Accumulate as tensors to avoid .item() sync per batch
1058
+ train_loss_sum += loss.detach() * x.size(0)
1059
+ train_samples += x.size(0)
1060
+
1061
+ # Single .item() call at end of epoch (reduces GPU sync overhead)
1062
+ train_loss_scalar = train_loss_sum.item()
1063
+
1064
+ # Synchronize training metrics across GPUs
1065
+ global_loss = accelerator.reduce(
1066
+ torch.tensor([train_loss_scalar], device=accelerator.device),
1067
+ reduction="sum",
1068
+ ).item()
1069
+ global_samples = accelerator.reduce(
1070
+ torch.tensor([train_samples], device=accelerator.device),
1071
+ reduction="sum",
1072
+ ).item()
1073
+ avg_train_loss = global_loss / global_samples
1074
+
1075
+ # ==================== VALIDATION PHASE ====================
1076
+ model.eval()
1077
+ # Use GPU tensor for loss accumulation (consistent with training phase)
1078
+ val_loss_sum = torch.tensor(0.0, device=accelerator.device)
1079
+ val_mae_sum = torch.zeros(out_dim, device=accelerator.device)
1080
+ val_samples = 0
1081
+
1082
+ # Accumulate predictions locally ON CPU to prevent GPU OOM
1083
+ local_preds = []
1084
+ local_targets = []
1085
+
1086
+ with torch.inference_mode():
1087
+ for x, y in val_dl:
1088
+ # Use mixed precision for validation (consistent with training)
1089
+ with accelerator.autocast():
1090
+ pred = model(x)
1091
+ # Pass inputs for input-dependent constraints
1092
+ if isinstance(criterion, PhysicsConstrainedLoss):
1093
+ loss = criterion(pred, y, x)
1094
+ else:
1095
+ loss = criterion(pred, y)
1096
+
1097
+ val_loss_sum += loss.detach() * x.size(0)
1098
+ val_samples += x.size(0)
1099
+
1100
+ # Physical MAE
1101
+ mae_batch = torch.abs((pred - y) * phys_scale).sum(dim=0)
1102
+ val_mae_sum += mae_batch
1103
+
1104
+ # Store on CPU (critical for large val sets)
1105
+ local_preds.append(pred.detach().cpu())
1106
+ local_targets.append(y.detach().cpu())
1107
+
1108
+ # Concatenate locally (keep on GPU for gather_for_metrics compatibility)
1109
+ local_preds_cat = torch.cat(local_preds)
1110
+ local_targets_cat = torch.cat(local_targets)
1111
+
1112
+ # Gather predictions and targets using Accelerate's CPU-efficient utility
1113
+ # gather_for_metrics handles:
1114
+ # - DDP padding removal (no need to trim manually)
1115
+ # - Efficient cross-rank gathering without GPU memory spike
1116
+ # - Returns concatenated tensors on CPU for metric computation
1117
+ if accelerator.num_processes > 1:
1118
+ # Move to GPU for gather (required by NCCL), then back to CPU
1119
+ # gather_for_metrics is more memory-efficient than manual gather
1120
+ # as it processes in chunks internally
1121
+ gathered_preds = accelerator.gather_for_metrics(
1122
+ local_preds_cat.to(accelerator.device)
1123
+ ).cpu()
1124
+ gathered_targets = accelerator.gather_for_metrics(
1125
+ local_targets_cat.to(accelerator.device)
1126
+ ).cpu()
1127
+ else:
1128
+ # Single-GPU mode: no gathering needed
1129
+ gathered_preds = local_preds_cat
1130
+ gathered_targets = local_targets_cat
1131
+
1132
+ # Synchronize validation metrics (scalars only - efficient)
1133
+ val_loss_scalar = val_loss_sum.item()
1134
+ val_metrics = torch.cat(
1135
+ [
1136
+ torch.tensor([val_loss_scalar], device=accelerator.device),
1137
+ val_mae_sum,
1138
+ ]
1139
+ )
1140
+ val_metrics_sync = accelerator.reduce(val_metrics, reduction="sum")
1141
+
1142
+ total_val_samples = accelerator.reduce(
1143
+ torch.tensor([val_samples], device=accelerator.device), reduction="sum"
1144
+ ).item()
1145
+
1146
+ avg_val_loss = val_metrics_sync[0].item() / total_val_samples
1147
+ # Cast to float32 before numpy (bf16 tensors can't convert directly)
1148
+ avg_mae_per_param = (
1149
+ (val_metrics_sync[1:] / total_val_samples).float().cpu().numpy()
1150
+ )
1151
+ avg_mae = avg_mae_per_param.mean()
1152
+
1153
+ # ==================== LOGGING & CHECKPOINTING ====================
1154
+ if accelerator.is_main_process:
1155
+ # Scientific metrics - cast to float32 before numpy
1156
+ # gather_for_metrics already handles DDP padding removal
1157
+ y_pred = gathered_preds.float().numpy()
1158
+ y_true = gathered_targets.float().numpy()
1159
+
1160
+ # Guard against tiny validation sets (R² undefined for <2 samples)
1161
+ if len(y_true) >= 2:
1162
+ r2 = r2_score(y_true, y_pred)
1163
+ else:
1164
+ r2 = float("nan")
1165
+ pcc = calc_pearson(y_true, y_pred)
1166
+ current_lr = get_lr(optimizer)
1167
+
1168
+ # Update history
1169
+ epoch_end_time = time.time()
1170
+ epoch_time = epoch_end_time - epoch_start_time
1171
+ total_training_time += epoch_time
1172
+
1173
+ epoch_stats = {
1174
+ "epoch": epoch + 1,
1175
+ "train_loss": avg_train_loss,
1176
+ "val_loss": avg_val_loss,
1177
+ "val_r2": r2,
1178
+ "val_pearson": pcc,
1179
+ "val_mae_avg": avg_mae,
1180
+ "grad_norm": grad_norm_tracker.avg,
1181
+ "lr": current_lr,
1182
+ "epoch_time": round(epoch_time, 2),
1183
+ "total_time": round(total_training_time, 2),
1184
+ }
1185
+ for i, mae in enumerate(avg_mae_per_param):
1186
+ epoch_stats[f"MAE_Phys_P{i}"] = mae
1187
+
1188
+ history.append(epoch_stats)
1189
+
1190
+ # Console display
1191
+ base_str = f"{epoch + 1:<4} | {avg_train_loss:<8.4f} | {avg_val_loss:<8.4f} | {r2:<6.4f} | {pcc:<6.4f} | {grad_norm_tracker.avg:<6.4f} | {current_lr:<8.2e} | {avg_mae:<8.4f}"
1192
+ param_str = " | ".join([f"{m:<8.4f}" for m in avg_mae_per_param])
1193
+ logger.info(f"{base_str} | {param_str}")
1194
+
1195
+ # WandB logging
1196
+ if args.wandb and WANDB_AVAILABLE:
1197
+ log_dict = {
1198
+ "main/train_loss": avg_train_loss,
1199
+ "main/val_loss": avg_val_loss,
1200
+ "metrics/r2_score": r2,
1201
+ "metrics/pearson_corr": pcc,
1202
+ "metrics/mae_avg": avg_mae,
1203
+ "system/grad_norm": grad_norm_tracker.avg,
1204
+ "hyper/lr": current_lr,
1205
+ }
1206
+ for i, mae in enumerate(avg_mae_per_param):
1207
+ log_dict[f"mae_detailed/P{i}"] = mae
1208
+
1209
+ # Periodic scatter plots
1210
+ if (epoch % 5 == 0) or (avg_val_loss < best_val_loss):
1211
+ real_true = scaler.inverse_transform(y_true)
1212
+ real_pred = scaler.inverse_transform(y_pred)
1213
+ fig = plot_scientific_scatter(real_true, real_pred)
1214
+ log_dict["plots/scatter_analysis"] = wandb.Image(fig)
1215
+ plt.close(fig)
1216
+
1217
+ accelerator.log(log_dict)
1218
+
1219
+ # ==========================================================================
1220
+ # DDP-SAFE CHECKPOINT LOGIC
1221
+ # ==========================================================================
1222
+ # Step 1: Determine if this is the best epoch (BEFORE updating best_val_loss)
1223
+ is_best_epoch = False
1224
+ if accelerator.is_main_process:
1225
+ if avg_val_loss < best_val_loss:
1226
+ is_best_epoch = True
1227
+
1228
+ # Step 2: Broadcast decision to all ranks (required for save_state)
1229
+ is_best_epoch = broadcast_early_stop(is_best_epoch, accelerator)
1230
+
1231
+ # Step 3: Save checkpoint with all ranks participating
1232
+ if is_best_epoch:
1233
+ ckpt_dir = os.path.join(args.output_dir, "best_checkpoint")
1234
+ with suppress_accelerate_logging():
1235
+ accelerator.save_state(ckpt_dir, safe_serialization=False)
1236
+
1237
+ # Step 4: Rank 0 handles metadata and updates tracking variables
1238
+ if accelerator.is_main_process:
1239
+ best_val_loss = avg_val_loss # Update AFTER checkpoint saved
1240
+ patience_ctr = 0
1241
+
1242
+ with open(os.path.join(ckpt_dir, "training_meta.pkl"), "wb") as f:
1243
+ pickle.dump(
1244
+ {
1245
+ "epoch": epoch + 1,
1246
+ "best_val_loss": best_val_loss,
1247
+ "patience_ctr": patience_ctr,
1248
+ # Model info for auto-detection during inference
1249
+ "model_name": args.model,
1250
+ "in_shape": in_shape,
1251
+ "out_dim": out_dim,
1252
+ },
1253
+ f,
1254
+ )
1255
+
1256
+ # Unwrap model for saving (handle torch.compile compatibility)
1257
+ try:
1258
+ unwrapped = accelerator.unwrap_model(model)
1259
+ except KeyError:
1260
+ # torch.compile model may not have _orig_mod in expected location
1261
+ # Fall back to getting the module directly
1262
+ unwrapped = model.module if hasattr(model, "module") else model
1263
+ # If still compiled, try to get the underlying model
1264
+ if hasattr(unwrapped, "_orig_mod"):
1265
+ unwrapped = unwrapped._orig_mod
1266
+
1267
+ torch.save(
1268
+ unwrapped.state_dict(),
1269
+ os.path.join(args.output_dir, "best_model_weights.pth"),
1270
+ )
1271
+
1272
+ # Copy scaler to checkpoint for portability
1273
+ scaler_src = os.path.join(args.output_dir, "scaler.pkl")
1274
+ scaler_dst = os.path.join(ckpt_dir, "scaler.pkl")
1275
+ if os.path.exists(scaler_src) and not os.path.exists(scaler_dst):
1276
+ shutil.copy2(scaler_src, scaler_dst)
1277
+
1278
+ logger.info(
1279
+ f" 💾 Best model saved (val_loss: {best_val_loss:.6f})"
1280
+ )
1281
+
1282
+ # Also save CSV on best model (ensures progress is saved)
1283
+ pd.DataFrame(history).to_csv(
1284
+ os.path.join(args.output_dir, "training_history.csv"),
1285
+ index=False,
1286
+ )
1287
+ else:
1288
+ if accelerator.is_main_process:
1289
+ patience_ctr += 1
1290
+
1291
+ # Periodic checkpoint (all ranks participate in save_state)
1292
+ periodic_checkpoint_needed = (
1293
+ args.save_every > 0 and (epoch + 1) % args.save_every == 0
1294
+ )
1295
+ if periodic_checkpoint_needed:
1296
+ ckpt_name = f"epoch_{epoch + 1}_checkpoint"
1297
+ ckpt_dir = os.path.join(args.output_dir, ckpt_name)
1298
+ with suppress_accelerate_logging():
1299
+ accelerator.save_state(ckpt_dir, safe_serialization=False)
1300
+
1301
+ if accelerator.is_main_process:
1302
+ with open(os.path.join(ckpt_dir, "training_meta.pkl"), "wb") as f:
1303
+ pickle.dump(
1304
+ {
1305
+ "epoch": epoch + 1,
1306
+ "best_val_loss": best_val_loss,
1307
+ "patience_ctr": patience_ctr,
1308
+ # Model info for auto-detection during inference
1309
+ "model_name": args.model,
1310
+ "in_shape": in_shape,
1311
+ "out_dim": out_dim,
1312
+ },
1313
+ f,
1314
+ )
1315
+ logger.info(f" 📁 Periodic checkpoint: {ckpt_name}")
1316
+
1317
+ # Save CSV with each checkpoint (keeps logs in sync with model state)
1318
+ pd.DataFrame(history).to_csv(
1319
+ os.path.join(args.output_dir, "training_history.csv"),
1320
+ index=False,
1321
+ )
1322
+
1323
+ # Learning rate scheduling (epoch-based schedulers only)
1324
+ # NOTE: For ReduceLROnPlateau with DDP, we must step only on main process
1325
+ # to avoid patience counter being incremented by all GPU processes.
1326
+ # Then we sync the new LR to all processes to keep them consistent.
1327
+ if not scheduler_step_per_batch:
1328
+ if args.scheduler == "plateau":
1329
+ # Step only on main process to avoid multi-GPU patience bug
1330
+ if accelerator.is_main_process:
1331
+ scheduler.step(avg_val_loss)
1332
+
1333
+ # Sync LR across all processes after main process updates it
1334
+ accelerator.wait_for_everyone()
1335
+
1336
+ # Broadcast new LR from rank 0 to all processes
1337
+ if dist.is_initialized():
1338
+ if accelerator.is_main_process:
1339
+ new_lr = optimizer.param_groups[0]["lr"]
1340
+ else:
1341
+ new_lr = 0.0
1342
+ new_lr_tensor = torch.tensor(
1343
+ new_lr, device=accelerator.device, dtype=torch.float32
1344
+ )
1345
+ dist.broadcast(new_lr_tensor, src=0)
1346
+ # Update LR on non-main processes
1347
+ if not accelerator.is_main_process:
1348
+ for param_group in optimizer.param_groups:
1349
+ param_group["lr"] = new_lr_tensor.item()
1350
+ else:
1351
+ scheduler.step()
1352
+
1353
+ # DDP-safe early stopping
1354
+ should_stop = (
1355
+ patience_ctr >= args.patience if accelerator.is_main_process else False
1356
+ )
1357
+ if broadcast_early_stop(should_stop, accelerator):
1358
+ if accelerator.is_main_process:
1359
+ logger.info(
1360
+ f"🛑 Early stopping at epoch {epoch + 1} (patience={args.patience})"
1361
+ )
1362
+ # Create completion flag to prevent auto-resume
1363
+ with open(
1364
+ os.path.join(args.output_dir, "training_complete.flag"), "w"
1365
+ ) as f:
1366
+ f.write(
1367
+ f"Training completed via early stopping at epoch {epoch + 1}\n"
1368
+ )
1369
+ break
1370
+
1371
+ except KeyboardInterrupt:
1372
+ logger.warning("Training interrupted. Saving emergency checkpoint...")
1373
+ with suppress_accelerate_logging():
1374
+ accelerator.save_state(
1375
+ os.path.join(args.output_dir, "interrupted_checkpoint"),
1376
+ safe_serialization=False,
1377
+ )
1378
+
1379
+ except Exception as e:
1380
+ logger.error(f"Critical error: {e}", exc_info=True)
1381
+ raise
1382
+
1383
+ else:
1384
+ # Training completed normally (reached max epochs without early stopping)
1385
+ # Create completion flag to prevent auto-resume on re-run
1386
+ if accelerator.is_main_process:
1387
+ if not os.path.exists(complete_flag_path):
1388
+ with open(complete_flag_path, "w") as f:
1389
+ f.write(f"Training completed normally after {args.epochs} epochs\n")
1390
+ logger.info(f"✅ Training completed after {args.epochs} epochs")
1391
+
1392
+ finally:
1393
+ # Final CSV write to capture all epochs (handles non-multiple-of-10 endings)
1394
+ if accelerator.is_main_process and len(history) > 0:
1395
+ pd.DataFrame(history).to_csv(
1396
+ os.path.join(args.output_dir, "training_history.csv"),
1397
+ index=False,
1398
+ )
1399
+
1400
+ # Generate training curves plot (PNG + SVG)
1401
+ if accelerator.is_main_process and len(history) > 0:
1402
+ try:
1403
+ fig = create_training_curves(history, show_lr=True)
1404
+ for fmt in ["png", "svg"]:
1405
+ fig.savefig(
1406
+ os.path.join(args.output_dir, f"training_curves.{fmt}"),
1407
+ dpi=FIGURE_DPI,
1408
+ bbox_inches="tight",
1409
+ )
1410
+ plt.close(fig)
1411
+ logger.info("✔ Saved: training_curves.png, training_curves.svg")
1412
+ except Exception as e:
1413
+ logger.warning(f"Could not generate training curves: {e}")
1414
+
1415
+ if args.wandb and WANDB_AVAILABLE:
1416
+ accelerator.end_training()
1417
+
1418
+ # Clean up distributed process group to prevent resource leak warning
1419
+ if torch.distributed.is_initialized():
1420
+ torch.distributed.destroy_process_group()
1421
+
1422
+ logger.info("Training completed.")
1423
+
1424
+
1425
+ if __name__ == "__main__":
1426
+ try:
1427
+ torch.multiprocessing.set_start_method("spawn")
1428
+ except RuntimeError:
1429
+ pass
1430
+ main()