wavedl 1.3.1__py3-none-any.whl → 1.4.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
wavedl/train.py CHANGED
@@ -1,1116 +1,1144 @@
1
- """
2
- WaveDL - Deep Learning for Wave-based Inverse Problems
3
- =======================================================
4
- Target Environment: NVIDIA HPC GPUs (Multi-GPU DDP) | PyTorch 2.x | Python 3.11+
5
-
6
- A modular training framework for wave-based inverse problems and regression:
7
- 1. HPC-Grade DDP Training: BF16/FP16 mixed precision with torch.compile support
8
- 2. Dynamic Model Selection: Use --model flag to select any registered architecture
9
- 3. Zero-Copy Data Engine: Memmap-backed datasets for large-scale training
10
- 4. Physics-Aware Metrics: Real-time physical MAE with proper unscaling
11
- 5. Robust Checkpointing: Resume training, periodic saves, and training curves
12
- 6. Deep Observability: WandB integration with scatter analysis
13
-
14
- Usage:
15
- # Recommended: Using the HPC launcher
16
- wavedl-hpc --model cnn --batch_size 128 --wandb
17
-
18
- # Or with direct accelerate launch
19
- accelerate launch -m wavedl.train --model cnn --batch_size 128 --wandb
20
-
21
- # Multi-GPU with explicit config
22
- wavedl-hpc --num_gpus 4 --mixed_precision bf16 --model cnn --wandb
23
-
24
- # Resume from checkpoint
25
- accelerate launch -m wavedl.train --model cnn --resume best_checkpoint --wandb
26
-
27
- # List available models
28
- wavedl-train --list_models
29
-
30
- Note:
31
- For HPC clusters (Compute Canada, etc.), use wavedl-hpc which handles
32
- environment configuration automatically. Mixed precision is controlled via
33
- --mixed_precision flag (default: bf16).
34
-
35
- Author: Ductho Le (ductho.le@outlook.com)
36
- """
37
-
38
- # ==============================================================================
39
- # ENVIRONMENT CONFIGURATION FOR HPC SYSTEMS
40
- # ==============================================================================
41
- # IMPORTANT: These must be set BEFORE matplotlib is imported to be effective
42
- import os
43
-
44
-
45
- os.environ.setdefault("MPLCONFIGDIR", os.getenv("TMPDIR", "/tmp") + "/matplotlib")
46
- os.environ.setdefault("FONTCONFIG_PATH", "/etc/fonts")
47
-
48
- import argparse
49
- import logging
50
- import pickle
51
- import shutil
52
- import sys
53
- import time
54
- import warnings
55
- from typing import Any
56
-
57
- import matplotlib.pyplot as plt
58
- import numpy as np
59
- import pandas as pd
60
- import torch
61
- from accelerate import Accelerator
62
- from accelerate.utils import set_seed
63
- from sklearn.metrics import r2_score
64
- from tqdm.auto import tqdm
65
-
66
- # Local imports
67
- from wavedl.models import build_model, get_model, list_models
68
- from wavedl.utils import (
69
- FIGURE_DPI,
70
- MetricTracker,
71
- broadcast_early_stop,
72
- calc_pearson,
73
- create_training_curves,
74
- # New factory functions
75
- get_loss,
76
- get_lr,
77
- get_optimizer,
78
- get_scheduler,
79
- is_epoch_based,
80
- list_losses,
81
- list_optimizers,
82
- list_schedulers,
83
- plot_scientific_scatter,
84
- prepare_data,
85
- )
86
-
87
-
88
- # Optional WandB import
89
- try:
90
- import wandb
91
-
92
- WANDB_AVAILABLE = True
93
- except ImportError:
94
- WANDB_AVAILABLE = False
95
-
96
- # Filter non-critical warnings for cleaner training logs
97
- warnings.filterwarnings("ignore", category=UserWarning)
98
- warnings.filterwarnings("ignore", category=FutureWarning)
99
- warnings.filterwarnings("ignore", category=DeprecationWarning)
100
- warnings.filterwarnings("ignore", module="pydantic")
101
- warnings.filterwarnings("ignore", message=".*UnsupportedFieldAttributeWarning.*")
102
-
103
-
104
- # ==============================================================================
105
- # ARGUMENT PARSING
106
- # ==============================================================================
107
- def parse_args() -> argparse.Namespace:
108
- """Parse command-line arguments with comprehensive options."""
109
- parser = argparse.ArgumentParser(
110
- description="Universal DDP Training Pipeline",
111
- formatter_class=argparse.ArgumentDefaultsHelpFormatter,
112
- )
113
-
114
- # Model Selection
115
- parser.add_argument(
116
- "--model",
117
- type=str,
118
- default="cnn",
119
- help=f"Model architecture to train. Available: {list_models()}",
120
- )
121
- parser.add_argument(
122
- "--list_models", action="store_true", help="List all available models and exit"
123
- )
124
- parser.add_argument(
125
- "--import",
126
- dest="import_modules",
127
- type=str,
128
- nargs="+",
129
- default=[],
130
- help="Python modules to import before training (for custom models)",
131
- )
132
-
133
- # Configuration File
134
- parser.add_argument(
135
- "--config",
136
- type=str,
137
- default=None,
138
- help="Path to YAML config file. CLI args override config values.",
139
- )
140
-
141
- # Hyperparameters
142
- parser.add_argument(
143
- "--batch_size", type=int, default=128, help="Batch size per GPU"
144
- )
145
- parser.add_argument("--lr", type=float, default=1e-3, help="Initial learning rate")
146
- parser.add_argument(
147
- "--epochs", type=int, default=1000, help="Maximum training epochs"
148
- )
149
- parser.add_argument(
150
- "--patience", type=int, default=20, help="Early stopping patience"
151
- )
152
- parser.add_argument("--weight_decay", type=float, default=1e-4, help="Weight decay")
153
- parser.add_argument(
154
- "--grad_clip", type=float, default=1.0, help="Gradient clipping norm"
155
- )
156
-
157
- # Loss Function
158
- parser.add_argument(
159
- "--loss",
160
- type=str,
161
- default="mse",
162
- choices=["mse", "mae", "huber", "smooth_l1", "log_cosh", "weighted_mse"],
163
- help=f"Loss function for training. Available: {list_losses()}",
164
- )
165
- parser.add_argument(
166
- "--huber_delta", type=float, default=1.0, help="Delta for Huber loss"
167
- )
168
- parser.add_argument(
169
- "--loss_weights",
170
- type=str,
171
- default=None,
172
- help="Comma-separated weights for weighted_mse (e.g., '1.0,2.0,1.0')",
173
- )
174
-
175
- # Optimizer
176
- parser.add_argument(
177
- "--optimizer",
178
- type=str,
179
- default="adamw",
180
- choices=["adamw", "adam", "sgd", "nadam", "radam", "rmsprop"],
181
- help=f"Optimizer for training. Available: {list_optimizers()}",
182
- )
183
- parser.add_argument(
184
- "--momentum", type=float, default=0.9, help="Momentum for SGD/RMSprop"
185
- )
186
- parser.add_argument(
187
- "--nesterov", action="store_true", help="Use Nesterov momentum (SGD)"
188
- )
189
- parser.add_argument(
190
- "--betas",
191
- type=str,
192
- default="0.9,0.999",
193
- help="Betas for Adam variants (comma-separated)",
194
- )
195
-
196
- # Learning Rate Scheduler
197
- parser.add_argument(
198
- "--scheduler",
199
- type=str,
200
- default="plateau",
201
- choices=[
202
- "plateau",
203
- "cosine",
204
- "cosine_restarts",
205
- "onecycle",
206
- "step",
207
- "multistep",
208
- "exponential",
209
- "linear_warmup",
210
- ],
211
- help=f"LR scheduler. Available: {list_schedulers()}",
212
- )
213
- parser.add_argument(
214
- "--scheduler_patience",
215
- type=int,
216
- default=10,
217
- help="Patience for ReduceLROnPlateau",
218
- )
219
- parser.add_argument(
220
- "--min_lr", type=float, default=1e-6, help="Minimum learning rate"
221
- )
222
- parser.add_argument(
223
- "--scheduler_factor", type=float, default=0.5, help="LR reduction factor"
224
- )
225
- parser.add_argument(
226
- "--warmup_epochs", type=int, default=5, help="Warmup epochs for linear_warmup"
227
- )
228
- parser.add_argument(
229
- "--step_size", type=int, default=30, help="Step size for StepLR"
230
- )
231
- parser.add_argument(
232
- "--milestones",
233
- type=str,
234
- default=None,
235
- help="Comma-separated epochs for MultiStepLR (e.g., '30,60,90')",
236
- )
237
-
238
- # Data
239
- parser.add_argument(
240
- "--data_path", type=str, default="train_data.npz", help="Path to NPZ dataset"
241
- )
242
- parser.add_argument(
243
- "--workers",
244
- type=int,
245
- default=-1,
246
- help="DataLoader workers per GPU (-1=auto-detect based on CPU cores)",
247
- )
248
- parser.add_argument("--seed", type=int, default=2025, help="Random seed")
249
- parser.add_argument(
250
- "--single_channel",
251
- action="store_true",
252
- help="Confirm data is single-channel (suppress ambiguous shape warnings for shallow 3D volumes)",
253
- )
254
-
255
- # Cross-Validation
256
- parser.add_argument(
257
- "--cv",
258
- type=int,
259
- default=0,
260
- help="Enable K-fold cross-validation with K folds (0=disabled)",
261
- )
262
- parser.add_argument(
263
- "--cv_stratify",
264
- action="store_true",
265
- help="Use stratified splitting for cross-validation",
266
- )
267
- parser.add_argument(
268
- "--cv_bins",
269
- type=int,
270
- default=10,
271
- help="Number of bins for stratified CV (only with --cv_stratify)",
272
- )
273
-
274
- # Checkpointing & Resume
275
- parser.add_argument(
276
- "--resume", type=str, default=None, help="Checkpoint directory to resume from"
277
- )
278
- parser.add_argument(
279
- "--save_every",
280
- type=int,
281
- default=50,
282
- help="Save checkpoint every N epochs (0=disable)",
283
- )
284
- parser.add_argument(
285
- "--output_dir", type=str, default=".", help="Output directory for checkpoints"
286
- )
287
- parser.add_argument(
288
- "--fresh",
289
- action="store_true",
290
- help="Force fresh training, ignore existing checkpoints",
291
- )
292
-
293
- # Performance
294
- parser.add_argument(
295
- "--compile", action="store_true", help="Enable torch.compile (PyTorch 2.x)"
296
- )
297
- parser.add_argument(
298
- "--precision",
299
- type=str,
300
- default="bf16",
301
- choices=["bf16", "fp16", "no"],
302
- help="Mixed precision mode",
303
- )
304
-
305
- # Logging
306
- parser.add_argument(
307
- "--wandb", action="store_true", help="Enable Weights & Biases logging"
308
- )
309
- parser.add_argument(
310
- "--project_name", type=str, default="DL-Training", help="WandB project name"
311
- )
312
- parser.add_argument("--run_name", type=str, default=None, help="WandB run name")
313
-
314
- args = parser.parse_args()
315
- return args, parser # Returns (Namespace, ArgumentParser)
316
-
317
-
318
- # ==============================================================================
319
- # MAIN TRAINING FUNCTION
320
- # ==============================================================================
321
- def main():
322
- args, parser = parse_args()
323
-
324
- # Import custom model modules if specified
325
- if args.import_modules:
326
- import importlib
327
-
328
- for module_name in args.import_modules:
329
- try:
330
- # Handle both module names (my_model) and file paths (./my_model.py)
331
- if module_name.endswith(".py"):
332
- # Import from file path
333
- import importlib.util
334
-
335
- spec = importlib.util.spec_from_file_location(
336
- "custom_module", module_name
337
- )
338
- if spec and spec.loader:
339
- module = importlib.util.module_from_spec(spec)
340
- sys.modules["custom_module"] = module
341
- spec.loader.exec_module(module)
342
- print(f"✓ Imported custom module from: {module_name}")
343
- else:
344
- # Import as regular module
345
- importlib.import_module(module_name)
346
- print(f"✓ Imported module: {module_name}")
347
- except ImportError as e:
348
- print(f"✗ Failed to import '{module_name}': {e}", file=sys.stderr)
349
- print(
350
- " Make sure the module is in your Python path or current directory."
351
- )
352
- sys.exit(1)
353
-
354
- # Handle --list_models flag
355
- if args.list_models:
356
- print("Available models:")
357
- for name in list_models():
358
- ModelClass = get_model(name)
359
- # Get first non-empty docstring line
360
- if ModelClass.__doc__:
361
- lines = [
362
- l.strip() for l in ModelClass.__doc__.splitlines() if l.strip()
363
- ]
364
- doc_first_line = lines[0] if lines else "No description"
365
- else:
366
- doc_first_line = "No description"
367
- print(f" - {name}: {doc_first_line}")
368
- sys.exit(0)
369
-
370
- # Load and merge config file if provided
371
- if args.config:
372
- from wavedl.utils.config import (
373
- load_config,
374
- merge_config_with_args,
375
- validate_config,
376
- )
377
-
378
- print(f"📄 Loading config from: {args.config}")
379
- config = load_config(args.config)
380
-
381
- # Validate config values
382
- warnings_list = validate_config(config)
383
- for w in warnings_list:
384
- print(f" ⚠ {w}")
385
-
386
- # Merge config with CLI args (CLI takes precedence via parser defaults detection)
387
- args = merge_config_with_args(config, args, parser=parser)
388
-
389
- # Handle --cv flag (cross-validation mode)
390
- if args.cv > 0:
391
- print(f"🔄 Cross-Validation Mode: {args.cv} folds")
392
- from wavedl.utils.cross_validation import run_cross_validation
393
-
394
- # Load data for CV using memory-efficient loader
395
- from wavedl.utils.data import DataSource, get_data_source
396
-
397
- data_format = DataSource.detect_format(args.data_path)
398
- source = get_data_source(data_format)
399
-
400
- # Use memory-mapped loading when available
401
- _cv_handle = None
402
- if hasattr(source, "load_mmap"):
403
- result = source.load_mmap(args.data_path)
404
- if hasattr(result, "inputs"):
405
- _cv_handle = result
406
- X, y = result.inputs, result.outputs
407
- else:
408
- X, y = result # NPZ returns tuple directly
409
- else:
410
- X, y = source.load(args.data_path)
411
-
412
- # Handle sparse matrices (must materialize for CV shuffling)
413
- if hasattr(X, "__getitem__") and len(X) > 0 and hasattr(X[0], "toarray"):
414
- X = np.stack([x.toarray() for x in X])
415
-
416
- # Normalize target shape: (N,) -> (N, 1) for consistency
417
- if y.ndim == 1:
418
- y = y.reshape(-1, 1)
419
-
420
- # Validate and determine input shape (consistent with prepare_data)
421
- # Check for ambiguous shapes that could be multi-channel or shallow 3D volume
422
- sample_shape = X.shape[1:] # Per-sample shape
423
-
424
- # Same heuristic as prepare_data: detect ambiguous 3D shapes
425
- is_ambiguous_shape = (
426
- len(sample_shape) == 3 # Exactly 3D: could be (C, H, W) or (D, H, W)
427
- and sample_shape[0] <= 16 # First dim looks like channels
428
- and sample_shape[1] > 16
429
- and sample_shape[2] > 16 # Both spatial dims are large
430
- )
431
-
432
- if is_ambiguous_shape and not args.single_channel:
433
- raise ValueError(
434
- f"Ambiguous input shape detected: sample shape {sample_shape}. "
435
- f"This could be either:\n"
436
- f" - Multi-channel 2D data (C={sample_shape[0]}, H={sample_shape[1]}, W={sample_shape[2]})\n"
437
- f" - Single-channel 3D volume (D={sample_shape[0]}, H={sample_shape[1]}, W={sample_shape[2]})\n\n"
438
- f"If this is single-channel 3D/shallow volume data, use --single_channel flag.\n"
439
- f"If this is multi-channel 2D data, reshape to (N*C, H, W) with adjusted targets."
440
- )
441
-
442
- # in_shape = spatial dimensions for model registry (channel added during training)
443
- in_shape = sample_shape
444
-
445
- # Run cross-validation
446
- try:
447
- run_cross_validation(
448
- X=X,
449
- y=y,
450
- model_name=args.model,
451
- in_shape=in_shape,
452
- out_size=y.shape[1],
453
- folds=args.cv,
454
- stratify=args.cv_stratify,
455
- stratify_bins=args.cv_bins,
456
- batch_size=args.batch_size,
457
- lr=args.lr,
458
- epochs=args.epochs,
459
- patience=args.patience,
460
- weight_decay=args.weight_decay,
461
- loss_name=args.loss,
462
- optimizer_name=args.optimizer,
463
- scheduler_name=args.scheduler,
464
- output_dir=args.output_dir,
465
- workers=args.workers,
466
- seed=args.seed,
467
- )
468
- finally:
469
- # Clean up file handle if HDF5/MAT
470
- if _cv_handle is not None and hasattr(_cv_handle, "close"):
471
- try:
472
- _cv_handle.close()
473
- except Exception:
474
- pass
475
- return
476
-
477
- # ==========================================================================
478
- # 1. SYSTEM INITIALIZATION
479
- # ==========================================================================
480
- # Initialize Accelerator for DDP and mixed precision
481
- accelerator = Accelerator(
482
- mixed_precision=args.precision,
483
- log_with="wandb" if args.wandb and WANDB_AVAILABLE else None,
484
- )
485
- set_seed(args.seed)
486
-
487
- # Configure logging (rank 0 only prints to console)
488
- logging.basicConfig(
489
- level=logging.INFO if accelerator.is_main_process else logging.ERROR,
490
- format="%(asctime)s | %(levelname)s | %(message)s",
491
- datefmt="%H:%M:%S",
492
- )
493
- logger = logging.getLogger("Trainer")
494
-
495
- # Ensure output directory exists (critical for cache files, checkpoints, etc.)
496
- os.makedirs(args.output_dir, exist_ok=True)
497
-
498
- # Auto-detect optimal DataLoader workers if not specified
499
- if args.workers < 0:
500
- cpu_count = os.cpu_count() or 4
501
- num_gpus = accelerator.num_processes
502
- # Heuristic: 4-8 workers per GPU, bounded by available CPU cores
503
- # Leave some cores for main process and system overhead
504
- args.workers = min(8, max(2, (cpu_count - 2) // num_gpus))
505
- if accelerator.is_main_process:
506
- logger.info(
507
- f"⚙️ Auto-detected workers: {args.workers} per GPU "
508
- f"(CPUs: {cpu_count}, GPUs: {num_gpus})"
509
- )
510
-
511
- if accelerator.is_main_process:
512
- logger.info(f"🚀 Cluster Status: {accelerator.num_processes}x GPUs detected")
513
- logger.info(
514
- f" Model: {args.model} | Precision: {args.precision} | Compile: {args.compile}"
515
- )
516
- logger.info(
517
- f" Loss: {args.loss} | Optimizer: {args.optimizer} | Scheduler: {args.scheduler}"
518
- )
519
- logger.info(f" Early Stopping Patience: {args.patience} epochs")
520
- if args.save_every > 0:
521
- logger.info(f" Periodic Checkpointing: Every {args.save_every} epochs")
522
- if args.resume:
523
- logger.info(f" 📂 Resuming from: {args.resume}")
524
-
525
- # Initialize WandB
526
- if args.wandb and WANDB_AVAILABLE:
527
- accelerator.init_trackers(
528
- project_name=args.project_name,
529
- config=vars(args),
530
- init_kwargs={"wandb": {"name": args.run_name or f"{args.model}_run"}},
531
- )
532
-
533
- # ==========================================================================
534
- # 2. DATA & MODEL LOADING
535
- # ==========================================================================
536
- train_dl, val_dl, scaler, in_shape, out_dim = prepare_data(
537
- args, logger, accelerator, cache_dir=args.output_dir
538
- )
539
-
540
- # Build model using registry
541
- model = build_model(args.model, in_shape=in_shape, out_size=out_dim)
542
-
543
- if accelerator.is_main_process:
544
- param_info = model.parameter_summary()
545
- logger.info(
546
- f" Model Parameters: {param_info['trainable_parameters']:,} trainable"
547
- )
548
- logger.info(f" Model Size: {param_info['total_mb']:.2f} MB")
549
-
550
- # Optional WandB model watching
551
- if args.wandb and WANDB_AVAILABLE and accelerator.is_main_process:
552
- wandb.watch(model, log="gradients", log_freq=100)
553
-
554
- # Torch 2.0 compilation (requires compatible Triton on GPU)
555
- if args.compile:
556
- try:
557
- # Test if Triton is available AND compatible with this PyTorch version
558
- # PyTorch needs triton_key from triton.compiler.compiler
559
- from triton.compiler.compiler import triton_key
560
-
561
- model = torch.compile(model)
562
- if accelerator.is_main_process:
563
- logger.info(" ✔ torch.compile enabled (Triton backend)")
564
- except ImportError as e:
565
- if accelerator.is_main_process:
566
- if "triton" in str(e).lower():
567
- logger.warning(
568
- " Triton not installed or incompatible version - torch.compile disabled. "
569
- "Training will proceed without compilation."
570
- )
571
- else:
572
- logger.warning(
573
- f" ⚠ torch.compile setup failed: {e}. Continuing without compilation."
574
- )
575
- except Exception as e:
576
- if accelerator.is_main_process:
577
- logger.warning(
578
- f" ⚠ torch.compile failed: {e}. Continuing without compilation."
579
- )
580
-
581
- # ==========================================================================
582
- # 2.5. OPTIMIZER, SCHEDULER & LOSS CONFIGURATION
583
- # ==========================================================================
584
- # Parse comma-separated arguments with validation
585
- try:
586
- betas_list = [float(x.strip()) for x in args.betas.split(",")]
587
- if len(betas_list) != 2:
588
- raise ValueError(
589
- f"--betas must have exactly 2 values, got {len(betas_list)}"
590
- )
591
- if not all(0.0 <= b < 1.0 for b in betas_list):
592
- raise ValueError(f"--betas values must be in [0, 1), got {betas_list}")
593
- betas = tuple(betas_list)
594
- except ValueError as e:
595
- raise ValueError(
596
- f"Invalid --betas format '{args.betas}': {e}. Expected format: '0.9,0.999'"
597
- )
598
-
599
- loss_weights = None
600
- if args.loss_weights:
601
- loss_weights = [float(x.strip()) for x in args.loss_weights.split(",")]
602
- milestones = None
603
- if args.milestones:
604
- milestones = [int(x.strip()) for x in args.milestones.split(",")]
605
-
606
- # Create optimizer using factory
607
- optimizer = get_optimizer(
608
- name=args.optimizer,
609
- params=model.get_optimizer_groups(args.lr, args.weight_decay),
610
- lr=args.lr,
611
- weight_decay=args.weight_decay,
612
- momentum=args.momentum,
613
- nesterov=args.nesterov,
614
- betas=betas,
615
- )
616
-
617
- # Create loss function using factory
618
- criterion = get_loss(
619
- name=args.loss,
620
- weights=loss_weights,
621
- delta=args.huber_delta,
622
- )
623
- # Move criterion to device (important for WeightedMSELoss buffer)
624
- criterion = criterion.to(accelerator.device)
625
-
626
- # Track if scheduler should step per batch (OneCycleLR) or per epoch
627
- scheduler_step_per_batch = not is_epoch_based(args.scheduler)
628
-
629
- # ==========================================================================
630
- # DDP Preparation Strategy:
631
- # - For batch-based schedulers (OneCycleLR): prepare DataLoaders first to get
632
- # the correct sharded batch count, then create scheduler
633
- # - For epoch-based schedulers: create scheduler before prepare (no issue)
634
- # ==========================================================================
635
- if scheduler_step_per_batch:
636
- # BATCH-BASED SCHEDULER (e.g., OneCycleLR)
637
- # Prepare model, optimizer, dataloaders FIRST to get sharded loader length
638
- model, optimizer, train_dl, val_dl = accelerator.prepare(
639
- model, optimizer, train_dl, val_dl
640
- )
641
-
642
- # Now create scheduler with the CORRECT sharded steps_per_epoch
643
- steps_per_epoch = len(train_dl) # Post-DDP sharded length
644
- scheduler = get_scheduler(
645
- name=args.scheduler,
646
- optimizer=optimizer,
647
- epochs=args.epochs,
648
- steps_per_epoch=steps_per_epoch,
649
- min_lr=args.min_lr,
650
- patience=args.scheduler_patience,
651
- factor=args.scheduler_factor,
652
- gamma=args.scheduler_factor, # For Step/MultiStep/Exponential schedulers
653
- step_size=args.step_size,
654
- milestones=milestones,
655
- warmup_epochs=args.warmup_epochs,
656
- )
657
- # Prepare scheduler separately (Accelerator wraps it for state saving)
658
- scheduler = accelerator.prepare(scheduler)
659
- else:
660
- # EPOCH-BASED SCHEDULER (plateau, cosine, step, etc.)
661
- # No batch count dependency - create scheduler before prepare
662
- scheduler = get_scheduler(
663
- name=args.scheduler,
664
- optimizer=optimizer,
665
- epochs=args.epochs,
666
- steps_per_epoch=None,
667
- min_lr=args.min_lr,
668
- patience=args.scheduler_patience,
669
- factor=args.scheduler_factor,
670
- gamma=args.scheduler_factor, # For Step/MultiStep/Exponential schedulers
671
- step_size=args.step_size,
672
- milestones=milestones,
673
- warmup_epochs=args.warmup_epochs,
674
- )
675
- # Prepare everything together
676
- model, optimizer, train_dl, val_dl, scheduler = accelerator.prepare(
677
- model, optimizer, train_dl, val_dl, scheduler
678
- )
679
-
680
- # ==========================================================================
681
- # 3. AUTO-RESUME / RESUME FROM CHECKPOINT
682
- # ==========================================================================
683
- start_epoch = 0
684
- best_val_loss = float("inf")
685
- patience_ctr = 0
686
- history: list[dict[str, Any]] = []
687
-
688
- # Define checkpoint paths
689
- best_ckpt_path = os.path.join(args.output_dir, "best_checkpoint")
690
- complete_flag_path = os.path.join(args.output_dir, "training_complete.flag")
691
-
692
- # Auto-resume logic (if not --fresh and no explicit --resume)
693
- if not args.fresh and args.resume is None:
694
- if os.path.exists(complete_flag_path):
695
- # Training already completed
696
- if accelerator.is_main_process:
697
- logger.info(
698
- "✅ Training already completed (early stopping). Use --fresh to retrain."
699
- )
700
- return # Exit gracefully
701
- elif os.path.exists(best_ckpt_path):
702
- # Incomplete training found - auto-resume
703
- args.resume = best_ckpt_path
704
- if accelerator.is_main_process:
705
- logger.info(f"🔄 Auto-resuming from: {best_ckpt_path}")
706
-
707
- if args.resume:
708
- if os.path.exists(args.resume):
709
- logger.info(f"🔄 Loading checkpoint from: {args.resume}")
710
- accelerator.load_state(args.resume)
711
-
712
- # Restore training metadata
713
- meta_path = os.path.join(args.resume, "training_meta.pkl")
714
- if os.path.exists(meta_path):
715
- with open(meta_path, "rb") as f:
716
- meta = pickle.load(f)
717
- start_epoch = meta.get("epoch", 0)
718
- best_val_loss = meta.get("best_val_loss", float("inf"))
719
- patience_ctr = meta.get("patience_ctr", 0)
720
- logger.info(
721
- f" ✅ Restored: Epoch {start_epoch}, Best Loss: {best_val_loss:.6f}"
722
- )
723
- else:
724
- logger.warning(
725
- " ⚠️ training_meta.pkl not found, starting from epoch 0"
726
- )
727
-
728
- # Restore history
729
- history_path = os.path.join(args.output_dir, "training_history.csv")
730
- if os.path.exists(history_path):
731
- history = pd.read_csv(history_path).to_dict("records")
732
- logger.info(f" ✅ Loaded {len(history)} epochs from history")
733
- else:
734
- raise FileNotFoundError(f"Checkpoint not found: {args.resume}")
735
-
736
- # ==========================================================================
737
- # 4. PHYSICAL METRIC SETUP
738
- # ==========================================================================
739
- # Physical MAE = normalized MAE * scaler.scale_
740
- phys_scale = torch.tensor(
741
- scaler.scale_, device=accelerator.device, dtype=torch.float32
742
- )
743
-
744
- # ==========================================================================
745
- # 5. TRAINING LOOP
746
- # ==========================================================================
747
- # Dynamic console header
748
- if accelerator.is_main_process:
749
- base_cols = ["Ep", "TrnLoss", "ValLoss", "R2", "PCC", "GradN", "LR", "MAE_Avg"]
750
- param_cols = [f"MAE_P{i}" for i in range(out_dim)]
751
- header = "{:<4} | {:<8} | {:<8} | {:<6} | {:<6} | {:<6} | {:<8} | {:<8}".format(
752
- *base_cols
753
- )
754
- header += " | " + " | ".join([f"{c:<8}" for c in param_cols])
755
- logger.info("=" * len(header))
756
- logger.info(header)
757
- logger.info("=" * len(header))
758
-
759
- try:
760
- time.time()
761
- total_training_time = 0.0
762
-
763
- for epoch in range(start_epoch, args.epochs):
764
- epoch_start_time = time.time()
765
-
766
- # ==================== TRAINING PHASE ====================
767
- model.train()
768
- # Use GPU tensor for loss accumulation to avoid .item() sync per batch
769
- train_loss_sum = torch.tensor(0.0, device=accelerator.device)
770
- train_samples = 0
771
- grad_norm_tracker = MetricTracker()
772
-
773
- pbar = tqdm(
774
- train_dl,
775
- disable=not accelerator.is_main_process,
776
- leave=False,
777
- desc=f"Epoch {epoch + 1}",
778
- )
779
-
780
- for x, y in pbar:
781
- with accelerator.accumulate(model):
782
- pred = model(x)
783
- loss = criterion(pred, y)
784
-
785
- accelerator.backward(loss)
786
-
787
- if accelerator.sync_gradients:
788
- grad_norm = accelerator.clip_grad_norm_(
789
- model.parameters(), args.grad_clip
790
- )
791
- if grad_norm is not None:
792
- grad_norm_tracker.update(grad_norm.item())
793
-
794
- optimizer.step()
795
- optimizer.zero_grad(set_to_none=True) # Faster than zero_grad()
796
-
797
- # Per-batch LR scheduling (e.g., OneCycleLR)
798
- if scheduler_step_per_batch:
799
- scheduler.step()
800
-
801
- # Accumulate as tensors to avoid .item() sync per batch
802
- train_loss_sum += loss.detach() * x.size(0)
803
- train_samples += x.size(0)
804
-
805
- # Single .item() call at end of epoch (reduces GPU sync overhead)
806
- train_loss_scalar = train_loss_sum.item()
807
-
808
- # Synchronize training metrics across GPUs
809
- global_loss = accelerator.reduce(
810
- torch.tensor([train_loss_scalar], device=accelerator.device),
811
- reduction="sum",
812
- ).item()
813
- global_samples = accelerator.reduce(
814
- torch.tensor([train_samples], device=accelerator.device),
815
- reduction="sum",
816
- ).item()
817
- avg_train_loss = global_loss / global_samples
818
-
819
- # ==================== VALIDATION PHASE ====================
820
- model.eval()
821
- # Use GPU tensor for loss accumulation (consistent with training phase)
822
- val_loss_sum = torch.tensor(0.0, device=accelerator.device)
823
- val_mae_sum = torch.zeros(out_dim, device=accelerator.device)
824
- val_samples = 0
825
-
826
- # Accumulate predictions locally, gather ONCE at end (reduces sync overhead)
827
- local_preds = []
828
- local_targets = []
829
-
830
- with torch.inference_mode():
831
- for x, y in val_dl:
832
- pred = model(x)
833
- loss = criterion(pred, y)
834
-
835
- val_loss_sum += loss.detach() * x.size(0)
836
- val_samples += x.size(0)
837
-
838
- # Physical MAE
839
- mae_batch = torch.abs((pred - y) * phys_scale).sum(dim=0)
840
- val_mae_sum += mae_batch
841
-
842
- # Store locally (no GPU sync per batch)
843
- local_preds.append(pred)
844
- local_targets.append(y)
845
-
846
- # Single gather at end of validation (2 syncs instead of 2×num_batches)
847
- all_local_preds = torch.cat(local_preds)
848
- all_local_targets = torch.cat(local_targets)
849
- all_preds = accelerator.gather_for_metrics(all_local_preds)
850
- all_targets = accelerator.gather_for_metrics(all_local_targets)
851
-
852
- # Synchronize validation metrics
853
- val_loss_scalar = val_loss_sum.item()
854
- val_metrics = torch.cat(
855
- [
856
- torch.tensor([val_loss_scalar], device=accelerator.device),
857
- val_mae_sum,
858
- ]
859
- )
860
- val_metrics_sync = accelerator.reduce(val_metrics, reduction="sum")
861
-
862
- total_val_samples = accelerator.reduce(
863
- torch.tensor([val_samples], device=accelerator.device), reduction="sum"
864
- ).item()
865
-
866
- avg_val_loss = val_metrics_sync[0].item() / total_val_samples
867
- # Cast to float32 before numpy (bf16 tensors can't convert directly)
868
- avg_mae_per_param = (
869
- (val_metrics_sync[1:] / total_val_samples).float().cpu().numpy()
870
- )
871
- avg_mae = avg_mae_per_param.mean()
872
-
873
- # ==================== LOGGING & CHECKPOINTING ====================
874
- if accelerator.is_main_process:
875
- # Scientific metrics - cast to float32 before numpy (bf16 can't convert)
876
- y_pred = all_preds.float().cpu().numpy()
877
- y_true = all_targets.float().cpu().numpy()
878
-
879
- # Trim DDP padding
880
- real_len = len(val_dl.dataset)
881
- if len(y_pred) > real_len:
882
- y_pred = y_pred[:real_len]
883
- y_true = y_true[:real_len]
884
-
885
- # Guard against tiny validation sets (R² undefined for <2 samples)
886
- if len(y_true) >= 2:
887
- r2 = r2_score(y_true, y_pred)
888
- else:
889
- r2 = float("nan")
890
- pcc = calc_pearson(y_true, y_pred)
891
- current_lr = get_lr(optimizer)
892
-
893
- # Update history
894
- epoch_end_time = time.time()
895
- epoch_time = epoch_end_time - epoch_start_time
896
- total_training_time += epoch_time
897
-
898
- epoch_stats = {
899
- "epoch": epoch + 1,
900
- "train_loss": avg_train_loss,
901
- "val_loss": avg_val_loss,
902
- "val_r2": r2,
903
- "val_pearson": pcc,
904
- "val_mae_avg": avg_mae,
905
- "grad_norm": grad_norm_tracker.avg,
906
- "lr": current_lr,
907
- "epoch_time": round(epoch_time, 2),
908
- "total_time": round(total_training_time, 2),
909
- }
910
- for i, mae in enumerate(avg_mae_per_param):
911
- epoch_stats[f"MAE_Phys_P{i}"] = mae
912
-
913
- history.append(epoch_stats)
914
-
915
- # Console display
916
- base_str = f"{epoch + 1:<4} | {avg_train_loss:<8.4f} | {avg_val_loss:<8.4f} | {r2:<6.4f} | {pcc:<6.4f} | {grad_norm_tracker.avg:<6.4f} | {current_lr:<8.2e} | {avg_mae:<8.4f}"
917
- param_str = " | ".join([f"{m:<8.4f}" for m in avg_mae_per_param])
918
- logger.info(f"{base_str} | {param_str}")
919
-
920
- # WandB logging
921
- if args.wandb and WANDB_AVAILABLE:
922
- log_dict = {
923
- "main/train_loss": avg_train_loss,
924
- "main/val_loss": avg_val_loss,
925
- "metrics/r2_score": r2,
926
- "metrics/pearson_corr": pcc,
927
- "metrics/mae_avg": avg_mae,
928
- "system/grad_norm": grad_norm_tracker.avg,
929
- "hyper/lr": current_lr,
930
- }
931
- for i, mae in enumerate(avg_mae_per_param):
932
- log_dict[f"mae_detailed/P{i}"] = mae
933
-
934
- # Periodic scatter plots
935
- if (epoch % 5 == 0) or (avg_val_loss < best_val_loss):
936
- real_true = scaler.inverse_transform(y_true)
937
- real_pred = scaler.inverse_transform(y_pred)
938
- fig = plot_scientific_scatter(real_true, real_pred)
939
- log_dict["plots/scatter_analysis"] = wandb.Image(fig)
940
- plt.close(fig)
941
-
942
- accelerator.log(log_dict)
943
-
944
- # ==========================================================================
945
- # DDP-SAFE CHECKPOINT LOGIC
946
- # ==========================================================================
947
- # Step 1: Determine if this is the best epoch (BEFORE updating best_val_loss)
948
- is_best_epoch = False
949
- if accelerator.is_main_process:
950
- if avg_val_loss < best_val_loss:
951
- is_best_epoch = True
952
-
953
- # Step 2: Broadcast decision to all ranks (required for save_state)
954
- is_best_epoch = broadcast_early_stop(is_best_epoch, accelerator)
955
-
956
- # Step 3: Save checkpoint with all ranks participating
957
- if is_best_epoch:
958
- ckpt_dir = os.path.join(args.output_dir, "best_checkpoint")
959
- accelerator.save_state(ckpt_dir) # All ranks must call this
960
-
961
- # Step 4: Rank 0 handles metadata and updates tracking variables
962
- if accelerator.is_main_process:
963
- best_val_loss = avg_val_loss # Update AFTER checkpoint saved
964
- patience_ctr = 0
965
-
966
- with open(os.path.join(ckpt_dir, "training_meta.pkl"), "wb") as f:
967
- pickle.dump(
968
- {
969
- "epoch": epoch + 1,
970
- "best_val_loss": best_val_loss,
971
- "patience_ctr": patience_ctr,
972
- # Model info for auto-detection during inference
973
- "model_name": args.model,
974
- "in_shape": in_shape,
975
- "out_dim": out_dim,
976
- },
977
- f,
978
- )
979
-
980
- unwrapped = accelerator.unwrap_model(model)
981
- torch.save(
982
- unwrapped.state_dict(),
983
- os.path.join(args.output_dir, "best_model_weights.pth"),
984
- )
985
-
986
- # Copy scaler to checkpoint for portability
987
- scaler_src = os.path.join(args.output_dir, "scaler.pkl")
988
- scaler_dst = os.path.join(ckpt_dir, "scaler.pkl")
989
- if os.path.exists(scaler_src) and not os.path.exists(scaler_dst):
990
- shutil.copy2(scaler_src, scaler_dst)
991
-
992
- logger.info(
993
- f" 💾 Best model saved (val_loss: {best_val_loss:.6f})"
994
- )
995
-
996
- # Also save CSV on best model (ensures progress is saved)
997
- pd.DataFrame(history).to_csv(
998
- os.path.join(args.output_dir, "training_history.csv"),
999
- index=False,
1000
- )
1001
- else:
1002
- if accelerator.is_main_process:
1003
- patience_ctr += 1
1004
-
1005
- # Periodic checkpoint (all ranks participate in save_state)
1006
- periodic_checkpoint_needed = (
1007
- args.save_every > 0 and (epoch + 1) % args.save_every == 0
1008
- )
1009
- if periodic_checkpoint_needed:
1010
- ckpt_name = f"epoch_{epoch + 1}_checkpoint"
1011
- ckpt_dir = os.path.join(args.output_dir, ckpt_name)
1012
- accelerator.save_state(ckpt_dir) # All ranks participate
1013
-
1014
- if accelerator.is_main_process:
1015
- with open(os.path.join(ckpt_dir, "training_meta.pkl"), "wb") as f:
1016
- pickle.dump(
1017
- {
1018
- "epoch": epoch + 1,
1019
- "best_val_loss": best_val_loss,
1020
- "patience_ctr": patience_ctr,
1021
- # Model info for auto-detection during inference
1022
- "model_name": args.model,
1023
- "in_shape": in_shape,
1024
- "out_dim": out_dim,
1025
- },
1026
- f,
1027
- )
1028
- logger.info(f" 📁 Periodic checkpoint: {ckpt_name}")
1029
-
1030
- # Save CSV with each checkpoint (keeps logs in sync with model state)
1031
- pd.DataFrame(history).to_csv(
1032
- os.path.join(args.output_dir, "training_history.csv"),
1033
- index=False,
1034
- )
1035
-
1036
- # Learning rate scheduling (epoch-based schedulers only)
1037
- if not scheduler_step_per_batch:
1038
- if args.scheduler == "plateau":
1039
- scheduler.step(avg_val_loss)
1040
- else:
1041
- scheduler.step()
1042
-
1043
- # DDP-safe early stopping
1044
- should_stop = (
1045
- patience_ctr >= args.patience if accelerator.is_main_process else False
1046
- )
1047
- if broadcast_early_stop(should_stop, accelerator):
1048
- if accelerator.is_main_process:
1049
- logger.info(
1050
- f"🛑 Early stopping at epoch {epoch + 1} (patience={args.patience})"
1051
- )
1052
- # Create completion flag to prevent auto-resume
1053
- with open(
1054
- os.path.join(args.output_dir, "training_complete.flag"), "w"
1055
- ) as f:
1056
- f.write(
1057
- f"Training completed via early stopping at epoch {epoch + 1}\n"
1058
- )
1059
- break
1060
-
1061
- except KeyboardInterrupt:
1062
- logger.warning("Training interrupted. Saving emergency checkpoint...")
1063
- accelerator.save_state(os.path.join(args.output_dir, "interrupted_checkpoint"))
1064
-
1065
- except Exception as e:
1066
- logger.error(f"Critical error: {e}", exc_info=True)
1067
- raise
1068
-
1069
- else:
1070
- # Training completed normally (reached max epochs without early stopping)
1071
- # Create completion flag to prevent auto-resume on re-run
1072
- if accelerator.is_main_process:
1073
- if not os.path.exists(complete_flag_path):
1074
- with open(complete_flag_path, "w") as f:
1075
- f.write(f"Training completed normally after {args.epochs} epochs\n")
1076
- logger.info(f"✅ Training completed after {args.epochs} epochs")
1077
-
1078
- finally:
1079
- # Final CSV write to capture all epochs (handles non-multiple-of-10 endings)
1080
- if accelerator.is_main_process and len(history) > 0:
1081
- pd.DataFrame(history).to_csv(
1082
- os.path.join(args.output_dir, "training_history.csv"),
1083
- index=False,
1084
- )
1085
-
1086
- # Generate training curves plot (PNG + SVG)
1087
- if accelerator.is_main_process and len(history) > 0:
1088
- try:
1089
- fig = create_training_curves(history, show_lr=True)
1090
- for fmt in ["png", "svg"]:
1091
- fig.savefig(
1092
- os.path.join(args.output_dir, f"training_curves.{fmt}"),
1093
- dpi=FIGURE_DPI,
1094
- bbox_inches="tight",
1095
- )
1096
- plt.close(fig)
1097
- logger.info("✔ Saved: training_curves.png, training_curves.svg")
1098
- except Exception as e:
1099
- logger.warning(f"Could not generate training curves: {e}")
1100
-
1101
- if args.wandb and WANDB_AVAILABLE:
1102
- accelerator.end_training()
1103
-
1104
- # Clean up distributed process group to prevent resource leak warning
1105
- if torch.distributed.is_initialized():
1106
- torch.distributed.destroy_process_group()
1107
-
1108
- logger.info("Training completed.")
1109
-
1110
-
1111
- if __name__ == "__main__":
1112
- try:
1113
- torch.multiprocessing.set_start_method("spawn")
1114
- except RuntimeError:
1115
- pass
1116
- main()
1
+ """
2
+ WaveDL - Deep Learning for Wave-based Inverse Problems
3
+ =======================================================
4
+ Target Environment: NVIDIA HPC GPUs (Multi-GPU DDP) | PyTorch 2.x | Python 3.11+
5
+
6
+ A modular training framework for wave-based inverse problems and regression:
7
+ 1. HPC-Grade DDP Training: BF16/FP16 mixed precision with torch.compile support
8
+ 2. Dynamic Model Selection: Use --model flag to select any registered architecture
9
+ 3. Zero-Copy Data Engine: Memmap-backed datasets for large-scale training
10
+ 4. Physics-Aware Metrics: Real-time physical MAE with proper unscaling
11
+ 5. Robust Checkpointing: Resume training, periodic saves, and training curves
12
+ 6. Deep Observability: WandB integration with scatter analysis
13
+
14
+ Usage:
15
+ # Recommended: Using the HPC launcher (handles accelerate configuration)
16
+ wavedl-hpc --model cnn --batch_size 128 --mixed_precision bf16 --wandb
17
+
18
+ # Or direct training module (use --precision, not --mixed_precision)
19
+ accelerate launch -m wavedl.train --model cnn --batch_size 128 --precision bf16
20
+
21
+ # Multi-GPU with explicit config
22
+ wavedl-hpc --num_gpus 4 --mixed_precision bf16 --model cnn --wandb
23
+
24
+ # Resume from checkpoint
25
+ accelerate launch -m wavedl.train --model cnn --resume best_checkpoint --wandb
26
+
27
+ # List available models
28
+ wavedl-train --list_models
29
+
30
+ Note:
31
+ - wavedl-hpc: Uses --mixed_precision (passed to accelerate launch)
32
+ - wavedl.train: Uses --precision (internal module flag)
33
+ Both control the same behavior; use the appropriate flag for your entry point.
34
+
35
+ Author: Ductho Le (ductho.le@outlook.com)
36
+ """
37
+
38
+ from __future__ import annotations
39
+
40
+ import argparse
41
+ import logging
42
+ import os
43
+ import pickle
44
+ import shutil
45
+ import sys
46
+ import time
47
+ import warnings
48
+ from typing import Any
49
+
50
+ import matplotlib.pyplot as plt
51
+ import numpy as np
52
+ import pandas as pd
53
+ import torch
54
+ from accelerate import Accelerator
55
+ from accelerate.utils import set_seed
56
+ from sklearn.metrics import r2_score
57
+ from tqdm.auto import tqdm
58
+
59
+ from wavedl.models import build_model, get_model, list_models
60
+ from wavedl.utils import (
61
+ FIGURE_DPI,
62
+ MetricTracker,
63
+ broadcast_early_stop,
64
+ calc_pearson,
65
+ create_training_curves,
66
+ get_loss,
67
+ get_lr,
68
+ get_optimizer,
69
+ get_scheduler,
70
+ is_epoch_based,
71
+ list_losses,
72
+ list_optimizers,
73
+ list_schedulers,
74
+ plot_scientific_scatter,
75
+ prepare_data,
76
+ )
77
+
78
+
79
+ try:
80
+ import wandb
81
+
82
+ WANDB_AVAILABLE = True
83
+ except ImportError:
84
+ WANDB_AVAILABLE = False
85
+
86
+ # ==============================================================================
87
+ # RUNTIME CONFIGURATION (post-import)
88
+ # ==============================================================================
89
+ # Configure matplotlib paths for HPC systems without writable home directories
90
+ os.environ.setdefault("MPLCONFIGDIR", os.getenv("TMPDIR", "/tmp") + "/matplotlib")
91
+ os.environ.setdefault("FONTCONFIG_PATH", "/etc/fonts")
92
+
93
+ # Suppress non-critical warnings for cleaner training logs
94
+ warnings.filterwarnings("ignore", category=UserWarning)
95
+ warnings.filterwarnings("ignore", category=FutureWarning)
96
+ warnings.filterwarnings("ignore", category=DeprecationWarning)
97
+ warnings.filterwarnings("ignore", module="pydantic")
98
+ warnings.filterwarnings("ignore", message=".*UnsupportedFieldAttributeWarning.*")
99
+
100
+ # ==============================================================================
101
+ # GPU PERFORMANCE OPTIMIZATIONS (Ampere/Hopper: A100, H100)
102
+ # ==============================================================================
103
+ # Enable TF32 for faster matmul (safe precision for training, ~2x speedup)
104
+ torch.backends.cuda.matmul.allow_tf32 = True
105
+ torch.backends.cudnn.allow_tf32 = True
106
+ torch.set_float32_matmul_precision("high") # Use TF32 for float32 ops
107
+
108
+ # Enable cuDNN autotuning for fixed-size inputs (CNN-like models benefit most)
109
+ # Note: First few batches may be slower due to benchmarking
110
+ torch.backends.cudnn.benchmark = True
111
+
112
+
113
+ # ==============================================================================
114
+ # ARGUMENT PARSING
115
+ # ==============================================================================
116
+ def parse_args() -> argparse.Namespace:
117
+ """Parse command-line arguments with comprehensive options."""
118
+ parser = argparse.ArgumentParser(
119
+ description="Universal DDP Training Pipeline",
120
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter,
121
+ )
122
+
123
+ # Model Selection
124
+ parser.add_argument(
125
+ "--model",
126
+ type=str,
127
+ default="cnn",
128
+ help=f"Model architecture to train. Available: {list_models()}",
129
+ )
130
+ parser.add_argument(
131
+ "--list_models", action="store_true", help="List all available models and exit"
132
+ )
133
+ parser.add_argument(
134
+ "--import",
135
+ dest="import_modules",
136
+ type=str,
137
+ nargs="+",
138
+ default=[],
139
+ help="Python modules to import before training (for custom models)",
140
+ )
141
+
142
+ # Configuration File
143
+ parser.add_argument(
144
+ "--config",
145
+ type=str,
146
+ default=None,
147
+ help="Path to YAML config file. CLI args override config values.",
148
+ )
149
+
150
+ # Hyperparameters
151
+ parser.add_argument(
152
+ "--batch_size", type=int, default=128, help="Batch size per GPU"
153
+ )
154
+ parser.add_argument("--lr", type=float, default=1e-3, help="Initial learning rate")
155
+ parser.add_argument(
156
+ "--epochs", type=int, default=1000, help="Maximum training epochs"
157
+ )
158
+ parser.add_argument(
159
+ "--patience", type=int, default=20, help="Early stopping patience"
160
+ )
161
+ parser.add_argument("--weight_decay", type=float, default=1e-4, help="Weight decay")
162
+ parser.add_argument(
163
+ "--grad_clip", type=float, default=1.0, help="Gradient clipping norm"
164
+ )
165
+
166
+ # Loss Function
167
+ parser.add_argument(
168
+ "--loss",
169
+ type=str,
170
+ default="mse",
171
+ choices=["mse", "mae", "huber", "smooth_l1", "log_cosh", "weighted_mse"],
172
+ help=f"Loss function for training. Available: {list_losses()}",
173
+ )
174
+ parser.add_argument(
175
+ "--huber_delta", type=float, default=1.0, help="Delta for Huber loss"
176
+ )
177
+ parser.add_argument(
178
+ "--loss_weights",
179
+ type=str,
180
+ default=None,
181
+ help="Comma-separated weights for weighted_mse (e.g., '1.0,2.0,1.0')",
182
+ )
183
+
184
+ # Optimizer
185
+ parser.add_argument(
186
+ "--optimizer",
187
+ type=str,
188
+ default="adamw",
189
+ choices=["adamw", "adam", "sgd", "nadam", "radam", "rmsprop"],
190
+ help=f"Optimizer for training. Available: {list_optimizers()}",
191
+ )
192
+ parser.add_argument(
193
+ "--momentum", type=float, default=0.9, help="Momentum for SGD/RMSprop"
194
+ )
195
+ parser.add_argument(
196
+ "--nesterov", action="store_true", help="Use Nesterov momentum (SGD)"
197
+ )
198
+ parser.add_argument(
199
+ "--betas",
200
+ type=str,
201
+ default="0.9,0.999",
202
+ help="Betas for Adam variants (comma-separated)",
203
+ )
204
+
205
+ # Learning Rate Scheduler
206
+ parser.add_argument(
207
+ "--scheduler",
208
+ type=str,
209
+ default="plateau",
210
+ choices=[
211
+ "plateau",
212
+ "cosine",
213
+ "cosine_restarts",
214
+ "onecycle",
215
+ "step",
216
+ "multistep",
217
+ "exponential",
218
+ "linear_warmup",
219
+ ],
220
+ help=f"LR scheduler. Available: {list_schedulers()}",
221
+ )
222
+ parser.add_argument(
223
+ "--scheduler_patience",
224
+ type=int,
225
+ default=10,
226
+ help="Patience for ReduceLROnPlateau",
227
+ )
228
+ parser.add_argument(
229
+ "--min_lr", type=float, default=1e-6, help="Minimum learning rate"
230
+ )
231
+ parser.add_argument(
232
+ "--scheduler_factor", type=float, default=0.5, help="LR reduction factor"
233
+ )
234
+ parser.add_argument(
235
+ "--warmup_epochs", type=int, default=5, help="Warmup epochs for linear_warmup"
236
+ )
237
+ parser.add_argument(
238
+ "--step_size", type=int, default=30, help="Step size for StepLR"
239
+ )
240
+ parser.add_argument(
241
+ "--milestones",
242
+ type=str,
243
+ default=None,
244
+ help="Comma-separated epochs for MultiStepLR (e.g., '30,60,90')",
245
+ )
246
+
247
+ # Data
248
+ parser.add_argument(
249
+ "--data_path", type=str, default="train_data.npz", help="Path to NPZ dataset"
250
+ )
251
+ parser.add_argument(
252
+ "--workers",
253
+ type=int,
254
+ default=-1,
255
+ help="DataLoader workers per GPU (-1=auto-detect based on CPU cores)",
256
+ )
257
+ parser.add_argument("--seed", type=int, default=2025, help="Random seed")
258
+ parser.add_argument(
259
+ "--single_channel",
260
+ action="store_true",
261
+ help="Confirm data is single-channel (suppress ambiguous shape warnings for shallow 3D volumes)",
262
+ )
263
+
264
+ # Cross-Validation
265
+ parser.add_argument(
266
+ "--cv",
267
+ type=int,
268
+ default=0,
269
+ help="Enable K-fold cross-validation with K folds (0=disabled)",
270
+ )
271
+ parser.add_argument(
272
+ "--cv_stratify",
273
+ action="store_true",
274
+ help="Use stratified splitting for cross-validation",
275
+ )
276
+ parser.add_argument(
277
+ "--cv_bins",
278
+ type=int,
279
+ default=10,
280
+ help="Number of bins for stratified CV (only with --cv_stratify)",
281
+ )
282
+
283
+ # Checkpointing & Resume
284
+ parser.add_argument(
285
+ "--resume", type=str, default=None, help="Checkpoint directory to resume from"
286
+ )
287
+ parser.add_argument(
288
+ "--save_every",
289
+ type=int,
290
+ default=50,
291
+ help="Save checkpoint every N epochs (0=disable)",
292
+ )
293
+ parser.add_argument(
294
+ "--output_dir", type=str, default=".", help="Output directory for checkpoints"
295
+ )
296
+ parser.add_argument(
297
+ "--fresh",
298
+ action="store_true",
299
+ help="Force fresh training, ignore existing checkpoints",
300
+ )
301
+
302
+ # Performance
303
+ parser.add_argument(
304
+ "--compile", action="store_true", help="Enable torch.compile (PyTorch 2.x)"
305
+ )
306
+ parser.add_argument(
307
+ "--precision",
308
+ type=str,
309
+ default="bf16",
310
+ choices=["bf16", "fp16", "no"],
311
+ help="Mixed precision mode",
312
+ )
313
+ # Alias for consistency with wavedl-hpc (--mixed_precision)
314
+ parser.add_argument(
315
+ "--mixed_precision",
316
+ dest="precision",
317
+ type=str,
318
+ choices=["bf16", "fp16", "no"],
319
+ help=argparse.SUPPRESS, # Hidden: use --precision instead
320
+ )
321
+
322
+ # Logging
323
+ parser.add_argument(
324
+ "--wandb", action="store_true", help="Enable Weights & Biases logging"
325
+ )
326
+ parser.add_argument(
327
+ "--wandb_watch",
328
+ action="store_true",
329
+ help="Enable WandB gradient watching (adds overhead, useful for debugging)",
330
+ )
331
+ parser.add_argument(
332
+ "--project_name", type=str, default="DL-Training", help="WandB project name"
333
+ )
334
+ parser.add_argument("--run_name", type=str, default=None, help="WandB run name")
335
+
336
+ args = parser.parse_args()
337
+ return args, parser # Returns (Namespace, ArgumentParser)
338
+
339
+
340
+ # ==============================================================================
341
+ # MAIN TRAINING FUNCTION
342
+ # ==============================================================================
343
+ def main():
344
+ args, parser = parse_args()
345
+
346
+ # Import custom model modules if specified
347
+ if args.import_modules:
348
+ import importlib
349
+
350
+ for module_name in args.import_modules:
351
+ try:
352
+ # Handle both module names (my_model) and file paths (./my_model.py)
353
+ if module_name.endswith(".py"):
354
+ # Import from file path
355
+ import importlib.util
356
+
357
+ spec = importlib.util.spec_from_file_location(
358
+ "custom_module", module_name
359
+ )
360
+ if spec and spec.loader:
361
+ module = importlib.util.module_from_spec(spec)
362
+ sys.modules["custom_module"] = module
363
+ spec.loader.exec_module(module)
364
+ print(f"✓ Imported custom module from: {module_name}")
365
+ else:
366
+ # Import as regular module
367
+ importlib.import_module(module_name)
368
+ print(f"✓ Imported module: {module_name}")
369
+ except ImportError as e:
370
+ print(f"✗ Failed to import '{module_name}': {e}", file=sys.stderr)
371
+ print(
372
+ " Make sure the module is in your Python path or current directory."
373
+ )
374
+ sys.exit(1)
375
+
376
+ # Handle --list_models flag
377
+ if args.list_models:
378
+ print("Available models:")
379
+ for name in list_models():
380
+ ModelClass = get_model(name)
381
+ # Get first non-empty docstring line
382
+ if ModelClass.__doc__:
383
+ lines = [
384
+ l.strip() for l in ModelClass.__doc__.splitlines() if l.strip()
385
+ ]
386
+ doc_first_line = lines[0] if lines else "No description"
387
+ else:
388
+ doc_first_line = "No description"
389
+ print(f" - {name}: {doc_first_line}")
390
+ sys.exit(0)
391
+
392
+ # Load and merge config file if provided
393
+ if args.config:
394
+ from wavedl.utils.config import (
395
+ load_config,
396
+ merge_config_with_args,
397
+ validate_config,
398
+ )
399
+
400
+ print(f"📄 Loading config from: {args.config}")
401
+ config = load_config(args.config)
402
+
403
+ # Validate config values
404
+ warnings_list = validate_config(config)
405
+ for w in warnings_list:
406
+ print(f" ⚠ {w}")
407
+
408
+ # Merge config with CLI args (CLI takes precedence via parser defaults detection)
409
+ args = merge_config_with_args(config, args, parser=parser)
410
+
411
+ # Handle --cv flag (cross-validation mode)
412
+ if args.cv > 0:
413
+ print(f"🔄 Cross-Validation Mode: {args.cv} folds")
414
+ from wavedl.utils.cross_validation import run_cross_validation
415
+
416
+ # Load data for CV using memory-efficient loader
417
+ from wavedl.utils.data import DataSource, get_data_source
418
+
419
+ data_format = DataSource.detect_format(args.data_path)
420
+ source = get_data_source(data_format)
421
+
422
+ # Use memory-mapped loading when available
423
+ _cv_handle = None
424
+ if hasattr(source, "load_mmap"):
425
+ result = source.load_mmap(args.data_path)
426
+ if hasattr(result, "inputs"):
427
+ _cv_handle = result
428
+ X, y = result.inputs, result.outputs
429
+ else:
430
+ X, y = result # NPZ returns tuple directly
431
+ else:
432
+ X, y = source.load(args.data_path)
433
+
434
+ # Handle sparse matrices (must materialize for CV shuffling)
435
+ if hasattr(X, "__getitem__") and len(X) > 0 and hasattr(X[0], "toarray"):
436
+ X = np.stack([x.toarray() for x in X])
437
+
438
+ # Normalize target shape: (N,) -> (N, 1) for consistency
439
+ if y.ndim == 1:
440
+ y = y.reshape(-1, 1)
441
+
442
+ # Validate and determine input shape (consistent with prepare_data)
443
+ # Check for ambiguous shapes that could be multi-channel or shallow 3D volume
444
+ sample_shape = X.shape[1:] # Per-sample shape
445
+
446
+ # Same heuristic as prepare_data: detect ambiguous 3D shapes
447
+ is_ambiguous_shape = (
448
+ len(sample_shape) == 3 # Exactly 3D: could be (C, H, W) or (D, H, W)
449
+ and sample_shape[0] <= 16 # First dim looks like channels
450
+ and sample_shape[1] > 16
451
+ and sample_shape[2] > 16 # Both spatial dims are large
452
+ )
453
+
454
+ if is_ambiguous_shape and not args.single_channel:
455
+ raise ValueError(
456
+ f"Ambiguous input shape detected: sample shape {sample_shape}. "
457
+ f"This could be either:\n"
458
+ f" - Multi-channel 2D data (C={sample_shape[0]}, H={sample_shape[1]}, W={sample_shape[2]})\n"
459
+ f" - Single-channel 3D volume (D={sample_shape[0]}, H={sample_shape[1]}, W={sample_shape[2]})\n\n"
460
+ f"If this is single-channel 3D/shallow volume data, use --single_channel flag.\n"
461
+ f"If this is multi-channel 2D data, reshape to (N*C, H, W) with adjusted targets."
462
+ )
463
+
464
+ # in_shape = spatial dimensions for model registry (channel added during training)
465
+ in_shape = sample_shape
466
+
467
+ # Run cross-validation
468
+ try:
469
+ run_cross_validation(
470
+ X=X,
471
+ y=y,
472
+ model_name=args.model,
473
+ in_shape=in_shape,
474
+ out_size=y.shape[1],
475
+ folds=args.cv,
476
+ stratify=args.cv_stratify,
477
+ stratify_bins=args.cv_bins,
478
+ batch_size=args.batch_size,
479
+ lr=args.lr,
480
+ epochs=args.epochs,
481
+ patience=args.patience,
482
+ weight_decay=args.weight_decay,
483
+ loss_name=args.loss,
484
+ optimizer_name=args.optimizer,
485
+ scheduler_name=args.scheduler,
486
+ output_dir=args.output_dir,
487
+ workers=args.workers,
488
+ seed=args.seed,
489
+ )
490
+ finally:
491
+ # Clean up file handle if HDF5/MAT
492
+ if _cv_handle is not None and hasattr(_cv_handle, "close"):
493
+ try:
494
+ _cv_handle.close()
495
+ except Exception as e:
496
+ logging.debug(f"Failed to close CV data handle: {e}")
497
+ return
498
+
499
+ # ==========================================================================
500
+ # 1. SYSTEM INITIALIZATION
501
+ # ==========================================================================
502
+ # Initialize Accelerator for DDP and mixed precision
503
+ accelerator = Accelerator(
504
+ mixed_precision=args.precision,
505
+ log_with="wandb" if args.wandb and WANDB_AVAILABLE else None,
506
+ )
507
+ set_seed(args.seed)
508
+
509
+ # Configure logging (rank 0 only prints to console)
510
+ logging.basicConfig(
511
+ level=logging.INFO if accelerator.is_main_process else logging.ERROR,
512
+ format="%(asctime)s | %(levelname)s | %(message)s",
513
+ datefmt="%H:%M:%S",
514
+ )
515
+ logger = logging.getLogger("Trainer")
516
+
517
+ # Ensure output directory exists (critical for cache files, checkpoints, etc.)
518
+ os.makedirs(args.output_dir, exist_ok=True)
519
+
520
+ # Auto-detect optimal DataLoader workers if not specified
521
+ if args.workers < 0:
522
+ cpu_count = os.cpu_count() or 4
523
+ num_gpus = accelerator.num_processes
524
+ # Heuristic: 4-16 workers per GPU, bounded by available CPU cores
525
+ # Increased cap from 8 to 16 for high-throughput GPUs (H100, A100)
526
+ args.workers = min(16, max(2, (cpu_count - 2) // num_gpus))
527
+ if accelerator.is_main_process:
528
+ logger.info(
529
+ f"⚙️ Auto-detected workers: {args.workers} per GPU "
530
+ f"(CPUs: {cpu_count}, GPUs: {num_gpus})"
531
+ )
532
+
533
+ if accelerator.is_main_process:
534
+ logger.info(f"🚀 Cluster Status: {accelerator.num_processes}x GPUs detected")
535
+ logger.info(
536
+ f" Model: {args.model} | Precision: {args.precision} | Compile: {args.compile}"
537
+ )
538
+ logger.info(
539
+ f" Loss: {args.loss} | Optimizer: {args.optimizer} | Scheduler: {args.scheduler}"
540
+ )
541
+ logger.info(f" Early Stopping Patience: {args.patience} epochs")
542
+ if args.save_every > 0:
543
+ logger.info(f" Periodic Checkpointing: Every {args.save_every} epochs")
544
+ if args.resume:
545
+ logger.info(f" 📂 Resuming from: {args.resume}")
546
+
547
+ # Initialize WandB
548
+ if args.wandb and WANDB_AVAILABLE:
549
+ accelerator.init_trackers(
550
+ project_name=args.project_name,
551
+ config=vars(args),
552
+ init_kwargs={"wandb": {"name": args.run_name or f"{args.model}_run"}},
553
+ )
554
+
555
+ # ==========================================================================
556
+ # 2. DATA & MODEL LOADING
557
+ # ==========================================================================
558
+ train_dl, val_dl, scaler, in_shape, out_dim = prepare_data(
559
+ args, logger, accelerator, cache_dir=args.output_dir
560
+ )
561
+
562
+ # Build model using registry
563
+ model = build_model(args.model, in_shape=in_shape, out_size=out_dim)
564
+
565
+ if accelerator.is_main_process:
566
+ param_info = model.parameter_summary()
567
+ logger.info(
568
+ f" Model Parameters: {param_info['trainable_parameters']:,} trainable"
569
+ )
570
+ logger.info(f" Model Size: {param_info['total_mb']:.2f} MB")
571
+
572
+ # Optional WandB model watching (opt-in due to overhead on large models)
573
+ if (
574
+ args.wandb
575
+ and args.wandb_watch
576
+ and WANDB_AVAILABLE
577
+ and accelerator.is_main_process
578
+ ):
579
+ wandb.watch(model, log="gradients", log_freq=100)
580
+ logger.info(" 📊 WandB gradient watching enabled")
581
+
582
+ # Torch 2.0 compilation (requires compatible Triton on GPU)
583
+ if args.compile:
584
+ try:
585
+ # Test if Triton is available AND compatible with this PyTorch version
586
+ # PyTorch needs triton_key from triton.compiler.compiler
587
+ from triton.compiler.compiler import triton_key
588
+
589
+ model = torch.compile(model)
590
+ if accelerator.is_main_process:
591
+ logger.info(" ✔ torch.compile enabled (Triton backend)")
592
+ except ImportError as e:
593
+ if accelerator.is_main_process:
594
+ if "triton" in str(e).lower():
595
+ logger.warning(
596
+ " Triton not installed or incompatible version - torch.compile disabled. "
597
+ "Training will proceed without compilation."
598
+ )
599
+ else:
600
+ logger.warning(
601
+ f" ⚠ torch.compile setup failed: {e}. Continuing without compilation."
602
+ )
603
+ except Exception as e:
604
+ if accelerator.is_main_process:
605
+ logger.warning(
606
+ f" ⚠ torch.compile failed: {e}. Continuing without compilation."
607
+ )
608
+
609
+ # ==========================================================================
610
+ # 2.5. OPTIMIZER, SCHEDULER & LOSS CONFIGURATION
611
+ # ==========================================================================
612
+ # Parse comma-separated arguments with validation
613
+ try:
614
+ betas_list = [float(x.strip()) for x in args.betas.split(",")]
615
+ if len(betas_list) != 2:
616
+ raise ValueError(
617
+ f"--betas must have exactly 2 values, got {len(betas_list)}"
618
+ )
619
+ if not all(0.0 <= b < 1.0 for b in betas_list):
620
+ raise ValueError(f"--betas values must be in [0, 1), got {betas_list}")
621
+ betas = tuple(betas_list)
622
+ except ValueError as e:
623
+ raise ValueError(
624
+ f"Invalid --betas format '{args.betas}': {e}. Expected format: '0.9,0.999'"
625
+ )
626
+
627
+ loss_weights = None
628
+ if args.loss_weights:
629
+ loss_weights = [float(x.strip()) for x in args.loss_weights.split(",")]
630
+ milestones = None
631
+ if args.milestones:
632
+ milestones = [int(x.strip()) for x in args.milestones.split(",")]
633
+
634
+ # Create optimizer using factory
635
+ optimizer = get_optimizer(
636
+ name=args.optimizer,
637
+ params=model.get_optimizer_groups(args.lr, args.weight_decay),
638
+ lr=args.lr,
639
+ weight_decay=args.weight_decay,
640
+ momentum=args.momentum,
641
+ nesterov=args.nesterov,
642
+ betas=betas,
643
+ )
644
+
645
+ # Create loss function using factory
646
+ criterion = get_loss(
647
+ name=args.loss,
648
+ weights=loss_weights,
649
+ delta=args.huber_delta,
650
+ )
651
+ # Move criterion to device (important for WeightedMSELoss buffer)
652
+ criterion = criterion.to(accelerator.device)
653
+
654
+ # Track if scheduler should step per batch (OneCycleLR) or per epoch
655
+ scheduler_step_per_batch = not is_epoch_based(args.scheduler)
656
+
657
+ # ==========================================================================
658
+ # DDP Preparation Strategy:
659
+ # - For batch-based schedulers (OneCycleLR): prepare DataLoaders first to get
660
+ # the correct sharded batch count, then create scheduler
661
+ # - For epoch-based schedulers: create scheduler before prepare (no issue)
662
+ # ==========================================================================
663
+ if scheduler_step_per_batch:
664
+ # BATCH-BASED SCHEDULER (e.g., OneCycleLR)
665
+ # Prepare model, optimizer, dataloaders FIRST to get sharded loader length
666
+ model, optimizer, train_dl, val_dl = accelerator.prepare(
667
+ model, optimizer, train_dl, val_dl
668
+ )
669
+
670
+ # Now create scheduler with the CORRECT sharded steps_per_epoch
671
+ steps_per_epoch = len(train_dl) # Post-DDP sharded length
672
+ scheduler = get_scheduler(
673
+ name=args.scheduler,
674
+ optimizer=optimizer,
675
+ epochs=args.epochs,
676
+ steps_per_epoch=steps_per_epoch,
677
+ min_lr=args.min_lr,
678
+ patience=args.scheduler_patience,
679
+ factor=args.scheduler_factor,
680
+ gamma=args.scheduler_factor, # For Step/MultiStep/Exponential schedulers
681
+ step_size=args.step_size,
682
+ milestones=milestones,
683
+ warmup_epochs=args.warmup_epochs,
684
+ )
685
+ # Prepare scheduler separately (Accelerator wraps it for state saving)
686
+ scheduler = accelerator.prepare(scheduler)
687
+ else:
688
+ # EPOCH-BASED SCHEDULER (plateau, cosine, step, etc.)
689
+ # No batch count dependency - create scheduler before prepare
690
+ scheduler = get_scheduler(
691
+ name=args.scheduler,
692
+ optimizer=optimizer,
693
+ epochs=args.epochs,
694
+ steps_per_epoch=None,
695
+ min_lr=args.min_lr,
696
+ patience=args.scheduler_patience,
697
+ factor=args.scheduler_factor,
698
+ gamma=args.scheduler_factor, # For Step/MultiStep/Exponential schedulers
699
+ step_size=args.step_size,
700
+ milestones=milestones,
701
+ warmup_epochs=args.warmup_epochs,
702
+ )
703
+ # Prepare everything together
704
+ model, optimizer, train_dl, val_dl, scheduler = accelerator.prepare(
705
+ model, optimizer, train_dl, val_dl, scheduler
706
+ )
707
+
708
+ # ==========================================================================
709
+ # 3. AUTO-RESUME / RESUME FROM CHECKPOINT
710
+ # ==========================================================================
711
+ start_epoch = 0
712
+ best_val_loss = float("inf")
713
+ patience_ctr = 0
714
+ history: list[dict[str, Any]] = []
715
+
716
+ # Define checkpoint paths
717
+ best_ckpt_path = os.path.join(args.output_dir, "best_checkpoint")
718
+ complete_flag_path = os.path.join(args.output_dir, "training_complete.flag")
719
+
720
+ # Auto-resume logic (if not --fresh and no explicit --resume)
721
+ if not args.fresh and args.resume is None:
722
+ if os.path.exists(complete_flag_path):
723
+ # Training already completed
724
+ if accelerator.is_main_process:
725
+ logger.info(
726
+ "✅ Training already completed (early stopping). Use --fresh to retrain."
727
+ )
728
+ return # Exit gracefully
729
+ elif os.path.exists(best_ckpt_path):
730
+ # Incomplete training found - auto-resume
731
+ args.resume = best_ckpt_path
732
+ if accelerator.is_main_process:
733
+ logger.info(f"🔄 Auto-resuming from: {best_ckpt_path}")
734
+
735
+ if args.resume:
736
+ if os.path.exists(args.resume):
737
+ logger.info(f"🔄 Loading checkpoint from: {args.resume}")
738
+ accelerator.load_state(args.resume)
739
+
740
+ # Restore training metadata
741
+ meta_path = os.path.join(args.resume, "training_meta.pkl")
742
+ if os.path.exists(meta_path):
743
+ with open(meta_path, "rb") as f:
744
+ meta = pickle.load(f)
745
+ start_epoch = meta.get("epoch", 0)
746
+ best_val_loss = meta.get("best_val_loss", float("inf"))
747
+ patience_ctr = meta.get("patience_ctr", 0)
748
+ logger.info(
749
+ f" Restored: Epoch {start_epoch}, Best Loss: {best_val_loss:.6f}"
750
+ )
751
+ else:
752
+ logger.warning(
753
+ " ⚠️ training_meta.pkl not found, starting from epoch 0"
754
+ )
755
+
756
+ # Restore history
757
+ history_path = os.path.join(args.output_dir, "training_history.csv")
758
+ if os.path.exists(history_path):
759
+ history = pd.read_csv(history_path).to_dict("records")
760
+ logger.info(f" ✅ Loaded {len(history)} epochs from history")
761
+ else:
762
+ raise FileNotFoundError(f"Checkpoint not found: {args.resume}")
763
+
764
+ # ==========================================================================
765
+ # 4. PHYSICAL METRIC SETUP
766
+ # ==========================================================================
767
+ # Physical MAE = normalized MAE * scaler.scale_
768
+ phys_scale = torch.tensor(
769
+ scaler.scale_, device=accelerator.device, dtype=torch.float32
770
+ )
771
+
772
+ # ==========================================================================
773
+ # 5. TRAINING LOOP
774
+ # ==========================================================================
775
+ # Dynamic console header
776
+ if accelerator.is_main_process:
777
+ base_cols = ["Ep", "TrnLoss", "ValLoss", "R2", "PCC", "GradN", "LR", "MAE_Avg"]
778
+ param_cols = [f"MAE_P{i}" for i in range(out_dim)]
779
+ header = "{:<4} | {:<8} | {:<8} | {:<6} | {:<6} | {:<6} | {:<8} | {:<8}".format(
780
+ *base_cols
781
+ )
782
+ header += " | " + " | ".join([f"{c:<8}" for c in param_cols])
783
+ logger.info("=" * len(header))
784
+ logger.info(header)
785
+ logger.info("=" * len(header))
786
+
787
+ try:
788
+ time.time()
789
+ total_training_time = 0.0
790
+
791
+ for epoch in range(start_epoch, args.epochs):
792
+ epoch_start_time = time.time()
793
+
794
+ # ==================== TRAINING PHASE ====================
795
+ model.train()
796
+ # Use GPU tensor for loss accumulation to avoid .item() sync per batch
797
+ train_loss_sum = torch.tensor(0.0, device=accelerator.device)
798
+ train_samples = 0
799
+ grad_norm_tracker = MetricTracker()
800
+
801
+ pbar = tqdm(
802
+ train_dl,
803
+ disable=not accelerator.is_main_process,
804
+ leave=False,
805
+ desc=f"Epoch {epoch + 1}",
806
+ )
807
+
808
+ for x, y in pbar:
809
+ with accelerator.accumulate(model):
810
+ pred = model(x)
811
+ loss = criterion(pred, y)
812
+
813
+ accelerator.backward(loss)
814
+
815
+ if accelerator.sync_gradients:
816
+ grad_norm = accelerator.clip_grad_norm_(
817
+ model.parameters(), args.grad_clip
818
+ )
819
+ if grad_norm is not None:
820
+ grad_norm_tracker.update(grad_norm.item())
821
+
822
+ optimizer.step()
823
+ optimizer.zero_grad(set_to_none=True) # Faster than zero_grad()
824
+
825
+ # Per-batch LR scheduling (e.g., OneCycleLR)
826
+ if scheduler_step_per_batch:
827
+ scheduler.step()
828
+
829
+ # Accumulate as tensors to avoid .item() sync per batch
830
+ train_loss_sum += loss.detach() * x.size(0)
831
+ train_samples += x.size(0)
832
+
833
+ # Single .item() call at end of epoch (reduces GPU sync overhead)
834
+ train_loss_scalar = train_loss_sum.item()
835
+
836
+ # Synchronize training metrics across GPUs
837
+ global_loss = accelerator.reduce(
838
+ torch.tensor([train_loss_scalar], device=accelerator.device),
839
+ reduction="sum",
840
+ ).item()
841
+ global_samples = accelerator.reduce(
842
+ torch.tensor([train_samples], device=accelerator.device),
843
+ reduction="sum",
844
+ ).item()
845
+ avg_train_loss = global_loss / global_samples
846
+
847
+ # ==================== VALIDATION PHASE ====================
848
+ model.eval()
849
+ # Use GPU tensor for loss accumulation (consistent with training phase)
850
+ val_loss_sum = torch.tensor(0.0, device=accelerator.device)
851
+ val_mae_sum = torch.zeros(out_dim, device=accelerator.device)
852
+ val_samples = 0
853
+
854
+ # Accumulate predictions locally, gather ONCE at end (reduces sync overhead)
855
+ local_preds = []
856
+ local_targets = []
857
+
858
+ with torch.inference_mode():
859
+ for x, y in val_dl:
860
+ pred = model(x)
861
+ loss = criterion(pred, y)
862
+
863
+ val_loss_sum += loss.detach() * x.size(0)
864
+ val_samples += x.size(0)
865
+
866
+ # Physical MAE
867
+ mae_batch = torch.abs((pred - y) * phys_scale).sum(dim=0)
868
+ val_mae_sum += mae_batch
869
+
870
+ # Store locally (no GPU sync per batch)
871
+ local_preds.append(pred)
872
+ local_targets.append(y)
873
+
874
+ # Single gather at end of validation (2 syncs instead of 2×num_batches)
875
+ all_local_preds = torch.cat(local_preds)
876
+ all_local_targets = torch.cat(local_targets)
877
+ all_preds = accelerator.gather_for_metrics(all_local_preds)
878
+ all_targets = accelerator.gather_for_metrics(all_local_targets)
879
+
880
+ # Synchronize validation metrics
881
+ val_loss_scalar = val_loss_sum.item()
882
+ val_metrics = torch.cat(
883
+ [
884
+ torch.tensor([val_loss_scalar], device=accelerator.device),
885
+ val_mae_sum,
886
+ ]
887
+ )
888
+ val_metrics_sync = accelerator.reduce(val_metrics, reduction="sum")
889
+
890
+ total_val_samples = accelerator.reduce(
891
+ torch.tensor([val_samples], device=accelerator.device), reduction="sum"
892
+ ).item()
893
+
894
+ avg_val_loss = val_metrics_sync[0].item() / total_val_samples
895
+ # Cast to float32 before numpy (bf16 tensors can't convert directly)
896
+ avg_mae_per_param = (
897
+ (val_metrics_sync[1:] / total_val_samples).float().cpu().numpy()
898
+ )
899
+ avg_mae = avg_mae_per_param.mean()
900
+
901
+ # ==================== LOGGING & CHECKPOINTING ====================
902
+ if accelerator.is_main_process:
903
+ # Scientific metrics - cast to float32 before numpy (bf16 can't convert)
904
+ y_pred = all_preds.float().cpu().numpy()
905
+ y_true = all_targets.float().cpu().numpy()
906
+
907
+ # Trim DDP padding
908
+ real_len = len(val_dl.dataset)
909
+ if len(y_pred) > real_len:
910
+ y_pred = y_pred[:real_len]
911
+ y_true = y_true[:real_len]
912
+
913
+ # Guard against tiny validation sets (R² undefined for <2 samples)
914
+ if len(y_true) >= 2:
915
+ r2 = r2_score(y_true, y_pred)
916
+ else:
917
+ r2 = float("nan")
918
+ pcc = calc_pearson(y_true, y_pred)
919
+ current_lr = get_lr(optimizer)
920
+
921
+ # Update history
922
+ epoch_end_time = time.time()
923
+ epoch_time = epoch_end_time - epoch_start_time
924
+ total_training_time += epoch_time
925
+
926
+ epoch_stats = {
927
+ "epoch": epoch + 1,
928
+ "train_loss": avg_train_loss,
929
+ "val_loss": avg_val_loss,
930
+ "val_r2": r2,
931
+ "val_pearson": pcc,
932
+ "val_mae_avg": avg_mae,
933
+ "grad_norm": grad_norm_tracker.avg,
934
+ "lr": current_lr,
935
+ "epoch_time": round(epoch_time, 2),
936
+ "total_time": round(total_training_time, 2),
937
+ }
938
+ for i, mae in enumerate(avg_mae_per_param):
939
+ epoch_stats[f"MAE_Phys_P{i}"] = mae
940
+
941
+ history.append(epoch_stats)
942
+
943
+ # Console display
944
+ base_str = f"{epoch + 1:<4} | {avg_train_loss:<8.4f} | {avg_val_loss:<8.4f} | {r2:<6.4f} | {pcc:<6.4f} | {grad_norm_tracker.avg:<6.4f} | {current_lr:<8.2e} | {avg_mae:<8.4f}"
945
+ param_str = " | ".join([f"{m:<8.4f}" for m in avg_mae_per_param])
946
+ logger.info(f"{base_str} | {param_str}")
947
+
948
+ # WandB logging
949
+ if args.wandb and WANDB_AVAILABLE:
950
+ log_dict = {
951
+ "main/train_loss": avg_train_loss,
952
+ "main/val_loss": avg_val_loss,
953
+ "metrics/r2_score": r2,
954
+ "metrics/pearson_corr": pcc,
955
+ "metrics/mae_avg": avg_mae,
956
+ "system/grad_norm": grad_norm_tracker.avg,
957
+ "hyper/lr": current_lr,
958
+ }
959
+ for i, mae in enumerate(avg_mae_per_param):
960
+ log_dict[f"mae_detailed/P{i}"] = mae
961
+
962
+ # Periodic scatter plots
963
+ if (epoch % 5 == 0) or (avg_val_loss < best_val_loss):
964
+ real_true = scaler.inverse_transform(y_true)
965
+ real_pred = scaler.inverse_transform(y_pred)
966
+ fig = plot_scientific_scatter(real_true, real_pred)
967
+ log_dict["plots/scatter_analysis"] = wandb.Image(fig)
968
+ plt.close(fig)
969
+
970
+ accelerator.log(log_dict)
971
+
972
+ # ==========================================================================
973
+ # DDP-SAFE CHECKPOINT LOGIC
974
+ # ==========================================================================
975
+ # Step 1: Determine if this is the best epoch (BEFORE updating best_val_loss)
976
+ is_best_epoch = False
977
+ if accelerator.is_main_process:
978
+ if avg_val_loss < best_val_loss:
979
+ is_best_epoch = True
980
+
981
+ # Step 2: Broadcast decision to all ranks (required for save_state)
982
+ is_best_epoch = broadcast_early_stop(is_best_epoch, accelerator)
983
+
984
+ # Step 3: Save checkpoint with all ranks participating
985
+ if is_best_epoch:
986
+ ckpt_dir = os.path.join(args.output_dir, "best_checkpoint")
987
+ accelerator.save_state(ckpt_dir) # All ranks must call this
988
+
989
+ # Step 4: Rank 0 handles metadata and updates tracking variables
990
+ if accelerator.is_main_process:
991
+ best_val_loss = avg_val_loss # Update AFTER checkpoint saved
992
+ patience_ctr = 0
993
+
994
+ with open(os.path.join(ckpt_dir, "training_meta.pkl"), "wb") as f:
995
+ pickle.dump(
996
+ {
997
+ "epoch": epoch + 1,
998
+ "best_val_loss": best_val_loss,
999
+ "patience_ctr": patience_ctr,
1000
+ # Model info for auto-detection during inference
1001
+ "model_name": args.model,
1002
+ "in_shape": in_shape,
1003
+ "out_dim": out_dim,
1004
+ },
1005
+ f,
1006
+ )
1007
+
1008
+ unwrapped = accelerator.unwrap_model(model)
1009
+ torch.save(
1010
+ unwrapped.state_dict(),
1011
+ os.path.join(args.output_dir, "best_model_weights.pth"),
1012
+ )
1013
+
1014
+ # Copy scaler to checkpoint for portability
1015
+ scaler_src = os.path.join(args.output_dir, "scaler.pkl")
1016
+ scaler_dst = os.path.join(ckpt_dir, "scaler.pkl")
1017
+ if os.path.exists(scaler_src) and not os.path.exists(scaler_dst):
1018
+ shutil.copy2(scaler_src, scaler_dst)
1019
+
1020
+ logger.info(
1021
+ f" 💾 Best model saved (val_loss: {best_val_loss:.6f})"
1022
+ )
1023
+
1024
+ # Also save CSV on best model (ensures progress is saved)
1025
+ pd.DataFrame(history).to_csv(
1026
+ os.path.join(args.output_dir, "training_history.csv"),
1027
+ index=False,
1028
+ )
1029
+ else:
1030
+ if accelerator.is_main_process:
1031
+ patience_ctr += 1
1032
+
1033
+ # Periodic checkpoint (all ranks participate in save_state)
1034
+ periodic_checkpoint_needed = (
1035
+ args.save_every > 0 and (epoch + 1) % args.save_every == 0
1036
+ )
1037
+ if periodic_checkpoint_needed:
1038
+ ckpt_name = f"epoch_{epoch + 1}_checkpoint"
1039
+ ckpt_dir = os.path.join(args.output_dir, ckpt_name)
1040
+ accelerator.save_state(ckpt_dir) # All ranks participate
1041
+
1042
+ if accelerator.is_main_process:
1043
+ with open(os.path.join(ckpt_dir, "training_meta.pkl"), "wb") as f:
1044
+ pickle.dump(
1045
+ {
1046
+ "epoch": epoch + 1,
1047
+ "best_val_loss": best_val_loss,
1048
+ "patience_ctr": patience_ctr,
1049
+ # Model info for auto-detection during inference
1050
+ "model_name": args.model,
1051
+ "in_shape": in_shape,
1052
+ "out_dim": out_dim,
1053
+ },
1054
+ f,
1055
+ )
1056
+ logger.info(f" 📁 Periodic checkpoint: {ckpt_name}")
1057
+
1058
+ # Save CSV with each checkpoint (keeps logs in sync with model state)
1059
+ pd.DataFrame(history).to_csv(
1060
+ os.path.join(args.output_dir, "training_history.csv"),
1061
+ index=False,
1062
+ )
1063
+
1064
+ # Learning rate scheduling (epoch-based schedulers only)
1065
+ if not scheduler_step_per_batch:
1066
+ if args.scheduler == "plateau":
1067
+ scheduler.step(avg_val_loss)
1068
+ else:
1069
+ scheduler.step()
1070
+
1071
+ # DDP-safe early stopping
1072
+ should_stop = (
1073
+ patience_ctr >= args.patience if accelerator.is_main_process else False
1074
+ )
1075
+ if broadcast_early_stop(should_stop, accelerator):
1076
+ if accelerator.is_main_process:
1077
+ logger.info(
1078
+ f"🛑 Early stopping at epoch {epoch + 1} (patience={args.patience})"
1079
+ )
1080
+ # Create completion flag to prevent auto-resume
1081
+ with open(
1082
+ os.path.join(args.output_dir, "training_complete.flag"), "w"
1083
+ ) as f:
1084
+ f.write(
1085
+ f"Training completed via early stopping at epoch {epoch + 1}\n"
1086
+ )
1087
+ break
1088
+
1089
+ except KeyboardInterrupt:
1090
+ logger.warning("Training interrupted. Saving emergency checkpoint...")
1091
+ accelerator.save_state(os.path.join(args.output_dir, "interrupted_checkpoint"))
1092
+
1093
+ except Exception as e:
1094
+ logger.error(f"Critical error: {e}", exc_info=True)
1095
+ raise
1096
+
1097
+ else:
1098
+ # Training completed normally (reached max epochs without early stopping)
1099
+ # Create completion flag to prevent auto-resume on re-run
1100
+ if accelerator.is_main_process:
1101
+ if not os.path.exists(complete_flag_path):
1102
+ with open(complete_flag_path, "w") as f:
1103
+ f.write(f"Training completed normally after {args.epochs} epochs\n")
1104
+ logger.info(f"✅ Training completed after {args.epochs} epochs")
1105
+
1106
+ finally:
1107
+ # Final CSV write to capture all epochs (handles non-multiple-of-10 endings)
1108
+ if accelerator.is_main_process and len(history) > 0:
1109
+ pd.DataFrame(history).to_csv(
1110
+ os.path.join(args.output_dir, "training_history.csv"),
1111
+ index=False,
1112
+ )
1113
+
1114
+ # Generate training curves plot (PNG + SVG)
1115
+ if accelerator.is_main_process and len(history) > 0:
1116
+ try:
1117
+ fig = create_training_curves(history, show_lr=True)
1118
+ for fmt in ["png", "svg"]:
1119
+ fig.savefig(
1120
+ os.path.join(args.output_dir, f"training_curves.{fmt}"),
1121
+ dpi=FIGURE_DPI,
1122
+ bbox_inches="tight",
1123
+ )
1124
+ plt.close(fig)
1125
+ logger.info("✔ Saved: training_curves.png, training_curves.svg")
1126
+ except Exception as e:
1127
+ logger.warning(f"Could not generate training curves: {e}")
1128
+
1129
+ if args.wandb and WANDB_AVAILABLE:
1130
+ accelerator.end_training()
1131
+
1132
+ # Clean up distributed process group to prevent resource leak warning
1133
+ if torch.distributed.is_initialized():
1134
+ torch.distributed.destroy_process_group()
1135
+
1136
+ logger.info("Training completed.")
1137
+
1138
+
1139
+ if __name__ == "__main__":
1140
+ try:
1141
+ torch.multiprocessing.set_start_method("spawn")
1142
+ except RuntimeError:
1143
+ pass
1144
+ main()