wavedl 1.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
wavedl/train.py ADDED
@@ -0,0 +1,1079 @@
1
+ """
2
+ WaveDL - Deep Learning for Wave-based Inverse Problems
3
+ =======================================================
4
+ Target Environment: NVIDIA HPC GPUs (Multi-GPU DDP) | PyTorch 2.x | Python 3.11+
5
+
6
+ A modular training framework for wave-based inverse problems and regression:
7
+ 1. HPC-Grade DDP Training: BF16/FP16 mixed precision with torch.compile support
8
+ 2. Dynamic Model Selection: Use --model flag to select any registered architecture
9
+ 3. Zero-Copy Data Engine: Memmap-backed datasets for large-scale training
10
+ 4. Physics-Aware Metrics: Real-time physical MAE with proper unscaling
11
+ 5. Robust Checkpointing: Resume training, periodic saves, and training curves
12
+ 6. Deep Observability: WandB integration with scatter analysis
13
+
14
+ Usage:
15
+ # Recommended: Using the HPC helper script
16
+ ./run_training.sh --model cnn --batch_size 128 --wandb
17
+
18
+ # Or with direct accelerate launch
19
+ accelerate launch train.py --model cnn --batch_size 128 --wandb
20
+
21
+ # Multi-GPU with explicit config
22
+ accelerate launch --num_processes=4 --mixed_precision=bf16 \
23
+ train.py --model cnn --wandb --project_name "MyProject"
24
+
25
+ # Resume from checkpoint
26
+ accelerate launch train.py --model cnn --resume best_checkpoint --wandb
27
+
28
+ # List available models
29
+ python train.py --list_models
30
+
31
+ Note:
32
+ For HPC clusters (Compute Canada, etc.), use run_training.sh which handles
33
+ environment configuration automatically. Mixed precision is controlled via
34
+ --precision flag (default: bf16).
35
+
36
+ Author: Ductho Le (ductho.le@outlook.com)
37
+ """
38
+
39
+ # ==============================================================================
40
+ # ENVIRONMENT CONFIGURATION FOR HPC SYSTEMS
41
+ # ==============================================================================
42
+ # IMPORTANT: These must be set BEFORE matplotlib is imported to be effective
43
+ import os
44
+
45
+
46
+ os.environ.setdefault("MPLCONFIGDIR", os.getenv("TMPDIR", "/tmp") + "/matplotlib")
47
+ os.environ.setdefault("FONTCONFIG_PATH", "/etc/fonts")
48
+
49
+ import argparse
50
+ import logging
51
+ import pickle
52
+ import shutil
53
+ import sys
54
+ import time
55
+ import warnings
56
+ from typing import Any
57
+
58
+ import matplotlib.pyplot as plt
59
+ import numpy as np
60
+ import pandas as pd
61
+ import torch
62
+ from accelerate import Accelerator
63
+ from accelerate.utils import set_seed
64
+ from sklearn.metrics import r2_score
65
+ from tqdm.auto import tqdm
66
+
67
+ # Local imports
68
+ from wavedl.models import build_model, get_model, list_models
69
+ from wavedl.utils import (
70
+ FIGURE_DPI,
71
+ MetricTracker,
72
+ broadcast_early_stop,
73
+ calc_pearson,
74
+ create_training_curves,
75
+ # New factory functions
76
+ get_loss,
77
+ get_lr,
78
+ get_optimizer,
79
+ get_scheduler,
80
+ is_epoch_based,
81
+ list_losses,
82
+ list_optimizers,
83
+ list_schedulers,
84
+ plot_scientific_scatter,
85
+ prepare_data,
86
+ )
87
+
88
+
89
+ # Optional WandB import
90
+ try:
91
+ import wandb
92
+
93
+ WANDB_AVAILABLE = True
94
+ except ImportError:
95
+ WANDB_AVAILABLE = False
96
+
97
+ # Filter non-critical warnings for cleaner training logs
98
+ warnings.filterwarnings("ignore", category=UserWarning)
99
+ warnings.filterwarnings("ignore", category=FutureWarning)
100
+ warnings.filterwarnings("ignore", category=DeprecationWarning)
101
+ warnings.filterwarnings("ignore", module="pydantic")
102
+ warnings.filterwarnings("ignore", message=".*UnsupportedFieldAttributeWarning.*")
103
+
104
+
105
+ # ==============================================================================
106
+ # ARGUMENT PARSING
107
+ # ==============================================================================
108
+ def parse_args() -> argparse.Namespace:
109
+ """Parse command-line arguments with comprehensive options."""
110
+ parser = argparse.ArgumentParser(
111
+ description="Universal DDP Training Pipeline",
112
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter,
113
+ )
114
+
115
+ # Model Selection
116
+ parser.add_argument(
117
+ "--model",
118
+ type=str,
119
+ default="cnn",
120
+ help=f"Model architecture to train. Available: {list_models()}",
121
+ )
122
+ parser.add_argument(
123
+ "--list_models", action="store_true", help="List all available models and exit"
124
+ )
125
+
126
+ # Configuration File
127
+ parser.add_argument(
128
+ "--config",
129
+ type=str,
130
+ default=None,
131
+ help="Path to YAML config file. CLI args override config values.",
132
+ )
133
+
134
+ # Hyperparameters
135
+ parser.add_argument(
136
+ "--batch_size", type=int, default=128, help="Batch size per GPU"
137
+ )
138
+ parser.add_argument("--lr", type=float, default=1e-3, help="Initial learning rate")
139
+ parser.add_argument(
140
+ "--epochs", type=int, default=1000, help="Maximum training epochs"
141
+ )
142
+ parser.add_argument(
143
+ "--patience", type=int, default=20, help="Early stopping patience"
144
+ )
145
+ parser.add_argument("--weight_decay", type=float, default=1e-4, help="Weight decay")
146
+ parser.add_argument(
147
+ "--grad_clip", type=float, default=1.0, help="Gradient clipping norm"
148
+ )
149
+
150
+ # Loss Function
151
+ parser.add_argument(
152
+ "--loss",
153
+ type=str,
154
+ default="mse",
155
+ choices=["mse", "mae", "huber", "smooth_l1", "log_cosh", "weighted_mse"],
156
+ help=f"Loss function for training. Available: {list_losses()}",
157
+ )
158
+ parser.add_argument(
159
+ "--huber_delta", type=float, default=1.0, help="Delta for Huber loss"
160
+ )
161
+ parser.add_argument(
162
+ "--loss_weights",
163
+ type=str,
164
+ default=None,
165
+ help="Comma-separated weights for weighted_mse (e.g., '1.0,2.0,1.0')",
166
+ )
167
+
168
+ # Optimizer
169
+ parser.add_argument(
170
+ "--optimizer",
171
+ type=str,
172
+ default="adamw",
173
+ choices=["adamw", "adam", "sgd", "nadam", "radam", "rmsprop"],
174
+ help=f"Optimizer for training. Available: {list_optimizers()}",
175
+ )
176
+ parser.add_argument(
177
+ "--momentum", type=float, default=0.9, help="Momentum for SGD/RMSprop"
178
+ )
179
+ parser.add_argument(
180
+ "--nesterov", action="store_true", help="Use Nesterov momentum (SGD)"
181
+ )
182
+ parser.add_argument(
183
+ "--betas",
184
+ type=str,
185
+ default="0.9,0.999",
186
+ help="Betas for Adam variants (comma-separated)",
187
+ )
188
+
189
+ # Learning Rate Scheduler
190
+ parser.add_argument(
191
+ "--scheduler",
192
+ type=str,
193
+ default="plateau",
194
+ choices=[
195
+ "plateau",
196
+ "cosine",
197
+ "cosine_restarts",
198
+ "onecycle",
199
+ "step",
200
+ "multistep",
201
+ "exponential",
202
+ "linear_warmup",
203
+ ],
204
+ help=f"LR scheduler. Available: {list_schedulers()}",
205
+ )
206
+ parser.add_argument(
207
+ "--scheduler_patience",
208
+ type=int,
209
+ default=10,
210
+ help="Patience for ReduceLROnPlateau",
211
+ )
212
+ parser.add_argument(
213
+ "--min_lr", type=float, default=1e-6, help="Minimum learning rate"
214
+ )
215
+ parser.add_argument(
216
+ "--scheduler_factor", type=float, default=0.5, help="LR reduction factor"
217
+ )
218
+ parser.add_argument(
219
+ "--warmup_epochs", type=int, default=5, help="Warmup epochs for linear_warmup"
220
+ )
221
+ parser.add_argument(
222
+ "--step_size", type=int, default=30, help="Step size for StepLR"
223
+ )
224
+ parser.add_argument(
225
+ "--milestones",
226
+ type=str,
227
+ default=None,
228
+ help="Comma-separated epochs for MultiStepLR (e.g., '30,60,90')",
229
+ )
230
+
231
+ # Data
232
+ parser.add_argument(
233
+ "--data_path", type=str, default="train_data.npz", help="Path to NPZ dataset"
234
+ )
235
+ parser.add_argument(
236
+ "--workers",
237
+ type=int,
238
+ default=-1,
239
+ help="DataLoader workers per GPU (-1=auto-detect based on CPU cores)",
240
+ )
241
+ parser.add_argument("--seed", type=int, default=2025, help="Random seed")
242
+ parser.add_argument(
243
+ "--single_channel",
244
+ action="store_true",
245
+ help="Confirm data is single-channel (suppress ambiguous shape warnings for shallow 3D volumes)",
246
+ )
247
+
248
+ # Cross-Validation
249
+ parser.add_argument(
250
+ "--cv",
251
+ type=int,
252
+ default=0,
253
+ help="Enable K-fold cross-validation with K folds (0=disabled)",
254
+ )
255
+ parser.add_argument(
256
+ "--cv_stratify",
257
+ action="store_true",
258
+ help="Use stratified splitting for cross-validation",
259
+ )
260
+ parser.add_argument(
261
+ "--cv_bins",
262
+ type=int,
263
+ default=10,
264
+ help="Number of bins for stratified CV (only with --cv_stratify)",
265
+ )
266
+
267
+ # Checkpointing & Resume
268
+ parser.add_argument(
269
+ "--resume", type=str, default=None, help="Checkpoint directory to resume from"
270
+ )
271
+ parser.add_argument(
272
+ "--save_every",
273
+ type=int,
274
+ default=50,
275
+ help="Save checkpoint every N epochs (0=disable)",
276
+ )
277
+ parser.add_argument(
278
+ "--output_dir", type=str, default=".", help="Output directory for checkpoints"
279
+ )
280
+ parser.add_argument(
281
+ "--fresh",
282
+ action="store_true",
283
+ help="Force fresh training, ignore existing checkpoints",
284
+ )
285
+
286
+ # Performance
287
+ parser.add_argument(
288
+ "--compile", action="store_true", help="Enable torch.compile (PyTorch 2.x)"
289
+ )
290
+ parser.add_argument(
291
+ "--precision",
292
+ type=str,
293
+ default="bf16",
294
+ choices=["bf16", "fp16", "no"],
295
+ help="Mixed precision mode",
296
+ )
297
+
298
+ # Logging
299
+ parser.add_argument(
300
+ "--wandb", action="store_true", help="Enable Weights & Biases logging"
301
+ )
302
+ parser.add_argument(
303
+ "--project_name", type=str, default="DL-Training", help="WandB project name"
304
+ )
305
+ parser.add_argument("--run_name", type=str, default=None, help="WandB run name")
306
+
307
+ args = parser.parse_args()
308
+ return args, parser # Returns (Namespace, ArgumentParser)
309
+
310
+
311
+ # ==============================================================================
312
+ # MAIN TRAINING FUNCTION
313
+ # ==============================================================================
314
+ def main():
315
+ args, parser = parse_args()
316
+
317
+ # Handle --list_models flag
318
+ if args.list_models:
319
+ print("Available models:")
320
+ for name in list_models():
321
+ ModelClass = get_model(name)
322
+ # Get first non-empty docstring line
323
+ if ModelClass.__doc__:
324
+ lines = [
325
+ l.strip() for l in ModelClass.__doc__.splitlines() if l.strip()
326
+ ]
327
+ doc_first_line = lines[0] if lines else "No description"
328
+ else:
329
+ doc_first_line = "No description"
330
+ print(f" - {name}: {doc_first_line}")
331
+ sys.exit(0)
332
+
333
+ # Load and merge config file if provided
334
+ if args.config:
335
+ from wavedl.utils.config import (
336
+ load_config,
337
+ merge_config_with_args,
338
+ validate_config,
339
+ )
340
+
341
+ print(f"📄 Loading config from: {args.config}")
342
+ config = load_config(args.config)
343
+
344
+ # Validate config values
345
+ warnings_list = validate_config(config)
346
+ for w in warnings_list:
347
+ print(f" ⚠ {w}")
348
+
349
+ # Merge config with CLI args (CLI takes precedence via parser defaults detection)
350
+ args = merge_config_with_args(config, args, parser=parser)
351
+
352
+ # Handle --cv flag (cross-validation mode)
353
+ if args.cv > 0:
354
+ print(f"🔄 Cross-Validation Mode: {args.cv} folds")
355
+ from wavedl.utils.cross_validation import run_cross_validation
356
+
357
+ # Load data for CV using memory-efficient loader
358
+ from wavedl.utils.data import DataSource, get_data_source
359
+
360
+ data_format = DataSource.detect_format(args.data_path)
361
+ source = get_data_source(data_format)
362
+
363
+ # Use memory-mapped loading when available
364
+ _cv_handle = None
365
+ if hasattr(source, "load_mmap"):
366
+ result = source.load_mmap(args.data_path)
367
+ if hasattr(result, "inputs"):
368
+ _cv_handle = result
369
+ X, y = result.inputs, result.outputs
370
+ else:
371
+ X, y = result # NPZ returns tuple directly
372
+ else:
373
+ X, y = source.load(args.data_path)
374
+
375
+ # Handle sparse matrices (must materialize for CV shuffling)
376
+ if hasattr(X, "__getitem__") and len(X) > 0 and hasattr(X[0], "toarray"):
377
+ X = np.stack([x.toarray() for x in X])
378
+
379
+ # Normalize target shape: (N,) -> (N, 1) for consistency
380
+ if y.ndim == 1:
381
+ y = y.reshape(-1, 1)
382
+
383
+ # Validate and determine input shape (consistent with prepare_data)
384
+ # Check for ambiguous shapes that could be multi-channel or shallow 3D volume
385
+ sample_shape = X.shape[1:] # Per-sample shape
386
+
387
+ # Same heuristic as prepare_data: detect ambiguous 3D shapes
388
+ is_ambiguous_shape = (
389
+ len(sample_shape) == 3 # Exactly 3D: could be (C, H, W) or (D, H, W)
390
+ and sample_shape[0] <= 16 # First dim looks like channels
391
+ and sample_shape[1] > 16
392
+ and sample_shape[2] > 16 # Both spatial dims are large
393
+ )
394
+
395
+ if is_ambiguous_shape and not args.single_channel:
396
+ raise ValueError(
397
+ f"Ambiguous input shape detected: sample shape {sample_shape}. "
398
+ f"This could be either:\n"
399
+ f" - Multi-channel 2D data (C={sample_shape[0]}, H={sample_shape[1]}, W={sample_shape[2]})\n"
400
+ f" - Single-channel 3D volume (D={sample_shape[0]}, H={sample_shape[1]}, W={sample_shape[2]})\n\n"
401
+ f"If this is single-channel 3D/shallow volume data, use --single_channel flag.\n"
402
+ f"If this is multi-channel 2D data, reshape to (N*C, H, W) with adjusted targets."
403
+ )
404
+
405
+ # in_shape = spatial dimensions for model registry (channel added during training)
406
+ in_shape = sample_shape
407
+
408
+ # Run cross-validation
409
+ try:
410
+ run_cross_validation(
411
+ X=X,
412
+ y=y,
413
+ model_name=args.model,
414
+ in_shape=in_shape,
415
+ out_size=y.shape[1],
416
+ folds=args.cv,
417
+ stratify=args.cv_stratify,
418
+ stratify_bins=args.cv_bins,
419
+ batch_size=args.batch_size,
420
+ lr=args.lr,
421
+ epochs=args.epochs,
422
+ patience=args.patience,
423
+ weight_decay=args.weight_decay,
424
+ loss_name=args.loss,
425
+ optimizer_name=args.optimizer,
426
+ scheduler_name=args.scheduler,
427
+ output_dir=args.output_dir,
428
+ workers=args.workers,
429
+ seed=args.seed,
430
+ )
431
+ finally:
432
+ # Clean up file handle if HDF5/MAT
433
+ if _cv_handle is not None and hasattr(_cv_handle, "close"):
434
+ try:
435
+ _cv_handle.close()
436
+ except Exception:
437
+ pass
438
+ return
439
+
440
+ # ==========================================================================
441
+ # 1. SYSTEM INITIALIZATION
442
+ # ==========================================================================
443
+ # Initialize Accelerator for DDP and mixed precision
444
+ accelerator = Accelerator(
445
+ mixed_precision=args.precision,
446
+ log_with="wandb" if args.wandb and WANDB_AVAILABLE else None,
447
+ )
448
+ set_seed(args.seed)
449
+
450
+ # Configure logging (rank 0 only prints to console)
451
+ logging.basicConfig(
452
+ level=logging.INFO if accelerator.is_main_process else logging.ERROR,
453
+ format="%(asctime)s | %(levelname)s | %(message)s",
454
+ datefmt="%H:%M:%S",
455
+ )
456
+ logger = logging.getLogger("Trainer")
457
+
458
+ # Ensure output directory exists (critical for cache files, checkpoints, etc.)
459
+ os.makedirs(args.output_dir, exist_ok=True)
460
+
461
+ # Auto-detect optimal DataLoader workers if not specified
462
+ if args.workers < 0:
463
+ cpu_count = os.cpu_count() or 4
464
+ num_gpus = accelerator.num_processes
465
+ # Heuristic: 4-8 workers per GPU, bounded by available CPU cores
466
+ # Leave some cores for main process and system overhead
467
+ args.workers = min(8, max(2, (cpu_count - 2) // num_gpus))
468
+ if accelerator.is_main_process:
469
+ logger.info(
470
+ f"⚙️ Auto-detected workers: {args.workers} per GPU "
471
+ f"(CPUs: {cpu_count}, GPUs: {num_gpus})"
472
+ )
473
+
474
+ if accelerator.is_main_process:
475
+ logger.info(f"🚀 Cluster Status: {accelerator.num_processes}x GPUs detected")
476
+ logger.info(
477
+ f" Model: {args.model} | Precision: {args.precision} | Compile: {args.compile}"
478
+ )
479
+ logger.info(
480
+ f" Loss: {args.loss} | Optimizer: {args.optimizer} | Scheduler: {args.scheduler}"
481
+ )
482
+ logger.info(f" Early Stopping Patience: {args.patience} epochs")
483
+ if args.save_every > 0:
484
+ logger.info(f" Periodic Checkpointing: Every {args.save_every} epochs")
485
+ if args.resume:
486
+ logger.info(f" 📂 Resuming from: {args.resume}")
487
+
488
+ # Initialize WandB
489
+ if args.wandb and WANDB_AVAILABLE:
490
+ accelerator.init_trackers(
491
+ project_name=args.project_name,
492
+ config=vars(args),
493
+ init_kwargs={"wandb": {"name": args.run_name or f"{args.model}_run"}},
494
+ )
495
+
496
+ # ==========================================================================
497
+ # 2. DATA & MODEL LOADING
498
+ # ==========================================================================
499
+ train_dl, val_dl, scaler, in_shape, out_dim = prepare_data(
500
+ args, logger, accelerator, cache_dir=args.output_dir
501
+ )
502
+
503
+ # Build model using registry
504
+ model = build_model(args.model, in_shape=in_shape, out_size=out_dim)
505
+
506
+ if accelerator.is_main_process:
507
+ param_info = model.parameter_summary()
508
+ logger.info(
509
+ f" Model Parameters: {param_info['trainable_parameters']:,} trainable"
510
+ )
511
+ logger.info(f" Model Size: {param_info['total_mb']:.2f} MB")
512
+
513
+ # Optional WandB model watching
514
+ if args.wandb and WANDB_AVAILABLE and accelerator.is_main_process:
515
+ wandb.watch(model, log="gradients", log_freq=100)
516
+
517
+ # Torch 2.0 compilation (requires compatible Triton on GPU)
518
+ if args.compile:
519
+ try:
520
+ # Test if Triton is available AND compatible with this PyTorch version
521
+ # PyTorch needs triton_key from triton.compiler.compiler
522
+ from triton.compiler.compiler import triton_key
523
+
524
+ model = torch.compile(model)
525
+ if accelerator.is_main_process:
526
+ logger.info(" ✔ torch.compile enabled (Triton backend)")
527
+ except ImportError as e:
528
+ if accelerator.is_main_process:
529
+ if "triton" in str(e).lower():
530
+ logger.warning(
531
+ " ⚠ Triton not installed or incompatible version - torch.compile disabled. "
532
+ "Training will proceed without compilation."
533
+ )
534
+ else:
535
+ logger.warning(
536
+ f" ⚠ torch.compile setup failed: {e}. Continuing without compilation."
537
+ )
538
+ except Exception as e:
539
+ if accelerator.is_main_process:
540
+ logger.warning(
541
+ f" ⚠ torch.compile failed: {e}. Continuing without compilation."
542
+ )
543
+
544
+ # ==========================================================================
545
+ # 2.5. OPTIMIZER, SCHEDULER & LOSS CONFIGURATION
546
+ # ==========================================================================
547
+ # Parse comma-separated arguments with validation
548
+ try:
549
+ betas_list = [float(x.strip()) for x in args.betas.split(",")]
550
+ if len(betas_list) != 2:
551
+ raise ValueError(
552
+ f"--betas must have exactly 2 values, got {len(betas_list)}"
553
+ )
554
+ if not all(0.0 <= b < 1.0 for b in betas_list):
555
+ raise ValueError(f"--betas values must be in [0, 1), got {betas_list}")
556
+ betas = tuple(betas_list)
557
+ except ValueError as e:
558
+ raise ValueError(
559
+ f"Invalid --betas format '{args.betas}': {e}. Expected format: '0.9,0.999'"
560
+ )
561
+
562
+ loss_weights = None
563
+ if args.loss_weights:
564
+ loss_weights = [float(x.strip()) for x in args.loss_weights.split(",")]
565
+ milestones = None
566
+ if args.milestones:
567
+ milestones = [int(x.strip()) for x in args.milestones.split(",")]
568
+
569
+ # Create optimizer using factory
570
+ optimizer = get_optimizer(
571
+ name=args.optimizer,
572
+ params=model.get_optimizer_groups(args.lr, args.weight_decay),
573
+ lr=args.lr,
574
+ weight_decay=args.weight_decay,
575
+ momentum=args.momentum,
576
+ nesterov=args.nesterov,
577
+ betas=betas,
578
+ )
579
+
580
+ # Create loss function using factory
581
+ criterion = get_loss(
582
+ name=args.loss,
583
+ weights=loss_weights,
584
+ delta=args.huber_delta,
585
+ )
586
+ # Move criterion to device (important for WeightedMSELoss buffer)
587
+ criterion = criterion.to(accelerator.device)
588
+
589
+ # Track if scheduler should step per batch (OneCycleLR) or per epoch
590
+ scheduler_step_per_batch = not is_epoch_based(args.scheduler)
591
+
592
+ # ==========================================================================
593
+ # DDP Preparation Strategy:
594
+ # - For batch-based schedulers (OneCycleLR): prepare DataLoaders first to get
595
+ # the correct sharded batch count, then create scheduler
596
+ # - For epoch-based schedulers: create scheduler before prepare (no issue)
597
+ # ==========================================================================
598
+ if scheduler_step_per_batch:
599
+ # BATCH-BASED SCHEDULER (e.g., OneCycleLR)
600
+ # Prepare model, optimizer, dataloaders FIRST to get sharded loader length
601
+ model, optimizer, train_dl, val_dl = accelerator.prepare(
602
+ model, optimizer, train_dl, val_dl
603
+ )
604
+
605
+ # Now create scheduler with the CORRECT sharded steps_per_epoch
606
+ steps_per_epoch = len(train_dl) # Post-DDP sharded length
607
+ scheduler = get_scheduler(
608
+ name=args.scheduler,
609
+ optimizer=optimizer,
610
+ epochs=args.epochs,
611
+ steps_per_epoch=steps_per_epoch,
612
+ min_lr=args.min_lr,
613
+ patience=args.scheduler_patience,
614
+ factor=args.scheduler_factor,
615
+ gamma=args.scheduler_factor, # For Step/MultiStep/Exponential schedulers
616
+ step_size=args.step_size,
617
+ milestones=milestones,
618
+ warmup_epochs=args.warmup_epochs,
619
+ )
620
+ # Prepare scheduler separately (Accelerator wraps it for state saving)
621
+ scheduler = accelerator.prepare(scheduler)
622
+ else:
623
+ # EPOCH-BASED SCHEDULER (plateau, cosine, step, etc.)
624
+ # No batch count dependency - create scheduler before prepare
625
+ scheduler = get_scheduler(
626
+ name=args.scheduler,
627
+ optimizer=optimizer,
628
+ epochs=args.epochs,
629
+ steps_per_epoch=None,
630
+ min_lr=args.min_lr,
631
+ patience=args.scheduler_patience,
632
+ factor=args.scheduler_factor,
633
+ gamma=args.scheduler_factor, # For Step/MultiStep/Exponential schedulers
634
+ step_size=args.step_size,
635
+ milestones=milestones,
636
+ warmup_epochs=args.warmup_epochs,
637
+ )
638
+ # Prepare everything together
639
+ model, optimizer, train_dl, val_dl, scheduler = accelerator.prepare(
640
+ model, optimizer, train_dl, val_dl, scheduler
641
+ )
642
+
643
+ # ==========================================================================
644
+ # 3. AUTO-RESUME / RESUME FROM CHECKPOINT
645
+ # ==========================================================================
646
+ start_epoch = 0
647
+ best_val_loss = float("inf")
648
+ patience_ctr = 0
649
+ history: list[dict[str, Any]] = []
650
+
651
+ # Define checkpoint paths
652
+ best_ckpt_path = os.path.join(args.output_dir, "best_checkpoint")
653
+ complete_flag_path = os.path.join(args.output_dir, "training_complete.flag")
654
+
655
+ # Auto-resume logic (if not --fresh and no explicit --resume)
656
+ if not args.fresh and args.resume is None:
657
+ if os.path.exists(complete_flag_path):
658
+ # Training already completed
659
+ if accelerator.is_main_process:
660
+ logger.info(
661
+ "✅ Training already completed (early stopping). Use --fresh to retrain."
662
+ )
663
+ return # Exit gracefully
664
+ elif os.path.exists(best_ckpt_path):
665
+ # Incomplete training found - auto-resume
666
+ args.resume = best_ckpt_path
667
+ if accelerator.is_main_process:
668
+ logger.info(f"🔄 Auto-resuming from: {best_ckpt_path}")
669
+
670
+ if args.resume:
671
+ if os.path.exists(args.resume):
672
+ logger.info(f"🔄 Loading checkpoint from: {args.resume}")
673
+ accelerator.load_state(args.resume)
674
+
675
+ # Restore training metadata
676
+ meta_path = os.path.join(args.resume, "training_meta.pkl")
677
+ if os.path.exists(meta_path):
678
+ with open(meta_path, "rb") as f:
679
+ meta = pickle.load(f)
680
+ start_epoch = meta.get("epoch", 0)
681
+ best_val_loss = meta.get("best_val_loss", float("inf"))
682
+ patience_ctr = meta.get("patience_ctr", 0)
683
+ logger.info(
684
+ f" ✅ Restored: Epoch {start_epoch}, Best Loss: {best_val_loss:.6f}"
685
+ )
686
+ else:
687
+ logger.warning(
688
+ " ⚠️ training_meta.pkl not found, starting from epoch 0"
689
+ )
690
+
691
+ # Restore history
692
+ history_path = os.path.join(args.output_dir, "training_history.csv")
693
+ if os.path.exists(history_path):
694
+ history = pd.read_csv(history_path).to_dict("records")
695
+ logger.info(f" ✅ Loaded {len(history)} epochs from history")
696
+ else:
697
+ raise FileNotFoundError(f"Checkpoint not found: {args.resume}")
698
+
699
+ # ==========================================================================
700
+ # 4. PHYSICAL METRIC SETUP
701
+ # ==========================================================================
702
+ # Physical MAE = normalized MAE * scaler.scale_
703
+ phys_scale = torch.tensor(
704
+ scaler.scale_, device=accelerator.device, dtype=torch.float32
705
+ )
706
+
707
+ # ==========================================================================
708
+ # 5. TRAINING LOOP
709
+ # ==========================================================================
710
+ # Dynamic console header
711
+ if accelerator.is_main_process:
712
+ base_cols = ["Ep", "TrnLoss", "ValLoss", "R2", "PCC", "GradN", "LR", "MAE_Avg"]
713
+ param_cols = [f"MAE_P{i}" for i in range(out_dim)]
714
+ header = "{:<4} | {:<8} | {:<8} | {:<6} | {:<6} | {:<6} | {:<8} | {:<8}".format(
715
+ *base_cols
716
+ )
717
+ header += " | " + " | ".join([f"{c:<8}" for c in param_cols])
718
+ logger.info("=" * len(header))
719
+ logger.info(header)
720
+ logger.info("=" * len(header))
721
+
722
+ try:
723
+ time.time()
724
+ total_training_time = 0.0
725
+
726
+ for epoch in range(start_epoch, args.epochs):
727
+ epoch_start_time = time.time()
728
+
729
+ # ==================== TRAINING PHASE ====================
730
+ model.train()
731
+ # Use GPU tensor for loss accumulation to avoid .item() sync per batch
732
+ train_loss_sum = torch.tensor(0.0, device=accelerator.device)
733
+ train_samples = 0
734
+ grad_norm_tracker = MetricTracker()
735
+
736
+ pbar = tqdm(
737
+ train_dl,
738
+ disable=not accelerator.is_main_process,
739
+ leave=False,
740
+ desc=f"Epoch {epoch + 1}",
741
+ )
742
+
743
+ for x, y in pbar:
744
+ with accelerator.accumulate(model):
745
+ pred = model(x)
746
+ loss = criterion(pred, y)
747
+
748
+ accelerator.backward(loss)
749
+
750
+ if accelerator.sync_gradients:
751
+ grad_norm = accelerator.clip_grad_norm_(
752
+ model.parameters(), args.grad_clip
753
+ )
754
+ if grad_norm is not None:
755
+ grad_norm_tracker.update(grad_norm.item())
756
+
757
+ optimizer.step()
758
+ optimizer.zero_grad(set_to_none=True) # Faster than zero_grad()
759
+
760
+ # Per-batch LR scheduling (e.g., OneCycleLR)
761
+ if scheduler_step_per_batch:
762
+ scheduler.step()
763
+
764
+ # Accumulate as tensors to avoid .item() sync per batch
765
+ train_loss_sum += loss.detach() * x.size(0)
766
+ train_samples += x.size(0)
767
+
768
+ # Single .item() call at end of epoch (reduces GPU sync overhead)
769
+ train_loss_scalar = train_loss_sum.item()
770
+
771
+ # Synchronize training metrics across GPUs
772
+ global_loss = accelerator.reduce(
773
+ torch.tensor([train_loss_scalar], device=accelerator.device),
774
+ reduction="sum",
775
+ ).item()
776
+ global_samples = accelerator.reduce(
777
+ torch.tensor([train_samples], device=accelerator.device),
778
+ reduction="sum",
779
+ ).item()
780
+ avg_train_loss = global_loss / global_samples
781
+
782
+ # ==================== VALIDATION PHASE ====================
783
+ model.eval()
784
+ # Use GPU tensor for loss accumulation (consistent with training phase)
785
+ val_loss_sum = torch.tensor(0.0, device=accelerator.device)
786
+ val_mae_sum = torch.zeros(out_dim, device=accelerator.device)
787
+ val_samples = 0
788
+
789
+ # Accumulate predictions locally, gather ONCE at end (reduces sync overhead)
790
+ local_preds = []
791
+ local_targets = []
792
+
793
+ with torch.inference_mode():
794
+ for x, y in val_dl:
795
+ pred = model(x)
796
+ loss = criterion(pred, y)
797
+
798
+ val_loss_sum += loss.detach() * x.size(0)
799
+ val_samples += x.size(0)
800
+
801
+ # Physical MAE
802
+ mae_batch = torch.abs((pred - y) * phys_scale).sum(dim=0)
803
+ val_mae_sum += mae_batch
804
+
805
+ # Store locally (no GPU sync per batch)
806
+ local_preds.append(pred)
807
+ local_targets.append(y)
808
+
809
+ # Single gather at end of validation (2 syncs instead of 2×num_batches)
810
+ all_local_preds = torch.cat(local_preds)
811
+ all_local_targets = torch.cat(local_targets)
812
+ all_preds = accelerator.gather_for_metrics(all_local_preds)
813
+ all_targets = accelerator.gather_for_metrics(all_local_targets)
814
+
815
+ # Synchronize validation metrics
816
+ val_loss_scalar = val_loss_sum.item()
817
+ val_metrics = torch.cat(
818
+ [
819
+ torch.tensor([val_loss_scalar], device=accelerator.device),
820
+ val_mae_sum,
821
+ ]
822
+ )
823
+ val_metrics_sync = accelerator.reduce(val_metrics, reduction="sum")
824
+
825
+ total_val_samples = accelerator.reduce(
826
+ torch.tensor([val_samples], device=accelerator.device), reduction="sum"
827
+ ).item()
828
+
829
+ avg_val_loss = val_metrics_sync[0].item() / total_val_samples
830
+ # Cast to float32 before numpy (bf16 tensors can't convert directly)
831
+ avg_mae_per_param = (
832
+ (val_metrics_sync[1:] / total_val_samples).float().cpu().numpy()
833
+ )
834
+ avg_mae = avg_mae_per_param.mean()
835
+
836
+ # ==================== LOGGING & CHECKPOINTING ====================
837
+ if accelerator.is_main_process:
838
+ # Scientific metrics - cast to float32 before numpy (bf16 can't convert)
839
+ y_pred = all_preds.float().cpu().numpy()
840
+ y_true = all_targets.float().cpu().numpy()
841
+
842
+ # Trim DDP padding
843
+ real_len = len(val_dl.dataset)
844
+ if len(y_pred) > real_len:
845
+ y_pred = y_pred[:real_len]
846
+ y_true = y_true[:real_len]
847
+
848
+ # Guard against tiny validation sets (R² undefined for <2 samples)
849
+ if len(y_true) >= 2:
850
+ r2 = r2_score(y_true, y_pred)
851
+ else:
852
+ r2 = float("nan")
853
+ pcc = calc_pearson(y_true, y_pred)
854
+ current_lr = get_lr(optimizer)
855
+
856
+ # Update history
857
+ epoch_end_time = time.time()
858
+ epoch_time = epoch_end_time - epoch_start_time
859
+ total_training_time += epoch_time
860
+
861
+ epoch_stats = {
862
+ "epoch": epoch + 1,
863
+ "train_loss": avg_train_loss,
864
+ "val_loss": avg_val_loss,
865
+ "val_r2": r2,
866
+ "val_pearson": pcc,
867
+ "val_mae_avg": avg_mae,
868
+ "grad_norm": grad_norm_tracker.avg,
869
+ "lr": current_lr,
870
+ "epoch_time": round(epoch_time, 2),
871
+ "total_time": round(total_training_time, 2),
872
+ }
873
+ for i, mae in enumerate(avg_mae_per_param):
874
+ epoch_stats[f"MAE_Phys_P{i}"] = mae
875
+
876
+ history.append(epoch_stats)
877
+
878
+ # Console display
879
+ base_str = f"{epoch + 1:<4} | {avg_train_loss:<8.4f} | {avg_val_loss:<8.4f} | {r2:<6.4f} | {pcc:<6.4f} | {grad_norm_tracker.avg:<6.4f} | {current_lr:<8.2e} | {avg_mae:<8.4f}"
880
+ param_str = " | ".join([f"{m:<8.4f}" for m in avg_mae_per_param])
881
+ logger.info(f"{base_str} | {param_str}")
882
+
883
+ # WandB logging
884
+ if args.wandb and WANDB_AVAILABLE:
885
+ log_dict = {
886
+ "main/train_loss": avg_train_loss,
887
+ "main/val_loss": avg_val_loss,
888
+ "metrics/r2_score": r2,
889
+ "metrics/pearson_corr": pcc,
890
+ "metrics/mae_avg": avg_mae,
891
+ "system/grad_norm": grad_norm_tracker.avg,
892
+ "hyper/lr": current_lr,
893
+ }
894
+ for i, mae in enumerate(avg_mae_per_param):
895
+ log_dict[f"mae_detailed/P{i}"] = mae
896
+
897
+ # Periodic scatter plots
898
+ if (epoch % 5 == 0) or (avg_val_loss < best_val_loss):
899
+ real_true = scaler.inverse_transform(y_true)
900
+ real_pred = scaler.inverse_transform(y_pred)
901
+ fig = plot_scientific_scatter(real_true, real_pred)
902
+ log_dict["plots/scatter_analysis"] = wandb.Image(fig)
903
+ plt.close(fig)
904
+
905
+ accelerator.log(log_dict)
906
+
907
+ # ==========================================================================
908
+ # DDP-SAFE CHECKPOINT LOGIC
909
+ # ==========================================================================
910
+ # Step 1: Determine if this is the best epoch (BEFORE updating best_val_loss)
911
+ is_best_epoch = False
912
+ if accelerator.is_main_process:
913
+ if avg_val_loss < best_val_loss:
914
+ is_best_epoch = True
915
+
916
+ # Step 2: Broadcast decision to all ranks (required for save_state)
917
+ is_best_epoch = broadcast_early_stop(is_best_epoch, accelerator)
918
+
919
+ # Step 3: Save checkpoint with all ranks participating
920
+ if is_best_epoch:
921
+ ckpt_dir = os.path.join(args.output_dir, "best_checkpoint")
922
+ accelerator.save_state(ckpt_dir) # All ranks must call this
923
+
924
+ # Step 4: Rank 0 handles metadata and updates tracking variables
925
+ if accelerator.is_main_process:
926
+ best_val_loss = avg_val_loss # Update AFTER checkpoint saved
927
+ patience_ctr = 0
928
+
929
+ with open(os.path.join(ckpt_dir, "training_meta.pkl"), "wb") as f:
930
+ pickle.dump(
931
+ {
932
+ "epoch": epoch + 1,
933
+ "best_val_loss": best_val_loss,
934
+ "patience_ctr": patience_ctr,
935
+ # Model info for auto-detection during inference
936
+ "model_name": args.model,
937
+ "in_shape": in_shape,
938
+ "out_dim": out_dim,
939
+ },
940
+ f,
941
+ )
942
+
943
+ unwrapped = accelerator.unwrap_model(model)
944
+ torch.save(
945
+ unwrapped.state_dict(),
946
+ os.path.join(args.output_dir, "best_model_weights.pth"),
947
+ )
948
+
949
+ # Copy scaler to checkpoint for portability
950
+ scaler_src = os.path.join(args.output_dir, "scaler.pkl")
951
+ scaler_dst = os.path.join(ckpt_dir, "scaler.pkl")
952
+ if os.path.exists(scaler_src) and not os.path.exists(scaler_dst):
953
+ shutil.copy2(scaler_src, scaler_dst)
954
+
955
+ logger.info(
956
+ f" 💾 Best model saved (val_loss: {best_val_loss:.6f})"
957
+ )
958
+
959
+ # Also save CSV on best model (ensures progress is saved)
960
+ pd.DataFrame(history).to_csv(
961
+ os.path.join(args.output_dir, "training_history.csv"),
962
+ index=False,
963
+ )
964
+ else:
965
+ if accelerator.is_main_process:
966
+ patience_ctr += 1
967
+
968
+ # Periodic checkpoint (all ranks participate in save_state)
969
+ periodic_checkpoint_needed = (
970
+ args.save_every > 0 and (epoch + 1) % args.save_every == 0
971
+ )
972
+ if periodic_checkpoint_needed:
973
+ ckpt_name = f"epoch_{epoch + 1}_checkpoint"
974
+ ckpt_dir = os.path.join(args.output_dir, ckpt_name)
975
+ accelerator.save_state(ckpt_dir) # All ranks participate
976
+
977
+ if accelerator.is_main_process:
978
+ with open(os.path.join(ckpt_dir, "training_meta.pkl"), "wb") as f:
979
+ pickle.dump(
980
+ {
981
+ "epoch": epoch + 1,
982
+ "best_val_loss": best_val_loss,
983
+ "patience_ctr": patience_ctr,
984
+ # Model info for auto-detection during inference
985
+ "model_name": args.model,
986
+ "in_shape": in_shape,
987
+ "out_dim": out_dim,
988
+ },
989
+ f,
990
+ )
991
+ logger.info(f" 📁 Periodic checkpoint: {ckpt_name}")
992
+
993
+ # Save CSV with each checkpoint (keeps logs in sync with model state)
994
+ pd.DataFrame(history).to_csv(
995
+ os.path.join(args.output_dir, "training_history.csv"),
996
+ index=False,
997
+ )
998
+
999
+ # Learning rate scheduling (epoch-based schedulers only)
1000
+ if not scheduler_step_per_batch:
1001
+ if args.scheduler == "plateau":
1002
+ scheduler.step(avg_val_loss)
1003
+ else:
1004
+ scheduler.step()
1005
+
1006
+ # DDP-safe early stopping
1007
+ should_stop = (
1008
+ patience_ctr >= args.patience if accelerator.is_main_process else False
1009
+ )
1010
+ if broadcast_early_stop(should_stop, accelerator):
1011
+ if accelerator.is_main_process:
1012
+ logger.info(
1013
+ f"🛑 Early stopping at epoch {epoch + 1} (patience={args.patience})"
1014
+ )
1015
+ # Create completion flag to prevent auto-resume
1016
+ with open(
1017
+ os.path.join(args.output_dir, "training_complete.flag"), "w"
1018
+ ) as f:
1019
+ f.write(
1020
+ f"Training completed via early stopping at epoch {epoch + 1}\n"
1021
+ )
1022
+ break
1023
+
1024
+ except KeyboardInterrupt:
1025
+ logger.warning("Training interrupted. Saving emergency checkpoint...")
1026
+ accelerator.save_state(os.path.join(args.output_dir, "interrupted_checkpoint"))
1027
+
1028
+ except Exception as e:
1029
+ logger.error(f"Critical error: {e}", exc_info=True)
1030
+ raise
1031
+
1032
+ else:
1033
+ # Training completed normally (reached max epochs without early stopping)
1034
+ # Create completion flag to prevent auto-resume on re-run
1035
+ if accelerator.is_main_process:
1036
+ if not os.path.exists(complete_flag_path):
1037
+ with open(complete_flag_path, "w") as f:
1038
+ f.write(f"Training completed normally after {args.epochs} epochs\n")
1039
+ logger.info(f"✅ Training completed after {args.epochs} epochs")
1040
+
1041
+ finally:
1042
+ # Final CSV write to capture all epochs (handles non-multiple-of-10 endings)
1043
+ if accelerator.is_main_process and len(history) > 0:
1044
+ pd.DataFrame(history).to_csv(
1045
+ os.path.join(args.output_dir, "training_history.csv"),
1046
+ index=False,
1047
+ )
1048
+
1049
+ # Generate training curves plot (PNG + SVG)
1050
+ if accelerator.is_main_process and len(history) > 0:
1051
+ try:
1052
+ fig = create_training_curves(history, show_lr=True)
1053
+ for fmt in ["png", "svg"]:
1054
+ fig.savefig(
1055
+ os.path.join(args.output_dir, f"training_curves.{fmt}"),
1056
+ dpi=FIGURE_DPI,
1057
+ bbox_inches="tight",
1058
+ )
1059
+ plt.close(fig)
1060
+ logger.info("✔ Saved: training_curves.png, training_curves.svg")
1061
+ except Exception as e:
1062
+ logger.warning(f"Could not generate training curves: {e}")
1063
+
1064
+ if args.wandb and WANDB_AVAILABLE:
1065
+ accelerator.end_training()
1066
+
1067
+ # Clean up distributed process group to prevent resource leak warning
1068
+ if torch.distributed.is_initialized():
1069
+ torch.distributed.destroy_process_group()
1070
+
1071
+ logger.info("Training completed.")
1072
+
1073
+
1074
+ if __name__ == "__main__":
1075
+ try:
1076
+ torch.multiprocessing.set_start_method("spawn")
1077
+ except RuntimeError:
1078
+ pass
1079
+ main()