wavedl 1.3.0__py3-none-any.whl → 1.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
wavedl/train.py CHANGED
@@ -1,1117 +1,1113 @@
1
- """
2
- WaveDL - Deep Learning for Wave-based Inverse Problems
3
- =======================================================
4
- Target Environment: NVIDIA HPC GPUs (Multi-GPU DDP) | PyTorch 2.x | Python 3.11+
5
-
6
- A modular training framework for wave-based inverse problems and regression:
7
- 1. HPC-Grade DDP Training: BF16/FP16 mixed precision with torch.compile support
8
- 2. Dynamic Model Selection: Use --model flag to select any registered architecture
9
- 3. Zero-Copy Data Engine: Memmap-backed datasets for large-scale training
10
- 4. Physics-Aware Metrics: Real-time physical MAE with proper unscaling
11
- 5. Robust Checkpointing: Resume training, periodic saves, and training curves
12
- 6. Deep Observability: WandB integration with scatter analysis
13
-
14
- Usage:
15
- # Recommended: Using the HPC launcher
16
- wavedl-hpc --model cnn --batch_size 128 --wandb
17
-
18
- # Or with direct accelerate launch
19
- accelerate launch -m wavedl.train --model cnn --batch_size 128 --wandb
20
-
21
- # Multi-GPU with explicit config
22
- wavedl-hpc --num_gpus 4 --mixed_precision bf16 --model cnn --wandb
23
-
24
- # Resume from checkpoint
25
- accelerate launch -m wavedl.train --model cnn --resume best_checkpoint --wandb
26
-
27
- # List available models
28
- wavedl-train --list_models
29
-
30
- Note:
31
- For HPC clusters (Compute Canada, etc.), use wavedl-hpc which handles
32
- environment configuration automatically. Mixed precision is controlled via
33
- --mixed_precision flag (default: bf16).
34
-
35
- Author: Ductho Le (ductho.le@outlook.com)
36
- """
37
-
38
- # ==============================================================================
39
- # ENVIRONMENT CONFIGURATION FOR HPC SYSTEMS
40
- # ==============================================================================
41
- # IMPORTANT: These must be set BEFORE matplotlib is imported to be effective
42
- import os
43
-
44
-
45
- os.environ.setdefault("MPLCONFIGDIR", os.getenv("TMPDIR", "/tmp") + "/matplotlib")
46
- os.environ.setdefault("FONTCONFIG_PATH", "/etc/fonts")
47
-
48
- import argparse
49
- import logging
50
- import pickle
51
- import shutil
52
- import sys
53
- import time
54
- import warnings
55
- from typing import Any
56
-
57
- import matplotlib.pyplot as plt
58
- import numpy as np
59
- import pandas as pd
60
- import torch
61
- from accelerate import Accelerator
62
- from accelerate.utils import set_seed
63
- from sklearn.metrics import r2_score
64
- from tqdm.auto import tqdm
65
-
66
- # Local imports
67
- from wavedl.models import build_model, get_model, list_models
68
- from wavedl.utils import (
69
- FIGURE_DPI,
70
- MetricTracker,
71
- broadcast_early_stop,
72
- calc_pearson,
73
- create_training_curves,
74
- # New factory functions
75
- get_loss,
76
- get_lr,
77
- get_optimizer,
78
- get_scheduler,
79
- is_epoch_based,
80
- list_losses,
81
- list_optimizers,
82
- list_schedulers,
83
- plot_scientific_scatter,
84
- prepare_data,
85
- )
86
-
87
-
88
- # Optional WandB import
89
- try:
90
- import wandb
91
-
92
- WANDB_AVAILABLE = True
93
- except ImportError:
94
- WANDB_AVAILABLE = False
95
-
96
- # Filter non-critical warnings for cleaner training logs
97
- warnings.filterwarnings("ignore", category=UserWarning)
98
- warnings.filterwarnings("ignore", category=FutureWarning)
99
- warnings.filterwarnings("ignore", category=DeprecationWarning)
100
- warnings.filterwarnings("ignore", module="pydantic")
101
- warnings.filterwarnings("ignore", message=".*UnsupportedFieldAttributeWarning.*")
102
-
103
-
104
- # ==============================================================================
105
- # ARGUMENT PARSING
106
- # ==============================================================================
107
- def parse_args() -> argparse.Namespace:
108
- """Parse command-line arguments with comprehensive options."""
109
- parser = argparse.ArgumentParser(
110
- description="Universal DDP Training Pipeline",
111
- formatter_class=argparse.ArgumentDefaultsHelpFormatter,
112
- )
113
-
114
- # Model Selection
115
- parser.add_argument(
116
- "--model",
117
- type=str,
118
- default="cnn",
119
- help=f"Model architecture to train. Available: {list_models()}",
120
- )
121
- parser.add_argument(
122
- "--list_models", action="store_true", help="List all available models and exit"
123
- )
124
- parser.add_argument(
125
- "--import",
126
- dest="import_modules",
127
- type=str,
128
- nargs="+",
129
- default=[],
130
- help="Python modules to import before training (for custom models)",
131
- )
132
-
133
- # Configuration File
134
- parser.add_argument(
135
- "--config",
136
- type=str,
137
- default=None,
138
- help="Path to YAML config file. CLI args override config values.",
139
- )
140
-
141
- # Hyperparameters
142
- parser.add_argument(
143
- "--batch_size", type=int, default=128, help="Batch size per GPU"
144
- )
145
- parser.add_argument("--lr", type=float, default=1e-3, help="Initial learning rate")
146
- parser.add_argument(
147
- "--epochs", type=int, default=1000, help="Maximum training epochs"
148
- )
149
- parser.add_argument(
150
- "--patience", type=int, default=20, help="Early stopping patience"
151
- )
152
- parser.add_argument("--weight_decay", type=float, default=1e-4, help="Weight decay")
153
- parser.add_argument(
154
- "--grad_clip", type=float, default=1.0, help="Gradient clipping norm"
155
- )
156
-
157
- # Loss Function
158
- parser.add_argument(
159
- "--loss",
160
- type=str,
161
- default="mse",
162
- choices=["mse", "mae", "huber", "smooth_l1", "log_cosh", "weighted_mse"],
163
- help=f"Loss function for training. Available: {list_losses()}",
164
- )
165
- parser.add_argument(
166
- "--huber_delta", type=float, default=1.0, help="Delta for Huber loss"
167
- )
168
- parser.add_argument(
169
- "--loss_weights",
170
- type=str,
171
- default=None,
172
- help="Comma-separated weights for weighted_mse (e.g., '1.0,2.0,1.0')",
173
- )
174
-
175
- # Optimizer
176
- parser.add_argument(
177
- "--optimizer",
178
- type=str,
179
- default="adamw",
180
- choices=["adamw", "adam", "sgd", "nadam", "radam", "rmsprop"],
181
- help=f"Optimizer for training. Available: {list_optimizers()}",
182
- )
183
- parser.add_argument(
184
- "--momentum", type=float, default=0.9, help="Momentum for SGD/RMSprop"
185
- )
186
- parser.add_argument(
187
- "--nesterov", action="store_true", help="Use Nesterov momentum (SGD)"
188
- )
189
- parser.add_argument(
190
- "--betas",
191
- type=str,
192
- default="0.9,0.999",
193
- help="Betas for Adam variants (comma-separated)",
194
- )
195
-
196
- # Learning Rate Scheduler
197
- parser.add_argument(
198
- "--scheduler",
199
- type=str,
200
- default="plateau",
201
- choices=[
202
- "plateau",
203
- "cosine",
204
- "cosine_restarts",
205
- "onecycle",
206
- "step",
207
- "multistep",
208
- "exponential",
209
- "linear_warmup",
210
- ],
211
- help=f"LR scheduler. Available: {list_schedulers()}",
212
- )
213
- parser.add_argument(
214
- "--scheduler_patience",
215
- type=int,
216
- default=10,
217
- help="Patience for ReduceLROnPlateau",
218
- )
219
- parser.add_argument(
220
- "--min_lr", type=float, default=1e-6, help="Minimum learning rate"
221
- )
222
- parser.add_argument(
223
- "--scheduler_factor", type=float, default=0.5, help="LR reduction factor"
224
- )
225
- parser.add_argument(
226
- "--warmup_epochs", type=int, default=5, help="Warmup epochs for linear_warmup"
227
- )
228
- parser.add_argument(
229
- "--step_size", type=int, default=30, help="Step size for StepLR"
230
- )
231
- parser.add_argument(
232
- "--milestones",
233
- type=str,
234
- default=None,
235
- help="Comma-separated epochs for MultiStepLR (e.g., '30,60,90')",
236
- )
237
-
238
- # Data
239
- parser.add_argument(
240
- "--data_path", type=str, default="train_data.npz", help="Path to NPZ dataset"
241
- )
242
- parser.add_argument(
243
- "--workers",
244
- type=int,
245
- default=-1,
246
- help="DataLoader workers per GPU (-1=auto-detect based on CPU cores)",
247
- )
248
- parser.add_argument("--seed", type=int, default=2025, help="Random seed")
249
- parser.add_argument(
250
- "--single_channel",
251
- action="store_true",
252
- help="Confirm data is single-channel (suppress ambiguous shape warnings for shallow 3D volumes)",
253
- )
254
-
255
- # Cross-Validation
256
- parser.add_argument(
257
- "--cv",
258
- type=int,
259
- default=0,
260
- help="Enable K-fold cross-validation with K folds (0=disabled)",
261
- )
262
- parser.add_argument(
263
- "--cv_stratify",
264
- action="store_true",
265
- help="Use stratified splitting for cross-validation",
266
- )
267
- parser.add_argument(
268
- "--cv_bins",
269
- type=int,
270
- default=10,
271
- help="Number of bins for stratified CV (only with --cv_stratify)",
272
- )
273
-
274
- # Checkpointing & Resume
275
- parser.add_argument(
276
- "--resume", type=str, default=None, help="Checkpoint directory to resume from"
277
- )
278
- parser.add_argument(
279
- "--save_every",
280
- type=int,
281
- default=50,
282
- help="Save checkpoint every N epochs (0=disable)",
283
- )
284
- parser.add_argument(
285
- "--output_dir", type=str, default=".", help="Output directory for checkpoints"
286
- )
287
- parser.add_argument(
288
- "--fresh",
289
- action="store_true",
290
- help="Force fresh training, ignore existing checkpoints",
291
- )
292
-
293
- # Performance
294
- parser.add_argument(
295
- "--compile", action="store_true", help="Enable torch.compile (PyTorch 2.x)"
296
- )
297
- parser.add_argument(
298
- "--precision",
299
- type=str,
300
- default="bf16",
301
- choices=["bf16", "fp16", "no"],
302
- help="Mixed precision mode",
303
- )
304
-
305
- # Logging
306
- parser.add_argument(
307
- "--wandb", action="store_true", help="Enable Weights & Biases logging"
308
- )
309
- parser.add_argument(
310
- "--project_name", type=str, default="DL-Training", help="WandB project name"
311
- )
312
- parser.add_argument("--run_name", type=str, default=None, help="WandB run name")
313
-
314
- args = parser.parse_args()
315
- return args, parser # Returns (Namespace, ArgumentParser)
316
-
317
-
318
- # ==============================================================================
319
- # MAIN TRAINING FUNCTION
320
- # ==============================================================================
321
- def main():
322
- args, parser = parse_args()
323
-
324
- # Import custom model modules if specified
325
- if args.import_modules:
326
- import importlib
327
- import sys
328
-
329
- for module_name in args.import_modules:
330
- try:
331
- # Handle both module names (my_model) and file paths (./my_model.py)
332
- if module_name.endswith(".py"):
333
- # Import from file path
334
- import importlib.util
335
-
336
- spec = importlib.util.spec_from_file_location(
337
- "custom_module", module_name
338
- )
339
- if spec and spec.loader:
340
- module = importlib.util.module_from_spec(spec)
341
- sys.modules["custom_module"] = module
342
- spec.loader.exec_module(module)
343
- print(f"✓ Imported custom module from: {module_name}")
344
- else:
345
- # Import as regular module
346
- importlib.import_module(module_name)
347
- print(f" Imported module: {module_name}")
348
- except ImportError as e:
349
- print(f"✗ Failed to import '{module_name}': {e}", file=sys.stderr)
350
- print(
351
- " Make sure the module is in your Python path or current directory."
352
- )
353
- sys.exit(1)
354
-
355
- # Handle --list_models flag
356
- if args.list_models:
357
- print("Available models:")
358
- for name in list_models():
359
- ModelClass = get_model(name)
360
- # Get first non-empty docstring line
361
- if ModelClass.__doc__:
362
- lines = [
363
- l.strip() for l in ModelClass.__doc__.splitlines() if l.strip()
364
- ]
365
- doc_first_line = lines[0] if lines else "No description"
366
- else:
367
- doc_first_line = "No description"
368
- print(f" - {name}: {doc_first_line}")
369
- sys.exit(0)
370
-
371
- # Load and merge config file if provided
372
- if args.config:
373
- from wavedl.utils.config import (
374
- load_config,
375
- merge_config_with_args,
376
- validate_config,
377
- )
378
-
379
- print(f"📄 Loading config from: {args.config}")
380
- config = load_config(args.config)
381
-
382
- # Validate config values
383
- warnings_list = validate_config(config)
384
- for w in warnings_list:
385
- print(f" ⚠ {w}")
386
-
387
- # Merge config with CLI args (CLI takes precedence via parser defaults detection)
388
- args = merge_config_with_args(config, args, parser=parser)
389
-
390
- # Handle --cv flag (cross-validation mode)
391
- if args.cv > 0:
392
- print(f"🔄 Cross-Validation Mode: {args.cv} folds")
393
- from wavedl.utils.cross_validation import run_cross_validation
394
-
395
- # Load data for CV using memory-efficient loader
396
- from wavedl.utils.data import DataSource, get_data_source
397
-
398
- data_format = DataSource.detect_format(args.data_path)
399
- source = get_data_source(data_format)
400
-
401
- # Use memory-mapped loading when available
402
- _cv_handle = None
403
- if hasattr(source, "load_mmap"):
404
- result = source.load_mmap(args.data_path)
405
- if hasattr(result, "inputs"):
406
- _cv_handle = result
407
- X, y = result.inputs, result.outputs
408
- else:
409
- X, y = result # NPZ returns tuple directly
410
- else:
411
- X, y = source.load(args.data_path)
412
-
413
- # Handle sparse matrices (must materialize for CV shuffling)
414
- if hasattr(X, "__getitem__") and len(X) > 0 and hasattr(X[0], "toarray"):
415
- X = np.stack([x.toarray() for x in X])
416
-
417
- # Normalize target shape: (N,) -> (N, 1) for consistency
418
- if y.ndim == 1:
419
- y = y.reshape(-1, 1)
420
-
421
- # Validate and determine input shape (consistent with prepare_data)
422
- # Check for ambiguous shapes that could be multi-channel or shallow 3D volume
423
- sample_shape = X.shape[1:] # Per-sample shape
424
-
425
- # Same heuristic as prepare_data: detect ambiguous 3D shapes
426
- is_ambiguous_shape = (
427
- len(sample_shape) == 3 # Exactly 3D: could be (C, H, W) or (D, H, W)
428
- and sample_shape[0] <= 16 # First dim looks like channels
429
- and sample_shape[1] > 16
430
- and sample_shape[2] > 16 # Both spatial dims are large
431
- )
432
-
433
- if is_ambiguous_shape and not args.single_channel:
434
- raise ValueError(
435
- f"Ambiguous input shape detected: sample shape {sample_shape}. "
436
- f"This could be either:\n"
437
- f" - Multi-channel 2D data (C={sample_shape[0]}, H={sample_shape[1]}, W={sample_shape[2]})\n"
438
- f" - Single-channel 3D volume (D={sample_shape[0]}, H={sample_shape[1]}, W={sample_shape[2]})\n\n"
439
- f"If this is single-channel 3D/shallow volume data, use --single_channel flag.\n"
440
- f"If this is multi-channel 2D data, reshape to (N*C, H, W) with adjusted targets."
441
- )
442
-
443
- # in_shape = spatial dimensions for model registry (channel added during training)
444
- in_shape = sample_shape
445
-
446
- # Run cross-validation
447
- try:
448
- run_cross_validation(
449
- X=X,
450
- y=y,
451
- model_name=args.model,
452
- in_shape=in_shape,
453
- out_size=y.shape[1],
454
- folds=args.cv,
455
- stratify=args.cv_stratify,
456
- stratify_bins=args.cv_bins,
457
- batch_size=args.batch_size,
458
- lr=args.lr,
459
- epochs=args.epochs,
460
- patience=args.patience,
461
- weight_decay=args.weight_decay,
462
- loss_name=args.loss,
463
- optimizer_name=args.optimizer,
464
- scheduler_name=args.scheduler,
465
- output_dir=args.output_dir,
466
- workers=args.workers,
467
- seed=args.seed,
468
- )
469
- finally:
470
- # Clean up file handle if HDF5/MAT
471
- if _cv_handle is not None and hasattr(_cv_handle, "close"):
472
- try:
473
- _cv_handle.close()
474
- except Exception:
475
- pass
476
- return
477
-
478
- # ==========================================================================
479
- # 1. SYSTEM INITIALIZATION
480
- # ==========================================================================
481
- # Initialize Accelerator for DDP and mixed precision
482
- accelerator = Accelerator(
483
- mixed_precision=args.precision,
484
- log_with="wandb" if args.wandb and WANDB_AVAILABLE else None,
485
- )
486
- set_seed(args.seed)
487
-
488
- # Configure logging (rank 0 only prints to console)
489
- logging.basicConfig(
490
- level=logging.INFO if accelerator.is_main_process else logging.ERROR,
491
- format="%(asctime)s | %(levelname)s | %(message)s",
492
- datefmt="%H:%M:%S",
493
- )
494
- logger = logging.getLogger("Trainer")
495
-
496
- # Ensure output directory exists (critical for cache files, checkpoints, etc.)
497
- os.makedirs(args.output_dir, exist_ok=True)
498
-
499
- # Auto-detect optimal DataLoader workers if not specified
500
- if args.workers < 0:
501
- cpu_count = os.cpu_count() or 4
502
- num_gpus = accelerator.num_processes
503
- # Heuristic: 4-8 workers per GPU, bounded by available CPU cores
504
- # Leave some cores for main process and system overhead
505
- args.workers = min(8, max(2, (cpu_count - 2) // num_gpus))
506
- if accelerator.is_main_process:
507
- logger.info(
508
- f"⚙️ Auto-detected workers: {args.workers} per GPU "
509
- f"(CPUs: {cpu_count}, GPUs: {num_gpus})"
510
- )
511
-
512
- if accelerator.is_main_process:
513
- logger.info(f"🚀 Cluster Status: {accelerator.num_processes}x GPUs detected")
514
- logger.info(
515
- f" Model: {args.model} | Precision: {args.precision} | Compile: {args.compile}"
516
- )
517
- logger.info(
518
- f" Loss: {args.loss} | Optimizer: {args.optimizer} | Scheduler: {args.scheduler}"
519
- )
520
- logger.info(f" Early Stopping Patience: {args.patience} epochs")
521
- if args.save_every > 0:
522
- logger.info(f" Periodic Checkpointing: Every {args.save_every} epochs")
523
- if args.resume:
524
- logger.info(f" 📂 Resuming from: {args.resume}")
525
-
526
- # Initialize WandB
527
- if args.wandb and WANDB_AVAILABLE:
528
- accelerator.init_trackers(
529
- project_name=args.project_name,
530
- config=vars(args),
531
- init_kwargs={"wandb": {"name": args.run_name or f"{args.model}_run"}},
532
- )
533
-
534
- # ==========================================================================
535
- # 2. DATA & MODEL LOADING
536
- # ==========================================================================
537
- train_dl, val_dl, scaler, in_shape, out_dim = prepare_data(
538
- args, logger, accelerator, cache_dir=args.output_dir
539
- )
540
-
541
- # Build model using registry
542
- model = build_model(args.model, in_shape=in_shape, out_size=out_dim)
543
-
544
- if accelerator.is_main_process:
545
- param_info = model.parameter_summary()
546
- logger.info(
547
- f" Model Parameters: {param_info['trainable_parameters']:,} trainable"
548
- )
549
- logger.info(f" Model Size: {param_info['total_mb']:.2f} MB")
550
-
551
- # Optional WandB model watching
552
- if args.wandb and WANDB_AVAILABLE and accelerator.is_main_process:
553
- wandb.watch(model, log="gradients", log_freq=100)
554
-
555
- # Torch 2.0 compilation (requires compatible Triton on GPU)
556
- if args.compile:
557
- try:
558
- # Test if Triton is available AND compatible with this PyTorch version
559
- # PyTorch needs triton_key from triton.compiler.compiler
560
- from triton.compiler.compiler import triton_key
561
-
562
- model = torch.compile(model)
563
- if accelerator.is_main_process:
564
- logger.info(" ✔ torch.compile enabled (Triton backend)")
565
- except ImportError as e:
566
- if accelerator.is_main_process:
567
- if "triton" in str(e).lower():
568
- logger.warning(
569
- " ⚠ Triton not installed or incompatible version - torch.compile disabled. "
570
- "Training will proceed without compilation."
571
- )
572
- else:
573
- logger.warning(
574
- f" ⚠ torch.compile setup failed: {e}. Continuing without compilation."
575
- )
576
- except Exception as e:
577
- if accelerator.is_main_process:
578
- logger.warning(
579
- f" ⚠ torch.compile failed: {e}. Continuing without compilation."
580
- )
581
-
582
- # ==========================================================================
583
- # 2.5. OPTIMIZER, SCHEDULER & LOSS CONFIGURATION
584
- # ==========================================================================
585
- # Parse comma-separated arguments with validation
586
- try:
587
- betas_list = [float(x.strip()) for x in args.betas.split(",")]
588
- if len(betas_list) != 2:
589
- raise ValueError(
590
- f"--betas must have exactly 2 values, got {len(betas_list)}"
591
- )
592
- if not all(0.0 <= b < 1.0 for b in betas_list):
593
- raise ValueError(f"--betas values must be in [0, 1), got {betas_list}")
594
- betas = tuple(betas_list)
595
- except ValueError as e:
596
- raise ValueError(
597
- f"Invalid --betas format '{args.betas}': {e}. Expected format: '0.9,0.999'"
598
- )
599
-
600
- loss_weights = None
601
- if args.loss_weights:
602
- loss_weights = [float(x.strip()) for x in args.loss_weights.split(",")]
603
- milestones = None
604
- if args.milestones:
605
- milestones = [int(x.strip()) for x in args.milestones.split(",")]
606
-
607
- # Create optimizer using factory
608
- optimizer = get_optimizer(
609
- name=args.optimizer,
610
- params=model.get_optimizer_groups(args.lr, args.weight_decay),
611
- lr=args.lr,
612
- weight_decay=args.weight_decay,
613
- momentum=args.momentum,
614
- nesterov=args.nesterov,
615
- betas=betas,
616
- )
617
-
618
- # Create loss function using factory
619
- criterion = get_loss(
620
- name=args.loss,
621
- weights=loss_weights,
622
- delta=args.huber_delta,
623
- )
624
- # Move criterion to device (important for WeightedMSELoss buffer)
625
- criterion = criterion.to(accelerator.device)
626
-
627
- # Track if scheduler should step per batch (OneCycleLR) or per epoch
628
- scheduler_step_per_batch = not is_epoch_based(args.scheduler)
629
-
630
- # ==========================================================================
631
- # DDP Preparation Strategy:
632
- # - For batch-based schedulers (OneCycleLR): prepare DataLoaders first to get
633
- # the correct sharded batch count, then create scheduler
634
- # - For epoch-based schedulers: create scheduler before prepare (no issue)
635
- # ==========================================================================
636
- if scheduler_step_per_batch:
637
- # BATCH-BASED SCHEDULER (e.g., OneCycleLR)
638
- # Prepare model, optimizer, dataloaders FIRST to get sharded loader length
639
- model, optimizer, train_dl, val_dl = accelerator.prepare(
640
- model, optimizer, train_dl, val_dl
641
- )
642
-
643
- # Now create scheduler with the CORRECT sharded steps_per_epoch
644
- steps_per_epoch = len(train_dl) # Post-DDP sharded length
645
- scheduler = get_scheduler(
646
- name=args.scheduler,
647
- optimizer=optimizer,
648
- epochs=args.epochs,
649
- steps_per_epoch=steps_per_epoch,
650
- min_lr=args.min_lr,
651
- patience=args.scheduler_patience,
652
- factor=args.scheduler_factor,
653
- gamma=args.scheduler_factor, # For Step/MultiStep/Exponential schedulers
654
- step_size=args.step_size,
655
- milestones=milestones,
656
- warmup_epochs=args.warmup_epochs,
657
- )
658
- # Prepare scheduler separately (Accelerator wraps it for state saving)
659
- scheduler = accelerator.prepare(scheduler)
660
- else:
661
- # EPOCH-BASED SCHEDULER (plateau, cosine, step, etc.)
662
- # No batch count dependency - create scheduler before prepare
663
- scheduler = get_scheduler(
664
- name=args.scheduler,
665
- optimizer=optimizer,
666
- epochs=args.epochs,
667
- steps_per_epoch=None,
668
- min_lr=args.min_lr,
669
- patience=args.scheduler_patience,
670
- factor=args.scheduler_factor,
671
- gamma=args.scheduler_factor, # For Step/MultiStep/Exponential schedulers
672
- step_size=args.step_size,
673
- milestones=milestones,
674
- warmup_epochs=args.warmup_epochs,
675
- )
676
- # Prepare everything together
677
- model, optimizer, train_dl, val_dl, scheduler = accelerator.prepare(
678
- model, optimizer, train_dl, val_dl, scheduler
679
- )
680
-
681
- # ==========================================================================
682
- # 3. AUTO-RESUME / RESUME FROM CHECKPOINT
683
- # ==========================================================================
684
- start_epoch = 0
685
- best_val_loss = float("inf")
686
- patience_ctr = 0
687
- history: list[dict[str, Any]] = []
688
-
689
- # Define checkpoint paths
690
- best_ckpt_path = os.path.join(args.output_dir, "best_checkpoint")
691
- complete_flag_path = os.path.join(args.output_dir, "training_complete.flag")
692
-
693
- # Auto-resume logic (if not --fresh and no explicit --resume)
694
- if not args.fresh and args.resume is None:
695
- if os.path.exists(complete_flag_path):
696
- # Training already completed
697
- if accelerator.is_main_process:
698
- logger.info(
699
- "✅ Training already completed (early stopping). Use --fresh to retrain."
700
- )
701
- return # Exit gracefully
702
- elif os.path.exists(best_ckpt_path):
703
- # Incomplete training found - auto-resume
704
- args.resume = best_ckpt_path
705
- if accelerator.is_main_process:
706
- logger.info(f"🔄 Auto-resuming from: {best_ckpt_path}")
707
-
708
- if args.resume:
709
- if os.path.exists(args.resume):
710
- logger.info(f"🔄 Loading checkpoint from: {args.resume}")
711
- accelerator.load_state(args.resume)
712
-
713
- # Restore training metadata
714
- meta_path = os.path.join(args.resume, "training_meta.pkl")
715
- if os.path.exists(meta_path):
716
- with open(meta_path, "rb") as f:
717
- meta = pickle.load(f)
718
- start_epoch = meta.get("epoch", 0)
719
- best_val_loss = meta.get("best_val_loss", float("inf"))
720
- patience_ctr = meta.get("patience_ctr", 0)
721
- logger.info(
722
- f" Restored: Epoch {start_epoch}, Best Loss: {best_val_loss:.6f}"
723
- )
724
- else:
725
- logger.warning(
726
- " ⚠️ training_meta.pkl not found, starting from epoch 0"
727
- )
728
-
729
- # Restore history
730
- history_path = os.path.join(args.output_dir, "training_history.csv")
731
- if os.path.exists(history_path):
732
- history = pd.read_csv(history_path).to_dict("records")
733
- logger.info(f" ✅ Loaded {len(history)} epochs from history")
734
- else:
735
- raise FileNotFoundError(f"Checkpoint not found: {args.resume}")
736
-
737
- # ==========================================================================
738
- # 4. PHYSICAL METRIC SETUP
739
- # ==========================================================================
740
- # Physical MAE = normalized MAE * scaler.scale_
741
- phys_scale = torch.tensor(
742
- scaler.scale_, device=accelerator.device, dtype=torch.float32
743
- )
744
-
745
- # ==========================================================================
746
- # 5. TRAINING LOOP
747
- # ==========================================================================
748
- # Dynamic console header
749
- if accelerator.is_main_process:
750
- base_cols = ["Ep", "TrnLoss", "ValLoss", "R2", "PCC", "GradN", "LR", "MAE_Avg"]
751
- param_cols = [f"MAE_P{i}" for i in range(out_dim)]
752
- header = "{:<4} | {:<8} | {:<8} | {:<6} | {:<6} | {:<6} | {:<8} | {:<8}".format(
753
- *base_cols
754
- )
755
- header += " | " + " | ".join([f"{c:<8}" for c in param_cols])
756
- logger.info("=" * len(header))
757
- logger.info(header)
758
- logger.info("=" * len(header))
759
-
760
- try:
761
- time.time()
762
- total_training_time = 0.0
763
-
764
- for epoch in range(start_epoch, args.epochs):
765
- epoch_start_time = time.time()
766
-
767
- # ==================== TRAINING PHASE ====================
768
- model.train()
769
- # Use GPU tensor for loss accumulation to avoid .item() sync per batch
770
- train_loss_sum = torch.tensor(0.0, device=accelerator.device)
771
- train_samples = 0
772
- grad_norm_tracker = MetricTracker()
773
-
774
- pbar = tqdm(
775
- train_dl,
776
- disable=not accelerator.is_main_process,
777
- leave=False,
778
- desc=f"Epoch {epoch + 1}",
779
- )
780
-
781
- for x, y in pbar:
782
- with accelerator.accumulate(model):
783
- pred = model(x)
784
- loss = criterion(pred, y)
785
-
786
- accelerator.backward(loss)
787
-
788
- if accelerator.sync_gradients:
789
- grad_norm = accelerator.clip_grad_norm_(
790
- model.parameters(), args.grad_clip
791
- )
792
- if grad_norm is not None:
793
- grad_norm_tracker.update(grad_norm.item())
794
-
795
- optimizer.step()
796
- optimizer.zero_grad(set_to_none=True) # Faster than zero_grad()
797
-
798
- # Per-batch LR scheduling (e.g., OneCycleLR)
799
- if scheduler_step_per_batch:
800
- scheduler.step()
801
-
802
- # Accumulate as tensors to avoid .item() sync per batch
803
- train_loss_sum += loss.detach() * x.size(0)
804
- train_samples += x.size(0)
805
-
806
- # Single .item() call at end of epoch (reduces GPU sync overhead)
807
- train_loss_scalar = train_loss_sum.item()
808
-
809
- # Synchronize training metrics across GPUs
810
- global_loss = accelerator.reduce(
811
- torch.tensor([train_loss_scalar], device=accelerator.device),
812
- reduction="sum",
813
- ).item()
814
- global_samples = accelerator.reduce(
815
- torch.tensor([train_samples], device=accelerator.device),
816
- reduction="sum",
817
- ).item()
818
- avg_train_loss = global_loss / global_samples
819
-
820
- # ==================== VALIDATION PHASE ====================
821
- model.eval()
822
- # Use GPU tensor for loss accumulation (consistent with training phase)
823
- val_loss_sum = torch.tensor(0.0, device=accelerator.device)
824
- val_mae_sum = torch.zeros(out_dim, device=accelerator.device)
825
- val_samples = 0
826
-
827
- # Accumulate predictions locally, gather ONCE at end (reduces sync overhead)
828
- local_preds = []
829
- local_targets = []
830
-
831
- with torch.inference_mode():
832
- for x, y in val_dl:
833
- pred = model(x)
834
- loss = criterion(pred, y)
835
-
836
- val_loss_sum += loss.detach() * x.size(0)
837
- val_samples += x.size(0)
838
-
839
- # Physical MAE
840
- mae_batch = torch.abs((pred - y) * phys_scale).sum(dim=0)
841
- val_mae_sum += mae_batch
842
-
843
- # Store locally (no GPU sync per batch)
844
- local_preds.append(pred)
845
- local_targets.append(y)
846
-
847
- # Single gather at end of validation (2 syncs instead of 2×num_batches)
848
- all_local_preds = torch.cat(local_preds)
849
- all_local_targets = torch.cat(local_targets)
850
- all_preds = accelerator.gather_for_metrics(all_local_preds)
851
- all_targets = accelerator.gather_for_metrics(all_local_targets)
852
-
853
- # Synchronize validation metrics
854
- val_loss_scalar = val_loss_sum.item()
855
- val_metrics = torch.cat(
856
- [
857
- torch.tensor([val_loss_scalar], device=accelerator.device),
858
- val_mae_sum,
859
- ]
860
- )
861
- val_metrics_sync = accelerator.reduce(val_metrics, reduction="sum")
862
-
863
- total_val_samples = accelerator.reduce(
864
- torch.tensor([val_samples], device=accelerator.device), reduction="sum"
865
- ).item()
866
-
867
- avg_val_loss = val_metrics_sync[0].item() / total_val_samples
868
- # Cast to float32 before numpy (bf16 tensors can't convert directly)
869
- avg_mae_per_param = (
870
- (val_metrics_sync[1:] / total_val_samples).float().cpu().numpy()
871
- )
872
- avg_mae = avg_mae_per_param.mean()
873
-
874
- # ==================== LOGGING & CHECKPOINTING ====================
875
- if accelerator.is_main_process:
876
- # Scientific metrics - cast to float32 before numpy (bf16 can't convert)
877
- y_pred = all_preds.float().cpu().numpy()
878
- y_true = all_targets.float().cpu().numpy()
879
-
880
- # Trim DDP padding
881
- real_len = len(val_dl.dataset)
882
- if len(y_pred) > real_len:
883
- y_pred = y_pred[:real_len]
884
- y_true = y_true[:real_len]
885
-
886
- # Guard against tiny validation sets (R² undefined for <2 samples)
887
- if len(y_true) >= 2:
888
- r2 = r2_score(y_true, y_pred)
889
- else:
890
- r2 = float("nan")
891
- pcc = calc_pearson(y_true, y_pred)
892
- current_lr = get_lr(optimizer)
893
-
894
- # Update history
895
- epoch_end_time = time.time()
896
- epoch_time = epoch_end_time - epoch_start_time
897
- total_training_time += epoch_time
898
-
899
- epoch_stats = {
900
- "epoch": epoch + 1,
901
- "train_loss": avg_train_loss,
902
- "val_loss": avg_val_loss,
903
- "val_r2": r2,
904
- "val_pearson": pcc,
905
- "val_mae_avg": avg_mae,
906
- "grad_norm": grad_norm_tracker.avg,
907
- "lr": current_lr,
908
- "epoch_time": round(epoch_time, 2),
909
- "total_time": round(total_training_time, 2),
910
- }
911
- for i, mae in enumerate(avg_mae_per_param):
912
- epoch_stats[f"MAE_Phys_P{i}"] = mae
913
-
914
- history.append(epoch_stats)
915
-
916
- # Console display
917
- base_str = f"{epoch + 1:<4} | {avg_train_loss:<8.4f} | {avg_val_loss:<8.4f} | {r2:<6.4f} | {pcc:<6.4f} | {grad_norm_tracker.avg:<6.4f} | {current_lr:<8.2e} | {avg_mae:<8.4f}"
918
- param_str = " | ".join([f"{m:<8.4f}" for m in avg_mae_per_param])
919
- logger.info(f"{base_str} | {param_str}")
920
-
921
- # WandB logging
922
- if args.wandb and WANDB_AVAILABLE:
923
- log_dict = {
924
- "main/train_loss": avg_train_loss,
925
- "main/val_loss": avg_val_loss,
926
- "metrics/r2_score": r2,
927
- "metrics/pearson_corr": pcc,
928
- "metrics/mae_avg": avg_mae,
929
- "system/grad_norm": grad_norm_tracker.avg,
930
- "hyper/lr": current_lr,
931
- }
932
- for i, mae in enumerate(avg_mae_per_param):
933
- log_dict[f"mae_detailed/P{i}"] = mae
934
-
935
- # Periodic scatter plots
936
- if (epoch % 5 == 0) or (avg_val_loss < best_val_loss):
937
- real_true = scaler.inverse_transform(y_true)
938
- real_pred = scaler.inverse_transform(y_pred)
939
- fig = plot_scientific_scatter(real_true, real_pred)
940
- log_dict["plots/scatter_analysis"] = wandb.Image(fig)
941
- plt.close(fig)
942
-
943
- accelerator.log(log_dict)
944
-
945
- # ==========================================================================
946
- # DDP-SAFE CHECKPOINT LOGIC
947
- # ==========================================================================
948
- # Step 1: Determine if this is the best epoch (BEFORE updating best_val_loss)
949
- is_best_epoch = False
950
- if accelerator.is_main_process:
951
- if avg_val_loss < best_val_loss:
952
- is_best_epoch = True
953
-
954
- # Step 2: Broadcast decision to all ranks (required for save_state)
955
- is_best_epoch = broadcast_early_stop(is_best_epoch, accelerator)
956
-
957
- # Step 3: Save checkpoint with all ranks participating
958
- if is_best_epoch:
959
- ckpt_dir = os.path.join(args.output_dir, "best_checkpoint")
960
- accelerator.save_state(ckpt_dir) # All ranks must call this
961
-
962
- # Step 4: Rank 0 handles metadata and updates tracking variables
963
- if accelerator.is_main_process:
964
- best_val_loss = avg_val_loss # Update AFTER checkpoint saved
965
- patience_ctr = 0
966
-
967
- with open(os.path.join(ckpt_dir, "training_meta.pkl"), "wb") as f:
968
- pickle.dump(
969
- {
970
- "epoch": epoch + 1,
971
- "best_val_loss": best_val_loss,
972
- "patience_ctr": patience_ctr,
973
- # Model info for auto-detection during inference
974
- "model_name": args.model,
975
- "in_shape": in_shape,
976
- "out_dim": out_dim,
977
- },
978
- f,
979
- )
980
-
981
- unwrapped = accelerator.unwrap_model(model)
982
- torch.save(
983
- unwrapped.state_dict(),
984
- os.path.join(args.output_dir, "best_model_weights.pth"),
985
- )
986
-
987
- # Copy scaler to checkpoint for portability
988
- scaler_src = os.path.join(args.output_dir, "scaler.pkl")
989
- scaler_dst = os.path.join(ckpt_dir, "scaler.pkl")
990
- if os.path.exists(scaler_src) and not os.path.exists(scaler_dst):
991
- shutil.copy2(scaler_src, scaler_dst)
992
-
993
- logger.info(
994
- f" 💾 Best model saved (val_loss: {best_val_loss:.6f})"
995
- )
996
-
997
- # Also save CSV on best model (ensures progress is saved)
998
- pd.DataFrame(history).to_csv(
999
- os.path.join(args.output_dir, "training_history.csv"),
1000
- index=False,
1001
- )
1002
- else:
1003
- if accelerator.is_main_process:
1004
- patience_ctr += 1
1005
-
1006
- # Periodic checkpoint (all ranks participate in save_state)
1007
- periodic_checkpoint_needed = (
1008
- args.save_every > 0 and (epoch + 1) % args.save_every == 0
1009
- )
1010
- if periodic_checkpoint_needed:
1011
- ckpt_name = f"epoch_{epoch + 1}_checkpoint"
1012
- ckpt_dir = os.path.join(args.output_dir, ckpt_name)
1013
- accelerator.save_state(ckpt_dir) # All ranks participate
1014
-
1015
- if accelerator.is_main_process:
1016
- with open(os.path.join(ckpt_dir, "training_meta.pkl"), "wb") as f:
1017
- pickle.dump(
1018
- {
1019
- "epoch": epoch + 1,
1020
- "best_val_loss": best_val_loss,
1021
- "patience_ctr": patience_ctr,
1022
- # Model info for auto-detection during inference
1023
- "model_name": args.model,
1024
- "in_shape": in_shape,
1025
- "out_dim": out_dim,
1026
- },
1027
- f,
1028
- )
1029
- logger.info(f" 📁 Periodic checkpoint: {ckpt_name}")
1030
-
1031
- # Save CSV with each checkpoint (keeps logs in sync with model state)
1032
- pd.DataFrame(history).to_csv(
1033
- os.path.join(args.output_dir, "training_history.csv"),
1034
- index=False,
1035
- )
1036
-
1037
- # Learning rate scheduling (epoch-based schedulers only)
1038
- if not scheduler_step_per_batch:
1039
- if args.scheduler == "plateau":
1040
- scheduler.step(avg_val_loss)
1041
- else:
1042
- scheduler.step()
1043
-
1044
- # DDP-safe early stopping
1045
- should_stop = (
1046
- patience_ctr >= args.patience if accelerator.is_main_process else False
1047
- )
1048
- if broadcast_early_stop(should_stop, accelerator):
1049
- if accelerator.is_main_process:
1050
- logger.info(
1051
- f"🛑 Early stopping at epoch {epoch + 1} (patience={args.patience})"
1052
- )
1053
- # Create completion flag to prevent auto-resume
1054
- with open(
1055
- os.path.join(args.output_dir, "training_complete.flag"), "w"
1056
- ) as f:
1057
- f.write(
1058
- f"Training completed via early stopping at epoch {epoch + 1}\n"
1059
- )
1060
- break
1061
-
1062
- except KeyboardInterrupt:
1063
- logger.warning("Training interrupted. Saving emergency checkpoint...")
1064
- accelerator.save_state(os.path.join(args.output_dir, "interrupted_checkpoint"))
1065
-
1066
- except Exception as e:
1067
- logger.error(f"Critical error: {e}", exc_info=True)
1068
- raise
1069
-
1070
- else:
1071
- # Training completed normally (reached max epochs without early stopping)
1072
- # Create completion flag to prevent auto-resume on re-run
1073
- if accelerator.is_main_process:
1074
- if not os.path.exists(complete_flag_path):
1075
- with open(complete_flag_path, "w") as f:
1076
- f.write(f"Training completed normally after {args.epochs} epochs\n")
1077
- logger.info(f"✅ Training completed after {args.epochs} epochs")
1078
-
1079
- finally:
1080
- # Final CSV write to capture all epochs (handles non-multiple-of-10 endings)
1081
- if accelerator.is_main_process and len(history) > 0:
1082
- pd.DataFrame(history).to_csv(
1083
- os.path.join(args.output_dir, "training_history.csv"),
1084
- index=False,
1085
- )
1086
-
1087
- # Generate training curves plot (PNG + SVG)
1088
- if accelerator.is_main_process and len(history) > 0:
1089
- try:
1090
- fig = create_training_curves(history, show_lr=True)
1091
- for fmt in ["png", "svg"]:
1092
- fig.savefig(
1093
- os.path.join(args.output_dir, f"training_curves.{fmt}"),
1094
- dpi=FIGURE_DPI,
1095
- bbox_inches="tight",
1096
- )
1097
- plt.close(fig)
1098
- logger.info("✔ Saved: training_curves.png, training_curves.svg")
1099
- except Exception as e:
1100
- logger.warning(f"Could not generate training curves: {e}")
1101
-
1102
- if args.wandb and WANDB_AVAILABLE:
1103
- accelerator.end_training()
1104
-
1105
- # Clean up distributed process group to prevent resource leak warning
1106
- if torch.distributed.is_initialized():
1107
- torch.distributed.destroy_process_group()
1108
-
1109
- logger.info("Training completed.")
1110
-
1111
-
1112
- if __name__ == "__main__":
1113
- try:
1114
- torch.multiprocessing.set_start_method("spawn")
1115
- except RuntimeError:
1116
- pass
1117
- main()
1
+ """
2
+ WaveDL - Deep Learning for Wave-based Inverse Problems
3
+ =======================================================
4
+ Target Environment: NVIDIA HPC GPUs (Multi-GPU DDP) | PyTorch 2.x | Python 3.11+
5
+
6
+ A modular training framework for wave-based inverse problems and regression:
7
+ 1. HPC-Grade DDP Training: BF16/FP16 mixed precision with torch.compile support
8
+ 2. Dynamic Model Selection: Use --model flag to select any registered architecture
9
+ 3. Zero-Copy Data Engine: Memmap-backed datasets for large-scale training
10
+ 4. Physics-Aware Metrics: Real-time physical MAE with proper unscaling
11
+ 5. Robust Checkpointing: Resume training, periodic saves, and training curves
12
+ 6. Deep Observability: WandB integration with scatter analysis
13
+
14
+ Usage:
15
+ # Recommended: Using the HPC launcher
16
+ wavedl-hpc --model cnn --batch_size 128 --wandb
17
+
18
+ # Or with direct accelerate launch
19
+ accelerate launch -m wavedl.train --model cnn --batch_size 128 --wandb
20
+
21
+ # Multi-GPU with explicit config
22
+ wavedl-hpc --num_gpus 4 --mixed_precision bf16 --model cnn --wandb
23
+
24
+ # Resume from checkpoint
25
+ accelerate launch -m wavedl.train --model cnn --resume best_checkpoint --wandb
26
+
27
+ # List available models
28
+ wavedl-train --list_models
29
+
30
+ Note:
31
+ For HPC clusters (Compute Canada, etc.), use wavedl-hpc which handles
32
+ environment configuration automatically. Mixed precision is controlled via
33
+ --mixed_precision flag (default: bf16).
34
+
35
+ Author: Ductho Le (ductho.le@outlook.com)
36
+ """
37
+
38
+ from __future__ import annotations
39
+
40
+ import argparse
41
+ import logging
42
+ import os
43
+ import pickle
44
+ import shutil
45
+ import sys
46
+ import time
47
+ import warnings
48
+ from typing import Any
49
+
50
+ import matplotlib.pyplot as plt
51
+ import numpy as np
52
+ import pandas as pd
53
+ import torch
54
+ from accelerate import Accelerator
55
+ from accelerate.utils import set_seed
56
+ from sklearn.metrics import r2_score
57
+ from tqdm.auto import tqdm
58
+
59
+ from wavedl.models import build_model, get_model, list_models
60
+ from wavedl.utils import (
61
+ FIGURE_DPI,
62
+ MetricTracker,
63
+ broadcast_early_stop,
64
+ calc_pearson,
65
+ create_training_curves,
66
+ get_loss,
67
+ get_lr,
68
+ get_optimizer,
69
+ get_scheduler,
70
+ is_epoch_based,
71
+ list_losses,
72
+ list_optimizers,
73
+ list_schedulers,
74
+ plot_scientific_scatter,
75
+ prepare_data,
76
+ )
77
+
78
+
79
+ try:
80
+ import wandb
81
+
82
+ WANDB_AVAILABLE = True
83
+ except ImportError:
84
+ WANDB_AVAILABLE = False
85
+
86
+ # ==============================================================================
87
+ # RUNTIME CONFIGURATION (post-import)
88
+ # ==============================================================================
89
+ # Configure matplotlib paths for HPC systems without writable home directories
90
+ os.environ.setdefault("MPLCONFIGDIR", os.getenv("TMPDIR", "/tmp") + "/matplotlib")
91
+ os.environ.setdefault("FONTCONFIG_PATH", "/etc/fonts")
92
+
93
+ # Suppress non-critical warnings for cleaner training logs
94
+ warnings.filterwarnings("ignore", category=UserWarning)
95
+ warnings.filterwarnings("ignore", category=FutureWarning)
96
+ warnings.filterwarnings("ignore", category=DeprecationWarning)
97
+ warnings.filterwarnings("ignore", module="pydantic")
98
+ warnings.filterwarnings("ignore", message=".*UnsupportedFieldAttributeWarning.*")
99
+
100
+
101
+ # ==============================================================================
102
+ # ARGUMENT PARSING
103
+ # ==============================================================================
104
+ def parse_args() -> argparse.Namespace:
105
+ """Parse command-line arguments with comprehensive options."""
106
+ parser = argparse.ArgumentParser(
107
+ description="Universal DDP Training Pipeline",
108
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter,
109
+ )
110
+
111
+ # Model Selection
112
+ parser.add_argument(
113
+ "--model",
114
+ type=str,
115
+ default="cnn",
116
+ help=f"Model architecture to train. Available: {list_models()}",
117
+ )
118
+ parser.add_argument(
119
+ "--list_models", action="store_true", help="List all available models and exit"
120
+ )
121
+ parser.add_argument(
122
+ "--import",
123
+ dest="import_modules",
124
+ type=str,
125
+ nargs="+",
126
+ default=[],
127
+ help="Python modules to import before training (for custom models)",
128
+ )
129
+
130
+ # Configuration File
131
+ parser.add_argument(
132
+ "--config",
133
+ type=str,
134
+ default=None,
135
+ help="Path to YAML config file. CLI args override config values.",
136
+ )
137
+
138
+ # Hyperparameters
139
+ parser.add_argument(
140
+ "--batch_size", type=int, default=128, help="Batch size per GPU"
141
+ )
142
+ parser.add_argument("--lr", type=float, default=1e-3, help="Initial learning rate")
143
+ parser.add_argument(
144
+ "--epochs", type=int, default=1000, help="Maximum training epochs"
145
+ )
146
+ parser.add_argument(
147
+ "--patience", type=int, default=20, help="Early stopping patience"
148
+ )
149
+ parser.add_argument("--weight_decay", type=float, default=1e-4, help="Weight decay")
150
+ parser.add_argument(
151
+ "--grad_clip", type=float, default=1.0, help="Gradient clipping norm"
152
+ )
153
+
154
+ # Loss Function
155
+ parser.add_argument(
156
+ "--loss",
157
+ type=str,
158
+ default="mse",
159
+ choices=["mse", "mae", "huber", "smooth_l1", "log_cosh", "weighted_mse"],
160
+ help=f"Loss function for training. Available: {list_losses()}",
161
+ )
162
+ parser.add_argument(
163
+ "--huber_delta", type=float, default=1.0, help="Delta for Huber loss"
164
+ )
165
+ parser.add_argument(
166
+ "--loss_weights",
167
+ type=str,
168
+ default=None,
169
+ help="Comma-separated weights for weighted_mse (e.g., '1.0,2.0,1.0')",
170
+ )
171
+
172
+ # Optimizer
173
+ parser.add_argument(
174
+ "--optimizer",
175
+ type=str,
176
+ default="adamw",
177
+ choices=["adamw", "adam", "sgd", "nadam", "radam", "rmsprop"],
178
+ help=f"Optimizer for training. Available: {list_optimizers()}",
179
+ )
180
+ parser.add_argument(
181
+ "--momentum", type=float, default=0.9, help="Momentum for SGD/RMSprop"
182
+ )
183
+ parser.add_argument(
184
+ "--nesterov", action="store_true", help="Use Nesterov momentum (SGD)"
185
+ )
186
+ parser.add_argument(
187
+ "--betas",
188
+ type=str,
189
+ default="0.9,0.999",
190
+ help="Betas for Adam variants (comma-separated)",
191
+ )
192
+
193
+ # Learning Rate Scheduler
194
+ parser.add_argument(
195
+ "--scheduler",
196
+ type=str,
197
+ default="plateau",
198
+ choices=[
199
+ "plateau",
200
+ "cosine",
201
+ "cosine_restarts",
202
+ "onecycle",
203
+ "step",
204
+ "multistep",
205
+ "exponential",
206
+ "linear_warmup",
207
+ ],
208
+ help=f"LR scheduler. Available: {list_schedulers()}",
209
+ )
210
+ parser.add_argument(
211
+ "--scheduler_patience",
212
+ type=int,
213
+ default=10,
214
+ help="Patience for ReduceLROnPlateau",
215
+ )
216
+ parser.add_argument(
217
+ "--min_lr", type=float, default=1e-6, help="Minimum learning rate"
218
+ )
219
+ parser.add_argument(
220
+ "--scheduler_factor", type=float, default=0.5, help="LR reduction factor"
221
+ )
222
+ parser.add_argument(
223
+ "--warmup_epochs", type=int, default=5, help="Warmup epochs for linear_warmup"
224
+ )
225
+ parser.add_argument(
226
+ "--step_size", type=int, default=30, help="Step size for StepLR"
227
+ )
228
+ parser.add_argument(
229
+ "--milestones",
230
+ type=str,
231
+ default=None,
232
+ help="Comma-separated epochs for MultiStepLR (e.g., '30,60,90')",
233
+ )
234
+
235
+ # Data
236
+ parser.add_argument(
237
+ "--data_path", type=str, default="train_data.npz", help="Path to NPZ dataset"
238
+ )
239
+ parser.add_argument(
240
+ "--workers",
241
+ type=int,
242
+ default=-1,
243
+ help="DataLoader workers per GPU (-1=auto-detect based on CPU cores)",
244
+ )
245
+ parser.add_argument("--seed", type=int, default=2025, help="Random seed")
246
+ parser.add_argument(
247
+ "--single_channel",
248
+ action="store_true",
249
+ help="Confirm data is single-channel (suppress ambiguous shape warnings for shallow 3D volumes)",
250
+ )
251
+
252
+ # Cross-Validation
253
+ parser.add_argument(
254
+ "--cv",
255
+ type=int,
256
+ default=0,
257
+ help="Enable K-fold cross-validation with K folds (0=disabled)",
258
+ )
259
+ parser.add_argument(
260
+ "--cv_stratify",
261
+ action="store_true",
262
+ help="Use stratified splitting for cross-validation",
263
+ )
264
+ parser.add_argument(
265
+ "--cv_bins",
266
+ type=int,
267
+ default=10,
268
+ help="Number of bins for stratified CV (only with --cv_stratify)",
269
+ )
270
+
271
+ # Checkpointing & Resume
272
+ parser.add_argument(
273
+ "--resume", type=str, default=None, help="Checkpoint directory to resume from"
274
+ )
275
+ parser.add_argument(
276
+ "--save_every",
277
+ type=int,
278
+ default=50,
279
+ help="Save checkpoint every N epochs (0=disable)",
280
+ )
281
+ parser.add_argument(
282
+ "--output_dir", type=str, default=".", help="Output directory for checkpoints"
283
+ )
284
+ parser.add_argument(
285
+ "--fresh",
286
+ action="store_true",
287
+ help="Force fresh training, ignore existing checkpoints",
288
+ )
289
+
290
+ # Performance
291
+ parser.add_argument(
292
+ "--compile", action="store_true", help="Enable torch.compile (PyTorch 2.x)"
293
+ )
294
+ parser.add_argument(
295
+ "--precision",
296
+ type=str,
297
+ default="bf16",
298
+ choices=["bf16", "fp16", "no"],
299
+ help="Mixed precision mode",
300
+ )
301
+
302
+ # Logging
303
+ parser.add_argument(
304
+ "--wandb", action="store_true", help="Enable Weights & Biases logging"
305
+ )
306
+ parser.add_argument(
307
+ "--project_name", type=str, default="DL-Training", help="WandB project name"
308
+ )
309
+ parser.add_argument("--run_name", type=str, default=None, help="WandB run name")
310
+
311
+ args = parser.parse_args()
312
+ return args, parser # Returns (Namespace, ArgumentParser)
313
+
314
+
315
+ # ==============================================================================
316
+ # MAIN TRAINING FUNCTION
317
+ # ==============================================================================
318
+ def main():
319
+ args, parser = parse_args()
320
+
321
+ # Import custom model modules if specified
322
+ if args.import_modules:
323
+ import importlib
324
+
325
+ for module_name in args.import_modules:
326
+ try:
327
+ # Handle both module names (my_model) and file paths (./my_model.py)
328
+ if module_name.endswith(".py"):
329
+ # Import from file path
330
+ import importlib.util
331
+
332
+ spec = importlib.util.spec_from_file_location(
333
+ "custom_module", module_name
334
+ )
335
+ if spec and spec.loader:
336
+ module = importlib.util.module_from_spec(spec)
337
+ sys.modules["custom_module"] = module
338
+ spec.loader.exec_module(module)
339
+ print(f"✓ Imported custom module from: {module_name}")
340
+ else:
341
+ # Import as regular module
342
+ importlib.import_module(module_name)
343
+ print(f"✓ Imported module: {module_name}")
344
+ except ImportError as e:
345
+ print(f"✗ Failed to import '{module_name}': {e}", file=sys.stderr)
346
+ print(
347
+ " Make sure the module is in your Python path or current directory."
348
+ )
349
+ sys.exit(1)
350
+
351
+ # Handle --list_models flag
352
+ if args.list_models:
353
+ print("Available models:")
354
+ for name in list_models():
355
+ ModelClass = get_model(name)
356
+ # Get first non-empty docstring line
357
+ if ModelClass.__doc__:
358
+ lines = [
359
+ l.strip() for l in ModelClass.__doc__.splitlines() if l.strip()
360
+ ]
361
+ doc_first_line = lines[0] if lines else "No description"
362
+ else:
363
+ doc_first_line = "No description"
364
+ print(f" - {name}: {doc_first_line}")
365
+ sys.exit(0)
366
+
367
+ # Load and merge config file if provided
368
+ if args.config:
369
+ from wavedl.utils.config import (
370
+ load_config,
371
+ merge_config_with_args,
372
+ validate_config,
373
+ )
374
+
375
+ print(f"📄 Loading config from: {args.config}")
376
+ config = load_config(args.config)
377
+
378
+ # Validate config values
379
+ warnings_list = validate_config(config)
380
+ for w in warnings_list:
381
+ print(f" ⚠ {w}")
382
+
383
+ # Merge config with CLI args (CLI takes precedence via parser defaults detection)
384
+ args = merge_config_with_args(config, args, parser=parser)
385
+
386
+ # Handle --cv flag (cross-validation mode)
387
+ if args.cv > 0:
388
+ print(f"🔄 Cross-Validation Mode: {args.cv} folds")
389
+ from wavedl.utils.cross_validation import run_cross_validation
390
+
391
+ # Load data for CV using memory-efficient loader
392
+ from wavedl.utils.data import DataSource, get_data_source
393
+
394
+ data_format = DataSource.detect_format(args.data_path)
395
+ source = get_data_source(data_format)
396
+
397
+ # Use memory-mapped loading when available
398
+ _cv_handle = None
399
+ if hasattr(source, "load_mmap"):
400
+ result = source.load_mmap(args.data_path)
401
+ if hasattr(result, "inputs"):
402
+ _cv_handle = result
403
+ X, y = result.inputs, result.outputs
404
+ else:
405
+ X, y = result # NPZ returns tuple directly
406
+ else:
407
+ X, y = source.load(args.data_path)
408
+
409
+ # Handle sparse matrices (must materialize for CV shuffling)
410
+ if hasattr(X, "__getitem__") and len(X) > 0 and hasattr(X[0], "toarray"):
411
+ X = np.stack([x.toarray() for x in X])
412
+
413
+ # Normalize target shape: (N,) -> (N, 1) for consistency
414
+ if y.ndim == 1:
415
+ y = y.reshape(-1, 1)
416
+
417
+ # Validate and determine input shape (consistent with prepare_data)
418
+ # Check for ambiguous shapes that could be multi-channel or shallow 3D volume
419
+ sample_shape = X.shape[1:] # Per-sample shape
420
+
421
+ # Same heuristic as prepare_data: detect ambiguous 3D shapes
422
+ is_ambiguous_shape = (
423
+ len(sample_shape) == 3 # Exactly 3D: could be (C, H, W) or (D, H, W)
424
+ and sample_shape[0] <= 16 # First dim looks like channels
425
+ and sample_shape[1] > 16
426
+ and sample_shape[2] > 16 # Both spatial dims are large
427
+ )
428
+
429
+ if is_ambiguous_shape and not args.single_channel:
430
+ raise ValueError(
431
+ f"Ambiguous input shape detected: sample shape {sample_shape}. "
432
+ f"This could be either:\n"
433
+ f" - Multi-channel 2D data (C={sample_shape[0]}, H={sample_shape[1]}, W={sample_shape[2]})\n"
434
+ f" - Single-channel 3D volume (D={sample_shape[0]}, H={sample_shape[1]}, W={sample_shape[2]})\n\n"
435
+ f"If this is single-channel 3D/shallow volume data, use --single_channel flag.\n"
436
+ f"If this is multi-channel 2D data, reshape to (N*C, H, W) with adjusted targets."
437
+ )
438
+
439
+ # in_shape = spatial dimensions for model registry (channel added during training)
440
+ in_shape = sample_shape
441
+
442
+ # Run cross-validation
443
+ try:
444
+ run_cross_validation(
445
+ X=X,
446
+ y=y,
447
+ model_name=args.model,
448
+ in_shape=in_shape,
449
+ out_size=y.shape[1],
450
+ folds=args.cv,
451
+ stratify=args.cv_stratify,
452
+ stratify_bins=args.cv_bins,
453
+ batch_size=args.batch_size,
454
+ lr=args.lr,
455
+ epochs=args.epochs,
456
+ patience=args.patience,
457
+ weight_decay=args.weight_decay,
458
+ loss_name=args.loss,
459
+ optimizer_name=args.optimizer,
460
+ scheduler_name=args.scheduler,
461
+ output_dir=args.output_dir,
462
+ workers=args.workers,
463
+ seed=args.seed,
464
+ )
465
+ finally:
466
+ # Clean up file handle if HDF5/MAT
467
+ if _cv_handle is not None and hasattr(_cv_handle, "close"):
468
+ try:
469
+ _cv_handle.close()
470
+ except Exception:
471
+ pass
472
+ return
473
+
474
+ # ==========================================================================
475
+ # 1. SYSTEM INITIALIZATION
476
+ # ==========================================================================
477
+ # Initialize Accelerator for DDP and mixed precision
478
+ accelerator = Accelerator(
479
+ mixed_precision=args.precision,
480
+ log_with="wandb" if args.wandb and WANDB_AVAILABLE else None,
481
+ )
482
+ set_seed(args.seed)
483
+
484
+ # Configure logging (rank 0 only prints to console)
485
+ logging.basicConfig(
486
+ level=logging.INFO if accelerator.is_main_process else logging.ERROR,
487
+ format="%(asctime)s | %(levelname)s | %(message)s",
488
+ datefmt="%H:%M:%S",
489
+ )
490
+ logger = logging.getLogger("Trainer")
491
+
492
+ # Ensure output directory exists (critical for cache files, checkpoints, etc.)
493
+ os.makedirs(args.output_dir, exist_ok=True)
494
+
495
+ # Auto-detect optimal DataLoader workers if not specified
496
+ if args.workers < 0:
497
+ cpu_count = os.cpu_count() or 4
498
+ num_gpus = accelerator.num_processes
499
+ # Heuristic: 4-8 workers per GPU, bounded by available CPU cores
500
+ # Leave some cores for main process and system overhead
501
+ args.workers = min(8, max(2, (cpu_count - 2) // num_gpus))
502
+ if accelerator.is_main_process:
503
+ logger.info(
504
+ f"⚙️ Auto-detected workers: {args.workers} per GPU "
505
+ f"(CPUs: {cpu_count}, GPUs: {num_gpus})"
506
+ )
507
+
508
+ if accelerator.is_main_process:
509
+ logger.info(f"🚀 Cluster Status: {accelerator.num_processes}x GPUs detected")
510
+ logger.info(
511
+ f" Model: {args.model} | Precision: {args.precision} | Compile: {args.compile}"
512
+ )
513
+ logger.info(
514
+ f" Loss: {args.loss} | Optimizer: {args.optimizer} | Scheduler: {args.scheduler}"
515
+ )
516
+ logger.info(f" Early Stopping Patience: {args.patience} epochs")
517
+ if args.save_every > 0:
518
+ logger.info(f" Periodic Checkpointing: Every {args.save_every} epochs")
519
+ if args.resume:
520
+ logger.info(f" 📂 Resuming from: {args.resume}")
521
+
522
+ # Initialize WandB
523
+ if args.wandb and WANDB_AVAILABLE:
524
+ accelerator.init_trackers(
525
+ project_name=args.project_name,
526
+ config=vars(args),
527
+ init_kwargs={"wandb": {"name": args.run_name or f"{args.model}_run"}},
528
+ )
529
+
530
+ # ==========================================================================
531
+ # 2. DATA & MODEL LOADING
532
+ # ==========================================================================
533
+ train_dl, val_dl, scaler, in_shape, out_dim = prepare_data(
534
+ args, logger, accelerator, cache_dir=args.output_dir
535
+ )
536
+
537
+ # Build model using registry
538
+ model = build_model(args.model, in_shape=in_shape, out_size=out_dim)
539
+
540
+ if accelerator.is_main_process:
541
+ param_info = model.parameter_summary()
542
+ logger.info(
543
+ f" Model Parameters: {param_info['trainable_parameters']:,} trainable"
544
+ )
545
+ logger.info(f" Model Size: {param_info['total_mb']:.2f} MB")
546
+
547
+ # Optional WandB model watching
548
+ if args.wandb and WANDB_AVAILABLE and accelerator.is_main_process:
549
+ wandb.watch(model, log="gradients", log_freq=100)
550
+
551
+ # Torch 2.0 compilation (requires compatible Triton on GPU)
552
+ if args.compile:
553
+ try:
554
+ # Test if Triton is available AND compatible with this PyTorch version
555
+ # PyTorch needs triton_key from triton.compiler.compiler
556
+ from triton.compiler.compiler import triton_key
557
+
558
+ model = torch.compile(model)
559
+ if accelerator.is_main_process:
560
+ logger.info(" ✔ torch.compile enabled (Triton backend)")
561
+ except ImportError as e:
562
+ if accelerator.is_main_process:
563
+ if "triton" in str(e).lower():
564
+ logger.warning(
565
+ " ⚠ Triton not installed or incompatible version - torch.compile disabled. "
566
+ "Training will proceed without compilation."
567
+ )
568
+ else:
569
+ logger.warning(
570
+ f" torch.compile setup failed: {e}. Continuing without compilation."
571
+ )
572
+ except Exception as e:
573
+ if accelerator.is_main_process:
574
+ logger.warning(
575
+ f" ⚠ torch.compile failed: {e}. Continuing without compilation."
576
+ )
577
+
578
+ # ==========================================================================
579
+ # 2.5. OPTIMIZER, SCHEDULER & LOSS CONFIGURATION
580
+ # ==========================================================================
581
+ # Parse comma-separated arguments with validation
582
+ try:
583
+ betas_list = [float(x.strip()) for x in args.betas.split(",")]
584
+ if len(betas_list) != 2:
585
+ raise ValueError(
586
+ f"--betas must have exactly 2 values, got {len(betas_list)}"
587
+ )
588
+ if not all(0.0 <= b < 1.0 for b in betas_list):
589
+ raise ValueError(f"--betas values must be in [0, 1), got {betas_list}")
590
+ betas = tuple(betas_list)
591
+ except ValueError as e:
592
+ raise ValueError(
593
+ f"Invalid --betas format '{args.betas}': {e}. Expected format: '0.9,0.999'"
594
+ )
595
+
596
+ loss_weights = None
597
+ if args.loss_weights:
598
+ loss_weights = [float(x.strip()) for x in args.loss_weights.split(",")]
599
+ milestones = None
600
+ if args.milestones:
601
+ milestones = [int(x.strip()) for x in args.milestones.split(",")]
602
+
603
+ # Create optimizer using factory
604
+ optimizer = get_optimizer(
605
+ name=args.optimizer,
606
+ params=model.get_optimizer_groups(args.lr, args.weight_decay),
607
+ lr=args.lr,
608
+ weight_decay=args.weight_decay,
609
+ momentum=args.momentum,
610
+ nesterov=args.nesterov,
611
+ betas=betas,
612
+ )
613
+
614
+ # Create loss function using factory
615
+ criterion = get_loss(
616
+ name=args.loss,
617
+ weights=loss_weights,
618
+ delta=args.huber_delta,
619
+ )
620
+ # Move criterion to device (important for WeightedMSELoss buffer)
621
+ criterion = criterion.to(accelerator.device)
622
+
623
+ # Track if scheduler should step per batch (OneCycleLR) or per epoch
624
+ scheduler_step_per_batch = not is_epoch_based(args.scheduler)
625
+
626
+ # ==========================================================================
627
+ # DDP Preparation Strategy:
628
+ # - For batch-based schedulers (OneCycleLR): prepare DataLoaders first to get
629
+ # the correct sharded batch count, then create scheduler
630
+ # - For epoch-based schedulers: create scheduler before prepare (no issue)
631
+ # ==========================================================================
632
+ if scheduler_step_per_batch:
633
+ # BATCH-BASED SCHEDULER (e.g., OneCycleLR)
634
+ # Prepare model, optimizer, dataloaders FIRST to get sharded loader length
635
+ model, optimizer, train_dl, val_dl = accelerator.prepare(
636
+ model, optimizer, train_dl, val_dl
637
+ )
638
+
639
+ # Now create scheduler with the CORRECT sharded steps_per_epoch
640
+ steps_per_epoch = len(train_dl) # Post-DDP sharded length
641
+ scheduler = get_scheduler(
642
+ name=args.scheduler,
643
+ optimizer=optimizer,
644
+ epochs=args.epochs,
645
+ steps_per_epoch=steps_per_epoch,
646
+ min_lr=args.min_lr,
647
+ patience=args.scheduler_patience,
648
+ factor=args.scheduler_factor,
649
+ gamma=args.scheduler_factor, # For Step/MultiStep/Exponential schedulers
650
+ step_size=args.step_size,
651
+ milestones=milestones,
652
+ warmup_epochs=args.warmup_epochs,
653
+ )
654
+ # Prepare scheduler separately (Accelerator wraps it for state saving)
655
+ scheduler = accelerator.prepare(scheduler)
656
+ else:
657
+ # EPOCH-BASED SCHEDULER (plateau, cosine, step, etc.)
658
+ # No batch count dependency - create scheduler before prepare
659
+ scheduler = get_scheduler(
660
+ name=args.scheduler,
661
+ optimizer=optimizer,
662
+ epochs=args.epochs,
663
+ steps_per_epoch=None,
664
+ min_lr=args.min_lr,
665
+ patience=args.scheduler_patience,
666
+ factor=args.scheduler_factor,
667
+ gamma=args.scheduler_factor, # For Step/MultiStep/Exponential schedulers
668
+ step_size=args.step_size,
669
+ milestones=milestones,
670
+ warmup_epochs=args.warmup_epochs,
671
+ )
672
+ # Prepare everything together
673
+ model, optimizer, train_dl, val_dl, scheduler = accelerator.prepare(
674
+ model, optimizer, train_dl, val_dl, scheduler
675
+ )
676
+
677
+ # ==========================================================================
678
+ # 3. AUTO-RESUME / RESUME FROM CHECKPOINT
679
+ # ==========================================================================
680
+ start_epoch = 0
681
+ best_val_loss = float("inf")
682
+ patience_ctr = 0
683
+ history: list[dict[str, Any]] = []
684
+
685
+ # Define checkpoint paths
686
+ best_ckpt_path = os.path.join(args.output_dir, "best_checkpoint")
687
+ complete_flag_path = os.path.join(args.output_dir, "training_complete.flag")
688
+
689
+ # Auto-resume logic (if not --fresh and no explicit --resume)
690
+ if not args.fresh and args.resume is None:
691
+ if os.path.exists(complete_flag_path):
692
+ # Training already completed
693
+ if accelerator.is_main_process:
694
+ logger.info(
695
+ "✅ Training already completed (early stopping). Use --fresh to retrain."
696
+ )
697
+ return # Exit gracefully
698
+ elif os.path.exists(best_ckpt_path):
699
+ # Incomplete training found - auto-resume
700
+ args.resume = best_ckpt_path
701
+ if accelerator.is_main_process:
702
+ logger.info(f"🔄 Auto-resuming from: {best_ckpt_path}")
703
+
704
+ if args.resume:
705
+ if os.path.exists(args.resume):
706
+ logger.info(f"🔄 Loading checkpoint from: {args.resume}")
707
+ accelerator.load_state(args.resume)
708
+
709
+ # Restore training metadata
710
+ meta_path = os.path.join(args.resume, "training_meta.pkl")
711
+ if os.path.exists(meta_path):
712
+ with open(meta_path, "rb") as f:
713
+ meta = pickle.load(f)
714
+ start_epoch = meta.get("epoch", 0)
715
+ best_val_loss = meta.get("best_val_loss", float("inf"))
716
+ patience_ctr = meta.get("patience_ctr", 0)
717
+ logger.info(
718
+ f" ✅ Restored: Epoch {start_epoch}, Best Loss: {best_val_loss:.6f}"
719
+ )
720
+ else:
721
+ logger.warning(
722
+ " ⚠️ training_meta.pkl not found, starting from epoch 0"
723
+ )
724
+
725
+ # Restore history
726
+ history_path = os.path.join(args.output_dir, "training_history.csv")
727
+ if os.path.exists(history_path):
728
+ history = pd.read_csv(history_path).to_dict("records")
729
+ logger.info(f" ✅ Loaded {len(history)} epochs from history")
730
+ else:
731
+ raise FileNotFoundError(f"Checkpoint not found: {args.resume}")
732
+
733
+ # ==========================================================================
734
+ # 4. PHYSICAL METRIC SETUP
735
+ # ==========================================================================
736
+ # Physical MAE = normalized MAE * scaler.scale_
737
+ phys_scale = torch.tensor(
738
+ scaler.scale_, device=accelerator.device, dtype=torch.float32
739
+ )
740
+
741
+ # ==========================================================================
742
+ # 5. TRAINING LOOP
743
+ # ==========================================================================
744
+ # Dynamic console header
745
+ if accelerator.is_main_process:
746
+ base_cols = ["Ep", "TrnLoss", "ValLoss", "R2", "PCC", "GradN", "LR", "MAE_Avg"]
747
+ param_cols = [f"MAE_P{i}" for i in range(out_dim)]
748
+ header = "{:<4} | {:<8} | {:<8} | {:<6} | {:<6} | {:<6} | {:<8} | {:<8}".format(
749
+ *base_cols
750
+ )
751
+ header += " | " + " | ".join([f"{c:<8}" for c in param_cols])
752
+ logger.info("=" * len(header))
753
+ logger.info(header)
754
+ logger.info("=" * len(header))
755
+
756
+ try:
757
+ time.time()
758
+ total_training_time = 0.0
759
+
760
+ for epoch in range(start_epoch, args.epochs):
761
+ epoch_start_time = time.time()
762
+
763
+ # ==================== TRAINING PHASE ====================
764
+ model.train()
765
+ # Use GPU tensor for loss accumulation to avoid .item() sync per batch
766
+ train_loss_sum = torch.tensor(0.0, device=accelerator.device)
767
+ train_samples = 0
768
+ grad_norm_tracker = MetricTracker()
769
+
770
+ pbar = tqdm(
771
+ train_dl,
772
+ disable=not accelerator.is_main_process,
773
+ leave=False,
774
+ desc=f"Epoch {epoch + 1}",
775
+ )
776
+
777
+ for x, y in pbar:
778
+ with accelerator.accumulate(model):
779
+ pred = model(x)
780
+ loss = criterion(pred, y)
781
+
782
+ accelerator.backward(loss)
783
+
784
+ if accelerator.sync_gradients:
785
+ grad_norm = accelerator.clip_grad_norm_(
786
+ model.parameters(), args.grad_clip
787
+ )
788
+ if grad_norm is not None:
789
+ grad_norm_tracker.update(grad_norm.item())
790
+
791
+ optimizer.step()
792
+ optimizer.zero_grad(set_to_none=True) # Faster than zero_grad()
793
+
794
+ # Per-batch LR scheduling (e.g., OneCycleLR)
795
+ if scheduler_step_per_batch:
796
+ scheduler.step()
797
+
798
+ # Accumulate as tensors to avoid .item() sync per batch
799
+ train_loss_sum += loss.detach() * x.size(0)
800
+ train_samples += x.size(0)
801
+
802
+ # Single .item() call at end of epoch (reduces GPU sync overhead)
803
+ train_loss_scalar = train_loss_sum.item()
804
+
805
+ # Synchronize training metrics across GPUs
806
+ global_loss = accelerator.reduce(
807
+ torch.tensor([train_loss_scalar], device=accelerator.device),
808
+ reduction="sum",
809
+ ).item()
810
+ global_samples = accelerator.reduce(
811
+ torch.tensor([train_samples], device=accelerator.device),
812
+ reduction="sum",
813
+ ).item()
814
+ avg_train_loss = global_loss / global_samples
815
+
816
+ # ==================== VALIDATION PHASE ====================
817
+ model.eval()
818
+ # Use GPU tensor for loss accumulation (consistent with training phase)
819
+ val_loss_sum = torch.tensor(0.0, device=accelerator.device)
820
+ val_mae_sum = torch.zeros(out_dim, device=accelerator.device)
821
+ val_samples = 0
822
+
823
+ # Accumulate predictions locally, gather ONCE at end (reduces sync overhead)
824
+ local_preds = []
825
+ local_targets = []
826
+
827
+ with torch.inference_mode():
828
+ for x, y in val_dl:
829
+ pred = model(x)
830
+ loss = criterion(pred, y)
831
+
832
+ val_loss_sum += loss.detach() * x.size(0)
833
+ val_samples += x.size(0)
834
+
835
+ # Physical MAE
836
+ mae_batch = torch.abs((pred - y) * phys_scale).sum(dim=0)
837
+ val_mae_sum += mae_batch
838
+
839
+ # Store locally (no GPU sync per batch)
840
+ local_preds.append(pred)
841
+ local_targets.append(y)
842
+
843
+ # Single gather at end of validation (2 syncs instead of 2×num_batches)
844
+ all_local_preds = torch.cat(local_preds)
845
+ all_local_targets = torch.cat(local_targets)
846
+ all_preds = accelerator.gather_for_metrics(all_local_preds)
847
+ all_targets = accelerator.gather_for_metrics(all_local_targets)
848
+
849
+ # Synchronize validation metrics
850
+ val_loss_scalar = val_loss_sum.item()
851
+ val_metrics = torch.cat(
852
+ [
853
+ torch.tensor([val_loss_scalar], device=accelerator.device),
854
+ val_mae_sum,
855
+ ]
856
+ )
857
+ val_metrics_sync = accelerator.reduce(val_metrics, reduction="sum")
858
+
859
+ total_val_samples = accelerator.reduce(
860
+ torch.tensor([val_samples], device=accelerator.device), reduction="sum"
861
+ ).item()
862
+
863
+ avg_val_loss = val_metrics_sync[0].item() / total_val_samples
864
+ # Cast to float32 before numpy (bf16 tensors can't convert directly)
865
+ avg_mae_per_param = (
866
+ (val_metrics_sync[1:] / total_val_samples).float().cpu().numpy()
867
+ )
868
+ avg_mae = avg_mae_per_param.mean()
869
+
870
+ # ==================== LOGGING & CHECKPOINTING ====================
871
+ if accelerator.is_main_process:
872
+ # Scientific metrics - cast to float32 before numpy (bf16 can't convert)
873
+ y_pred = all_preds.float().cpu().numpy()
874
+ y_true = all_targets.float().cpu().numpy()
875
+
876
+ # Trim DDP padding
877
+ real_len = len(val_dl.dataset)
878
+ if len(y_pred) > real_len:
879
+ y_pred = y_pred[:real_len]
880
+ y_true = y_true[:real_len]
881
+
882
+ # Guard against tiny validation sets ( undefined for <2 samples)
883
+ if len(y_true) >= 2:
884
+ r2 = r2_score(y_true, y_pred)
885
+ else:
886
+ r2 = float("nan")
887
+ pcc = calc_pearson(y_true, y_pred)
888
+ current_lr = get_lr(optimizer)
889
+
890
+ # Update history
891
+ epoch_end_time = time.time()
892
+ epoch_time = epoch_end_time - epoch_start_time
893
+ total_training_time += epoch_time
894
+
895
+ epoch_stats = {
896
+ "epoch": epoch + 1,
897
+ "train_loss": avg_train_loss,
898
+ "val_loss": avg_val_loss,
899
+ "val_r2": r2,
900
+ "val_pearson": pcc,
901
+ "val_mae_avg": avg_mae,
902
+ "grad_norm": grad_norm_tracker.avg,
903
+ "lr": current_lr,
904
+ "epoch_time": round(epoch_time, 2),
905
+ "total_time": round(total_training_time, 2),
906
+ }
907
+ for i, mae in enumerate(avg_mae_per_param):
908
+ epoch_stats[f"MAE_Phys_P{i}"] = mae
909
+
910
+ history.append(epoch_stats)
911
+
912
+ # Console display
913
+ base_str = f"{epoch + 1:<4} | {avg_train_loss:<8.4f} | {avg_val_loss:<8.4f} | {r2:<6.4f} | {pcc:<6.4f} | {grad_norm_tracker.avg:<6.4f} | {current_lr:<8.2e} | {avg_mae:<8.4f}"
914
+ param_str = " | ".join([f"{m:<8.4f}" for m in avg_mae_per_param])
915
+ logger.info(f"{base_str} | {param_str}")
916
+
917
+ # WandB logging
918
+ if args.wandb and WANDB_AVAILABLE:
919
+ log_dict = {
920
+ "main/train_loss": avg_train_loss,
921
+ "main/val_loss": avg_val_loss,
922
+ "metrics/r2_score": r2,
923
+ "metrics/pearson_corr": pcc,
924
+ "metrics/mae_avg": avg_mae,
925
+ "system/grad_norm": grad_norm_tracker.avg,
926
+ "hyper/lr": current_lr,
927
+ }
928
+ for i, mae in enumerate(avg_mae_per_param):
929
+ log_dict[f"mae_detailed/P{i}"] = mae
930
+
931
+ # Periodic scatter plots
932
+ if (epoch % 5 == 0) or (avg_val_loss < best_val_loss):
933
+ real_true = scaler.inverse_transform(y_true)
934
+ real_pred = scaler.inverse_transform(y_pred)
935
+ fig = plot_scientific_scatter(real_true, real_pred)
936
+ log_dict["plots/scatter_analysis"] = wandb.Image(fig)
937
+ plt.close(fig)
938
+
939
+ accelerator.log(log_dict)
940
+
941
+ # ==========================================================================
942
+ # DDP-SAFE CHECKPOINT LOGIC
943
+ # ==========================================================================
944
+ # Step 1: Determine if this is the best epoch (BEFORE updating best_val_loss)
945
+ is_best_epoch = False
946
+ if accelerator.is_main_process:
947
+ if avg_val_loss < best_val_loss:
948
+ is_best_epoch = True
949
+
950
+ # Step 2: Broadcast decision to all ranks (required for save_state)
951
+ is_best_epoch = broadcast_early_stop(is_best_epoch, accelerator)
952
+
953
+ # Step 3: Save checkpoint with all ranks participating
954
+ if is_best_epoch:
955
+ ckpt_dir = os.path.join(args.output_dir, "best_checkpoint")
956
+ accelerator.save_state(ckpt_dir) # All ranks must call this
957
+
958
+ # Step 4: Rank 0 handles metadata and updates tracking variables
959
+ if accelerator.is_main_process:
960
+ best_val_loss = avg_val_loss # Update AFTER checkpoint saved
961
+ patience_ctr = 0
962
+
963
+ with open(os.path.join(ckpt_dir, "training_meta.pkl"), "wb") as f:
964
+ pickle.dump(
965
+ {
966
+ "epoch": epoch + 1,
967
+ "best_val_loss": best_val_loss,
968
+ "patience_ctr": patience_ctr,
969
+ # Model info for auto-detection during inference
970
+ "model_name": args.model,
971
+ "in_shape": in_shape,
972
+ "out_dim": out_dim,
973
+ },
974
+ f,
975
+ )
976
+
977
+ unwrapped = accelerator.unwrap_model(model)
978
+ torch.save(
979
+ unwrapped.state_dict(),
980
+ os.path.join(args.output_dir, "best_model_weights.pth"),
981
+ )
982
+
983
+ # Copy scaler to checkpoint for portability
984
+ scaler_src = os.path.join(args.output_dir, "scaler.pkl")
985
+ scaler_dst = os.path.join(ckpt_dir, "scaler.pkl")
986
+ if os.path.exists(scaler_src) and not os.path.exists(scaler_dst):
987
+ shutil.copy2(scaler_src, scaler_dst)
988
+
989
+ logger.info(
990
+ f" 💾 Best model saved (val_loss: {best_val_loss:.6f})"
991
+ )
992
+
993
+ # Also save CSV on best model (ensures progress is saved)
994
+ pd.DataFrame(history).to_csv(
995
+ os.path.join(args.output_dir, "training_history.csv"),
996
+ index=False,
997
+ )
998
+ else:
999
+ if accelerator.is_main_process:
1000
+ patience_ctr += 1
1001
+
1002
+ # Periodic checkpoint (all ranks participate in save_state)
1003
+ periodic_checkpoint_needed = (
1004
+ args.save_every > 0 and (epoch + 1) % args.save_every == 0
1005
+ )
1006
+ if periodic_checkpoint_needed:
1007
+ ckpt_name = f"epoch_{epoch + 1}_checkpoint"
1008
+ ckpt_dir = os.path.join(args.output_dir, ckpt_name)
1009
+ accelerator.save_state(ckpt_dir) # All ranks participate
1010
+
1011
+ if accelerator.is_main_process:
1012
+ with open(os.path.join(ckpt_dir, "training_meta.pkl"), "wb") as f:
1013
+ pickle.dump(
1014
+ {
1015
+ "epoch": epoch + 1,
1016
+ "best_val_loss": best_val_loss,
1017
+ "patience_ctr": patience_ctr,
1018
+ # Model info for auto-detection during inference
1019
+ "model_name": args.model,
1020
+ "in_shape": in_shape,
1021
+ "out_dim": out_dim,
1022
+ },
1023
+ f,
1024
+ )
1025
+ logger.info(f" 📁 Periodic checkpoint: {ckpt_name}")
1026
+
1027
+ # Save CSV with each checkpoint (keeps logs in sync with model state)
1028
+ pd.DataFrame(history).to_csv(
1029
+ os.path.join(args.output_dir, "training_history.csv"),
1030
+ index=False,
1031
+ )
1032
+
1033
+ # Learning rate scheduling (epoch-based schedulers only)
1034
+ if not scheduler_step_per_batch:
1035
+ if args.scheduler == "plateau":
1036
+ scheduler.step(avg_val_loss)
1037
+ else:
1038
+ scheduler.step()
1039
+
1040
+ # DDP-safe early stopping
1041
+ should_stop = (
1042
+ patience_ctr >= args.patience if accelerator.is_main_process else False
1043
+ )
1044
+ if broadcast_early_stop(should_stop, accelerator):
1045
+ if accelerator.is_main_process:
1046
+ logger.info(
1047
+ f"🛑 Early stopping at epoch {epoch + 1} (patience={args.patience})"
1048
+ )
1049
+ # Create completion flag to prevent auto-resume
1050
+ with open(
1051
+ os.path.join(args.output_dir, "training_complete.flag"), "w"
1052
+ ) as f:
1053
+ f.write(
1054
+ f"Training completed via early stopping at epoch {epoch + 1}\n"
1055
+ )
1056
+ break
1057
+
1058
+ except KeyboardInterrupt:
1059
+ logger.warning("Training interrupted. Saving emergency checkpoint...")
1060
+ accelerator.save_state(os.path.join(args.output_dir, "interrupted_checkpoint"))
1061
+
1062
+ except Exception as e:
1063
+ logger.error(f"Critical error: {e}", exc_info=True)
1064
+ raise
1065
+
1066
+ else:
1067
+ # Training completed normally (reached max epochs without early stopping)
1068
+ # Create completion flag to prevent auto-resume on re-run
1069
+ if accelerator.is_main_process:
1070
+ if not os.path.exists(complete_flag_path):
1071
+ with open(complete_flag_path, "w") as f:
1072
+ f.write(f"Training completed normally after {args.epochs} epochs\n")
1073
+ logger.info(f"✅ Training completed after {args.epochs} epochs")
1074
+
1075
+ finally:
1076
+ # Final CSV write to capture all epochs (handles non-multiple-of-10 endings)
1077
+ if accelerator.is_main_process and len(history) > 0:
1078
+ pd.DataFrame(history).to_csv(
1079
+ os.path.join(args.output_dir, "training_history.csv"),
1080
+ index=False,
1081
+ )
1082
+
1083
+ # Generate training curves plot (PNG + SVG)
1084
+ if accelerator.is_main_process and len(history) > 0:
1085
+ try:
1086
+ fig = create_training_curves(history, show_lr=True)
1087
+ for fmt in ["png", "svg"]:
1088
+ fig.savefig(
1089
+ os.path.join(args.output_dir, f"training_curves.{fmt}"),
1090
+ dpi=FIGURE_DPI,
1091
+ bbox_inches="tight",
1092
+ )
1093
+ plt.close(fig)
1094
+ logger.info("✔ Saved: training_curves.png, training_curves.svg")
1095
+ except Exception as e:
1096
+ logger.warning(f"Could not generate training curves: {e}")
1097
+
1098
+ if args.wandb and WANDB_AVAILABLE:
1099
+ accelerator.end_training()
1100
+
1101
+ # Clean up distributed process group to prevent resource leak warning
1102
+ if torch.distributed.is_initialized():
1103
+ torch.distributed.destroy_process_group()
1104
+
1105
+ logger.info("Training completed.")
1106
+
1107
+
1108
+ if __name__ == "__main__":
1109
+ try:
1110
+ torch.multiprocessing.set_start_method("spawn")
1111
+ except RuntimeError:
1112
+ pass
1113
+ main()